checkpoint: finished test curation
This commit is contained in:
163
ai_client.py
163
ai_client.py
@@ -20,6 +20,7 @@ import difflib
|
||||
import threading
|
||||
from pathlib import Path
|
||||
import os
|
||||
import project_manager
|
||||
import file_cache
|
||||
import mcp_client
|
||||
import anthropic
|
||||
@@ -44,6 +45,13 @@ def set_model_params(temp: float, max_tok: int, trunc_limit: int = 8000):
|
||||
_max_tokens = max_tok
|
||||
_history_trunc_limit = trunc_limit
|
||||
|
||||
def get_history_trunc_limit() -> int:
|
||||
return _history_trunc_limit
|
||||
|
||||
def set_history_trunc_limit(val: int):
|
||||
global _history_trunc_limit
|
||||
_history_trunc_limit = val
|
||||
|
||||
_gemini_client = None
|
||||
_gemini_chat = None
|
||||
_gemini_cache = None
|
||||
@@ -800,11 +808,10 @@ def _send_gemini_cli(md_content: str, user_message: str, base_dir: str,
|
||||
try:
|
||||
if _gemini_cli_adapter is None:
|
||||
_gemini_cli_adapter = GeminiCliAdapter(binary_path="gemini")
|
||||
|
||||
events.emit("request_start", payload={"provider": "gemini_cli", "model": _model, "round": 0})
|
||||
|
||||
|
||||
mcp_client.configure(file_items or [], [base_dir])
|
||||
|
||||
# If it's a new session (session_id is None), we should ideally send the context.
|
||||
# For now, following the simple pattern:
|
||||
payload = user_message
|
||||
if _gemini_cli_adapter.session_id is None:
|
||||
# Prepend context and discussion history to the first message
|
||||
@@ -814,23 +821,104 @@ def _send_gemini_cli(md_content: str, user_message: str, base_dir: str,
|
||||
full_prompt += user_message
|
||||
payload = full_prompt
|
||||
|
||||
_append_comms("OUT", "request", {"message": f"[CLI] [msg {len(payload)}]"})
|
||||
|
||||
result_text = _gemini_cli_adapter.send(payload)
|
||||
|
||||
usage = _gemini_cli_adapter.last_usage or {}
|
||||
latency = _gemini_cli_adapter.last_latency
|
||||
events.emit("response_received", payload={"provider": "gemini_cli", "model": _model, "usage": usage, "latency": latency, "round": 0})
|
||||
|
||||
_append_comms("IN", "response", {
|
||||
"round": 0,
|
||||
"stop_reason": "STOP",
|
||||
"text": result_text,
|
||||
"tool_calls": [],
|
||||
"usage": usage
|
||||
})
|
||||
|
||||
return result_text
|
||||
all_text = []
|
||||
_cumulative_tool_bytes = 0
|
||||
|
||||
for r_idx in range(MAX_TOOL_ROUNDS + 2):
|
||||
events.emit("request_start", payload={"provider": "gemini_cli", "model": _model, "round": r_idx})
|
||||
_append_comms("OUT", "request", {"message": f"[CLI] [round {r_idx}] [msg {len(payload)}]"})
|
||||
|
||||
resp_data = _gemini_cli_adapter.send(payload)
|
||||
txt = resp_data.get("text", "")
|
||||
if txt: all_text.append(txt)
|
||||
|
||||
calls = resp_data.get("tool_calls", [])
|
||||
usage = _gemini_cli_adapter.last_usage or {}
|
||||
latency = _gemini_cli_adapter.last_latency
|
||||
|
||||
events.emit("response_received", payload={"provider": "gemini_cli", "model": _model, "usage": usage, "latency": latency, "round": r_idx})
|
||||
|
||||
# Clean up the tool calls format to match comms log expectation
|
||||
log_calls = []
|
||||
for c in calls:
|
||||
log_calls.append({"name": c.get("name"), "args": c.get("args")})
|
||||
|
||||
_append_comms("IN", "response", {
|
||||
"round": r_idx,
|
||||
"stop_reason": "TOOL_USE" if calls else "STOP",
|
||||
"text": txt,
|
||||
"tool_calls": log_calls,
|
||||
"usage": usage
|
||||
})
|
||||
|
||||
# If there's text and we're not done, push it to the history immediately
|
||||
# so it appears as a separate entry in the GUI.
|
||||
if txt and calls and comms_log_callback:
|
||||
# Use kind='history_add' to push a new entry into the disc_entries list
|
||||
comms_log_callback({
|
||||
"ts": project_manager.now_ts(),
|
||||
"direction": "IN",
|
||||
"kind": "history_add",
|
||||
"payload": {
|
||||
"role": "AI",
|
||||
"content": txt
|
||||
}
|
||||
})
|
||||
|
||||
if not calls or r_idx > MAX_TOOL_ROUNDS:
|
||||
break
|
||||
|
||||
tool_results_for_cli = []
|
||||
for i, fc in enumerate(calls):
|
||||
name = fc.get("name")
|
||||
args = fc.get("args", {})
|
||||
call_id = fc.get("id")
|
||||
|
||||
events.emit("tool_execution", payload={"status": "started", "tool": name, "args": args, "round": r_idx})
|
||||
if name in mcp_client.TOOL_NAMES:
|
||||
_append_comms("OUT", "tool_call", {"name": name, "id": call_id, "args": args})
|
||||
out = mcp_client.dispatch(name, args)
|
||||
elif name == TOOL_NAME:
|
||||
scr = args.get("script", "")
|
||||
_append_comms("OUT", "tool_call", {"name": TOOL_NAME, "id": call_id, "script": scr})
|
||||
out = _run_script(scr, base_dir)
|
||||
else:
|
||||
out = f"ERROR: unknown tool '{name}'"
|
||||
|
||||
if i == len(calls) - 1:
|
||||
if file_items:
|
||||
file_items, changed = _reread_file_items(file_items)
|
||||
ctx = _build_file_diff_text(changed)
|
||||
if ctx:
|
||||
out += f"\n\n[SYSTEM: FILES UPDATED]\n\n{ctx}"
|
||||
if r_idx == MAX_TOOL_ROUNDS:
|
||||
out += "\n\n[SYSTEM: MAX ROUNDS. PROVIDE FINAL ANSWER.]"
|
||||
|
||||
out = _truncate_tool_output(out)
|
||||
_cumulative_tool_bytes += len(out)
|
||||
|
||||
tool_results_for_cli.append({
|
||||
"role": "tool",
|
||||
"tool_call_id": call_id,
|
||||
"name": name,
|
||||
"content": out
|
||||
})
|
||||
|
||||
_append_comms("IN", "tool_result", {"name": name, "id": call_id, "output": out})
|
||||
events.emit("tool_execution", payload={"status": "completed", "tool": name, "result": out, "round": r_idx})
|
||||
|
||||
if _cumulative_tool_bytes > _MAX_TOOL_OUTPUT_BYTES:
|
||||
_append_comms("OUT", "request", {"message": f"[TOOL OUTPUT BUDGET EXCEEDED: {_cumulative_tool_bytes} bytes]"})
|
||||
# We should ideally tell the model here, but for CLI we just append to payload
|
||||
|
||||
# For Gemini CLI, we send the tool results as a JSON array of messages (or similar)
|
||||
# The adapter expects a string, so we'll pass the JSON string of the results.
|
||||
payload = json.dumps(tool_results_for_cli)
|
||||
|
||||
# Return only the text from the last round, because intermediate
|
||||
# text chunks were already pushed to history via comms_log_callback.
|
||||
final_text = all_text[-1] if all_text else "(No text returned)"
|
||||
return final_text
|
||||
except Exception as e:
|
||||
# Basic error classification for CLI
|
||||
raise ProviderError("unknown", "gemini_cli", e)
|
||||
@@ -1348,6 +1436,7 @@ def get_history_bleed_stats(md_content: str | None = None) -> dict:
|
||||
"percentage": percentage,
|
||||
}
|
||||
elif _provider == "gemini":
|
||||
effective_limit = _history_trunc_limit if _history_trunc_limit > 0 else _GEMINI_MAX_INPUT_TOKENS
|
||||
if _gemini_chat:
|
||||
try:
|
||||
_ensure_gemini_client()
|
||||
@@ -1368,7 +1457,7 @@ def get_history_bleed_stats(md_content: str | None = None) -> dict:
|
||||
print("[DEBUG] Gemini count_tokens skipped: no history or md_content")
|
||||
return {
|
||||
"provider": "gemini",
|
||||
"limit": _GEMINI_MAX_INPUT_TOKENS,
|
||||
"limit": effective_limit,
|
||||
"current": 0,
|
||||
"percentage": 0,
|
||||
}
|
||||
@@ -1379,12 +1468,11 @@ def get_history_bleed_stats(md_content: str | None = None) -> dict:
|
||||
contents=history
|
||||
)
|
||||
current_tokens = resp.total_tokens
|
||||
limit_tokens = _GEMINI_MAX_INPUT_TOKENS
|
||||
percentage = (current_tokens / limit_tokens) * 100 if limit_tokens > 0 else 0
|
||||
percentage = (current_tokens / effective_limit) * 100 if effective_limit > 0 else 0
|
||||
print(f"[DEBUG] Gemini current_tokens={current_tokens}, percentage={percentage:.4f}%")
|
||||
return {
|
||||
"provider": "gemini",
|
||||
"limit": limit_tokens,
|
||||
"limit": effective_limit,
|
||||
"current": current_tokens,
|
||||
"percentage": percentage,
|
||||
}
|
||||
@@ -1400,12 +1488,11 @@ def get_history_bleed_stats(md_content: str | None = None) -> dict:
|
||||
contents=[types.Content(role="user", parts=[types.Part.from_text(text=md_content)])]
|
||||
)
|
||||
current_tokens = resp.total_tokens
|
||||
limit_tokens = _GEMINI_MAX_INPUT_TOKENS
|
||||
percentage = (current_tokens / limit_tokens) * 100 if limit_tokens > 0 else 0
|
||||
percentage = (current_tokens / effective_limit) * 100 if effective_limit > 0 else 0
|
||||
print(f"[DEBUG] Gemini (MD ONLY) current_tokens={current_tokens}, percentage={percentage:.4f}%")
|
||||
return {
|
||||
"provider": "gemini",
|
||||
"limit": limit_tokens,
|
||||
"limit": effective_limit,
|
||||
"current": current_tokens,
|
||||
"percentage": percentage,
|
||||
}
|
||||
@@ -1415,10 +1502,28 @@ def get_history_bleed_stats(md_content: str | None = None) -> dict:
|
||||
|
||||
return {
|
||||
"provider": "gemini",
|
||||
"limit": _GEMINI_MAX_INPUT_TOKENS,
|
||||
"limit": effective_limit,
|
||||
"current": 0,
|
||||
"percentage": 0,
|
||||
}
|
||||
elif _provider == "gemini_cli":
|
||||
effective_limit = _history_trunc_limit if _history_trunc_limit > 0 else _GEMINI_MAX_INPUT_TOKENS
|
||||
# For Gemini CLI, we don't have direct count_tokens access without making a call,
|
||||
# so we report the limit and current usage from the last run if available.
|
||||
limit_tokens = effective_limit
|
||||
current_tokens = 0
|
||||
if _gemini_cli_adapter and _gemini_cli_adapter.last_usage:
|
||||
# Stats from CLI use 'input_tokens' or 'input'
|
||||
u = _gemini_cli_adapter.last_usage
|
||||
current_tokens = u.get("input_tokens") or u.get("input", 0)
|
||||
|
||||
percentage = (current_tokens / limit_tokens) * 100 if limit_tokens > 0 else 0
|
||||
return {
|
||||
"provider": "gemini_cli",
|
||||
"limit": limit_tokens,
|
||||
"current": current_tokens,
|
||||
"percentage": percentage,
|
||||
}
|
||||
|
||||
# Default empty state
|
||||
return {
|
||||
|
||||
Reference in New Issue
Block a user