refactor(ai_client): 14 module globals → provider_state.get_history() pattern
This commit is contained in:
+20
-30
@@ -43,6 +43,7 @@ from src import mcp_tool_specs
|
||||
from src import mma_prompts
|
||||
from src import performance_monitor
|
||||
from src import project_manager
|
||||
from src import provider_state
|
||||
from src.vendor_capabilities import VendorCapabilities, get_capabilities
|
||||
|
||||
# TODO(Ed): Eliminate these?
|
||||
@@ -109,29 +110,29 @@ _gemini_cached_file_paths: list[str] = []
|
||||
_GEMINI_CACHE_TTL: int = 3600
|
||||
|
||||
_anthropic_client: Optional[anthropic.Anthropic] = None
|
||||
_anthropic_history: list[Metadata] = []
|
||||
_anthropic_history_lock: threading.Lock = threading.Lock()
|
||||
_anthropic_history = provider_state.get_history("anthropic")
|
||||
_anthropic_history_lock = _anthropic_history.lock
|
||||
|
||||
_deepseek_client: Any = None
|
||||
_deepseek_history: list[Metadata] = []
|
||||
_deepseek_history_lock: threading.Lock = threading.Lock()
|
||||
_deepseek_history = provider_state.get_history("deepseek")
|
||||
_deepseek_history_lock = _deepseek_history.lock
|
||||
|
||||
_minimax_client: Any = None
|
||||
_minimax_history: list[Metadata] = []
|
||||
_minimax_history_lock: threading.Lock = threading.Lock()
|
||||
_minimax_history = provider_state.get_history("minimax")
|
||||
_minimax_history_lock = _minimax_history.lock
|
||||
|
||||
_qwen_client: Any = None
|
||||
_qwen_history: list[Metadata] = []
|
||||
_qwen_history_lock: threading.Lock = threading.Lock()
|
||||
_qwen_history = provider_state.get_history("qwen")
|
||||
_qwen_history_lock = _qwen_history.lock
|
||||
_qwen_region: str = "china"
|
||||
|
||||
_grok_client: Any = None
|
||||
_grok_history: list[Metadata] = []
|
||||
_grok_history_lock: threading.Lock = threading.Lock()
|
||||
_grok_history = provider_state.get_history("grok")
|
||||
_grok_history_lock = _grok_history.lock
|
||||
|
||||
_llama_client: Any = None
|
||||
_llama_history: list[Metadata] = []
|
||||
_llama_history_lock: threading.Lock = threading.Lock()
|
||||
_llama_history = provider_state.get_history("llama")
|
||||
_llama_history_lock = _llama_history.lock
|
||||
_llama_base_url: str = "http://localhost:11434/v1"
|
||||
_llama_api_key: str = "ollama"
|
||||
|
||||
@@ -461,10 +462,10 @@ def reset_session() -> None:
|
||||
"""Clears conversation history and resets provider-specific session state."""
|
||||
global _gemini_client, _gemini_chat, _gemini_cache
|
||||
global _gemini_cache_md_hash, _gemini_cache_created_at, _gemini_cached_file_paths
|
||||
global _anthropic_client, _anthropic_history
|
||||
global _deepseek_client, _deepseek_history
|
||||
global _minimax_client, _minimax_history
|
||||
global _qwen_client, _qwen_history
|
||||
global _anthropic_client
|
||||
global _deepseek_client
|
||||
global _minimax_client
|
||||
global _qwen_client
|
||||
global _CACHED_ANTHROPIC_TOOLS, _CACHED_DEEPSEEK_TOOLS
|
||||
global _gemini_cli_adapter
|
||||
if _gemini_client and _gemini_cache:
|
||||
@@ -475,29 +476,18 @@ def reset_session() -> None:
|
||||
_gemini_cache_md_hash = None
|
||||
_gemini_cache_created_at = None
|
||||
_gemini_cached_file_paths = []
|
||||
|
||||
|
||||
# Preserve binary_path if adapter exists
|
||||
old_path = _gemini_cli_adapter.binary_path if _gemini_cli_adapter else "gemini"
|
||||
_gemini_cli_adapter = GeminiCliAdapter(binary_path=old_path)
|
||||
|
||||
|
||||
_anthropic_client = None
|
||||
with _anthropic_history_lock:
|
||||
_anthropic_history = []
|
||||
provider_state.clear_all()
|
||||
_deepseek_client = None
|
||||
with _deepseek_history_lock:
|
||||
_deepseek_history = []
|
||||
_minimax_client = None
|
||||
with _minimax_history_lock:
|
||||
_minimax_history = []
|
||||
_qwen_client = None
|
||||
with _qwen_history_lock:
|
||||
_qwen_history = []
|
||||
_grok_client = None
|
||||
with _grok_history_lock:
|
||||
_grok_history = []
|
||||
_llama_client = None
|
||||
with _llama_history_lock:
|
||||
_llama_history = []
|
||||
_llama_base_url = "http://localhost:11434/v1"
|
||||
_llama_api_key = "ollama"
|
||||
_CACHED_ANTHROPIC_TOOLS = None
|
||||
|
||||
@@ -22,11 +22,28 @@ from dataclasses import dataclass, field
|
||||
from src.type_aliases import HistoryMessage, Metadata
|
||||
|
||||
|
||||
@dataclass
|
||||
@dataclass
|
||||
class ProviderHistory:
|
||||
messages: list[HistoryMessage] = field(default_factory=list)
|
||||
lock: threading.Lock = field(default_factory=threading.Lock)
|
||||
|
||||
def __bool__(self) -> bool:
|
||||
with self.lock:
|
||||
return bool(self.messages)
|
||||
|
||||
def __len__(self) -> int:
|
||||
with self.lock:
|
||||
return len(self.messages)
|
||||
|
||||
def __iter__(self):
|
||||
with self.lock:
|
||||
return iter(list(self.messages))
|
||||
|
||||
def __getitem__(self, idx):
|
||||
with self.lock:
|
||||
return self.messages[idx]
|
||||
|
||||
def append(self, message: HistoryMessage) -> None:
|
||||
with self.lock:
|
||||
self.messages.append(message)
|
||||
@@ -54,6 +71,16 @@ _PROVIDER_HISTORIES: dict[str, ProviderHistory] = {
|
||||
}
|
||||
|
||||
|
||||
_PROVIDER_HISTORIES: dict[str, ProviderHistory] = {
|
||||
"anthropic": ProviderHistory(),
|
||||
"deepseek": ProviderHistory(),
|
||||
"minimax": ProviderHistory(),
|
||||
"qwen": ProviderHistory(),
|
||||
"grok": ProviderHistory(),
|
||||
"llama": ProviderHistory(),
|
||||
}
|
||||
|
||||
|
||||
def get_history(provider: str) -> ProviderHistory:
|
||||
if provider not in _PROVIDER_HISTORIES:
|
||||
raise KeyError(f"Unknown provider: {provider!r}")
|
||||
|
||||
Reference in New Issue
Block a user