Performance enhancements (#527)

This commit is contained in:
MS
2024-02-04 13:37:37 -05:00
committed by GitHub
parent b4c9d78eb4
commit 8cc79ad4de
9 changed files with 328 additions and 123 deletions

View File

@@ -1,12 +1,13 @@
"""Converts x86 machine code into text (i.e. assembly). The end goal is to
compare the code in the original and recomp binaries, using longest common
subsequence (LCS), i.e. difflib.SequenceMatcher.
The capstone library takes the raw bytes and gives us the mnemnonic
The capstone library takes the raw bytes and gives us the mnemonic
and operand(s) for each instruction. We need to "sanitize" the text further
so that virtual addresses are replaced by symbol name or a generic
placeholder string."""
import re
from functools import cache
from typing import Callable, List, Optional, Tuple
from collections import namedtuple
from isledecomp.bin import InvalidVirtualAddressError
@@ -19,6 +20,7 @@ ptr_replace_regex = re.compile(r"(?P<data_size>\w+) ptr \[(?P<addr>0x[0-9a-fA-F]
DisasmLiteInst = namedtuple("DisasmLiteInst", "address, size, mnemonic, op_str")
@cache
def from_hex(string: str) -> Optional[int]:
try:
return int(string, 16)
@@ -97,6 +99,9 @@ class ParseAsm:
# Nothing to sanitize
return (inst.mnemonic, "")
if "0x" not in inst.op_str:
return (inst.mnemonic, inst.op_str)
# For jumps or calls, if the entire op_str is a hex number, the value
# is a relative offset.
# Otherwise (i.e. it looks like `dword ptr [address]`) it is an
@@ -167,21 +172,20 @@ class ParseAsm:
else:
op_str = ptr_replace_regex.sub(filter_out_ptr, inst.op_str)
def replace_immediate(chunk: str) -> str:
if (inttest := from_hex(chunk)) is not None:
# If this value is a virtual address, it is referenced absolutely,
# which means it must be in the relocation table.
if self.is_relocated(inttest):
return self.replace(inttest)
return chunk
# Performance hack:
# Skip this step if there is nothing left to consider replacing.
if "0x" in op_str:
# Replace immediate values with name or placeholder (where appropriate)
words = op_str.split(", ")
for i, word in enumerate(words):
try:
inttest = int(word, 16)
# If this value is a virtual address, it is referenced absolutely,
# which means it must be in the relocation table.
if self.is_relocated(inttest):
words[i] = self.replace(inttest)
except ValueError:
pass
op_str = ", ".join(words)
op_str = ", ".join(map(replace_immediate, op_str.split(", ")))
return inst.mnemonic, op_str

View File

@@ -17,6 +17,7 @@ _SETUP_SQL = """
);
CREATE INDEX `symbols_or` ON `symbols` (orig_addr);
CREATE INDEX `symbols_re` ON `symbols` (recomp_addr);
CREATE INDEX `symbols_na` ON `symbols` (name);
"""

View File

@@ -2,6 +2,7 @@
between FUNCTION markers and PDB analysis."""
import sqlite3
import logging
from functools import cache
from typing import Optional
from pathlib import Path
from isledecomp.dir import PathResolver
@@ -22,6 +23,16 @@ _SETUP_SQL = """
logger = logging.getLogger(__name__)
@cache
def my_samefile(path: str, source_path: str) -> bool:
return Path(path).samefile(source_path)
@cache
def my_basename_lower(path: str) -> str:
return Path(path).name.lower()
class LinesDb:
def __init__(self, code_dir) -> None:
self._db = sqlite3.connect(":memory:")
@@ -31,7 +42,7 @@ class LinesDb:
def add_line(self, path: str, line_no: int, addr: int):
"""To be added from the LINES section of cvdump."""
sourcepath = self._path_resolver.resolve_cvdump(path)
filename = Path(sourcepath).name.lower()
filename = my_basename_lower(sourcepath)
self._db.execute(
"INSERT INTO `lineref` (path, filename, line, addr) VALUES (?,?,?,?)",
@@ -41,13 +52,13 @@ class LinesDb:
def search_line(self, path: str, line_no: int) -> Optional[int]:
"""Using path and line number from FUNCTION marker,
get the address of this function in the recomp."""
filename = Path(path).name.lower()
filename = my_basename_lower(path)
cur = self._db.execute(
"SELECT path, addr FROM `lineref` WHERE filename = ? AND line = ?",
(filename, line_no),
)
for source_path, addr in cur.fetchall():
if Path(path).samefile(source_path):
if my_samefile(path, source_path):
return addr
logger.error(