"""Tests for src.code_path_audit v2 - Phase 1 (data model).""" from __future__ import annotations import ast import textwrap import tempfile from pathlib import Path from collections import Counter import pytest from src.code_path_audit import ( AggregateKind, MemoryDim, AccessPattern, Frequency, RecommendedDirection, FunctionRef, AccessPatternEvidence, FrequencyEvidence, ResultCoverage, TypeAliasCoverage, CrossAuditFinding, CrossAuditFindings, DecompositionCost, OptimizationCandidate, AggregateProfile, ProducerConsumerGraph, P1_pass, P2_pass, P3_pass, build_pcg, CANONICAL_MEMORY_DIM, MEMORY_DIM_FILE_HEURISTIC, load_memory_dim_overrides, file_origin_memory_dim, classify_memory_dim, WHOLE_STRUCT_KEY_THRESHOLD, FIELD_BY_FIELD_KEY_THRESHOLD, MIXED_DOMINANCE_THRESHOLD, AGGREGATE_LEVEL_DOMINANCE_THRESHOLD, is_whole_struct_access, is_field_by_field_access, is_hot_cold_split, is_bulk_batched_access, dominant_pattern, detect_access_pattern, ) from src.result_types import Result, ErrorInfo, ErrorKind def test_aggregate_kind_4_values() -> None: """AggregateKind is a Literal with 4 values: typealias, dataclass, candidate_dataclass, builtin.""" expected = {"typealias", "dataclass", "candidate_dataclass", "builtin"} assert set(AggregateKind.__args__) == expected def test_memory_dim_7_values() -> None: """MemoryDim is a Literal with 7 values: curation, discussion, rag, knowledge, config, control, unknown.""" expected = {"curation", "discussion", "rag", "knowledge", "config", "control", "unknown"} assert set(MemoryDim.__args__) == expected def test_access_pattern_5_values() -> None: """AccessPattern is a Literal with 5 values: whole_struct, field_by_field, hot_cold_split, bulk_batched, mixed.""" expected = {"whole_struct", "field_by_field", "hot_cold_split", "bulk_batched", "mixed"} assert set(AccessPattern.__args__) == expected def test_frequency_7_values() -> None: """Frequency is a Literal with 7 values: hot, per_turn, per_discussion, per_request, cold, init, unknown.""" expected = {"hot", "per_turn", "per_discussion", "per_request", "cold", "init", "unknown"} assert set(Frequency.__args__) == expected def test_recommended_direction_4_values() -> None: """RecommendedDirection is a Literal with 4 values: componentize, unify, hold, insufficient_data.""" expected = {"componentize", "unify", "hold", "insufficient_data"} assert set(RecommendedDirection.__args__) == expected def test_function_ref_4_fields() -> None: """FunctionRef has fqname, file, line, role (per spec).""" ref = FunctionRef( fqname="src.ai_client.AIClient.send_result", file="src/ai_client.py", line=100, role="producer", ) assert ref.fqname == "src.ai_client.AIClient.send_result" assert ref.file == "src/ai_client.py" assert ref.line == 100 assert ref.role == "producer" def test_function_ref_frozen() -> None: """FunctionRef is frozen (immutability per error_handling.md).""" ref = FunctionRef( fqname="src.x.y", file="src/x.py", line=1, role="consumer", ) with pytest.raises((AttributeError, Exception)) as exc_info: ref.fqname = "src.z.w" assert "frozen" in str(exc_info.value).lower() or "cannot assign" in str(exc_info.value).lower() def test_access_pattern_evidence_4_fields() -> None: """AccessPatternEvidence has function, pattern, field_accesses, confidence.""" ref = FunctionRef(fqname="src.x.y", file="src/x.py", line=1, role="consumer") ev = AccessPatternEvidence( function=ref, pattern="field_by_field", field_accesses={"path": 3, "view_mode": 2}, confidence="high", ) assert ev.function is ref assert ev.pattern == "field_by_field" assert ev.field_accesses == {"path": 3, "view_mode": 2} assert ev.confidence == "high" def test_frequency_evidence_4_fields() -> None: """FrequencyEvidence has function, frequency, source, note (default '').""" ref = FunctionRef(fqname="src.x.y", file="src/x.py", line=1, role="both") ev = FrequencyEvidence( function=ref, frequency="per_turn", source="entry_point", note="called per LLM turn", ) assert ev.function is ref assert ev.frequency == "per_turn" assert ev.source == "entry_point" assert ev.note == "called per LLM turn" def test_frequency_evidence_default_note() -> None: """FrequencyEvidence.note defaults to ''.""" ref = FunctionRef(fqname="src.x.y", file="src/x.py", line=1, role="consumer") ev = FrequencyEvidence(function=ref, frequency="cold", source="control_flow_position") assert ev.note == "" def test_result_coverage_5_fields() -> None: """ResultCoverage has total_producers, result_producers, total_consumers, result_consumers, summary.""" cov = ResultCoverage( total_producers=12, result_producers=5, total_consumers=15, result_consumers=8, summary="5/12 producers return Result[T] (42%); 8/15 consumers branch on .errors (53%)", ) assert cov.total_producers == 12 assert cov.result_producers == 5 assert cov.total_consumers == 15 assert cov.result_consumers == 8 assert "42%" in cov.summary assert "53%" in cov.summary def test_type_alias_coverage_4_fields() -> None: """TypeAliasCoverage has total_sites, typed_sites, untyped_sites, summary.""" cov = TypeAliasCoverage( total_sites=45, typed_sites=38, untyped_sites=7, summary="45 total sites; 38 typed (84%); 7 untyped (16%)", ) assert cov.total_sites == 45 assert cov.typed_sites == 38 assert cov.untyped_sites == 7 assert "84%" in cov.summary assert "16%" in cov.summary def test_cross_audit_finding_5_fields() -> None: """CrossAuditFinding has audit_script, site_count, example_file, example_line, note (default '').""" finding = CrossAuditFinding( audit_script="audit_weak_types", site_count=12, example_file="src/ai_client.py", example_line=100, note="12 weak-type sites in producer+consumer functions", ) assert finding.audit_script == "audit_weak_types" assert finding.site_count == 12 assert finding.example_file == "src/ai_client.py" assert finding.example_line == 100 assert finding.note == "12 weak-type sites in producer+consumer functions" def test_cross_audit_finding_default_note() -> None: """CrossAuditFinding.note defaults to ''.""" finding = CrossAuditFinding( audit_script="audit_optional_in_3_files", site_count=0, example_file="", example_line=0, ) assert finding.note == "" def test_cross_audit_findings_5_audit_scripts() -> None: """CrossAuditFindings has 5 audit-script fields, each a tuple of CrossAuditFinding.""" findings = CrossAuditFindings( weak_types=(), exception_handling=(), optional_in_baseline=(), config_io_ownership=(), import_graph=(), ) assert findings.weak_types == () assert findings.exception_handling == () assert findings.optional_in_baseline == () assert findings.config_io_ownership == () assert findings.import_graph == () def test_decomposition_cost_8_fields() -> None: """DecompositionCost has 8 fields per spec.""" cost = DecompositionCost( current_cost_estimate=1500, componentize_savings=450, unify_savings=0, recommended_direction="hold", recommended_rationale="whole_struct access on a frozen dataclass; current shape is correct", batch_size=None, struct_field_count=8, struct_frozen=True, ) assert cost.current_cost_estimate == 1500 assert cost.componentize_savings == 450 assert cost.unify_savings == 0 assert cost.recommended_direction == "hold" assert "frozen" in cost.recommended_rationale assert cost.batch_size is None assert cost.struct_field_count == 8 assert cost.struct_frozen is True def test_optimization_candidate_7_fields() -> None: """OptimizationCandidate has 7 fields per spec.""" cand = OptimizationCandidate( candidate="Migrate 7 producers of Metadata to Result[Metadata]", direction="componentize", affected_files=("src/ai_client.py", "src/app_controller.py", "src/history.py"), estimated_savings_us=500, effort="small", priority="high", cross_ref="docs/reports/EXCEPTION_HANDLING_AUDIT_20260616.md", ) assert "Migrate" in cand.candidate assert cand.direction == "componentize" assert len(cand.affected_files) == 3 assert cand.estimated_savings_us == 500 assert cand.effort == "small" assert cand.priority == "high" assert "EXCEPTION_HANDLING_AUDIT" in cand.cross_ref def test_aggregate_profile_14_fields() -> None: """AggregateProfile has 14 top-level fields (per spec section 7.1).""" f = FunctionRef(fqname="src.x.y", file="src/x.py", line=1, role="producer") profile = AggregateProfile( name="Metadata", aggregate_kind="typealias", memory_dim="discussion", producers=(f,), consumers=(f,), access_pattern="field_by_field", access_pattern_evidence=(AccessPatternEvidence( function=f, pattern="field_by_field", field_accesses={"role": 3}, confidence="high" ),), frequency="per_turn", frequency_evidence=(FrequencyEvidence( function=f, frequency="per_turn", source="entry_point", note="per LLM turn" ),), result_coverage=ResultCoverage(0, 0, 0, 0, ""), type_alias_coverage=TypeAliasCoverage(0, 0, 0, ""), cross_audit_findings=CrossAuditFindings((), (), (), (), ()), decomposition_cost=DecompositionCost(0, 0, 0, "hold", "no data", None, 0, False), optimization_candidates=(), is_candidate=False, ) assert profile.name == "Metadata" assert profile.aggregate_kind == "typealias" assert profile.memory_dim == "discussion" assert len(profile.producers) == 1 assert len(profile.consumers) == 1 assert profile.access_pattern == "field_by_field" assert len(profile.access_pattern_evidence) == 1 assert profile.frequency == "per_turn" assert len(profile.frequency_evidence) == 1 assert profile.is_candidate is False def test_aggregate_profile_is_candidate_true() -> None: """AggregateProfile.is_candidate=True for the 3 candidate aggregates.""" profile = AggregateProfile( name="ChatMessage", aggregate_kind="candidate_dataclass", memory_dim="discussion", producers=(), consumers=(), access_pattern="mixed", access_pattern_evidence=(), frequency="unknown", frequency_evidence=(), result_coverage=ResultCoverage(0, 0, 0, 0, ""), type_alias_coverage=TypeAliasCoverage(0, 0, 0, ""), cross_audit_findings=CrossAuditFindings((), (), (), (), ()), decomposition_cost=DecompositionCost(0, 0, 0, "insufficient_data", "candidate", None, 0, False), optimization_candidates=(), is_candidate=True, ) assert profile.is_candidate is True assert profile.aggregate_kind == "candidate_dataclass" assert profile.producers == () assert profile.consumers == () def test_pcg_init_empty() -> None: """ProducerConsumerGraph starts with empty edges and producers/consumers dicts.""" pcg = ProducerConsumerGraph() assert pcg.edges == {} assert pcg.producers == {} assert pcg.consumers == {} def test_pcg_add_producer_consumer() -> None: """add_producer + add_consumer add to the bipartite graph.""" pcg = ProducerConsumerGraph() f = FunctionRef(fqname="src.x.y", file="src/x.py", line=1, role="producer") pcg.add_producer("Metadata", f) pcg.add_consumer("Metadata", f) assert "Metadata" in pcg.producers assert "Metadata" in pcg.consumers assert f in pcg.producers["Metadata"] assert f in pcg.consumers["Metadata"] def test_p1_pass_finds_producer_of_T() -> None: """P1 detects a function whose return annotation is a TypeAlias name (producer of T).""" source = textwrap.dedent(''' def send_result() -> Metadata: return {} ''') tree = ast.parse(source) producers = P1_pass(tree, file="synthetic.py") assert ("send_result", "Metadata", "producer", "high") in producers def test_p1_pass_finds_producer_of_Result_T() -> None: """P1 detects a function whose return annotation is Result[T] (producer of T).""" source = textwrap.dedent(''' def fetch() -> Result[FileItems]: return Result(data=[]) ''') tree = ast.parse(source) producers = P1_pass(tree, file="synthetic.py") assert ("fetch", "FileItems", "producer", "high") in producers def test_p1_pass_skips_non_annotated_return() -> None: """P1 returns [] for functions without return annotations.""" source = textwrap.dedent(''' def unannotated(): return {} ''') tree = ast.parse(source) producers = P1_pass(tree, file="synthetic.py") assert producers == [] def test_p2_pass_finds_consumer_of_T() -> None: """P2 detects a function whose parameter is a TypeAlias name (consumer of T).""" source = textwrap.dedent(''' def process(entry: Metadata) -> None: pass ''') tree = ast.parse(source) consumers = P2_pass(tree, file="synthetic.py") assert ("process", "Metadata", "consumer", "high") in consumers def test_p2_pass_finds_consumer_of_list_T() -> None: """P2 detects a function whose parameter is list[T] (consumer of T).""" source = textwrap.dedent(''' def aggregate(items: list[FileItems]) -> None: pass ''') tree = ast.parse(source) consumers = P2_pass(tree, file="synthetic.py") assert ("aggregate", "FileItems", "consumer", "high") in consumers def test_p2_pass_skips_untyped_parameter() -> None: """P2 returns [] for parameters without type annotations.""" source = textwrap.dedent(''' def process(entry) -> None: pass ''') tree = ast.parse(source) consumers = P2_pass(tree, file="synthetic.py") assert consumers == [] def test_p3_pass_finds_consumer_via_subscript() -> None: """P3 detects a function that reads entry['path']; without a type registry, returns the field name only.""" source = textwrap.dedent(''' def process(entry) -> None: path = entry['path'] ''') tree = ast.parse(source) accesses = P3_pass(tree, file="synthetic.py", type_registry={}) assert ("process", "path", "subscript", 1) in accesses def test_p3_pass_finds_consumer_via_attribute() -> None: """P3 detects a function that reads entry.attr; returns (function, attr, kind, count).""" source = textwrap.dedent(''' def process(entry) -> None: path = entry.path ''') tree = ast.parse(source) accesses = P3_pass(tree, file="synthetic.py", type_registry={}) assert ("process", "path", "attribute", 1) in accesses def test_p3_pass_counts_multiple_accesses() -> None: """P3 counts multiple accesses to the same key within a single function.""" source = textwrap.dedent(''' def process(entry) -> None: a = entry['path'] b = entry['path'] c = entry['view_mode'] ''') tree = ast.parse(source) accesses = P3_pass(tree, file="synthetic.py", type_registry={}) path_count = sum(c for fn, k, kind, c in accesses if fn == "process" and k == "path") assert path_count == 2 def test_build_pcg_returns_result() -> None: """build_pcg returns Result[ProducerConsumerGraph] per error_handling.md.""" with tempfile.TemporaryDirectory() as tmp: (Path(tmp) / "mod.py").write_text(textwrap.dedent(''' from src.type_aliases import Metadata def produce() -> Metadata: return {} ''')) result = build_pcg(tmp) assert isinstance(result, Result) assert result.ok def test_build_pcg_finds_producer_via_p1() -> None: """build_pcg correctly identifies a producer of Metadata via P1.""" with tempfile.TemporaryDirectory() as tmp: (Path(tmp) / "mod.py").write_text(textwrap.dedent(''' from src.type_aliases import Metadata def produce() -> Metadata: return {} ''')) pcg = build_pcg(tmp).data assert "Metadata" in pcg.producers def test_build_pcg_tolerates_syntax_errors() -> None: """build_pcg records syntax errors as ErrorInfo (boundary pattern); Result.ok is False.""" with tempfile.TemporaryDirectory() as tmp: (Path(tmp) / "bad.py").write_text("def unclosed(:\n pass") result = build_pcg(tmp) assert not result.ok assert len(result.errors) >= 1 assert isinstance(result.errors[0], ErrorInfo) assert result.data is not None def test_canonical_memory_dim_has_aggregates() -> None: """CANONICAL_MEMORY_DIM has >=13 known aggregate -> dim mappings (10 in-scope + 3 candidate).""" assert len(CANONICAL_MEMORY_DIM) >= 13 assert CANONICAL_MEMORY_DIM["Metadata"] == "discussion" assert CANONICAL_MEMORY_DIM["CommsLogEntry"] == "discussion" assert CANONICAL_MEMORY_DIM["FileItem"] == "curation" assert CANONICAL_MEMORY_DIM["FileItems"] == "curation" assert CANONICAL_MEMORY_DIM["Result"] == "control" assert CANONICAL_MEMORY_DIM["ErrorInfo"] == "control" def test_memory_dim_file_heuristic_has_5_buckets() -> None: """MEMORY_DIM_FILE_HEURISTIC has the 5 file-of-origin buckets.""" assert len(MEMORY_DIM_FILE_HEURISTIC) == 5 assert "curation" in MEMORY_DIM_FILE_HEURISTIC assert "discussion" in MEMORY_DIM_FILE_HEURISTIC assert "rag" in MEMORY_DIM_FILE_HEURISTIC assert "config" in MEMORY_DIM_FILE_HEURISTIC def test_load_memory_dim_overrides_empty() -> None: """load_memory_dim_overrides returns {} for a missing file.""" result = load_memory_dim_overrides("/nonexistent/overrides.toml") assert result == {} def test_load_memory_dim_overrides_parses_toml() -> None: """load_memory_dim_overrides parses [memory_dim.] = '' lines.""" with tempfile.TemporaryDirectory() as tmp: overrides_path = Path(tmp) / "overrides.toml" overrides_path.write_text('[memory_dim]\nMetadata = "curation"\n') result = load_memory_dim_overrides(str(overrides_path)) assert result.get("Metadata") == "curation" def test_file_origin_memory_dim_curation() -> None: """file_origin_memory_dim returns 'curation' for files in the curation bucket.""" dim = file_origin_memory_dim("src/aggregate.py") assert dim == "curation" def test_file_origin_memory_dim_discussion() -> None: """file_origin_memory_dim returns 'discussion' for files in the discussion bucket.""" dim = file_origin_memory_dim("src/ai_client.py") assert dim == "discussion" def test_file_origin_memory_dim_unknown() -> None: """file_origin_memory_dim returns 'unknown' for files not in any bucket.""" dim = file_origin_memory_dim("src/random.py") assert dim == "unknown" def test_classify_memory_dim_canonical() -> None: """classify_memory_dim returns the canonical dim for known aggregates, regardless of producer file.""" dim = classify_memory_dim("Metadata", "src/aggregate.py", overrides={}) assert dim == "discussion" def test_classify_memory_dim_override() -> None: """classify_memory_dim respects the override file's mapping.""" dim = classify_memory_dim("Metadata", "src/aggregate.py", overrides={"Metadata": "curation"}) assert dim == "curation" def test_classify_memory_dim_file_heuristic() -> None: """classify_memory_dim falls back to file-of-origin for unknown aggregates.""" dim = classify_memory_dim("SomeUnknownAggregate", "src/aggregate.py", overrides={}) assert dim == "curation" def test_classify_memory_dim_unknown_when_no_evidence() -> None: """classify_memory_dim returns 'unknown' when no canonical, override, or file evidence.""" dim = classify_memory_dim("SomeUnknownAggregate", "src/random.py", overrides={}) assert dim == "unknown" def test_threshold_constants() -> None: """The 4 APD threshold constants are defined.""" assert WHOLE_STRUCT_KEY_THRESHOLD == 1 assert FIELD_BY_FIELD_KEY_THRESHOLD == 3 assert MIXED_DOMINANCE_THRESHOLD == 0.6 assert AGGREGATE_LEVEL_DOMINANCE_THRESHOLD == 0.25 def test_is_whole_struct_access_true() -> None: """is_whole_struct_access returns True for a function that reads the aggregate without accessing fields.""" counts: Counter[str] = Counter() assert is_whole_struct_access(counts, has_direct_access=True) is True def test_is_whole_struct_access_one_key() -> None: """is_whole_struct_access returns True for <=1 distinct key.""" counts: Counter[str] = Counter({"path": 3}) assert is_whole_struct_access(counts, has_direct_access=False) is True def test_is_whole_struct_access_two_keys_false() -> None: """is_whole_struct_access returns False for 2+ distinct keys.""" counts: Counter[str] = Counter({"path": 3, "view_mode": 2}) assert is_whole_struct_access(counts, has_direct_access=False) is False def test_is_field_by_field_access_true() -> None: """is_field_by_field_access returns True for >=3 distinct keys AND no whole_struct access.""" counts: Counter[str] = Counter({"a": 1, "b": 1, "c": 1}) assert is_field_by_field_access(counts) is True def test_is_field_by_field_access_few_keys() -> None: """is_field_by_field_access returns False for <3 distinct keys.""" counts: Counter[str] = Counter({"a": 1, "b": 1}) assert is_field_by_field_access(counts) is False def test_is_hot_cold_split_true() -> None: """is_hot_cold_split returns True for 1-2 hot keys + 2+ cold keys.""" hot = {"role", "content"} cold = {"tool_calls", "reasoning_content"} assert is_hot_cold_split(hot, cold) is True def test_is_hot_cold_split_too_many_hot() -> None: """is_hot_cold_split returns False for 3+ hot keys.""" hot = {"a", "b", "c"} cold = {"d", "e"} assert is_hot_cold_split(hot, cold) is False def test_is_hot_cold_split_too_few_cold() -> None: """is_hot_cold_split returns False for <2 cold keys.""" hot = {"a"} cold = {"b"} assert is_hot_cold_split(hot, cold) is False def test_is_bulk_batched_access_true() -> None: """is_bulk_batched_access returns True for a function iterating over a list of aggregates.""" assert is_bulk_batched_access(iterates_over_list=True, body_accesses_uniform=True) is True def test_is_bulk_batched_access_no_iteration() -> None: """is_bulk_batched_access returns False when the function doesn't iterate over a list.""" assert is_bulk_batched_access(iterates_over_list=False, body_accesses_uniform=True) is False def test_is_bulk_batched_access_non_uniform() -> None: """is_bulk_batched_access returns False when the body has non-uniform access.""" assert is_bulk_batched_access(iterates_over_list=True, body_accesses_uniform=False) is False def test_dominant_pattern_clear_winner() -> None: """dominant_pattern returns the pattern with the highest share if >=25%.""" counts = {"field_by_field": 3, "whole_struct": 1} assert dominant_pattern(counts) == "field_by_field" def test_dominant_pattern_below_threshold() -> None: """dominant_pattern returns 'mixed' when no pattern has >=25% share.""" counts = {"field_by_field": 1, "whole_struct": 1, "hot_cold_split": 1, "bulk_batched": 1} assert dominant_pattern(counts) == "mixed" def test_dominant_pattern_empty() -> None: """dominant_pattern returns 'mixed' for an empty counts dict.""" assert dominant_pattern({}) == "mixed" def test_detect_access_pattern_whole_struct() -> None: """detect_access_pattern returns 'whole_struct' for a function that reads 0-1 keys.""" counts: Counter[str] = Counter() pattern = detect_access_pattern(counts, has_direct_access=True) assert pattern == "whole_struct" def test_detect_access_pattern_field_by_field() -> None: """detect_access_pattern returns 'field_by_field' for a function that reads 3+ keys.""" counts: Counter[str] = Counter({"a": 1, "b": 1, "c": 1}) pattern = detect_access_pattern(counts, has_direct_access=False) assert pattern == "field_by_field" def test_detect_access_pattern_hot_cold_split() -> None: """detect_access_pattern returns 'hot_cold_split' for 1-2 hot + 2+ cold keys.""" counts: Counter[str] = Counter({"a": 1, "b": 1, "c": 1, "d": 1}) pattern = detect_access_pattern(counts, has_direct_access=False, hot_keys={"a", "b"}, cold_keys={"c", "d"}) assert pattern == "hot_cold_split" def test_detect_access_pattern_mixed() -> None: """detect_access_pattern returns 'mixed' when no pattern dominates (2+ distinct keys but <3).""" counts: Counter[str] = Counter({"a": 1, "b": 1}) pattern = detect_access_pattern(counts, has_direct_access=False) assert pattern == "mixed"