"""Per-aggregate cross-audit mapping. Maps each audit finding (file:line) to one or more aggregates via the PCG's producers + consumers dictionaries. """ from __future__ import annotations from pathlib import Path from src.code_path_audit import ( CrossAuditFinding, CrossAuditFindings, FunctionRef, find_enclosing_function, ) AUDIT_BUCKET_FIELDS: dict[str, str] = { "audit_weak_types": "weak_types", "audit_exception_handling": "exception_handling", "audit_optional_in_3_files": "optional_in_baseline", "audit_no_models_config_io": "config_io_ownership", "audit_main_thread_imports": "import_graph", } def _all_function_refs( producers: dict[str, list[FunctionRef]], consumers: dict[str, list[FunctionRef]], ) -> list[FunctionRef]: """Flatten all FunctionRefs from the PCG dicts.""" out: list[FunctionRef] = [] for refs in producers.values(): out.extend(refs) for refs in consumers.values(): out.extend(refs) return out def _file_to_aggregates( producers: dict[str, list[FunctionRef]], consumers: dict[str, list[FunctionRef]], ) -> dict[str, set[str]]: """Build a {file: {aggregate, ...}} index for file-level fallback mapping.""" out: dict[str, set[str]] = {} for aggregate, refs in producers.items(): for r in refs: out.setdefault(_normalize_path(r.file), set()).add(aggregate) for aggregate, refs in consumers.items(): for r in refs: out.setdefault(_normalize_path(r.file), set()).add(aggregate) return out def _aggregate_for_fqname( fqname: str, producers: dict[str, list[FunctionRef]], consumers: dict[str, list[FunctionRef]], ) -> str: """Find which aggregate this FunctionRef is associated with.""" for ag, refs in producers.items(): if any(r.fqname == fqname for r in refs): return ag for ag, refs in consumers.items(): if any(r.fqname == fqname for r in refs): return ag return "" def _normalize_path(p: str) -> str: """Normalize file path separators for comparison.""" return p.replace("\\", "/") def map_finding_to_aggregates( file: str, line: int, producers: dict[str, list[FunctionRef]], consumers: dict[str, list[FunctionRef]], ) -> set[str]: """Map a (file, line) finding to a set of aggregate names. Tier 1: function lookup via find_enclosing_function (with line=0 fallback to file-only match). Tier 2: file heuristic via the PCG's file index. File paths are normalized to forward-slash form for comparison. """ all_refs = _all_function_refs(producers, consumers) normalized = _normalize_path(file) fref = find_enclosing_function(file=normalized, line=line, function_refs=all_refs) if fref is None: same_file = [r for r in all_refs if _normalize_path(r.file) == normalized] return {_aggregate_for_fqname(r.fqname, producers, consumers) for r in same_file} return {_aggregate_for_fqname(fref.fqname, producers, consumers)} def aggregate_findings( audit_name: str, findings: list[dict], producers: dict[str, list[FunctionRef]], consumers: dict[str, list[FunctionRef]], ) -> dict[str, list[CrossAuditFinding]]: """Group findings by aggregate via the PCG. Mapping tiers: 1. Function lookup (find_enclosing_function) -> exact match 2. File-level fallback (file has any producer/consumer of the aggregate) 3. Unbucketed (the file has no Metadata-touching functions) """ out: dict[str, list[CrossAuditFinding]] = {} file_index = _file_to_aggregates(producers, consumers) for finding in findings: file = finding.get("file", "") or finding.get("filename", "") line = int(finding.get("line", 0) or 0) note = finding.get("category", "") or finding.get("body_summary", "") or finding.get("note", "") or "" aggregates = map_finding_to_aggregates(file, line, producers, consumers) if not aggregates: normalized = _normalize_path(file) aggregates = file_index.get(normalized, set()) if not aggregates: aggregates = {""} for aggregate in aggregates: cf = CrossAuditFinding( audit_script=audit_name, site_count=1, example_file=file, example_line=line, note=note, ) out.setdefault(aggregate, []).append(cf) return out def build_cross_audit_findings_for_aggregate( aggregate: str, aggregated: dict[str, dict[str, list[CrossAuditFinding]]], ) -> CrossAuditFindings: """Build a CrossAuditFindings struct for one aggregate from aggregated data.""" weak = () exc = () opt = () cfg = () imp = () for audit_name, by_agg in aggregated.items(): findings = by_agg.get(aggregate, []) if not findings: continue bucket = AUDIT_BUCKET_FIELDS.get(audit_name, "") total = len(findings) first = findings[0] combined = CrossAuditFinding( audit_script=audit_name, site_count=total, example_file=first.example_file, example_line=first.example_line, note=f"{total} sites", ) if bucket == "weak_types": weak = (combined,) elif bucket == "exception_handling": exc = (combined,) elif bucket == "optional_in_baseline": opt = (combined,) elif bucket == "config_io_ownership": cfg = (combined,) elif bucket == "import_graph": imp = (combined,) return CrossAuditFindings( weak_types=weak, exception_handling=exc, optional_in_baseline=opt, config_io_ownership=cfg, import_graph=imp, )