mirror of
https://github.com/isledecomp/isle.git
synced 2025-10-23 00:14:22 +00:00
382 lines
13 KiB
Python
382 lines
13 KiB
Python
"""Wrapper for database (here an in-memory sqlite database) that collects the
|
|
addresses/symbols that we want to compare between the original and recompiled binaries."""
|
|
import sqlite3
|
|
import logging
|
|
from typing import List, Optional
|
|
from isledecomp.types import SymbolType
|
|
|
|
_SETUP_SQL = """
|
|
DROP TABLE IF EXISTS `symbols`;
|
|
DROP TABLE IF EXISTS `match_options`;
|
|
|
|
CREATE TABLE `symbols` (
|
|
compare_type int,
|
|
orig_addr int,
|
|
recomp_addr int,
|
|
name text,
|
|
decorated_name text,
|
|
size int
|
|
);
|
|
|
|
CREATE TABLE `match_options` (
|
|
addr int not null,
|
|
name text not null,
|
|
value text,
|
|
primary key (addr, name)
|
|
) without rowid;
|
|
|
|
CREATE VIEW IF NOT EXISTS `match_info`
|
|
(compare_type, orig_addr, recomp_addr, name, size) AS
|
|
SELECT compare_type, orig_addr, recomp_addr, name, size
|
|
FROM `symbols`
|
|
ORDER BY orig_addr NULLS LAST;
|
|
|
|
CREATE INDEX `symbols_or` ON `symbols` (orig_addr);
|
|
CREATE INDEX `symbols_re` ON `symbols` (recomp_addr);
|
|
CREATE INDEX `symbols_na` ON `symbols` (name);
|
|
"""
|
|
|
|
|
|
class MatchInfo:
|
|
def __init__(
|
|
self,
|
|
ctype: Optional[int],
|
|
orig: Optional[int],
|
|
recomp: Optional[int],
|
|
name: Optional[str],
|
|
size: Optional[int],
|
|
) -> None:
|
|
self.compare_type = SymbolType(ctype) if ctype is not None else None
|
|
self.orig_addr = orig
|
|
self.recomp_addr = recomp
|
|
self.name = name
|
|
self.size = size
|
|
|
|
def match_name(self) -> str:
|
|
"""Combination of the name and compare type.
|
|
Intended for name substitution in the diff. If there is a diff,
|
|
it will be more obvious what this symbol indicates."""
|
|
if self.name is None:
|
|
return None
|
|
|
|
ctype = self.compare_type.name if self.compare_type is not None else "UNK"
|
|
name = repr(self.name) if ctype == "STRING" else self.name
|
|
return f"{name} ({ctype})"
|
|
|
|
|
|
def matchinfo_factory(_, row):
|
|
return MatchInfo(*row)
|
|
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
class CompareDb:
|
|
def __init__(self):
|
|
self._db = sqlite3.connect(":memory:")
|
|
self._db.executescript(_SETUP_SQL)
|
|
|
|
def set_recomp_symbol(
|
|
self,
|
|
addr: int,
|
|
compare_type: Optional[SymbolType],
|
|
name: Optional[str],
|
|
decorated_name: Optional[str],
|
|
size: Optional[int],
|
|
):
|
|
# Ignore collisions here. The same recomp address can have
|
|
# multiple names (e.g. _strlwr and __strlwr)
|
|
if self._recomp_used(addr):
|
|
return
|
|
|
|
compare_value = compare_type.value if compare_type is not None else None
|
|
self._db.execute(
|
|
"INSERT INTO `symbols` (recomp_addr, compare_type, name, decorated_name, size) VALUES (?,?,?,?,?)",
|
|
(addr, compare_value, name, decorated_name, size),
|
|
)
|
|
|
|
def get_unmatched_strings(self) -> List[str]:
|
|
"""Return any strings not already identified by STRING markers."""
|
|
|
|
cur = self._db.execute(
|
|
"SELECT name FROM `symbols` WHERE compare_type = ? AND orig_addr IS NULL",
|
|
(SymbolType.STRING.value,),
|
|
)
|
|
|
|
return [string for (string,) in cur.fetchall()]
|
|
|
|
def get_all(self) -> List[MatchInfo]:
|
|
cur = self._db.execute("SELECT * FROM `match_info`")
|
|
cur.row_factory = matchinfo_factory
|
|
|
|
return cur.fetchall()
|
|
|
|
def get_matches(self) -> Optional[MatchInfo]:
|
|
cur = self._db.execute(
|
|
"""SELECT * FROM `match_info`
|
|
WHERE orig_addr IS NOT NULL
|
|
AND recomp_addr IS NOT NULL
|
|
""",
|
|
)
|
|
cur.row_factory = matchinfo_factory
|
|
|
|
return cur.fetchall()
|
|
|
|
def get_one_match(self, addr: int) -> Optional[MatchInfo]:
|
|
cur = self._db.execute(
|
|
"""SELECT * FROM `match_info`
|
|
WHERE orig_addr = ?
|
|
AND recomp_addr IS NOT NULL
|
|
""",
|
|
(addr,),
|
|
)
|
|
cur.row_factory = matchinfo_factory
|
|
return cur.fetchone()
|
|
|
|
def get_by_orig(self, addr: int) -> Optional[MatchInfo]:
|
|
cur = self._db.execute(
|
|
"""SELECT * FROM `match_info`
|
|
WHERE orig_addr = ?
|
|
""",
|
|
(addr,),
|
|
)
|
|
cur.row_factory = matchinfo_factory
|
|
return cur.fetchone()
|
|
|
|
def get_by_recomp(self, addr: int) -> Optional[MatchInfo]:
|
|
cur = self._db.execute(
|
|
"""SELECT * FROM `match_info`
|
|
WHERE recomp_addr = ?
|
|
""",
|
|
(addr,),
|
|
)
|
|
cur.row_factory = matchinfo_factory
|
|
return cur.fetchone()
|
|
|
|
def get_matches_by_type(self, compare_type: SymbolType) -> List[MatchInfo]:
|
|
cur = self._db.execute(
|
|
"""SELECT * FROM `match_info`
|
|
WHERE compare_type = ?
|
|
AND orig_addr IS NOT NULL
|
|
AND recomp_addr IS NOT NULL
|
|
""",
|
|
(compare_type.value,),
|
|
)
|
|
cur.row_factory = matchinfo_factory
|
|
|
|
return cur.fetchall()
|
|
|
|
def _orig_used(self, addr: int) -> bool:
|
|
cur = self._db.execute("SELECT 1 FROM symbols WHERE orig_addr = ?", (addr,))
|
|
return cur.fetchone() is not None
|
|
|
|
def _recomp_used(self, addr: int) -> bool:
|
|
cur = self._db.execute("SELECT 1 FROM symbols WHERE recomp_addr = ?", (addr,))
|
|
return cur.fetchone() is not None
|
|
|
|
def set_pair(
|
|
self, orig: int, recomp: int, compare_type: Optional[SymbolType] = None
|
|
) -> bool:
|
|
if self._orig_used(orig):
|
|
logger.error("Original address %s not unique!", hex(orig))
|
|
return False
|
|
|
|
compare_value = compare_type.value if compare_type is not None else None
|
|
cur = self._db.execute(
|
|
"UPDATE `symbols` SET orig_addr = ?, compare_type = ? WHERE recomp_addr = ?",
|
|
(orig, compare_value, recomp),
|
|
)
|
|
|
|
return cur.rowcount > 0
|
|
|
|
def set_pair_tentative(
|
|
self, orig: int, recomp: int, compare_type: Optional[SymbolType] = None
|
|
) -> bool:
|
|
"""Declare a match for the original and recomp addresses given, but only if:
|
|
1. The original address is not used elsewhere (as with set_pair)
|
|
2. The recomp address has not already been matched
|
|
If the compare_type is given, update this also, but only if NULL in the db.
|
|
|
|
The purpose here is to set matches found via some automated analysis
|
|
but to not overwrite a match provided by the human operator."""
|
|
if self._orig_used(orig):
|
|
# Probable and expected situation. Just ignore it.
|
|
return False
|
|
|
|
compare_value = compare_type.value if compare_type is not None else None
|
|
|
|
cur = self._db.execute(
|
|
"""UPDATE `symbols`
|
|
SET orig_addr = ?, compare_type = coalesce(compare_type, ?)
|
|
WHERE recomp_addr = ?
|
|
AND orig_addr IS NULL""",
|
|
(orig, compare_value, recomp),
|
|
)
|
|
|
|
return cur.rowcount > 0
|
|
|
|
def set_function_pair(self, orig: int, recomp: int) -> bool:
|
|
"""For lineref match or _entry"""
|
|
return self.set_pair(orig, recomp, SymbolType.FUNCTION)
|
|
|
|
def _set_opt_bool(self, addr: int, option: str, enabled: bool = True):
|
|
if enabled:
|
|
self._db.execute(
|
|
"""INSERT OR IGNORE INTO `match_options`
|
|
(addr, name)
|
|
VALUES (?, ?)""",
|
|
(addr, option),
|
|
)
|
|
else:
|
|
self._db.execute(
|
|
"""DELETE FROM `match_options` WHERE addr = ? AND name = ?""",
|
|
(addr, option),
|
|
)
|
|
|
|
def mark_stub(self, orig: int):
|
|
self._set_opt_bool(orig, "stub")
|
|
|
|
def skip_compare(self, orig: int):
|
|
self._set_opt_bool(orig, "skip")
|
|
|
|
def get_match_options(self, addr: int) -> Optional[dict]:
|
|
cur = self._db.execute(
|
|
"""SELECT name, value FROM `match_options` WHERE addr = ?""", (addr,)
|
|
)
|
|
|
|
return {
|
|
option: value if value is not None else True
|
|
for (option, value) in cur.fetchall()
|
|
}
|
|
|
|
def _find_potential_match(
|
|
self, name: str, compare_type: SymbolType
|
|
) -> Optional[int]:
|
|
"""Name lookup"""
|
|
match_decorate = compare_type != SymbolType.STRING and name.startswith("?")
|
|
if match_decorate:
|
|
sql = """
|
|
SELECT recomp_addr
|
|
FROM `symbols`
|
|
WHERE orig_addr IS NULL
|
|
AND decorated_name = ?
|
|
AND (compare_type IS NULL OR compare_type = ?)
|
|
LIMIT 1
|
|
"""
|
|
else:
|
|
sql = """
|
|
SELECT recomp_addr
|
|
FROM `symbols`
|
|
WHERE orig_addr IS NULL
|
|
AND name = ?
|
|
AND (compare_type IS NULL OR compare_type = ?)
|
|
LIMIT 1
|
|
"""
|
|
|
|
row = self._db.execute(sql, (name, compare_type.value)).fetchone()
|
|
return row[0] if row is not None else None
|
|
|
|
def _find_static_variable(
|
|
self, variable_name: str, function_sym: str
|
|
) -> Optional[int]:
|
|
"""Get the recomp address of a static function variable.
|
|
Matches using a LIKE clause on the combination of:
|
|
1. The variable name read from decomp marker.
|
|
2. The decorated name of the enclosing function.
|
|
For example, the variable "g_startupDelay" from function "IsleApp::Tick"
|
|
has symbol: `?g_startupDelay@?1??Tick@IsleApp@@QAEXH@Z@4HA`
|
|
The function's decorated name is: `?Tick@IsleApp@@QAEXH@Z`"""
|
|
|
|
row = self._db.execute(
|
|
"""SELECT recomp_addr FROM `symbols`
|
|
WHERE decorated_name LIKE '%' || ? || '%' || ? || '%'
|
|
AND orig_addr IS NULL
|
|
AND (compare_type = ? OR compare_type = ? OR compare_type IS NULL)""",
|
|
(
|
|
variable_name,
|
|
function_sym,
|
|
SymbolType.DATA.value,
|
|
SymbolType.POINTER.value,
|
|
),
|
|
).fetchone()
|
|
return row[0] if row is not None else None
|
|
|
|
def _match_on(self, compare_type: SymbolType, addr: int, name: str) -> bool:
|
|
# Update the compare_type here too since the marker tells us what we should do
|
|
|
|
# Truncate the name to 255 characters. It will not be possible to match a name
|
|
# longer than that because MSVC truncates the debug symbols to this length.
|
|
# See also: warning C4786.
|
|
name = name[:255]
|
|
|
|
logger.debug("Looking for %s %s", compare_type.name.lower(), name)
|
|
recomp_addr = self._find_potential_match(name, compare_type)
|
|
if recomp_addr is None:
|
|
return False
|
|
|
|
return self.set_pair(addr, recomp_addr, compare_type)
|
|
|
|
def match_function(self, addr: int, name: str) -> bool:
|
|
did_match = self._match_on(SymbolType.FUNCTION, addr, name)
|
|
if not did_match:
|
|
logger.error("Failed to find function symbol with name: %s", name)
|
|
|
|
return did_match
|
|
|
|
def match_vtable(self, addr: int, name: str) -> bool:
|
|
did_match = self._match_on(SymbolType.VTABLE, addr, name)
|
|
if not did_match:
|
|
logger.error("Failed to find vtable for class: %s", name)
|
|
|
|
return did_match
|
|
|
|
def match_static_variable(self, addr: int, name: str, function_addr: int) -> bool:
|
|
"""Matching a static function variable by combining the variable name
|
|
with the decorated (mangled) name of its parent function."""
|
|
|
|
cur = self._db.execute(
|
|
"""SELECT name, decorated_name
|
|
FROM `symbols`
|
|
WHERE orig_addr = ?""",
|
|
(function_addr,),
|
|
)
|
|
|
|
if (result := cur.fetchone()) is None:
|
|
logger.error("No function for static variable: %s", name)
|
|
return False
|
|
|
|
# Get the friendly name for the "failed to match" error message
|
|
(function_name, decorated_name) = result
|
|
|
|
recomp_addr = self._find_static_variable(name, decorated_name)
|
|
if recomp_addr is not None:
|
|
# TODO: This variable could be a pointer, but I don't think we
|
|
# have a way to tell that right now.
|
|
if self.set_pair(addr, recomp_addr, SymbolType.DATA):
|
|
return True
|
|
|
|
logger.error(
|
|
"Failed to match static variable %s from function %s",
|
|
name,
|
|
function_name,
|
|
)
|
|
|
|
return False
|
|
|
|
def match_variable(self, addr: int, name: str) -> bool:
|
|
did_match = self._match_on(SymbolType.DATA, addr, name) or self._match_on(
|
|
SymbolType.POINTER, addr, name
|
|
)
|
|
if not did_match:
|
|
logger.error("Failed to find variable: %s", name)
|
|
|
|
return did_match
|
|
|
|
def match_string(self, addr: int, value: str) -> bool:
|
|
did_match = self._match_on(SymbolType.STRING, addr, value)
|
|
if not did_match:
|
|
escaped = repr(value)
|
|
logger.error("Failed to find string: %s", escaped)
|
|
|
|
return did_match
|