diff --git a/src/ai_client.py b/src/ai_client.py index 24922a32..583f0406 100644 --- a/src/ai_client.py +++ b/src/ai_client.py @@ -43,6 +43,7 @@ from src import mcp_tool_specs from src import mma_prompts from src import performance_monitor from src import project_manager +from src import provider_state from src.vendor_capabilities import VendorCapabilities, get_capabilities # TODO(Ed): Eliminate these? @@ -109,29 +110,29 @@ _gemini_cached_file_paths: list[str] = [] _GEMINI_CACHE_TTL: int = 3600 _anthropic_client: Optional[anthropic.Anthropic] = None -_anthropic_history: list[Metadata] = [] -_anthropic_history_lock: threading.Lock = threading.Lock() +_anthropic_history = provider_state.get_history("anthropic") +_anthropic_history_lock = _anthropic_history.lock _deepseek_client: Any = None -_deepseek_history: list[Metadata] = [] -_deepseek_history_lock: threading.Lock = threading.Lock() +_deepseek_history = provider_state.get_history("deepseek") +_deepseek_history_lock = _deepseek_history.lock _minimax_client: Any = None -_minimax_history: list[Metadata] = [] -_minimax_history_lock: threading.Lock = threading.Lock() +_minimax_history = provider_state.get_history("minimax") +_minimax_history_lock = _minimax_history.lock _qwen_client: Any = None -_qwen_history: list[Metadata] = [] -_qwen_history_lock: threading.Lock = threading.Lock() +_qwen_history = provider_state.get_history("qwen") +_qwen_history_lock = _qwen_history.lock _qwen_region: str = "china" _grok_client: Any = None -_grok_history: list[Metadata] = [] -_grok_history_lock: threading.Lock = threading.Lock() +_grok_history = provider_state.get_history("grok") +_grok_history_lock = _grok_history.lock _llama_client: Any = None -_llama_history: list[Metadata] = [] -_llama_history_lock: threading.Lock = threading.Lock() +_llama_history = provider_state.get_history("llama") +_llama_history_lock = _llama_history.lock _llama_base_url: str = "http://localhost:11434/v1" _llama_api_key: str = "ollama" @@ -461,10 +462,10 @@ def reset_session() -> None: """Clears conversation history and resets provider-specific session state.""" global _gemini_client, _gemini_chat, _gemini_cache global _gemini_cache_md_hash, _gemini_cache_created_at, _gemini_cached_file_paths - global _anthropic_client, _anthropic_history - global _deepseek_client, _deepseek_history - global _minimax_client, _minimax_history - global _qwen_client, _qwen_history + global _anthropic_client + global _deepseek_client + global _minimax_client + global _qwen_client global _CACHED_ANTHROPIC_TOOLS, _CACHED_DEEPSEEK_TOOLS global _gemini_cli_adapter if _gemini_client and _gemini_cache: @@ -475,29 +476,18 @@ def reset_session() -> None: _gemini_cache_md_hash = None _gemini_cache_created_at = None _gemini_cached_file_paths = [] - + # Preserve binary_path if adapter exists old_path = _gemini_cli_adapter.binary_path if _gemini_cli_adapter else "gemini" _gemini_cli_adapter = GeminiCliAdapter(binary_path=old_path) - + _anthropic_client = None - with _anthropic_history_lock: - _anthropic_history = [] + provider_state.clear_all() _deepseek_client = None - with _deepseek_history_lock: - _deepseek_history = [] _minimax_client = None - with _minimax_history_lock: - _minimax_history = [] _qwen_client = None - with _qwen_history_lock: - _qwen_history = [] _grok_client = None - with _grok_history_lock: - _grok_history = [] _llama_client = None - with _llama_history_lock: - _llama_history = [] _llama_base_url = "http://localhost:11434/v1" _llama_api_key = "ollama" _CACHED_ANTHROPIC_TOOLS = None diff --git a/src/provider_state.py b/src/provider_state.py index 78e374b4..c1302b22 100644 --- a/src/provider_state.py +++ b/src/provider_state.py @@ -22,11 +22,28 @@ from dataclasses import dataclass, field from src.type_aliases import HistoryMessage, Metadata +@dataclass @dataclass class ProviderHistory: messages: list[HistoryMessage] = field(default_factory=list) lock: threading.Lock = field(default_factory=threading.Lock) + def __bool__(self) -> bool: + with self.lock: + return bool(self.messages) + + def __len__(self) -> int: + with self.lock: + return len(self.messages) + + def __iter__(self): + with self.lock: + return iter(list(self.messages)) + + def __getitem__(self, idx): + with self.lock: + return self.messages[idx] + def append(self, message: HistoryMessage) -> None: with self.lock: self.messages.append(message) @@ -54,6 +71,16 @@ _PROVIDER_HISTORIES: dict[str, ProviderHistory] = { } +_PROVIDER_HISTORIES: dict[str, ProviderHistory] = { + "anthropic": ProviderHistory(), + "deepseek": ProviderHistory(), + "minimax": ProviderHistory(), + "qwen": ProviderHistory(), + "grok": ProviderHistory(), + "llama": ProviderHistory(), +} + + def get_history(provider: str) -> ProviderHistory: if provider not in _PROVIDER_HISTORIES: raise KeyError(f"Unknown provider: {provider!r}")