refactor(types): Phase 4 type hint sweep — core modules

This commit is contained in:
2026-02-28 15:13:55 -05:00
parent ca04026db5
commit 46c2f9a0ca
10 changed files with 52 additions and 42 deletions

View File

@@ -1,14 +1,16 @@
from __future__ import annotations
import requests import requests
import json import json
import time import time
from typing import Any
class ApiHookClient: class ApiHookClient:
def __init__(self, base_url="http://127.0.0.1:8999", max_retries=5, retry_delay=0.2): def __init__(self, base_url: str = "http://127.0.0.1:8999", max_retries: int = 5, retry_delay: float = 0.2) -> None:
self.base_url = base_url self.base_url = base_url
self.max_retries = max_retries self.max_retries = max_retries
self.retry_delay = retry_delay self.retry_delay = retry_delay
def wait_for_server(self, timeout=3): def wait_for_server(self, timeout: float = 3) -> bool:
""" """
Polls the /status endpoint until the server is ready or timeout is reached. Polls the /status endpoint until the server is ready or timeout is reached.
""" """
@@ -21,7 +23,7 @@ class ApiHookClient:
time.sleep(0.1) time.sleep(0.1)
return False return False
def _make_request(self, method, endpoint, data=None, timeout=None): def _make_request(self, method: str, endpoint: str, data: dict | None = None, timeout: float | None = None) -> dict | None:
url = f"{self.base_url}{endpoint}" url = f"{self.base_url}{endpoint}"
headers = {'Content-Type': 'application/json'} headers = {'Content-Type': 'application/json'}
last_exception = None last_exception = None
@@ -54,7 +56,7 @@ class ApiHookClient:
if last_exception: if last_exception:
raise last_exception raise last_exception
def get_status(self): def get_status(self) -> dict:
"""Checks the health of the hook server.""" """Checks the health of the hook server."""
url = f"{self.base_url}/status" url = f"{self.base_url}/status"
try: try:
@@ -64,37 +66,37 @@ class ApiHookClient:
except Exception: except Exception:
raise requests.exceptions.ConnectionError(f"Could not reach /status at {self.base_url}") raise requests.exceptions.ConnectionError(f"Could not reach /status at {self.base_url}")
def get_project(self): def get_project(self) -> dict | None:
return self._make_request('GET', '/api/project') return self._make_request('GET', '/api/project')
def post_project(self, project_data): def post_project(self, project_data: dict) -> dict | None:
return self._make_request('POST', '/api/project', data={'project': project_data}) return self._make_request('POST', '/api/project', data={'project': project_data})
def get_session(self): def get_session(self) -> dict | None:
return self._make_request('GET', '/api/session') return self._make_request('GET', '/api/session')
def get_mma_status(self): def get_mma_status(self) -> dict | None:
"""Retrieves current MMA status (track, tickets, tier, etc.)""" """Retrieves current MMA status (track, tickets, tier, etc.)"""
return self._make_request('GET', '/api/gui/mma_status') return self._make_request('GET', '/api/gui/mma_status')
def push_event(self, event_type, payload): def push_event(self, event_type: str, payload: dict) -> dict | None:
"""Pushes an event to the GUI's AsyncEventQueue via the /api/gui endpoint.""" """Pushes an event to the GUI's AsyncEventQueue via the /api/gui endpoint."""
return self.post_gui({ return self.post_gui({
"action": event_type, "action": event_type,
"payload": payload "payload": payload
}) })
def get_performance(self): def get_performance(self) -> dict | None:
"""Retrieves UI performance metrics.""" """Retrieves UI performance metrics."""
return self._make_request('GET', '/api/performance') return self._make_request('GET', '/api/performance')
def post_session(self, session_entries): def post_session(self, session_entries: list) -> dict | None:
return self._make_request('POST', '/api/session', data={'session': {'entries': session_entries}}) return self._make_request('POST', '/api/session', data={'session': {'entries': session_entries}})
def post_gui(self, gui_data): def post_gui(self, gui_data: dict) -> dict | None:
return self._make_request('POST', '/api/gui', data=gui_data) return self._make_request('POST', '/api/gui', data=gui_data)
def select_tab(self, tab_bar, tab): def select_tab(self, tab_bar: str, tab: str) -> dict | None:
"""Tells the GUI to switch to a specific tab in a tab bar.""" """Tells the GUI to switch to a specific tab in a tab bar."""
return self.post_gui({ return self.post_gui({
"action": "select_tab", "action": "select_tab",
@@ -102,7 +104,7 @@ class ApiHookClient:
"tab": tab "tab": tab
}) })
def select_list_item(self, listbox, item_value): def select_list_item(self, listbox: str, item_value: str) -> dict | None:
"""Tells the GUI to select an item in a listbox by its value.""" """Tells the GUI to select an item in a listbox by its value."""
return self.post_gui({ return self.post_gui({
"action": "select_list_item", "action": "select_list_item",
@@ -110,7 +112,7 @@ class ApiHookClient:
"item_value": item_value "item_value": item_value
}) })
def set_value(self, item, value): def set_value(self, item: str, value: Any) -> dict | None:
"""Sets the value of a GUI item.""" """Sets the value of a GUI item."""
return self.post_gui({ return self.post_gui({
"action": "set_value", "action": "set_value",
@@ -118,7 +120,7 @@ class ApiHookClient:
"value": value "value": value
}) })
def get_value(self, item): def get_value(self, item: str) -> Any:
"""Gets the value of a GUI item via its mapped field.""" """Gets the value of a GUI item via its mapped field."""
try: try:
# First try direct field querying via POST # First try direct field querying via POST
@@ -156,12 +158,12 @@ class ApiHookClient:
pass pass
return None return None
def get_text_value(self, item_tag): def get_text_value(self, item_tag: str) -> str | None:
"""Wraps get_value and returns its string representation, or None.""" """Wraps get_value and returns its string representation, or None."""
val = self.get_value(item_tag) val = self.get_value(item_tag)
return str(val) if val is not None else None return str(val) if val is not None else None
def get_node_status(self, node_tag): def get_node_status(self, node_tag: str) -> Any:
"""Wraps get_value for a DAG node or queries the diagnostic endpoint for its status.""" """Wraps get_value for a DAG node or queries the diagnostic endpoint for its status."""
val = self.get_value(node_tag) val = self.get_value(node_tag)
if val is not None: if val is not None:
@@ -176,7 +178,7 @@ class ApiHookClient:
pass pass
return None return None
def click(self, item, *args, **kwargs): def click(self, item: str, *args: Any, **kwargs: Any) -> dict | None:
"""Simulates a click on a GUI button or item.""" """Simulates a click on a GUI button or item."""
user_data = kwargs.pop('user_data', None) user_data = kwargs.pop('user_data', None)
return self.post_gui({ return self.post_gui({
@@ -187,7 +189,7 @@ class ApiHookClient:
"user_data": user_data "user_data": user_data
}) })
def get_indicator_state(self, tag): def get_indicator_state(self, tag: str) -> dict:
"""Checks if an indicator is shown using the diagnostics endpoint.""" """Checks if an indicator is shown using the diagnostics endpoint."""
# Mapping tag to the keys used in diagnostics endpoint # Mapping tag to the keys used in diagnostics endpoint
mapping = { mapping = {
@@ -202,14 +204,14 @@ class ApiHookClient:
except Exception as e: except Exception as e:
return {"tag": tag, "shown": False, "error": str(e)} return {"tag": tag, "shown": False, "error": str(e)}
def get_events(self): def get_events(self) -> list:
"""Fetches and clears the event queue from the server.""" """Fetches and clears the event queue from the server."""
try: try:
return self._make_request('GET', '/api/events').get("events", []) return self._make_request('GET', '/api/events').get("events", [])
except Exception: except Exception:
return [] return []
def wait_for_event(self, event_type, timeout=5): def wait_for_event(self, event_type: str, timeout: float = 5) -> dict | None:
"""Polls for a specific event type.""" """Polls for a specific event type."""
start = time.time() start = time.time()
while time.time() - start < timeout: while time.time() - start < timeout:
@@ -220,7 +222,7 @@ class ApiHookClient:
time.sleep(0.1) # Fast poll time.sleep(0.1) # Fast poll
return None return None
def wait_for_value(self, item, expected, timeout=5): def wait_for_value(self, item: str, expected: Any, timeout: float = 5) -> bool:
"""Polls until get_value(item) == expected.""" """Polls until get_value(item) == expected."""
start = time.time() start = time.time()
while time.time() - start < timeout: while time.time() - start < timeout:
@@ -229,11 +231,11 @@ class ApiHookClient:
time.sleep(0.1) # Fast poll time.sleep(0.1) # Fast poll
return False return False
def reset_session(self): def reset_session(self) -> dict | None:
"""Simulates clicking the 'Reset Session' button in the GUI.""" """Simulates clicking the 'Reset Session' button in the GUI."""
return self.click("btn_reset") return self.click("btn_reset")
def request_confirmation(self, tool_name, args): def request_confirmation(self, tool_name: str, args: dict) -> Any:
"""Asks the user for confirmation via the GUI (blocking call).""" """Asks the user for confirmation via the GUI (blocking call)."""
# Using a long timeout as this waits for human input (60 seconds) # Using a long timeout as this waits for human input (60 seconds)
res = self._make_request('POST', '/api/ask', res = self._make_request('POST', '/api/ask',

View File

@@ -1,14 +1,16 @@
from __future__ import annotations
import json import json
import threading import threading
import uuid import uuid
from http.server import ThreadingHTTPServer, BaseHTTPRequestHandler from http.server import ThreadingHTTPServer, BaseHTTPRequestHandler
from typing import Any
import logging import logging
import session_logger import session_logger
class HookServerInstance(ThreadingHTTPServer): class HookServerInstance(ThreadingHTTPServer):
"""Custom HTTPServer that carries a reference to the main App instance.""" """Custom HTTPServer that carries a reference to the main App instance."""
def __init__(self, server_address, RequestHandlerClass, app): def __init__(self, server_address: tuple[str, int], RequestHandlerClass: type, app: Any) -> None:
super().__init__(server_address, RequestHandlerClass) super().__init__(server_address, RequestHandlerClass)
self.app = app self.app = app
@@ -273,11 +275,11 @@ class HookHandler(BaseHTTPRequestHandler):
self.end_headers() self.end_headers()
self.wfile.write(json.dumps({'error': str(e)}).encode('utf-8')) self.wfile.write(json.dumps({'error': str(e)}).encode('utf-8'))
def log_message(self, format, *args): def log_message(self, format: str, *args: Any) -> None:
logging.info("Hook API: " + format % args) logging.info("Hook API: " + format % args)
class HookServer: class HookServer:
def __init__(self, app, port=8999): def __init__(self, app: Any, port: int = 8999) -> None:
self.app = app self.app = app
self.port = port self.port = port
self.server = None self.server = None

View File

@@ -1,11 +1,13 @@
# gemini.py # gemini.py
from __future__ import annotations
import tomllib import tomllib
from pathlib import Path from pathlib import Path
from typing import Any
from google import genai from google import genai
from google.genai import types from google.genai import types
_client = None _client: genai.Client | None = None
_chat = None _chat: Any = None
def _load_key() -> str: def _load_key() -> str:
with open("credentials.toml", "rb") as f: with open("credentials.toml", "rb") as f:

View File

@@ -6,7 +6,7 @@ import os
import session_logger # Import session_logger import session_logger # Import session_logger
class GeminiCliAdapter: class GeminiCliAdapter:
def __init__(self, binary_path="gemini"): def __init__(self, binary_path: str = "gemini") -> None:
self.binary_path = binary_path self.binary_path = binary_path
self.last_usage = None self.last_usage = None
self.session_id = None self.session_id = None
@@ -23,7 +23,7 @@ class GeminiCliAdapter:
estimated_tokens = total_chars // 4 estimated_tokens = total_chars // 4
return estimated_tokens return estimated_tokens
def send(self, message, safety_settings=None, system_instruction=None, model: str = None): def send(self, message: str, safety_settings: list | None = None, system_instruction: str | None = None, model: str | None = None) -> str:
""" """
Sends a message to the Gemini CLI and processes the streaming JSON output. Sends a message to the Gemini CLI and processes the streaming JSON output.
Logs the CLI call details using session_logger.log_cli_call. Logs the CLI call details using session_logger.log_cli_call.

View File

@@ -1,3 +1,4 @@
from __future__ import annotations
import tomli_w import tomli_w
import tomllib import tomllib
from datetime import datetime from datetime import datetime
@@ -9,7 +10,7 @@ class LogRegistry:
Tracks session paths, start times, whitelisting status, and metadata. Tracks session paths, start times, whitelisting status, and metadata.
""" """
def __init__(self, registry_path): def __init__(self, registry_path: str) -> None:
""" """
Initializes the LogRegistry with a path to the registry file. Initializes the LogRegistry with a path to the registry file.
@@ -81,7 +82,7 @@ class LogRegistry:
except Exception as e: except Exception as e:
print(f"Error saving registry to {self.registry_path}: {e}") print(f"Error saving registry to {self.registry_path}: {e}")
def register_session(self, session_id, path, start_time): def register_session(self, session_id: str, path: str, start_time: datetime | str) -> None:
""" """
Registers a new session in the registry. Registers a new session in the registry.
@@ -105,7 +106,7 @@ class LogRegistry:
} }
self.save_registry() self.save_registry()
def update_session_metadata(self, session_id, message_count, errors, size_kb, whitelisted, reason): def update_session_metadata(self, session_id: str, message_count: int, errors: int, size_kb: int, whitelisted: bool, reason: str) -> None:
""" """
Updates metadata fields for an existing session. Updates metadata fields for an existing session.
@@ -135,7 +136,7 @@ class LogRegistry:
self.data[session_id]['whitelisted'] = whitelisted self.data[session_id]['whitelisted'] = whitelisted
self.save_registry() # Save after update self.save_registry() # Save after update
def is_session_whitelisted(self, session_id): def is_session_whitelisted(self, session_id: str) -> bool:
""" """
Checks if a specific session is marked as whitelisted. Checks if a specific session is marked as whitelisted.
@@ -209,7 +210,7 @@ class LogRegistry:
reason=reason reason=reason
) )
def get_old_non_whitelisted_sessions(self, cutoff_datetime): def get_old_non_whitelisted_sessions(self, cutoff_datetime: datetime) -> list[dict]:
""" """
Retrieves a list of sessions that are older than a specific cutoff time Retrieves a list of sessions that are older than a specific cutoff time
and are not marked as whitelisted. and are not marked as whitelisted.

View File

@@ -213,7 +213,7 @@ def confirm_spawn(role: str, prompt: str, context_md: str, event_queue: events.A
return approved, modified_prompt, modified_context return approved, modified_prompt, modified_context
return False, prompt, context_md return False, prompt, context_md
def run_worker_lifecycle(ticket: Ticket, context: WorkerContext, context_files: List[str] = None, event_queue: events.AsyncEventQueue = None, engine: Optional['ConductorEngine'] = None, md_content: str = ""): def run_worker_lifecycle(ticket: Ticket, context: WorkerContext, context_files: List[str] | None = None, event_queue: events.AsyncEventQueue | None = None, engine: Optional['ConductorEngine'] = None, md_content: str = "") -> None:
""" """
Simulates the lifecycle of a single agent working on a ticket. Simulates the lifecycle of a single agent working on a ticket.
Calls the AI client and updates the ticket status based on the response. Calls the AI client and updates the ticket status based on the response.

View File

@@ -6,7 +6,7 @@ import aggregate
import summarize import summarize
from pathlib import Path from pathlib import Path
CONDUCTOR_PATH = Path("conductor") CONDUCTOR_PATH: Path = Path("conductor")
def get_track_history_summary() -> str: def get_track_history_summary() -> str:
""" """

View File

@@ -1,6 +1,8 @@
from __future__ import annotations
import time import time
import psutil import psutil
import threading import threading
from typing import Any
class PerformanceMonitor: class PerformanceMonitor:
def __init__(self) -> None: def __init__(self) -> None:
@@ -98,7 +100,7 @@ class PerformanceMonitor:
self._last_alert_time = now self._last_alert_time = now
self.alert_callback("; ".join(alerts)) self.alert_callback("; ".join(alerts))
def get_metrics(self): def get_metrics(self) -> dict[str, Any]:
with self._cpu_lock: with self._cpu_lock:
cpu_usage = self._cpu_usage cpu_usage = self._cpu_usage
metrics = { metrics = {

View File

@@ -26,6 +26,7 @@ context block that replaces full file contents in the initial <context> send.
import ast import ast
import re import re
from pathlib import Path from pathlib import Path
from typing import Callable
# ------------------------------------------------------------------ per-type extractors # ------------------------------------------------------------------ per-type extractors
@@ -138,7 +139,7 @@ def _summarise_generic(path: Path, content: str) -> str:
return "\n".join(parts) return "\n".join(parts)
# ------------------------------------------------------------------ dispatch # ------------------------------------------------------------------ dispatch
_SUMMARISERS = { _SUMMARISERS: dict[str, Callable[[Path, str], str]] = {
".py": _summarise_python, ".py": _summarise_python,
".toml": _summarise_toml, ".toml": _summarise_toml,
".md": _summarise_markdown, ".md": _summarise_markdown,

View File

@@ -16,7 +16,7 @@ from pathlib import Path
# Each palette maps imgui color enum values to (R, G, B, A) floats [0..1]. # Each palette maps imgui color enum values to (R, G, B, A) floats [0..1].
# Only keys that differ from the ImGui dark defaults need to be listed. # Only keys that differ from the ImGui dark defaults need to be listed.
def _c(r, g, b, a=255): def _c(r: int, g: int, b: int, a: int = 255) -> tuple[float, float, float, float]:
"""Convert 0-255 RGBA to 0.0-1.0 floats.""" """Convert 0-255 RGBA to 0.0-1.0 floats."""
return (r / 255.0, g / 255.0, b / 255.0, a / 255.0) return (r / 255.0, g / 255.0, b / 255.0, a / 255.0)