amazing
This commit is contained in:
116
ai_client.py
116
ai_client.py
@@ -1,5 +1,7 @@
|
||||
# ai_client.py
|
||||
import tomllib
|
||||
import json
|
||||
import datetime
|
||||
from pathlib import Path
|
||||
|
||||
_provider: str = "gemini"
|
||||
@@ -16,8 +18,56 @@ _anthropic_history: list[dict] = []
|
||||
# Returns the output string if approved, None if rejected.
|
||||
confirm_and_run_callback = None
|
||||
|
||||
# Injected by gui.py - called whenever a comms entry is appended.
|
||||
# Signature: (entry: dict) -> None
|
||||
comms_log_callback = None
|
||||
|
||||
MAX_TOOL_ROUNDS = 5
|
||||
|
||||
# ------------------------------------------------------------------ comms log
|
||||
|
||||
_comms_log: list[dict] = []
|
||||
|
||||
MAX_FIELD_CHARS = 400 # beyond this we show a truncated preview in the UI
|
||||
|
||||
def _clamp(value, max_chars: int = MAX_FIELD_CHARS) -> tuple[str, bool]:
|
||||
"""Return (display_str, was_truncated)."""
|
||||
if isinstance(value, (dict, list)):
|
||||
s = json.dumps(value, ensure_ascii=False, indent=2)
|
||||
else:
|
||||
s = str(value)
|
||||
if len(s) > max_chars:
|
||||
return s[:max_chars], True
|
||||
return s, False
|
||||
|
||||
|
||||
def _append_comms(direction: str, kind: str, payload: dict):
|
||||
"""
|
||||
direction : "OUT" | "IN"
|
||||
kind : "request" | "response" | "tool_call" | "tool_result"
|
||||
payload : raw dict describing the event
|
||||
"""
|
||||
entry = {
|
||||
"ts": datetime.datetime.now().strftime("%H:%M:%S"),
|
||||
"direction": direction,
|
||||
"kind": kind,
|
||||
"provider": _provider,
|
||||
"model": _model,
|
||||
"payload": payload,
|
||||
}
|
||||
_comms_log.append(entry)
|
||||
if comms_log_callback is not None:
|
||||
comms_log_callback(entry)
|
||||
|
||||
|
||||
def get_comms_log() -> list[dict]:
|
||||
return list(_comms_log)
|
||||
|
||||
|
||||
def clear_comms_log():
|
||||
_comms_log.clear()
|
||||
|
||||
|
||||
def _load_credentials() -> dict:
|
||||
with open("credentials.toml", "rb") as f:
|
||||
return tomllib.load(f)
|
||||
@@ -264,15 +314,33 @@ def _send_gemini(md_content: str, user_message: str, base_dir: str) -> str:
|
||||
|
||||
full_message = f"<context>\n{md_content}\n</context>\n\n{user_message}"
|
||||
|
||||
_append_comms("OUT", "request", {
|
||||
"message": full_message,
|
||||
})
|
||||
|
||||
response = _gemini_chat.send_message(full_message)
|
||||
|
||||
for _ in range(MAX_TOOL_ROUNDS):
|
||||
for round_idx in range(MAX_TOOL_ROUNDS):
|
||||
# Log the raw response candidates as text summary
|
||||
text_parts_raw = [
|
||||
part.text
|
||||
for candidate in response.candidates
|
||||
for part in candidate.content.parts
|
||||
if hasattr(part, "text") and part.text
|
||||
]
|
||||
tool_calls = [
|
||||
part.function_call
|
||||
for candidate in response.candidates
|
||||
for part in candidate.content.parts
|
||||
if part.function_call is not None
|
||||
]
|
||||
|
||||
_append_comms("IN", "response", {
|
||||
"round": round_idx,
|
||||
"text": "\n".join(text_parts_raw),
|
||||
"tool_calls": [{"name": fc.name, "args": dict(fc.args)} for fc in tool_calls],
|
||||
})
|
||||
|
||||
if not tool_calls:
|
||||
break
|
||||
|
||||
@@ -280,7 +348,15 @@ def _send_gemini(md_content: str, user_message: str, base_dir: str) -> str:
|
||||
for fc in tool_calls:
|
||||
if fc.name == TOOL_NAME:
|
||||
script = fc.args.get("script", "")
|
||||
_append_comms("OUT", "tool_call", {
|
||||
"name": TOOL_NAME,
|
||||
"script": script,
|
||||
})
|
||||
output = _run_script(script, base_dir)
|
||||
_append_comms("IN", "tool_result", {
|
||||
"name": TOOL_NAME,
|
||||
"output": output,
|
||||
})
|
||||
function_responses.append(
|
||||
types.Part.from_function_response(
|
||||
name=TOOL_NAME,
|
||||
@@ -325,7 +401,11 @@ def _send_anthropic(md_content: str, user_message: str, base_dir: str) -> str:
|
||||
full_message = f"<context>\n{md_content}\n</context>\n\n{user_message}"
|
||||
_anthropic_history.append({"role": "user", "content": full_message})
|
||||
|
||||
for _ in range(MAX_TOOL_ROUNDS):
|
||||
_append_comms("OUT", "request", {
|
||||
"message": full_message,
|
||||
})
|
||||
|
||||
for round_idx in range(MAX_TOOL_ROUNDS):
|
||||
response = _anthropic_client.messages.create(
|
||||
model=_model,
|
||||
max_tokens=8096,
|
||||
@@ -338,6 +418,24 @@ def _send_anthropic(md_content: str, user_message: str, base_dir: str) -> str:
|
||||
"content": response.content
|
||||
})
|
||||
|
||||
# Summarise the response content for the log
|
||||
text_blocks = [b.text for b in response.content if hasattr(b, "text") and b.text]
|
||||
tool_use_blocks = [
|
||||
{"id": b.id, "name": b.name, "input": b.input}
|
||||
for b in response.content
|
||||
if b.type == "tool_use"
|
||||
]
|
||||
_append_comms("IN", "response", {
|
||||
"round": round_idx,
|
||||
"stop_reason": response.stop_reason,
|
||||
"text": "\n".join(text_blocks),
|
||||
"tool_calls": tool_use_blocks,
|
||||
"usage": {
|
||||
"input_tokens": response.usage.input_tokens,
|
||||
"output_tokens": response.usage.output_tokens,
|
||||
} if response.usage else {},
|
||||
})
|
||||
|
||||
if response.stop_reason != "tool_use":
|
||||
break
|
||||
|
||||
@@ -345,7 +443,17 @@ def _send_anthropic(md_content: str, user_message: str, base_dir: str) -> str:
|
||||
for block in response.content:
|
||||
if block.type == "tool_use" and block.name == TOOL_NAME:
|
||||
script = block.input.get("script", "")
|
||||
_append_comms("OUT", "tool_call", {
|
||||
"name": TOOL_NAME,
|
||||
"id": block.id,
|
||||
"script": script,
|
||||
})
|
||||
output = _run_script(script, base_dir)
|
||||
_append_comms("IN", "tool_result", {
|
||||
"name": TOOL_NAME,
|
||||
"id": block.id,
|
||||
"output": output,
|
||||
})
|
||||
tool_results.append({
|
||||
"type": "tool_result",
|
||||
"tool_use_id": block.id,
|
||||
@@ -360,6 +468,10 @@ def _send_anthropic(md_content: str, user_message: str, base_dir: str) -> str:
|
||||
"content": tool_results
|
||||
})
|
||||
|
||||
_append_comms("OUT", "tool_result_send", {
|
||||
"results": [{"tool_use_id": r["tool_use_id"], "content": r["content"]} for r in tool_results],
|
||||
})
|
||||
|
||||
text_parts = [
|
||||
block.text
|
||||
for block in response.content
|
||||
|
||||
Reference in New Issue
Block a user