Read floating point constants up front (#868)

* Read floating point constants before sanitize

* Fix roadmap
This commit is contained in:
MS
2024-04-29 14:33:16 -04:00
committed by GitHub
parent 7c6c68d6f9
commit e7670f9a81
7 changed files with 97 additions and 47 deletions

View File

@@ -35,16 +35,6 @@ def from_hex(string: str) -> Optional[int]:
return None
def bytes_to_float(b: bytes) -> Optional[float]:
if len(b) == 4:
return struct.unpack("<f", b)[0]
if len(b) == 8:
return struct.unpack("<d", b)[0]
return None
def bytes_to_dword(b: bytes) -> Optional[int]:
if len(b) == 4:
return struct.unpack("<L", b)[0]
@@ -74,18 +64,6 @@ class ParseAsm:
return False
def float_replace(self, addr: int, data_size: int) -> Optional[str]:
if callable(self.bin_lookup):
float_bytes = self.bin_lookup(addr, data_size)
if float_bytes is None:
return None
float_value = bytes_to_float(float_bytes)
if float_value is not None:
return f"{float_value} (FLOAT)"
return None
def lookup(
self, addr: int, use_cache: bool = True, exact: bool = False
) -> Optional[str]:
@@ -165,25 +143,6 @@ class ParseAsm:
return match.group(0).replace(match.group(1), self.replace(value))
def hex_replace_float(self, match: re.Match) -> str:
"""Special case for replacements on float instructions.
If the pointer is a float constant, read it from the binary."""
value = int(match.group(1), 16)
# If we can find a variable name for this pointer, use it.
placeholder = self.lookup(value)
# Read what's under the pointer and show the decimal value.
if placeholder is None:
float_size = 8 if "qword" in match.string else 4
placeholder = self.float_replace(value, float_size)
# If we can't read the float, use a regular placeholder.
if placeholder is None:
placeholder = self.replace(value)
return match.group(0).replace(match.group(1), placeholder)
def sanitize(self, inst: DisasmLiteInst) -> Tuple[str, str]:
# For jumps or calls, if the entire op_str is a hex number, the value
# is a relative offset.
@@ -224,9 +183,6 @@ class ParseAsm:
if inst.mnemonic == "call":
# Special handling for absolute indirect CALL.
op_str = ptr_replace_regex.sub(self.hex_replace_indirect, inst.op_str)
elif inst.mnemonic.startswith("f"):
# If floating point instruction
op_str = ptr_replace_regex.sub(self.hex_replace_float, inst.op_str)
else:
op_str = ptr_replace_regex.sub(self.hex_replace_always, inst.op_str)

View File

@@ -82,6 +82,7 @@ class Compare:
self._load_cvdump()
self._load_markers()
self._find_original_strings()
self._find_float_const()
self._match_imports()
self._match_exports()
self._match_thunks()
@@ -249,6 +250,18 @@ class Compare:
self._db.match_string(addr, string)
def _find_float_const(self):
"""Add floating point constants in each binary to the database.
We are not matching anything right now because these values are not
deduped like strings."""
for addr, size, float_value in self.orig_bin.find_float_consts():
self._db.set_orig_symbol(addr, SymbolType.FLOAT, str(float_value), size)
for addr, size, float_value in self.recomp_bin.find_float_consts():
self._db.set_recomp_symbol(
addr, SymbolType.FLOAT, str(float_value), None, size
)
def _match_imports(self):
"""We can match imported functions based on the DLL name and
function symbol name."""

View File

@@ -84,6 +84,23 @@ class CompareDb:
self._db = sqlite3.connect(":memory:")
self._db.executescript(_SETUP_SQL)
def set_orig_symbol(
self,
addr: int,
compare_type: Optional[SymbolType],
name: Optional[str],
size: Optional[int],
):
# Ignore collisions here.
if self._orig_used(addr):
return
compare_value = compare_type.value if compare_type is not None else None
self._db.execute(
"INSERT INTO `symbols` (orig_addr, compare_type, name, size) VALUES (?,?,?,?)",
(addr, compare_value, name, size),
)
def set_recomp_symbol(
self,
addr: int,