Match vtables with virtual inheritance (#717)

* Match vtables with virtual inheritance

* Simplify vtable name check

* Thunk alert
This commit is contained in:
MS
2024-03-23 18:01:40 -04:00
committed by GitHub
parent b279e8b8b9
commit 3f03940fcb
11 changed files with 350 additions and 24 deletions

View File

@@ -86,6 +86,7 @@ class Compare:
self._find_original_strings()
self._match_thunks()
self._match_exports()
self._find_vtordisp()
def _load_cvdump(self):
logger.info("Parsing %s ...", self.pdb_file)
@@ -198,7 +199,7 @@ class Compare:
self._db.match_variable(var.offset, var.name)
for tbl in codebase.iter_vtables():
self._db.match_vtable(tbl.offset, tbl.name)
self._db.match_vtable(tbl.offset, tbl.name, tbl.base_class)
for string in codebase.iter_strings():
# Not that we don't trust you, but we're checking the string
@@ -285,10 +286,105 @@ class Compare:
):
logger.debug("Matched export %s", repr(export_name))
def _find_vtordisp(self):
"""If there are any cases of virtual inheritance, we can read
through the vtables for those classes and find the vtable thunk
functions (vtordisp).
Our approach is this: walk both vtables and check where we have a
vtordisp in the recomp table. Inspect the function at that vtable
position (in both) and check whether we jump to the same function.
One potential pitfall here is that the virtual displacement could
differ between the thunks. We are not (yet) checking for this, so the
result is that the vtable will appear to match but we will have a diff
on the thunk in our regular function comparison.
We could do this differently and check only the original vtable,
construct the name of the vtordisp function and match based on that."""
for match in self._db.get_matches_by_type(SymbolType.VTABLE):
# We need some method of identifying vtables that
# might have thunks, and this ought to work okay.
if "{for" not in match.name:
continue
# TODO: We might want to fix this at the source (cvdump) instead.
# Any problem will be logged later when we compare the vtable.
vtable_size = 4 * (match.size // 4)
orig_table = self.orig_bin.read(match.orig_addr, vtable_size)
recomp_table = self.recomp_bin.read(match.recomp_addr, vtable_size)
raw_addrs = zip(
[t for (t,) in struct.iter_unpack("<L", orig_table)],
[t for (t,) in struct.iter_unpack("<L", recomp_table)],
)
# Now walk both vtables looking for thunks.
for orig_addr, recomp_addr in raw_addrs:
if not self._db.is_vtordisp(recomp_addr):
continue
thunk_fn = self.get_by_recomp(recomp_addr)
# Read the function bytes here.
# In practice, the adjuster thunk will be under 16 bytes.
# If we have thunks of unequal size, we can still tell whether
# they are thunking the same function by grabbing the
# JMP instruction at the end.
thunk_presumed_size = max(thunk_fn.size, 16)
# Strip off MSVC padding 0xcc bytes.
# This should be safe to do; it is highly unlikely that
# the MSB of the jump displacement would be 0xcc. (huge jump)
orig_thunk_bin = self.orig_bin.read(
orig_addr, thunk_presumed_size
).rstrip(b"\xcc")
recomp_thunk_bin = self.recomp_bin.read(
recomp_addr, thunk_presumed_size
).rstrip(b"\xcc")
# Read jump opcode and displacement (last 5 bytes)
(orig_jmp, orig_disp) = struct.unpack("<Bi", orig_thunk_bin[-5:])
(recomp_jmp, recomp_disp) = struct.unpack("<Bi", recomp_thunk_bin[-5:])
# Make sure it's a JMP
if orig_jmp != 0xE9 or recomp_jmp != 0xE9:
continue
# Calculate jump destination from the end of the JMP instruction
# i.e. the end of the function
orig_actual = orig_addr + len(orig_thunk_bin) + orig_disp
recomp_actual = recomp_addr + len(recomp_thunk_bin) + recomp_disp
# If they are thunking the same function, then this must be a match.
if self.is_pointer_match(orig_actual, recomp_actual):
if len(orig_thunk_bin) != len(recomp_thunk_bin):
logger.warning(
"Adjuster thunk %s (0x%x) is not exact",
thunk_fn.name,
orig_addr,
)
self._db.set_function_pair(orig_addr, recomp_addr)
def _compare_function(self, match: MatchInfo) -> DiffReport:
orig_raw = self.orig_bin.read(match.orig_addr, match.size)
recomp_raw = self.recomp_bin.read(match.recomp_addr, match.size)
# It's unlikely that a function other than an adjuster thunk would
# start with a SUB instruction, so alert to a possible wrong
# annotation here.
# There's probably a better place to do this, but we're reading
# the function bytes here already.
try:
if orig_raw[0] == 0x2B and recomp_raw[0] != 0x2B:
logger.warning(
"Possible thunk at 0x%x (%s)", match.orig_addr, match.name
)
except IndexError:
pass
def orig_lookup(addr: int) -> Optional[str]:
m = self._db.get_by_orig(addr)
if m is None:
@@ -432,7 +528,7 @@ class Compare:
match_type=SymbolType.VTABLE,
orig_addr=match.orig_addr,
recomp_addr=match.recomp_addr,
name=f"{match.name}::`vftable'",
name=match.name,
udiff=unified_diff,
ratio=ratio,
)

View File

@@ -4,6 +4,7 @@ import sqlite3
import logging
from typing import List, Optional
from isledecomp.types import SymbolType
from isledecomp.cvdump.demangler import get_vtordisp_name
_SETUP_SQL = """
DROP TABLE IF EXISTS `symbols`;
@@ -249,6 +250,37 @@ class CompareDb:
for (option, value) in cur.fetchall()
}
def is_vtordisp(self, recomp_addr: int) -> bool:
"""Check whether this function is a vtordisp based on its
decorated name. If its demangled name is missing the vtordisp
indicator, correct that."""
row = self._db.execute(
"""SELECT name, decorated_name
FROM `symbols`
WHERE recomp_addr = ?""",
(recomp_addr,),
).fetchone()
if row is None:
return False
(name, decorated_name) = row
if "`vtordisp" in name:
return True
new_name = get_vtordisp_name(decorated_name)
if new_name is None:
return False
self._db.execute(
"""UPDATE `symbols`
SET name = ?
WHERE recomp_addr = ?""",
(new_name, recomp_addr),
)
return True
def _find_potential_match(
self, name: str, compare_type: SymbolType
) -> Optional[int]:
@@ -323,12 +355,34 @@ class CompareDb:
return did_match
def match_vtable(self, addr: int, name: str) -> bool:
did_match = self._match_on(SymbolType.VTABLE, addr, name)
if not did_match:
logger.error("Failed to find vtable for class: %s", name)
def match_vtable(
self, addr: int, name: str, base_class: Optional[str] = None
) -> bool:
# Only allow a match against "Class:`vftable'"
# if this is the derived class.
name = (
f"{name}::`vftable'"
if base_class is None or base_class == name
else f"{name}::`vftable'{{for `{base_class}'}}"
)
return did_match
row = self._db.execute(
"""
SELECT recomp_addr
FROM `symbols`
WHERE orig_addr IS NULL
AND name = ?
AND (compare_type = ?)
LIMIT 1
""",
(name, SymbolType.VTABLE.value),
).fetchone()
if row is not None and self.set_pair(addr, row[0], SymbolType.VTABLE):
return True
logger.error("Failed to find vtable for class: %s", name)
return False
def match_static_variable(self, addr: int, name: str, function_addr: int) -> bool:
"""Matching a static function variable by combining the variable name