diff --git a/src/provider_state.py b/src/provider_state.py new file mode 100644 index 00000000..78e374b4 --- /dev/null +++ b/src/provider_state.py @@ -0,0 +1,69 @@ +"""Per-provider history state for the AI client layer. + +Promotes 14 module globals in src/ai_client.py: +- 7x `__history: list[Metadata]` (anthropic/deepseek/minimax/qwen/grok/llama) +- 7x `__history_lock: threading.Lock` + +To a single `_PROVIDER_HISTORIES: dict[str, ProviderHistory]` keyed by +provider name. Each `ProviderHistory` owns its own lock and message list; +the cross-provider pattern is encapsulated behind a 4-method interface. + +SDK client holders (`_gemini_chat`, `_deepseek_client`, etc.) stay as +module-level `Any` variables per Pattern 3 (heterogeneous SDK types, +lazy-initialized). Only the homogeneous history aspect is unified. + +CONVENTION: 1-space indentation. NO COMMENTS. +""" +from __future__ import annotations + +import threading +from dataclasses import dataclass, field + +from src.type_aliases import HistoryMessage, Metadata + + +@dataclass +class ProviderHistory: + messages: list[HistoryMessage] = field(default_factory=list) + lock: threading.Lock = field(default_factory=threading.Lock) + + def append(self, message: HistoryMessage) -> None: + with self.lock: + self.messages.append(message) + + def get_all(self) -> list[HistoryMessage]: + with self.lock: + return list(self.messages) + + def replace_all(self, messages: list[HistoryMessage]) -> None: + with self.lock: + self.messages = list(messages) + + def clear(self) -> None: + with self.lock: + self.messages = [] + + +_PROVIDER_HISTORIES: dict[str, ProviderHistory] = { + "anthropic": ProviderHistory(), + "deepseek": ProviderHistory(), + "minimax": ProviderHistory(), + "qwen": ProviderHistory(), + "grok": ProviderHistory(), + "llama": ProviderHistory(), +} + + +def get_history(provider: str) -> ProviderHistory: + if provider not in _PROVIDER_HISTORIES: + raise KeyError(f"Unknown provider: {provider!r}") + return _PROVIDER_HISTORIES[provider] + + +def clear_all() -> None: + for h in _PROVIDER_HISTORIES.values(): + h.clear() + + +def providers() -> tuple[str, ...]: + return tuple(_PROVIDER_HISTORIES.keys()) \ No newline at end of file diff --git a/tests/test_provider_state.py b/tests/test_provider_state.py new file mode 100644 index 00000000..5bd689e9 --- /dev/null +++ b/tests/test_provider_state.py @@ -0,0 +1,131 @@ +"""Tests for src/provider_state.py + +Phase 3 of any_type_componentization_20260621. Verifies: +- 6 ProviderHistory instances pre-registered +- get_history() returns singleton instance per provider +- ProviderHistory.append() / get_all() / replace_all() / clear() are thread-safe +- clear_all() resets all 6 +- providers() returns the expected 6-tuple + +CONVENTION: 1-space indentation. NO COMMENTS. +""" +from __future__ import annotations + +import threading + +import pytest +from src import provider_state + + +EXPECTED_PROVIDERS: tuple[str, ...] = ("anthropic", "deepseek", "minimax", "qwen", "grok", "llama") + + +def test_six_providers_registered() -> None: + assert provider_state.providers() == EXPECTED_PROVIDERS + + +def test_get_history_returns_singleton_per_provider() -> None: + a1 = provider_state.get_history("anthropic") + a2 = provider_state.get_history("anthropic") + assert a1 is a2 + g1 = provider_state.get_history("grok") + g2 = provider_state.get_history("grok") + assert g1 is g2 + assert a1 is not g1 + + +def test_get_history_raises_for_unknown() -> None: + with pytest.raises(KeyError): + provider_state.get_history("nonexistent_provider") + + +def test_provider_history_starts_empty() -> None: + provider_state.clear_all() + h = provider_state.get_history("anthropic") + assert h.get_all() == [] + + +def test_provider_history_append() -> None: + provider_state.clear_all() + h = provider_state.get_history("deepseek") + h.append({"role": "user", "content": "hello"}) + h.append({"role": "assistant", "content": "world"}) + assert h.get_all() == [ + {"role": "user", "content": "hello"}, + {"role": "assistant", "content": "world"}, + ] + + +def test_provider_history_get_all_returns_copy() -> None: + h = provider_state.get_history("qwen") + h.clear() + h.append({"role": "user", "content": "hi"}) + snapshot = h.get_all() + snapshot.append({"role": "user", "content": "leaked"}) + assert h.get_all() == [{"role": "user", "content": "hi"}] + + +def test_provider_history_replace_all() -> None: + h = provider_state.get_history("minimax") + h.clear() + h.append({"role": "user", "content": "old"}) + h.replace_all([{"role": "user", "content": "new"}]) + assert h.get_all() == [{"role": "user", "content": "new"}] + + +def test_provider_history_replace_all_takes_copy() -> None: + h = provider_state.get_history("llama") + h.clear() + new_messages = [{"role": "user", "content": "x"}] + h.replace_all(new_messages) + new_messages.append({"role": "user", "content": "leaked"}) + assert h.get_all() == [{"role": "user", "content": "x"}] + + +def test_provider_history_clear() -> None: + h = provider_state.get_history("grok") + h.append({"role": "user", "content": "x"}) + h.clear() + assert h.get_all() == [] + + +def test_clear_all_resets_every_provider() -> None: + for p in EXPECTED_PROVIDERS: + provider_state.get_history(p).append({"role": "user", "content": f"{p}-msg"}) + provider_state.clear_all() + for p in EXPECTED_PROVIDERS: + assert provider_state.get_history(p).get_all() == [] + + +def test_provider_history_thread_safety() -> None: + h = provider_state.get_history("anthropic") + h.clear() + num_threads = 10 + per_thread = 100 + barrier = threading.Barrier(num_threads) + def worker() -> None: + barrier.wait() + for i in range(per_thread): + h.append({"role": "user", "content": f"msg-{i}"}) + threads = [threading.Thread(target=worker) for _ in range(num_threads)] + for t in threads: + t.start() + for t in threads: + t.join() + assert len(h.get_all()) == num_threads * per_thread + + +def test_independent_locks_per_provider() -> None: + h1 = provider_state.get_history("anthropic") + h2 = provider_state.get_history("deepseek") + assert h1.lock is not h2.lock + acquired_both = [] + def lock_h1() -> None: + with h1.lock: + acquired_both.append("h1") + lock_h2() + def lock_h2() -> None: + with h2.lock: + acquired_both.append("h2") + lock_h1() + assert acquired_both == ["h1", "h2"] \ No newline at end of file