Reccmp comparison engine refactor (#405)

* Reccmp comparison engine refactor

* Remove redundant references to 'entry' symbol
This commit is contained in:
MS
2024-01-04 18:12:55 -05:00
committed by GitHub
parent eeb980fa0f
commit ce68a7b1f4
19 changed files with 987 additions and 279 deletions

View File

@@ -0,0 +1 @@
from .core import Compare

View File

@@ -0,0 +1,149 @@
import os
import logging
from typing import List, Optional
from isledecomp.cvdump.demangler import demangle_string_const
from isledecomp.cvdump import Cvdump, CvdumpAnalysis
from isledecomp.parser import DecompCodebase
from isledecomp.dir import walk_source_dir
from isledecomp.types import SymbolType
from .db import CompareDb, MatchInfo
from .lines import LinesDb
logger = logging.getLogger(__name__)
class Compare:
# pylint: disable=too-many-instance-attributes
def __init__(self, orig_bin, recomp_bin, pdb_file, code_dir):
self.orig_bin = orig_bin
self.recomp_bin = recomp_bin
self.pdb_file = pdb_file
self.code_dir = code_dir
self._lines_db = LinesDb(code_dir)
self._db = CompareDb()
self._load_cvdump()
self._load_markers()
self._find_original_strings()
def _load_cvdump(self):
logger.info("Parsing %s ...", self.pdb_file)
cv = (
Cvdump(self.pdb_file)
.lines()
.globals()
.publics()
.symbols()
.section_contributions()
.run()
)
res = CvdumpAnalysis(cv)
for sym in res.nodes:
# The PDB might contain sections that do not line up with the
# actual binary. The symbol "__except_list" is one example.
# In these cases, just skip this symbol and move on because
# we can't do much with it.
if not self.recomp_bin.is_valid_section(sym.section):
continue
addr = self.recomp_bin.get_abs_addr(sym.section, sym.offset)
# If this symbol is the final one in its section, we were not able to
# estimate its size because we didn't have the total size of that section.
# We can get this estimate now and assume that the final symbol occupies
# the remainder of the section.
if sym.estimated_size is None:
sym.estimated_size = (
self.recomp_bin.get_section_extent_by_index(sym.section)
- sym.offset
)
if sym.node_type == SymbolType.STRING:
string_info = demangle_string_const(sym.decorated_name)
# TODO: skip unicode for now. will need to handle these differently.
if string_info.is_utf16:
continue
raw = self.recomp_bin.read(addr, sym.size())
try:
sym.friendly_name = raw.decode("latin1")
except UnicodeDecodeError:
pass
self._db.set_recomp_symbol(addr, sym.node_type, sym.name(), sym.size())
for lineref in cv.lines:
addr = self.recomp_bin.get_abs_addr(lineref.section, lineref.offset)
self._lines_db.add_line(lineref.filename, lineref.line_no, addr)
# The _entry symbol is referenced in the PE header so we get this match for free.
self._db.set_function_pair(self.orig_bin.entry, self.recomp_bin.entry)
def _load_markers(self):
# Guess at module name from PDB file name
# reccmp checks the original binary filename; we could use this too
(module, _) = os.path.splitext(os.path.basename(self.pdb_file))
codefiles = list(walk_source_dir(self.code_dir))
codebase = DecompCodebase(codefiles, module)
# Match lineref functions first because this is a guaranteed match.
# If we have two functions that share the same name, and one is
# a lineref, we can match the nameref correctly because the lineref
# was already removed from consideration.
for fun in codebase.iter_line_functions():
recomp_addr = self._lines_db.search_line(fun.filename, fun.line_number)
if recomp_addr is not None:
self._db.set_function_pair(fun.offset, recomp_addr)
if fun.should_skip():
self._db.skip_compare(fun.offset)
for fun in codebase.iter_name_functions():
self._db.match_function(fun.offset, fun.name)
if fun.should_skip():
self._db.skip_compare(fun.offset)
for var in codebase.iter_variables():
self._db.match_variable(var.offset, var.name)
for tbl in codebase.iter_vtables():
self._db.match_vtable(tbl.offset, tbl.name)
def _find_original_strings(self):
"""Go to the original binary and look for the specified string constants
to find a match. This is a (relatively) expensive operation so we only
look at strings that we have not already matched via a STRING annotation."""
for string in self._db.get_unmatched_strings():
addr = self.orig_bin.find_string(string.encode("latin1"))
if addr is None:
escaped = repr(string)
logger.debug("Failed to find this string in the original: %s", escaped)
continue
self._db.match_string(addr, string)
def get_one_function(self, addr: int) -> Optional[MatchInfo]:
"""i.e. verbose mode for reccmp"""
return self._db.get_one_function(addr)
def get_functions(self) -> List[MatchInfo]:
return self._db.get_matches(SymbolType.FUNCTION)
def compare_functions(self):
pass
def compare_variables(self):
pass
def compare_pointers(self):
pass
def compare_strings(self):
pass
def compare_vtables(self):
pass

View File

@@ -0,0 +1,149 @@
"""Wrapper for database (here an in-memory sqlite database) that collects the
addresses/symbols that we want to compare between the original and recompiled binaries."""
import sqlite3
import logging
from collections import namedtuple
from typing import List, Optional
from isledecomp.types import SymbolType
_SETUP_SQL = """
DROP TABLE IF EXISTS `symbols`;
CREATE TABLE `symbols` (
compare_type int,
orig_addr int,
recomp_addr int,
name text,
size int,
should_skip int default(FALSE)
);
CREATE INDEX `symbols_re` ON `symbols` (recomp_addr);
CREATE INDEX `symbols_na` ON `symbols` (compare_type, name);
"""
MatchInfo = namedtuple("MatchInfo", "orig_addr, recomp_addr, size, name")
def matchinfo_factory(_, row):
return MatchInfo(*row)
logger = logging.getLogger(__name__)
class CompareDb:
def __init__(self):
self._db = sqlite3.connect(":memory:")
self._db.executescript(_SETUP_SQL)
def set_recomp_symbol(
self,
addr: int,
compare_type: Optional[SymbolType],
name: Optional[str],
size: Optional[int],
):
compare_value = compare_type.value if compare_type is not None else None
self._db.execute(
"INSERT INTO `symbols` (recomp_addr, compare_type, name, size) VALUES (?,?,?,?)",
(addr, compare_value, name, size),
)
def get_unmatched_strings(self) -> List[str]:
"""Return any strings not already identified by STRING markers."""
cur = self._db.execute(
"SELECT name FROM `symbols` WHERE compare_type = ? AND orig_addr IS NULL",
(SymbolType.STRING.value,),
)
return [string for (string,) in cur.fetchall()]
def get_one_function(self, addr: int) -> Optional[MatchInfo]:
cur = self._db.execute(
"""SELECT orig_addr, recomp_addr, size, name
FROM `symbols`
WHERE compare_type = ?
AND orig_addr = ?
AND recomp_addr IS NOT NULL
AND should_skip IS FALSE
ORDER BY orig_addr
""",
(SymbolType.FUNCTION.value, addr),
)
cur.row_factory = matchinfo_factory
return cur.fetchone()
def get_matches(self, compare_type: SymbolType) -> List[MatchInfo]:
cur = self._db.execute(
"""SELECT orig_addr, recomp_addr, size, name
FROM `symbols`
WHERE compare_type = ?
AND orig_addr IS NOT NULL
AND recomp_addr IS NOT NULL
AND should_skip IS FALSE
ORDER BY orig_addr
""",
(compare_type.value,),
)
cur.row_factory = matchinfo_factory
return cur.fetchall()
def set_function_pair(self, orig: int, recomp: int) -> bool:
"""For lineref match or _entry"""
cur = self._db.execute(
"UPDATE `symbols` SET orig_addr = ?, compare_type = ? WHERE recomp_addr = ?",
(orig, SymbolType.FUNCTION.value, recomp),
)
return cur.rowcount > 0
# TODO: Both ways required?
def skip_compare(self, orig: int):
self._db.execute(
"UPDATE `symbols` SET should_skip = TRUE WHERE orig_addr = ?", (orig,)
)
def _match_on(self, compare_type: SymbolType, addr: int, name: str) -> bool:
# Update the compare_type here too since the marker tells us what we should do
logger.debug("Looking for %s %s", compare_type.name.lower(), name)
cur = self._db.execute(
"""UPDATE `symbols`
SET orig_addr = ?, compare_type = ?
WHERE name = ?
AND orig_addr IS NULL
AND (compare_type = ? OR compare_type IS NULL)""",
(addr, compare_type.value, name, compare_type.value),
)
return cur.rowcount > 0
def match_function(self, addr: int, name: str) -> bool:
did_match = self._match_on(SymbolType.FUNCTION, addr, name)
if not did_match:
logger.error("Failed to find function symbol with name: %s", name)
return did_match
def match_vtable(self, addr: int, name: str) -> bool:
did_match = self._match_on(SymbolType.VTABLE, addr, name)
if not did_match:
logger.error("Failed to find vtable for class: %s", name)
return did_match
def match_variable(self, addr: int, name: str) -> bool:
did_match = self._match_on(SymbolType.DATA, addr, name) or self._match_on(
SymbolType.POINTER, addr, name
)
if not did_match:
logger.error("Failed to find variable: %s", name)
return did_match
def match_string(self, addr: int, value: str) -> bool:
did_match = self._match_on(SymbolType.STRING, addr, value)
if not did_match:
escaped = repr(value)
logger.error("Failed to find string: %s", escaped)

View File

@@ -0,0 +1,58 @@
"""Database used to match (filename, line_number) pairs
between FUNCTION markers and PDB analysis."""
import sqlite3
import logging
from typing import Optional
from pathlib import Path
from isledecomp.dir import PathResolver
_SETUP_SQL = """
DROP TABLE IF EXISTS `lineref`;
CREATE TABLE `lineref` (
path text not null,
filename text not null,
line int not null,
addr int not null
);
CREATE INDEX `file_line` ON `lineref` (filename, line);
"""
logger = logging.getLogger(__name__)
class LinesDb:
def __init__(self, code_dir) -> None:
self._db = sqlite3.connect(":memory:")
self._db.executescript(_SETUP_SQL)
self._path_resolver = PathResolver(code_dir)
def add_line(self, path: str, line_no: int, addr: int):
"""To be added from the LINES section of cvdump."""
sourcepath = self._path_resolver.resolve_cvdump(path)
filename = Path(sourcepath).name.lower()
self._db.execute(
"INSERT INTO `lineref` (path, filename, line, addr) VALUES (?,?,?,?)",
(sourcepath, filename, line_no, addr),
)
def search_line(self, path: str, line_no: int) -> Optional[int]:
"""Using path and line number from FUNCTION marker,
get the address of this function in the recomp."""
filename = Path(path).name.lower()
cur = self._db.execute(
"SELECT path, addr FROM `lineref` WHERE filename = ? AND line = ?",
(filename, line_no),
)
for source_path, addr in cur.fetchall():
if Path(path).samefile(source_path):
return addr
logger.error(
"Failed to find function symbol with filename and line: %s:%d",
path,
line_no,
)
return None