Add Ghidra function import script (#909)

* Add draft for Ghidra function import script

* feature: Basic PDB analysis [skip ci]

This is a draft with a lot of open questions left. Please do not merge

* Refactor: Introduce submodules and reload remedy

* refactor types and make them Python 3.9 compatible

* run black

* WIP: save progress

* fix types and small type safety violations

* fix another Python 3.9 syntax incompatibility

* Implement struct imports [skip ci]

- This code is still in dire need of refactoring and tests
- There are only single-digit issues left, and 2600 functions can be imported
- The biggest remaining error is mismatched stacks

* Refactor, implement enums, fix lots of bugs

* fix Python 3.9 issue

* refactor: address review comments

Not sure why VS Code suddenly decides to remove some empty spaces, but they don't make sense anyway

* add unit tests for new type parsers, fix linter issue

* refactor: db access from pdb_extraction.py

* Fix stack layout offset error

* fix: Undo incorrect reference change

* Fix CI issue

* Improve READMEs (fix typos, add information)

---------

Co-authored-by: jonschz <jonschz@users.noreply.github.com>
This commit is contained in:
jonschz
2024-06-09 14:41:24 +02:00
committed by GitHub
parent 88805f9fcb
commit f26c30974a
21 changed files with 1824 additions and 114 deletions

View File

@@ -0,0 +1,47 @@
class Lego1Exception(Exception):
"""
Our own base class for exceptions.
Makes it easier to distinguish expected and unexpected errors.
"""
class TypeNotFoundError(Lego1Exception):
def __str__(self):
return f"Type not found in PDB: {self.args[0]}"
class TypeNotFoundInGhidraError(Lego1Exception):
def __str__(self):
return f"Type not found in Ghidra: {self.args[0]}"
class TypeNotImplementedError(Lego1Exception):
def __str__(self):
return f"Import not implemented for type: {self.args[0]}"
class ClassOrNamespaceNotFoundInGhidraError(Lego1Exception):
def __init__(self, namespaceHierachy: list[str]):
super().__init__(namespaceHierachy)
def get_namespace_str(self) -> str:
return "::".join(self.args[0])
def __str__(self):
return f"Class or namespace not found in Ghidra: {self.get_namespace_str()}"
class MultipleTypesFoundInGhidraError(Lego1Exception):
def __str__(self):
return (
f"Found multiple types matching '{self.args[0]}' in Ghidra: {self.args[1]}"
)
class StackOffsetMismatchError(Lego1Exception):
pass
class StructModificationError(Lego1Exception):
def __str__(self):
return f"Failed to modify struct in Ghidra: '{self.args[0]}'\nDetailed error: {self.__cause__}"

View File

@@ -0,0 +1,241 @@
# This file can only be imported successfully when run from Ghidra using Ghidrathon.
# Disable spurious warnings in vscode / pylance
# pyright: reportMissingModuleSource=false
import logging
from typing import Optional
from ghidra.program.model.listing import Function, Parameter
from ghidra.program.flatapi import FlatProgramAPI
from ghidra.program.model.listing import ParameterImpl
from ghidra.program.model.symbol import SourceType
from isledecomp.compare.db import MatchInfo
from lego_util.pdb_extraction import (
FunctionSignature,
CppRegisterSymbol,
CppStackSymbol,
)
from lego_util.ghidra_helper import (
get_ghidra_namespace,
sanitize_name,
)
from lego_util.exceptions import StackOffsetMismatchError
from lego_util.type_importer import PdbTypeImporter
logger = logging.getLogger(__name__)
# pylint: disable=too-many-instance-attributes
class PdbFunctionImporter:
"""A representation of a function from the PDB with each type replaced by a Ghidra type instance."""
def __init__(
self,
api: FlatProgramAPI,
match_info: MatchInfo,
signature: FunctionSignature,
type_importer: "PdbTypeImporter",
):
self.api = api
self.match_info = match_info
self.signature = signature
self.type_importer = type_importer
if signature.class_type is not None:
# Import the base class so the namespace exists
self.type_importer.import_pdb_type_into_ghidra(signature.class_type)
assert match_info.name is not None
colon_split = sanitize_name(match_info.name).split("::")
self.name = colon_split.pop()
namespace_hierachy = colon_split
self.namespace = get_ghidra_namespace(api, namespace_hierachy)
self.return_type = type_importer.import_pdb_type_into_ghidra(
signature.return_type
)
self.arguments = [
ParameterImpl(
f"param{index}",
type_importer.import_pdb_type_into_ghidra(type_name),
api.getCurrentProgram(),
)
for (index, type_name) in enumerate(signature.arglist)
]
@property
def call_type(self):
return self.signature.call_type
@property
def stack_symbols(self):
return self.signature.stack_symbols
def get_full_name(self) -> str:
return f"{self.namespace.getName()}::{self.name}"
def matches_ghidra_function(self, ghidra_function: Function) -> bool:
"""Checks whether this function declaration already matches the description in Ghidra"""
name_match = self.name == ghidra_function.getName(False)
namespace_match = self.namespace == ghidra_function.getParentNamespace()
return_type_match = self.return_type == ghidra_function.getReturnType()
# match arguments: decide if thiscall or not
thiscall_matches = (
self.signature.call_type == ghidra_function.getCallingConventionName()
)
if thiscall_matches:
if self.signature.call_type == "__thiscall":
args_match = self._matches_thiscall_parameters(ghidra_function)
else:
args_match = self._matches_non_thiscall_parameters(ghidra_function)
else:
args_match = False
logger.debug(
"Matches: namespace=%s name=%s return_type=%s thiscall=%s args=%s",
namespace_match,
name_match,
return_type_match,
thiscall_matches,
args_match,
)
return (
name_match
and namespace_match
and return_type_match
and thiscall_matches
and args_match
)
def _matches_non_thiscall_parameters(self, ghidra_function: Function) -> bool:
return self._parameter_lists_match(ghidra_function.getParameters())
def _matches_thiscall_parameters(self, ghidra_function: Function) -> bool:
ghidra_params = list(ghidra_function.getParameters())
# remove the `this` argument which we don't generate ourselves
ghidra_params.pop(0)
return self._parameter_lists_match(ghidra_params)
def _parameter_lists_match(self, ghidra_params: "list[Parameter]") -> bool:
if len(self.arguments) != len(ghidra_params):
logger.info("Mismatching argument count")
return False
for this_arg, ghidra_arg in zip(self.arguments, ghidra_params):
# compare argument types
if this_arg.getDataType() != ghidra_arg.getDataType():
logger.debug(
"Mismatching arg type: expected %s, found %s",
this_arg.getDataType(),
ghidra_arg.getDataType(),
)
return False
# compare argument names
stack_match = self.get_matching_stack_symbol(ghidra_arg.getStackOffset())
if stack_match is None:
logger.debug("Not found on stack: %s", ghidra_arg)
return False
# "__formal" is the placeholder for arguments without a name
if (
stack_match.name != ghidra_arg.getName()
and not stack_match.name.startswith("__formal")
):
logger.debug(
"Argument name mismatch: expected %s, found %s",
stack_match.name,
ghidra_arg.getName(),
)
return False
return True
def overwrite_ghidra_function(self, ghidra_function: Function):
"""Replace the function declaration in Ghidra by the one derived from C++."""
ghidra_function.setName(self.name, SourceType.USER_DEFINED)
ghidra_function.setParentNamespace(self.namespace)
ghidra_function.setReturnType(self.return_type, SourceType.USER_DEFINED)
ghidra_function.setCallingConvention(self.call_type)
ghidra_function.replaceParameters(
Function.FunctionUpdateType.DYNAMIC_STORAGE_ALL_PARAMS,
True,
SourceType.USER_DEFINED,
self.arguments,
)
# When we set the parameters, Ghidra will generate the layout.
# Now we read them again and match them against the stack layout in the PDB,
# both to verify and to set the parameter names.
ghidra_parameters: list[Parameter] = ghidra_function.getParameters()
# Try to add Ghidra function names
for index, param in enumerate(ghidra_parameters):
if param.isStackVariable():
self._rename_stack_parameter(index, param)
else:
if param.getName() == "this":
# 'this' parameters are auto-generated and cannot be changed
continue
# Appears to never happen - could in theory be relevant to __fastcall__ functions,
# which we haven't seen yet
logger.warning("Unhandled register variable in %s", self.get_full_name)
continue
def _rename_stack_parameter(self, index: int, param: Parameter):
match = self.get_matching_stack_symbol(param.getStackOffset())
if match is None:
raise StackOffsetMismatchError(
f"Could not find a matching symbol at offset {param.getStackOffset()} in {self.get_full_name()}"
)
if match.data_type == "T_NOTYPE(0000)":
logger.warning("Skipping stack parameter of type NOTYPE")
return
if param.getDataType() != self.type_importer.import_pdb_type_into_ghidra(
match.data_type
):
logger.error(
"Type mismatch for parameter: %s in Ghidra, %s in PDB", param, match
)
return
name = match.name
if name == "__formal":
# these can cause name collisions if multiple ones are present
name = f"__formal_{index}"
param.setName(name, SourceType.USER_DEFINED)
def get_matching_stack_symbol(self, stack_offset: int) -> Optional[CppStackSymbol]:
return next(
(
symbol
for symbol in self.stack_symbols
if isinstance(symbol, CppStackSymbol)
and symbol.stack_offset == stack_offset
),
None,
)
def get_matching_register_symbol(
self, register: str
) -> Optional[CppRegisterSymbol]:
return next(
(
symbol
for symbol in self.stack_symbols
if isinstance(symbol, CppRegisterSymbol) and symbol.register == register
),
None,
)

View File

@@ -0,0 +1,100 @@
"""A collection of helper functions for the interaction with Ghidra."""
import logging
from lego_util.exceptions import (
ClassOrNamespaceNotFoundInGhidraError,
TypeNotFoundInGhidraError,
MultipleTypesFoundInGhidraError,
)
# Disable spurious warnings in vscode / pylance
# pyright: reportMissingModuleSource=false
from ghidra.program.model.data import PointerDataType
from ghidra.program.model.data import DataTypeConflictHandler
from ghidra.program.flatapi import FlatProgramAPI
from ghidra.program.model.data import DataType
from ghidra.program.model.symbol import Namespace
logger = logging.getLogger(__name__)
def get_ghidra_type(api: FlatProgramAPI, type_name: str):
"""
Searches for the type named `typeName` in Ghidra.
Raises:
- NotFoundInGhidraError
- MultipleTypesFoundInGhidraError
"""
result = api.getDataTypes(type_name)
if len(result) == 0:
raise TypeNotFoundInGhidraError(type_name)
if len(result) == 1:
return result[0]
raise MultipleTypesFoundInGhidraError(type_name, result)
def add_pointer_type(api: FlatProgramAPI, pointee: DataType) -> DataType:
new_data_type = PointerDataType(pointee)
new_data_type.setCategoryPath(pointee.getCategoryPath())
result_data_type = (
api.getCurrentProgram()
.getDataTypeManager()
.addDataType(new_data_type, DataTypeConflictHandler.KEEP_HANDLER)
)
if result_data_type is not new_data_type:
logger.debug(
"New pointer replaced by existing one. Fresh pointer: %s (class: %s)",
result_data_type,
result_data_type.__class__,
)
return result_data_type
def get_ghidra_namespace(
api: FlatProgramAPI, namespace_hierachy: list[str]
) -> Namespace:
namespace = api.getCurrentProgram().getGlobalNamespace()
for part in namespace_hierachy:
namespace = api.getNamespace(namespace, part)
if namespace is None:
raise ClassOrNamespaceNotFoundInGhidraError(namespace_hierachy)
return namespace
def create_ghidra_namespace(
api: FlatProgramAPI, namespace_hierachy: list[str]
) -> Namespace:
namespace = api.getCurrentProgram().getGlobalNamespace()
for part in namespace_hierachy:
namespace = api.getNamespace(namespace, part)
if namespace is None:
namespace = api.createNamespace(namespace, part)
return namespace
def sanitize_name(name: str) -> str:
"""
Takes a full class or function name and replaces characters not accepted by Ghidra.
Applies mostly to templates and names like `vbase destructor`.
"""
new_class_name = (
name.replace("<", "[")
.replace(">", "]")
.replace("*", "#")
.replace(" ", "_")
.replace("`", "'")
)
if "<" in name:
new_class_name = "_template_" + new_class_name
if new_class_name != name:
logger.warning(
"Class or function name contains characters forbidden by Ghidra, changing from '%s' to '%s'",
name,
new_class_name,
)
return new_class_name

View File

@@ -0,0 +1,19 @@
from typing import TypeVar
import ghidra
# pylint: disable=invalid-name,unused-argument
T = TypeVar("T")
# from ghidra.app.script.GhidraScript
def currentProgram() -> "ghidra.program.model.listing.Program": ...
def getAddressFactory() -> " ghidra.program.model.address.AddressFactory": ...
def state() -> "ghidra.app.script.GhidraState": ...
def askChoice(title: str, message: str, choices: list[T], defaultValue: T) -> T: ...
def askYesNo(title: str, question: str) -> bool: ...
def getFunctionAt(
entryPoint: ghidra.program.model.address.Address,
) -> ghidra.program.model.listing.Function: ...
def createFunction(
entryPoint: ghidra.program.model.address.Address, name: str
) -> ghidra.program.model.listing.Function: ...

View File

@@ -0,0 +1,166 @@
from dataclasses import dataclass
import re
from typing import Any, Optional
import logging
from isledecomp.cvdump.symbols import SymbolsEntry
from isledecomp.compare import Compare as IsleCompare
from isledecomp.compare.db import MatchInfo
logger = logging.getLogger(__file__)
@dataclass
class CppStackOrRegisterSymbol:
name: str
data_type: str
@dataclass
class CppStackSymbol(CppStackOrRegisterSymbol):
stack_offset: int
"""Should have a value iff `symbol_type=='S_BPREL32'."""
@dataclass
class CppRegisterSymbol(CppStackOrRegisterSymbol):
register: str
"""Should have a value iff `symbol_type=='S_REGISTER'.` Should always be set/converted to lowercase."""
@dataclass
class FunctionSignature:
original_function_symbol: SymbolsEntry
call_type: str
arglist: list[str]
return_type: str
class_type: Optional[str]
stack_symbols: list[CppStackOrRegisterSymbol]
class PdbFunctionExtractor:
"""
Extracts all information on a given function from the parsed PDB
and prepares the data for the import in Ghidra.
"""
def __init__(self, compare: IsleCompare):
self.compare = compare
scalar_type_regex = re.compile(r"t_(?P<typename>\w+)(?:\((?P<type_id>\d+)\))?")
_call_type_map = {
"ThisCall": "__thiscall",
"C Near": "__thiscall",
"STD Near": "__stdcall",
}
def _get_cvdump_type(self, type_name: Optional[str]) -> Optional[dict[str, Any]]:
return (
None
if type_name is None
else self.compare.cv.types.keys.get(type_name.lower())
)
def get_func_signature(self, fn: SymbolsEntry) -> Optional[FunctionSignature]:
function_type_str = fn.func_type
if function_type_str == "T_NOTYPE(0000)":
logger.debug(
"Skipping a NOTYPE (synthetic or template + synthetic): %s", fn.name
)
return None
# get corresponding function type
function_type = self.compare.cv.types.keys.get(function_type_str.lower())
if function_type is None:
logger.error(
"Could not find function type %s for function %s", fn.func_type, fn.name
)
return None
class_type = function_type.get("class_type")
arg_list_type = self._get_cvdump_type(function_type.get("arg_list_type"))
assert arg_list_type is not None
arg_list_pdb_types = arg_list_type.get("args", [])
assert arg_list_type["argcount"] == len(arg_list_pdb_types)
stack_symbols: list[CppStackOrRegisterSymbol] = []
# for some unexplained reason, the reported stack is offset by 4 when this flag is set
stack_offset_delta = -4 if fn.frame_pointer_present else 0
for symbol in fn.stack_symbols:
if symbol.symbol_type == "S_REGISTER":
stack_symbols.append(
CppRegisterSymbol(
symbol.name,
symbol.data_type,
symbol.location,
)
)
elif symbol.symbol_type == "S_BPREL32":
stack_offset = int(symbol.location[1:-1], 16)
stack_symbols.append(
CppStackSymbol(
symbol.name,
symbol.data_type,
stack_offset + stack_offset_delta,
)
)
call_type = self._call_type_map[function_type["call_type"]]
return FunctionSignature(
original_function_symbol=fn,
call_type=call_type,
arglist=arg_list_pdb_types,
return_type=function_type["return_type"],
class_type=class_type,
stack_symbols=stack_symbols,
)
def get_function_list(self) -> list[tuple[MatchInfo, FunctionSignature]]:
handled = (
self.handle_matched_function(match)
for match in self.compare.get_functions()
)
return [signature for signature in handled if signature is not None]
def handle_matched_function(
self, match_info: MatchInfo
) -> Optional[tuple[MatchInfo, FunctionSignature]]:
assert match_info.orig_addr is not None
match_options = self.compare.get_match_options(match_info.orig_addr)
assert match_options is not None
if match_options.get("skip", False) or match_options.get("stub", False):
return None
function_data = next(
(
y
for y in self.compare.cvdump_analysis.nodes
if y.addr == match_info.recomp_addr
),
None,
)
if not function_data:
logger.error(
"Did not find function in nodes, skipping: %s", match_info.name
)
return None
function_symbol = function_data.symbol_entry
if function_symbol is None:
logger.debug(
"Could not find function symbol (likely a PUBLICS entry): %s",
match_info.name,
)
return None
function_signature = self.get_func_signature(function_symbol)
if function_signature is None:
return None
return match_info, function_signature

View File

@@ -0,0 +1,68 @@
from dataclasses import dataclass, field
import logging
from lego_util.exceptions import (
TypeNotFoundInGhidraError,
ClassOrNamespaceNotFoundInGhidraError,
)
logger = logging.getLogger(__name__)
@dataclass
class Statistics:
functions_changed: int = 0
successes: int = 0
failures: dict[str, int] = field(default_factory=dict)
known_missing_types: dict[str, int] = field(default_factory=dict)
known_missing_namespaces: dict[str, int] = field(default_factory=dict)
def track_failure_and_tell_if_new(self, error: Exception) -> bool:
"""
Adds the error to the statistics. Returns `False` if logging the error would be redundant
(e.g. because it is a `TypeNotFoundInGhidraError` with a type that has been logged before).
"""
error_type_name = error.__class__.__name__
self.failures[error_type_name] = (
self.failures.setdefault(error_type_name, 0) + 1
)
if isinstance(error, TypeNotFoundInGhidraError):
return self._add_occurence_and_check_if_new(
self.known_missing_types, error.args[0]
)
if isinstance(error, ClassOrNamespaceNotFoundInGhidraError):
return self._add_occurence_and_check_if_new(
self.known_missing_namespaces, error.get_namespace_str()
)
# We do not have detailed tracking for other errors, so we want to log them every time
return True
def _add_occurence_and_check_if_new(self, target: dict[str, int], key: str) -> bool:
old_count = target.setdefault(key, 0)
target[key] = old_count + 1
return old_count == 0
def log(self):
logger.info("Statistics:\n~~~~~")
logger.info(
"Missing types (with number of occurences): %s\n~~~~~",
self.format_statistics(self.known_missing_types),
)
logger.info(
"Missing classes/namespaces (with number of occurences): %s\n~~~~~",
self.format_statistics(self.known_missing_namespaces),
)
logger.info("Successes: %d", self.successes)
logger.info("Failures: %s", self.failures)
logger.info("Functions changed: %d", self.functions_changed)
def format_statistics(self, stats: dict[str, int]) -> str:
if len(stats) == 0:
return "<none>"
return ", ".join(
f"{entry[0]} ({entry[1]})"
for entry in sorted(stats.items(), key=lambda x: x[1], reverse=True)
)

View File

@@ -0,0 +1,313 @@
import logging
from typing import Any
# Disable spurious warnings in vscode / pylance
# pyright: reportMissingModuleSource=false
# pylint: disable=too-many-return-statements # a `match` would be better, but for now we are stuck with Python 3.9
# pylint: disable=no-else-return # Not sure why this rule even is a thing, this is great for checking exhaustiveness
from lego_util.exceptions import (
ClassOrNamespaceNotFoundInGhidraError,
TypeNotFoundError,
TypeNotFoundInGhidraError,
TypeNotImplementedError,
StructModificationError,
)
from lego_util.ghidra_helper import (
add_pointer_type,
create_ghidra_namespace,
get_ghidra_namespace,
get_ghidra_type,
sanitize_name,
)
from lego_util.pdb_extraction import PdbFunctionExtractor
from ghidra.program.flatapi import FlatProgramAPI
from ghidra.program.model.data import (
ArrayDataType,
CategoryPath,
DataType,
DataTypeConflictHandler,
EnumDataType,
StructureDataType,
StructureInternal,
)
from ghidra.util.task import ConsoleTaskMonitor
logger = logging.getLogger(__name__)
class PdbTypeImporter:
"""Allows PDB types to be imported into Ghidra."""
def __init__(self, api: FlatProgramAPI, extraction: PdbFunctionExtractor):
self.api = api
self.extraction = extraction
# tracks the structs/classes we have already started to import, otherwise we run into infinite recursion
self.handled_structs: set[str] = set()
self.struct_call_stack: list[str] = []
@property
def types(self):
return self.extraction.compare.cv.types
def import_pdb_type_into_ghidra(self, type_index: str) -> DataType:
"""
Recursively imports a type from the PDB into Ghidra.
@param type_index Either a scalar type like `T_INT4(...)` or a PDB reference like `0x10ba`
"""
type_index_lower = type_index.lower()
if type_index_lower.startswith("t_"):
return self._import_scalar_type(type_index_lower)
try:
type_pdb = self.extraction.compare.cv.types.keys[type_index_lower]
except KeyError as e:
raise TypeNotFoundError(
f"Failed to find referenced type '{type_index_lower}'"
) from e
type_category = type_pdb["type"]
# follow forward reference (class, struct, union)
if type_pdb.get("is_forward_ref", False):
return self._import_forward_ref_type(type_index_lower, type_pdb)
if type_category == "LF_POINTER":
return add_pointer_type(
self.api, self.import_pdb_type_into_ghidra(type_pdb["element_type"])
)
elif type_category in ["LF_CLASS", "LF_STRUCTURE"]:
return self._import_class_or_struct(type_pdb)
elif type_category == "LF_ARRAY":
return self._import_array(type_pdb)
elif type_category == "LF_ENUM":
return self._import_enum(type_pdb)
elif type_category == "LF_PROCEDURE":
logger.warning(
"Not implemented: Function-valued argument or return type will be replaced by void pointer: %s",
type_pdb,
)
return get_ghidra_type(self.api, "void")
elif type_category == "LF_UNION":
return self._import_union(type_pdb)
else:
raise TypeNotImplementedError(type_pdb)
_scalar_type_map = {
"rchar": "char",
"int4": "int",
"uint4": "uint",
"real32": "float",
"real64": "double",
}
def _scalar_type_to_cpp(self, scalar_type: str) -> str:
if scalar_type.startswith("32p"):
return f"{self._scalar_type_to_cpp(scalar_type[3:])} *"
return self._scalar_type_map.get(scalar_type, scalar_type)
def _import_scalar_type(self, type_index_lower: str) -> DataType:
if (match := self.extraction.scalar_type_regex.match(type_index_lower)) is None:
raise TypeNotFoundError(f"Type has unexpected format: {type_index_lower}")
scalar_cpp_type = self._scalar_type_to_cpp(match.group("typename"))
return get_ghidra_type(self.api, scalar_cpp_type)
def _import_forward_ref_type(
self, type_index, type_pdb: dict[str, Any]
) -> DataType:
referenced_type = type_pdb.get("udt") or type_pdb.get("modifies")
if referenced_type is None:
try:
# Example: HWND__, needs to be created manually
return get_ghidra_type(self.api, type_pdb["name"])
except TypeNotFoundInGhidraError as e:
raise TypeNotImplementedError(
f"{type_index}: forward ref without target, needs to be created manually: {type_pdb}"
) from e
logger.debug(
"Following forward reference from %s to %s",
type_index,
referenced_type,
)
return self.import_pdb_type_into_ghidra(referenced_type)
def _import_array(self, type_pdb: dict[str, Any]) -> DataType:
inner_type = self.import_pdb_type_into_ghidra(type_pdb["array_type"])
array_total_bytes: int = type_pdb["size"]
data_type_size = inner_type.getLength()
array_length, modulus = divmod(array_total_bytes, data_type_size)
assert (
modulus == 0
), f"Data type size {data_type_size} does not divide array size {array_total_bytes}"
return ArrayDataType(inner_type, array_length, 0)
def _import_union(self, type_pdb: dict[str, Any]) -> DataType:
try:
logger.debug("Dereferencing union %s", type_pdb)
union_type = get_ghidra_type(self.api, type_pdb["name"])
assert (
union_type.getLength() == type_pdb["size"]
), f"Wrong size of existing union type '{type_pdb['name']}': expected {type_pdb['size']}, got {union_type.getLength()}"
return union_type
except TypeNotFoundInGhidraError as e:
# We have so few instances, it is not worth implementing this
raise TypeNotImplementedError(
f"Writing union types is not supported. Please add by hand: {type_pdb}"
) from e
def _import_enum(self, type_pdb: dict[str, Any]) -> DataType:
underlying_type = self.import_pdb_type_into_ghidra(type_pdb["underlying_type"])
field_list = self.extraction.compare.cv.types.keys.get(type_pdb["field_type"])
assert field_list is not None, f"Failed to find field list for enum {type_pdb}"
result = EnumDataType(
CategoryPath("/imported"), type_pdb["name"], underlying_type.getLength()
)
variants: list[dict[str, Any]] = field_list["variants"]
for variant in variants:
result.add(variant["name"], variant["value"])
return result
def _import_class_or_struct(self, type_in_pdb: dict[str, Any]) -> DataType:
field_list_type: str = type_in_pdb["field_list_type"]
field_list = self.types.keys[field_list_type.lower()]
class_size: int = type_in_pdb["size"]
class_name_with_namespace: str = sanitize_name(type_in_pdb["name"])
if class_name_with_namespace in self.handled_structs:
logger.debug(
"Class has been handled or is being handled: %s",
class_name_with_namespace,
)
return get_ghidra_type(self.api, class_name_with_namespace)
logger.debug(
"--- Beginning to import class/struct '%s'", class_name_with_namespace
)
# Add as soon as we start to avoid infinite recursion
self.handled_structs.add(class_name_with_namespace)
self._get_or_create_namespace(class_name_with_namespace)
data_type = self._get_or_create_struct_data_type(
class_name_with_namespace, class_size
)
if (old_size := data_type.getLength()) != class_size:
logger.warning(
"Existing class %s had incorrect size %d. Setting to %d...",
class_name_with_namespace,
old_size,
class_size,
)
logger.info("Adding class data type %s", class_name_with_namespace)
logger.debug("Class information: %s", type_in_pdb)
data_type.deleteAll()
data_type.growStructure(class_size)
# this case happened e.g. for IUnknown, which linked to an (incorrect) existing library, and some other types as well.
# Unfortunately, we don't get proper error handling for read-only types.
# However, we really do NOT want to do this every time because the type might be self-referential and partially imported.
if data_type.getLength() != class_size:
data_type = self._delete_and_recreate_struct_data_type(
class_name_with_namespace, class_size, data_type
)
# can be missing when no new fields are declared
components: list[dict[str, Any]] = field_list.get("members") or []
super_type = field_list.get("super")
if super_type is not None:
components.insert(0, {"type": super_type, "offset": 0, "name": "base"})
for component in components:
ghidra_type = self.import_pdb_type_into_ghidra(component["type"])
logger.debug("Adding component to class: %s", component)
try:
# for better logs
data_type.replaceAtOffset(
component["offset"], ghidra_type, -1, component["name"], None
)
except Exception as e:
raise StructModificationError(type_in_pdb) from e
logger.info("Finished importing class %s", class_name_with_namespace)
return data_type
def _get_or_create_namespace(self, class_name_with_namespace: str):
colon_split = class_name_with_namespace.split("::")
class_name = colon_split[-1]
try:
get_ghidra_namespace(self.api, colon_split)
logger.debug("Found existing class/namespace %s", class_name_with_namespace)
except ClassOrNamespaceNotFoundInGhidraError:
logger.info("Creating class/namespace %s", class_name_with_namespace)
class_name = colon_split.pop()
parent_namespace = create_ghidra_namespace(self.api, colon_split)
self.api.createClass(parent_namespace, class_name)
def _get_or_create_struct_data_type(
self, class_name_with_namespace: str, class_size: int
) -> StructureInternal:
try:
data_type = get_ghidra_type(self.api, class_name_with_namespace)
logger.debug(
"Found existing data type %s under category path %s",
class_name_with_namespace,
data_type.getCategoryPath(),
)
except TypeNotFoundInGhidraError:
# Create a new struct data type
data_type = StructureDataType(
CategoryPath("/imported"), class_name_with_namespace, class_size
)
data_type = (
self.api.getCurrentProgram()
.getDataTypeManager()
.addDataType(data_type, DataTypeConflictHandler.KEEP_HANDLER)
)
logger.info("Created new data type %s", class_name_with_namespace)
assert isinstance(
data_type, StructureInternal
), f"Found type sharing its name with a class/struct, but is not a struct: {class_name_with_namespace}"
return data_type
def _delete_and_recreate_struct_data_type(
self,
class_name_with_namespace: str,
class_size: int,
existing_data_type: DataType,
) -> StructureInternal:
logger.warning(
"Failed to modify data type %s. Will try to delete the existing one and re-create the imported one.",
class_name_with_namespace,
)
assert (
self.api.getCurrentProgram()
.getDataTypeManager()
.remove(existing_data_type, ConsoleTaskMonitor())
), f"Failed to delete and re-create data type {class_name_with_namespace}"
data_type = StructureDataType(
CategoryPath("/imported"), class_name_with_namespace, class_size
)
data_type = (
self.api.getCurrentProgram()
.getDataTypeManager()
.addDataType(data_type, DataTypeConflictHandler.KEEP_HANDLER)
)
assert isinstance(data_type, StructureInternal) # for type checking
return data_type