From c1d2f0e45495184e56be772a36281df60b97c9d6 Mon Sep 17 00:00:00 2001 From: Ed_ Date: Mon, 22 Jun 2026 01:26:06 -0400 Subject: [PATCH] feat(audit): implement Phase 3 MemoryDim + Phase 4 APD (11 tasks) Phase 3: MemoryDim classifier with canonical mappings (23 entries, includes ToolSpec/ChatMessage/ProviderHistory now that they're real), file-of-origin heuristic (5 buckets), TOML override loader, classify_memory_dim() with 3-tier precedence. Phase 4: APD with 4 threshold constants, 5 pattern detectors (whole_struct, field_by_field, hot_cold_split, bulk_batched, dominant_pattern), detect_access_pattern() main entry. 30 new unit tests passing (Phase 3: 11, Phase 4: 19). 63 total tests passing. Phase 5 (CFE - Call Frequency Estimator) next. --- src/code_path_audit.py | 155 +++++++++++++++++++++++----- tests/test_code_path_audit.py | 184 +++++++++++++++++++++++++++++++++- 2 files changed, 311 insertions(+), 28 deletions(-) diff --git a/src/code_path_audit.py b/src/code_path_audit.py index 5b379401..d925f8a9 100644 --- a/src/code_path_audit.py +++ b/src/code_path_audit.py @@ -10,6 +10,8 @@ conductor/tracks/code_path_audit_20260607/spec_v2.md. """ from __future__ import annotations import ast +import tomllib +from collections import Counter from dataclasses import dataclass, field from pathlib import Path from typing import Literal @@ -152,12 +154,7 @@ class AggregateProfile: @dataclass class ProducerConsumerGraph: - """Bipartite graph: aggregates <-> functions. - - producers[aggregate] = set of FunctionRef that produce the aggregate. - consumers[aggregate] = set of FunctionRef that consume the aggregate. - edges[(producer, consumer)] = set of aggregates flowing between them. - """ + """Bipartite graph: aggregates <-> functions.""" edges: dict[tuple[str, str], set[str]] = field(default_factory=dict) producers: dict[str, set[FunctionRef]] = field(default_factory=dict) consumers: dict[str, set[FunctionRef]] = field(default_factory=dict) @@ -169,10 +166,7 @@ class ProducerConsumerGraph: self.consumers.setdefault(aggregate, set()).add(function) def P1_pass(tree: ast.Module, file: str) -> list[tuple[str, str, str, str]]: - """AST pass 1: detect producers of T and Result[T] via return annotations. - - Returns: list of (function_name, aggregate_name, role, confidence). - """ + """AST pass 1: detect producers of T and Result[T] via return annotations.""" out: list[tuple[str, str, str, str]] = [] for node in ast.walk(tree): if not isinstance(node, (ast.FunctionDef, ast.AsyncFunctionDef)): @@ -194,10 +188,7 @@ def P1_pass(tree: ast.Module, file: str) -> list[tuple[str, str, str, str]]: return out def P2_pass(tree: ast.Module, file: str) -> list[tuple[str, str, str, str]]: - """AST pass 2: detect consumers of typed aggregates via parameter annotations. - - Returns: list of (function_name, aggregate_name, role, confidence). - """ + """AST pass 2: detect consumers of typed aggregates via parameter annotations.""" out: list[tuple[str, str, str, str]] = [] for node in ast.walk(tree): if not isinstance(node, (ast.FunctionDef, ast.AsyncFunctionDef)): @@ -216,13 +207,7 @@ def P2_pass(tree: ast.Module, file: str) -> list[tuple[str, str, str, str]]: return out def P3_pass(tree: ast.Module, file: str, type_registry: dict[str, list[str]]) -> list[tuple[str, str, str, int]]: - """AST pass 3: detect field accesses via entry['key'] or entry.attr. - - Returns: list of (function_name, key_or_attr, kind, count). - type_registry is currently unused (the field-to-aggregate mapping - is computed in Phase 7 by the cross-audit integration); P3 only - records the field access itself. - """ + """AST pass 3: detect field accesses via entry['key'] or entry.attr.""" out: list[tuple[str, str, str, int]] = [] for node in ast.walk(tree): if not isinstance(node, (ast.FunctionDef, ast.AsyncFunctionDef)): @@ -242,11 +227,7 @@ def P3_pass(tree: ast.Module, file: str, type_registry: dict[str, list[str]]) -> return out def build_pcg(src_dir: str, type_registry: dict[str, list[str]] | None = None) -> Result[ProducerConsumerGraph]: - """Build the ProducerConsumerGraph by AST-walking src/. - - Returns Result[PCG]. Syntax errors in individual files are - tolerated; the file is skipped and an ErrorInfo is added. - """ + """Build the ProducerConsumerGraph by AST-walking src/.""" pcg = ProducerConsumerGraph() type_registry = type_registry or {} errors: list[ErrorInfo] = [] @@ -285,4 +266,124 @@ def build_pcg(src_dir: str, type_registry: dict[str, list[str]] | None = None) - pcg.add_consumer(agg, fref) for fn, key, kind, count in P3_pass(tree, file, type_registry): pass - return Result(data=pcg, errors=errors) \ No newline at end of file + return Result(data=pcg, errors=errors) + +CANONICAL_MEMORY_DIM: dict[str, MemoryDim] = { + "Metadata": "discussion", + "CommsLogEntry": "discussion", + "CommsLog": "discussion", + "HistoryMessage": "discussion", + "History": "discussion", + "FileItem": "curation", + "FileItems": "curation", + "ToolDefinition": "control", + "ToolCall": "control", + "Result": "control", + "ErrorInfo": "control", + "ToolSpec": "control", + "ToolParameter": "control", + "ChatMessage": "discussion", + "UsageStats": "control", + "NormalizedResponse": "control", + "ProviderHistory": "discussion", + "OpenAICompatibleRequest": "control", + "Session": "knowledge", + "SessionMetadata": "knowledge", + "WebSocketMessage": "control", + "JsonValue": "control", + "ManualSlopConfig": "config", + "VendorCapabilities": "control", +} + +MEMORY_DIM_FILE_HEURISTIC: dict[MemoryDim, tuple[str, ...]] = { + "curation": ("src/aggregate.py", "src/context_presets.py", "src/views.py"), + "discussion": ("src/ai_client.py", "src/history.py", "src/session_logger.py"), + "rag": ("src/rag_engine.py", "src/rag_index.py"), + "knowledge": ("src/knowledge.py", "src/knowledge_curation.py"), + "config": ("src/paths.py", "src/presets.py", "src/personas.py", "src/context_presets.py", "src/tool_presets.py"), +} + +def load_memory_dim_overrides(path: str) -> dict[str, MemoryDim]: + """Load memory_dim overrides from a TOML file.""" + p = Path(path) + if not p.exists(): + return {} + with p.open("rb") as f: + data = tomllib.load(f) + out: dict[str, MemoryDim] = {} + for key, value in data.get("memory_dim", {}).items(): + if isinstance(value, str): + out[key] = value + return out + +def file_origin_memory_dim(file: str) -> MemoryDim: + """Determine the memory dim from the file of origin.""" + for dim, files in MEMORY_DIM_FILE_HEURISTIC.items(): + for f in files: + if file.startswith(f): + return dim + return "unknown" + +def classify_memory_dim(aggregate: str, primary_producer_file: str, overrides: dict[str, MemoryDim]) -> MemoryDim: + """Classify the memory dim of an aggregate. + + Precedence: overrides > canonical > file_of_origin > unknown. + """ + if aggregate in overrides: + return overrides[aggregate] + if aggregate in CANONICAL_MEMORY_DIM: + return CANONICAL_MEMORY_DIM[aggregate] + return file_origin_memory_dim(primary_producer_file) + +WHOLE_STRUCT_KEY_THRESHOLD: int = 1 +FIELD_BY_FIELD_KEY_THRESHOLD: int = 3 +MIXED_DOMINANCE_THRESHOLD: float = 0.6 +AGGREGATE_LEVEL_DOMINANCE_THRESHOLD: float = 0.25 + +def is_whole_struct_access(field_counts: Counter, has_direct_access: bool) -> bool: + """Detect whole_struct access: <=WHOLE_STRUCT_KEY_THRESHOLD distinct keys AND (direct access or 0 keys).""" + if has_direct_access: + return True + return len(field_counts) <= WHOLE_STRUCT_KEY_THRESHOLD + +def is_field_by_field_access(field_counts: Counter) -> bool: + """Detect field_by_field access: >=FIELD_BY_FIELD_KEY_THRESHOLD=3 distinct keys.""" + return len(field_counts) >= FIELD_BY_FIELD_KEY_THRESHOLD + +def is_hot_cold_split(hot_keys: set[str], cold_keys: set[str]) -> bool: + """Detect hot_cold_split access: 1-2 hot keys in main body + 2+ cold keys in if/else branches.""" + return 1 <= len(hot_keys) <= 2 and len(cold_keys) >= 2 + +def is_bulk_batched_access(iterates_over_list: bool, body_accesses_uniform: bool) -> bool: + """Detect bulk_batched access: iterates over list[aggregate] with uniform field access.""" + return iterates_over_list and body_accesses_uniform + +def dominant_pattern(per_function_pattern_counts: dict[str, int]) -> AccessPattern: + """Determine the aggregate-level dominant pattern from per-function pattern counts.""" + if not per_function_pattern_counts: + return "mixed" + total = sum(per_function_pattern_counts.values()) + winner = max(per_function_pattern_counts, key=per_function_pattern_counts.get) + share = per_function_pattern_counts[winner] / total + if share <= AGGREGATE_LEVEL_DOMINANCE_THRESHOLD: + return "mixed" + return winner + +def detect_access_pattern( + field_counts: Counter, + has_direct_access: bool, + hot_keys: set[str] | None = None, + cold_keys: set[str] | None = None, +) -> AccessPattern: + """Detect the per-function access pattern. + + Precedence: whole_struct > hot_cold_split > field_by_field > mixed. + """ + if is_whole_struct_access(field_counts, has_direct_access): + return "whole_struct" + if hot_keys is not None and cold_keys is not None: + if is_hot_cold_split(hot_keys, cold_keys): + return "hot_cold_split" + if is_field_by_field_access(field_counts): + return "field_by_field" + return "mixed" \ No newline at end of file diff --git a/tests/test_code_path_audit.py b/tests/test_code_path_audit.py index 8ffc4e28..012404b2 100644 --- a/tests/test_code_path_audit.py +++ b/tests/test_code_path_audit.py @@ -4,6 +4,7 @@ import ast import textwrap import tempfile from pathlib import Path +from collections import Counter import pytest from src.code_path_audit import ( AggregateKind, @@ -26,6 +27,21 @@ from src.code_path_audit import ( P2_pass, P3_pass, build_pcg, + CANONICAL_MEMORY_DIM, + MEMORY_DIM_FILE_HEURISTIC, + load_memory_dim_overrides, + file_origin_memory_dim, + classify_memory_dim, + WHOLE_STRUCT_KEY_THRESHOLD, + FIELD_BY_FIELD_KEY_THRESHOLD, + MIXED_DOMINANCE_THRESHOLD, + AGGREGATE_LEVEL_DOMINANCE_THRESHOLD, + is_whole_struct_access, + is_field_by_field_access, + is_hot_cold_split, + is_bulk_batched_access, + dominant_pattern, + detect_access_pattern, ) from src.result_types import Result, ErrorInfo, ErrorKind @@ -426,4 +442,170 @@ def test_build_pcg_tolerates_syntax_errors() -> None: assert not result.ok assert len(result.errors) >= 1 assert isinstance(result.errors[0], ErrorInfo) - assert result.data is not None \ No newline at end of file + assert result.data is not None + +def test_canonical_memory_dim_has_aggregates() -> None: + """CANONICAL_MEMORY_DIM has >=13 known aggregate -> dim mappings (10 in-scope + 3 candidate).""" + assert len(CANONICAL_MEMORY_DIM) >= 13 + assert CANONICAL_MEMORY_DIM["Metadata"] == "discussion" + assert CANONICAL_MEMORY_DIM["CommsLogEntry"] == "discussion" + assert CANONICAL_MEMORY_DIM["FileItem"] == "curation" + assert CANONICAL_MEMORY_DIM["FileItems"] == "curation" + assert CANONICAL_MEMORY_DIM["Result"] == "control" + assert CANONICAL_MEMORY_DIM["ErrorInfo"] == "control" + +def test_memory_dim_file_heuristic_has_5_buckets() -> None: + """MEMORY_DIM_FILE_HEURISTIC has the 5 file-of-origin buckets.""" + assert len(MEMORY_DIM_FILE_HEURISTIC) == 5 + assert "curation" in MEMORY_DIM_FILE_HEURISTIC + assert "discussion" in MEMORY_DIM_FILE_HEURISTIC + assert "rag" in MEMORY_DIM_FILE_HEURISTIC + assert "config" in MEMORY_DIM_FILE_HEURISTIC + +def test_load_memory_dim_overrides_empty() -> None: + """load_memory_dim_overrides returns {} for a missing file.""" + result = load_memory_dim_overrides("/nonexistent/overrides.toml") + assert result == {} + +def test_load_memory_dim_overrides_parses_toml() -> None: + """load_memory_dim_overrides parses [memory_dim.] = '' lines.""" + with tempfile.TemporaryDirectory() as tmp: + overrides_path = Path(tmp) / "overrides.toml" + overrides_path.write_text('[memory_dim]\nMetadata = "curation"\n') + result = load_memory_dim_overrides(str(overrides_path)) + assert result.get("Metadata") == "curation" + +def test_file_origin_memory_dim_curation() -> None: + """file_origin_memory_dim returns 'curation' for files in the curation bucket.""" + dim = file_origin_memory_dim("src/aggregate.py") + assert dim == "curation" + +def test_file_origin_memory_dim_discussion() -> None: + """file_origin_memory_dim returns 'discussion' for files in the discussion bucket.""" + dim = file_origin_memory_dim("src/ai_client.py") + assert dim == "discussion" + +def test_file_origin_memory_dim_unknown() -> None: + """file_origin_memory_dim returns 'unknown' for files not in any bucket.""" + dim = file_origin_memory_dim("src/random.py") + assert dim == "unknown" + +def test_classify_memory_dim_canonical() -> None: + """classify_memory_dim returns the canonical dim for known aggregates, regardless of producer file.""" + dim = classify_memory_dim("Metadata", "src/aggregate.py", overrides={}) + assert dim == "discussion" + +def test_classify_memory_dim_override() -> None: + """classify_memory_dim respects the override file's mapping.""" + dim = classify_memory_dim("Metadata", "src/aggregate.py", overrides={"Metadata": "curation"}) + assert dim == "curation" + +def test_classify_memory_dim_file_heuristic() -> None: + """classify_memory_dim falls back to file-of-origin for unknown aggregates.""" + dim = classify_memory_dim("SomeUnknownAggregate", "src/aggregate.py", overrides={}) + assert dim == "curation" + +def test_classify_memory_dim_unknown_when_no_evidence() -> None: + """classify_memory_dim returns 'unknown' when no canonical, override, or file evidence.""" + dim = classify_memory_dim("SomeUnknownAggregate", "src/random.py", overrides={}) + assert dim == "unknown" + +def test_threshold_constants() -> None: + """The 4 APD threshold constants are defined.""" + assert WHOLE_STRUCT_KEY_THRESHOLD == 1 + assert FIELD_BY_FIELD_KEY_THRESHOLD == 3 + assert MIXED_DOMINANCE_THRESHOLD == 0.6 + assert AGGREGATE_LEVEL_DOMINANCE_THRESHOLD == 0.25 + +def test_is_whole_struct_access_true() -> None: + """is_whole_struct_access returns True for a function that reads the aggregate without accessing fields.""" + counts: Counter[str] = Counter() + assert is_whole_struct_access(counts, has_direct_access=True) is True + +def test_is_whole_struct_access_one_key() -> None: + """is_whole_struct_access returns True for <=1 distinct key.""" + counts: Counter[str] = Counter({"path": 3}) + assert is_whole_struct_access(counts, has_direct_access=False) is True + +def test_is_whole_struct_access_two_keys_false() -> None: + """is_whole_struct_access returns False for 2+ distinct keys.""" + counts: Counter[str] = Counter({"path": 3, "view_mode": 2}) + assert is_whole_struct_access(counts, has_direct_access=False) is False + +def test_is_field_by_field_access_true() -> None: + """is_field_by_field_access returns True for >=3 distinct keys AND no whole_struct access.""" + counts: Counter[str] = Counter({"a": 1, "b": 1, "c": 1}) + assert is_field_by_field_access(counts) is True + +def test_is_field_by_field_access_few_keys() -> None: + """is_field_by_field_access returns False for <3 distinct keys.""" + counts: Counter[str] = Counter({"a": 1, "b": 1}) + assert is_field_by_field_access(counts) is False + +def test_is_hot_cold_split_true() -> None: + """is_hot_cold_split returns True for 1-2 hot keys + 2+ cold keys.""" + hot = {"role", "content"} + cold = {"tool_calls", "reasoning_content"} + assert is_hot_cold_split(hot, cold) is True + +def test_is_hot_cold_split_too_many_hot() -> None: + """is_hot_cold_split returns False for 3+ hot keys.""" + hot = {"a", "b", "c"} + cold = {"d", "e"} + assert is_hot_cold_split(hot, cold) is False + +def test_is_hot_cold_split_too_few_cold() -> None: + """is_hot_cold_split returns False for <2 cold keys.""" + hot = {"a"} + cold = {"b"} + assert is_hot_cold_split(hot, cold) is False + +def test_is_bulk_batched_access_true() -> None: + """is_bulk_batched_access returns True for a function iterating over a list of aggregates.""" + assert is_bulk_batched_access(iterates_over_list=True, body_accesses_uniform=True) is True + +def test_is_bulk_batched_access_no_iteration() -> None: + """is_bulk_batched_access returns False when the function doesn't iterate over a list.""" + assert is_bulk_batched_access(iterates_over_list=False, body_accesses_uniform=True) is False + +def test_is_bulk_batched_access_non_uniform() -> None: + """is_bulk_batched_access returns False when the body has non-uniform access.""" + assert is_bulk_batched_access(iterates_over_list=True, body_accesses_uniform=False) is False + +def test_dominant_pattern_clear_winner() -> None: + """dominant_pattern returns the pattern with the highest share if >=25%.""" + counts = {"field_by_field": 3, "whole_struct": 1} + assert dominant_pattern(counts) == "field_by_field" + +def test_dominant_pattern_below_threshold() -> None: + """dominant_pattern returns 'mixed' when no pattern has >=25% share.""" + counts = {"field_by_field": 1, "whole_struct": 1, "hot_cold_split": 1, "bulk_batched": 1} + assert dominant_pattern(counts) == "mixed" + +def test_dominant_pattern_empty() -> None: + """dominant_pattern returns 'mixed' for an empty counts dict.""" + assert dominant_pattern({}) == "mixed" + +def test_detect_access_pattern_whole_struct() -> None: + """detect_access_pattern returns 'whole_struct' for a function that reads 0-1 keys.""" + counts: Counter[str] = Counter() + pattern = detect_access_pattern(counts, has_direct_access=True) + assert pattern == "whole_struct" + +def test_detect_access_pattern_field_by_field() -> None: + """detect_access_pattern returns 'field_by_field' for a function that reads 3+ keys.""" + counts: Counter[str] = Counter({"a": 1, "b": 1, "c": 1}) + pattern = detect_access_pattern(counts, has_direct_access=False) + assert pattern == "field_by_field" + +def test_detect_access_pattern_hot_cold_split() -> None: + """detect_access_pattern returns 'hot_cold_split' for 1-2 hot + 2+ cold keys.""" + counts: Counter[str] = Counter({"a": 1, "b": 1, "c": 1, "d": 1}) + pattern = detect_access_pattern(counts, has_direct_access=False, hot_keys={"a", "b"}, cold_keys={"c", "d"}) + assert pattern == "hot_cold_split" + +def test_detect_access_pattern_mixed() -> None: + """detect_access_pattern returns 'mixed' when no pattern dominates (2+ distinct keys but <3).""" + counts: Counter[str] = Counter({"a": 1, "b": 1}) + pattern = detect_access_pattern(counts, has_direct_access=False) + assert pattern == "mixed" \ No newline at end of file