mirror of
				https://github.com/isledecomp/isle.git
				synced 2025-10-26 09:54:18 +00:00 
			
		
		
		
	reccmp: Unique addresses and stub reporting (#554)
This commit is contained in:
		| @@ -20,6 +20,7 @@ logger = logging.getLogger(__name__) | ||||
| 
 | ||||
| @dataclass | ||||
| class DiffReport: | ||||
|     # pylint: disable=too-many-instance-attributes | ||||
|     match_type: SymbolType | ||||
|     orig_addr: int | ||||
|     recomp_addr: int | ||||
| @@ -27,6 +28,7 @@ class DiffReport: | ||||
|     udiff: Optional[List[str]] = None | ||||
|     ratio: float = 0.0 | ||||
|     is_effective_match: bool = False | ||||
|     is_stub: bool = False | ||||
| 
 | ||||
|     @property | ||||
|     def effective_ratio(self) -> float: | ||||
| @@ -130,7 +132,23 @@ class Compare: | ||||
| 
 | ||||
|                 raw = self.recomp_bin.read(addr, sym.size()) | ||||
|                 try: | ||||
|                     sym.friendly_name = raw.decode("latin1").rstrip("\x00") | ||||
|                     # We use the string length reported in the mangled symbol as the | ||||
|                     # data size, but this is not always accurate with respect to the | ||||
|                     # null terminator. | ||||
|                     # e.g. ??_C@_0BA@EFDM@MxObjectFactory?$AA@ | ||||
|                     # reported length: 16 (includes null terminator) | ||||
|                     # c.f. ??_C@_03DPKJ@enz?$AA@ | ||||
|                     # reported length: 3 (does NOT include terminator) | ||||
|                     # This will handle the case where the entire string contains "\x00" | ||||
|                     # because those are distinct from the empty string of length 0. | ||||
|                     decoded_string = raw.decode("latin1") | ||||
|                     rstrip_string = decoded_string.rstrip("\x00") | ||||
| 
 | ||||
|                     if decoded_string != "" and rstrip_string != "": | ||||
|                         sym.friendly_name = rstrip_string | ||||
|                     else: | ||||
|                         sym.friendly_name = decoded_string | ||||
| 
 | ||||
|                 except UnicodeDecodeError: | ||||
|                     pass | ||||
| 
 | ||||
| @@ -162,12 +180,12 @@ class Compare: | ||||
|             if recomp_addr is not None: | ||||
|                 self._db.set_function_pair(fun.offset, recomp_addr) | ||||
|                 if fun.should_skip(): | ||||
|                     self._db.skip_compare(fun.offset) | ||||
|                     self._db.mark_stub(fun.offset) | ||||
| 
 | ||||
|         for fun in codebase.iter_name_functions(): | ||||
|             self._db.match_function(fun.offset, fun.name) | ||||
|             if fun.should_skip(): | ||||
|                 self._db.skip_compare(fun.offset) | ||||
|                 self._db.mark_stub(fun.offset) | ||||
| 
 | ||||
|         for var in codebase.iter_variables(): | ||||
|             if var.is_static and var.parent_function is not None: | ||||
| @@ -255,15 +273,6 @@ class Compare: | ||||
|                 self._db.skip_compare(thunk_from_orig) | ||||
| 
 | ||||
|     def _compare_function(self, match: MatchInfo) -> DiffReport: | ||||
|         if match.size == 0: | ||||
|             # Report a failed match to make the user aware of the empty function. | ||||
|             return DiffReport( | ||||
|                 match_type=SymbolType.FUNCTION, | ||||
|                 orig_addr=match.orig_addr, | ||||
|                 recomp_addr=match.recomp_addr, | ||||
|                 name=match.name, | ||||
|             ) | ||||
| 
 | ||||
|         orig_raw = self.orig_bin.read(match.orig_addr, match.size) | ||||
|         recomp_raw = self.recomp_bin.read(match.recomp_addr, match.size) | ||||
| 
 | ||||
| @@ -406,6 +415,23 @@ class Compare: | ||||
| 
 | ||||
|     def _compare_match(self, match: MatchInfo) -> Optional[DiffReport]: | ||||
|         """Router for comparison type""" | ||||
| 
 | ||||
|         if match.size == 0: | ||||
|             return None | ||||
| 
 | ||||
|         options = self._db.get_match_options(match.orig_addr) | ||||
|         if options.get("skip", False): | ||||
|             return None | ||||
| 
 | ||||
|         if options.get("stub", False): | ||||
|             return DiffReport( | ||||
|                 match_type=match.compare_type, | ||||
|                 orig_addr=match.orig_addr, | ||||
|                 recomp_addr=match.recomp_addr, | ||||
|                 name=match.name, | ||||
|                 is_stub=True, | ||||
|             ) | ||||
| 
 | ||||
|         if match.compare_type == SymbolType.FUNCTION: | ||||
|             return self._compare_function(match) | ||||
| 
 | ||||
| @@ -440,7 +466,9 @@ class Compare: | ||||
| 
 | ||||
|     def compare_functions(self) -> Iterable[DiffReport]: | ||||
|         for match in self.get_functions(): | ||||
|             yield self._compare_match(match) | ||||
|             diff = self._compare_match(match) | ||||
|             if diff is not None: | ||||
|                 yield diff | ||||
| 
 | ||||
|     def compare_variables(self): | ||||
|         pass | ||||
| @@ -453,4 +481,6 @@ class Compare: | ||||
| 
 | ||||
|     def compare_vtables(self) -> Iterable[DiffReport]: | ||||
|         for match in self.get_vtables(): | ||||
|             yield self._compare_match(match) | ||||
|             diff = self._compare_match(match) | ||||
|             if diff is not None: | ||||
|                 yield self._compare_match(match) | ||||
|   | ||||
| @@ -7,15 +7,30 @@ from isledecomp.types import SymbolType | ||||
| 
 | ||||
| _SETUP_SQL = """ | ||||
|     DROP TABLE IF EXISTS `symbols`; | ||||
|     DROP TABLE IF EXISTS `match_options`; | ||||
| 
 | ||||
|     CREATE TABLE `symbols` ( | ||||
|         compare_type int, | ||||
|         orig_addr int, | ||||
|         recomp_addr int, | ||||
|         name text, | ||||
|         decorated_name text, | ||||
|         size int, | ||||
|         should_skip int default(FALSE) | ||||
|         size int | ||||
|     ); | ||||
| 
 | ||||
|     CREATE TABLE `match_options` ( | ||||
|         addr int not null, | ||||
|         name text not null, | ||||
|         value text, | ||||
|         primary key (addr, name) | ||||
|     ) without rowid; | ||||
| 
 | ||||
|     CREATE VIEW IF NOT EXISTS `match_info` | ||||
|     (compare_type, orig_addr, recomp_addr, name, size) AS | ||||
|         SELECT compare_type, orig_addr, recomp_addr, name, size | ||||
|         FROM `symbols` | ||||
|         ORDER BY orig_addr NULLS LAST; | ||||
| 
 | ||||
|     CREATE INDEX `symbols_or` ON `symbols` (orig_addr); | ||||
|     CREATE INDEX `symbols_re` ON `symbols` (recomp_addr); | ||||
|     CREATE INDEX `symbols_na` ON `symbols` (name); | ||||
| @@ -69,6 +84,11 @@ class CompareDb: | ||||
|         decorated_name: Optional[str], | ||||
|         size: Optional[int], | ||||
|     ): | ||||
|         # Ignore collisions here. The same recomp address can have | ||||
|         # multiple names (e.g. _strlwr and __strlwr) | ||||
|         if self.recomp_used(addr): | ||||
|             return | ||||
| 
 | ||||
|         compare_value = compare_type.value if compare_type is not None else None | ||||
|         self._db.execute( | ||||
|             "INSERT INTO `symbols` (recomp_addr, compare_type, name, decorated_name, size) VALUES (?,?,?,?,?)", | ||||
| @@ -86,24 +106,16 @@ class CompareDb: | ||||
|         return [string for (string,) in cur.fetchall()] | ||||
| 
 | ||||
|     def get_all(self) -> List[MatchInfo]: | ||||
|         cur = self._db.execute( | ||||
|             """SELECT compare_type, orig_addr, recomp_addr, name, size | ||||
|             FROM `symbols` | ||||
|             ORDER BY orig_addr NULLS LAST | ||||
|             """, | ||||
|         ) | ||||
|         cur = self._db.execute("SELECT * FROM `match_info`") | ||||
|         cur.row_factory = matchinfo_factory | ||||
| 
 | ||||
|         return cur.fetchall() | ||||
| 
 | ||||
|     def get_matches(self) -> Optional[MatchInfo]: | ||||
|         cur = self._db.execute( | ||||
|             """SELECT compare_type, orig_addr, recomp_addr, name, size | ||||
|             FROM `symbols` | ||||
|             """SELECT * FROM `match_info` | ||||
|             WHERE orig_addr IS NOT NULL | ||||
|             AND recomp_addr IS NOT NULL | ||||
|             AND should_skip IS FALSE | ||||
|             ORDER BY orig_addr | ||||
|             """, | ||||
|         ) | ||||
|         cur.row_factory = matchinfo_factory | ||||
| @@ -112,11 +124,9 @@ class CompareDb: | ||||
| 
 | ||||
|     def get_one_match(self, addr: int) -> Optional[MatchInfo]: | ||||
|         cur = self._db.execute( | ||||
|             """SELECT compare_type, orig_addr, recomp_addr, name, size | ||||
|             FROM `symbols` | ||||
|             """SELECT * FROM `match_info` | ||||
|             WHERE orig_addr = ? | ||||
|             AND recomp_addr IS NOT NULL | ||||
|             AND should_skip IS FALSE | ||||
|             """, | ||||
|             (addr,), | ||||
|         ) | ||||
| @@ -125,8 +135,7 @@ class CompareDb: | ||||
| 
 | ||||
|     def get_by_orig(self, addr: int) -> Optional[MatchInfo]: | ||||
|         cur = self._db.execute( | ||||
|             """SELECT compare_type, orig_addr, recomp_addr, name, size | ||||
|             FROM `symbols` | ||||
|             """SELECT * FROM `match_info` | ||||
|             WHERE orig_addr = ? | ||||
|             """, | ||||
|             (addr,), | ||||
| @@ -136,8 +145,7 @@ class CompareDb: | ||||
| 
 | ||||
|     def get_by_recomp(self, addr: int) -> Optional[MatchInfo]: | ||||
|         cur = self._db.execute( | ||||
|             """SELECT compare_type, orig_addr, recomp_addr, name, size | ||||
|             FROM `symbols` | ||||
|             """SELECT * FROM `match_info` | ||||
|             WHERE recomp_addr = ? | ||||
|             """, | ||||
|             (addr,), | ||||
| @@ -147,13 +155,10 @@ class CompareDb: | ||||
| 
 | ||||
|     def get_matches_by_type(self, compare_type: SymbolType) -> List[MatchInfo]: | ||||
|         cur = self._db.execute( | ||||
|             """SELECT compare_type, orig_addr, recomp_addr, name, size | ||||
|             FROM `symbols` | ||||
|             """SELECT * FROM `match_info` | ||||
|             WHERE compare_type = ? | ||||
|             AND orig_addr IS NOT NULL | ||||
|             AND recomp_addr IS NOT NULL | ||||
|             AND should_skip IS FALSE | ||||
|             ORDER BY orig_addr | ||||
|             """, | ||||
|             (compare_type.value,), | ||||
|         ) | ||||
| @@ -161,9 +166,21 @@ class CompareDb: | ||||
| 
 | ||||
|         return cur.fetchall() | ||||
| 
 | ||||
|     def orig_used(self, addr: int) -> bool: | ||||
|         cur = self._db.execute("SELECT 1 FROM symbols WHERE orig_addr = ?", (addr,)) | ||||
|         return cur.fetchone() is not None | ||||
| 
 | ||||
|     def recomp_used(self, addr: int) -> bool: | ||||
|         cur = self._db.execute("SELECT 1 FROM symbols WHERE recomp_addr = ?", (addr,)) | ||||
|         return cur.fetchone() is not None | ||||
| 
 | ||||
|     def set_pair( | ||||
|         self, orig: int, recomp: int, compare_type: Optional[SymbolType] = None | ||||
|     ) -> bool: | ||||
|         if self.orig_used(orig): | ||||
|             logger.error("Original address %s not unique!", hex(orig)) | ||||
|             return False | ||||
| 
 | ||||
|         compare_value = compare_type.value if compare_type is not None else None | ||||
|         cur = self._db.execute( | ||||
|             "UPDATE `symbols` SET orig_addr = ?, compare_type = ? WHERE recomp_addr = ?", | ||||
| @@ -174,14 +191,90 @@ class CompareDb: | ||||
| 
 | ||||
|     def set_function_pair(self, orig: int, recomp: int) -> bool: | ||||
|         """For lineref match or _entry""" | ||||
|         self.set_pair(orig, recomp, SymbolType.FUNCTION) | ||||
|         # TODO: Both ways required? | ||||
|         return self.set_pair(orig, recomp, SymbolType.FUNCTION) | ||||
| 
 | ||||
|     def _set_opt_bool(self, addr: int, option: str, enabled: bool = True): | ||||
|         if enabled: | ||||
|             self._db.execute( | ||||
|                 """INSERT OR IGNORE INTO `match_options` | ||||
|                 (addr, name) | ||||
|                 VALUES (?, ?)""", | ||||
|                 (addr, option), | ||||
|             ) | ||||
|         else: | ||||
|             self._db.execute( | ||||
|                 """DELETE FROM `match_options` WHERE addr = ? AND name = ?""", | ||||
|                 (addr, option), | ||||
|             ) | ||||
| 
 | ||||
|     def mark_stub(self, orig: int): | ||||
|         self._set_opt_bool(orig, "stub") | ||||
| 
 | ||||
|     def skip_compare(self, orig: int): | ||||
|         self._db.execute( | ||||
|             "UPDATE `symbols` SET should_skip = TRUE WHERE orig_addr = ?", (orig,) | ||||
|         self._set_opt_bool(orig, "skip") | ||||
| 
 | ||||
|     def get_match_options(self, addr: int) -> Optional[dict]: | ||||
|         cur = self._db.execute( | ||||
|             """SELECT name, value FROM `match_options` WHERE addr = ?""", (addr,) | ||||
|         ) | ||||
| 
 | ||||
|         return { | ||||
|             option: value if value is not None else True | ||||
|             for (option, value) in cur.fetchall() | ||||
|         } | ||||
| 
 | ||||
|     def _find_potential_match( | ||||
|         self, name: str, compare_type: SymbolType | ||||
|     ) -> Optional[int]: | ||||
|         """Name lookup""" | ||||
|         match_decorate = compare_type != SymbolType.STRING and name.startswith("?") | ||||
|         if match_decorate: | ||||
|             sql = """ | ||||
|             SELECT recomp_addr | ||||
|             FROM `symbols` | ||||
|             WHERE orig_addr IS NULL | ||||
|             AND decorated_name = ? | ||||
|             AND (compare_type IS NULL OR compare_type = ?) | ||||
|             LIMIT 1 | ||||
|             """ | ||||
|         else: | ||||
|             sql = """ | ||||
|             SELECT recomp_addr | ||||
|             FROM `symbols` | ||||
|             WHERE orig_addr IS NULL | ||||
|             AND name = ? | ||||
|             AND (compare_type IS NULL OR compare_type = ?) | ||||
|             LIMIT 1 | ||||
|             """ | ||||
| 
 | ||||
|         row = self._db.execute(sql, (name, compare_type.value)).fetchone() | ||||
|         return row[0] if row is not None else None | ||||
| 
 | ||||
|     def _find_static_variable( | ||||
|         self, variable_name: str, function_sym: str | ||||
|     ) -> Optional[int]: | ||||
|         """Get the recomp address of a static function variable. | ||||
|         Matches using a LIKE clause on the combination of: | ||||
|         1. The variable name read from decomp marker. | ||||
|         2. The decorated name of the enclosing function. | ||||
|         For example, the variable "g_startupDelay" from function "IsleApp::Tick" | ||||
|         has symbol: `?g_startupDelay@?1??Tick@IsleApp@@QAEXH@Z@4HA` | ||||
|         The function's decorated name is: `?Tick@IsleApp@@QAEXH@Z`""" | ||||
| 
 | ||||
|         row = self._db.execute( | ||||
|             """SELECT recomp_addr FROM `symbols` | ||||
|             WHERE decorated_name LIKE '%' || ? || '%' || ? || '%' | ||||
|             AND orig_addr IS NULL | ||||
|             AND (compare_type = ? OR compare_type = ? OR compare_type IS NULL)""", | ||||
|             ( | ||||
|                 variable_name, | ||||
|                 function_sym, | ||||
|                 SymbolType.DATA.value, | ||||
|                 SymbolType.POINTER.value, | ||||
|             ), | ||||
|         ).fetchone() | ||||
|         return row[0] if row is not None else None | ||||
| 
 | ||||
|     def _match_on(self, compare_type: SymbolType, addr: int, name: str) -> bool: | ||||
|         # Update the compare_type here too since the marker tells us what we should do | ||||
| 
 | ||||
| @@ -191,16 +284,11 @@ class CompareDb: | ||||
|         name = name[:255] | ||||
| 
 | ||||
|         logger.debug("Looking for %s %s", compare_type.name.lower(), name) | ||||
|         cur = self._db.execute( | ||||
|             """UPDATE `symbols` | ||||
|             SET orig_addr = ?, compare_type = ? | ||||
|             WHERE name = ? | ||||
|             AND orig_addr IS NULL | ||||
|             AND (compare_type = ? OR compare_type IS NULL)""", | ||||
|             (addr, compare_type.value, name, compare_type.value), | ||||
|         ) | ||||
|         recomp_addr = self._find_potential_match(name, compare_type) | ||||
|         if recomp_addr is None: | ||||
|             return False | ||||
| 
 | ||||
|         return cur.rowcount > 0 | ||||
|         return self.set_pair(addr, recomp_addr, compare_type) | ||||
| 
 | ||||
|     def match_function(self, addr: int, name: str) -> bool: | ||||
|         did_match = self._match_on(SymbolType.FUNCTION, addr, name) | ||||
| @@ -234,37 +322,20 @@ class CompareDb: | ||||
|         # Get the friendly name for the "failed to match" error message | ||||
|         (function_name, decorated_name) = result | ||||
| 
 | ||||
|         # Now we have to combine the variable name (read from the marker) | ||||
|         # and the decorated name of the enclosing function (the above variable) | ||||
|         # into a LIKE clause and try to match. | ||||
|         # For example, the variable "g_startupDelay" from function "IsleApp::Tick" | ||||
|         # has symbol: "?g_startupDelay@?1??Tick@IsleApp@@QAEXH@Z@4HA" | ||||
|         # The function's decorated name is: "?Tick@IsleApp@@QAEXH@Z" | ||||
|         cur = self._db.execute( | ||||
|             """UPDATE `symbols` | ||||
|             SET orig_addr = ? | ||||
|             WHERE name LIKE '%' || ? || '%' || ? || '%' | ||||
|             AND orig_addr IS NULL | ||||
|             AND (compare_type = ? OR compare_type = ? OR compare_type IS NULL)""", | ||||
|             ( | ||||
|                 addr, | ||||
|                 name, | ||||
|                 decorated_name, | ||||
|                 SymbolType.DATA.value, | ||||
|                 SymbolType.POINTER.value, | ||||
|             ), | ||||
|         recomp_addr = self._find_static_variable(name, decorated_name) | ||||
|         if recomp_addr is not None: | ||||
|             # TODO: This variable could be a pointer, but I don't think we | ||||
|             # have a way to tell that right now. | ||||
|             if self.set_pair(addr, recomp_addr, SymbolType.DATA): | ||||
|                 return True | ||||
| 
 | ||||
|         logger.error( | ||||
|             "Failed to match static variable %s from function %s", | ||||
|             name, | ||||
|             function_name, | ||||
|         ) | ||||
| 
 | ||||
|         did_match = cur.rowcount > 0 | ||||
| 
 | ||||
|         if not did_match: | ||||
|             logger.error( | ||||
|                 "Failed to match static variable %s from function %s", | ||||
|                 name, | ||||
|                 function_name, | ||||
|             ) | ||||
| 
 | ||||
|         return did_match | ||||
|         return False | ||||
| 
 | ||||
|     def match_variable(self, addr: int, name: str) -> bool: | ||||
|         did_match = self._match_on(SymbolType.DATA, addr, name) or self._match_on( | ||||
|   | ||||
		Reference in New Issue
	
	Block a user
	 MS
					MS