WIP: PAIN

This commit is contained in:
2026-03-05 14:24:03 -05:00
parent e81843b11b
commit 0e3b479bd6
27 changed files with 684 additions and 772 deletions

View File

@@ -57,10 +57,20 @@ def resolve_paths(base_dir: Path, entry: str) -> list[Path]:
filtered.append(p)
return sorted(filtered)
def build_discussion_section(history: list[str]) -> str:
def build_discussion_section(history: list[Any]) -> str:
"""
Builds a markdown section for discussion history.
Handles both legacy list[str] and new list[dict].
"""
sections = []
for i, paste in enumerate(history, start=1):
sections.append(f"### Discussion Excerpt {i}\n\n{paste.strip()}")
for i, entry in enumerate(history, start=1):
if isinstance(entry, dict):
role = entry.get("role", "Unknown")
content = entry.get("content", "").strip()
text = f"{role}: {content}"
else:
text = str(entry).strip()
sections.append(f"### Discussion Excerpt {i}\n\n{text}")
return "\n\n---\n\n".join(sections)
def build_files_section(base_dir: Path, files: list[str | dict[str, Any]]) -> str:

View File

@@ -129,6 +129,7 @@ _comms_log: list[dict[str, Any]] = []
COMMS_CLAMP_CHARS: int = 300
def _append_comms(direction: str, kind: str, payload: dict[str, Any]) -> None:
global current_tier
entry: dict[str, Any] = {
"ts": datetime.datetime.now().strftime("%H:%M:%S"),
"direction": direction,
@@ -1585,13 +1586,28 @@ def _add_bleed_derived(d: dict[str, Any], sys_tok: int = 0, tool_tok: int = 0) -
d["estimated_prompt_tokens"] = cur
d["max_prompt_tokens"] = lim
d["utilization_pct"] = d.get("percentage", 0.0)
d["headroom_tokens"] = max(0, lim - cur)
d["would_trim"] = (lim - cur) < 20000
d["system_tokens"] = sys_tok
d["tools_tokens"] = tool_tok
d["headroom"] = max(0, lim - cur)
d["would_trim"] = cur >= lim
d["sys_tokens"] = sys_tok
d["tool_tokens"] = tool_tok
d["history_tokens"] = max(0, cur - sys_tok - tool_tok)
return d
def _is_mutating_tool(name: str) -> bool:
"""Returns True if the tool name is considered a mutating tool."""
return name in mcp_client.MUTATING_TOOLS or name == TOOL_NAME
def _confirm_and_run(script: str, base_dir: str, qa_callback: Optional[Callable[[str], str]] = None) -> Optional[str]:
"""
Wrapper for the confirm_and_run_callback.
This is what the providers call to trigger HITL approval.
"""
if confirm_and_run_callback:
return confirm_and_run_callback(script, base_dir, qa_callback)
# Fallback to direct execution if no callback registered (headless default)
from src import shell_runner
return shell_runner.run_powershell(script, base_dir, qa_callback=qa_callback)
def get_history_bleed_stats(md_content: Optional[str] = None) -> dict[str, Any]:
if _provider == "anthropic":
with _anthropic_history_lock:

View File

@@ -5,167 +5,112 @@ import time
from typing import Any
class ApiHookClient:
def __init__(self, base_url: str = "http://127.0.0.1:8999", max_retries: int = 5, retry_delay: float = 0.2) -> None:
self.base_url = base_url
self.max_retries = max_retries
self.retry_delay = retry_delay
self._event_buffer: list[dict[str, Any]] = []
def __init__(self, base_url: str = "http://127.0.0.1:8999", api_key: str | None = None):
self.base_url = base_url.rstrip('/')
self.api_key = api_key
def wait_for_server(self, timeout: float = 3) -> bool:
"""
Polls the /status endpoint until the server is ready or timeout is reached.
"""
start_time = time.time()
while time.time() - start_time < timeout:
try:
if self.get_status().get('status') == 'ok':
return True
except (requests.exceptions.ConnectionError, requests.exceptions.Timeout):
time.sleep(0.1)
def _make_request(self, method: str, path: str, data: dict | None = None, timeout: float = 5.0) -> dict[str, Any] | None:
"""Helper to make HTTP requests to the hook server."""
url = f"{self.base_url}{path}"
headers = {}
if self.api_key:
headers["X-API-KEY"] = self.api_key
if method not in ('GET', 'POST', 'DELETE'):
raise ValueError(f"Unsupported HTTP method: {method}")
try:
if method == 'GET':
response = requests.get(url, headers=headers, timeout=timeout)
elif method == 'POST':
response = requests.post(url, json=data, headers=headers, timeout=timeout)
elif method == 'DELETE':
response = requests.delete(url, headers=headers, timeout=timeout)
if response.status_code == 200:
return response.json()
return None
except Exception as e:
# Silently ignore connection errors unless we are in a wait loop
return None
def wait_for_server(self, timeout: int = 15) -> bool:
"""Polls the health endpoint until the server responds or timeout occurs."""
start = time.time()
while time.time() - start < timeout:
status = self.get_status()
if status and (status.get("status") == "ok" or "status" in status):
return True
time.sleep(0.5)
return False
def _make_request(self, method: str, endpoint: str, data: dict[str, Any] | None = None, timeout: float | None = None) -> dict[str, Any] | None:
url = f"{self.base_url}{endpoint}"
headers = {'Content-Type': 'application/json'}
last_exception = None
# Increase default request timeout for local server
req_timeout = timeout if timeout is not None else 10.0
for attempt in range(self.max_retries + 1):
try:
if method == 'GET':
response = requests.get(url, timeout=req_timeout)
elif method == 'POST':
response = requests.post(url, json=data, headers=headers, timeout=req_timeout)
else:
raise ValueError(f"Unsupported HTTP method: {method}")
response.raise_for_status() # Raise HTTPError for bad responses (4xx or 5xx)
res_json = response.json()
return res_json if isinstance(res_json, dict) else None
except (requests.exceptions.Timeout, requests.exceptions.ConnectionError) as e:
last_exception = e
if attempt < self.max_retries:
time.sleep(self.retry_delay)
continue
else:
if isinstance(e, requests.exceptions.Timeout):
raise requests.exceptions.Timeout(f"Request to {endpoint} timed out after {self.max_retries} retries.") from e
else:
raise requests.exceptions.ConnectionError(f"Could not connect to API hook server at {self.base_url} after {self.max_retries} retries.") from e
except requests.exceptions.HTTPError as e:
raise requests.exceptions.HTTPError(f"HTTP error {e.response.status_code} for {endpoint}: {e.response.text}") from e
except json.JSONDecodeError as e:
raise ValueError(f"Failed to decode JSON from response for {endpoint}: {response.text}") from e
if last_exception:
raise last_exception
return None
def get_status(self) -> dict[str, Any]:
"""Checks the health of the hook server."""
url = f"{self.base_url}/status"
try:
response = requests.get(url, timeout=5.0)
response.raise_for_status()
res = response.json()
return res if isinstance(res, dict) else {}
except Exception:
raise requests.exceptions.ConnectionError(f"Could not reach /status at {self.base_url}")
def get_project(self) -> dict[str, Any] | None:
return self._make_request('GET', '/api/project')
def post_project(self, project_data: dict[str, Any]) -> dict[str, Any] | None:
return self._make_request('POST', '/api/project', data={'project': project_data})
def get_session(self) -> dict[str, Any] | None:
res = self._make_request('GET', '/api/session')
res = self._make_request('GET', '/status')
if res is None:
# For backward compatibility with tests expecting ConnectionError
# But our _make_request handles it. Let's return empty if failed.
return {}
return res
def get_mma_status(self) -> dict[str, Any] | None:
"""Retrieves current MMA status (track, tickets, tier, etc.)"""
return self._make_request('GET', '/api/gui/mma_status')
def get_project(self) -> dict[str, Any]:
"""Retrieves the current project state."""
return self._make_request('GET', '/api/project') or {}
def get_gui_state(self) -> dict | None:
"""Retrieves the current GUI state via /api/gui/state."""
resp = self._make_request("GET", "/api/gui/state")
return resp if resp else None
def get_session(self) -> dict[str, Any]:
"""Retrieves the current discussion session history."""
return self._make_request('GET', '/api/session') or {}
def push_event(self, event_type: str, payload: dict[str, Any]) -> dict[str, Any] | None:
"""Pushes an event to the GUI's AsyncEventQueue via the /api/gui endpoint."""
return self.post_gui({
"action": event_type,
"payload": payload
})
def post_session(self, session_entries: list[dict]) -> dict[str, Any]:
"""Updates the session history."""
return self._make_request('POST', '/api/session', data={"entries": session_entries}) or {}
def get_performance(self) -> dict[str, Any] | None:
"""Retrieves UI performance metrics."""
return self._make_request('GET', '/api/performance')
def post_gui(self, payload: dict) -> dict[str, Any]:
"""Pushes an event to the GUI's SyncEventQueue via the /api/gui endpoint."""
return self._make_request('POST', '/api/gui', data=payload) or {}
def post_session(self, session_entries: list[Any]) -> dict[str, Any] | None:
return self._make_request('POST', '/api/session', data={'session': {'entries': session_entries}})
def click(self, item: str, user_data: Any = None) -> dict[str, Any]:
"""Simulates a button click."""
return self.post_gui({"action": "click", "item": item, "user_data": user_data})
def post_gui(self, gui_data: dict[str, Any]) -> dict[str, Any] | None:
return self._make_request('POST', '/api/gui', data=gui_data)
def set_value(self, item: str, value: Any) -> dict[str, Any]:
"""Sets the value of a GUI widget."""
return self.post_gui({"action": "set_value", "item": item, "value": value})
def select_tab(self, tab_bar: str, tab: str) -> dict[str, Any] | None:
"""Tells the GUI to switch to a specific tab in a tab bar."""
return self.post_gui({
"action": "select_tab",
"tab_bar": tab_bar,
"tab": tab
})
def select_tab(self, item: str, value: str) -> dict[str, Any]:
"""Selects a specific tab in a tab bar."""
return self.set_value(item, value)
def select_list_item(self, listbox: str, item_value: str) -> dict[str, Any] | None:
"""Tells the GUI to select an item in a listbox by its value."""
return self.post_gui({
"action": "select_list_item",
"listbox": listbox,
"item_value": item_value
})
def select_list_item(self, item: str, value: str) -> dict[str, Any]:
"""Selects an item in a listbox or combo."""
return self.set_value(item, value)
def set_value(self, item: str, value: Any) -> dict[str, Any] | None:
"""Sets the value of a GUI item."""
return self.post_gui({
"action": "set_value",
"item": item,
"value": value
})
def get_gui_state(self) -> dict[str, Any]:
"""Returns the full GUI state available via the hook API."""
return self._make_request('GET', '/api/gui/state') or {}
def get_value(self, item: str) -> Any:
"""Gets the value of a GUI item via its mapped field."""
try:
# First try direct field querying via POST
res = self._make_request('POST', '/api/gui/value', data={"field": item})
if res and "value" in res:
v = res.get("value")
if v is not None:
return v
except Exception:
pass
try:
# Try GET fallback
res = self._make_request('GET', f'/api/gui/value/{item}')
if res and "value" in res:
v = res.get("value")
if v is not None:
return v
except Exception:
pass
try:
# Try state endpoint first (new preferred way)
state = self.get_gui_state()
if item in state:
return state[item]
# Fallback for thinking/live/prior which are in diagnostics
diag = self._make_request('GET', '/api/gui/diagnostics')
if diag and item in diag:
return diag[item]
# Map common indicator tags to diagnostics keys
mapping = {
"thinking_indicator": "thinking",
"operations_live_indicator": "live",
"prior_session_indicator": "prior"
}
key = mapping.get(item)
if diag and key and key in diag:
return diag[key]
except Exception:
pass
diag = self.get_gui_diagnostics()
if diag and item in diag:
return diag[item]
# Map common indicator tags to diagnostics keys
mapping = {
"thinking_indicator": "thinking",
"operations_live_indicator": "live",
"prior_session_indicator": "prior"
}
key = mapping.get(item)
if diag and key and key in diag:
return diag[key]
return None
def get_text_value(self, item_tag: str) -> str | None:
@@ -173,93 +118,39 @@ class ApiHookClient:
val = self.get_value(item_tag)
return str(val) if val is not None else None
def get_node_status(self, node_tag: str) -> Any:
"""Wraps get_value for a DAG node or queries the diagnostic endpoint for its status."""
val = self.get_value(node_tag)
if val is not None:
return val
try:
diag = self._make_request('GET', '/api/gui/diagnostics')
if diag and 'nodes' in diag and node_tag in diag['nodes']:
return diag['nodes'][node_tag]
if diag and node_tag in diag:
return diag[node_tag]
except Exception:
pass
return None
def get_indicator_state(self, item_tag: str) -> dict[str, bool]:
"""Returns the visibility/active state of a status indicator."""
val = self.get_value(item_tag)
return {"shown": bool(val)}
def click(self, item: str, *args: Any, **kwargs: Any) -> dict[str, Any] | None:
"""Simulates a click on a GUI button or item."""
user_data = kwargs.pop('user_data', None)
return self.post_gui({
"action": "click",
"item": item,
"args": args,
"kwargs": kwargs,
"user_data": user_data
})
def get_gui_diagnostics(self) -> dict[str, Any]:
"""Retrieves performance and diagnostic metrics."""
return self._make_request('GET', '/api/gui/diagnostics') or {}
def get_indicator_state(self, tag: str) -> dict[str, Any]:
"""Checks if an indicator is shown using the diagnostics endpoint."""
# Mapping tag to the keys used in diagnostics endpoint
mapping = {
"thinking_indicator": "thinking",
"operations_live_indicator": "live",
"prior_session_indicator": "prior"
def get_mma_status(self) -> dict[str, Any]:
"""Convenience to get the current MMA engine status."""
state = self.get_gui_state()
return {
"mma_status": state.get("mma_status"),
"ai_status": state.get("ai_status"),
"active_tier": state.get("mma_active_tier")
}
key = mapping.get(tag, tag)
try:
diag = self._make_request('GET', '/api/gui/diagnostics')
return {"tag": tag, "shown": diag.get(key, False) if diag else False}
except Exception as e:
return {"tag": tag, "shown": False, "error": str(e)}
def get_events(self) -> list[Any]:
"""Fetches new events and adds them to the internal buffer."""
try:
res = self._make_request('GET', '/api/events')
new_events = res.get("events", []) if res else []
if new_events:
self._event_buffer.extend(new_events)
return list(self._event_buffer)
except Exception:
return list(self._event_buffer)
def get_node_status(self, node_id: str) -> dict[str, Any]:
"""Retrieves status for a specific node in the MMA DAG."""
return self._make_request('GET', f'/api/mma/node/{node_id}') or {}
def clear_events(self) -> None:
"""Clears the internal event buffer and the server queue."""
self._make_request('GET', '/api/events')
self._event_buffer.clear()
def wait_for_event(self, event_type: str, timeout: float = 5) -> dict[str, Any] | None:
"""Polls for a specific event type in the internal buffer."""
start = time.time()
while time.time() - start < timeout:
# Refresh buffer
self.get_events()
# Search in buffer
for i, ev in enumerate(self._event_buffer):
if isinstance(ev, dict) and ev.get("type") == event_type:
return self._event_buffer.pop(i)
time.sleep(0.1) # Fast poll
return None
def wait_for_value(self, item: str, expected: Any, timeout: float = 5) -> bool:
"""Polls until get_value(item) == expected."""
start = time.time()
while time.time() - start < timeout:
if self.get_value(item) == expected:
return True
time.sleep(0.1) # Fast poll
return False
def reset_session(self) -> dict[str, Any] | None:
"""Simulates clicking the 'Reset Session' button in the GUI."""
return self.click("btn_reset")
def request_confirmation(self, tool_name: str, args: dict[str, Any]) -> Any:
"""Asks the user for confirmation via the GUI (blocking call)."""
# Using a long timeout as this waits for human input (60 seconds)
def request_confirmation(self, tool_name: str, args: dict) -> bool | None:
"""
Pushes a manual confirmation request and waits for response.
Blocks for up to 60 seconds.
"""
# Long timeout as this waits for human input (60 seconds)
res = self._make_request('POST', '/api/ask',
data={'type': 'tool_approval', 'tool': tool_name, 'args': args},
timeout=60.0)
return res.get('response') if res else None
def reset_session(self) -> None:
"""Resets the current session via button click."""
self.click("btn_reset")

View File

@@ -21,17 +21,17 @@ class LogPruner:
self.log_registry = log_registry
self.logs_dir = logs_dir
def prune(self) -> None:
def prune(self, max_age_days: int = 1) -> None:
"""
Prunes old and small session directories from the logs directory.
Deletes session directories that meet the following criteria:
1. The session start time is older than 24 hours (based on data from LogRegistry).
1. The session start time is older than max_age_days.
2. The session name is NOT in the whitelist provided by the LogRegistry.
3. The total size of all files within the session directory is less than 2KB (2048 bytes).
"""
now = datetime.now()
cutoff_time = now - timedelta(hours=24)
cutoff_time = now - timedelta(days=max_age_days)
# Ensure the base logs directory exists.
if not os.path.isdir(self.logs_dir):
return
@@ -39,7 +39,7 @@ class LogPruner:
old_sessions_to_check = self.log_registry.get_old_non_whitelisted_sessions(cutoff_time)
# Prune sessions if their size is less than 2048 bytes
for session_info in old_sessions_to_check:
session_info['session_id']
session_id = session_info['session_id']
session_path = session_info['path']
if not session_path or not os.path.isdir(session_path):
continue
@@ -55,6 +55,9 @@ class LogPruner:
if total_size < 2048: # 2KB
try:
shutil.rmtree(session_path)
# print(f"Pruned session '{session_id}' (Size: {total_size} bytes)")
# Also remove from registry to keep it in sync
if session_id in self.log_registry.data:
del self.log_registry.data[session_id]
except OSError:
pass
self.log_registry.save_registry()

View File

@@ -22,6 +22,11 @@ class LogRegistry:
self.data: dict[str, dict[str, Any]] = {}
self.load_registry()
@property
def sessions(self) -> dict[str, dict[str, Any]]:
"""Alias for compatibility with older code/tests."""
return self.data
def load_registry(self) -> None:
"""
Loads the registry data from the TOML file into memory.

View File

@@ -106,7 +106,7 @@ def _is_allowed(path: Path) -> bool:
"""
# Blacklist check
name = path.name.lower()
if name == "history.toml" or name.endswith("_history.toml"):
if name in ("history.toml", "config.toml", "credentials.toml") or name.endswith("_history.toml"):
return False
try:
rp = path.resolve(strict=True)
@@ -926,7 +926,9 @@ def dispatch(tool_name: str, tool_input: dict[str, Any]) -> str:
if tool_name == "get_tree":
return get_tree(path, int(tool_input.get("max_depth", 2)))
return f"ERROR: unknown MCP tool '{tool_name}'"
# ------------------------------------------------------------------ tool schema helpers
# ------------------------------------------------------------------ tool schema helpers
# These are imported by ai_client.py to build provider-specific declarations.
MCP_TOOL_SPECS: list[dict[str, Any]] = [
@@ -1389,3 +1391,4 @@ MCP_TOOL_SPECS: list[dict[str, Any]] = [
}
}
]