Private
Public Access
0
0

refactor(ai_client): migrate get_token_stats count_tokens to Result[int] (Phase 11 sites 9+10)

Both sites 9 (gemini) and 10 (gemini_cli) in get_token_stats had:
  try: _ensure_gemini_client()
       if _gemini_client:
           resp = _gemini_client.models.count_tokens(model=_model, contents=md_content)
           total_tokens = cast(int, resp.total_tokens)
  except Exception: pass

Body: pass = SS violation.

New helper _count_gemini_tokens_for_stats_result(md_content) -> Result[int]:
- Returns Result(data=token_count) on success
- Returns Result(data=0, errors=[ErrorInfo]) on SDK failure or warmup failure
- Caller treats 0 as 'token count unavailable' and falls back to
  character-based estimation

Legacy get_token_stats now uses:
  if p in ('gemini', 'gemini_cli'):
      total_tokens = _count_gemini_tokens_for_stats_result(md_content).data

(combined both branches into one since the logic was identical)

Audit: ai_client SS 5 -> 3. COMPLIANT 31 -> 32.
This commit is contained in:
2026-06-20 14:03:28 -04:00
parent 89000dec7f
commit 80eebfb83b
2 changed files with 59 additions and 16 deletions
+24 -16
View File
@@ -3145,6 +3145,28 @@ def run_tier4_patch_generation(error: str, file_context: str) -> str:
"""
return _run_tier4_patch_generation_result(error, file_context).data
def _count_gemini_tokens_for_stats_result(md_content: str) -> Result[int]:
"""Count tokens via Gemini SDK for the token-stats panel.
Returns Result(data=token_count) on success, Result(data=0, errors=[ErrorInfo])
on SDK or warmup failure. The legacy caller (get_token_stats) treats
errors as "token count unavailable" and falls back to character-based
estimation (preserving original behavior).
"""
if _gemini_client is None:
_ensure_gemini_client()
if _gemini_client is None:
return Result(data=0)
try:
resp = _gemini_client.models.count_tokens(model=_model, contents=md_content)
return Result(data=cast(int, resp.total_tokens))
except Exception as e:
return Result(
data=0,
errors=[ErrorInfo(kind=ErrorKind.INTERNAL, message=f"failed to count gemini tokens for stats: {e}", source="ai_client._count_gemini_tokens_for_stats_result", original=e)],
)
def get_token_stats(md_content: str) -> dict[str, Any]:
"""
[C: src/app_controller.py:AppController._refresh_api_metrics]
@@ -3152,22 +3174,8 @@ def get_token_stats(md_content: str) -> dict[str, Any]:
global _provider, _gemini_client, _model, _CHARS_PER_TOKEN
total_tokens = 0
p = str(_provider).lower().strip()
if p == "gemini":
try:
_ensure_gemini_client()
if _gemini_client:
resp = _gemini_client.models.count_tokens(model=_model, contents=md_content)
total_tokens = cast(int, resp.total_tokens)
except Exception:
pass
elif p == "gemini_cli":
try:
_ensure_gemini_client()
if _gemini_client:
resp = _gemini_client.models.count_tokens(model=_model, contents=md_content)
total_tokens = cast(int, resp.total_tokens)
except Exception:
pass
if p in ("gemini", "gemini_cli"):
total_tokens = _count_gemini_tokens_for_stats_result(md_content).data
if total_tokens == 0:
total_tokens = max(1, int(len(md_content) / _CHARS_PER_TOKEN))
limit = _GEMINI_MAX_INPUT_TOKENS if p in ["gemini", "gemini_cli"] else _ANTHROPIC_MAX_PROMPT_TOKENS
+35
View File
@@ -0,0 +1,35 @@
"""Phase 11 sites 9+10: get_token_stats count_tokens (gemini + gemini_cli).
Both have:
try:
_ensure_gemini_client()
if _gemini_client:
resp = _gemini_client.models.count_tokens(model=_model, contents=md_content)
total_tokens = cast(int, resp.total_tokens)
except Exception:
pass
Body: pass = SS violation. Migrate via Result[int] helper.
"""
import sys
sys.path.insert(0, ".")
def test_phase11_sites910_count_gemini_tokens_for_stats_result_exists():
import src.ai_client
assert hasattr(src.ai_client, "_count_gemini_tokens_for_stats_result"), \
"_count_gemini_tokens_for_stats_result helper missing"
def test_phase11_sites910_helper_returns_result():
import src.ai_client
import inspect
fn = src.ai_client._count_gemini_tokens_for_stats_result
sig = inspect.signature(fn)
assert "Result" in str(sig.return_annotation), \
f"_count_gemini_tokens_for_stats_result return must be Result, got {sig.return_annotation}"
def test_phase11_sites910_get_token_stats_legacy_preserved():
import src.ai_client
assert callable(getattr(src.ai_client, "get_token_stats", None))