feat(testing): stabilize simulation suite and fix gemini caching
This commit is contained in:
85
ai_client.py
85
ai_client.py
@@ -617,7 +617,7 @@ def _send_gemini(md_content: str, user_message: str, base_dir: str,
|
||||
if _gemini_chat and _gemini_cache and _gemini_cache_created_at:
|
||||
elapsed = time.time() - _gemini_cache_created_at
|
||||
if elapsed > _GEMINI_CACHE_TTL * 0.9:
|
||||
old_history = list(_get_gemini_history_list(_gemini_chat)) if _get_gemini_history_list(_gemini_chat) else []
|
||||
old_history = list(_get_gemini_history_list(_gemini_chat)) if _get_gemini_history_list(_get_gemini_history_list(_gemini_chat)) else []
|
||||
try: _gemini_client.caches.delete(name=_gemini_cache.name)
|
||||
except Exception as e: _append_comms("OUT", "request", {"message": f"[CACHE DELETE WARN] {e}"})
|
||||
_gemini_chat = None
|
||||
@@ -633,28 +633,42 @@ def _send_gemini(md_content: str, user_message: str, base_dir: str,
|
||||
max_output_tokens=_max_tokens,
|
||||
safety_settings=[types.SafetySetting(category="HARM_CATEGORY_DANGEROUS_CONTENT", threshold="BLOCK_ONLY_HIGH")]
|
||||
)
|
||||
|
||||
# Check if context is large enough to warrant caching (min 2048 tokens usually)
|
||||
should_cache = False
|
||||
try:
|
||||
# Gemini requires 1024 (Flash) or 4096 (Pro) tokens to cache.
|
||||
_gemini_cache = _gemini_client.caches.create(
|
||||
model=_model,
|
||||
config=types.CreateCachedContentConfig(
|
||||
system_instruction=sys_instr,
|
||||
tools=tools_decl,
|
||||
ttl=f"{_GEMINI_CACHE_TTL}s",
|
||||
)
|
||||
)
|
||||
_gemini_cache_created_at = time.time()
|
||||
chat_config = types.GenerateContentConfig(
|
||||
cached_content=_gemini_cache.name,
|
||||
temperature=_temperature,
|
||||
max_output_tokens=_max_tokens,
|
||||
safety_settings=[types.SafetySetting(category="HARM_CATEGORY_DANGEROUS_CONTENT", threshold="BLOCK_ONLY_HIGH")]
|
||||
)
|
||||
_append_comms("OUT", "request", {"message": f"[CACHE CREATED] {_gemini_cache.name}"})
|
||||
count_resp = _gemini_client.models.count_tokens(model=_model, contents=[sys_instr])
|
||||
# We use a 2048 threshold to be safe across models
|
||||
if count_resp.total_tokens >= 2048:
|
||||
should_cache = True
|
||||
else:
|
||||
_append_comms("OUT", "request", {"message": f"[CACHING SKIPPED] Context too small ({count_resp.total_tokens} tokens < 2048)"})
|
||||
except Exception as e:
|
||||
_gemini_cache = None
|
||||
_gemini_cache_created_at = None
|
||||
_append_comms("OUT", "request", {"message": f"[CACHE FAILED] {type(e).__name__}: {e} — falling back to inline system_instruction"})
|
||||
_append_comms("OUT", "request", {"message": f"[COUNT FAILED] {e}"})
|
||||
|
||||
if should_cache:
|
||||
try:
|
||||
# Gemini requires 1024 (Flash) or 4096 (Pro) tokens to cache.
|
||||
_gemini_cache = _gemini_client.caches.create(
|
||||
model=_model,
|
||||
config=types.CreateCachedContentConfig(
|
||||
system_instruction=sys_instr,
|
||||
tools=tools_decl,
|
||||
ttl=f"{_GEMINI_CACHE_TTL}s",
|
||||
)
|
||||
)
|
||||
_gemini_cache_created_at = time.time()
|
||||
chat_config = types.GenerateContentConfig(
|
||||
cached_content=_gemini_cache.name,
|
||||
temperature=_temperature,
|
||||
max_output_tokens=_max_tokens,
|
||||
safety_settings=[types.SafetySetting(category="HARM_CATEGORY_DANGEROUS_CONTENT", threshold="BLOCK_ONLY_HIGH")]
|
||||
)
|
||||
_append_comms("OUT", "request", {"message": f"[CACHE CREATED] {_gemini_cache.name}"})
|
||||
except Exception as e:
|
||||
_gemini_cache = None
|
||||
_gemini_cache_created_at = None
|
||||
_append_comms("OUT", "request", {"message": f"[CACHE FAILED] {type(e).__name__}: {e} — falling back to inline system_instruction"})
|
||||
|
||||
kwargs = {"model": _model, "config": chat_config}
|
||||
if old_history:
|
||||
@@ -1290,11 +1304,29 @@ def get_history_bleed_stats(md_content: str | None = None) -> dict:
|
||||
if _gemini_chat:
|
||||
try:
|
||||
_ensure_gemini_client()
|
||||
history = list(_get_gemini_history_list(_gemini_chat))
|
||||
raw_history = list(_get_gemini_history_list(_gemini_chat))
|
||||
|
||||
# Copy and correct roles for counting
|
||||
history = []
|
||||
for c in raw_history:
|
||||
# Gemini roles MUST be 'user' or 'model'
|
||||
role = "model" if c.role in ["assistant", "model"] else "user"
|
||||
history.append(types.Content(role=role, parts=c.parts))
|
||||
|
||||
if md_content:
|
||||
# Prepend context as a user part for counting
|
||||
history.insert(0, types.Content(role="user", parts=[types.Part.from_text(text=md_content)]))
|
||||
|
||||
if not history:
|
||||
print("[DEBUG] Gemini count_tokens skipped: no history or md_content")
|
||||
return {
|
||||
"provider": "gemini",
|
||||
"limit": _GEMINI_MAX_INPUT_TOKENS,
|
||||
"current": 0,
|
||||
"percentage": 0,
|
||||
}
|
||||
|
||||
print(f"[DEBUG] Gemini count_tokens on {len(history)} messages using model {_model}")
|
||||
resp = _gemini_client.models.count_tokens(
|
||||
model=_model,
|
||||
contents=history
|
||||
@@ -1302,17 +1334,20 @@ def get_history_bleed_stats(md_content: str | None = None) -> dict:
|
||||
current_tokens = resp.total_tokens
|
||||
limit_tokens = _GEMINI_MAX_INPUT_TOKENS
|
||||
percentage = (current_tokens / limit_tokens) * 100 if limit_tokens > 0 else 0
|
||||
print(f"[DEBUG] Gemini current_tokens={current_tokens}, percentage={percentage:.4f}%")
|
||||
return {
|
||||
"provider": "gemini",
|
||||
"limit": limit_tokens,
|
||||
"current": current_tokens,
|
||||
"percentage": percentage,
|
||||
}
|
||||
except Exception:
|
||||
except Exception as e:
|
||||
print(f"[DEBUG] Gemini count_tokens error: {e}")
|
||||
pass
|
||||
elif md_content:
|
||||
try:
|
||||
_ensure_gemini_client()
|
||||
print(f"[DEBUG] Gemini count_tokens (MD ONLY) using model {_model}")
|
||||
resp = _gemini_client.models.count_tokens(
|
||||
model=_model,
|
||||
contents=[types.Content(role="user", parts=[types.Part.from_text(text=md_content)])]
|
||||
@@ -1320,13 +1355,15 @@ def get_history_bleed_stats(md_content: str | None = None) -> dict:
|
||||
current_tokens = resp.total_tokens
|
||||
limit_tokens = _GEMINI_MAX_INPUT_TOKENS
|
||||
percentage = (current_tokens / limit_tokens) * 100 if limit_tokens > 0 else 0
|
||||
print(f"[DEBUG] Gemini (MD ONLY) current_tokens={current_tokens}, percentage={percentage:.4f}%")
|
||||
return {
|
||||
"provider": "gemini",
|
||||
"limit": limit_tokens,
|
||||
"current": current_tokens,
|
||||
"percentage": percentage,
|
||||
}
|
||||
except Exception:
|
||||
except Exception as e:
|
||||
print(f"[DEBUG] Gemini count_tokens (MD ONLY) error: {e}")
|
||||
pass
|
||||
|
||||
return {
|
||||
|
||||
Reference in New Issue
Block a user