refactor(tests): Add strict type hints to first batch of test files

This commit is contained in:
2026-02-28 19:06:50 -05:00
parent e8833b6656
commit f0415a40aa
10 changed files with 59 additions and 70 deletions

View File

@@ -12,7 +12,7 @@ sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), "..")))
from api_hook_client import ApiHookClient from api_hook_client import ApiHookClient
def test_get_status_success(live_gui): def test_get_status_success(live_gui: tuple) -> None:
""" """
Test that get_status successfully retrieves the server status Test that get_status successfully retrieves the server status
when the live GUI is running. when the live GUI is running.
@@ -21,7 +21,7 @@ def test_get_status_success(live_gui):
status = client.get_status() status = client.get_status()
assert status == {'status': 'ok'} assert status == {'status': 'ok'}
def test_get_project_success(live_gui): def test_get_project_success(live_gui: tuple) -> None:
""" """
Test successful retrieval of project data from the live GUI. Test successful retrieval of project data from the live GUI.
""" """
@@ -30,7 +30,7 @@ def test_get_project_success(live_gui):
assert 'project' in response assert 'project' in response
# We don't assert specific content as it depends on the environment's active project # We don't assert specific content as it depends on the environment's active project
def test_get_session_success(live_gui): def test_get_session_success(live_gui: tuple) -> None:
""" """
Test successful retrieval of session data. Test successful retrieval of session data.
""" """
@@ -39,7 +39,7 @@ def test_get_session_success(live_gui):
assert 'session' in response assert 'session' in response
assert 'entries' in response['session'] assert 'entries' in response['session']
def test_post_gui_success(live_gui): def test_post_gui_success(live_gui: tuple) -> None:
""" """
Test successful posting of GUI data. Test successful posting of GUI data.
""" """
@@ -48,7 +48,7 @@ def test_post_gui_success(live_gui):
response = client.post_gui(gui_data) response = client.post_gui(gui_data)
assert response == {'status': 'queued'} assert response == {'status': 'queued'}
def test_get_performance_success(live_gui): def test_get_performance_success(live_gui: tuple) -> None:
""" """
Test successful retrieval of performance metrics. Test successful retrieval of performance metrics.
""" """

View File

@@ -22,8 +22,7 @@ class TestCliToolBridge(unittest.TestCase):
@patch('sys.stdin', new_callable=io.StringIO) @patch('sys.stdin', new_callable=io.StringIO)
@patch('sys.stdout', new_callable=io.StringIO) @patch('sys.stdout', new_callable=io.StringIO)
@patch('api_hook_client.ApiHookClient.request_confirmation') @patch('api_hook_client.ApiHookClient.request_confirmation')
def test_allow_decision(self, mock_request, mock_stdout, mock_stdin): def test_allow_decision(self, mock_request: MagicMock, mock_stdout: MagicMock, mock_stdin: MagicMock) -> None:
# 1. Mock stdin with a JSON string tool call
mock_stdin.write(json.dumps(self.tool_call)) mock_stdin.write(json.dumps(self.tool_call))
mock_stdin.seek(0) mock_stdin.seek(0)
# 2. Mock ApiHookClient to return approved # 2. Mock ApiHookClient to return approved
@@ -37,8 +36,7 @@ class TestCliToolBridge(unittest.TestCase):
@patch('sys.stdin', new_callable=io.StringIO) @patch('sys.stdin', new_callable=io.StringIO)
@patch('sys.stdout', new_callable=io.StringIO) @patch('sys.stdout', new_callable=io.StringIO)
@patch('api_hook_client.ApiHookClient.request_confirmation') @patch('api_hook_client.ApiHookClient.request_confirmation')
def test_deny_decision(self, mock_request, mock_stdout, mock_stdin): def test_deny_decision(self, mock_request: MagicMock, mock_stdout: MagicMock, mock_stdin: MagicMock) -> None:
# Mock stdin
mock_stdin.write(json.dumps(self.tool_call)) mock_stdin.write(json.dumps(self.tool_call))
mock_stdin.seek(0) mock_stdin.seek(0)
# 4. Mock ApiHookClient to return denied # 4. Mock ApiHookClient to return denied
@@ -51,8 +49,7 @@ class TestCliToolBridge(unittest.TestCase):
@patch('sys.stdin', new_callable=io.StringIO) @patch('sys.stdin', new_callable=io.StringIO)
@patch('sys.stdout', new_callable=io.StringIO) @patch('sys.stdout', new_callable=io.StringIO)
@patch('api_hook_client.ApiHookClient.request_confirmation') @patch('api_hook_client.ApiHookClient.request_confirmation')
def test_unreachable_hook_server(self, mock_request, mock_stdout, mock_stdin): def test_unreachable_hook_server(self, mock_request: MagicMock, mock_stdout: MagicMock, mock_stdin: MagicMock) -> None:
# Mock stdin
mock_stdin.write(json.dumps(self.tool_call)) mock_stdin.write(json.dumps(self.tool_call))
mock_stdin.seek(0) mock_stdin.seek(0)
# 5. Test case where hook server is unreachable (exception) # 5. Test case where hook server is unreachable (exception)

View File

@@ -16,7 +16,7 @@ def test_conductor_engine_initialization() -> None:
assert engine.track == track assert engine.track == track
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_conductor_engine_run_linear_executes_tickets_in_order(monkeypatch): async def test_conductor_engine_run_linear_executes_tickets_in_order(monkeypatch: pytest.MonkeyPatch) -> None:
""" """
Test that run_linear iterates through executable tickets and calls the worker lifecycle. Test that run_linear iterates through executable tickets and calls the worker lifecycle.
""" """
@@ -48,7 +48,7 @@ async def test_conductor_engine_run_linear_executes_tickets_in_order(monkeypatch
assert calls[1][0][0].id == "T2" assert calls[1][0][0].id == "T2"
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_run_worker_lifecycle_calls_ai_client_send(monkeypatch): async def test_run_worker_lifecycle_calls_ai_client_send(monkeypatch: pytest.MonkeyPatch) -> None:
""" """
Test that run_worker_lifecycle triggers the AI client and updates ticket status on success. Test that run_worker_lifecycle triggers the AI client and updates ticket status on success.
""" """
@@ -69,7 +69,7 @@ async def test_run_worker_lifecycle_calls_ai_client_send(monkeypatch):
assert ticket.description in kwargs["user_message"] assert ticket.description in kwargs["user_message"]
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_run_worker_lifecycle_context_injection(monkeypatch): async def test_run_worker_lifecycle_context_injection(monkeypatch: pytest.MonkeyPatch) -> None:
""" """
Test that run_worker_lifecycle can take a context_files list and injects AST views into the prompt. Test that run_worker_lifecycle can take a context_files list and injects AST views into the prompt.
""" """
@@ -115,7 +115,7 @@ async def test_run_worker_lifecycle_context_injection(monkeypatch):
assert "secondary.py" in user_message assert "secondary.py" in user_message
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_run_worker_lifecycle_handles_blocked_response(monkeypatch): async def test_run_worker_lifecycle_handles_blocked_response(monkeypatch: pytest.MonkeyPatch) -> None:
""" """
Test that run_worker_lifecycle marks the ticket as blocked if the AI indicates it cannot proceed. Test that run_worker_lifecycle marks the ticket as blocked if the AI indicates it cannot proceed.
""" """
@@ -132,7 +132,7 @@ async def test_run_worker_lifecycle_handles_blocked_response(monkeypatch):
assert "BLOCKED" in ticket.blocked_reason assert "BLOCKED" in ticket.blocked_reason
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_run_worker_lifecycle_step_mode_confirmation(monkeypatch): async def test_run_worker_lifecycle_step_mode_confirmation(monkeypatch: pytest.MonkeyPatch) -> None:
""" """
Test that run_worker_lifecycle passes confirm_execution to ai_client.send when step_mode is True. Test that run_worker_lifecycle passes confirm_execution to ai_client.send when step_mode is True.
Verify that if confirm_execution is called (simulated by mocking ai_client.send to call its callback), Verify that if confirm_execution is called (simulated by mocking ai_client.send to call its callback),
@@ -162,7 +162,7 @@ async def test_run_worker_lifecycle_step_mode_confirmation(monkeypatch):
assert ticket.status == "completed" assert ticket.status == "completed"
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_run_worker_lifecycle_step_mode_rejection(monkeypatch): async def test_run_worker_lifecycle_step_mode_rejection(monkeypatch: pytest.MonkeyPatch) -> None:
""" """
Verify that if confirm_execution returns False, the logic (in ai_client, which we simulate here) Verify that if confirm_execution returns False, the logic (in ai_client, which we simulate here)
would prevent execution. In run_worker_lifecycle, we just check if it's passed. would prevent execution. In run_worker_lifecycle, we just check if it's passed.
@@ -184,7 +184,7 @@ async def test_run_worker_lifecycle_step_mode_rejection(monkeypatch):
# here we just verify the wiring. # here we just verify the wiring.
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_conductor_engine_dynamic_parsing_and_execution(monkeypatch): async def test_conductor_engine_dynamic_parsing_and_execution(monkeypatch: pytest.MonkeyPatch) -> None:
""" """
Test that parse_json_tickets correctly populates the track and run_linear executes them in dependency order. Test that parse_json_tickets correctly populates the track and run_linear executes them in dependency order.
""" """

View File

@@ -29,7 +29,7 @@ class TestGeminiCliAdapterParity(unittest.TestCase):
self.session_logger_patcher.stop() self.session_logger_patcher.stop()
@patch('subprocess.Popen') @patch('subprocess.Popen')
def test_count_tokens_uses_estimation(self, mock_popen): def test_count_tokens_uses_estimation(self, mock_popen: MagicMock) -> None:
""" """
Test that count_tokens uses character-based estimation. Test that count_tokens uses character-based estimation.
""" """
@@ -42,7 +42,7 @@ class TestGeminiCliAdapterParity(unittest.TestCase):
mock_popen.assert_not_called() mock_popen.assert_not_called()
@patch('subprocess.Popen') @patch('subprocess.Popen')
def test_send_with_safety_settings_no_flags_added(self, mock_popen): def test_send_with_safety_settings_no_flags_added(self, mock_popen: MagicMock) -> None:
""" """
Test that the send method does NOT add --safety flags when safety_settings are provided, Test that the send method does NOT add --safety flags when safety_settings are provided,
as this functionality is no longer supported via CLI flags. as this functionality is no longer supported via CLI flags.
@@ -66,7 +66,7 @@ class TestGeminiCliAdapterParity(unittest.TestCase):
process_mock.communicate.assert_called_once_with(input=message_content) process_mock.communicate.assert_called_once_with(input=message_content)
@patch('subprocess.Popen') @patch('subprocess.Popen')
def test_send_without_safety_settings_no_flags(self, mock_popen): def test_send_without_safety_settings_no_flags(self, mock_popen: MagicMock) -> None:
""" """
Test that when safety_settings is None or an empty list, no --safety flags are added. Test that when safety_settings is None or an empty list, no --safety flags are added.
""" """
@@ -85,7 +85,7 @@ class TestGeminiCliAdapterParity(unittest.TestCase):
self.assertNotIn("--safety", args_empty[0]) self.assertNotIn("--safety", args_empty[0])
@patch('subprocess.Popen') @patch('subprocess.Popen')
def test_send_with_system_instruction_prepended_to_stdin(self, mock_popen): def test_send_with_system_instruction_prepended_to_stdin(self, mock_popen: MagicMock) -> None:
""" """
Test that the send method prepends the system instruction to the prompt Test that the send method prepends the system instruction to the prompt
sent via stdin, and does NOT add a --system flag to the command. sent via stdin, and does NOT add a --system flag to the command.
@@ -107,7 +107,7 @@ class TestGeminiCliAdapterParity(unittest.TestCase):
self.assertNotIn("--system", command) self.assertNotIn("--system", command)
@patch('subprocess.Popen') @patch('subprocess.Popen')
def test_send_with_model_parameter(self, mock_popen): def test_send_with_model_parameter(self, mock_popen: MagicMock) -> None:
""" """
Test that the send method correctly adds the -m <model> flag when a model is specified. Test that the send method correctly adds the -m <model> flag when a model is specified.
""" """
@@ -128,7 +128,7 @@ class TestGeminiCliAdapterParity(unittest.TestCase):
process_mock.communicate.assert_called_once_with(input=message_content) process_mock.communicate.assert_called_once_with(input=message_content)
@patch('subprocess.Popen') @patch('subprocess.Popen')
def test_send_kills_process_on_communicate_exception(self, mock_popen): def test_send_kills_process_on_communicate_exception(self, mock_popen: MagicMock) -> None:
""" """
Test that if subprocess.Popen().communicate() raises an exception, Test that if subprocess.Popen().communicate() raises an exception,
GeminiCliAdapter.send() kills the process and re-raises the exception. GeminiCliAdapter.send() kills the process and re-raises the exception.

View File

@@ -8,8 +8,7 @@ from pathlib import Path
from fastapi.testclient import TestClient from fastapi.testclient import TestClient
class TestHeadlessAPI(unittest.TestCase): class TestHeadlessAPI(unittest.TestCase):
def setUp(self): def setUp(self) -> None:
# We need an App instance to initialize the API, but we want to avoid GUI stuff
with patch('gui_2.session_logger.open_session'), \ with patch('gui_2.session_logger.open_session'), \
patch('gui_2.ai_client.set_provider'), \ patch('gui_2.ai_client.set_provider'), \
patch('gui_2.session_logger.close_session'): patch('gui_2.session_logger.close_session'):
@@ -29,14 +28,12 @@ class TestHeadlessAPI(unittest.TestCase):
self.assertEqual(response.status_code, 200) self.assertEqual(response.status_code, 200)
self.assertEqual(response.json(), {"status": "ok"}) self.assertEqual(response.json(), {"status": "ok"})
def test_status_endpoint_unauthorized(self): def test_status_endpoint_unauthorized(self) -> None:
# Ensure a key is required
with patch.dict(self.app_instance.config, {"headless": {"api_key": "some-required-key"}}): with patch.dict(self.app_instance.config, {"headless": {"api_key": "some-required-key"}}):
response = self.client.get("/status") response = self.client.get("/status")
self.assertEqual(response.status_code, 403) self.assertEqual(response.status_code, 403)
def test_status_endpoint_authorized(self): def test_status_endpoint_authorized(self) -> None:
# We'll use a test key
headers = {"X-API-KEY": "test-secret-key"} headers = {"X-API-KEY": "test-secret-key"}
with patch.dict(self.app_instance.config, {"headless": {"api_key": "test-secret-key"}}): with patch.dict(self.app_instance.config, {"headless": {"api_key": "test-secret-key"}}):
response = self.client.get("/status", headers=headers) response = self.client.get("/status", headers=headers)
@@ -63,8 +60,7 @@ class TestHeadlessAPI(unittest.TestCase):
self.assertIn("metadata", data) self.assertIn("metadata", data)
self.assertEqual(data["usage"]["input_tokens"], 10) self.assertEqual(data["usage"]["input_tokens"], 10)
def test_pending_actions_endpoint(self): def test_pending_actions_endpoint(self) -> None:
# Manually add a pending action
with patch('gui_2.uuid.uuid4', return_value="test-action-id"): with patch('gui_2.uuid.uuid4', return_value="test-action-id"):
dialog = gui_2.ConfirmDialog("dir", ".") dialog = gui_2.ConfirmDialog("dir", ".")
self.app_instance._pending_actions[dialog._uid] = dialog self.app_instance._pending_actions[dialog._uid] = dialog
@@ -74,8 +70,7 @@ class TestHeadlessAPI(unittest.TestCase):
self.assertEqual(len(data), 1) self.assertEqual(len(data), 1)
self.assertEqual(data[0]["action_id"], "test-action-id") self.assertEqual(data[0]["action_id"], "test-action-id")
def test_confirm_action_endpoint(self): def test_confirm_action_endpoint(self) -> None:
# Manually add a pending action
with patch('gui_2.uuid.uuid4', return_value="test-confirm-id"): with patch('gui_2.uuid.uuid4', return_value="test-confirm-id"):
dialog = gui_2.ConfirmDialog("dir", ".") dialog = gui_2.ConfirmDialog("dir", ".")
self.app_instance._pending_actions[dialog._uid] = dialog self.app_instance._pending_actions[dialog._uid] = dialog
@@ -85,8 +80,7 @@ class TestHeadlessAPI(unittest.TestCase):
self.assertTrue(dialog._done) self.assertTrue(dialog._done)
self.assertTrue(dialog._approved) self.assertTrue(dialog._approved)
def test_list_sessions_endpoint(self): def test_list_sessions_endpoint(self) -> None:
# Ensure logs directory exists
Path("logs").mkdir(exist_ok=True) Path("logs").mkdir(exist_ok=True)
# Create a dummy log # Create a dummy log
dummy_log = Path("logs/test_session_api.log") dummy_log = Path("logs/test_session_api.log")
@@ -108,8 +102,7 @@ class TestHeadlessAPI(unittest.TestCase):
self.assertIn("screenshots", data) self.assertIn("screenshots", data)
self.assertIn("files_base_dir", data) self.assertIn("files_base_dir", data)
def test_endpoint_no_api_key_configured(self): def test_endpoint_no_api_key_configured(self) -> None:
# Test the security fix specifically
with patch.dict(self.app_instance.config, {"headless": {"api_key": ""}}): with patch.dict(self.app_instance.config, {"headless": {"api_key": ""}}):
response = self.client.get("/status", headers=self.headers) response = self.client.get("/status", headers=self.headers)
self.assertEqual(response.status_code, 403) self.assertEqual(response.status_code, 403)
@@ -122,8 +115,7 @@ class TestHeadlessStartup(unittest.TestCase):
@patch('gui_2.save_config') @patch('gui_2.save_config')
@patch('gui_2.ai_client.cleanup') @patch('gui_2.ai_client.cleanup')
@patch('uvicorn.run') # Mock uvicorn.run to prevent hanging @patch('uvicorn.run') # Mock uvicorn.run to prevent hanging
def test_headless_flag_prevents_gui_run(self, mock_uvicorn_run, mock_cleanup, mock_save_config, mock_hook_server, mock_immapp_run): def test_headless_flag_prevents_gui_run(self, mock_uvicorn_run: MagicMock, mock_cleanup: MagicMock, mock_save_config: MagicMock, mock_hook_server: MagicMock, mock_immapp_run: MagicMock) -> None:
# Setup mock argv with --headless
test_args = ["gui_2.py", "--headless"] test_args = ["gui_2.py", "--headless"]
with patch.object(sys, 'argv', test_args): with patch.object(sys, 'argv', test_args):
with patch('gui_2.session_logger.close_session'), \ with patch('gui_2.session_logger.close_session'), \
@@ -138,7 +130,7 @@ class TestHeadlessStartup(unittest.TestCase):
mock_uvicorn_run.assert_called_once() mock_uvicorn_run.assert_called_once()
@patch('gui_2.immapp.run') @patch('gui_2.immapp.run')
def test_normal_startup_calls_gui_run(self, mock_immapp_run): def test_normal_startup_calls_gui_run(self, mock_immapp_run: MagicMock) -> None:
test_args = ["gui_2.py"] test_args = ["gui_2.py"]
with patch.object(sys, 'argv', test_args): with patch.object(sys, 'argv', test_args):
# In normal mode, it should still call immapp.run # In normal mode, it should still call immapp.run

View File

@@ -17,7 +17,7 @@ import ai_client
# --- Tests for Aggregate Module --- # --- Tests for Aggregate Module ---
def test_aggregate_includes_segregated_history(tmp_path): def test_aggregate_includes_segregated_history(tmp_path: Path) -> None:
""" """
Tests if the aggregate function correctly includes history Tests if the aggregate function correctly includes history
when it's segregated into a separate file. when it's segregated into a separate file.
@@ -38,7 +38,7 @@ def test_aggregate_includes_segregated_history(tmp_path):
assert "Show me history" in markdown assert "Show me history" in markdown
# --- Tests for MCP Client and Blacklisting --- # --- Tests for MCP Client and Blacklisting ---
def test_mcp_blacklist(tmp_path): def test_mcp_blacklist(tmp_path: Path) -> None:
""" """
Tests that the MCP client correctly blacklists specified files Tests that the MCP client correctly blacklists specified files
and prevents listing them. and prevents listing them.
@@ -57,7 +57,7 @@ def test_mcp_blacklist(tmp_path):
# The blacklisted file should not appear in the directory listing # The blacklisted file should not appear in the directory listing
assert "my_project_history.toml" not in result assert "my_project_history.toml" not in result
def test_aggregate_blacklist(tmp_path): def test_aggregate_blacklist(tmp_path: Path) -> None:
""" """
Tests that aggregate's path resolution respects blacklisting, Tests that aggregate's path resolution respects blacklisting,
ensuring history files are not included by default. ensuring history files are not included by default.
@@ -73,7 +73,7 @@ def test_aggregate_blacklist(tmp_path):
assert hist_file not in paths, "History file should be excluded even with a general glob" assert hist_file not in paths, "History file should be excluded even with a general glob"
# --- Tests for History Migration and Separation --- # --- Tests for History Migration and Separation ---
def test_migration_on_load(tmp_path): def test_migration_on_load(tmp_path: Path) -> None:
""" """
Tests that project loading migrates discussion history from manual_slop.toml Tests that project loading migrates discussion history from manual_slop.toml
to manual_slop_history.toml if it exists in the main config. to manual_slop_history.toml if it exists in the main config.
@@ -102,7 +102,7 @@ def test_migration_on_load(tmp_path):
on_disk_hist = tomllib.load(f) on_disk_hist = tomllib.load(f)
assert on_disk_hist["discussions"]["main"]["history"] == ["Hello", "World"] assert on_disk_hist["discussions"]["main"]["history"] == ["Hello", "World"]
def test_save_separation(tmp_path): def test_save_separation(tmp_path: Path) -> None:
""" """
Tests that saving project data correctly separates discussion history Tests that saving project data correctly separates discussion history
into manual_slop_history.toml. into manual_slop_history.toml.
@@ -128,7 +128,7 @@ def test_save_separation(tmp_path):
assert h_disk["discussions"]["main"]["history"] == ["Saved", "Separately"] assert h_disk["discussions"]["main"]["history"] == ["Saved", "Separately"]
# --- Tests for History Persistence Across Turns --- # --- Tests for History Persistence Across Turns ---
def test_history_persistence_across_turns(tmp_path): def test_history_persistence_across_turns(tmp_path: Path) -> None:
""" """
Tests that discussion history is correctly persisted across multiple save/load cycles. Tests that discussion history is correctly persisted across multiple save/load cycles.
""" """

View File

@@ -8,7 +8,8 @@ import gui_2
from gui_2 import App from gui_2 import App
@pytest.fixture @pytest.fixture
def mock_config(tmp_path): @pytest.fixture
def mock_config(tmp_path: Path) -> Path:
config_path = tmp_path / "config.toml" config_path = tmp_path / "config.toml"
config_path.write_text("""[projects] config_path.write_text("""[projects]
paths = [] paths = []
@@ -20,7 +21,8 @@ model = "model"
return config_path return config_path
@pytest.fixture @pytest.fixture
def mock_project(tmp_path): @pytest.fixture
def mock_project(tmp_path: Path) -> Path:
project_path = tmp_path / "project.toml" project_path = tmp_path / "project.toml"
project_path.write_text("""[project] project_path.write_text("""[project]
name = "test" name = "test"
@@ -33,7 +35,8 @@ history = []
return project_path return project_path
@pytest.fixture @pytest.fixture
def app_instance(mock_config, mock_project, monkeypatch): @pytest.fixture
def app_instance(mock_config: Path, mock_project: Path, monkeypatch: pytest.MonkeyPatch) -> App:
monkeypatch.setattr("gui_2.CONFIG_PATH", mock_config) monkeypatch.setattr("gui_2.CONFIG_PATH", mock_config)
with patch("project_manager.load_project") as mock_load, \ with patch("project_manager.load_project") as mock_load, \
patch("session_logger.open_session"): patch("session_logger.open_session"):
@@ -54,14 +57,14 @@ def app_instance(mock_config, mock_project, monkeypatch):
# but python allows calling it directly. # but python allows calling it directly.
return app return app
def test_log_management_init(app_instance): def test_log_management_init(app_instance: App) -> None:
app = app_instance app = app_instance
assert "Log Management" in app.show_windows assert "Log Management" in app.show_windows
assert app.show_windows["Log Management"] is False assert app.show_windows["Log Management"] is False
assert hasattr(app, "_render_log_management") assert hasattr(app, "_render_log_management")
assert callable(app._render_log_management) assert callable(app._render_log_management)
def test_render_log_management_logic(app_instance): def test_render_log_management_logic(app_instance: App) -> None:
app = app_instance app = app_instance
app.show_windows["Log Management"] = True app.show_windows["Log Management"] = True
# Mock LogRegistry # Mock LogRegistry

View File

@@ -25,7 +25,7 @@ def app_instance() -> None:
if not hasattr(app, '_show_track_proposal_modal'): app._show_track_proposal_modal = False if not hasattr(app, '_show_track_proposal_modal'): app._show_track_proposal_modal = False
yield app yield app
def test_mma_ui_state_initialization(app_instance): def test_mma_ui_state_initialization(app_instance: App) -> None:
"""Verifies that the new MMA UI state variables are initialized correctly.""" """Verifies that the new MMA UI state variables are initialized correctly."""
assert hasattr(app_instance, 'ui_epic_input') assert hasattr(app_instance, 'ui_epic_input')
assert hasattr(app_instance, 'proposed_tracks') assert hasattr(app_instance, 'proposed_tracks')
@@ -36,7 +36,7 @@ def test_mma_ui_state_initialization(app_instance):
assert app_instance._show_track_proposal_modal is False assert app_instance._show_track_proposal_modal is False
assert app_instance.mma_streams == {} assert app_instance.mma_streams == {}
def test_process_pending_gui_tasks_show_track_proposal(app_instance): def test_process_pending_gui_tasks_show_track_proposal(app_instance: App) -> None:
"""Verifies that the 'show_track_proposal' action correctly updates the UI state.""" """Verifies that the 'show_track_proposal' action correctly updates the UI state."""
mock_tracks = [{"id": "track_1", "title": "Test Track"}] mock_tracks = [{"id": "track_1", "title": "Test Track"}]
task = { task = {
@@ -48,7 +48,7 @@ def test_process_pending_gui_tasks_show_track_proposal(app_instance):
assert app_instance.proposed_tracks == mock_tracks assert app_instance.proposed_tracks == mock_tracks
assert app_instance._show_track_proposal_modal is True assert app_instance._show_track_proposal_modal is True
def test_cb_plan_epic_launches_thread(app_instance): def test_cb_plan_epic_launches_thread(app_instance: App) -> None:
"""Verifies that _cb_plan_epic launches a thread and eventually queues a task.""" """Verifies that _cb_plan_epic launches a thread and eventually queues a task."""
app_instance.ui_epic_input = "Develop a new feature" app_instance.ui_epic_input = "Develop a new feature"
app_instance.active_project_path = "test_project.toml" app_instance.active_project_path = "test_project.toml"
@@ -80,7 +80,7 @@ def test_cb_plan_epic_launches_thread(app_instance):
mock_get_history.assert_called_once() mock_get_history.assert_called_once()
mock_gen_tracks.assert_called_once() mock_gen_tracks.assert_called_once()
def test_process_pending_gui_tasks_mma_spawn_approval(app_instance): def test_process_pending_gui_tasks_mma_spawn_approval(app_instance: App) -> None:
"""Verifies that the 'mma_spawn_approval' action correctly updates the UI state.""" """Verifies that the 'mma_spawn_approval' action correctly updates the UI state."""
task = { task = {
"action": "mma_spawn_approval", "action": "mma_spawn_approval",
@@ -100,7 +100,7 @@ def test_process_pending_gui_tasks_mma_spawn_approval(app_instance):
assert task["dialog_container"][0] is not None assert task["dialog_container"][0] is not None
assert task["dialog_container"][0]._ticket_id == "T1" assert task["dialog_container"][0]._ticket_id == "T1"
def test_handle_ai_response_with_stream_id(app_instance): def test_handle_ai_response_with_stream_id(app_instance: App) -> None:
"""Verifies routing to mma_streams.""" """Verifies routing to mma_streams."""
task = { task = {
"action": "handle_ai_response", "action": "handle_ai_response",
@@ -116,7 +116,7 @@ def test_handle_ai_response_with_stream_id(app_instance):
assert app_instance.ai_status == "Thinking..." assert app_instance.ai_status == "Thinking..."
assert app_instance.ai_response == "" assert app_instance.ai_response == ""
def test_handle_ai_response_fallback(app_instance): def test_handle_ai_response_fallback(app_instance: App) -> None:
"""Verifies fallback to ai_response when stream_id is missing.""" """Verifies fallback to ai_response when stream_id is missing."""
task = { task = {
"action": "handle_ai_response", "action": "handle_ai_response",

View File

@@ -19,7 +19,7 @@ class TestOrchestratorPMHistory(unittest.TestCase):
if self.test_dir.exists(): if self.test_dir.exists():
shutil.rmtree(self.test_dir) shutil.rmtree(self.test_dir)
def create_track(self, parent_dir, track_id, title, status, overview): def create_track(self, parent_dir: Path, track_id: str, title: str, status: str, overview: str) -> None:
track_path = parent_dir / track_id track_path = parent_dir / track_id
track_path.mkdir(exist_ok=True) track_path.mkdir(exist_ok=True)
metadata = {"title": title, "status": status} metadata = {"title": title, "status": status}
@@ -30,8 +30,7 @@ class TestOrchestratorPMHistory(unittest.TestCase):
f.write(spec_content) f.write(spec_content)
@patch('orchestrator_pm.CONDUCTOR_PATH', Path("test_conductor")) @patch('orchestrator_pm.CONDUCTOR_PATH', Path("test_conductor"))
def test_get_track_history_summary(self): def test_get_track_history_summary(self) -> None:
# Setup mock tracks
self.create_track(self.archive_dir, "track_001", "Initial Setup", "completed", "Setting up the project structure.") self.create_track(self.archive_dir, "track_001", "Initial Setup", "completed", "Setting up the project structure.")
self.create_track(self.tracks_dir, "track_002", "Feature A", "in_progress", "Implementing Feature A.") self.create_track(self.tracks_dir, "track_002", "Feature A", "in_progress", "Implementing Feature A.")
summary = orchestrator_pm.get_track_history_summary() summary = orchestrator_pm.get_track_history_summary()
@@ -43,8 +42,7 @@ class TestOrchestratorPMHistory(unittest.TestCase):
self.assertIn("Implementing Feature A.", summary) self.assertIn("Implementing Feature A.", summary)
@patch('orchestrator_pm.CONDUCTOR_PATH', Path("test_conductor")) @patch('orchestrator_pm.CONDUCTOR_PATH', Path("test_conductor"))
def test_get_track_history_summary_missing_files(self): def test_get_track_history_summary_missing_files(self) -> None:
# Track with missing spec.md
track_path = self.tracks_dir / "track_003" track_path = self.tracks_dir / "track_003"
track_path.mkdir(exist_ok=True) track_path.mkdir(exist_ok=True)
with open(track_path / "metadata.json", "w") as f: with open(track_path / "metadata.json", "w") as f:
@@ -56,7 +54,7 @@ class TestOrchestratorPMHistory(unittest.TestCase):
@patch('orchestrator_pm.summarize.build_summary_markdown') @patch('orchestrator_pm.summarize.build_summary_markdown')
@patch('ai_client.send') @patch('ai_client.send')
def test_generate_tracks_with_history(self, mock_send, mock_summarize): def test_generate_tracks_with_history(self, mock_send: MagicMock, mock_summarize: MagicMock) -> None:
mock_summarize.return_value = "REPO_MAP" mock_summarize.return_value = "REPO_MAP"
mock_send.return_value = "[]" mock_send.return_value = "[]"
history_summary = "PAST_HISTORY_SUMMARY" history_summary = "PAST_HISTORY_SUMMARY"

View File

@@ -7,12 +7,11 @@ import asyncio
import concurrent.futures import concurrent.futures
class MockDialog: class MockDialog:
def __init__(self, approved, final_payload=None): def __init__(self, approved: bool, final_payload: dict | None = None) -> None:
self.approved = approved self.approved = approved
self.final_payload = final_payload self.final_payload = final_payload
def wait(self): def wait(self) -> dict:
# Match the new return format: a dictionary
res = {'approved': self.approved, 'abort': False} res = {'approved': self.approved, 'abort': False}
if self.final_payload: if self.final_payload:
res.update(self.final_payload) res.update(self.final_payload)
@@ -25,7 +24,7 @@ def mock_ai_client() -> None:
yield mock_send yield mock_send
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_confirm_spawn_pushed_to_queue(): async def test_confirm_spawn_pushed_to_queue() -> None:
event_queue = events.AsyncEventQueue() event_queue = events.AsyncEventQueue()
ticket_id = "T1" ticket_id = "T1"
role = "Tier 3 Worker" role = "Tier 3 Worker"
@@ -54,7 +53,7 @@ async def test_confirm_spawn_pushed_to_queue():
assert final_context == "Modified Context" assert final_context == "Modified Context"
@patch("multi_agent_conductor.confirm_spawn") @patch("multi_agent_conductor.confirm_spawn")
def test_run_worker_lifecycle_approved(mock_confirm, mock_ai_client): def test_run_worker_lifecycle_approved(mock_confirm: MagicMock, mock_ai_client: MagicMock) -> None:
ticket = Ticket(id="T1", description="desc", status="todo", assigned_to="user") ticket = Ticket(id="T1", description="desc", status="todo", assigned_to="user")
context = WorkerContext(ticket_id="T1", model_name="model", messages=[]) context = WorkerContext(ticket_id="T1", model_name="model", messages=[])
event_queue = events.AsyncEventQueue() event_queue = events.AsyncEventQueue()
@@ -68,7 +67,7 @@ def test_run_worker_lifecycle_approved(mock_confirm, mock_ai_client):
assert ticket.status == "completed" assert ticket.status == "completed"
@patch("multi_agent_conductor.confirm_spawn") @patch("multi_agent_conductor.confirm_spawn")
def test_run_worker_lifecycle_rejected(mock_confirm, mock_ai_client): def test_run_worker_lifecycle_rejected(mock_confirm: MagicMock, mock_ai_client: MagicMock) -> None:
ticket = Ticket(id="T1", description="desc", status="todo", assigned_to="user") ticket = Ticket(id="T1", description="desc", status="todo", assigned_to="user")
context = WorkerContext(ticket_id="T1", model_name="model", messages=[]) context = WorkerContext(ticket_id="T1", model_name="model", messages=[])
event_queue = events.AsyncEventQueue() event_queue = events.AsyncEventQueue()