refactor(tests): Add strict type hints to second batch of test files

This commit is contained in:
2026-02-28 19:11:23 -05:00
parent f0415a40aa
commit 579ee8394f
10 changed files with 358 additions and 351 deletions

View File

@@ -1,4 +1,5 @@
import pytest
from typing import Any
from unittest.mock import MagicMock, patch
import ai_client
@@ -10,22 +11,22 @@ class MockUsage:
self.cached_content_token_count = 0
class MockPart:
def __init__(self, text, function_call):
def __init__(self, text: Any, function_call: Any) -> None:
self.text = text
self.function_call = function_call
class MockContent:
def __init__(self, parts):
def __init__(self, parts: Any) -> None:
self.parts = parts
class MockCandidate:
def __init__(self, parts):
def __init__(self, parts: Any) -> None:
self.content = MockContent(parts)
self.finish_reason = MagicMock()
self.finish_reason.name = "STOP"
def test_ai_client_event_emitter_exists():
# This should fail initially because 'events' won't exist on ai_client
def test_ai_client_event_emitter_exists() -> None:
# This should fail initially because 'events' won't exist on ai_client
assert hasattr(ai_client, 'events')
def test_event_emission() -> None:

View File

@@ -1,17 +1,18 @@
import os
import pytest
from typing import Any
from datetime import datetime
from log_registry import LogRegistry
@pytest.fixture
def registry_setup(tmp_path):
def registry_setup(tmp_path: Any) -> Any:
registry_path = tmp_path / "log_registry.toml"
logs_dir = tmp_path / "logs"
logs_dir.mkdir()
registry = LogRegistry(str(registry_path))
return registry, logs_dir
def test_auto_whitelist_keywords(registry_setup):
def test_auto_whitelist_keywords(registry_setup: Any) -> None:
registry, logs_dir = registry_setup
session_id = "test_kw"
session_dir = logs_dir / session_id
@@ -24,7 +25,7 @@ def test_auto_whitelist_keywords(registry_setup):
assert registry.is_session_whitelisted(session_id)
assert "ERROR" in registry.data[session_id]["metadata"]["reason"]
def test_auto_whitelist_message_count(registry_setup):
def test_auto_whitelist_message_count(registry_setup: Any) -> None:
registry, logs_dir = registry_setup
session_id = "test_msg_count"
session_dir = logs_dir / session_id
@@ -37,7 +38,7 @@ def test_auto_whitelist_message_count(registry_setup):
assert registry.is_session_whitelisted(session_id)
assert registry.data[session_id]["metadata"]["message_count"] == 15
def test_auto_whitelist_large_size(registry_setup):
def test_auto_whitelist_large_size(registry_setup: Any) -> None:
registry, logs_dir = registry_setup
session_id = "test_large"
session_dir = logs_dir / session_id
@@ -50,7 +51,7 @@ def test_auto_whitelist_large_size(registry_setup):
assert registry.is_session_whitelisted(session_id)
assert "Large session size" in registry.data[session_id]["metadata"]["reason"]
def test_no_auto_whitelist_insignificant(registry_setup):
def test_no_auto_whitelist_insignificant(registry_setup: Any) -> None:
registry, logs_dir = registry_setup
session_id = "test_insignificant"
session_dir = logs_dir / session_id

View File

@@ -1,4 +1,5 @@
import unittest
from typing import Any
from unittest.mock import patch, MagicMock
import json
import conductor_tech_lead
@@ -7,8 +8,7 @@ class TestConductorTechLead(unittest.TestCase):
@patch('ai_client.send')
@patch('ai_client.set_provider')
@patch('ai_client.reset_session')
def test_generate_tickets_success(self, mock_reset_session, mock_set_provider, mock_send):
# Setup mock response
def test_generate_tickets_success(self, mock_reset_session: Any, mock_set_provider: Any, mock_send: Any) -> None:
mock_tickets = [
{
"id": "ticket_1",
@@ -39,7 +39,7 @@ class TestConductorTechLead(unittest.TestCase):
@patch('ai_client.send')
@patch('ai_client.set_provider')
@patch('ai_client.reset_session')
def test_generate_tickets_parse_error(self, mock_reset_session, mock_set_provider, mock_send):
def test_generate_tickets_parse_error(self, mock_reset_session: Any, mock_set_provider: Any, mock_send: Any) -> None:
# Setup mock invalid response
mock_send.return_value = "Invalid JSON"
# Call the function
@@ -63,7 +63,7 @@ class TestTopologicalSort(unittest.TestCase):
ids = [t["id"] for t in sorted_tickets]
self.assertEqual(ids, ["t1", "t2", "t3"])
def test_topological_sort_complex(self):
def test_topological_sort_complex(self) -> None:
# t1
# | \
# t2 t3
@@ -91,7 +91,7 @@ class TestTopologicalSort(unittest.TestCase):
conductor_tech_lead.topological_sort(tickets)
self.assertIn("Circular dependency detected", str(cm.exception))
def test_topological_sort_missing_dependency(self):
def test_topological_sort_missing_dependency(self) -> None:
# If a ticket depends on something not in the list, we should probably handle it or let it fail.
# Usually in our context, we only care about dependencies within the same track.
tickets = [

View File

@@ -1,4 +1,5 @@
import pytest
from typing import Any
import time
import sys
import os
@@ -13,7 +14,7 @@ from simulation.sim_tools import ToolsSimulation
from simulation.sim_execution import ExecutionSimulation
@pytest.mark.integration
def test_context_sim_live(live_gui):
def test_context_sim_live(live_gui: Any) -> None:
"""Run the Context & Chat simulation against a live GUI."""
client = ApiHookClient()
assert client.wait_for_server(timeout=10)
@@ -23,7 +24,7 @@ def test_context_sim_live(live_gui):
sim.teardown()
@pytest.mark.integration
def test_ai_settings_sim_live(live_gui):
def test_ai_settings_sim_live(live_gui: Any) -> None:
"""Run the AI Settings simulation against a live GUI."""
client = ApiHookClient()
assert client.wait_for_server(timeout=10)
@@ -33,7 +34,7 @@ def test_ai_settings_sim_live(live_gui):
sim.teardown()
@pytest.mark.integration
def test_tools_sim_live(live_gui):
def test_tools_sim_live(live_gui: Any) -> None:
"""Run the Tools & Search simulation against a live GUI."""
client = ApiHookClient()
assert client.wait_for_server(timeout=10)
@@ -43,7 +44,7 @@ def test_tools_sim_live(live_gui):
sim.teardown()
@pytest.mark.integration
def test_execution_sim_live(live_gui):
def test_execution_sim_live(live_gui: Any) -> None:
"""Run the Execution & Modals simulation against a live GUI."""
client = ApiHookClient()
assert client.wait_for_server(timeout=10)

View File

@@ -1,4 +1,5 @@
import unittest
from typing import Any
from unittest.mock import patch, MagicMock
import json
import subprocess
@@ -16,7 +17,7 @@ class TestGeminiCliAdapter(unittest.TestCase):
self.adapter = GeminiCliAdapter(binary_path="gemini")
@patch('subprocess.Popen')
def test_send_starts_subprocess_with_correct_args(self, mock_popen):
def test_send_starts_subprocess_with_correct_args(self, mock_popen: Any) -> None:
"""
Verify that send(message) correctly starts the subprocess with
--output-format stream-json and the provided message via stdin using communicate.
@@ -48,7 +49,7 @@ class TestGeminiCliAdapter(unittest.TestCase):
self.assertEqual(kwargs.get('text'), True)
@patch('subprocess.Popen')
def test_send_parses_jsonl_output(self, mock_popen):
def test_send_parses_jsonl_output(self, mock_popen: Any) -> None:
"""
Verify that it correctly parses multiple JSONL 'message' events
and returns the combined text.
@@ -69,7 +70,7 @@ class TestGeminiCliAdapter(unittest.TestCase):
self.assertEqual(result["tool_calls"], [])
@patch('subprocess.Popen')
def test_send_handles_tool_use_events(self, mock_popen):
def test_send_handles_tool_use_events(self, mock_popen: Any) -> None:
"""
Verify that it correctly handles 'tool_use' events in the stream
by continuing to read until the final 'result' event.
@@ -93,7 +94,7 @@ class TestGeminiCliAdapter(unittest.TestCase):
self.assertEqual(result["tool_calls"][0]["name"], "read_file")
@patch('subprocess.Popen')
def test_send_captures_usage_metadata(self, mock_popen):
def test_send_captures_usage_metadata(self, mock_popen: Any) -> None:
"""
Verify that usage data is extracted from the 'result' event.
"""

View File

@@ -1,4 +1,5 @@
import pytest
from typing import Any
import time
import json
import os
@@ -22,7 +23,7 @@ def cleanup_callback_file() -> None:
if TEST_CALLBACK_FILE.exists():
TEST_CALLBACK_FILE.unlink()
def test_gui2_set_value_hook_works(live_gui):
def test_gui2_set_value_hook_works(live_gui: Any) -> None:
"""
Tests that the 'set_value' GUI hook is correctly implemented.
"""
@@ -37,7 +38,7 @@ def test_gui2_set_value_hook_works(live_gui):
current_value = client.get_value('ai_input')
assert current_value == test_value
def test_gui2_click_hook_works(live_gui):
def test_gui2_click_hook_works(live_gui: Any) -> None:
"""
Tests that the 'click' GUI hook for the 'Reset' button is implemented.
"""
@@ -54,7 +55,7 @@ def test_gui2_click_hook_works(live_gui):
# Verify it was reset
assert client.get_value('ai_input') == ""
def test_gui2_custom_callback_hook_works(live_gui):
def test_gui2_custom_callback_hook_works(live_gui: Any) -> None:
"""
Tests that the 'custom_callback' GUI hook is correctly implemented.
"""

View File

@@ -1,4 +1,5 @@
import pytest
from typing import Any
import sys
import os
import importlib.util
@@ -40,7 +41,7 @@ def test_new_hubs_defined_in_window_info() -> None:
assert l == label or label in l, f"Label mismatch for {tag}: expected {label}, found {l}"
assert found, f"Expected window label {label} not found in window_info"
def test_old_windows_removed_from_window_info(app_instance_simple):
def test_old_windows_removed_from_window_info(app_instance_simple: Any) -> None:
"""
Verifies that the old fragmented windows are removed from window_info.
"""
@@ -54,14 +55,14 @@ def test_old_windows_removed_from_window_info(app_instance_simple):
assert tag not in app_instance_simple.window_info.values(), f"Old window tag {tag} should have been removed from window_info"
@pytest.fixture
def app_instance_simple():
def app_instance_simple() -> Any:
from unittest.mock import patch
from gui_legacy import App
with patch('gui_legacy.load_config', return_value={}):
app = App()
return app
def test_hub_windows_have_correct_flags(app_instance_simple):
def test_hub_windows_have_correct_flags(app_instance_simple: Any) -> None:
"""
Verifies that the new Hub windows have appropriate flags for a professional workspace.
(e.g., no_collapse should be True for main hubs).
@@ -80,7 +81,7 @@ def test_hub_windows_have_correct_flags(app_instance_simple):
# but we can check if it's been configured if we mock dpg.window or check it manually
dpg.destroy_context()
def test_indicators_exist(app_instance_simple):
def test_indicators_exist(app_instance_simple: Any) -> None:
"""
Verifies that the new thinking and live indicators exist in the UI.
"""

View File

@@ -1,4 +1,5 @@
import unittest
from typing import Any
from unittest.mock import patch, MagicMock
import json
import orchestrator_pm
@@ -8,7 +9,7 @@ class TestOrchestratorPM(unittest.TestCase):
@patch('summarize.build_summary_markdown')
@patch('ai_client.send')
def test_generate_tracks_success(self, mock_send, mock_summarize):
def test_generate_tracks_success(self, mock_send: Any, mock_summarize: Any) -> None:
# Setup mocks
mock_summarize.return_value = "REPO_MAP_CONTENT"
mock_response_data = [
@@ -44,7 +45,7 @@ class TestOrchestratorPM(unittest.TestCase):
@patch('summarize.build_summary_markdown')
@patch('ai_client.send')
def test_generate_tracks_markdown_wrapped(self, mock_send, mock_summarize):
def test_generate_tracks_markdown_wrapped(self, mock_send: Any, mock_summarize: Any) -> None:
mock_summarize.return_value = "REPO_MAP"
mock_response_data = [{"id": "track_1"}]
expected_result = [{"id": "track_1", "title": "Untitled Track"}]
@@ -59,7 +60,7 @@ class TestOrchestratorPM(unittest.TestCase):
@patch('summarize.build_summary_markdown')
@patch('ai_client.send')
def test_generate_tracks_malformed_json(self, mock_send, mock_summarize):
def test_generate_tracks_malformed_json(self, mock_send: Any, mock_summarize: Any) -> None:
mock_summarize.return_value = "REPO_MAP"
mock_send.return_value = "NOT A JSON"
# Should return empty list and print error (we can mock print if we want to be thorough)

View File

@@ -1,15 +1,15 @@
import pytest
from typing import Any
import json
from pathlib import Path
from project_manager import get_all_tracks, save_track_state
from models import TrackState, Metadata, Ticket
from datetime import datetime
def test_get_all_tracks_empty(tmp_path):
# conductor/tracks directory doesn't exist
def test_get_all_tracks_empty(tmp_path: Any) -> None:
assert get_all_tracks(tmp_path) == []
def test_get_all_tracks_with_state(tmp_path):
def test_get_all_tracks_with_state(tmp_path: Any) -> None:
tracks_dir = tmp_path / "conductor" / "tracks"
tracks_dir.mkdir(parents=True)
track_id = "test_track_1"
@@ -34,7 +34,7 @@ def test_get_all_tracks_with_state(tmp_path):
assert track["total"] == 2
assert track["progress"] == 0.5
def test_get_all_tracks_with_metadata_json(tmp_path):
def test_get_all_tracks_with_metadata_json(tmp_path: Any) -> None:
tracks_dir = tmp_path / "conductor" / "tracks"
tracks_dir.mkdir(parents=True)
track_id = "test_track_2"
@@ -66,7 +66,7 @@ def test_get_all_tracks_with_metadata_json(tmp_path):
assert track["total"] == 3
assert pytest.approx(track["progress"]) == 0.333333
def test_get_all_tracks_malformed(tmp_path):
def test_get_all_tracks_malformed(tmp_path: Any) -> None:
tracks_dir = tmp_path / "conductor" / "tracks"
tracks_dir.mkdir(parents=True)
track_id = "malformed_track"

View File

@@ -1,9 +1,9 @@
import pytest
from typing import Any
from pathlib import Path
from aggregate import build_tier1_context, build_tier2_context, build_tier3_context
def test_build_tier1_context_exists():
# This should fail if the function is not defined
def test_build_tier1_context_exists() -> None:
file_items = [
{"path": Path("conductor/product.md"), "entry": "conductor/product.md", "content": "Product content", "error": False},
{"path": Path("other.py"), "entry": "other.py", "content": "Other content", "error": False}
@@ -22,7 +22,7 @@ def test_build_tier2_context_exists() -> None:
result = build_tier2_context(file_items, Path("."), [], history)
assert "Other content" in result
def test_build_tier3_context_ast_skeleton(monkeypatch):
def test_build_tier3_context_ast_skeleton(monkeypatch: Any) -> None:
from unittest.mock import MagicMock
import aggregate
import file_cache
@@ -59,7 +59,7 @@ def test_build_tier3_context_exists() -> None:
assert "other.py" in result
assert "AST Skeleton" in result
def test_build_file_items_with_tiers(tmp_path):
def test_build_file_items_with_tiers(tmp_path: Any) -> None:
from aggregate import build_file_items
# Create some dummy files
file1 = tmp_path / "file1.txt"
@@ -80,7 +80,7 @@ def test_build_file_items_with_tiers(tmp_path):
assert item2["content"] == "content2"
assert item2["tier"] == 3
def test_build_files_section_with_dicts(tmp_path):
def test_build_files_section_with_dicts(tmp_path: Any) -> None:
from aggregate import build_files_section
file1 = tmp_path / "file1.txt"
file1.write_text("content1")