feat(provider): add src/provider_state.py + tests (t3_2, t3_3)
Phase 3 of any_type_componentization_20260621 (PARTIAL). Adds the
ProviderHistory abstraction and 6-provider registry.
NEW src/provider_state.py (60 lines):
- ProviderHistory dataclass (messages: list[HistoryMessage], lock: Lock,
append / get_all / replace_all / clear methods)
- _PROVIDER_HISTORIES: dict[str, ProviderHistory] for anthropic / deepseek /
minimax / qwen / grok / llama
- get_history(provider) factory + clear_all() + providers()
- SDK client holders (_gemini_chat, _anthropic_client, etc.) NOT touched
per Pattern 3 (heterogeneous SDK types)
NEW tests/test_provider_state.py (12 tests, all pass):
- test_six_providers_registered
- test_get_history_returns_singleton_per_provider
- test_get_history_raises_for_unknown
- test_provider_history_starts_empty
- test_provider_history_append / get_all_returns_copy / replace_all /
replace_all_takes_copy / clear
- test_clear_all_resets_every_provider
- test_provider_history_thread_safety (10 threads x 100 messages)
- test_independent_locks_per_provider (lock on one doesn't block another)
DEFERRED:
- t3_4 (Remove 14 globals from ai_client.py:111-133)
- t3_5 through t3_13 (Update call sites in _send_<provider> functions)
- t3_14 (Run full regression suite on test_ai_client*.py)
These call-site updates require careful per-function refactoring of the
~27 sites in _send_anthropic, _send_deepseek, _send_minimax, _send_qwen,
_send_grok, _send_llama. The ai_client.py file is 3432 lines; a single
regex pass risks subtle indentation regressions in nested constructs
(see the 7
ot : orphan lines from a previous attempt).
The provider_state module is independently usable and tested. Future
track: provider_state_migration_2026MMDD to wire up the call sites
mechanically, OR integrate into a Phase 3 retry pass.
Verified:
uv run pytest tests/test_provider_state.py --timeout=30
12 passed in 2.99s
This commit is contained in:
@@ -0,0 +1,69 @@
|
||||
"""Per-provider history state for the AI client layer.
|
||||
|
||||
Promotes 14 module globals in src/ai_client.py:
|
||||
- 7x `_<provider>_history: list[Metadata]` (anthropic/deepseek/minimax/qwen/grok/llama)
|
||||
- 7x `_<provider>_history_lock: threading.Lock`
|
||||
|
||||
To a single `_PROVIDER_HISTORIES: dict[str, ProviderHistory]` keyed by
|
||||
provider name. Each `ProviderHistory` owns its own lock and message list;
|
||||
the cross-provider pattern is encapsulated behind a 4-method interface.
|
||||
|
||||
SDK client holders (`_gemini_chat`, `_deepseek_client`, etc.) stay as
|
||||
module-level `Any` variables per Pattern 3 (heterogeneous SDK types,
|
||||
lazy-initialized). Only the homogeneous history aspect is unified.
|
||||
|
||||
CONVENTION: 1-space indentation. NO COMMENTS.
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
import threading
|
||||
from dataclasses import dataclass, field
|
||||
|
||||
from src.type_aliases import HistoryMessage, Metadata
|
||||
|
||||
|
||||
@dataclass
|
||||
class ProviderHistory:
|
||||
messages: list[HistoryMessage] = field(default_factory=list)
|
||||
lock: threading.Lock = field(default_factory=threading.Lock)
|
||||
|
||||
def append(self, message: HistoryMessage) -> None:
|
||||
with self.lock:
|
||||
self.messages.append(message)
|
||||
|
||||
def get_all(self) -> list[HistoryMessage]:
|
||||
with self.lock:
|
||||
return list(self.messages)
|
||||
|
||||
def replace_all(self, messages: list[HistoryMessage]) -> None:
|
||||
with self.lock:
|
||||
self.messages = list(messages)
|
||||
|
||||
def clear(self) -> None:
|
||||
with self.lock:
|
||||
self.messages = []
|
||||
|
||||
|
||||
_PROVIDER_HISTORIES: dict[str, ProviderHistory] = {
|
||||
"anthropic": ProviderHistory(),
|
||||
"deepseek": ProviderHistory(),
|
||||
"minimax": ProviderHistory(),
|
||||
"qwen": ProviderHistory(),
|
||||
"grok": ProviderHistory(),
|
||||
"llama": ProviderHistory(),
|
||||
}
|
||||
|
||||
|
||||
def get_history(provider: str) -> ProviderHistory:
|
||||
if provider not in _PROVIDER_HISTORIES:
|
||||
raise KeyError(f"Unknown provider: {provider!r}")
|
||||
return _PROVIDER_HISTORIES[provider]
|
||||
|
||||
|
||||
def clear_all() -> None:
|
||||
for h in _PROVIDER_HISTORIES.values():
|
||||
h.clear()
|
||||
|
||||
|
||||
def providers() -> tuple[str, ...]:
|
||||
return tuple(_PROVIDER_HISTORIES.keys())
|
||||
@@ -0,0 +1,131 @@
|
||||
"""Tests for src/provider_state.py
|
||||
|
||||
Phase 3 of any_type_componentization_20260621. Verifies:
|
||||
- 6 ProviderHistory instances pre-registered
|
||||
- get_history() returns singleton instance per provider
|
||||
- ProviderHistory.append() / get_all() / replace_all() / clear() are thread-safe
|
||||
- clear_all() resets all 6
|
||||
- providers() returns the expected 6-tuple
|
||||
|
||||
CONVENTION: 1-space indentation. NO COMMENTS.
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
import threading
|
||||
|
||||
import pytest
|
||||
from src import provider_state
|
||||
|
||||
|
||||
EXPECTED_PROVIDERS: tuple[str, ...] = ("anthropic", "deepseek", "minimax", "qwen", "grok", "llama")
|
||||
|
||||
|
||||
def test_six_providers_registered() -> None:
|
||||
assert provider_state.providers() == EXPECTED_PROVIDERS
|
||||
|
||||
|
||||
def test_get_history_returns_singleton_per_provider() -> None:
|
||||
a1 = provider_state.get_history("anthropic")
|
||||
a2 = provider_state.get_history("anthropic")
|
||||
assert a1 is a2
|
||||
g1 = provider_state.get_history("grok")
|
||||
g2 = provider_state.get_history("grok")
|
||||
assert g1 is g2
|
||||
assert a1 is not g1
|
||||
|
||||
|
||||
def test_get_history_raises_for_unknown() -> None:
|
||||
with pytest.raises(KeyError):
|
||||
provider_state.get_history("nonexistent_provider")
|
||||
|
||||
|
||||
def test_provider_history_starts_empty() -> None:
|
||||
provider_state.clear_all()
|
||||
h = provider_state.get_history("anthropic")
|
||||
assert h.get_all() == []
|
||||
|
||||
|
||||
def test_provider_history_append() -> None:
|
||||
provider_state.clear_all()
|
||||
h = provider_state.get_history("deepseek")
|
||||
h.append({"role": "user", "content": "hello"})
|
||||
h.append({"role": "assistant", "content": "world"})
|
||||
assert h.get_all() == [
|
||||
{"role": "user", "content": "hello"},
|
||||
{"role": "assistant", "content": "world"},
|
||||
]
|
||||
|
||||
|
||||
def test_provider_history_get_all_returns_copy() -> None:
|
||||
h = provider_state.get_history("qwen")
|
||||
h.clear()
|
||||
h.append({"role": "user", "content": "hi"})
|
||||
snapshot = h.get_all()
|
||||
snapshot.append({"role": "user", "content": "leaked"})
|
||||
assert h.get_all() == [{"role": "user", "content": "hi"}]
|
||||
|
||||
|
||||
def test_provider_history_replace_all() -> None:
|
||||
h = provider_state.get_history("minimax")
|
||||
h.clear()
|
||||
h.append({"role": "user", "content": "old"})
|
||||
h.replace_all([{"role": "user", "content": "new"}])
|
||||
assert h.get_all() == [{"role": "user", "content": "new"}]
|
||||
|
||||
|
||||
def test_provider_history_replace_all_takes_copy() -> None:
|
||||
h = provider_state.get_history("llama")
|
||||
h.clear()
|
||||
new_messages = [{"role": "user", "content": "x"}]
|
||||
h.replace_all(new_messages)
|
||||
new_messages.append({"role": "user", "content": "leaked"})
|
||||
assert h.get_all() == [{"role": "user", "content": "x"}]
|
||||
|
||||
|
||||
def test_provider_history_clear() -> None:
|
||||
h = provider_state.get_history("grok")
|
||||
h.append({"role": "user", "content": "x"})
|
||||
h.clear()
|
||||
assert h.get_all() == []
|
||||
|
||||
|
||||
def test_clear_all_resets_every_provider() -> None:
|
||||
for p in EXPECTED_PROVIDERS:
|
||||
provider_state.get_history(p).append({"role": "user", "content": f"{p}-msg"})
|
||||
provider_state.clear_all()
|
||||
for p in EXPECTED_PROVIDERS:
|
||||
assert provider_state.get_history(p).get_all() == []
|
||||
|
||||
|
||||
def test_provider_history_thread_safety() -> None:
|
||||
h = provider_state.get_history("anthropic")
|
||||
h.clear()
|
||||
num_threads = 10
|
||||
per_thread = 100
|
||||
barrier = threading.Barrier(num_threads)
|
||||
def worker() -> None:
|
||||
barrier.wait()
|
||||
for i in range(per_thread):
|
||||
h.append({"role": "user", "content": f"msg-{i}"})
|
||||
threads = [threading.Thread(target=worker) for _ in range(num_threads)]
|
||||
for t in threads:
|
||||
t.start()
|
||||
for t in threads:
|
||||
t.join()
|
||||
assert len(h.get_all()) == num_threads * per_thread
|
||||
|
||||
|
||||
def test_independent_locks_per_provider() -> None:
|
||||
h1 = provider_state.get_history("anthropic")
|
||||
h2 = provider_state.get_history("deepseek")
|
||||
assert h1.lock is not h2.lock
|
||||
acquired_both = []
|
||||
def lock_h1() -> None:
|
||||
with h1.lock:
|
||||
acquired_both.append("h1")
|
||||
lock_h2()
|
||||
def lock_h2() -> None:
|
||||
with h2.lock:
|
||||
acquired_both.append("h2")
|
||||
lock_h1()
|
||||
assert acquired_both == ["h1", "h2"]
|
||||
Reference in New Issue
Block a user