diff --git a/src/ai_client.py b/src/ai_client.py index 97663af2..b2431dfa 100644 --- a/src/ai_client.py +++ b/src/ai_client.py @@ -42,8 +42,8 @@ from src import mcp_client from src import mma_prompts from src import performance_monitor from src import project_manager -from src.openai_compatible import send_openai_compatible -from src.vendor_capabilities import VendorCapabilities +from src.openai_compatible import send_openai_compatible, OpenAICompatibleRequest +from src.vendor_capabilities import VendorCapabilities, get_capabilities # TODO(Ed): Eliminate these? from src.events import EventEmitter @@ -2244,8 +2244,7 @@ def _send_grok(md_content: str, user_message: str, base_dir: str, stream_callback: Optional[Callable[[str], None]] = None, patch_callback: Optional[Callable[[str, str], Optional[str]]] = None) -> str: client = _ensure_grok_client() - from src.openai_compatible import OpenAICompatibleRequest, send_openai_compatible - from src.vendor_capabilities import get_capabilities + tools: list[dict[str, Any]] | None = _get_deepseek_tools() or None with _grok_history_lock: user_content = user_message if file_items: @@ -2256,21 +2255,22 @@ def _send_grok(md_content: str, user_message: str, base_dir: str, _grok_history.append({"role": "user", "content": f"[DISCUSSION HISTORY]\n\n{discussion_history}\n\n---\n\n{user_message}"}) else: _grok_history.append({"role": "user", "content": user_content}) - messages = [{"role": "system", "content": f"{_get_combined_system_prompt()}\n\n\n{md_content}\n"}] - messages.extend(_grok_history) - request = OpenAICompatibleRequest( - messages=messages, - model=_model, - temperature=_temperature, - top_p=_top_p, - max_tokens=_max_tokens, - stream=stream, - stream_callback=stream_callback, - ) + def _build_grok_request(_round_idx: int) -> OpenAICompatibleRequest: + with _grok_history_lock: + messages: list[dict[str, Any]] = [{"role": "system", "content": f"{_get_combined_system_prompt()}\n\n\n{md_content}\n"}] + messages.extend(_grok_history) + return OpenAICompatibleRequest( + messages=messages, model=_model, temperature=_temperature, top_p=_top_p, + max_tokens=_max_tokens, stream=stream, stream_callback=stream_callback, + tools=tools, tool_choice="auto" if tools else "auto", + ) caps = get_capabilities("grok", _model) - response = send_openai_compatible(client, request, capabilities=caps) - _grok_history.append({"role": "assistant", "content": response.text}) - return response.text + return run_with_tool_loop( + client, _build_grok_request, capabilities=caps, + pre_tool_callback=pre_tool_callback, qa_callback=qa_callback, stream_callback=stream_callback, + patch_callback=patch_callback, base_dir=base_dir, vendor_name="grok", + history_lock=_grok_history_lock, history=_grok_history, + ) def _list_grok_models() -> list[str]: from src.vendor_capabilities import list_models_for_vendor @@ -2450,8 +2450,7 @@ def _send_llama(md_content: str, user_message: str, base_dir: str, stream_callback: Optional[Callable[[str], None]] = None, patch_callback: Optional[Callable[[str, str], Optional[str]]] = None) -> str: client = _ensure_llama_client() - from src.openai_compatible import OpenAICompatibleRequest, send_openai_compatible - from src.vendor_capabilities import get_capabilities + tools: list[dict[str, Any]] | None = _get_deepseek_tools() or None with _llama_history_lock: user_content = user_message if file_items: @@ -2462,21 +2461,22 @@ def _send_llama(md_content: str, user_message: str, base_dir: str, _llama_history.append({"role": "user", "content": f"[DISCUSSION HISTORY]\n\n{discussion_history}\n\n---\n\n{user_message}"}) else: _llama_history.append({"role": "user", "content": user_content}) - messages = [{"role": "system", "content": f"{_get_combined_system_prompt()}\n\n\n{md_content}\n"}] - messages.extend(_llama_history) - request = OpenAICompatibleRequest( - messages=messages, - model=_model, - temperature=_temperature, - top_p=_top_p, - max_tokens=_max_tokens, - stream=stream, - stream_callback=stream_callback, - ) + def _build_llama_request(_round_idx: int) -> OpenAICompatibleRequest: + with _llama_history_lock: + messages: list[dict[str, Any]] = [{"role": "system", "content": f"{_get_combined_system_prompt()}\n\n\n{md_content}\n"}] + messages.extend(_llama_history) + return OpenAICompatibleRequest( + messages=messages, model=_model, temperature=_temperature, top_p=_top_p, + max_tokens=_max_tokens, stream=stream, stream_callback=stream_callback, + tools=tools, tool_choice="auto" if tools else "auto", + ) caps = get_capabilities("llama", _model) - response = send_openai_compatible(client, request, capabilities=caps) - _llama_history.append({"role": "assistant", "content": response.text}) - return response.text + return run_with_tool_loop( + client, _build_llama_request, capabilities=caps, + pre_tool_callback=pre_tool_callback, qa_callback=qa_callback, stream_callback=stream_callback, + patch_callback=patch_callback, base_dir=base_dir, vendor_name="llama", + history_lock=_llama_history_lock, history=_llama_history, + ) def _list_llama_models() -> list[str]: from src.vendor_capabilities import list_models_for_vendor