mirror of
				https://github.com/isledecomp/isle.git
				synced 2025-10-26 09:54:18 +00:00 
			
		
		
		
	Performance enhancements (#527)
This commit is contained in:
		| @@ -1,5 +1,7 @@ | ||||
| import logging | ||||
| import struct | ||||
| import bisect | ||||
| from functools import cached_property | ||||
| from typing import List, Optional, Tuple | ||||
| from dataclasses import dataclass | ||||
| from collections import namedtuple | ||||
| @@ -36,33 +38,44 @@ PEHeader = namedtuple( | ||||
|     ], | ||||
| ) | ||||
| 
 | ||||
| ImageSectionHeader = namedtuple( | ||||
|     "ImageSectionHeader", | ||||
|     [ | ||||
|         "name", | ||||
|         "virtual_size", | ||||
|         "virtual_address", | ||||
|         "size_of_raw_data", | ||||
|         "pointer_to_raw_data", | ||||
|         "pointer_to_relocations", | ||||
|         "pointer_to_line_numbers", | ||||
|         "number_of_relocations", | ||||
|         "number_of_line_numbers", | ||||
|         "characteristics", | ||||
|     ], | ||||
| ) | ||||
| 
 | ||||
| 
 | ||||
| @dataclass | ||||
| class ImageSectionHeader: | ||||
|     # pylint: disable=too-many-instance-attributes | ||||
|     # Most attributes are unused, but this is the struct format | ||||
|     name: bytes | ||||
| class Section: | ||||
|     name: str | ||||
|     virtual_size: int | ||||
|     virtual_address: int | ||||
|     size_of_raw_data: int | ||||
|     pointer_to_raw_data: int | ||||
|     pointer_to_relocations: int | ||||
|     pointer_to_line_numbers: int | ||||
|     number_of_relocations: int | ||||
|     number_of_line_numbers: int | ||||
|     characteristics: int | ||||
|     view: memoryview | ||||
| 
 | ||||
|     @property | ||||
|     @cached_property | ||||
|     def size_of_raw_data(self) -> int: | ||||
|         return len(self.view) | ||||
| 
 | ||||
|     @cached_property | ||||
|     def extent(self): | ||||
|         """Get the highest possible offset of this section""" | ||||
|         return max(self.size_of_raw_data, self.virtual_size) | ||||
| 
 | ||||
|     def match_name(self, name: str) -> bool: | ||||
|         return self.name == struct.pack("8s", name.encode("ascii")) | ||||
|         return self.name == name | ||||
| 
 | ||||
|     def contains_vaddr(self, vaddr: int) -> bool: | ||||
|         ofs = vaddr - self.virtual_address | ||||
|         return 0 <= ofs < self.extent | ||||
|         return self.virtual_address <= vaddr < self.virtual_address + self.extent | ||||
| 
 | ||||
|     def addr_is_uninitialized(self, vaddr: int) -> bool: | ||||
|         """We cannot rely on the IMAGE_SCN_CNT_UNINITIALIZED_DATA flag (0x80) in | ||||
| @@ -89,11 +102,11 @@ class Bin: | ||||
|     def __init__(self, filename: str, find_str: bool = False) -> None: | ||||
|         logger.debug('Parsing headers of "%s"... ', filename) | ||||
|         self.filename = filename | ||||
|         self.file = None | ||||
|         self.view: memoryview = None | ||||
|         self.imagebase = None | ||||
|         self.entry = None | ||||
|         self.sections: List[ImageSectionHeader] = [] | ||||
|         self.last_section = None | ||||
|         self.sections: List[Section] = [] | ||||
|         self._section_vaddr: List[int] = [] | ||||
|         self.find_str = find_str | ||||
|         self._potential_strings = {} | ||||
|         self._relocated_addrs = set() | ||||
| @@ -102,36 +115,51 @@ class Bin: | ||||
| 
 | ||||
|     def __enter__(self): | ||||
|         logger.debug("Bin %s Enter", self.filename) | ||||
|         self.file = open(self.filename, "rb") | ||||
|         with open(self.filename, "rb") as f: | ||||
|             self.view = memoryview(f.read()) | ||||
| 
 | ||||
|         (mz_str,) = struct.unpack("2s", self.file.read(2)) | ||||
|         (mz_str,) = struct.unpack("2s", self.view[0:2]) | ||||
|         if mz_str != b"MZ": | ||||
|             raise MZHeaderNotFoundError | ||||
| 
 | ||||
|         # Skip to PE header offset in MZ header. | ||||
|         self.file.seek(0x3C) | ||||
|         (pe_header_start,) = struct.unpack("<I", self.file.read(4)) | ||||
|         (pe_header_start,) = struct.unpack("<I", self.view[0x3C:0x40]) | ||||
| 
 | ||||
|         # PE header offset is absolute, so seek there | ||||
|         self.file.seek(pe_header_start) | ||||
|         pe_hdr = PEHeader(*struct.unpack("<2s2x2H3I2H", self.file.read(0x18))) | ||||
|         pe_header_view = self.view[pe_header_start:] | ||||
|         pe_hdr = PEHeader(*struct.unpack("<2s2x2H3I2H", pe_header_view[:0x18])) | ||||
| 
 | ||||
|         if pe_hdr.Signature != b"PE": | ||||
|             raise PEHeaderNotFoundError | ||||
| 
 | ||||
|         optional_hdr = self.file.read(pe_hdr.SizeOfOptionalHeader) | ||||
|         optional_hdr = pe_header_view[0x18:] | ||||
|         (self.imagebase,) = struct.unpack("<i", optional_hdr[0x1C:0x20]) | ||||
|         (entry,) = struct.unpack("<i", optional_hdr[0x10:0x14]) | ||||
|         self.entry = entry + self.imagebase | ||||
| 
 | ||||
|         self.sections = [ | ||||
|             ImageSectionHeader(*struct.unpack("<8s6I2HI", self.file.read(0x28))) | ||||
|             for i in range(pe_hdr.NumberOfSections) | ||||
|         headers_view = optional_hdr[ | ||||
|             pe_hdr.SizeOfOptionalHeader : pe_hdr.SizeOfOptionalHeader | ||||
|             + 0x28 * pe_hdr.NumberOfSections | ||||
|         ] | ||||
|         section_headers = [ | ||||
|             ImageSectionHeader(*h) for h in struct.iter_unpack("<8s6I2HI", headers_view) | ||||
|         ] | ||||
| 
 | ||||
|         # Add the imagebase here because we almost never need the base vaddr without it | ||||
|         for sect in self.sections: | ||||
|             sect.virtual_address += self.imagebase | ||||
|         self.sections = [ | ||||
|             Section( | ||||
|                 name=hdr.name.decode("ascii").rstrip("\x00"), | ||||
|                 virtual_address=self.imagebase + hdr.virtual_address, | ||||
|                 virtual_size=hdr.virtual_size, | ||||
|                 view=self.view[ | ||||
|                     hdr.pointer_to_raw_data : hdr.pointer_to_raw_data | ||||
|                     + hdr.size_of_raw_data | ||||
|                 ], | ||||
|             ) | ||||
|             for hdr in section_headers | ||||
|         ] | ||||
| 
 | ||||
|         # bisect does not support key on the github CI version of python | ||||
|         self._section_vaddr = [section.virtual_address for section in self.sections] | ||||
| 
 | ||||
|         self._populate_relocations() | ||||
|         self._populate_imports() | ||||
| @@ -143,16 +171,12 @@ class Bin: | ||||
|         if self.find_str: | ||||
|             self._prepare_string_search() | ||||
| 
 | ||||
|         text_section = self._get_section_by_name(".text") | ||||
|         self.last_section = text_section | ||||
| 
 | ||||
|         logger.debug("... Parsing finished") | ||||
|         return self | ||||
| 
 | ||||
|     def __exit__(self, exc_type, exc_value, exc_traceback): | ||||
|         logger.debug("Bin %s Exit", self.filename) | ||||
|         if self.file: | ||||
|             self.file.close() | ||||
|         self.view.release() | ||||
| 
 | ||||
|     def get_relocated_addresses(self) -> List[int]: | ||||
|         return sorted(self._relocated_addrs) | ||||
| @@ -186,8 +210,8 @@ class Bin: | ||||
|         def is_ascii(b): | ||||
|             return b" " <= b < b"\x7f" | ||||
| 
 | ||||
|         sect_data = self._get_section_by_name(".data") | ||||
|         sect_rdata = self._get_section_by_name(".rdata") | ||||
|         sect_data = self.get_section_by_name(".data") | ||||
|         sect_rdata = self.get_section_by_name(".rdata") | ||||
|         potentials = filter( | ||||
|             lambda a: sect_data.contains_vaddr(a) or sect_rdata.contains_vaddr(a), | ||||
|             self.get_relocated_addresses(), | ||||
| @@ -212,7 +236,8 @@ class Bin: | ||||
|         One use case is to tell whether an immediate value in an operand represents | ||||
|         a virtual address or just a big number.""" | ||||
| 
 | ||||
|         ofs = self.get_section_offset_by_name(".reloc") | ||||
|         reloc = self.get_section_by_name(".reloc").view | ||||
|         ofs = 0 | ||||
|         reloc_addrs = [] | ||||
| 
 | ||||
|         # Parse the structure in .reloc to get the list locations to check. | ||||
| @@ -223,12 +248,12 @@ class Bin: | ||||
|         # If the entry read in is zero, we are at the end of this section and | ||||
|         # these are padding bytes. | ||||
|         while True: | ||||
|             (page_base, block_size) = struct.unpack("<2I", self.read(ofs, 8)) | ||||
|             (page_base, block_size) = struct.unpack("<2I", reloc[ofs : ofs + 8]) | ||||
|             if block_size == 0: | ||||
|                 break | ||||
| 
 | ||||
|             # HACK: ignore the relocation type for now (the top 4 bits of the value). | ||||
|             values = list(struct.iter_unpack("<H", self.read(ofs + 8, block_size - 8))) | ||||
|             values = list(struct.iter_unpack("<H", reloc[ofs + 8 : ofs + block_size])) | ||||
|             reloc_addrs += [ | ||||
|                 self.imagebase + page_base + (v[0] & 0xFFF) for v in values if v[0] != 0 | ||||
|             ] | ||||
| @@ -238,8 +263,9 @@ class Bin: | ||||
|         # We are now interested in the relocated addresses themselves. Seek to the | ||||
|         # address where there is a relocation, then read the four bytes into our set. | ||||
|         reloc_addrs.sort() | ||||
|         for addr in reloc_addrs: | ||||
|             (relocated_addr,) = struct.unpack("<I", self.read(addr, 4)) | ||||
|         for section_id, offset in map(self.get_relative_addr, reloc_addrs): | ||||
|             section = self.get_section_by_index(section_id) | ||||
|             (relocated_addr,) = struct.unpack("<I", section.view[offset : offset + 4]) | ||||
|             self._relocated_addrs.add(relocated_addr) | ||||
| 
 | ||||
|     def _populate_imports(self): | ||||
| @@ -296,15 +322,13 @@ class Bin: | ||||
|         instruction in the function is a jmp to the address in .idata. | ||||
|         Search .text to find these functions.""" | ||||
| 
 | ||||
|         text_sect = self._get_section_by_name(".text") | ||||
|         idata_sect = self._get_section_by_name(".idata") | ||||
|         text_sect = self.get_section_by_name(".text") | ||||
|         idata_sect = self.get_section_by_name(".idata") | ||||
|         start = text_sect.virtual_address | ||||
|         ofs = start | ||||
| 
 | ||||
|         bs = self.read(ofs, text_sect.size_of_raw_data) | ||||
| 
 | ||||
|         for shift in (0, 2, 4): | ||||
|             window = bs[shift:] | ||||
|             window = text_sect.view[shift:] | ||||
|             win_end = 6 * (len(window) // 6) | ||||
|             for i, (b0, b1, jmp_ofs) in enumerate( | ||||
|                 struct.iter_unpack("<2BL", window[:win_end]) | ||||
| @@ -314,23 +338,7 @@ class Bin: | ||||
|                     thunk_ofs = ofs + shift + i * 6 | ||||
|                     self.thunks.append((thunk_ofs, jmp_ofs)) | ||||
| 
 | ||||
|     def _set_section_for_vaddr(self, vaddr: int): | ||||
|         if self.last_section is not None and self.last_section.contains_vaddr(vaddr): | ||||
|             return | ||||
| 
 | ||||
|         # TODO: assumes no potential for section overlap. reasonable? | ||||
|         self.last_section = next( | ||||
|             filter( | ||||
|                 lambda section: section.contains_vaddr(vaddr), | ||||
|                 self.sections, | ||||
|             ), | ||||
|             None, | ||||
|         ) | ||||
| 
 | ||||
|         if self.last_section is None: | ||||
|             raise InvalidVirtualAddressError(f"0x{vaddr:08x}") | ||||
| 
 | ||||
|     def _get_section_by_name(self, name: str): | ||||
|     def get_section_by_name(self, name: str) -> Section: | ||||
|         section = next( | ||||
|             filter(lambda section: section.match_name(name), self.sections), | ||||
|             None, | ||||
| @@ -341,8 +349,12 @@ class Bin: | ||||
| 
 | ||||
|         return section | ||||
| 
 | ||||
|     def get_section_by_index(self, index: int) -> Section: | ||||
|         """Convert 1-based index into 0-based.""" | ||||
|         return self.sections[index - 1] | ||||
| 
 | ||||
|     def get_section_extent_by_index(self, index: int) -> int: | ||||
|         return self.sections[index - 1].extent | ||||
|         return self.get_section_by_index(index).extent | ||||
| 
 | ||||
|     def get_section_offset_by_index(self, index: int) -> int: | ||||
|         """The symbols output from cvdump gives addresses in this format: AAAA.BBBBBBBB | ||||
| @@ -350,14 +362,12 @@ class Bin: | ||||
|         This will return the virtual address for the start of the section at the given index | ||||
|         so you can get the virtual address for whatever symbol you are looking at. | ||||
|         """ | ||||
| 
 | ||||
|         section = self.sections[index - 1] | ||||
|         return section.virtual_address | ||||
|         return self.get_section_by_index(index).virtual_address | ||||
| 
 | ||||
|     def get_section_offset_by_name(self, name: str) -> int: | ||||
|         """Same as above, but use the section name as the lookup""" | ||||
| 
 | ||||
|         section = self._get_section_by_name(name) | ||||
|         section = self.get_section_by_name(name) | ||||
|         return section.virtual_address | ||||
| 
 | ||||
|     def get_abs_addr(self, section: int, offset: int) -> int: | ||||
| @@ -367,41 +377,32 @@ class Bin: | ||||
| 
 | ||||
|     def get_relative_addr(self, addr: int) -> Tuple[int, int]: | ||||
|         """Convert an absolute address back into a (section, offset) pair.""" | ||||
|         for i, section in enumerate(self.sections): | ||||
|             if section.contains_vaddr(addr): | ||||
|                 return (i + 1, addr - section.virtual_address) | ||||
|         i = bisect.bisect_right(self._section_vaddr, addr) - 1 | ||||
|         i = max(0, i) | ||||
| 
 | ||||
|         return (0, 0) | ||||
|         section = self.sections[i] | ||||
|         if section.contains_vaddr(addr): | ||||
|             return (i + 1, addr - section.virtual_address) | ||||
| 
 | ||||
|     def get_raw_addr(self, vaddr: int) -> int: | ||||
|         """Returns the raw offset in the PE binary for the given virtual address.""" | ||||
|         self._set_section_for_vaddr(vaddr) | ||||
|         return ( | ||||
|             vaddr | ||||
|             - self.last_section.virtual_address | ||||
|             + self.last_section.pointer_to_raw_data | ||||
|         ) | ||||
|         raise InvalidVirtualAddressError(hex(addr)) | ||||
| 
 | ||||
|     def is_valid_section(self, section: int) -> bool: | ||||
|     def is_valid_section(self, section_id: int) -> bool: | ||||
|         """The PDB will refer to sections that are not listed in the headers | ||||
|         and so should ignore these references.""" | ||||
|         try: | ||||
|             _ = self.sections[section - 1] | ||||
|             _ = self.get_section_by_index(section_id) | ||||
|             return True | ||||
|         except IndexError: | ||||
|             return False | ||||
| 
 | ||||
|     def is_valid_vaddr(self, vaddr: int) -> bool: | ||||
|         """Does this virtual address point to anything in the exe?""" | ||||
|         section = next( | ||||
|             filter( | ||||
|                 lambda section: section.contains_vaddr(vaddr), | ||||
|                 self.sections, | ||||
|             ), | ||||
|             None, | ||||
|         ) | ||||
|         try: | ||||
|             (_, __) = self.get_relative_addr(vaddr) | ||||
|         except InvalidVirtualAddressError: | ||||
|             return False | ||||
| 
 | ||||
|         return section is not None | ||||
|         return True | ||||
| 
 | ||||
|     def read_string(self, offset: int, chunk_size: int = 1000) -> Optional[bytes]: | ||||
|         """Read until we find a zero byte.""" | ||||
| @@ -415,23 +416,16 @@ class Bin: | ||||
|             # No terminator found, just return what we have | ||||
|             return b | ||||
| 
 | ||||
|     def read(self, offset: int, size: int) -> Optional[bytes]: | ||||
|     def read(self, vaddr: int, size: int) -> Optional[bytes]: | ||||
|         """Read (at most) the given number of bytes at the given virtual address. | ||||
|         If we return None, the given address points to uninitialized data.""" | ||||
|         self._set_section_for_vaddr(offset) | ||||
|         (section_id, offset) = self.get_relative_addr(vaddr) | ||||
|         section = self.sections[section_id - 1] | ||||
| 
 | ||||
|         if self.last_section.addr_is_uninitialized(offset): | ||||
|         if section.addr_is_uninitialized(vaddr): | ||||
|             return None | ||||
| 
 | ||||
|         raw_addr = self.get_raw_addr(offset) | ||||
|         self.file.seek(raw_addr) | ||||
| 
 | ||||
|         # Clamp the read within the extent of the current section. | ||||
|         # Reading off the end will most likely misrepresent the virtual addressing. | ||||
|         _size = min( | ||||
|             size, | ||||
|             self.last_section.pointer_to_raw_data | ||||
|             + self.last_section.size_of_raw_data | ||||
|             - raw_addr, | ||||
|         ) | ||||
|         return self.file.read(_size) | ||||
|         _size = min(size, section.size_of_raw_data - offset) | ||||
|         return bytes(section.view[offset : offset + _size]) | ||||
|   | ||||
| @@ -1,12 +1,13 @@ | ||||
| """Converts x86 machine code into text (i.e. assembly). The end goal is to | ||||
| compare the code in the original and recomp binaries, using longest common | ||||
| subsequence (LCS), i.e. difflib.SequenceMatcher. | ||||
| The capstone library takes the raw bytes and gives us the mnemnonic | ||||
| The capstone library takes the raw bytes and gives us the mnemonic | ||||
| and operand(s) for each instruction. We need to "sanitize" the text further | ||||
| so that virtual addresses are replaced by symbol name or a generic | ||||
| placeholder string.""" | ||||
| 
 | ||||
| import re | ||||
| from functools import cache | ||||
| from typing import Callable, List, Optional, Tuple | ||||
| from collections import namedtuple | ||||
| from isledecomp.bin import InvalidVirtualAddressError | ||||
| @@ -19,6 +20,7 @@ ptr_replace_regex = re.compile(r"(?P<data_size>\w+) ptr \[(?P<addr>0x[0-9a-fA-F] | ||||
| DisasmLiteInst = namedtuple("DisasmLiteInst", "address, size, mnemonic, op_str") | ||||
| 
 | ||||
| 
 | ||||
| @cache | ||||
| def from_hex(string: str) -> Optional[int]: | ||||
|     try: | ||||
|         return int(string, 16) | ||||
| @@ -97,6 +99,9 @@ class ParseAsm: | ||||
|             # Nothing to sanitize | ||||
|             return (inst.mnemonic, "") | ||||
| 
 | ||||
|         if "0x" not in inst.op_str: | ||||
|             return (inst.mnemonic, inst.op_str) | ||||
| 
 | ||||
|         # For jumps or calls, if the entire op_str is a hex number, the value | ||||
|         # is a relative offset. | ||||
|         # Otherwise (i.e. it looks like `dword ptr [address]`) it is an | ||||
| @@ -167,21 +172,20 @@ class ParseAsm: | ||||
|         else: | ||||
|             op_str = ptr_replace_regex.sub(filter_out_ptr, inst.op_str) | ||||
| 
 | ||||
|         def replace_immediate(chunk: str) -> str: | ||||
|             if (inttest := from_hex(chunk)) is not None: | ||||
|                 # If this value is a virtual address, it is referenced absolutely, | ||||
|                 # which means it must be in the relocation table. | ||||
|                 if self.is_relocated(inttest): | ||||
|                     return self.replace(inttest) | ||||
| 
 | ||||
|             return chunk | ||||
| 
 | ||||
|         # Performance hack: | ||||
|         # Skip this step if there is nothing left to consider replacing. | ||||
|         if "0x" in op_str: | ||||
|             # Replace immediate values with name or placeholder (where appropriate) | ||||
|             words = op_str.split(", ") | ||||
|             for i, word in enumerate(words): | ||||
|                 try: | ||||
|                     inttest = int(word, 16) | ||||
|                     # If this value is a virtual address, it is referenced absolutely, | ||||
|                     # which means it must be in the relocation table. | ||||
|                     if self.is_relocated(inttest): | ||||
|                         words[i] = self.replace(inttest) | ||||
|                 except ValueError: | ||||
|                     pass | ||||
|             op_str = ", ".join(words) | ||||
|             op_str = ", ".join(map(replace_immediate, op_str.split(", "))) | ||||
| 
 | ||||
|         return inst.mnemonic, op_str | ||||
| 
 | ||||
|   | ||||
| @@ -17,6 +17,7 @@ _SETUP_SQL = """ | ||||
|     ); | ||||
|     CREATE INDEX `symbols_or` ON `symbols` (orig_addr); | ||||
|     CREATE INDEX `symbols_re` ON `symbols` (recomp_addr); | ||||
|     CREATE INDEX `symbols_na` ON `symbols` (name); | ||||
| """ | ||||
| 
 | ||||
| 
 | ||||
|   | ||||
| @@ -2,6 +2,7 @@ | ||||
| between FUNCTION markers and PDB analysis.""" | ||||
| import sqlite3 | ||||
| import logging | ||||
| from functools import cache | ||||
| from typing import Optional | ||||
| from pathlib import Path | ||||
| from isledecomp.dir import PathResolver | ||||
| @@ -22,6 +23,16 @@ _SETUP_SQL = """ | ||||
| logger = logging.getLogger(__name__) | ||||
| 
 | ||||
| 
 | ||||
| @cache | ||||
| def my_samefile(path: str, source_path: str) -> bool: | ||||
|     return Path(path).samefile(source_path) | ||||
| 
 | ||||
| 
 | ||||
| @cache | ||||
| def my_basename_lower(path: str) -> str: | ||||
|     return Path(path).name.lower() | ||||
| 
 | ||||
| 
 | ||||
| class LinesDb: | ||||
|     def __init__(self, code_dir) -> None: | ||||
|         self._db = sqlite3.connect(":memory:") | ||||
| @@ -31,7 +42,7 @@ class LinesDb: | ||||
|     def add_line(self, path: str, line_no: int, addr: int): | ||||
|         """To be added from the LINES section of cvdump.""" | ||||
|         sourcepath = self._path_resolver.resolve_cvdump(path) | ||||
|         filename = Path(sourcepath).name.lower() | ||||
|         filename = my_basename_lower(sourcepath) | ||||
| 
 | ||||
|         self._db.execute( | ||||
|             "INSERT INTO `lineref` (path, filename, line, addr) VALUES (?,?,?,?)", | ||||
| @@ -41,13 +52,13 @@ class LinesDb: | ||||
|     def search_line(self, path: str, line_no: int) -> Optional[int]: | ||||
|         """Using path and line number from FUNCTION marker, | ||||
|         get the address of this function in the recomp.""" | ||||
|         filename = Path(path).name.lower() | ||||
|         filename = my_basename_lower(path) | ||||
|         cur = self._db.execute( | ||||
|             "SELECT path, addr FROM `lineref` WHERE filename = ? AND line = ?", | ||||
|             (filename, line_no), | ||||
|         ) | ||||
|         for source_path, addr in cur.fetchall(): | ||||
|             if Path(path).samefile(source_path): | ||||
|             if my_samefile(path, source_path): | ||||
|                 return addr | ||||
| 
 | ||||
|         logger.error( | ||||
|   | ||||
		Reference in New Issue
	
	Block a user
	 MS
					MS