refactor(types): Phase 4 type hint sweep — core modules

This commit is contained in:
2026-02-28 15:13:55 -05:00
parent ca04026db5
commit 46c2f9a0ca
10 changed files with 52 additions and 42 deletions

View File

@@ -1,14 +1,16 @@
from __future__ import annotations
import requests
import json
import time
from typing import Any
class ApiHookClient:
def __init__(self, base_url="http://127.0.0.1:8999", max_retries=5, retry_delay=0.2):
def __init__(self, base_url: str = "http://127.0.0.1:8999", max_retries: int = 5, retry_delay: float = 0.2) -> None:
self.base_url = base_url
self.max_retries = max_retries
self.retry_delay = retry_delay
def wait_for_server(self, timeout=3):
def wait_for_server(self, timeout: float = 3) -> bool:
"""
Polls the /status endpoint until the server is ready or timeout is reached.
"""
@@ -21,7 +23,7 @@ class ApiHookClient:
time.sleep(0.1)
return False
def _make_request(self, method, endpoint, data=None, timeout=None):
def _make_request(self, method: str, endpoint: str, data: dict | None = None, timeout: float | None = None) -> dict | None:
url = f"{self.base_url}{endpoint}"
headers = {'Content-Type': 'application/json'}
last_exception = None
@@ -54,7 +56,7 @@ class ApiHookClient:
if last_exception:
raise last_exception
def get_status(self):
def get_status(self) -> dict:
"""Checks the health of the hook server."""
url = f"{self.base_url}/status"
try:
@@ -64,37 +66,37 @@ class ApiHookClient:
except Exception:
raise requests.exceptions.ConnectionError(f"Could not reach /status at {self.base_url}")
def get_project(self):
def get_project(self) -> dict | None:
return self._make_request('GET', '/api/project')
def post_project(self, project_data):
def post_project(self, project_data: dict) -> dict | None:
return self._make_request('POST', '/api/project', data={'project': project_data})
def get_session(self):
def get_session(self) -> dict | None:
return self._make_request('GET', '/api/session')
def get_mma_status(self):
def get_mma_status(self) -> dict | None:
"""Retrieves current MMA status (track, tickets, tier, etc.)"""
return self._make_request('GET', '/api/gui/mma_status')
def push_event(self, event_type, payload):
def push_event(self, event_type: str, payload: dict) -> dict | None:
"""Pushes an event to the GUI's AsyncEventQueue via the /api/gui endpoint."""
return self.post_gui({
"action": event_type,
"payload": payload
})
def get_performance(self):
def get_performance(self) -> dict | None:
"""Retrieves UI performance metrics."""
return self._make_request('GET', '/api/performance')
def post_session(self, session_entries):
def post_session(self, session_entries: list) -> dict | None:
return self._make_request('POST', '/api/session', data={'session': {'entries': session_entries}})
def post_gui(self, gui_data):
def post_gui(self, gui_data: dict) -> dict | None:
return self._make_request('POST', '/api/gui', data=gui_data)
def select_tab(self, tab_bar, tab):
def select_tab(self, tab_bar: str, tab: str) -> dict | None:
"""Tells the GUI to switch to a specific tab in a tab bar."""
return self.post_gui({
"action": "select_tab",
@@ -102,7 +104,7 @@ class ApiHookClient:
"tab": tab
})
def select_list_item(self, listbox, item_value):
def select_list_item(self, listbox: str, item_value: str) -> dict | None:
"""Tells the GUI to select an item in a listbox by its value."""
return self.post_gui({
"action": "select_list_item",
@@ -110,7 +112,7 @@ class ApiHookClient:
"item_value": item_value
})
def set_value(self, item, value):
def set_value(self, item: str, value: Any) -> dict | None:
"""Sets the value of a GUI item."""
return self.post_gui({
"action": "set_value",
@@ -118,7 +120,7 @@ class ApiHookClient:
"value": value
})
def get_value(self, item):
def get_value(self, item: str) -> Any:
"""Gets the value of a GUI item via its mapped field."""
try:
# First try direct field querying via POST
@@ -156,12 +158,12 @@ class ApiHookClient:
pass
return None
def get_text_value(self, item_tag):
def get_text_value(self, item_tag: str) -> str | None:
"""Wraps get_value and returns its string representation, or None."""
val = self.get_value(item_tag)
return str(val) if val is not None else None
def get_node_status(self, node_tag):
def get_node_status(self, node_tag: str) -> Any:
"""Wraps get_value for a DAG node or queries the diagnostic endpoint for its status."""
val = self.get_value(node_tag)
if val is not None:
@@ -176,7 +178,7 @@ class ApiHookClient:
pass
return None
def click(self, item, *args, **kwargs):
def click(self, item: str, *args: Any, **kwargs: Any) -> dict | None:
"""Simulates a click on a GUI button or item."""
user_data = kwargs.pop('user_data', None)
return self.post_gui({
@@ -187,7 +189,7 @@ class ApiHookClient:
"user_data": user_data
})
def get_indicator_state(self, tag):
def get_indicator_state(self, tag: str) -> dict:
"""Checks if an indicator is shown using the diagnostics endpoint."""
# Mapping tag to the keys used in diagnostics endpoint
mapping = {
@@ -202,14 +204,14 @@ class ApiHookClient:
except Exception as e:
return {"tag": tag, "shown": False, "error": str(e)}
def get_events(self):
def get_events(self) -> list:
"""Fetches and clears the event queue from the server."""
try:
return self._make_request('GET', '/api/events').get("events", [])
except Exception:
return []
def wait_for_event(self, event_type, timeout=5):
def wait_for_event(self, event_type: str, timeout: float = 5) -> dict | None:
"""Polls for a specific event type."""
start = time.time()
while time.time() - start < timeout:
@@ -220,7 +222,7 @@ class ApiHookClient:
time.sleep(0.1) # Fast poll
return None
def wait_for_value(self, item, expected, timeout=5):
def wait_for_value(self, item: str, expected: Any, timeout: float = 5) -> bool:
"""Polls until get_value(item) == expected."""
start = time.time()
while time.time() - start < timeout:
@@ -229,11 +231,11 @@ class ApiHookClient:
time.sleep(0.1) # Fast poll
return False
def reset_session(self):
def reset_session(self) -> dict | None:
"""Simulates clicking the 'Reset Session' button in the GUI."""
return self.click("btn_reset")
def request_confirmation(self, tool_name, args):
def request_confirmation(self, tool_name: str, args: dict) -> Any:
"""Asks the user for confirmation via the GUI (blocking call)."""
# Using a long timeout as this waits for human input (60 seconds)
res = self._make_request('POST', '/api/ask',

View File

@@ -1,14 +1,16 @@
from __future__ import annotations
import json
import threading
import uuid
from http.server import ThreadingHTTPServer, BaseHTTPRequestHandler
from typing import Any
import logging
import session_logger
class HookServerInstance(ThreadingHTTPServer):
"""Custom HTTPServer that carries a reference to the main App instance."""
def __init__(self, server_address, RequestHandlerClass, app):
def __init__(self, server_address: tuple[str, int], RequestHandlerClass: type, app: Any) -> None:
super().__init__(server_address, RequestHandlerClass)
self.app = app
@@ -273,11 +275,11 @@ class HookHandler(BaseHTTPRequestHandler):
self.end_headers()
self.wfile.write(json.dumps({'error': str(e)}).encode('utf-8'))
def log_message(self, format, *args):
def log_message(self, format: str, *args: Any) -> None:
logging.info("Hook API: " + format % args)
class HookServer:
def __init__(self, app, port=8999):
def __init__(self, app: Any, port: int = 8999) -> None:
self.app = app
self.port = port
self.server = None

View File

@@ -1,11 +1,13 @@
# gemini.py
from __future__ import annotations
import tomllib
from pathlib import Path
from typing import Any
from google import genai
from google.genai import types
_client = None
_chat = None
_client: genai.Client | None = None
_chat: Any = None
def _load_key() -> str:
with open("credentials.toml", "rb") as f:

View File

@@ -6,7 +6,7 @@ import os
import session_logger # Import session_logger
class GeminiCliAdapter:
def __init__(self, binary_path="gemini"):
def __init__(self, binary_path: str = "gemini") -> None:
self.binary_path = binary_path
self.last_usage = None
self.session_id = None
@@ -23,7 +23,7 @@ class GeminiCliAdapter:
estimated_tokens = total_chars // 4
return estimated_tokens
def send(self, message, safety_settings=None, system_instruction=None, model: str = None):
def send(self, message: str, safety_settings: list | None = None, system_instruction: str | None = None, model: str | None = None) -> str:
"""
Sends a message to the Gemini CLI and processes the streaming JSON output.
Logs the CLI call details using session_logger.log_cli_call.

View File

@@ -1,3 +1,4 @@
from __future__ import annotations
import tomli_w
import tomllib
from datetime import datetime
@@ -9,7 +10,7 @@ class LogRegistry:
Tracks session paths, start times, whitelisting status, and metadata.
"""
def __init__(self, registry_path):
def __init__(self, registry_path: str) -> None:
"""
Initializes the LogRegistry with a path to the registry file.
@@ -81,7 +82,7 @@ class LogRegistry:
except Exception as e:
print(f"Error saving registry to {self.registry_path}: {e}")
def register_session(self, session_id, path, start_time):
def register_session(self, session_id: str, path: str, start_time: datetime | str) -> None:
"""
Registers a new session in the registry.
@@ -105,7 +106,7 @@ class LogRegistry:
}
self.save_registry()
def update_session_metadata(self, session_id, message_count, errors, size_kb, whitelisted, reason):
def update_session_metadata(self, session_id: str, message_count: int, errors: int, size_kb: int, whitelisted: bool, reason: str) -> None:
"""
Updates metadata fields for an existing session.
@@ -135,7 +136,7 @@ class LogRegistry:
self.data[session_id]['whitelisted'] = whitelisted
self.save_registry() # Save after update
def is_session_whitelisted(self, session_id):
def is_session_whitelisted(self, session_id: str) -> bool:
"""
Checks if a specific session is marked as whitelisted.
@@ -209,7 +210,7 @@ class LogRegistry:
reason=reason
)
def get_old_non_whitelisted_sessions(self, cutoff_datetime):
def get_old_non_whitelisted_sessions(self, cutoff_datetime: datetime) -> list[dict]:
"""
Retrieves a list of sessions that are older than a specific cutoff time
and are not marked as whitelisted.

View File

@@ -213,7 +213,7 @@ def confirm_spawn(role: str, prompt: str, context_md: str, event_queue: events.A
return approved, modified_prompt, modified_context
return False, prompt, context_md
def run_worker_lifecycle(ticket: Ticket, context: WorkerContext, context_files: List[str] = None, event_queue: events.AsyncEventQueue = None, engine: Optional['ConductorEngine'] = None, md_content: str = ""):
def run_worker_lifecycle(ticket: Ticket, context: WorkerContext, context_files: List[str] | None = None, event_queue: events.AsyncEventQueue | None = None, engine: Optional['ConductorEngine'] = None, md_content: str = "") -> None:
"""
Simulates the lifecycle of a single agent working on a ticket.
Calls the AI client and updates the ticket status based on the response.

View File

@@ -6,7 +6,7 @@ import aggregate
import summarize
from pathlib import Path
CONDUCTOR_PATH = Path("conductor")
CONDUCTOR_PATH: Path = Path("conductor")
def get_track_history_summary() -> str:
"""

View File

@@ -1,6 +1,8 @@
from __future__ import annotations
import time
import psutil
import threading
from typing import Any
class PerformanceMonitor:
def __init__(self) -> None:
@@ -98,7 +100,7 @@ class PerformanceMonitor:
self._last_alert_time = now
self.alert_callback("; ".join(alerts))
def get_metrics(self):
def get_metrics(self) -> dict[str, Any]:
with self._cpu_lock:
cpu_usage = self._cpu_usage
metrics = {

View File

@@ -26,6 +26,7 @@ context block that replaces full file contents in the initial <context> send.
import ast
import re
from pathlib import Path
from typing import Callable
# ------------------------------------------------------------------ per-type extractors
@@ -138,7 +139,7 @@ def _summarise_generic(path: Path, content: str) -> str:
return "\n".join(parts)
# ------------------------------------------------------------------ dispatch
_SUMMARISERS = {
_SUMMARISERS: dict[str, Callable[[Path, str], str]] = {
".py": _summarise_python,
".toml": _summarise_toml,
".md": _summarise_markdown,

View File

@@ -16,7 +16,7 @@ from pathlib import Path
# Each palette maps imgui color enum values to (R, G, B, A) floats [0..1].
# Only keys that differ from the ImGui dark defaults need to be listed.
def _c(r, g, b, a=255):
def _c(r: int, g: int, b: int, a: int = 255) -> tuple[float, float, float, float]:
"""Convert 0-255 RGBA to 0.0-1.0 floats."""
return (r / 255.0, g / 255.0, b / 255.0, a / 255.0)