mirror of
				https://github.com/isledecomp/isle.git
				synced 2025-10-26 18:04:06 +00:00 
			
		
		
		
	Match vtables with virtual inheritance (#717)
* Match vtables with virtual inheritance * Simplify vtable name check * Thunk alert
This commit is contained in:
		| @@ -86,6 +86,7 @@ class Compare: | ||||
|         self._find_original_strings() | ||||
|         self._match_thunks() | ||||
|         self._match_exports() | ||||
|         self._find_vtordisp() | ||||
| 
 | ||||
|     def _load_cvdump(self): | ||||
|         logger.info("Parsing %s ...", self.pdb_file) | ||||
| @@ -198,7 +199,7 @@ class Compare: | ||||
|                 self._db.match_variable(var.offset, var.name) | ||||
| 
 | ||||
|         for tbl in codebase.iter_vtables(): | ||||
|             self._db.match_vtable(tbl.offset, tbl.name) | ||||
|             self._db.match_vtable(tbl.offset, tbl.name, tbl.base_class) | ||||
| 
 | ||||
|         for string in codebase.iter_strings(): | ||||
|             # Not that we don't trust you, but we're checking the string | ||||
| @@ -285,10 +286,105 @@ class Compare: | ||||
|             ): | ||||
|                 logger.debug("Matched export %s", repr(export_name)) | ||||
| 
 | ||||
|     def _find_vtordisp(self): | ||||
|         """If there are any cases of virtual inheritance, we can read | ||||
|         through the vtables for those classes and find the vtable thunk | ||||
|         functions (vtordisp). | ||||
| 
 | ||||
|         Our approach is this: walk both vtables and check where we have a | ||||
|         vtordisp in the recomp table. Inspect the function at that vtable | ||||
|         position (in both) and check whether we jump to the same function. | ||||
| 
 | ||||
|         One potential pitfall here is that the virtual displacement could | ||||
|         differ between the thunks. We are not (yet) checking for this, so the | ||||
|         result is that the vtable will appear to match but we will have a diff | ||||
|         on the thunk in our regular function comparison. | ||||
| 
 | ||||
|         We could do this differently and check only the original vtable, | ||||
|         construct the name of the vtordisp function and match based on that.""" | ||||
| 
 | ||||
|         for match in self._db.get_matches_by_type(SymbolType.VTABLE): | ||||
|             # We need some method of identifying vtables that | ||||
|             # might have thunks, and this ought to work okay. | ||||
|             if "{for" not in match.name: | ||||
|                 continue | ||||
| 
 | ||||
|             # TODO: We might want to fix this at the source (cvdump) instead. | ||||
|             # Any problem will be logged later when we compare the vtable. | ||||
|             vtable_size = 4 * (match.size // 4) | ||||
|             orig_table = self.orig_bin.read(match.orig_addr, vtable_size) | ||||
|             recomp_table = self.recomp_bin.read(match.recomp_addr, vtable_size) | ||||
| 
 | ||||
|             raw_addrs = zip( | ||||
|                 [t for (t,) in struct.iter_unpack("<L", orig_table)], | ||||
|                 [t for (t,) in struct.iter_unpack("<L", recomp_table)], | ||||
|             ) | ||||
| 
 | ||||
|             # Now walk both vtables looking for thunks. | ||||
|             for orig_addr, recomp_addr in raw_addrs: | ||||
|                 if not self._db.is_vtordisp(recomp_addr): | ||||
|                     continue | ||||
| 
 | ||||
|                 thunk_fn = self.get_by_recomp(recomp_addr) | ||||
| 
 | ||||
|                 # Read the function bytes here. | ||||
|                 # In practice, the adjuster thunk will be under 16 bytes. | ||||
|                 # If we have thunks of unequal size, we can still tell whether | ||||
|                 # they are thunking the same function by grabbing the | ||||
|                 # JMP instruction at the end. | ||||
|                 thunk_presumed_size = max(thunk_fn.size, 16) | ||||
| 
 | ||||
|                 # Strip off MSVC padding 0xcc bytes. | ||||
|                 # This should be safe to do; it is highly unlikely that | ||||
|                 # the MSB of the jump displacement would be 0xcc. (huge jump) | ||||
|                 orig_thunk_bin = self.orig_bin.read( | ||||
|                     orig_addr, thunk_presumed_size | ||||
|                 ).rstrip(b"\xcc") | ||||
| 
 | ||||
|                 recomp_thunk_bin = self.recomp_bin.read( | ||||
|                     recomp_addr, thunk_presumed_size | ||||
|                 ).rstrip(b"\xcc") | ||||
| 
 | ||||
|                 # Read jump opcode and displacement (last 5 bytes) | ||||
|                 (orig_jmp, orig_disp) = struct.unpack("<Bi", orig_thunk_bin[-5:]) | ||||
|                 (recomp_jmp, recomp_disp) = struct.unpack("<Bi", recomp_thunk_bin[-5:]) | ||||
| 
 | ||||
|                 # Make sure it's a JMP | ||||
|                 if orig_jmp != 0xE9 or recomp_jmp != 0xE9: | ||||
|                     continue | ||||
| 
 | ||||
|                 # Calculate jump destination from the end of the JMP instruction | ||||
|                 # i.e. the end of the function | ||||
|                 orig_actual = orig_addr + len(orig_thunk_bin) + orig_disp | ||||
|                 recomp_actual = recomp_addr + len(recomp_thunk_bin) + recomp_disp | ||||
| 
 | ||||
|                 # If they are thunking the same function, then this must be a match. | ||||
|                 if self.is_pointer_match(orig_actual, recomp_actual): | ||||
|                     if len(orig_thunk_bin) != len(recomp_thunk_bin): | ||||
|                         logger.warning( | ||||
|                             "Adjuster thunk %s (0x%x) is not exact", | ||||
|                             thunk_fn.name, | ||||
|                             orig_addr, | ||||
|                         ) | ||||
|                     self._db.set_function_pair(orig_addr, recomp_addr) | ||||
| 
 | ||||
|     def _compare_function(self, match: MatchInfo) -> DiffReport: | ||||
|         orig_raw = self.orig_bin.read(match.orig_addr, match.size) | ||||
|         recomp_raw = self.recomp_bin.read(match.recomp_addr, match.size) | ||||
| 
 | ||||
|         # It's unlikely that a function other than an adjuster thunk would | ||||
|         # start with a SUB instruction, so alert to a possible wrong | ||||
|         # annotation here. | ||||
|         # There's probably a better place to do this, but we're reading | ||||
|         # the function bytes here already. | ||||
|         try: | ||||
|             if orig_raw[0] == 0x2B and recomp_raw[0] != 0x2B: | ||||
|                 logger.warning( | ||||
|                     "Possible thunk at 0x%x (%s)", match.orig_addr, match.name | ||||
|                 ) | ||||
|         except IndexError: | ||||
|             pass | ||||
| 
 | ||||
|         def orig_lookup(addr: int) -> Optional[str]: | ||||
|             m = self._db.get_by_orig(addr) | ||||
|             if m is None: | ||||
| @@ -432,7 +528,7 @@ class Compare: | ||||
|             match_type=SymbolType.VTABLE, | ||||
|             orig_addr=match.orig_addr, | ||||
|             recomp_addr=match.recomp_addr, | ||||
|             name=f"{match.name}::`vftable'", | ||||
|             name=match.name, | ||||
|             udiff=unified_diff, | ||||
|             ratio=ratio, | ||||
|         ) | ||||
|   | ||||
| @@ -4,6 +4,7 @@ import sqlite3 | ||||
| import logging | ||||
| from typing import List, Optional | ||||
| from isledecomp.types import SymbolType | ||||
| from isledecomp.cvdump.demangler import get_vtordisp_name | ||||
| 
 | ||||
| _SETUP_SQL = """ | ||||
|     DROP TABLE IF EXISTS `symbols`; | ||||
| @@ -249,6 +250,37 @@ class CompareDb: | ||||
|             for (option, value) in cur.fetchall() | ||||
|         } | ||||
| 
 | ||||
|     def is_vtordisp(self, recomp_addr: int) -> bool: | ||||
|         """Check whether this function is a vtordisp based on its | ||||
|         decorated name. If its demangled name is missing the vtordisp | ||||
|         indicator, correct that.""" | ||||
|         row = self._db.execute( | ||||
|             """SELECT name, decorated_name | ||||
|             FROM `symbols` | ||||
|             WHERE recomp_addr = ?""", | ||||
|             (recomp_addr,), | ||||
|         ).fetchone() | ||||
| 
 | ||||
|         if row is None: | ||||
|             return False | ||||
| 
 | ||||
|         (name, decorated_name) = row | ||||
|         if "`vtordisp" in name: | ||||
|             return True | ||||
| 
 | ||||
|         new_name = get_vtordisp_name(decorated_name) | ||||
|         if new_name is None: | ||||
|             return False | ||||
| 
 | ||||
|         self._db.execute( | ||||
|             """UPDATE `symbols` | ||||
|             SET name = ? | ||||
|             WHERE recomp_addr = ?""", | ||||
|             (new_name, recomp_addr), | ||||
|         ) | ||||
| 
 | ||||
|         return True | ||||
| 
 | ||||
|     def _find_potential_match( | ||||
|         self, name: str, compare_type: SymbolType | ||||
|     ) -> Optional[int]: | ||||
| @@ -323,12 +355,34 @@ class CompareDb: | ||||
| 
 | ||||
|         return did_match | ||||
| 
 | ||||
|     def match_vtable(self, addr: int, name: str) -> bool: | ||||
|         did_match = self._match_on(SymbolType.VTABLE, addr, name) | ||||
|         if not did_match: | ||||
|             logger.error("Failed to find vtable for class: %s", name) | ||||
|     def match_vtable( | ||||
|         self, addr: int, name: str, base_class: Optional[str] = None | ||||
|     ) -> bool: | ||||
|         # Only allow a match against "Class:`vftable'" | ||||
|         # if this is the derived class. | ||||
|         name = ( | ||||
|             f"{name}::`vftable'" | ||||
|             if base_class is None or base_class == name | ||||
|             else f"{name}::`vftable'{{for `{base_class}'}}" | ||||
|         ) | ||||
| 
 | ||||
|         return did_match | ||||
|         row = self._db.execute( | ||||
|             """ | ||||
|             SELECT recomp_addr | ||||
|             FROM `symbols` | ||||
|             WHERE orig_addr IS NULL | ||||
|             AND name = ? | ||||
|             AND (compare_type = ?) | ||||
|             LIMIT 1 | ||||
|             """, | ||||
|             (name, SymbolType.VTABLE.value), | ||||
|         ).fetchone() | ||||
| 
 | ||||
|         if row is not None and self.set_pair(addr, row[0], SymbolType.VTABLE): | ||||
|             return True | ||||
| 
 | ||||
|         logger.error("Failed to find vtable for class: %s", name) | ||||
|         return False | ||||
| 
 | ||||
|     def match_static_variable(self, addr: int, name: str, function_addr: int) -> bool: | ||||
|         """Matching a static function variable by combining the variable name | ||||
|   | ||||
| @@ -5,6 +5,7 @@ https://en.wikiversity.org/wiki/Visual_C%2B%2B_name_mangling | ||||
| import re | ||||
| from collections import namedtuple | ||||
| from typing import Optional | ||||
| import pydemangler | ||||
| 
 | ||||
| 
 | ||||
| class InvalidEncodedNumberError(Exception): | ||||
| @@ -51,8 +52,52 @@ def demangle_string_const(symbol: str) -> Optional[StringConstInfo]: | ||||
|     return StringConstInfo(len=strlen, is_utf16=is_utf16) | ||||
| 
 | ||||
| 
 | ||||
| def get_vtordisp_name(symbol: str) -> Optional[str]: | ||||
|     # pylint: disable=c-extension-no-member | ||||
|     """For adjuster thunk functions, the PDB will sometimes use a name | ||||
|     that contains "vtordisp" but often will just reuse the name of the | ||||
|     function being thunked. We want to use the vtordisp name if possible.""" | ||||
|     name = pydemangler.demangle(symbol) | ||||
|     if name is None: | ||||
|         return None | ||||
| 
 | ||||
|     if "`vtordisp" not in name: | ||||
|         return None | ||||
| 
 | ||||
|     # Now we remove the parts of the friendly name that we don't need | ||||
|     try: | ||||
|         # Assuming this is the last of the function prefixes | ||||
|         thiscall_idx = name.index("__thiscall") | ||||
|         # To match the end of the `vtordisp{x,y}' string | ||||
|         end_idx = name.index("}'") | ||||
|         return name[thiscall_idx + 11 : end_idx + 2] | ||||
|     except ValueError: | ||||
|         return name | ||||
| 
 | ||||
| 
 | ||||
| def demangle_vtable(symbol: str) -> str: | ||||
|     # pylint: disable=c-extension-no-member | ||||
|     """Get the class name referenced in the vtable symbol.""" | ||||
|     raw = pydemangler.demangle(symbol) | ||||
| 
 | ||||
|     if raw is None: | ||||
|         pass  # TODO: This shouldn't happen if MSVC behaves | ||||
| 
 | ||||
|     # Remove storage class and other stuff we don't care about | ||||
|     return ( | ||||
|         raw.replace("<class ", "<") | ||||
|         .replace("<struct ", "<") | ||||
|         .replace("const ", "") | ||||
|         .replace("volatile ", "") | ||||
|     ) | ||||
| 
 | ||||
| 
 | ||||
| def demangle_vtable_ourselves(symbol: str) -> str: | ||||
|     """Parked implementation of MSVC symbol demangling. | ||||
|     We only use this for vtables and it works okay with the simple cases or | ||||
|     templates that refer to other classes/structs. Some namespace support. | ||||
|     Does not support backrefs, primitive types, or vtables with | ||||
|     virtual inheritance.""" | ||||
| 
 | ||||
|     # Seek ahead 4 chars to strip off "??_7" prefix | ||||
|     t = symbol[4:].split("@") | ||||
| @@ -66,11 +111,11 @@ def demangle_vtable(symbol: str) -> str: | ||||
|         else: | ||||
|             generic = t[1][1:] | ||||
| 
 | ||||
|         return f"{class_name}<{generic}>" | ||||
|         return f"{class_name}<{generic}>::`vftable'" | ||||
| 
 | ||||
|     # If we have two classes listed, it is a namespace hierarchy. | ||||
|     # @@6B@ is a common generic suffix for these vtable symbols. | ||||
|     if t[1] != "" and t[1] != "6B": | ||||
|         return t[1] + "::" + t[0] | ||||
|         return t[1] + "::" + t[0] + "::`vftable'" | ||||
| 
 | ||||
|     return t[0] | ||||
|     return t[0] + "::`vftable'" | ||||
|   | ||||
| @@ -1,5 +1,5 @@ | ||||
| import re | ||||
| from typing import Optional | ||||
| from typing import Optional, Tuple | ||||
| from enum import Enum | ||||
| 
 | ||||
| 
 | ||||
| @@ -29,18 +29,20 @@ class MarkerType(Enum): | ||||
| 
 | ||||
| 
 | ||||
| markerRegex = re.compile( | ||||
|     r"\s*//\s*(?P<type>\w+):\s*(?P<module>\w+)\s+(?P<offset>0x[a-f0-9]+)", | ||||
|     r"\s*//\s*(?P<type>\w+):\s*(?P<module>\w+)\s+(?P<offset>0x[a-f0-9]+) *(?P<extra>\S.+\S)?", | ||||
|     flags=re.I, | ||||
| ) | ||||
| 
 | ||||
| 
 | ||||
| markerExactRegex = re.compile( | ||||
|     r"\s*// (?P<type>[A-Z]+): (?P<module>[A-Z0-9]+) (?P<offset>0x[a-f0-9]+)$" | ||||
|     r"\s*// (?P<type>[A-Z]+): (?P<module>[A-Z0-9]+) (?P<offset>0x[a-f0-9]+)(?: (?P<extra>\S.+\S))?\n?$" | ||||
| ) | ||||
| 
 | ||||
| 
 | ||||
| class DecompMarker: | ||||
|     def __init__(self, marker_type: str, module: str, offset: int) -> None: | ||||
|     def __init__( | ||||
|         self, marker_type: str, module: str, offset: int, extra: Optional[str] = None | ||||
|     ) -> None: | ||||
|         try: | ||||
|             self._type = MarkerType[marker_type.upper()] | ||||
|         except KeyError: | ||||
| @@ -51,6 +53,7 @@ class DecompMarker: | ||||
|         # we will emit a syntax error. | ||||
|         self._module: str = module.upper() | ||||
|         self._offset: int = offset | ||||
|         self._extra: Optional[str] = extra | ||||
| 
 | ||||
|     @property | ||||
|     def type(self) -> MarkerType: | ||||
| @@ -64,6 +67,10 @@ class DecompMarker: | ||||
|     def offset(self) -> int: | ||||
|         return self._offset | ||||
| 
 | ||||
|     @property | ||||
|     def extra(self) -> Optional[str]: | ||||
|         return self._extra | ||||
| 
 | ||||
|     @property | ||||
|     def category(self) -> MarkerCategory: | ||||
|         if self.is_vtable(): | ||||
| @@ -81,6 +88,11 @@ class DecompMarker: | ||||
| 
 | ||||
|         return MarkerCategory.ADDRESS | ||||
| 
 | ||||
|     @property | ||||
|     def key(self) -> Tuple[str, str, Optional[str]]: | ||||
|         """For use with the MarkerDict. To detect/avoid marker collision.""" | ||||
|         return (self.category, self.module, self.extra) | ||||
| 
 | ||||
|     def is_regular_function(self) -> bool: | ||||
|         """Regular function, meaning: not an explicit byname lookup. FUNCTION | ||||
|         markers can be _implicit_ byname. | ||||
| @@ -126,6 +138,7 @@ def match_marker(line: str) -> Optional[DecompMarker]: | ||||
|         marker_type=match.group("type"), | ||||
|         module=match.group("module"), | ||||
|         offset=int(match.group("offset"), 16), | ||||
|         extra=match.group("extra"), | ||||
|     ) | ||||
| 
 | ||||
| 
 | ||||
|   | ||||
| @@ -55,7 +55,7 @@ class ParserVariable(ParserSymbol): | ||||
| 
 | ||||
| @dataclass | ||||
| class ParserVtable(ParserSymbol): | ||||
|     pass | ||||
|     base_class: Optional[str] = None | ||||
| 
 | ||||
| 
 | ||||
| @dataclass | ||||
|   | ||||
| @@ -47,15 +47,16 @@ class MarkerDict: | ||||
| 
 | ||||
|     def insert(self, marker: DecompMarker) -> bool: | ||||
|         """Return True if this insert would overwrite""" | ||||
|         key = (marker.category, marker.module) | ||||
|         if key in self.markers: | ||||
|         if marker.key in self.markers: | ||||
|             return True | ||||
| 
 | ||||
|         self.markers[key] = marker | ||||
|         self.markers[marker.key] = marker | ||||
|         return False | ||||
| 
 | ||||
|     def query(self, category: MarkerCategory, module: str) -> Optional[DecompMarker]: | ||||
|         return self.markers.get((category, module)) | ||||
|     def query( | ||||
|         self, category: MarkerCategory, module: str, extra: Optional[str] = None | ||||
|     ) -> Optional[DecompMarker]: | ||||
|         return self.markers.get((category, module, extra)) | ||||
| 
 | ||||
|     def iter(self) -> Iterator[DecompMarker]: | ||||
|         for _, marker in self.markers.items(): | ||||
| @@ -275,6 +276,7 @@ class DecompParser: | ||||
|                     module=marker.module, | ||||
|                     offset=marker.offset, | ||||
|                     name=self.curly.get_prefix(class_name), | ||||
|                     base_class=marker.extra, | ||||
|                 ) | ||||
|             ) | ||||
| 
 | ||||
|   | ||||
		Reference in New Issue
	
	Block a user
	 MS
					MS