mirror of
https://github.com/isledecomp/isle.git
synced 2025-10-26 09:54:18 +00:00
Reccmp comparison engine refactor (#405)
* Reccmp comparison engine refactor * Remove redundant references to 'entry' symbol
This commit is contained in:
1
tools/isledecomp/isledecomp/compare/__init__.py
Normal file
1
tools/isledecomp/isledecomp/compare/__init__.py
Normal file
@@ -0,0 +1 @@
|
||||
from .core import Compare
|
||||
149
tools/isledecomp/isledecomp/compare/core.py
Normal file
149
tools/isledecomp/isledecomp/compare/core.py
Normal file
@@ -0,0 +1,149 @@
|
||||
import os
|
||||
import logging
|
||||
from typing import List, Optional
|
||||
from isledecomp.cvdump.demangler import demangle_string_const
|
||||
from isledecomp.cvdump import Cvdump, CvdumpAnalysis
|
||||
from isledecomp.parser import DecompCodebase
|
||||
from isledecomp.dir import walk_source_dir
|
||||
from isledecomp.types import SymbolType
|
||||
from .db import CompareDb, MatchInfo
|
||||
from .lines import LinesDb
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class Compare:
|
||||
# pylint: disable=too-many-instance-attributes
|
||||
def __init__(self, orig_bin, recomp_bin, pdb_file, code_dir):
|
||||
self.orig_bin = orig_bin
|
||||
self.recomp_bin = recomp_bin
|
||||
self.pdb_file = pdb_file
|
||||
self.code_dir = code_dir
|
||||
|
||||
self._lines_db = LinesDb(code_dir)
|
||||
self._db = CompareDb()
|
||||
|
||||
self._load_cvdump()
|
||||
self._load_markers()
|
||||
self._find_original_strings()
|
||||
|
||||
def _load_cvdump(self):
|
||||
logger.info("Parsing %s ...", self.pdb_file)
|
||||
cv = (
|
||||
Cvdump(self.pdb_file)
|
||||
.lines()
|
||||
.globals()
|
||||
.publics()
|
||||
.symbols()
|
||||
.section_contributions()
|
||||
.run()
|
||||
)
|
||||
res = CvdumpAnalysis(cv)
|
||||
|
||||
for sym in res.nodes:
|
||||
# The PDB might contain sections that do not line up with the
|
||||
# actual binary. The symbol "__except_list" is one example.
|
||||
# In these cases, just skip this symbol and move on because
|
||||
# we can't do much with it.
|
||||
if not self.recomp_bin.is_valid_section(sym.section):
|
||||
continue
|
||||
|
||||
addr = self.recomp_bin.get_abs_addr(sym.section, sym.offset)
|
||||
|
||||
# If this symbol is the final one in its section, we were not able to
|
||||
# estimate its size because we didn't have the total size of that section.
|
||||
# We can get this estimate now and assume that the final symbol occupies
|
||||
# the remainder of the section.
|
||||
if sym.estimated_size is None:
|
||||
sym.estimated_size = (
|
||||
self.recomp_bin.get_section_extent_by_index(sym.section)
|
||||
- sym.offset
|
||||
)
|
||||
|
||||
if sym.node_type == SymbolType.STRING:
|
||||
string_info = demangle_string_const(sym.decorated_name)
|
||||
# TODO: skip unicode for now. will need to handle these differently.
|
||||
if string_info.is_utf16:
|
||||
continue
|
||||
|
||||
raw = self.recomp_bin.read(addr, sym.size())
|
||||
try:
|
||||
sym.friendly_name = raw.decode("latin1")
|
||||
except UnicodeDecodeError:
|
||||
pass
|
||||
|
||||
self._db.set_recomp_symbol(addr, sym.node_type, sym.name(), sym.size())
|
||||
|
||||
for lineref in cv.lines:
|
||||
addr = self.recomp_bin.get_abs_addr(lineref.section, lineref.offset)
|
||||
self._lines_db.add_line(lineref.filename, lineref.line_no, addr)
|
||||
|
||||
# The _entry symbol is referenced in the PE header so we get this match for free.
|
||||
self._db.set_function_pair(self.orig_bin.entry, self.recomp_bin.entry)
|
||||
|
||||
def _load_markers(self):
|
||||
# Guess at module name from PDB file name
|
||||
# reccmp checks the original binary filename; we could use this too
|
||||
(module, _) = os.path.splitext(os.path.basename(self.pdb_file))
|
||||
|
||||
codefiles = list(walk_source_dir(self.code_dir))
|
||||
codebase = DecompCodebase(codefiles, module)
|
||||
|
||||
# Match lineref functions first because this is a guaranteed match.
|
||||
# If we have two functions that share the same name, and one is
|
||||
# a lineref, we can match the nameref correctly because the lineref
|
||||
# was already removed from consideration.
|
||||
for fun in codebase.iter_line_functions():
|
||||
recomp_addr = self._lines_db.search_line(fun.filename, fun.line_number)
|
||||
if recomp_addr is not None:
|
||||
self._db.set_function_pair(fun.offset, recomp_addr)
|
||||
if fun.should_skip():
|
||||
self._db.skip_compare(fun.offset)
|
||||
|
||||
for fun in codebase.iter_name_functions():
|
||||
self._db.match_function(fun.offset, fun.name)
|
||||
if fun.should_skip():
|
||||
self._db.skip_compare(fun.offset)
|
||||
|
||||
for var in codebase.iter_variables():
|
||||
self._db.match_variable(var.offset, var.name)
|
||||
|
||||
for tbl in codebase.iter_vtables():
|
||||
self._db.match_vtable(tbl.offset, tbl.name)
|
||||
|
||||
def _find_original_strings(self):
|
||||
"""Go to the original binary and look for the specified string constants
|
||||
to find a match. This is a (relatively) expensive operation so we only
|
||||
look at strings that we have not already matched via a STRING annotation."""
|
||||
|
||||
for string in self._db.get_unmatched_strings():
|
||||
addr = self.orig_bin.find_string(string.encode("latin1"))
|
||||
if addr is None:
|
||||
escaped = repr(string)
|
||||
logger.debug("Failed to find this string in the original: %s", escaped)
|
||||
continue
|
||||
|
||||
self._db.match_string(addr, string)
|
||||
|
||||
def get_one_function(self, addr: int) -> Optional[MatchInfo]:
|
||||
"""i.e. verbose mode for reccmp"""
|
||||
return self._db.get_one_function(addr)
|
||||
|
||||
def get_functions(self) -> List[MatchInfo]:
|
||||
return self._db.get_matches(SymbolType.FUNCTION)
|
||||
|
||||
def compare_functions(self):
|
||||
pass
|
||||
|
||||
def compare_variables(self):
|
||||
pass
|
||||
|
||||
def compare_pointers(self):
|
||||
pass
|
||||
|
||||
def compare_strings(self):
|
||||
pass
|
||||
|
||||
def compare_vtables(self):
|
||||
pass
|
||||
149
tools/isledecomp/isledecomp/compare/db.py
Normal file
149
tools/isledecomp/isledecomp/compare/db.py
Normal file
@@ -0,0 +1,149 @@
|
||||
"""Wrapper for database (here an in-memory sqlite database) that collects the
|
||||
addresses/symbols that we want to compare between the original and recompiled binaries."""
|
||||
import sqlite3
|
||||
import logging
|
||||
from collections import namedtuple
|
||||
from typing import List, Optional
|
||||
from isledecomp.types import SymbolType
|
||||
|
||||
_SETUP_SQL = """
|
||||
DROP TABLE IF EXISTS `symbols`;
|
||||
CREATE TABLE `symbols` (
|
||||
compare_type int,
|
||||
orig_addr int,
|
||||
recomp_addr int,
|
||||
name text,
|
||||
size int,
|
||||
should_skip int default(FALSE)
|
||||
);
|
||||
CREATE INDEX `symbols_re` ON `symbols` (recomp_addr);
|
||||
CREATE INDEX `symbols_na` ON `symbols` (compare_type, name);
|
||||
"""
|
||||
|
||||
|
||||
MatchInfo = namedtuple("MatchInfo", "orig_addr, recomp_addr, size, name")
|
||||
|
||||
|
||||
def matchinfo_factory(_, row):
|
||||
return MatchInfo(*row)
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class CompareDb:
|
||||
def __init__(self):
|
||||
self._db = sqlite3.connect(":memory:")
|
||||
self._db.executescript(_SETUP_SQL)
|
||||
|
||||
def set_recomp_symbol(
|
||||
self,
|
||||
addr: int,
|
||||
compare_type: Optional[SymbolType],
|
||||
name: Optional[str],
|
||||
size: Optional[int],
|
||||
):
|
||||
compare_value = compare_type.value if compare_type is not None else None
|
||||
self._db.execute(
|
||||
"INSERT INTO `symbols` (recomp_addr, compare_type, name, size) VALUES (?,?,?,?)",
|
||||
(addr, compare_value, name, size),
|
||||
)
|
||||
|
||||
def get_unmatched_strings(self) -> List[str]:
|
||||
"""Return any strings not already identified by STRING markers."""
|
||||
|
||||
cur = self._db.execute(
|
||||
"SELECT name FROM `symbols` WHERE compare_type = ? AND orig_addr IS NULL",
|
||||
(SymbolType.STRING.value,),
|
||||
)
|
||||
|
||||
return [string for (string,) in cur.fetchall()]
|
||||
|
||||
def get_one_function(self, addr: int) -> Optional[MatchInfo]:
|
||||
cur = self._db.execute(
|
||||
"""SELECT orig_addr, recomp_addr, size, name
|
||||
FROM `symbols`
|
||||
WHERE compare_type = ?
|
||||
AND orig_addr = ?
|
||||
AND recomp_addr IS NOT NULL
|
||||
AND should_skip IS FALSE
|
||||
ORDER BY orig_addr
|
||||
""",
|
||||
(SymbolType.FUNCTION.value, addr),
|
||||
)
|
||||
cur.row_factory = matchinfo_factory
|
||||
return cur.fetchone()
|
||||
|
||||
def get_matches(self, compare_type: SymbolType) -> List[MatchInfo]:
|
||||
cur = self._db.execute(
|
||||
"""SELECT orig_addr, recomp_addr, size, name
|
||||
FROM `symbols`
|
||||
WHERE compare_type = ?
|
||||
AND orig_addr IS NOT NULL
|
||||
AND recomp_addr IS NOT NULL
|
||||
AND should_skip IS FALSE
|
||||
ORDER BY orig_addr
|
||||
""",
|
||||
(compare_type.value,),
|
||||
)
|
||||
cur.row_factory = matchinfo_factory
|
||||
|
||||
return cur.fetchall()
|
||||
|
||||
def set_function_pair(self, orig: int, recomp: int) -> bool:
|
||||
"""For lineref match or _entry"""
|
||||
cur = self._db.execute(
|
||||
"UPDATE `symbols` SET orig_addr = ?, compare_type = ? WHERE recomp_addr = ?",
|
||||
(orig, SymbolType.FUNCTION.value, recomp),
|
||||
)
|
||||
|
||||
return cur.rowcount > 0
|
||||
# TODO: Both ways required?
|
||||
|
||||
def skip_compare(self, orig: int):
|
||||
self._db.execute(
|
||||
"UPDATE `symbols` SET should_skip = TRUE WHERE orig_addr = ?", (orig,)
|
||||
)
|
||||
|
||||
def _match_on(self, compare_type: SymbolType, addr: int, name: str) -> bool:
|
||||
# Update the compare_type here too since the marker tells us what we should do
|
||||
logger.debug("Looking for %s %s", compare_type.name.lower(), name)
|
||||
cur = self._db.execute(
|
||||
"""UPDATE `symbols`
|
||||
SET orig_addr = ?, compare_type = ?
|
||||
WHERE name = ?
|
||||
AND orig_addr IS NULL
|
||||
AND (compare_type = ? OR compare_type IS NULL)""",
|
||||
(addr, compare_type.value, name, compare_type.value),
|
||||
)
|
||||
|
||||
return cur.rowcount > 0
|
||||
|
||||
def match_function(self, addr: int, name: str) -> bool:
|
||||
did_match = self._match_on(SymbolType.FUNCTION, addr, name)
|
||||
if not did_match:
|
||||
logger.error("Failed to find function symbol with name: %s", name)
|
||||
|
||||
return did_match
|
||||
|
||||
def match_vtable(self, addr: int, name: str) -> bool:
|
||||
did_match = self._match_on(SymbolType.VTABLE, addr, name)
|
||||
if not did_match:
|
||||
logger.error("Failed to find vtable for class: %s", name)
|
||||
|
||||
return did_match
|
||||
|
||||
def match_variable(self, addr: int, name: str) -> bool:
|
||||
did_match = self._match_on(SymbolType.DATA, addr, name) or self._match_on(
|
||||
SymbolType.POINTER, addr, name
|
||||
)
|
||||
if not did_match:
|
||||
logger.error("Failed to find variable: %s", name)
|
||||
|
||||
return did_match
|
||||
|
||||
def match_string(self, addr: int, value: str) -> bool:
|
||||
did_match = self._match_on(SymbolType.STRING, addr, value)
|
||||
if not did_match:
|
||||
escaped = repr(value)
|
||||
logger.error("Failed to find string: %s", escaped)
|
||||
58
tools/isledecomp/isledecomp/compare/lines.py
Normal file
58
tools/isledecomp/isledecomp/compare/lines.py
Normal file
@@ -0,0 +1,58 @@
|
||||
"""Database used to match (filename, line_number) pairs
|
||||
between FUNCTION markers and PDB analysis."""
|
||||
import sqlite3
|
||||
import logging
|
||||
from typing import Optional
|
||||
from pathlib import Path
|
||||
from isledecomp.dir import PathResolver
|
||||
|
||||
|
||||
_SETUP_SQL = """
|
||||
DROP TABLE IF EXISTS `lineref`;
|
||||
CREATE TABLE `lineref` (
|
||||
path text not null,
|
||||
filename text not null,
|
||||
line int not null,
|
||||
addr int not null
|
||||
);
|
||||
CREATE INDEX `file_line` ON `lineref` (filename, line);
|
||||
"""
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class LinesDb:
|
||||
def __init__(self, code_dir) -> None:
|
||||
self._db = sqlite3.connect(":memory:")
|
||||
self._db.executescript(_SETUP_SQL)
|
||||
self._path_resolver = PathResolver(code_dir)
|
||||
|
||||
def add_line(self, path: str, line_no: int, addr: int):
|
||||
"""To be added from the LINES section of cvdump."""
|
||||
sourcepath = self._path_resolver.resolve_cvdump(path)
|
||||
filename = Path(sourcepath).name.lower()
|
||||
|
||||
self._db.execute(
|
||||
"INSERT INTO `lineref` (path, filename, line, addr) VALUES (?,?,?,?)",
|
||||
(sourcepath, filename, line_no, addr),
|
||||
)
|
||||
|
||||
def search_line(self, path: str, line_no: int) -> Optional[int]:
|
||||
"""Using path and line number from FUNCTION marker,
|
||||
get the address of this function in the recomp."""
|
||||
filename = Path(path).name.lower()
|
||||
cur = self._db.execute(
|
||||
"SELECT path, addr FROM `lineref` WHERE filename = ? AND line = ?",
|
||||
(filename, line_no),
|
||||
)
|
||||
for source_path, addr in cur.fetchall():
|
||||
if Path(path).samefile(source_path):
|
||||
return addr
|
||||
|
||||
logger.error(
|
||||
"Failed to find function symbol with filename and line: %s:%d",
|
||||
path,
|
||||
line_no,
|
||||
)
|
||||
return None
|
||||
Reference in New Issue
Block a user