reccmp: Unique addresses and stub reporting (#554)

This commit is contained in:
MS
2024-02-13 20:25:51 -05:00
committed by GitHub
parent eb3b339454
commit 1b696e4bd8
8 changed files with 326 additions and 99 deletions

View File

@@ -20,6 +20,7 @@ logger = logging.getLogger(__name__)
@dataclass
class DiffReport:
# pylint: disable=too-many-instance-attributes
match_type: SymbolType
orig_addr: int
recomp_addr: int
@@ -27,6 +28,7 @@ class DiffReport:
udiff: Optional[List[str]] = None
ratio: float = 0.0
is_effective_match: bool = False
is_stub: bool = False
@property
def effective_ratio(self) -> float:
@@ -130,7 +132,23 @@ class Compare:
raw = self.recomp_bin.read(addr, sym.size())
try:
sym.friendly_name = raw.decode("latin1").rstrip("\x00")
# We use the string length reported in the mangled symbol as the
# data size, but this is not always accurate with respect to the
# null terminator.
# e.g. ??_C@_0BA@EFDM@MxObjectFactory?$AA@
# reported length: 16 (includes null terminator)
# c.f. ??_C@_03DPKJ@enz?$AA@
# reported length: 3 (does NOT include terminator)
# This will handle the case where the entire string contains "\x00"
# because those are distinct from the empty string of length 0.
decoded_string = raw.decode("latin1")
rstrip_string = decoded_string.rstrip("\x00")
if decoded_string != "" and rstrip_string != "":
sym.friendly_name = rstrip_string
else:
sym.friendly_name = decoded_string
except UnicodeDecodeError:
pass
@@ -162,12 +180,12 @@ class Compare:
if recomp_addr is not None:
self._db.set_function_pair(fun.offset, recomp_addr)
if fun.should_skip():
self._db.skip_compare(fun.offset)
self._db.mark_stub(fun.offset)
for fun in codebase.iter_name_functions():
self._db.match_function(fun.offset, fun.name)
if fun.should_skip():
self._db.skip_compare(fun.offset)
self._db.mark_stub(fun.offset)
for var in codebase.iter_variables():
if var.is_static and var.parent_function is not None:
@@ -255,15 +273,6 @@ class Compare:
self._db.skip_compare(thunk_from_orig)
def _compare_function(self, match: MatchInfo) -> DiffReport:
if match.size == 0:
# Report a failed match to make the user aware of the empty function.
return DiffReport(
match_type=SymbolType.FUNCTION,
orig_addr=match.orig_addr,
recomp_addr=match.recomp_addr,
name=match.name,
)
orig_raw = self.orig_bin.read(match.orig_addr, match.size)
recomp_raw = self.recomp_bin.read(match.recomp_addr, match.size)
@@ -406,6 +415,23 @@ class Compare:
def _compare_match(self, match: MatchInfo) -> Optional[DiffReport]:
"""Router for comparison type"""
if match.size == 0:
return None
options = self._db.get_match_options(match.orig_addr)
if options.get("skip", False):
return None
if options.get("stub", False):
return DiffReport(
match_type=match.compare_type,
orig_addr=match.orig_addr,
recomp_addr=match.recomp_addr,
name=match.name,
is_stub=True,
)
if match.compare_type == SymbolType.FUNCTION:
return self._compare_function(match)
@@ -440,7 +466,9 @@ class Compare:
def compare_functions(self) -> Iterable[DiffReport]:
for match in self.get_functions():
yield self._compare_match(match)
diff = self._compare_match(match)
if diff is not None:
yield diff
def compare_variables(self):
pass
@@ -453,4 +481,6 @@ class Compare:
def compare_vtables(self) -> Iterable[DiffReport]:
for match in self.get_vtables():
yield self._compare_match(match)
diff = self._compare_match(match)
if diff is not None:
yield self._compare_match(match)

View File

@@ -7,15 +7,30 @@ from isledecomp.types import SymbolType
_SETUP_SQL = """
DROP TABLE IF EXISTS `symbols`;
DROP TABLE IF EXISTS `match_options`;
CREATE TABLE `symbols` (
compare_type int,
orig_addr int,
recomp_addr int,
name text,
decorated_name text,
size int,
should_skip int default(FALSE)
size int
);
CREATE TABLE `match_options` (
addr int not null,
name text not null,
value text,
primary key (addr, name)
) without rowid;
CREATE VIEW IF NOT EXISTS `match_info`
(compare_type, orig_addr, recomp_addr, name, size) AS
SELECT compare_type, orig_addr, recomp_addr, name, size
FROM `symbols`
ORDER BY orig_addr NULLS LAST;
CREATE INDEX `symbols_or` ON `symbols` (orig_addr);
CREATE INDEX `symbols_re` ON `symbols` (recomp_addr);
CREATE INDEX `symbols_na` ON `symbols` (name);
@@ -69,6 +84,11 @@ class CompareDb:
decorated_name: Optional[str],
size: Optional[int],
):
# Ignore collisions here. The same recomp address can have
# multiple names (e.g. _strlwr and __strlwr)
if self.recomp_used(addr):
return
compare_value = compare_type.value if compare_type is not None else None
self._db.execute(
"INSERT INTO `symbols` (recomp_addr, compare_type, name, decorated_name, size) VALUES (?,?,?,?,?)",
@@ -86,24 +106,16 @@ class CompareDb:
return [string for (string,) in cur.fetchall()]
def get_all(self) -> List[MatchInfo]:
cur = self._db.execute(
"""SELECT compare_type, orig_addr, recomp_addr, name, size
FROM `symbols`
ORDER BY orig_addr NULLS LAST
""",
)
cur = self._db.execute("SELECT * FROM `match_info`")
cur.row_factory = matchinfo_factory
return cur.fetchall()
def get_matches(self) -> Optional[MatchInfo]:
cur = self._db.execute(
"""SELECT compare_type, orig_addr, recomp_addr, name, size
FROM `symbols`
"""SELECT * FROM `match_info`
WHERE orig_addr IS NOT NULL
AND recomp_addr IS NOT NULL
AND should_skip IS FALSE
ORDER BY orig_addr
""",
)
cur.row_factory = matchinfo_factory
@@ -112,11 +124,9 @@ class CompareDb:
def get_one_match(self, addr: int) -> Optional[MatchInfo]:
cur = self._db.execute(
"""SELECT compare_type, orig_addr, recomp_addr, name, size
FROM `symbols`
"""SELECT * FROM `match_info`
WHERE orig_addr = ?
AND recomp_addr IS NOT NULL
AND should_skip IS FALSE
""",
(addr,),
)
@@ -125,8 +135,7 @@ class CompareDb:
def get_by_orig(self, addr: int) -> Optional[MatchInfo]:
cur = self._db.execute(
"""SELECT compare_type, orig_addr, recomp_addr, name, size
FROM `symbols`
"""SELECT * FROM `match_info`
WHERE orig_addr = ?
""",
(addr,),
@@ -136,8 +145,7 @@ class CompareDb:
def get_by_recomp(self, addr: int) -> Optional[MatchInfo]:
cur = self._db.execute(
"""SELECT compare_type, orig_addr, recomp_addr, name, size
FROM `symbols`
"""SELECT * FROM `match_info`
WHERE recomp_addr = ?
""",
(addr,),
@@ -147,13 +155,10 @@ class CompareDb:
def get_matches_by_type(self, compare_type: SymbolType) -> List[MatchInfo]:
cur = self._db.execute(
"""SELECT compare_type, orig_addr, recomp_addr, name, size
FROM `symbols`
"""SELECT * FROM `match_info`
WHERE compare_type = ?
AND orig_addr IS NOT NULL
AND recomp_addr IS NOT NULL
AND should_skip IS FALSE
ORDER BY orig_addr
""",
(compare_type.value,),
)
@@ -161,9 +166,21 @@ class CompareDb:
return cur.fetchall()
def orig_used(self, addr: int) -> bool:
cur = self._db.execute("SELECT 1 FROM symbols WHERE orig_addr = ?", (addr,))
return cur.fetchone() is not None
def recomp_used(self, addr: int) -> bool:
cur = self._db.execute("SELECT 1 FROM symbols WHERE recomp_addr = ?", (addr,))
return cur.fetchone() is not None
def set_pair(
self, orig: int, recomp: int, compare_type: Optional[SymbolType] = None
) -> bool:
if self.orig_used(orig):
logger.error("Original address %s not unique!", hex(orig))
return False
compare_value = compare_type.value if compare_type is not None else None
cur = self._db.execute(
"UPDATE `symbols` SET orig_addr = ?, compare_type = ? WHERE recomp_addr = ?",
@@ -174,14 +191,90 @@ class CompareDb:
def set_function_pair(self, orig: int, recomp: int) -> bool:
"""For lineref match or _entry"""
self.set_pair(orig, recomp, SymbolType.FUNCTION)
# TODO: Both ways required?
return self.set_pair(orig, recomp, SymbolType.FUNCTION)
def _set_opt_bool(self, addr: int, option: str, enabled: bool = True):
if enabled:
self._db.execute(
"""INSERT OR IGNORE INTO `match_options`
(addr, name)
VALUES (?, ?)""",
(addr, option),
)
else:
self._db.execute(
"""DELETE FROM `match_options` WHERE addr = ? AND name = ?""",
(addr, option),
)
def mark_stub(self, orig: int):
self._set_opt_bool(orig, "stub")
def skip_compare(self, orig: int):
self._db.execute(
"UPDATE `symbols` SET should_skip = TRUE WHERE orig_addr = ?", (orig,)
self._set_opt_bool(orig, "skip")
def get_match_options(self, addr: int) -> Optional[dict]:
cur = self._db.execute(
"""SELECT name, value FROM `match_options` WHERE addr = ?""", (addr,)
)
return {
option: value if value is not None else True
for (option, value) in cur.fetchall()
}
def _find_potential_match(
self, name: str, compare_type: SymbolType
) -> Optional[int]:
"""Name lookup"""
match_decorate = compare_type != SymbolType.STRING and name.startswith("?")
if match_decorate:
sql = """
SELECT recomp_addr
FROM `symbols`
WHERE orig_addr IS NULL
AND decorated_name = ?
AND (compare_type IS NULL OR compare_type = ?)
LIMIT 1
"""
else:
sql = """
SELECT recomp_addr
FROM `symbols`
WHERE orig_addr IS NULL
AND name = ?
AND (compare_type IS NULL OR compare_type = ?)
LIMIT 1
"""
row = self._db.execute(sql, (name, compare_type.value)).fetchone()
return row[0] if row is not None else None
def _find_static_variable(
self, variable_name: str, function_sym: str
) -> Optional[int]:
"""Get the recomp address of a static function variable.
Matches using a LIKE clause on the combination of:
1. The variable name read from decomp marker.
2. The decorated name of the enclosing function.
For example, the variable "g_startupDelay" from function "IsleApp::Tick"
has symbol: `?g_startupDelay@?1??Tick@IsleApp@@QAEXH@Z@4HA`
The function's decorated name is: `?Tick@IsleApp@@QAEXH@Z`"""
row = self._db.execute(
"""SELECT recomp_addr FROM `symbols`
WHERE decorated_name LIKE '%' || ? || '%' || ? || '%'
AND orig_addr IS NULL
AND (compare_type = ? OR compare_type = ? OR compare_type IS NULL)""",
(
variable_name,
function_sym,
SymbolType.DATA.value,
SymbolType.POINTER.value,
),
).fetchone()
return row[0] if row is not None else None
def _match_on(self, compare_type: SymbolType, addr: int, name: str) -> bool:
# Update the compare_type here too since the marker tells us what we should do
@@ -191,16 +284,11 @@ class CompareDb:
name = name[:255]
logger.debug("Looking for %s %s", compare_type.name.lower(), name)
cur = self._db.execute(
"""UPDATE `symbols`
SET orig_addr = ?, compare_type = ?
WHERE name = ?
AND orig_addr IS NULL
AND (compare_type = ? OR compare_type IS NULL)""",
(addr, compare_type.value, name, compare_type.value),
)
recomp_addr = self._find_potential_match(name, compare_type)
if recomp_addr is None:
return False
return cur.rowcount > 0
return self.set_pair(addr, recomp_addr, compare_type)
def match_function(self, addr: int, name: str) -> bool:
did_match = self._match_on(SymbolType.FUNCTION, addr, name)
@@ -234,37 +322,20 @@ class CompareDb:
# Get the friendly name for the "failed to match" error message
(function_name, decorated_name) = result
# Now we have to combine the variable name (read from the marker)
# and the decorated name of the enclosing function (the above variable)
# into a LIKE clause and try to match.
# For example, the variable "g_startupDelay" from function "IsleApp::Tick"
# has symbol: "?g_startupDelay@?1??Tick@IsleApp@@QAEXH@Z@4HA"
# The function's decorated name is: "?Tick@IsleApp@@QAEXH@Z"
cur = self._db.execute(
"""UPDATE `symbols`
SET orig_addr = ?
WHERE name LIKE '%' || ? || '%' || ? || '%'
AND orig_addr IS NULL
AND (compare_type = ? OR compare_type = ? OR compare_type IS NULL)""",
(
addr,
name,
decorated_name,
SymbolType.DATA.value,
SymbolType.POINTER.value,
),
recomp_addr = self._find_static_variable(name, decorated_name)
if recomp_addr is not None:
# TODO: This variable could be a pointer, but I don't think we
# have a way to tell that right now.
if self.set_pair(addr, recomp_addr, SymbolType.DATA):
return True
logger.error(
"Failed to match static variable %s from function %s",
name,
function_name,
)
did_match = cur.rowcount > 0
if not did_match:
logger.error(
"Failed to match static variable %s from function %s",
name,
function_name,
)
return did_match
return False
def match_variable(self, addr: int, name: str) -> bool:
did_match = self._match_on(SymbolType.DATA, addr, name) or self._match_on(

View File

@@ -4,6 +4,9 @@ import colorama
def print_diff(udiff, plain):
if udiff is None:
return False
has_diff = False
for line in udiff:
has_diff = True