Reccmp: Use symbol names in asm output (#433)

* Name substitution for reccmp asm output

* Decomp marker corrections

* Fix a few annotations

* Fix IslePathActor dtor

* Fix audio presenter

* Fix LegoEntity::Create

* Fix Pizza and related

* Fix path part

* Add missing annotations

* Add missing annotations

* Add more missing annotations

* Fix MxNotificationParam

* More fixes

* More fixes

* Add missing annotations

* Fixes

* More annotations

* More annotations

* More annotations

* More annotations

* Fixes and annotations

* Find imports and thunk functions

* Fix more bugs

* Add some markers for LEGO1 imports, fix SIZE comment

* Add more annotations

* Rename annotation

* Fix bugs and annotations

* Fix bug

* Order

* Update legoanimpresenter.h

* Re-enable print-rec-addr option

---------

Co-authored-by: Christian Semmler <mail@csemmler.com>
This commit is contained in:
MS
2024-01-14 16:28:46 -05:00
committed by GitHub
parent 7f7e6e37dd
commit 7e9d3bde65
73 changed files with 1357 additions and 427 deletions

View File

@@ -97,6 +97,8 @@ class Bin:
self.find_str = find_str
self._potential_strings = {}
self._relocated_addrs = set()
self.imports = []
self.thunks = []
def __enter__(self):
logger.debug("Bin %s Enter", self.filename)
@@ -132,6 +134,8 @@ class Bin:
sect.virtual_address += self.imagebase
self._populate_relocations()
self._populate_imports()
self._populate_thunks()
# This is a (semi) expensive lookup that is not necesssary in every case.
# We can find strings in the original if we have coverage using STRING markers.
@@ -238,6 +242,78 @@ class Bin:
(relocated_addr,) = struct.unpack("<I", self.read(addr, 4))
self._relocated_addrs.add(relocated_addr)
def _populate_imports(self):
"""Parse .idata to find imported DLLs and their functions."""
idata_ofs = self.get_section_offset_by_name(".idata")
def iter_image_import():
ofs = idata_ofs
while True:
# Read 5 dwords until all are zero.
image_import_descriptor = struct.unpack("<5I", self.read(ofs, 20))
ofs += 20
if all(x == 0 for x in image_import_descriptor):
break
(rva_ilt, _, __, dll_name, rva_iat) = image_import_descriptor
# Convert relative virtual addresses into absolute
yield (
self.imagebase + rva_ilt,
self.imagebase + dll_name,
self.imagebase + rva_iat,
)
image_import_descriptors = list(iter_image_import())
def iter_imports():
# ILT = Import Lookup Table
# IAT = Import Address Table
# ILT gives us the symbol name of the import.
# IAT gives the address. The compiler generated a thunk function
# that jumps to the value of this address.
for start_ilt, dll_addr, start_iat in image_import_descriptors:
dll_name = self.read_string(dll_addr).decode("ascii")
ofs_ilt = start_ilt
# Address of "__imp__*" symbols.
ofs_iat = start_iat
while True:
(lookup_addr,) = struct.unpack("<L", self.read(ofs_ilt, 4))
(import_addr,) = struct.unpack("<L", self.read(ofs_iat, 4))
if lookup_addr == 0 or import_addr == 0:
break
# Skip the "Hint" field, 2 bytes
name_ofs = lookup_addr + self.imagebase + 2
symbol_name = self.read_string(name_ofs).decode("ascii")
yield (dll_name, symbol_name, ofs_iat)
ofs_ilt += 4
ofs_iat += 4
self.imports = list(iter_imports())
def _populate_thunks(self):
"""For each imported function, we generate a thunk function. The only
instruction in the function is a jmp to the address in .idata.
Search .text to find these functions."""
text_sect = self._get_section_by_name(".text")
idata_sect = self._get_section_by_name(".idata")
start = text_sect.virtual_address
ofs = start
bs = self.read(ofs, text_sect.size_of_raw_data)
for shift in (0, 2, 4):
window = bs[shift:]
win_end = 6 * (len(window) // 6)
for i, (b0, b1, jmp_ofs) in enumerate(
struct.iter_unpack("<2BL", window[:win_end])
):
if (b0, b1) == (0xFF, 0x25) and idata_sect.contains_vaddr(jmp_ofs):
# Record the address of the jmp instruction and the destination in .idata
thunk_ofs = ofs + shift + i * 6
self.thunks.append((thunk_ofs, jmp_ofs))
def _set_section_for_vaddr(self, vaddr: int):
if self.last_section is not None and self.last_section.contains_vaddr(vaddr):
return
@@ -319,6 +395,18 @@ class Bin:
return section is not None
def read_string(self, offset: int, chunk_size: int = 1000) -> Optional[bytes]:
"""Read until we find a zero byte."""
b = self.read(offset, chunk_size)
if b is None:
return None
try:
return b[: b.index(b"\x00")]
except ValueError:
# No terminator found, just return what we have
return b
def read(self, offset: int, size: int) -> Optional[bytes]:
"""Read (at most) the given number of bytes at the given virtual address.
If we return None, the given address points to uninitialized data."""

View File

@@ -0,0 +1,2 @@
from .parse import ParseAsm
from .swap import can_resolve_register_differences

View File

@@ -0,0 +1,152 @@
"""Converts x86 machine code into text (i.e. assembly). The end goal is to
compare the code in the original and recomp binaries, using longest common
subsequence (LCS), i.e. difflib.SequenceMatcher.
The capstone library takes the raw bytes and gives us the mnemnonic
and operand(s) for each instruction. We need to "sanitize" the text further
so that virtual addresses are replaced by symbol name or a generic
placeholder string."""
import re
from typing import Callable, List, Optional, Tuple
from collections import namedtuple
from capstone import Cs, CS_ARCH_X86, CS_MODE_32
disassembler = Cs(CS_ARCH_X86, CS_MODE_32)
ptr_replace_regex = re.compile(r"ptr \[(0x[0-9a-fA-F]+)\]")
DisasmLiteInst = namedtuple("DisasmLiteInst", "address, size, mnemonic, op_str")
def from_hex(string: str) -> Optional[int]:
try:
return int(string, 16)
except ValueError:
pass
return None
class ParseAsm:
def __init__(
self,
relocate_lookup: Optional[Callable[[int], bool]] = None,
name_lookup: Optional[Callable[[int], str]] = None,
) -> None:
self.relocate_lookup = relocate_lookup
self.name_lookup = name_lookup
self.replacements = {}
self.number_placeholders = True
def reset(self):
self.replacements = {}
def is_relocated(self, addr: int) -> bool:
if callable(self.relocate_lookup):
return self.relocate_lookup(addr)
return False
def lookup(self, addr: int) -> Optional[str]:
"""Return a replacement name for this address if we find one."""
if (cached := self.replacements.get(addr, None)) is not None:
return cached
if callable(self.name_lookup):
if (name := self.name_lookup(addr)) is not None:
self.replacements[addr] = name
return name
return None
def replace(self, addr: int) -> str:
"""Same function as lookup above, but here we return a placeholder
if there is no better name to use."""
if (name := self.lookup(addr)) is not None:
return name
# The placeholder number corresponds to the number of addresses we have
# already replaced. This is so the number will be consistent across the diff
# if we can replace some symbols with actual names in recomp but not orig.
idx = len(self.replacements) + 1
placeholder = f"<OFFSET{idx}>" if self.number_placeholders else "<OFFSET>"
self.replacements[addr] = placeholder
return placeholder
def sanitize(self, inst: DisasmLiteInst) -> Tuple[str, str]:
if len(inst.op_str) == 0:
# Nothing to sanitize
return (inst.mnemonic, "")
# For jumps or calls, if the entire op_str is a hex number, the value
# is a relative offset.
# Otherwise (i.e. it looks like `dword ptr [address]`) it is an
# absolute indirect that we will handle below.
# Providing the starting address of the function to capstone.disasm has
# automatically resolved relative offsets to an absolute address.
# We will have to undo this for some of the jumps or they will not match.
op_str_address = from_hex(inst.op_str)
if op_str_address is not None:
if inst.mnemonic == "call":
return (inst.mnemonic, self.replace(op_str_address))
if inst.mnemonic == "jmp":
# The unwind section contains JMPs to other functions.
# If we have a name for this address, use it. If not,
# do not create a new placeholder. We will instead
# fall through to generic jump handling below.
potential_name = self.lookup(op_str_address)
if potential_name is not None:
return (inst.mnemonic, potential_name)
if inst.mnemonic.startswith("j"):
# i.e. if this is any jump
# Show the jump offset rather than the absolute address
jump_displacement = op_str_address - (inst.address + inst.size)
return (inst.mnemonic, hex(jump_displacement))
def filter_out_ptr(match):
"""Helper for re.sub, see below"""
offset = from_hex(match.group(1))
if offset is not None:
# We assume this is always an address to replace
placeholder = self.replace(offset)
return f"ptr [{placeholder}]"
# Strict regex should ensure we can read the hex number.
# But just in case: return the string with no changes
return match.group(0)
op_str = ptr_replace_regex.sub(filter_out_ptr, inst.op_str)
# Performance hack:
# Skip this step if there is nothing left to consider replacing.
if "0x" in op_str:
# Replace immediate values with name or placeholder (where appropriate)
words = op_str.split(", ")
for i, word in enumerate(words):
try:
inttest = int(word, 16)
# If this value is a virtual address, it is referenced absolutely,
# which means it must be in the relocation table.
if self.is_relocated(inttest):
words[i] = self.replace(inttest)
except ValueError:
pass
op_str = ", ".join(words)
return inst.mnemonic, op_str
def parse_asm(self, data: bytes, start_addr: Optional[int] = 0) -> List[str]:
asm = []
for inst in disassembler.disasm_lite(data, start_addr):
# Use heuristics to disregard some differences that aren't representative
# of the accuracy of a function (e.g. global offsets)
result = self.sanitize(DisasmLiteInst(*inst))
# mnemonic + " " + op_str
asm.append(" ".join(result))
return asm

View File

@@ -0,0 +1,80 @@
import re
REGISTER_LIST = set(
[
"ax",
"bp",
"bx",
"cx",
"di",
"dx",
"eax",
"ebp",
"ebx",
"ecx",
"edi",
"edx",
"esi",
"esp",
"si",
"sp",
]
)
WORDS = re.compile(r"\w+")
def get_registers(line: str):
to_replace = []
# use words regex to find all matching positions:
for match in WORDS.finditer(line):
reg = match.group(0)
if reg in REGISTER_LIST:
to_replace.append((reg, match.start()))
return to_replace
def replace_register(
lines: list[str], start_line: int, reg: str, replacement: str
) -> list[str]:
return [
line.replace(reg, replacement) if i >= start_line else line
for i, line in enumerate(lines)
]
# Is it possible to make new_asm the same as original_asm by swapping registers?
def can_resolve_register_differences(original_asm, new_asm):
# Split the ASM on spaces to get more granularity, and so
# that we don't modify the original arrays passed in.
original_asm = [part for line in original_asm for part in line.split()]
new_asm = [part for line in new_asm for part in line.split()]
# Swapping ain't gonna help if the lengths are different
if len(original_asm) != len(new_asm):
return False
# Look for the mismatching lines
for i, original_line in enumerate(original_asm):
new_line = new_asm[i]
if new_line != original_line:
# Find all the registers to replace
to_replace = get_registers(original_line)
for replace in to_replace:
(reg, reg_index) = replace
replacing_reg = new_line[reg_index : reg_index + len(reg)]
if replacing_reg in REGISTER_LIST:
if replacing_reg != reg:
# Do a three-way swap replacing in all the subsequent lines
temp_reg = "&" * len(reg)
new_asm = replace_register(new_asm, i, replacing_reg, temp_reg)
new_asm = replace_register(new_asm, i, reg, replacing_reg)
new_asm = replace_register(new_asm, i, temp_reg, reg)
else:
# No replacement to do, different code, bail out
return False
# Check if the lines are now the same
for i, original_line in enumerate(original_asm):
if new_asm[i] != original_line:
return False
return True

View File

@@ -1,11 +1,14 @@
import os
import logging
from typing import List, Optional
import difflib
from dataclasses import dataclass
from typing import Iterable, List, Optional
from isledecomp.cvdump.demangler import demangle_string_const
from isledecomp.cvdump import Cvdump, CvdumpAnalysis
from isledecomp.parser import DecompCodebase
from isledecomp.dir import walk_source_dir
from isledecomp.types import SymbolType
from isledecomp.compare.asm import ParseAsm, can_resolve_register_differences
from .db import CompareDb, MatchInfo
from .lines import LinesDb
@@ -13,6 +16,24 @@ from .lines import LinesDb
logger = logging.getLogger(__name__)
@dataclass
class DiffReport:
orig_addr: int
recomp_addr: int
name: str
udiff: Optional[List[str]] = None
ratio: float = 0.0
is_effective_match: bool = False
@property
def effective_ratio(self) -> float:
return 1.0 if self.is_effective_match else self.ratio
def __str__(self) -> str:
"""For debug purposes. Proper diff printing (with coloring) is in another module."""
return f"{self.name} (0x{self.orig_addr:x}) {self.ratio*100:.02f}%{'*' if self.is_effective_match else ''}"
class Compare:
# pylint: disable=too-many-instance-attributes
def __init__(self, orig_bin, recomp_bin, pdb_file, code_dir):
@@ -27,6 +48,7 @@ class Compare:
self._load_cvdump()
self._load_markers()
self._find_original_strings()
self._match_thunks()
def _load_cvdump(self):
logger.info("Parsing %s ...", self.pdb_file)
@@ -126,6 +148,46 @@ class Compare:
self._db.match_string(addr, string)
def _match_thunks(self):
orig_byaddr = {
addr: (dll.upper(), name) for (dll, name, addr) in self.orig_bin.imports
}
recomp_byname = {
(dll.upper(), name): addr for (dll, name, addr) in self.recomp_bin.imports
}
# Combine these two dictionaries. We don't care about imports from recomp
# not found in orig because:
# 1. They shouldn't be there
# 2. They are already identified via cvdump
orig_to_recomp = {
addr: recomp_byname.get(pair, None) for addr, pair in orig_byaddr.items()
}
# Now: we have the IAT offset in each matched up, so we need to make
# the connection between the thunk functions.
# We already have the symbol name we need from the PDB.
orig_thunks = {
iat_ofs: func_ofs for (func_ofs, iat_ofs) in self.orig_bin.thunks
}
recomp_thunks = {
iat_ofs: func_ofs for (func_ofs, iat_ofs) in self.recomp_bin.thunks
}
for orig, recomp in orig_to_recomp.items():
self._db.set_pair(orig, recomp, SymbolType.POINTER)
thunk_from_orig = orig_thunks.get(orig, None)
thunk_from_recomp = recomp_thunks.get(recomp, None)
if thunk_from_orig is not None and thunk_from_recomp is not None:
self._db.set_function_pair(thunk_from_orig, thunk_from_recomp)
# Don't compare thunk functions for now. The comparison isn't
# "useful" in the usual sense. We are only looking at the 6
# bytes of the jmp instruction and not the larger context of
# where this function is. Also: these will always match 100%
# because we are searching for a match to register this as a
# function in the first place.
self._db.skip_compare(thunk_from_orig)
def get_one_function(self, addr: int) -> Optional[MatchInfo]:
"""i.e. verbose mode for reccmp"""
return self._db.get_one_function(addr)
@@ -133,8 +195,84 @@ class Compare:
def get_functions(self) -> List[MatchInfo]:
return self._db.get_matches(SymbolType.FUNCTION)
def compare_functions(self):
pass
def _compare_function(self, match: MatchInfo) -> DiffReport:
if match.size == 0:
# Report a failed match to make the user aware of the empty function.
return DiffReport(
orig_addr=match.orig_addr,
recomp_addr=match.recomp_addr,
name=match.name,
)
orig_raw = self.orig_bin.read(match.orig_addr, match.size)
recomp_raw = self.recomp_bin.read(match.recomp_addr, match.size)
def orig_should_replace(addr: int) -> bool:
return addr > self.orig_bin.imagebase and self.orig_bin.is_relocated_addr(
addr
)
def recomp_should_replace(addr: int) -> bool:
return (
addr > self.recomp_bin.imagebase
and self.recomp_bin.is_relocated_addr(addr)
)
def orig_lookup(addr: int) -> Optional[str]:
m = self._db.get_by_orig(addr)
if m is None:
return None
return m.match_name()
def recomp_lookup(addr: int) -> Optional[str]:
m = self._db.get_by_recomp(addr)
if m is None:
return None
return m.match_name()
orig_parse = ParseAsm(
relocate_lookup=orig_should_replace, name_lookup=orig_lookup
)
recomp_parse = ParseAsm(
relocate_lookup=recomp_should_replace, name_lookup=recomp_lookup
)
orig_asm = orig_parse.parse_asm(orig_raw, match.orig_addr)
recomp_asm = recomp_parse.parse_asm(recomp_raw, match.recomp_addr)
diff = difflib.SequenceMatcher(None, orig_asm, recomp_asm)
ratio = diff.ratio()
if ratio != 1.0:
# Check whether we can resolve register swaps which are actually
# perfect matches modulo compiler entropy.
is_effective_match = can_resolve_register_differences(orig_asm, recomp_asm)
unified_diff = difflib.unified_diff(orig_asm, recomp_asm, n=10)
else:
is_effective_match = False
unified_diff = []
return DiffReport(
orig_addr=match.orig_addr,
recomp_addr=match.recomp_addr,
name=match.name,
udiff=unified_diff,
ratio=ratio,
is_effective_match=is_effective_match,
)
def compare_function(self, addr: int) -> Optional[DiffReport]:
match = self.get_one_function(addr)
if match is None:
return None
return self._compare_function(match)
def compare_functions(self) -> Iterable[DiffReport]:
for match in self.get_functions():
yield self._compare_function(match)
def compare_variables(self):
pass

View File

@@ -2,7 +2,6 @@
addresses/symbols that we want to compare between the original and recompiled binaries."""
import sqlite3
import logging
from collections import namedtuple
from typing import List, Optional
from isledecomp.types import SymbolType
@@ -16,12 +15,35 @@ _SETUP_SQL = """
size int,
should_skip int default(FALSE)
);
CREATE INDEX `symbols_or` ON `symbols` (orig_addr);
CREATE INDEX `symbols_re` ON `symbols` (recomp_addr);
CREATE INDEX `symbols_na` ON `symbols` (compare_type, name);
"""
MatchInfo = namedtuple("MatchInfo", "orig_addr, recomp_addr, size, name")
class MatchInfo:
def __init__(
self,
ctype: Optional[int],
orig: Optional[int],
recomp: Optional[int],
name: Optional[str],
size: Optional[int],
) -> None:
self.compare_type = SymbolType(ctype) if ctype is not None else None
self.orig_addr = orig
self.recomp_addr = recomp
self.name = name
self.size = size
def match_name(self) -> str:
"""Combination of the name and compare type.
Intended for name substitution in the diff. If there is a diff,
it will be more obvious what this symbol indicates."""
if self.name is None:
return None
ctype = self.compare_type.name if self.compare_type is not None else "UNK"
return f"{self.name} ({ctype})"
def matchinfo_factory(_, row):
@@ -61,7 +83,7 @@ class CompareDb:
def get_one_function(self, addr: int) -> Optional[MatchInfo]:
cur = self._db.execute(
"""SELECT orig_addr, recomp_addr, size, name
"""SELECT compare_type, orig_addr, recomp_addr, name, size
FROM `symbols`
WHERE compare_type = ?
AND orig_addr = ?
@@ -74,9 +96,31 @@ class CompareDb:
cur.row_factory = matchinfo_factory
return cur.fetchone()
def get_by_orig(self, addr: int) -> Optional[MatchInfo]:
cur = self._db.execute(
"""SELECT compare_type, orig_addr, recomp_addr, name, size
FROM `symbols`
WHERE orig_addr = ?
""",
(addr,),
)
cur.row_factory = matchinfo_factory
return cur.fetchone()
def get_by_recomp(self, addr: int) -> Optional[MatchInfo]:
cur = self._db.execute(
"""SELECT compare_type, orig_addr, recomp_addr, name, size
FROM `symbols`
WHERE recomp_addr = ?
""",
(addr,),
)
cur.row_factory = matchinfo_factory
return cur.fetchone()
def get_matches(self, compare_type: SymbolType) -> List[MatchInfo]:
cur = self._db.execute(
"""SELECT orig_addr, recomp_addr, size, name
"""SELECT compare_type, orig_addr, recomp_addr, name, size
FROM `symbols`
WHERE compare_type = ?
AND orig_addr IS NOT NULL
@@ -90,14 +134,20 @@ class CompareDb:
return cur.fetchall()
def set_function_pair(self, orig: int, recomp: int) -> bool:
"""For lineref match or _entry"""
def set_pair(
self, orig: int, recomp: int, compare_type: Optional[SymbolType] = None
) -> bool:
compare_value = compare_type.value if compare_type is not None else None
cur = self._db.execute(
"UPDATE `symbols` SET orig_addr = ?, compare_type = ? WHERE recomp_addr = ?",
(orig, SymbolType.FUNCTION.value, recomp),
(orig, compare_value, recomp),
)
return cur.rowcount > 0
def set_function_pair(self, orig: int, recomp: int) -> bool:
"""For lineref match or _entry"""
self.set_pair(orig, recomp, SymbolType.FUNCTION)
# TODO: Both ways required?
def skip_compare(self, orig: int):

View File

@@ -36,7 +36,7 @@ _section_contrib_regex = re.compile(
# e.g. `S_GDATA32: [0003:000004A4], Type: T_32PRCHAR(0470), g_set`
_gdata32_regex = re.compile(
r"S_GDATA32: \[(?P<section>\w{4}):(?P<offset>\w{8})\], Type:\s*(?P<type>\S+), (?P<name>\S+)"
r"S_GDATA32: \[(?P<section>\w{4}):(?P<offset>\w{8})\], Type:\s*(?P<type>\S+), (?P<name>.+)"
)

View File

@@ -26,17 +26,3 @@ def print_diff(udiff, plain):
def get_file_in_script_dir(fn):
return os.path.join(os.path.dirname(os.path.abspath(sys.argv[0])), fn)
class OffsetPlaceholderGenerator:
def __init__(self):
self.counter = 0
self.replacements = {}
def get(self, replace_addr):
if replace_addr in self.replacements:
return self.replacements[replace_addr]
self.counter += 1
replacement = f"<OFFSET{self.counter}>"
self.replacements[replace_addr] = replacement
return replacement

View File

@@ -0,0 +1,179 @@
from typing import Optional
import pytest
from isledecomp.compare.asm.parse import DisasmLiteInst, ParseAsm
def mock_inst(mnemonic: str, op_str: str) -> DisasmLiteInst:
"""Mock up the named tuple DisasmLite from just a mnemonic and op_str.
To be used for tests on sanitize that do not require the instruction address
or size. i.e. any non-jump instruction."""
return DisasmLiteInst(0, 0, mnemonic, op_str)
identity_cases = [
("", ""),
("sti", ""),
("push", "ebx"),
("ret", ""),
("ret", "4"),
("mov", "eax, 0x1234"),
]
@pytest.mark.parametrize("mnemonic, op_str", identity_cases)
def test_identity(mnemonic, op_str):
"""Confirm that nothing is substituted."""
p = ParseAsm()
inst = mock_inst(mnemonic, op_str)
result = p.sanitize(inst)
assert result == (mnemonic, op_str)
ptr_replace_cases = [
("byte ptr [0x5555]", "byte ptr [<OFFSET1>]"),
("word ptr [0x5555]", "word ptr [<OFFSET1>]"),
("dword ptr [0x5555]", "dword ptr [<OFFSET1>]"),
("qword ptr [0x5555]", "qword ptr [<OFFSET1>]"),
("eax, dword ptr [0x5555]", "eax, dword ptr [<OFFSET1>]"),
("dword ptr [0x5555], eax", "dword ptr [<OFFSET1>], eax"),
("dword ptr [0x5555], 0", "dword ptr [<OFFSET1>], 0"),
("dword ptr [0x5555], 8", "dword ptr [<OFFSET1>], 8"),
# Same value, assumed to be an addr in the first appearance
# because it is designated as 'ptr', but we have not provided the
# relocation table lookup method so we do not replace the second appearance.
("dword ptr [0x5555], 0x5555", "dword ptr [<OFFSET1>], 0x5555"),
]
@pytest.mark.parametrize("start, end", ptr_replace_cases)
def test_ptr_replace(start, end):
"""Anything in square brackets (with the 'ptr' prefix) will always be replaced."""
p = ParseAsm()
inst = mock_inst("", start)
(_, op_str) = p.sanitize(inst)
assert op_str == end
call_replace_cases = [
("ebx", "ebx"),
("0x1234", "<OFFSET1>"),
("dword ptr [0x1234]", "dword ptr [<OFFSET1>]"),
("dword ptr [ecx + 0x10]", "dword ptr [ecx + 0x10]"),
]
@pytest.mark.parametrize("start, end", call_replace_cases)
def test_call_replace(start, end):
"""Call with hex operand is always replaced.
Otherwise, ptr replacement rules apply, but skip `this` calls."""
p = ParseAsm()
inst = mock_inst("call", start)
(_, op_str) = p.sanitize(inst)
assert op_str == end
def test_jump_displacement():
"""Display jump displacement (offset from end of jump instruction)
instead of destination address."""
p = ParseAsm()
inst = DisasmLiteInst(0x1000, 2, "je", "0x1000")
(_, op_str) = p.sanitize(inst)
assert op_str == "-0x2"
@pytest.mark.xfail(reason="Not implemented yet")
def test_jmp_table():
"""Should detect the characteristic jump table instruction
(for a switch statement) and use placeholder."""
p = ParseAsm()
inst = mock_inst("jmp", "dword ptr [eax*4 + 0x5555]")
(_, op_str) = p.sanitize(inst)
assert op_str == "dword ptr [eax*4 + <OFFSET1>]"
name_replace_cases = [
("byte ptr [0x5555]", "byte ptr [_substitute_]"),
("word ptr [0x5555]", "word ptr [_substitute_]"),
("dword ptr [0x5555]", "dword ptr [_substitute_]"),
("qword ptr [0x5555]", "qword ptr [_substitute_]"),
]
@pytest.mark.parametrize("start, end", name_replace_cases)
def test_name_replace(start, end):
"""Make sure the name lookup function is called if present"""
def substitute(_: int) -> str:
return "_substitute_"
p = ParseAsm(name_lookup=substitute)
inst = mock_inst("mov", start)
(_, op_str) = p.sanitize(inst)
assert op_str == end
def test_replacement_cache():
p = ParseAsm()
inst = mock_inst("inc", "dword ptr [0x1234]")
(_, op_str) = p.sanitize(inst)
assert op_str == "dword ptr [<OFFSET1>]"
(_, op_str) = p.sanitize(inst)
assert op_str == "dword ptr [<OFFSET1>]"
def test_replacement_numbering():
"""If we can use the name lookup for the first address but not the second,
the second replacement should be <OFFSET2> not <OFFSET1>."""
def substitute_1234(addr: int) -> Optional[str]:
return "_substitute_" if addr == 0x1234 else None
p = ParseAsm(name_lookup=substitute_1234)
(_, op_str) = p.sanitize(mock_inst("inc", "dword ptr [0x1234]"))
assert op_str == "dword ptr [_substitute_]"
(_, op_str) = p.sanitize(mock_inst("inc", "dword ptr [0x5555]"))
assert op_str == "dword ptr [<OFFSET2>]"
def test_relocate_lookup():
"""Immediate values would be relocated if they are actually addresses.
So we can use the relocation table to check whether a given value is an
address or just some number."""
def relocate_lookup(addr: int) -> bool:
return addr == 0x1234
p = ParseAsm(relocate_lookup=relocate_lookup)
(_, op_str) = p.sanitize(mock_inst("mov", "eax, 0x1234"))
assert op_str == "eax, <OFFSET1>"
(_, op_str) = p.sanitize(mock_inst("mov", "eax, 0x5555"))
assert op_str == "eax, 0x5555"
def test_jump_to_function():
"""A jmp instruction can lead us directly to a function. This can be found
in the unwind section at the end of a function. However: we do not want to
assume this is the case for all jumps. Only replace the jump with a name
if we can find it using our lookup."""
def substitute_1234(addr: int) -> Optional[str]:
return "_substitute_" if addr == 0x1234 else None
p = ParseAsm(name_lookup=substitute_1234)
inst = DisasmLiteInst(0x1000, 2, "jmp", "0x1234")
(_, op_str) = p.sanitize(inst)
assert op_str == "_substitute_"
# Should not replace this jump.
# 0x1000 (start addr)
# + 2 (size of jump instruction)
# + 0x5555 (displacement, the value we want)
# = 0x6557
inst = DisasmLiteInst(0x1000, 2, "jmp", "0x6557")
(_, op_str) = p.sanitize(inst)
assert op_str == "0x5555"