diff --git a/ai_client.py b/ai_client.py index 89c2db0..c63d625 100644 --- a/ai_client.py +++ b/ai_client.py @@ -21,6 +21,7 @@ import difflib import threading import requests from pathlib import Path +from typing import Optional, Callable import os import project_manager import file_cache @@ -522,10 +523,10 @@ def _gemini_tool_declaration(): return types.Tool(function_declarations=declarations) if declarations else None -def _run_script(script: str, base_dir: str) -> str: +def _run_script(script: str, base_dir: str, qa_callback: Optional[Callable[[str], str]] = None) -> str: if confirm_and_run_callback is None: return "ERROR: no confirmation handler registered" - result = confirm_and_run_callback(script, base_dir) + result = confirm_and_run_callback(script, base_dir, qa_callback) if result is None: output = "USER REJECTED: command was not executed" else: @@ -669,7 +670,8 @@ def _get_gemini_history_list(chat): def _send_gemini(md_content: str, user_message: str, base_dir: str, file_items: list[dict] | None = None, discussion_history: str = "", - pre_tool_callback = None) -> str: + pre_tool_callback = None, + qa_callback: Optional[Callable[[str], str]] = None) -> str: global _gemini_chat, _gemini_cache, _gemini_cache_md_hash, _gemini_cache_created_at try: @@ -848,7 +850,7 @@ def _send_gemini(md_content: str, user_message: str, base_dir: str, elif name == TOOL_NAME: scr = args.get("script", "") _append_comms("OUT", "tool_call", {"name": TOOL_NAME, "script": scr}) - out = _run_script(scr, base_dir) + out = _run_script(scr, base_dir, qa_callback) else: out = f"ERROR: unknown tool '{name}'" if i == len(calls) - 1: @@ -880,7 +882,8 @@ def _send_gemini(md_content: str, user_message: str, base_dir: str, def _send_gemini_cli(md_content: str, user_message: str, base_dir: str, file_items: list[dict] | None = None, discussion_history: str = "", - pre_tool_callback = None) -> str: + pre_tool_callback = None, + qa_callback: Optional[Callable[[str], str]] = None) -> str: global _gemini_cli_adapter try: if _gemini_cli_adapter is None: @@ -984,7 +987,7 @@ def _send_gemini_cli(md_content: str, user_message: str, base_dir: str, elif name == TOOL_NAME: scr = args.get("script", "") _append_comms("OUT", "tool_call", {"name": TOOL_NAME, "id": call_id, "script": scr}) - out = _run_script(scr, base_dir) + out = _run_script(scr, base_dir, qa_callback) else: out = f"ERROR: unknown tool '{name}'" @@ -1277,7 +1280,7 @@ def _repair_anthropic_history(history: list[dict]): }) -def _send_anthropic(md_content: str, user_message: str, base_dir: str, file_items: list[dict] | None = None, discussion_history: str = "", pre_tool_callback = None) -> str: +def _send_anthropic(md_content: str, user_message: str, base_dir: str, file_items: list[dict] | None = None, discussion_history: str = "", pre_tool_callback = None, qa_callback: Optional[Callable[[str], str]] = None) -> str: try: _ensure_anthropic_client() mcp_client.configure(file_items or [], [base_dir]) @@ -1441,7 +1444,7 @@ def _send_anthropic(md_content: str, user_message: str, base_dir: str, file_item "id": b_id, "script": script, }) - output = _run_script(script, base_dir) + output = _run_script(script, base_dir, qa_callback) _append_comms("IN", "tool_result", { "name": TOOL_NAME, "id": b_id, @@ -1520,7 +1523,8 @@ def _send_deepseek(md_content: str, user_message: str, base_dir: str, file_items: list[dict] | None = None, discussion_history: str = "", stream: bool = False, - pre_tool_callback = None) -> str: + pre_tool_callback = None, + qa_callback: Optional[Callable[[str], str]] = None) -> str: """ Sends a message to the DeepSeek API, handling tool calls and history. Supports streaming responses. @@ -1713,7 +1717,7 @@ def _send_deepseek(md_content: str, user_message: str, base_dir: str, elif tool_name == TOOL_NAME: script = tool_args.get("script", "") _append_comms("OUT", "tool_call", {"name": TOOL_NAME, "id": tool_id, "script": script}) - tool_output = _run_script(script, base_dir) + tool_output = _run_script(script, base_dir, qa_callback) else: tool_output = f"ERROR: unknown tool '{tool_name}'" @@ -1811,6 +1815,7 @@ def send( discussion_history: str = "", stream: bool = False, pre_tool_callback = None, + qa_callback: Optional[Callable[[str], str]] = None, ) -> str: """ Send a message to the active provider. @@ -1825,16 +1830,17 @@ def send( conversation message instead of caching it) stream : Whether to use streaming (supported by DeepSeek) pre_tool_callback : Optional callback (payload: str) -> bool called before tool execution + qa_callback : Optional callback (stderr: str) -> str called for Tier 4 error analysis """ with _send_lock: if _provider == "gemini": - return _send_gemini(md_content, user_message, base_dir, file_items, discussion_history, pre_tool_callback) + return _send_gemini(md_content, user_message, base_dir, file_items, discussion_history, pre_tool_callback, qa_callback) elif _provider == "gemini_cli": - return _send_gemini_cli(md_content, user_message, base_dir, file_items, discussion_history, pre_tool_callback) + return _send_gemini_cli(md_content, user_message, base_dir, file_items, discussion_history, pre_tool_callback, qa_callback) elif _provider == "anthropic": - return _send_anthropic(md_content, user_message, base_dir, file_items, discussion_history, pre_tool_callback) + return _send_anthropic(md_content, user_message, base_dir, file_items, discussion_history, pre_tool_callback, qa_callback) elif _provider == "deepseek": - return _send_deepseek(md_content, user_message, base_dir, file_items, discussion_history, stream=stream, pre_tool_callback=pre_tool_callback) + return _send_deepseek(md_content, user_message, base_dir, file_items, discussion_history, stream=stream, pre_tool_callback=pre_tool_callback, qa_callback=qa_callback) raise ValueError(f"unknown provider: {_provider}") def get_history_bleed_stats(md_content: str | None = None) -> dict: diff --git a/gui_2.py b/gui_2.py index e2a5d7f..72a9a55 100644 --- a/gui_2.py +++ b/gui_2.py @@ -1066,7 +1066,7 @@ class App: self.is_viewing_prior_session = True self.ai_status = f"viewing prior session: {Path(path).name} ({len(entries)} entries)" - def _confirm_and_run(self, script: str, base_dir: str) -> str | None: + def _confirm_and_run(self, script: str, base_dir: str, qa_callback: Optional[Callable[[str], str]] = None) -> str | None: print(f"[DEBUG] _confirm_and_run triggered for script length: {len(script)}") dialog = ConfirmDialog(script, base_dir) @@ -1106,7 +1106,7 @@ class App: self.ai_status = "running powershell..." print(f"[DEBUG] Running powershell in {base_dir}") - output = shell_runner.run_powershell(final_script, base_dir) + output = shell_runner.run_powershell(final_script, base_dir, qa_callback=qa_callback) self._append_tool_log(final_script, output) self.ai_status = "powershell done, awaiting AI..." return output diff --git a/gui_legacy.py b/gui_legacy.py index 49ad3d4..fd44622 100644 --- a/gui_legacy.py +++ b/gui_legacy.py @@ -910,7 +910,7 @@ class App: # ---------------------------------------------------------------- tool execution - def _confirm_and_run(self, script: str, base_dir: str) -> str | None: + def _confirm_and_run(self, script: str, base_dir: str, qa_callback: Optional[Callable[[str], str]] = None) -> str | None: dialog = ConfirmDialog(script, base_dir) with self._pending_dialog_lock: @@ -923,7 +923,7 @@ class App: return None self._update_status("running powershell...") - output = shell_runner.run_powershell(final_script, base_dir) + output = shell_runner.run_powershell(final_script, base_dir, qa_callback=qa_callback) self._append_tool_log(final_script, output) self._update_status("powershell done, awaiting AI...") return output diff --git a/multi_agent_conductor.py b/multi_agent_conductor.py index 338e8f5..2f9467f 100644 --- a/multi_agent_conductor.py +++ b/multi_agent_conductor.py @@ -82,7 +82,8 @@ def run_worker_lifecycle(ticket: Ticket, context: WorkerContext, context_files: md_content="", user_message=user_message, base_dir=".", - pre_tool_callback=confirm_execution if ticket.step_mode else None + pre_tool_callback=confirm_execution if ticket.step_mode else None, + qa_callback=ai_client.run_tier4_analysis ) if "BLOCKED" in response.upper(): diff --git a/tests/test_tier4_interceptor.py b/tests/test_tier4_interceptor.py index d03b61d..f29d331 100644 --- a/tests/test_tier4_interceptor.py +++ b/tests/test_tier4_interceptor.py @@ -126,3 +126,99 @@ def test_end_to_end_tier4_integration(): mock_analysis.assert_called_once_with(stderr_content) assert f"QA ANALYSIS:\n{expected_analysis}" in output + +def test_ai_client_passes_qa_callback(): + """ + Verifies that ai_client.send passes the qa_callback down to the provider function. + """ + import ai_client + + # Mocking a provider function to avoid actual API calls + mock_send_gemini = MagicMock(return_value="AI Response") + + qa_callback = MagicMock(return_value="QA Analysis") + + # Force provider to gemini and mock its send function + with patch("ai_client._provider", "gemini"), \ + patch("ai_client._send_gemini", mock_send_gemini): + + ai_client.send( + md_content="Context", + user_message="Hello", + qa_callback=qa_callback + ) + + # Verify provider received the qa_callback + mock_send_gemini.assert_called_once() + args, kwargs = mock_send_gemini.call_args + # qa_callback is the 7th positional argument in _send_gemini + assert args[6] == qa_callback + +def test_gemini_provider_passes_qa_callback_to_run_script(): + """ + Verifies that _send_gemini passes the qa_callback to _run_script. + """ + import ai_client + + # Mock Gemini chat and client + mock_client = MagicMock() + mock_chat = MagicMock() + + # Simulate a tool call response + mock_part = MagicMock() + mock_part.text = "" + mock_part.function_call = MagicMock() + mock_part.function_call.name = "run_powershell" + mock_part.function_call.args = {"script": "dir"} + + mock_candidate = MagicMock() + mock_candidate.content.parts = [mock_part] + mock_candidate.finish_reason.name = "STOP" + + mock_response = MagicMock() + mock_response.candidates = [mock_candidate] + mock_response.usage_metadata.prompt_token_count = 10 + mock_response.usage_metadata.candidates_token_count = 5 + + # Second call returns a stop response to break the loop + mock_stop_part = MagicMock() + mock_stop_part.text = "Done" + mock_stop_part.function_call = None + + mock_stop_candidate = MagicMock() + mock_stop_candidate.content.parts = [mock_stop_part] + mock_stop_candidate.finish_reason.name = "STOP" + + mock_stop_response = MagicMock() + mock_stop_response.candidates = [mock_stop_candidate] + mock_stop_response.usage_metadata.prompt_token_count = 5 + mock_stop_response.usage_metadata.candidates_token_count = 2 + + mock_chat.send_message.side_effect = [mock_response, mock_stop_response] + + # Mock count_tokens to avoid chat creation failure + mock_count_resp = MagicMock() + mock_count_resp.total_tokens = 100 + mock_client.models.count_tokens.return_value = mock_count_resp + + qa_callback = MagicMock() + + # Set global state for the test + with patch("ai_client._gemini_client", mock_client), \ + patch("ai_client._gemini_chat", None), \ + patch("ai_client._ensure_gemini_client"), \ + patch("ai_client._run_script", return_value="output") as mock_run_script, \ + patch("ai_client._get_gemini_history_list", return_value=[]): + + # Ensure chats.create returns our mock_chat + mock_client.chats.create.return_value = mock_chat + + ai_client._send_gemini( + md_content="Context", + user_message="Run dir", + base_dir=".", + qa_callback=qa_callback + ) + + # Verify _run_script received the qa_callback + mock_run_script.assert_called_once_with("dir", ".", qa_callback)