Read floating point constants up front (#868)

* Read floating point constants before sanitize

* Fix roadmap
This commit is contained in:
MS
2024-04-29 14:33:16 -04:00
committed by GitHub
parent 7c6c68d6f9
commit e7670f9a81
7 changed files with 97 additions and 47 deletions

View File

@@ -2,7 +2,7 @@ import logging
import struct
import bisect
from functools import cached_property
from typing import List, Optional, Tuple
from typing import Iterator, List, Optional, Tuple
from dataclasses import dataclass
from collections import namedtuple
@@ -77,6 +77,18 @@ class Section:
def contains_vaddr(self, vaddr: int) -> bool:
return self.virtual_address <= vaddr < self.virtual_address + self.extent
def read_virtual(self, vaddr: int, size: int) -> memoryview:
ofs = vaddr - self.virtual_address
# Negative index will read from the end, which we don't want
if ofs < 0:
raise InvalidVirtualAddressError
try:
return self.view[ofs : ofs + size]
except IndexError as ex:
raise InvalidVirtualAddressError from ex
def addr_is_uninitialized(self, vaddr: int) -> bool:
"""We cannot rely on the IMAGE_SCN_CNT_UNINITIALIZED_DATA flag (0x80) in
the characteristics field so instead we determine it this way."""
@@ -109,6 +121,7 @@ class Bin:
self._section_vaddr: List[int] = []
self.find_str = find_str
self._potential_strings = {}
self._relocations = set()
self._relocated_addrs = set()
self.imports = []
self.thunks = []
@@ -279,11 +292,49 @@ class Bin:
# We are now interested in the relocated addresses themselves. Seek to the
# address where there is a relocation, then read the four bytes into our set.
reloc_addrs.sort()
self._relocations = set(reloc_addrs)
for section_id, offset in map(self.get_relative_addr, reloc_addrs):
section = self.get_section_by_index(section_id)
(relocated_addr,) = struct.unpack("<I", section.view[offset : offset + 4])
self._relocated_addrs.add(relocated_addr)
def find_float_consts(self) -> Iterator[Tuple[int, int, float]]:
"""Floating point instructions that refer to a memory address can
point to constant values. Search the code sections to find FP
instructions and check whether the pointer address refers to
read-only data."""
# TODO: Should check any section that has code, not just .text
text = self.get_section_by_name(".text")
rdata = self.get_section_by_name(".rdata")
# These are the addresses where a relocation occurs.
# Meaning: it points to an absolute address of something
for addr in self._relocations:
if not text.contains_vaddr(addr):
continue
# Read the two bytes before the relocated address.
# We will check against possible float opcodes
raw = text.read_virtual(addr - 2, 6)
(opcode, opcode_ext, const_addr) = struct.unpack("<BBL", raw)
# Skip right away if this is not const data
if not rdata.contains_vaddr(const_addr):
continue
if opcode_ext in (0x5, 0xD, 0x15, 0x1D, 0x25, 0x2D, 0x35, 0x3D):
if opcode in (0xD8, 0xD9):
# dword ptr -- single precision
(float_value,) = struct.unpack("<f", self.read(const_addr, 4))
yield (const_addr, 4, float_value)
elif opcode in (0xDC, 0xDD):
# qword ptr -- double precision
(float_value,) = struct.unpack("<d", self.read(const_addr, 8))
yield (const_addr, 8, float_value)
def _populate_imports(self):
"""Parse .idata to find imported DLLs and their functions."""
idata_ofs = self.get_section_offset_by_name(".idata")

View File

@@ -35,16 +35,6 @@ def from_hex(string: str) -> Optional[int]:
return None
def bytes_to_float(b: bytes) -> Optional[float]:
if len(b) == 4:
return struct.unpack("<f", b)[0]
if len(b) == 8:
return struct.unpack("<d", b)[0]
return None
def bytes_to_dword(b: bytes) -> Optional[int]:
if len(b) == 4:
return struct.unpack("<L", b)[0]
@@ -74,18 +64,6 @@ class ParseAsm:
return False
def float_replace(self, addr: int, data_size: int) -> Optional[str]:
if callable(self.bin_lookup):
float_bytes = self.bin_lookup(addr, data_size)
if float_bytes is None:
return None
float_value = bytes_to_float(float_bytes)
if float_value is not None:
return f"{float_value} (FLOAT)"
return None
def lookup(
self, addr: int, use_cache: bool = True, exact: bool = False
) -> Optional[str]:
@@ -165,25 +143,6 @@ class ParseAsm:
return match.group(0).replace(match.group(1), self.replace(value))
def hex_replace_float(self, match: re.Match) -> str:
"""Special case for replacements on float instructions.
If the pointer is a float constant, read it from the binary."""
value = int(match.group(1), 16)
# If we can find a variable name for this pointer, use it.
placeholder = self.lookup(value)
# Read what's under the pointer and show the decimal value.
if placeholder is None:
float_size = 8 if "qword" in match.string else 4
placeholder = self.float_replace(value, float_size)
# If we can't read the float, use a regular placeholder.
if placeholder is None:
placeholder = self.replace(value)
return match.group(0).replace(match.group(1), placeholder)
def sanitize(self, inst: DisasmLiteInst) -> Tuple[str, str]:
# For jumps or calls, if the entire op_str is a hex number, the value
# is a relative offset.
@@ -224,9 +183,6 @@ class ParseAsm:
if inst.mnemonic == "call":
# Special handling for absolute indirect CALL.
op_str = ptr_replace_regex.sub(self.hex_replace_indirect, inst.op_str)
elif inst.mnemonic.startswith("f"):
# If floating point instruction
op_str = ptr_replace_regex.sub(self.hex_replace_float, inst.op_str)
else:
op_str = ptr_replace_regex.sub(self.hex_replace_always, inst.op_str)

View File

@@ -82,6 +82,7 @@ class Compare:
self._load_cvdump()
self._load_markers()
self._find_original_strings()
self._find_float_const()
self._match_imports()
self._match_exports()
self._match_thunks()
@@ -249,6 +250,18 @@ class Compare:
self._db.match_string(addr, string)
def _find_float_const(self):
"""Add floating point constants in each binary to the database.
We are not matching anything right now because these values are not
deduped like strings."""
for addr, size, float_value in self.orig_bin.find_float_consts():
self._db.set_orig_symbol(addr, SymbolType.FLOAT, str(float_value), size)
for addr, size, float_value in self.recomp_bin.find_float_consts():
self._db.set_recomp_symbol(
addr, SymbolType.FLOAT, str(float_value), None, size
)
def _match_imports(self):
"""We can match imported functions based on the DLL name and
function symbol name."""

View File

@@ -84,6 +84,23 @@ class CompareDb:
self._db = sqlite3.connect(":memory:")
self._db.executescript(_SETUP_SQL)
def set_orig_symbol(
self,
addr: int,
compare_type: Optional[SymbolType],
name: Optional[str],
size: Optional[int],
):
# Ignore collisions here.
if self._orig_used(addr):
return
compare_value = compare_type.value if compare_type is not None else None
self._db.execute(
"INSERT INTO `symbols` (orig_addr, compare_type, name, size) VALUES (?,?,?,?)",
(addr, compare_value, name, size),
)
def set_recomp_symbol(
self,
addr: int,

View File

@@ -10,3 +10,4 @@ class SymbolType(Enum):
POINTER = 3
STRING = 4
VTABLE = 5
FLOAT = 6

View File

@@ -189,6 +189,7 @@ def test_jump_to_function():
assert op_str == "0x5555"
@pytest.mark.skip(reason="changed implementation")
def test_float_replacement():
"""Floating point constants often appear as pointers to data.
A good example is ViewROI::IntrinsicImportance and the subclass override
@@ -208,6 +209,7 @@ def test_float_replacement():
assert op_str == "dword ptr [3.1415927410125732 (FLOAT)]"
@pytest.mark.skip(reason="changed implementation")
def test_float_variable():
"""If there is a variable at the address referenced by a float instruction,
use the name instead of calling into the float replacement handler."""