Private
Public Access
0
0

feat(audit): implement Phase 3 MemoryDim + Phase 4 APD (11 tasks)

Phase 3: MemoryDim classifier with canonical mappings (23 entries,
includes ToolSpec/ChatMessage/ProviderHistory now that they're real),
file-of-origin heuristic (5 buckets), TOML override loader,
classify_memory_dim() with 3-tier precedence.

Phase 4: APD with 4 threshold constants, 5 pattern detectors
(whole_struct, field_by_field, hot_cold_split, bulk_batched,
dominant_pattern), detect_access_pattern() main entry.

30 new unit tests passing (Phase 3: 11, Phase 4: 19).
63 total tests passing.

Phase 5 (CFE - Call Frequency Estimator) next.
This commit is contained in:
2026-06-22 01:26:06 -04:00
parent a42a60b8bf
commit c1d2f0e454
2 changed files with 311 additions and 28 deletions
+128 -27
View File
@@ -10,6 +10,8 @@ conductor/tracks/code_path_audit_20260607/spec_v2.md.
"""
from __future__ import annotations
import ast
import tomllib
from collections import Counter
from dataclasses import dataclass, field
from pathlib import Path
from typing import Literal
@@ -152,12 +154,7 @@ class AggregateProfile:
@dataclass
class ProducerConsumerGraph:
"""Bipartite graph: aggregates <-> functions.
producers[aggregate] = set of FunctionRef that produce the aggregate.
consumers[aggregate] = set of FunctionRef that consume the aggregate.
edges[(producer, consumer)] = set of aggregates flowing between them.
"""
"""Bipartite graph: aggregates <-> functions."""
edges: dict[tuple[str, str], set[str]] = field(default_factory=dict)
producers: dict[str, set[FunctionRef]] = field(default_factory=dict)
consumers: dict[str, set[FunctionRef]] = field(default_factory=dict)
@@ -169,10 +166,7 @@ class ProducerConsumerGraph:
self.consumers.setdefault(aggregate, set()).add(function)
def P1_pass(tree: ast.Module, file: str) -> list[tuple[str, str, str, str]]:
"""AST pass 1: detect producers of T and Result[T] via return annotations.
Returns: list of (function_name, aggregate_name, role, confidence).
"""
"""AST pass 1: detect producers of T and Result[T] via return annotations."""
out: list[tuple[str, str, str, str]] = []
for node in ast.walk(tree):
if not isinstance(node, (ast.FunctionDef, ast.AsyncFunctionDef)):
@@ -194,10 +188,7 @@ def P1_pass(tree: ast.Module, file: str) -> list[tuple[str, str, str, str]]:
return out
def P2_pass(tree: ast.Module, file: str) -> list[tuple[str, str, str, str]]:
"""AST pass 2: detect consumers of typed aggregates via parameter annotations.
Returns: list of (function_name, aggregate_name, role, confidence).
"""
"""AST pass 2: detect consumers of typed aggregates via parameter annotations."""
out: list[tuple[str, str, str, str]] = []
for node in ast.walk(tree):
if not isinstance(node, (ast.FunctionDef, ast.AsyncFunctionDef)):
@@ -216,13 +207,7 @@ def P2_pass(tree: ast.Module, file: str) -> list[tuple[str, str, str, str]]:
return out
def P3_pass(tree: ast.Module, file: str, type_registry: dict[str, list[str]]) -> list[tuple[str, str, str, int]]:
"""AST pass 3: detect field accesses via entry['key'] or entry.attr.
Returns: list of (function_name, key_or_attr, kind, count).
type_registry is currently unused (the field-to-aggregate mapping
is computed in Phase 7 by the cross-audit integration); P3 only
records the field access itself.
"""
"""AST pass 3: detect field accesses via entry['key'] or entry.attr."""
out: list[tuple[str, str, str, int]] = []
for node in ast.walk(tree):
if not isinstance(node, (ast.FunctionDef, ast.AsyncFunctionDef)):
@@ -242,11 +227,7 @@ def P3_pass(tree: ast.Module, file: str, type_registry: dict[str, list[str]]) ->
return out
def build_pcg(src_dir: str, type_registry: dict[str, list[str]] | None = None) -> Result[ProducerConsumerGraph]:
"""Build the ProducerConsumerGraph by AST-walking src/.
Returns Result[PCG]. Syntax errors in individual files are
tolerated; the file is skipped and an ErrorInfo is added.
"""
"""Build the ProducerConsumerGraph by AST-walking src/."""
pcg = ProducerConsumerGraph()
type_registry = type_registry or {}
errors: list[ErrorInfo] = []
@@ -285,4 +266,124 @@ def build_pcg(src_dir: str, type_registry: dict[str, list[str]] | None = None) -
pcg.add_consumer(agg, fref)
for fn, key, kind, count in P3_pass(tree, file, type_registry):
pass
return Result(data=pcg, errors=errors)
return Result(data=pcg, errors=errors)
CANONICAL_MEMORY_DIM: dict[str, MemoryDim] = {
"Metadata": "discussion",
"CommsLogEntry": "discussion",
"CommsLog": "discussion",
"HistoryMessage": "discussion",
"History": "discussion",
"FileItem": "curation",
"FileItems": "curation",
"ToolDefinition": "control",
"ToolCall": "control",
"Result": "control",
"ErrorInfo": "control",
"ToolSpec": "control",
"ToolParameter": "control",
"ChatMessage": "discussion",
"UsageStats": "control",
"NormalizedResponse": "control",
"ProviderHistory": "discussion",
"OpenAICompatibleRequest": "control",
"Session": "knowledge",
"SessionMetadata": "knowledge",
"WebSocketMessage": "control",
"JsonValue": "control",
"ManualSlopConfig": "config",
"VendorCapabilities": "control",
}
MEMORY_DIM_FILE_HEURISTIC: dict[MemoryDim, tuple[str, ...]] = {
"curation": ("src/aggregate.py", "src/context_presets.py", "src/views.py"),
"discussion": ("src/ai_client.py", "src/history.py", "src/session_logger.py"),
"rag": ("src/rag_engine.py", "src/rag_index.py"),
"knowledge": ("src/knowledge.py", "src/knowledge_curation.py"),
"config": ("src/paths.py", "src/presets.py", "src/personas.py", "src/context_presets.py", "src/tool_presets.py"),
}
def load_memory_dim_overrides(path: str) -> dict[str, MemoryDim]:
"""Load memory_dim overrides from a TOML file."""
p = Path(path)
if not p.exists():
return {}
with p.open("rb") as f:
data = tomllib.load(f)
out: dict[str, MemoryDim] = {}
for key, value in data.get("memory_dim", {}).items():
if isinstance(value, str):
out[key] = value
return out
def file_origin_memory_dim(file: str) -> MemoryDim:
"""Determine the memory dim from the file of origin."""
for dim, files in MEMORY_DIM_FILE_HEURISTIC.items():
for f in files:
if file.startswith(f):
return dim
return "unknown"
def classify_memory_dim(aggregate: str, primary_producer_file: str, overrides: dict[str, MemoryDim]) -> MemoryDim:
"""Classify the memory dim of an aggregate.
Precedence: overrides > canonical > file_of_origin > unknown.
"""
if aggregate in overrides:
return overrides[aggregate]
if aggregate in CANONICAL_MEMORY_DIM:
return CANONICAL_MEMORY_DIM[aggregate]
return file_origin_memory_dim(primary_producer_file)
WHOLE_STRUCT_KEY_THRESHOLD: int = 1
FIELD_BY_FIELD_KEY_THRESHOLD: int = 3
MIXED_DOMINANCE_THRESHOLD: float = 0.6
AGGREGATE_LEVEL_DOMINANCE_THRESHOLD: float = 0.25
def is_whole_struct_access(field_counts: Counter, has_direct_access: bool) -> bool:
"""Detect whole_struct access: <=WHOLE_STRUCT_KEY_THRESHOLD distinct keys AND (direct access or 0 keys)."""
if has_direct_access:
return True
return len(field_counts) <= WHOLE_STRUCT_KEY_THRESHOLD
def is_field_by_field_access(field_counts: Counter) -> bool:
"""Detect field_by_field access: >=FIELD_BY_FIELD_KEY_THRESHOLD=3 distinct keys."""
return len(field_counts) >= FIELD_BY_FIELD_KEY_THRESHOLD
def is_hot_cold_split(hot_keys: set[str], cold_keys: set[str]) -> bool:
"""Detect hot_cold_split access: 1-2 hot keys in main body + 2+ cold keys in if/else branches."""
return 1 <= len(hot_keys) <= 2 and len(cold_keys) >= 2
def is_bulk_batched_access(iterates_over_list: bool, body_accesses_uniform: bool) -> bool:
"""Detect bulk_batched access: iterates over list[aggregate] with uniform field access."""
return iterates_over_list and body_accesses_uniform
def dominant_pattern(per_function_pattern_counts: dict[str, int]) -> AccessPattern:
"""Determine the aggregate-level dominant pattern from per-function pattern counts."""
if not per_function_pattern_counts:
return "mixed"
total = sum(per_function_pattern_counts.values())
winner = max(per_function_pattern_counts, key=per_function_pattern_counts.get)
share = per_function_pattern_counts[winner] / total
if share <= AGGREGATE_LEVEL_DOMINANCE_THRESHOLD:
return "mixed"
return winner
def detect_access_pattern(
field_counts: Counter,
has_direct_access: bool,
hot_keys: set[str] | None = None,
cold_keys: set[str] | None = None,
) -> AccessPattern:
"""Detect the per-function access pattern.
Precedence: whole_struct > hot_cold_split > field_by_field > mixed.
"""
if is_whole_struct_access(field_counts, has_direct_access):
return "whole_struct"
if hot_keys is not None and cold_keys is not None:
if is_hot_cold_split(hot_keys, cold_keys):
return "hot_cold_split"
if is_field_by_field_access(field_counts):
return "field_by_field"
return "mixed"