reccmp: vtable comparison (#452)

* Add vtable comparison to reccmp

* Add missing scalar deleting destructors

* Fix some vtables

---------

Co-authored-by: Christian Semmler <mail@csemmler.com>
This commit is contained in:
MS
2024-01-18 08:34:14 -05:00
committed by GitHub
parent 99917ca765
commit 909c44b679
161 changed files with 679 additions and 34 deletions

View File

@@ -1,6 +1,7 @@
import os
import logging
import difflib
import struct
from dataclasses import dataclass
from typing import Iterable, List, Optional
from isledecomp.cvdump.demangler import demangle_string_const
@@ -18,6 +19,7 @@ logger = logging.getLogger(__name__)
@dataclass
class DiffReport:
match_type: SymbolType
orig_addr: int
recomp_addr: int
name: str
@@ -214,17 +216,11 @@ class Compare:
# function in the first place.
self._db.skip_compare(thunk_from_orig)
def get_one_function(self, addr: int) -> Optional[MatchInfo]:
"""i.e. verbose mode for reccmp"""
return self._db.get_one_function(addr)
def get_functions(self) -> List[MatchInfo]:
return self._db.get_matches(SymbolType.FUNCTION)
def _compare_function(self, match: MatchInfo) -> DiffReport:
if match.size == 0:
# Report a failed match to make the user aware of the empty function.
return DiffReport(
match_type=SymbolType.FUNCTION,
orig_addr=match.orig_addr,
recomp_addr=match.recomp_addr,
name=match.name,
@@ -281,6 +277,7 @@ class Compare:
unified_diff = []
return DiffReport(
match_type=SymbolType.FUNCTION,
orig_addr=match.orig_addr,
recomp_addr=match.recomp_addr,
name=match.name,
@@ -289,16 +286,121 @@ class Compare:
is_effective_match=is_effective_match,
)
def compare_function(self, addr: int) -> Optional[DiffReport]:
match = self.get_one_function(addr)
def _compare_vtable(self, match: MatchInfo) -> DiffReport:
vtable_size = match.size
# The vtable size should always be a multiple of 4 because that
# is the pointer size. If it is not (for whatever reason)
# it would cause iter_unpack to blow up so let's just fix it.
if vtable_size % 4 != 0:
logger.warning(
"Vtable for class %s has irregular size %d", match.name, vtable_size
)
vtable_size = 4 * (vtable_size // 4)
orig_table = self.orig_bin.read(match.orig_addr, vtable_size)
recomp_table = self.recomp_bin.read(match.recomp_addr, vtable_size)
raw_addrs = zip(
[t for (t,) in struct.iter_unpack("<L", orig_table)],
[t for (t,) in struct.iter_unpack("<L", recomp_table)],
)
def match_text(
i: int, m: Optional[MatchInfo], raw_addr: Optional[int] = None
) -> str:
"""Format the function reference at this vtable index as text.
If we have not identified this function, we have the option to
display the raw address. This is only worth doing for the original addr
because we should always be able to identify the recomp function.
If the original function is missing then this probably means that the class
should override the given function from the superclass, but we have not
implemented this yet.
"""
index = f"vtable0x{i*4:02x}"
if m is not None:
orig = hex(m.orig_addr) if m.orig_addr is not None else "no orig"
recomp = (
hex(m.recomp_addr) if m.recomp_addr is not None else "no recomp"
)
return f"{index:>12} : ({orig:10} / {recomp:10}) : {m.name}"
if raw_addr is not None:
return f"{index:>12} : 0x{raw_addr:x} from orig not annotated."
return f"{index:>12} : (no match)"
orig_text = []
recomp_text = []
ratio = 0
n_entries = 0
# Now compare each pointer from the two vtables.
for i, (raw_orig, raw_recomp) in enumerate(raw_addrs):
orig = self._db.get_by_orig(raw_orig)
recomp = self._db.get_by_recomp(raw_recomp)
if (
orig is not None
and recomp is not None
and orig.recomp_addr == recomp.recomp_addr
):
ratio += 1
n_entries += 1
orig_text.append(match_text(i, orig, raw_orig))
recomp_text.append(match_text(i, recomp))
ratio = ratio / float(n_entries) if n_entries > 0 else 0
# n=100: Show the entire table if there is a diff to display.
# Otherwise it would be confusing if the table got cut off.
unified_diff = difflib.unified_diff(orig_text, recomp_text, n=100)
return DiffReport(
match_type=SymbolType.VTABLE,
orig_addr=match.orig_addr,
recomp_addr=match.recomp_addr,
name=f"{match.name}::`vftable'",
udiff=unified_diff,
ratio=ratio,
)
def _compare_match(self, match: MatchInfo) -> Optional[DiffReport]:
"""Router for comparison type"""
if match.compare_type == SymbolType.FUNCTION:
return self._compare_function(match)
if match.compare_type == SymbolType.VTABLE:
return self._compare_vtable(match)
return None
## Public API
def get_functions(self) -> List[MatchInfo]:
return self._db.get_matches_by_type(SymbolType.FUNCTION)
def get_vtables(self) -> List[MatchInfo]:
return self._db.get_matches_by_type(SymbolType.VTABLE)
def compare_address(self, addr: int) -> Optional[DiffReport]:
match = self._db.get_one_match(addr)
if match is None:
return None
return self._compare_function(match)
return self._compare_match(match)
def compare_all(self) -> Iterable[DiffReport]:
for match in self._db.get_matches():
diff = self._compare_match(match)
if diff is not None:
yield diff
def compare_functions(self) -> Iterable[DiffReport]:
for match in self.get_functions():
yield self._compare_function(match)
yield self._compare_match(match)
def compare_variables(self):
pass
@@ -309,5 +411,6 @@ class Compare:
def compare_strings(self):
pass
def compare_vtables(self):
pass
def compare_vtables(self) -> Iterable[DiffReport]:
for match in self.get_vtables():
yield self._compare_match(match)

View File

@@ -82,17 +82,29 @@ class CompareDb:
return [string for (string,) in cur.fetchall()]
def get_one_function(self, addr: int) -> Optional[MatchInfo]:
def get_matches(self) -> Optional[MatchInfo]:
cur = self._db.execute(
"""SELECT compare_type, orig_addr, recomp_addr, name, size
FROM `symbols`
WHERE compare_type = ?
AND orig_addr = ?
WHERE orig_addr IS NOT NULL
AND recomp_addr IS NOT NULL
AND should_skip IS FALSE
ORDER BY orig_addr
""",
(SymbolType.FUNCTION.value, addr),
)
cur.row_factory = matchinfo_factory
return cur.fetchall()
def get_one_match(self, addr: int) -> Optional[MatchInfo]:
cur = self._db.execute(
"""SELECT compare_type, orig_addr, recomp_addr, name, size
FROM `symbols`
WHERE orig_addr = ?
AND recomp_addr IS NOT NULL
AND should_skip IS FALSE
""",
(addr,),
)
cur.row_factory = matchinfo_factory
return cur.fetchone()
@@ -119,7 +131,7 @@ class CompareDb:
cur.row_factory = matchinfo_factory
return cur.fetchone()
def get_matches(self, compare_type: SymbolType) -> List[MatchInfo]:
def get_matches_by_type(self, compare_type: SymbolType) -> List[MatchInfo]:
cur = self._db.execute(
"""SELECT compare_type, orig_addr, recomp_addr, name, size
FROM `symbols`

View File

@@ -12,6 +12,7 @@ from isledecomp import (
print_diff,
)
from isledecomp.compare import Compare as IsleCompare
from isledecomp.types import SymbolType
from pystache import Renderer
import colorama
@@ -225,9 +226,9 @@ def main():
### Compare one or none.
if args.verbose is not None:
match = isle_compare.compare_function(args.verbose)
match = isle_compare.compare_address(args.verbose)
if match is None:
print(f"Failed to find the function with address 0x{args.verbose:x}")
print(f"Failed to find a match at address 0x{args.verbose:x}")
return
print_match_verbose(
@@ -242,14 +243,15 @@ def main():
total_effective_accuracy = 0
htmlinsert = []
for match in isle_compare.compare_functions():
for match in isle_compare.compare_all():
print_match_oneline(
match, show_both_addrs=args.print_rec_addr, is_plain=args.no_color
)
function_count += 1
total_accuracy += match.ratio
total_effective_accuracy += match.effective_ratio
if match.match_type == SymbolType.FUNCTION:
function_count += 1
total_accuracy += match.ratio
total_effective_accuracy += match.effective_ratio
# If html, record the diffs to an HTML file
if args.html is not None: