Files
isle/tools/isledecomp/isledecomp/compare/core.py

533 lines
19 KiB
Python

import os
import logging
import difflib
import struct
from dataclasses import dataclass
from typing import Callable, Iterable, List, Optional
from isledecomp.bin import Bin as IsleBin
from isledecomp.cvdump.demangler import demangle_string_const
from isledecomp.cvdump import Cvdump, CvdumpAnalysis
from isledecomp.parser import DecompCodebase
from isledecomp.dir import walk_source_dir
from isledecomp.types import SymbolType
from isledecomp.compare.asm import ParseAsm, can_resolve_register_differences
from .db import CompareDb, MatchInfo
from .diff import combined_diff
from .lines import LinesDb
logger = logging.getLogger(__name__)
@dataclass
class DiffReport:
# pylint: disable=too-many-instance-attributes
match_type: SymbolType
orig_addr: int
recomp_addr: int
name: str
udiff: Optional[List[str]] = None
ratio: float = 0.0
is_effective_match: bool = False
is_stub: bool = False
@property
def effective_ratio(self) -> float:
return 1.0 if self.is_effective_match else self.ratio
def __str__(self) -> str:
"""For debug purposes. Proper diff printing (with coloring) is in another module."""
return f"{self.name} (0x{self.orig_addr:x}) {self.ratio*100:.02f}%{'*' if self.is_effective_match else ''}"
def create_reloc_lookup(bin_file: IsleBin) -> Callable[[int], bool]:
"""Function generator for relocation table lookup"""
def lookup(addr: int) -> bool:
return addr > bin_file.imagebase and bin_file.is_relocated_addr(addr)
return lookup
def create_float_lookup(bin_file: IsleBin) -> Callable[[int, int], Optional[str]]:
"""Function generator for floating point lookup"""
def lookup(addr: int, size: int) -> Optional[str]:
data = bin_file.read(addr, size)
# If this is a float constant, it should be initialized data.
if data is None:
return None
struct_str = "<f" if size == 4 else "<d"
try:
(float_value,) = struct.unpack(struct_str, data)
return str(float_value)
except struct.error:
return None
return lookup
class Compare:
# pylint: disable=too-many-instance-attributes
def __init__(
self, orig_bin: IsleBin, recomp_bin: IsleBin, pdb_file: str, code_dir: str
):
self.orig_bin = orig_bin
self.recomp_bin = recomp_bin
self.pdb_file = pdb_file
self.code_dir = code_dir
self._lines_db = LinesDb(code_dir)
self._db = CompareDb()
self._load_cvdump()
self._load_markers()
self._find_original_strings()
self._match_thunks()
self._match_exports()
def _load_cvdump(self):
logger.info("Parsing %s ...", self.pdb_file)
cv = (
Cvdump(self.pdb_file)
.lines()
.globals()
.publics()
.symbols()
.section_contributions()
.types()
.run()
)
res = CvdumpAnalysis(cv)
for sym in res.nodes:
# The PDB might contain sections that do not line up with the
# actual binary. The symbol "__except_list" is one example.
# In these cases, just skip this symbol and move on because
# we can't do much with it.
if not self.recomp_bin.is_valid_section(sym.section):
continue
addr = self.recomp_bin.get_abs_addr(sym.section, sym.offset)
# If this symbol is the final one in its section, we were not able to
# estimate its size because we didn't have the total size of that section.
# We can get this estimate now and assume that the final symbol occupies
# the remainder of the section.
if sym.estimated_size is None:
sym.estimated_size = (
self.recomp_bin.get_section_extent_by_index(sym.section)
- sym.offset
)
if sym.node_type == SymbolType.STRING:
string_info = demangle_string_const(sym.decorated_name)
if string_info is None:
logger.debug(
"Could not demangle string symbol: %s", sym.decorated_name
)
continue
# TODO: skip unicode for now. will need to handle these differently.
if string_info.is_utf16:
continue
raw = self.recomp_bin.read(addr, sym.size())
try:
# We use the string length reported in the mangled symbol as the
# data size, but this is not always accurate with respect to the
# null terminator.
# e.g. ??_C@_0BA@EFDM@MxObjectFactory?$AA@
# reported length: 16 (includes null terminator)
# c.f. ??_C@_03DPKJ@enz?$AA@
# reported length: 3 (does NOT include terminator)
# This will handle the case where the entire string contains "\x00"
# because those are distinct from the empty string of length 0.
decoded_string = raw.decode("latin1")
rstrip_string = decoded_string.rstrip("\x00")
if decoded_string != "" and rstrip_string != "":
sym.friendly_name = rstrip_string
else:
sym.friendly_name = decoded_string
except UnicodeDecodeError:
pass
self._db.set_recomp_symbol(
addr, sym.node_type, sym.name(), sym.decorated_name, sym.size()
)
for (section, offset), (filename, line_no) in res.verified_lines.items():
addr = self.recomp_bin.get_abs_addr(section, offset)
self._lines_db.add_line(filename, line_no, addr)
# The _entry symbol is referenced in the PE header so we get this match for free.
self._db.set_function_pair(self.orig_bin.entry, self.recomp_bin.entry)
def _load_markers(self):
# Assume module name is the base filename of the original binary.
(module, _) = os.path.splitext(os.path.basename(self.orig_bin.filename))
codefiles = list(walk_source_dir(self.code_dir))
codebase = DecompCodebase(codefiles, module.upper())
# Match lineref functions first because this is a guaranteed match.
# If we have two functions that share the same name, and one is
# a lineref, we can match the nameref correctly because the lineref
# was already removed from consideration.
for fun in codebase.iter_line_functions():
recomp_addr = self._lines_db.search_line(fun.filename, fun.line_number)
if recomp_addr is not None:
self._db.set_function_pair(fun.offset, recomp_addr)
if fun.should_skip():
self._db.mark_stub(fun.offset)
for fun in codebase.iter_name_functions():
self._db.match_function(fun.offset, fun.name)
if fun.should_skip():
self._db.mark_stub(fun.offset)
for var in codebase.iter_variables():
if var.is_static and var.parent_function is not None:
self._db.match_static_variable(
var.offset, var.name, var.parent_function
)
else:
self._db.match_variable(var.offset, var.name)
for tbl in codebase.iter_vtables():
self._db.match_vtable(tbl.offset, tbl.name)
for string in codebase.iter_strings():
# Not that we don't trust you, but we're checking the string
# annotation to make sure it is accurate.
try:
# TODO: would presumably fail for wchar_t strings
orig = self.orig_bin.read_string(string.offset).decode("latin1")
string_correct = string.name == orig
except UnicodeDecodeError:
string_correct = False
if not string_correct:
logger.error(
"Data at 0x%x does not match string %s",
string.offset,
repr(string.name),
)
continue
self._db.match_string(string.offset, string.name)
def _find_original_strings(self):
"""Go to the original binary and look for the specified string constants
to find a match. This is a (relatively) expensive operation so we only
look at strings that we have not already matched via a STRING annotation."""
for string in self._db.get_unmatched_strings():
addr = self.orig_bin.find_string(string.encode("latin1"))
if addr is None:
escaped = repr(string)
logger.debug("Failed to find this string in the original: %s", escaped)
continue
self._db.match_string(addr, string)
def _match_thunks(self):
orig_byaddr = {
addr: (dll.upper(), name) for (dll, name, addr) in self.orig_bin.imports
}
recomp_byname = {
(dll.upper(), name): addr for (dll, name, addr) in self.recomp_bin.imports
}
# Combine these two dictionaries. We don't care about imports from recomp
# not found in orig because:
# 1. They shouldn't be there
# 2. They are already identified via cvdump
orig_to_recomp = {
addr: recomp_byname.get(pair, None) for addr, pair in orig_byaddr.items()
}
# Now: we have the IAT offset in each matched up, so we need to make
# the connection between the thunk functions.
# We already have the symbol name we need from the PDB.
orig_thunks = {
iat_ofs: func_ofs for (func_ofs, iat_ofs) in self.orig_bin.thunks
}
recomp_thunks = {
iat_ofs: func_ofs for (func_ofs, iat_ofs) in self.recomp_bin.thunks
}
for orig, recomp in orig_to_recomp.items():
self._db.set_pair(orig, recomp, SymbolType.POINTER)
thunk_from_orig = orig_thunks.get(orig, None)
thunk_from_recomp = recomp_thunks.get(recomp, None)
if thunk_from_orig is not None and thunk_from_recomp is not None:
self._db.set_function_pair(thunk_from_orig, thunk_from_recomp)
# Don't compare thunk functions for now. The comparison isn't
# "useful" in the usual sense. We are only looking at the 6
# bytes of the jmp instruction and not the larger context of
# where this function is. Also: these will always match 100%
# because we are searching for a match to register this as a
# function in the first place.
self._db.skip_compare(thunk_from_orig)
def _match_exports(self):
# invert for name lookup
orig_exports = {y: x for (x, y) in self.orig_bin.exports}
for recomp_addr, export_name in self.recomp_bin.exports:
orig_addr = orig_exports.get(export_name)
if orig_addr is not None and self._db.set_pair_tentative(
orig_addr, recomp_addr
):
logger.debug("Matched export %s", repr(export_name))
def _compare_function(self, match: MatchInfo) -> DiffReport:
orig_raw = self.orig_bin.read(match.orig_addr, match.size)
recomp_raw = self.recomp_bin.read(match.recomp_addr, match.size)
def orig_lookup(addr: int) -> Optional[str]:
m = self._db.get_by_orig(addr)
if m is None:
return None
return m.match_name()
def recomp_lookup(addr: int) -> Optional[str]:
m = self._db.get_by_recomp(addr)
if m is None:
return None
return m.match_name()
orig_should_replace = create_reloc_lookup(self.orig_bin)
recomp_should_replace = create_reloc_lookup(self.recomp_bin)
orig_float = create_float_lookup(self.orig_bin)
recomp_float = create_float_lookup(self.recomp_bin)
orig_parse = ParseAsm(
relocate_lookup=orig_should_replace,
name_lookup=orig_lookup,
float_lookup=orig_float,
)
recomp_parse = ParseAsm(
relocate_lookup=recomp_should_replace,
name_lookup=recomp_lookup,
float_lookup=recomp_float,
)
orig_combined = orig_parse.parse_asm(orig_raw, match.orig_addr)
recomp_combined = recomp_parse.parse_asm(recomp_raw, match.recomp_addr)
# Detach addresses from asm lines for the text diff.
orig_asm = [x[1] for x in orig_combined]
recomp_asm = [x[1] for x in recomp_combined]
diff = difflib.SequenceMatcher(None, orig_asm, recomp_asm)
ratio = diff.ratio()
if ratio != 1.0:
# Check whether we can resolve register swaps which are actually
# perfect matches modulo compiler entropy.
is_effective_match = can_resolve_register_differences(orig_asm, recomp_asm)
unified_diff = combined_diff(
diff, orig_combined, recomp_combined, context_size=10
)
else:
is_effective_match = False
unified_diff = []
return DiffReport(
match_type=SymbolType.FUNCTION,
orig_addr=match.orig_addr,
recomp_addr=match.recomp_addr,
name=match.name,
udiff=unified_diff,
ratio=ratio,
is_effective_match=is_effective_match,
)
def _compare_vtable(self, match: MatchInfo) -> DiffReport:
vtable_size = match.size
# The vtable size should always be a multiple of 4 because that
# is the pointer size. If it is not (for whatever reason)
# it would cause iter_unpack to blow up so let's just fix it.
if vtable_size % 4 != 0:
logger.warning(
"Vtable for class %s has irregular size %d", match.name, vtable_size
)
vtable_size = 4 * (vtable_size // 4)
orig_table = self.orig_bin.read(match.orig_addr, vtable_size)
recomp_table = self.recomp_bin.read(match.recomp_addr, vtable_size)
raw_addrs = zip(
[t for (t,) in struct.iter_unpack("<L", orig_table)],
[t for (t,) in struct.iter_unpack("<L", recomp_table)],
)
def match_text(m: Optional[MatchInfo], raw_addr: Optional[int] = None) -> str:
"""Format the function reference at this vtable index as text.
If we have not identified this function, we have the option to
display the raw address. This is only worth doing for the original addr
because we should always be able to identify the recomp function.
If the original function is missing then this probably means that the class
should override the given function from the superclass, but we have not
implemented this yet.
"""
if m is not None:
orig = hex(m.orig_addr) if m.orig_addr is not None else "no orig"
recomp = (
hex(m.recomp_addr) if m.recomp_addr is not None else "no recomp"
)
return f"({orig} / {recomp}) : {m.name}"
if raw_addr is not None:
return f"0x{raw_addr:x} from orig not annotated."
return "(no match)"
orig_text = []
recomp_text = []
ratio = 0
n_entries = 0
# Now compare each pointer from the two vtables.
for i, (raw_orig, raw_recomp) in enumerate(raw_addrs):
orig = self._db.get_by_orig(raw_orig)
recomp = self._db.get_by_recomp(raw_recomp)
if (
orig is not None
and recomp is not None
and orig.recomp_addr == recomp.recomp_addr
):
ratio += 1
n_entries += 1
index = f"vtable0x{i*4:02x}"
orig_text.append((index, match_text(orig, raw_orig)))
recomp_text.append((index, match_text(recomp)))
ratio = ratio / float(n_entries) if n_entries > 0 else 0
# n=100: Show the entire table if there is a diff to display.
# Otherwise it would be confusing if the table got cut off.
sm = difflib.SequenceMatcher(
None,
[x[1] for x in orig_text],
[x[1] for x in recomp_text],
)
unified_diff = combined_diff(sm, orig_text, recomp_text, context_size=100)
return DiffReport(
match_type=SymbolType.VTABLE,
orig_addr=match.orig_addr,
recomp_addr=match.recomp_addr,
name=f"{match.name}::`vftable'",
udiff=unified_diff,
ratio=ratio,
)
def _compare_match(self, match: MatchInfo) -> Optional[DiffReport]:
"""Router for comparison type"""
if match.size == 0:
return None
options = self._db.get_match_options(match.orig_addr)
if options.get("skip", False):
return None
if options.get("stub", False):
return DiffReport(
match_type=match.compare_type,
orig_addr=match.orig_addr,
recomp_addr=match.recomp_addr,
name=match.name,
is_stub=True,
)
if match.compare_type == SymbolType.FUNCTION:
return self._compare_function(match)
if match.compare_type == SymbolType.VTABLE:
return self._compare_vtable(match)
return None
## Public API
def is_pointer_match(self, orig_addr, recomp_addr) -> bool:
"""Check whether these pointers point at the same thing"""
# Null pointers considered matching
if orig_addr == 0 and recomp_addr == 0:
return True
match = self._db.get_by_orig(orig_addr)
if match is None:
return False
return match.recomp_addr == recomp_addr
def get_by_orig(self, addr: int) -> Optional[MatchInfo]:
return self._db.get_by_orig(addr)
def get_by_recomp(self, addr: int) -> Optional[MatchInfo]:
return self._db.get_by_recomp(addr)
def get_all(self) -> List[MatchInfo]:
return self._db.get_all()
def get_functions(self) -> List[MatchInfo]:
return self._db.get_matches_by_type(SymbolType.FUNCTION)
def get_vtables(self) -> List[MatchInfo]:
return self._db.get_matches_by_type(SymbolType.VTABLE)
def get_variables(self) -> List[MatchInfo]:
return self._db.get_matches_by_type(SymbolType.DATA)
def compare_address(self, addr: int) -> Optional[DiffReport]:
match = self._db.get_one_match(addr)
if match is None:
return None
return self._compare_match(match)
def compare_all(self) -> Iterable[DiffReport]:
for match in self._db.get_matches():
diff = self._compare_match(match)
if diff is not None:
yield diff
def compare_functions(self) -> Iterable[DiffReport]:
for match in self.get_functions():
diff = self._compare_match(match)
if diff is not None:
yield diff
def compare_variables(self):
pass
def compare_pointers(self):
pass
def compare_strings(self):
pass
def compare_vtables(self) -> Iterable[DiffReport]:
for match in self.get_vtables():
diff = self._compare_match(match)
if diff is not None:
yield self._compare_match(match)