feat(audit): implement Phase 3 MemoryDim + Phase 4 APD (11 tasks)
Phase 3: MemoryDim classifier with canonical mappings (23 entries, includes ToolSpec/ChatMessage/ProviderHistory now that they're real), file-of-origin heuristic (5 buckets), TOML override loader, classify_memory_dim() with 3-tier precedence. Phase 4: APD with 4 threshold constants, 5 pattern detectors (whole_struct, field_by_field, hot_cold_split, bulk_batched, dominant_pattern), detect_access_pattern() main entry. 30 new unit tests passing (Phase 3: 11, Phase 4: 19). 63 total tests passing. Phase 5 (CFE - Call Frequency Estimator) next.
This commit is contained in:
+128
-27
@@ -10,6 +10,8 @@ conductor/tracks/code_path_audit_20260607/spec_v2.md.
|
||||
"""
|
||||
from __future__ import annotations
|
||||
import ast
|
||||
import tomllib
|
||||
from collections import Counter
|
||||
from dataclasses import dataclass, field
|
||||
from pathlib import Path
|
||||
from typing import Literal
|
||||
@@ -152,12 +154,7 @@ class AggregateProfile:
|
||||
|
||||
@dataclass
|
||||
class ProducerConsumerGraph:
|
||||
"""Bipartite graph: aggregates <-> functions.
|
||||
|
||||
producers[aggregate] = set of FunctionRef that produce the aggregate.
|
||||
consumers[aggregate] = set of FunctionRef that consume the aggregate.
|
||||
edges[(producer, consumer)] = set of aggregates flowing between them.
|
||||
"""
|
||||
"""Bipartite graph: aggregates <-> functions."""
|
||||
edges: dict[tuple[str, str], set[str]] = field(default_factory=dict)
|
||||
producers: dict[str, set[FunctionRef]] = field(default_factory=dict)
|
||||
consumers: dict[str, set[FunctionRef]] = field(default_factory=dict)
|
||||
@@ -169,10 +166,7 @@ class ProducerConsumerGraph:
|
||||
self.consumers.setdefault(aggregate, set()).add(function)
|
||||
|
||||
def P1_pass(tree: ast.Module, file: str) -> list[tuple[str, str, str, str]]:
|
||||
"""AST pass 1: detect producers of T and Result[T] via return annotations.
|
||||
|
||||
Returns: list of (function_name, aggregate_name, role, confidence).
|
||||
"""
|
||||
"""AST pass 1: detect producers of T and Result[T] via return annotations."""
|
||||
out: list[tuple[str, str, str, str]] = []
|
||||
for node in ast.walk(tree):
|
||||
if not isinstance(node, (ast.FunctionDef, ast.AsyncFunctionDef)):
|
||||
@@ -194,10 +188,7 @@ def P1_pass(tree: ast.Module, file: str) -> list[tuple[str, str, str, str]]:
|
||||
return out
|
||||
|
||||
def P2_pass(tree: ast.Module, file: str) -> list[tuple[str, str, str, str]]:
|
||||
"""AST pass 2: detect consumers of typed aggregates via parameter annotations.
|
||||
|
||||
Returns: list of (function_name, aggregate_name, role, confidence).
|
||||
"""
|
||||
"""AST pass 2: detect consumers of typed aggregates via parameter annotations."""
|
||||
out: list[tuple[str, str, str, str]] = []
|
||||
for node in ast.walk(tree):
|
||||
if not isinstance(node, (ast.FunctionDef, ast.AsyncFunctionDef)):
|
||||
@@ -216,13 +207,7 @@ def P2_pass(tree: ast.Module, file: str) -> list[tuple[str, str, str, str]]:
|
||||
return out
|
||||
|
||||
def P3_pass(tree: ast.Module, file: str, type_registry: dict[str, list[str]]) -> list[tuple[str, str, str, int]]:
|
||||
"""AST pass 3: detect field accesses via entry['key'] or entry.attr.
|
||||
|
||||
Returns: list of (function_name, key_or_attr, kind, count).
|
||||
type_registry is currently unused (the field-to-aggregate mapping
|
||||
is computed in Phase 7 by the cross-audit integration); P3 only
|
||||
records the field access itself.
|
||||
"""
|
||||
"""AST pass 3: detect field accesses via entry['key'] or entry.attr."""
|
||||
out: list[tuple[str, str, str, int]] = []
|
||||
for node in ast.walk(tree):
|
||||
if not isinstance(node, (ast.FunctionDef, ast.AsyncFunctionDef)):
|
||||
@@ -242,11 +227,7 @@ def P3_pass(tree: ast.Module, file: str, type_registry: dict[str, list[str]]) ->
|
||||
return out
|
||||
|
||||
def build_pcg(src_dir: str, type_registry: dict[str, list[str]] | None = None) -> Result[ProducerConsumerGraph]:
|
||||
"""Build the ProducerConsumerGraph by AST-walking src/.
|
||||
|
||||
Returns Result[PCG]. Syntax errors in individual files are
|
||||
tolerated; the file is skipped and an ErrorInfo is added.
|
||||
"""
|
||||
"""Build the ProducerConsumerGraph by AST-walking src/."""
|
||||
pcg = ProducerConsumerGraph()
|
||||
type_registry = type_registry or {}
|
||||
errors: list[ErrorInfo] = []
|
||||
@@ -285,4 +266,124 @@ def build_pcg(src_dir: str, type_registry: dict[str, list[str]] | None = None) -
|
||||
pcg.add_consumer(agg, fref)
|
||||
for fn, key, kind, count in P3_pass(tree, file, type_registry):
|
||||
pass
|
||||
return Result(data=pcg, errors=errors)
|
||||
return Result(data=pcg, errors=errors)
|
||||
|
||||
CANONICAL_MEMORY_DIM: dict[str, MemoryDim] = {
|
||||
"Metadata": "discussion",
|
||||
"CommsLogEntry": "discussion",
|
||||
"CommsLog": "discussion",
|
||||
"HistoryMessage": "discussion",
|
||||
"History": "discussion",
|
||||
"FileItem": "curation",
|
||||
"FileItems": "curation",
|
||||
"ToolDefinition": "control",
|
||||
"ToolCall": "control",
|
||||
"Result": "control",
|
||||
"ErrorInfo": "control",
|
||||
"ToolSpec": "control",
|
||||
"ToolParameter": "control",
|
||||
"ChatMessage": "discussion",
|
||||
"UsageStats": "control",
|
||||
"NormalizedResponse": "control",
|
||||
"ProviderHistory": "discussion",
|
||||
"OpenAICompatibleRequest": "control",
|
||||
"Session": "knowledge",
|
||||
"SessionMetadata": "knowledge",
|
||||
"WebSocketMessage": "control",
|
||||
"JsonValue": "control",
|
||||
"ManualSlopConfig": "config",
|
||||
"VendorCapabilities": "control",
|
||||
}
|
||||
|
||||
MEMORY_DIM_FILE_HEURISTIC: dict[MemoryDim, tuple[str, ...]] = {
|
||||
"curation": ("src/aggregate.py", "src/context_presets.py", "src/views.py"),
|
||||
"discussion": ("src/ai_client.py", "src/history.py", "src/session_logger.py"),
|
||||
"rag": ("src/rag_engine.py", "src/rag_index.py"),
|
||||
"knowledge": ("src/knowledge.py", "src/knowledge_curation.py"),
|
||||
"config": ("src/paths.py", "src/presets.py", "src/personas.py", "src/context_presets.py", "src/tool_presets.py"),
|
||||
}
|
||||
|
||||
def load_memory_dim_overrides(path: str) -> dict[str, MemoryDim]:
|
||||
"""Load memory_dim overrides from a TOML file."""
|
||||
p = Path(path)
|
||||
if not p.exists():
|
||||
return {}
|
||||
with p.open("rb") as f:
|
||||
data = tomllib.load(f)
|
||||
out: dict[str, MemoryDim] = {}
|
||||
for key, value in data.get("memory_dim", {}).items():
|
||||
if isinstance(value, str):
|
||||
out[key] = value
|
||||
return out
|
||||
|
||||
def file_origin_memory_dim(file: str) -> MemoryDim:
|
||||
"""Determine the memory dim from the file of origin."""
|
||||
for dim, files in MEMORY_DIM_FILE_HEURISTIC.items():
|
||||
for f in files:
|
||||
if file.startswith(f):
|
||||
return dim
|
||||
return "unknown"
|
||||
|
||||
def classify_memory_dim(aggregate: str, primary_producer_file: str, overrides: dict[str, MemoryDim]) -> MemoryDim:
|
||||
"""Classify the memory dim of an aggregate.
|
||||
|
||||
Precedence: overrides > canonical > file_of_origin > unknown.
|
||||
"""
|
||||
if aggregate in overrides:
|
||||
return overrides[aggregate]
|
||||
if aggregate in CANONICAL_MEMORY_DIM:
|
||||
return CANONICAL_MEMORY_DIM[aggregate]
|
||||
return file_origin_memory_dim(primary_producer_file)
|
||||
|
||||
WHOLE_STRUCT_KEY_THRESHOLD: int = 1
|
||||
FIELD_BY_FIELD_KEY_THRESHOLD: int = 3
|
||||
MIXED_DOMINANCE_THRESHOLD: float = 0.6
|
||||
AGGREGATE_LEVEL_DOMINANCE_THRESHOLD: float = 0.25
|
||||
|
||||
def is_whole_struct_access(field_counts: Counter, has_direct_access: bool) -> bool:
|
||||
"""Detect whole_struct access: <=WHOLE_STRUCT_KEY_THRESHOLD distinct keys AND (direct access or 0 keys)."""
|
||||
if has_direct_access:
|
||||
return True
|
||||
return len(field_counts) <= WHOLE_STRUCT_KEY_THRESHOLD
|
||||
|
||||
def is_field_by_field_access(field_counts: Counter) -> bool:
|
||||
"""Detect field_by_field access: >=FIELD_BY_FIELD_KEY_THRESHOLD=3 distinct keys."""
|
||||
return len(field_counts) >= FIELD_BY_FIELD_KEY_THRESHOLD
|
||||
|
||||
def is_hot_cold_split(hot_keys: set[str], cold_keys: set[str]) -> bool:
|
||||
"""Detect hot_cold_split access: 1-2 hot keys in main body + 2+ cold keys in if/else branches."""
|
||||
return 1 <= len(hot_keys) <= 2 and len(cold_keys) >= 2
|
||||
|
||||
def is_bulk_batched_access(iterates_over_list: bool, body_accesses_uniform: bool) -> bool:
|
||||
"""Detect bulk_batched access: iterates over list[aggregate] with uniform field access."""
|
||||
return iterates_over_list and body_accesses_uniform
|
||||
|
||||
def dominant_pattern(per_function_pattern_counts: dict[str, int]) -> AccessPattern:
|
||||
"""Determine the aggregate-level dominant pattern from per-function pattern counts."""
|
||||
if not per_function_pattern_counts:
|
||||
return "mixed"
|
||||
total = sum(per_function_pattern_counts.values())
|
||||
winner = max(per_function_pattern_counts, key=per_function_pattern_counts.get)
|
||||
share = per_function_pattern_counts[winner] / total
|
||||
if share <= AGGREGATE_LEVEL_DOMINANCE_THRESHOLD:
|
||||
return "mixed"
|
||||
return winner
|
||||
|
||||
def detect_access_pattern(
|
||||
field_counts: Counter,
|
||||
has_direct_access: bool,
|
||||
hot_keys: set[str] | None = None,
|
||||
cold_keys: set[str] | None = None,
|
||||
) -> AccessPattern:
|
||||
"""Detect the per-function access pattern.
|
||||
|
||||
Precedence: whole_struct > hot_cold_split > field_by_field > mixed.
|
||||
"""
|
||||
if is_whole_struct_access(field_counts, has_direct_access):
|
||||
return "whole_struct"
|
||||
if hot_keys is not None and cold_keys is not None:
|
||||
if is_hot_cold_split(hot_keys, cold_keys):
|
||||
return "hot_cold_split"
|
||||
if is_field_by_field_access(field_counts):
|
||||
return "field_by_field"
|
||||
return "mixed"
|
||||
Reference in New Issue
Block a user