Private
Public Access
0
0

refactor(ai_client): 14 module globals → provider_state.get_history() pattern

This commit is contained in:
2026-06-24 17:17:58 -04:00
parent 20236546d7
commit 25a2205722
2 changed files with 47 additions and 30 deletions
+20 -30
View File
@@ -43,6 +43,7 @@ from src import mcp_tool_specs
from src import mma_prompts
from src import performance_monitor
from src import project_manager
from src import provider_state
from src.vendor_capabilities import VendorCapabilities, get_capabilities
# TODO(Ed): Eliminate these?
@@ -109,29 +110,29 @@ _gemini_cached_file_paths: list[str] = []
_GEMINI_CACHE_TTL: int = 3600
_anthropic_client: Optional[anthropic.Anthropic] = None
_anthropic_history: list[Metadata] = []
_anthropic_history_lock: threading.Lock = threading.Lock()
_anthropic_history = provider_state.get_history("anthropic")
_anthropic_history_lock = _anthropic_history.lock
_deepseek_client: Any = None
_deepseek_history: list[Metadata] = []
_deepseek_history_lock: threading.Lock = threading.Lock()
_deepseek_history = provider_state.get_history("deepseek")
_deepseek_history_lock = _deepseek_history.lock
_minimax_client: Any = None
_minimax_history: list[Metadata] = []
_minimax_history_lock: threading.Lock = threading.Lock()
_minimax_history = provider_state.get_history("minimax")
_minimax_history_lock = _minimax_history.lock
_qwen_client: Any = None
_qwen_history: list[Metadata] = []
_qwen_history_lock: threading.Lock = threading.Lock()
_qwen_history = provider_state.get_history("qwen")
_qwen_history_lock = _qwen_history.lock
_qwen_region: str = "china"
_grok_client: Any = None
_grok_history: list[Metadata] = []
_grok_history_lock: threading.Lock = threading.Lock()
_grok_history = provider_state.get_history("grok")
_grok_history_lock = _grok_history.lock
_llama_client: Any = None
_llama_history: list[Metadata] = []
_llama_history_lock: threading.Lock = threading.Lock()
_llama_history = provider_state.get_history("llama")
_llama_history_lock = _llama_history.lock
_llama_base_url: str = "http://localhost:11434/v1"
_llama_api_key: str = "ollama"
@@ -461,10 +462,10 @@ def reset_session() -> None:
"""Clears conversation history and resets provider-specific session state."""
global _gemini_client, _gemini_chat, _gemini_cache
global _gemini_cache_md_hash, _gemini_cache_created_at, _gemini_cached_file_paths
global _anthropic_client, _anthropic_history
global _deepseek_client, _deepseek_history
global _minimax_client, _minimax_history
global _qwen_client, _qwen_history
global _anthropic_client
global _deepseek_client
global _minimax_client
global _qwen_client
global _CACHED_ANTHROPIC_TOOLS, _CACHED_DEEPSEEK_TOOLS
global _gemini_cli_adapter
if _gemini_client and _gemini_cache:
@@ -475,29 +476,18 @@ def reset_session() -> None:
_gemini_cache_md_hash = None
_gemini_cache_created_at = None
_gemini_cached_file_paths = []
# Preserve binary_path if adapter exists
old_path = _gemini_cli_adapter.binary_path if _gemini_cli_adapter else "gemini"
_gemini_cli_adapter = GeminiCliAdapter(binary_path=old_path)
_anthropic_client = None
with _anthropic_history_lock:
_anthropic_history = []
provider_state.clear_all()
_deepseek_client = None
with _deepseek_history_lock:
_deepseek_history = []
_minimax_client = None
with _minimax_history_lock:
_minimax_history = []
_qwen_client = None
with _qwen_history_lock:
_qwen_history = []
_grok_client = None
with _grok_history_lock:
_grok_history = []
_llama_client = None
with _llama_history_lock:
_llama_history = []
_llama_base_url = "http://localhost:11434/v1"
_llama_api_key = "ollama"
_CACHED_ANTHROPIC_TOOLS = None
+27
View File
@@ -22,11 +22,28 @@ from dataclasses import dataclass, field
from src.type_aliases import HistoryMessage, Metadata
@dataclass
@dataclass
class ProviderHistory:
messages: list[HistoryMessage] = field(default_factory=list)
lock: threading.Lock = field(default_factory=threading.Lock)
def __bool__(self) -> bool:
with self.lock:
return bool(self.messages)
def __len__(self) -> int:
with self.lock:
return len(self.messages)
def __iter__(self):
with self.lock:
return iter(list(self.messages))
def __getitem__(self, idx):
with self.lock:
return self.messages[idx]
def append(self, message: HistoryMessage) -> None:
with self.lock:
self.messages.append(message)
@@ -54,6 +71,16 @@ _PROVIDER_HISTORIES: dict[str, ProviderHistory] = {
}
_PROVIDER_HISTORIES: dict[str, ProviderHistory] = {
"anthropic": ProviderHistory(),
"deepseek": ProviderHistory(),
"minimax": ProviderHistory(),
"qwen": ProviderHistory(),
"grok": ProviderHistory(),
"llama": ProviderHistory(),
}
def get_history(provider: str) -> ProviderHistory:
if provider not in _PROVIDER_HISTORIES:
raise KeyError(f"Unknown provider: {provider!r}")