remove(ai_client): delete unused stub and proxy files
Deleted: - src/ai_client_stub.py - src/ai_client_proxy.py Fixed test imports to use ai_client instead of ai_client_stub.
This commit is contained in:
@@ -1,102 +0,0 @@
|
||||
import json
|
||||
import uuid
|
||||
import threading
|
||||
import subprocess
|
||||
import sys
|
||||
import os
|
||||
from typing import Any, Optional
|
||||
|
||||
|
||||
class AIProxyClient:
|
||||
def __init__(self):
|
||||
self._process: Optional[subprocess.Popen] = None
|
||||
self._status: str = "disconnected"
|
||||
self._pending: dict[str, Any] = {}
|
||||
self._reader_thread: Optional[threading.Thread] = None
|
||||
self._pending_lock: threading.Lock = threading.Lock()
|
||||
|
||||
@property
|
||||
def status(self) -> str:
|
||||
return self._status
|
||||
|
||||
def start_server(self):
|
||||
if self._process is not None:
|
||||
return
|
||||
self._process = subprocess.Popen(
|
||||
[sys.executable, "-m", "src.ai_server"],
|
||||
stdin=subprocess.PIPE,
|
||||
stdout=subprocess.PIPE,
|
||||
stderr=subprocess.PIPE,
|
||||
text=True,
|
||||
)
|
||||
self._status = "init"
|
||||
self._reader_thread = threading.Thread(target=self._read_loop, daemon=True)
|
||||
self._reader_thread.start()
|
||||
|
||||
def _read_loop(self):
|
||||
if self._process is None or self._process.stdout is None:
|
||||
return
|
||||
try:
|
||||
for line in self._process.stdout:
|
||||
line = line.strip()
|
||||
if not line:
|
||||
continue
|
||||
try:
|
||||
response = json.loads(line)
|
||||
if response.get("type") == "ready" and self._status == "init":
|
||||
self._status = "ready"
|
||||
continue
|
||||
rid = response.get("id")
|
||||
if rid in self._pending:
|
||||
self._pending[rid] = response
|
||||
event_key = rid + "_event"
|
||||
if event_key in self._pending:
|
||||
self._pending[event_key].set()
|
||||
except json.JSONDecodeError:
|
||||
pass
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
def send_command(self, method: str, params: dict[str, Any]) -> dict[str, Any]:
|
||||
if self._process is None or self._process.stdin is None:
|
||||
return {"error": "server not started"}
|
||||
|
||||
request_id = str(uuid.uuid4())
|
||||
event = threading.Event()
|
||||
self._pending[request_id] = None
|
||||
self._pending[request_id + "_event"] = event
|
||||
|
||||
command = {"id": request_id, "method": method, "params": params}
|
||||
try:
|
||||
self._process.stdin.write(json.dumps(command) + "\n")
|
||||
self._process.stdin.flush()
|
||||
except Exception as e:
|
||||
self._pending.pop(request_id, None)
|
||||
self._pending.pop(request_id + "_event", None)
|
||||
return {"error": str(e)}
|
||||
|
||||
if not event.wait(timeout=60):
|
||||
self._pending.pop(request_id, None)
|
||||
self._pending.pop(request_id + "_event", None)
|
||||
return {"error": "timeout"}
|
||||
|
||||
result = self._pending.pop(request_id, {"error": "response not found"})
|
||||
self._pending.pop(request_id + "_event", None)
|
||||
return result if result else {"error": "no response"}
|
||||
|
||||
def stop(self):
|
||||
if self._process:
|
||||
try:
|
||||
self._process.stdin.close()
|
||||
self._process.stdout.close()
|
||||
self._process.stderr.close()
|
||||
self._process.terminate()
|
||||
self._process.wait(timeout=5)
|
||||
except Exception:
|
||||
try:
|
||||
self._process.kill()
|
||||
except Exception:
|
||||
pass
|
||||
self._process = None
|
||||
self._status = "disconnected"
|
||||
self._pending.clear()
|
||||
@@ -1,370 +0,0 @@
|
||||
from __future__ import annotations
|
||||
import threading
|
||||
import datetime
|
||||
import time
|
||||
import os
|
||||
import json
|
||||
import hashlib
|
||||
from typing import Optional, Callable, Any, List, cast
|
||||
from collections import deque
|
||||
from pathlib import Path
|
||||
|
||||
from src.gemini_cli_adapter import GeminiCliAdapter
|
||||
|
||||
class EventEmitter:
|
||||
def __init__(self):
|
||||
self._handlers: dict[str, list[Callable]] = {}
|
||||
def on(self, event: str, callback: Callable) -> None:
|
||||
if event not in self._handlers:
|
||||
self._handlers[event] = []
|
||||
self._handlers[event].append(callback)
|
||||
def emit(self, event: str, **kwargs: Any) -> None:
|
||||
for cb in self._handlers.get(event, []):
|
||||
cb(**kwargs)
|
||||
|
||||
events = EventEmitter()
|
||||
|
||||
_provider: str = "gemini"
|
||||
_model: str = "gemini-2.5-flash-lite"
|
||||
_temperature: float = 0.0
|
||||
_top_p: float = 1.0
|
||||
_max_tokens: int = 8192
|
||||
_history_trunc_limit: int = 8000
|
||||
|
||||
_custom_system_prompt: str = ""
|
||||
_base_system_prompt_override: str = ""
|
||||
_use_default_base_system_prompt: bool = True
|
||||
_project_context_marker: str = ""
|
||||
|
||||
_local_storage = threading.local()
|
||||
_comms_log: deque[dict[str, Any]] = deque(maxlen=1000)
|
||||
|
||||
_tool_approval_modes: dict[str, str] = {}
|
||||
_active_tool_preset = None
|
||||
_active_bias_profile = None
|
||||
_agent_tools: dict[str, bool] = {}
|
||||
_active_bias_profile_name: Optional[str] = None
|
||||
|
||||
confirm_and_run_callback: Optional[Callable[..., Optional[str]]] = None
|
||||
comms_log_callback: Optional[Callable[[dict[str, Any]], None]] = None
|
||||
tool_log_callback: Optional[Callable[[str, str], None]] = None
|
||||
|
||||
COMMS_CLAMP_CHARS: int = 300
|
||||
MAX_TOOL_ROUNDS: int = 10
|
||||
MAX_TOOL_OUTPUT_BYTES: int = 500_000
|
||||
|
||||
_ai_proxy = None
|
||||
|
||||
def _get_proxy():
|
||||
global _ai_proxy
|
||||
if _ai_proxy is None and os.environ.get("AI_SERVER_ENABLED"):
|
||||
try:
|
||||
from src.ai_client_proxy import AIProxyClient
|
||||
_ai_proxy = AIProxyClient()
|
||||
_ai_proxy.start_server()
|
||||
except Exception:
|
||||
_ai_proxy = None
|
||||
return _ai_proxy
|
||||
|
||||
class ProviderError(Exception):
|
||||
def __init__(self, kind: str, provider: str, original: Exception) -> None:
|
||||
self.kind = kind
|
||||
self.provider = provider
|
||||
self.original = original
|
||||
super().__init__(str(original))
|
||||
def ui_message(self) -> str:
|
||||
labels = {"quota": "QUOTA EXHAUSTED", "rate_limit": "RATE LIMITED", "auth": "AUTH / API KEY ERROR", "balance": "BALANCE / BILLING ERROR", "network": "NETWORK / CONNECTION ERROR", "unknown": "API ERROR"}
|
||||
label = labels.get(self.kind, "API ERROR")
|
||||
return f"[{self.provider.upper()} {label}]\n\n{self.original}"
|
||||
|
||||
def get_current_tier() -> Optional[str]:
|
||||
return getattr(_local_storage, "current_tier", None)
|
||||
|
||||
def set_current_tier(tier: Optional[str]) -> None:
|
||||
_local_storage.current_tier = tier
|
||||
|
||||
def get_comms_log_callback() -> Optional[Callable[[dict[str, Any]], None]]:
|
||||
tl_cb = getattr(_local_storage, "comms_log_callback", None)
|
||||
if tl_cb:
|
||||
return tl_cb
|
||||
return comms_log_callback
|
||||
|
||||
def set_comms_log_callback(cb: Optional[Callable[[dict[str, Any]], None]]) -> None:
|
||||
global comms_log_callback
|
||||
comms_log_callback = cb
|
||||
_local_storage.comms_log_callback = cb
|
||||
|
||||
_SYSTEM_PROMPT = (
|
||||
"You are a helpful coding assistant with access to a PowerShell tool (run_powershell) and MCP tools."
|
||||
)
|
||||
|
||||
def set_custom_system_prompt(prompt: str) -> None:
|
||||
global _custom_system_prompt
|
||||
_custom_system_prompt = prompt
|
||||
|
||||
def set_base_system_prompt(prompt: str) -> None:
|
||||
global _base_system_prompt_override
|
||||
_base_system_prompt_override = prompt
|
||||
|
||||
def set_use_default_base_prompt(use_default: bool) -> None:
|
||||
global _use_default_base_system_prompt
|
||||
_use_default_base_system_prompt = use_default
|
||||
|
||||
def set_project_context_marker(marker: str) -> None:
|
||||
global _project_context_marker
|
||||
_project_context_marker = marker
|
||||
|
||||
def _get_combined_system_prompt() -> str:
|
||||
if _use_default_base_system_prompt:
|
||||
base = _SYSTEM_PROMPT
|
||||
else:
|
||||
base = _base_system_prompt_override
|
||||
if _custom_system_prompt.strip():
|
||||
base = f"{base}\n\n[USER SYSTEM PROMPT]\n{_custom_system_prompt}"
|
||||
return base
|
||||
|
||||
def get_combined_system_prompt() -> str:
|
||||
return _get_combined_system_prompt()
|
||||
|
||||
def _append_comms(direction: str, kind: str, payload: dict[str, Any]) -> None:
|
||||
entry: dict[str, Any] = {"ts": datetime.datetime.now().strftime("%H:%M:%S"), "direction": direction, "kind": kind, "provider": _provider, "model": _model, "payload": payload, "source_tier": get_current_tier(), "local_ts": time.time()}
|
||||
_comms_log.append(entry)
|
||||
_cb = get_comms_log_callback()
|
||||
if _cb is not None:
|
||||
_cb(entry)
|
||||
|
||||
def get_comms_log() -> list[dict[str, Any]]:
|
||||
return list(_comms_log)
|
||||
|
||||
def clear_comms_log() -> None:
|
||||
_comms_log.clear()
|
||||
|
||||
def get_credentials_path() -> Path:
|
||||
return Path(os.environ.get("SLOP_CREDENTIALS", str(Path(__file__).parent.parent / "credentials.toml")))
|
||||
|
||||
def _load_credentials() -> dict[str, Any]:
|
||||
import tomllib
|
||||
cred_path = get_credentials_path()
|
||||
try:
|
||||
with open(cred_path, "rb") as f:
|
||||
return tomllib.load(f)
|
||||
except FileNotFoundError:
|
||||
raise FileNotFoundError(f"Credentials file not found: {cred_path}")
|
||||
|
||||
def set_provider(provider: str, model: str) -> None:
|
||||
global _provider, _model
|
||||
_provider = provider
|
||||
if provider == "gemini_cli":
|
||||
if model != "mock" and not any(m in model for m in ["deepseek"]):
|
||||
_model = model
|
||||
else:
|
||||
_model = "gemini-3-flash-preview"
|
||||
else:
|
||||
_model = model
|
||||
|
||||
def get_provider() -> str:
|
||||
return _provider
|
||||
|
||||
def set_model_params(temp: float, max_tok: int, trunc_limit: int = 8000, top_p: float = 1.0) -> None:
|
||||
global _temperature, _max_tokens, _history_trunc_limit, _top_p
|
||||
_temperature = temp
|
||||
_max_tokens = max_tok
|
||||
_history_trunc_limit = trunc_limit
|
||||
_top_p = top_p
|
||||
|
||||
def set_agent_tools(tools: dict[str, bool]) -> None:
|
||||
global _agent_tools
|
||||
_agent_tools = tools
|
||||
|
||||
def set_tool_preset(preset_name: Optional[str]) -> None:
|
||||
global _tool_approval_modes, _active_tool_preset
|
||||
_tool_approval_modes = {}
|
||||
if not preset_name or preset_name == "None":
|
||||
from src import mcp_client
|
||||
_agent_tools = {name: True for name in mcp_client.TOOL_NAMES}
|
||||
_agent_tools["run_powershell"] = True
|
||||
_active_tool_preset = None
|
||||
else:
|
||||
try:
|
||||
from src.tool_presets import ToolPresetManager
|
||||
manager = ToolPresetManager()
|
||||
presets = manager.load_all()
|
||||
if preset_name in presets:
|
||||
preset = presets[preset_name]
|
||||
_active_tool_preset = preset
|
||||
from src import mcp_client
|
||||
new_tools = {name: False for name in mcp_client.TOOL_NAMES}
|
||||
new_tools["run_powershell"] = False
|
||||
for cat in preset.categories.values():
|
||||
for tool in cat:
|
||||
name = tool.name
|
||||
new_tools[name] = True
|
||||
_tool_approval_modes[name] = tool.approval
|
||||
_agent_tools = new_tools
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
def set_bias_profile(profile_name: Optional[str]) -> None:
|
||||
global _active_bias_profile, _active_bias_profile_name
|
||||
if not profile_name or profile_name == "None":
|
||||
_active_bias_profile = None
|
||||
_active_bias_profile_name = None
|
||||
else:
|
||||
try:
|
||||
from src.tool_presets import ToolPresetManager
|
||||
manager = ToolPresetManager()
|
||||
profiles = manager.load_all_bias_profiles()
|
||||
if profile_name in profiles:
|
||||
_active_bias_profile = profiles[profile_name]
|
||||
_active_bias_profile_name = profile_name
|
||||
else:
|
||||
_active_bias_profile = None
|
||||
_active_bias_profile_name = None
|
||||
except Exception:
|
||||
_active_bias_profile = None
|
||||
_active_bias_profile_name = None
|
||||
|
||||
def get_bias_profile() -> Optional[str]:
|
||||
return _active_bias_profile_name
|
||||
|
||||
_gemini_cli_adapter = None
|
||||
|
||||
def cleanup() -> None:
|
||||
global _gemini_cli_adapter
|
||||
proxy = _get_proxy()
|
||||
if proxy and proxy.status == "ready":
|
||||
proxy.send_command("cleanup", {})
|
||||
if _gemini_cli_adapter:
|
||||
old_path = _gemini_cli_adapter.binary_path
|
||||
_gemini_cli_adapter = None
|
||||
else:
|
||||
old_path = "gemini"
|
||||
from src.gemini_cli_adapter import GeminiCliAdapter
|
||||
_gemini_cli_adapter = GeminiCliAdapter(binary_path=old_path)
|
||||
|
||||
def reset_session() -> None:
|
||||
global _gemini_cli_adapter
|
||||
proxy = _get_proxy()
|
||||
if proxy and proxy.status == "ready":
|
||||
proxy.send_command("reset_session", {})
|
||||
if _gemini_cli_adapter:
|
||||
old_path = _gemini_cli_adapter.binary_path
|
||||
else:
|
||||
old_path = "gemini"
|
||||
from src.gemini_cli_adapter import GeminiCliAdapter
|
||||
_gemini_cli_adapter = GeminiCliAdapter(binary_path=old_path)
|
||||
_comms_log.clear()
|
||||
|
||||
def get_gemini_cache_stats() -> dict[str, Any]:
|
||||
proxy = _get_proxy()
|
||||
if proxy and proxy.status == "ready":
|
||||
result = proxy.send_command("get_gemini_cache_stats", {})
|
||||
if "result" in result:
|
||||
return result["result"]
|
||||
return {"cache_count": 0, "total_size_bytes": 0, "cached_files": []}
|
||||
|
||||
def list_models(provider: str) -> list[str]:
|
||||
proxy = _get_proxy()
|
||||
if proxy and proxy.status == "ready":
|
||||
result = proxy.send_command("list_models", {"provider": provider})
|
||||
if "result" in result:
|
||||
return result["result"].get("models", [])
|
||||
if provider == "gemini":
|
||||
try:
|
||||
from google import genai
|
||||
creds = _load_credentials()
|
||||
client = genai.Client(api_key=creds["gemini"]["api_key"])
|
||||
models = []
|
||||
for m in client.models.list():
|
||||
name = m.name
|
||||
if name and name.startswith("models/"):
|
||||
name = name[len("models/"):]
|
||||
if name and "gemini" in name.lower():
|
||||
models.append(name)
|
||||
return sorted(models)
|
||||
except Exception:
|
||||
return []
|
||||
elif provider == "anthropic":
|
||||
try:
|
||||
import anthropic
|
||||
creds = _load_credentials()
|
||||
client = anthropic.Anthropic(api_key=creds["anthropic"]["api_key"])
|
||||
return sorted([m.id for m in client.models.list()])
|
||||
except Exception:
|
||||
return []
|
||||
elif provider == "deepseek":
|
||||
return ["deepseek-chat", "deepseek-reasoner"]
|
||||
elif provider == "gemini_cli":
|
||||
return ["gemini-3-flash-preview", "gemini-3.1-pro-preview", "gemini-2.5-pro", "gemini-2.5-flash", "gemini-2.0-flash", "gemini-2.5-flash-lite"]
|
||||
elif provider == "minimax":
|
||||
try:
|
||||
from openai import OpenAI
|
||||
creds = _load_credentials()
|
||||
client = OpenAI(api_key=creds["minimax"]["api_key"], base_url="https://api.minimax.io/v1")
|
||||
return sorted([m.id for m in client.models.list()])
|
||||
except Exception:
|
||||
return ["MiniMax-M2.7", "MiniMax-M2.5", "MiniMax-M2.1", "MiniMax-M2"]
|
||||
return []
|
||||
|
||||
def send(md_content: str, user_message: str, base_dir: str,
|
||||
file_items: Optional[list[dict[str, Any]]] = None,
|
||||
discussion_history: str = "",
|
||||
pre_tool_callback: Optional[Callable] = None,
|
||||
qa_callback: Optional[Callable] = None,
|
||||
enable_tools: bool = True,
|
||||
stream_callback: Optional[Callable[[str], None]] = None,
|
||||
patch_callback: Optional[Callable[[str, str], Optional[str]]] = None) -> str:
|
||||
proxy = _get_proxy()
|
||||
if proxy and proxy.status == "ready":
|
||||
result = proxy.send_command("send", {
|
||||
"md_content": md_content,
|
||||
"user_message": user_message,
|
||||
"base_dir": base_dir,
|
||||
"file_items": file_items or [],
|
||||
"discussion_history": discussion_history,
|
||||
"pre_tool_callback": pre_tool_callback is not None,
|
||||
"enable_tools": enable_tools,
|
||||
})
|
||||
if "result" in result:
|
||||
return result["result"].get("response", "")
|
||||
return "ERROR: AI server not available"
|
||||
|
||||
def get_token_stats(md_content: str) -> dict[str, Any]:
|
||||
proxy = _get_proxy()
|
||||
if proxy and proxy.status == "ready":
|
||||
result = proxy.send_command("get_token_stats", {"md_content": md_content})
|
||||
if "result" in result:
|
||||
return result["result"]
|
||||
return {"input_tokens": 0, "output_tokens": 0, "total_tokens": 0, "cached_tokens": 0}
|
||||
|
||||
def run_tier4_analysis(error: str) -> str:
|
||||
proxy = _get_proxy()
|
||||
if proxy and proxy.status == "ready":
|
||||
result = proxy.send_command("run_tier4_analysis", {"error": error})
|
||||
if "result" in result:
|
||||
return result["result"].get("analysis", "")
|
||||
return ""
|
||||
|
||||
def run_tier4_patch_callback(script: str, base_dir: str) -> Optional[str]:
|
||||
proxy = _get_proxy()
|
||||
if proxy and proxy.status == "ready":
|
||||
result = proxy.send_command("run_tier4_patch_callback", {"script": script, "base_dir": base_dir})
|
||||
if "result" in result:
|
||||
return result["result"].get("output")
|
||||
return None
|
||||
|
||||
def run_tier4_patch_generation(error: str, context: str) -> str:
|
||||
proxy = _get_proxy()
|
||||
if proxy and proxy.status == "ready":
|
||||
result = proxy.send_command("run_tier4_patch_generation", {"error": error, "context": context})
|
||||
if "result" in result:
|
||||
return result["result"].get("diff", "")
|
||||
return ""
|
||||
|
||||
def run_subagent_summarization(text: str, system_prompt: str, provider: str = "gemini") -> str:
|
||||
proxy = _get_proxy()
|
||||
if proxy and proxy.status == "ready":
|
||||
result = proxy.send_command("run_subagent_summarization", {"text": text, "system_prompt": system_prompt, "provider": provider})
|
||||
if "result" in result:
|
||||
return result["result"].get("summary", "")
|
||||
return ""
|
||||
@@ -18,7 +18,7 @@ from pathlib import Path
|
||||
from pydantic import BaseModel
|
||||
from typing import Any, List, Dict, Optional, Callable
|
||||
from src import aggregate
|
||||
from src import ai_client_stub as ai_client
|
||||
from src import ai_client
|
||||
from src import conductor_tech_lead
|
||||
from src import events
|
||||
from src import mcp_client
|
||||
@@ -1625,7 +1625,7 @@ class AppController:
|
||||
Stops background threads and cleans up resources.
|
||||
[C: src/gui_2.py:App.run, src/gui_2.py:App.shutdown, tests/conftest.py:app_instance, tests/conftest.py:mock_app]
|
||||
"""
|
||||
from src import ai_client_stub as ai_client
|
||||
from src import ai_client
|
||||
ai_client.cleanup()
|
||||
if hasattr(self, 'hook_server') and self.hook_server:
|
||||
self.hook_server.stop()
|
||||
@@ -3009,7 +3009,7 @@ class AppController:
|
||||
self._update_cached_stats()
|
||||
|
||||
def _update_cached_stats(self) -> None:
|
||||
from src import ai_client_stub as ai_client
|
||||
from src import ai_client
|
||||
self._cached_cache_stats = ai_client.get_gemini_cache_stats()
|
||||
self._cached_tool_stats = dict(self._tool_stats)
|
||||
|
||||
@@ -3017,7 +3017,7 @@ class AppController:
|
||||
"""
|
||||
[C: src/gui_2.py:App._render_cache_panel]
|
||||
"""
|
||||
from src import ai_client_stub as ai_client
|
||||
from src import ai_client
|
||||
ai_client.cleanup()
|
||||
self._update_cached_stats()
|
||||
|
||||
|
||||
@@ -88,7 +88,7 @@ def test_on_tool_log_offloading(app_controller, tmp_session_dir):
|
||||
script = "Get-Process"
|
||||
result = "Process list..."
|
||||
|
||||
with patch("src.ai_client_stub.get_current_tier", return_value="Tier 3"):
|
||||
with patch("src.ai_client.get_current_tier", return_value="Tier 3"):
|
||||
app_controller._on_tool_log(script, result)
|
||||
|
||||
# Verify files were created in session directory
|
||||
|
||||
@@ -17,8 +17,8 @@ def app_instance():
|
||||
patch('src.app_controller.AppController._prune_old_logs'),
|
||||
patch('src.app_controller.AppController.start_services'),
|
||||
patch('src.api_hooks.HookServer'),
|
||||
patch('src.ai_client_stub.set_provider'),
|
||||
patch('src.ai_client_stub.reset_session')
|
||||
patch('src.ai_client.set_provider'),
|
||||
patch('src.ai_client.reset_session')
|
||||
):
|
||||
app = App()
|
||||
app.project = {
|
||||
|
||||
@@ -31,7 +31,7 @@ def test_telemetry_data_updates_correctly(app_instance: Any) -> None:
|
||||
"percentage": 75.0,
|
||||
}
|
||||
# 3. Patch the dependencies
|
||||
with patch('src.ai_client_stub.get_token_stats', return_value=mock_stats) as mock_get_stats:
|
||||
with patch('src.ai_client.get_token_stats', return_value=mock_stats) as mock_get_stats:
|
||||
# 4. Call the method under test
|
||||
app_instance._refresh_api_metrics({}, md_content="test content")
|
||||
# 5. Assert the results
|
||||
@@ -60,7 +60,7 @@ def test_gui_updates_on_event(app_instance: App) -> None:
|
||||
"""
|
||||
mock_stats = {"percentage": 50.0, "current": 500, "limit": 1000}
|
||||
app_instance.last_md = "mock_md"
|
||||
with patch('src.ai_client_stub.get_token_stats', return_value=mock_stats):
|
||||
with patch('src.ai_client.get_token_stats', return_value=mock_stats):
|
||||
# Drain the queue
|
||||
while not app_instance.event_queue.empty():
|
||||
app_instance.event_queue.get()
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
from typing import Generator
|
||||
import pytest
|
||||
from unittest.mock import patch, Mock
|
||||
from src import ai_client_stub
|
||||
from src import ai_client
|
||||
from src.gui_2 import App
|
||||
|
||||
@pytest.fixture
|
||||
@@ -20,8 +20,8 @@ def app_instance() -> Generator[App, None, None]:
|
||||
patch('src.app_controller.AppController.start_services'),
|
||||
# Do not patch _init_ai_and_hooks to ensure _settable_fields is initialized
|
||||
patch('src.api_hooks.HookServer'),
|
||||
patch('src.ai_client_stub.set_provider'),
|
||||
patch('src.ai_client_stub.reset_session')
|
||||
patch('src.ai_client.set_provider'),
|
||||
patch('src.ai_client.reset_session')
|
||||
):
|
||||
app = App()
|
||||
yield app
|
||||
@@ -30,8 +30,8 @@ def test_redundant_calls_in_process_pending_gui_tasks(app_instance: App) -> None
|
||||
app_instance.controller._pending_gui_tasks = [
|
||||
{'action': 'set_value', 'item': 'current_provider', 'value': 'anthropic'}
|
||||
]
|
||||
with patch('src.ai_client_stub.set_provider') as mock_set_provider, \
|
||||
patch('src.ai_client_stub.reset_session') as mock_reset_session:
|
||||
with patch('src.ai_client.set_provider') as mock_set_provider, \
|
||||
patch('src.ai_client.reset_session') as mock_reset_session:
|
||||
app_instance.controller._process_pending_gui_tasks()
|
||||
assert mock_set_provider.call_count == 1
|
||||
assert mock_reset_session.call_count == 1
|
||||
@@ -42,10 +42,10 @@ def test_gcli_path_updates_adapter(app_instance: App) -> None:
|
||||
{'action': 'set_value', 'item': 'gcli_path', 'value': '/new/path/to/gemini'}
|
||||
]
|
||||
# Initialize adapter if it doesn't exist (it shouldn't in mock env)
|
||||
ai_client_stub._gemini_cli_adapter = None
|
||||
ai_client._gemini_cli_adapter = None
|
||||
app_instance.controller._process_pending_gui_tasks()
|
||||
assert ai_client_stub._gemini_cli_adapter is not None
|
||||
assert ai_client_stub._gemini_cli_adapter.binary_path == '/new/path/to/gemini'
|
||||
assert ai_client._gemini_cli_adapter is not None
|
||||
assert ai_client._gemini_cli_adapter.binary_path == '/new/path/to/gemini'
|
||||
|
||||
def test_process_pending_gui_tasks_drag(app_instance: App) -> None:
|
||||
"""Test that the drag action is correctly processed and dispatches to the registered callback."""
|
||||
|
||||
Reference in New Issue
Block a user