- Add cost tracking with new cost_tracker.py module - Enhance Track Proposal modal with editable titles and goals - Add Conductor Setup summary and New Track creation form to MMA Dashboard - Implement Task DAG editing (add/delete tickets) and track-scoped discussion - Add visual polish: color-coded statuses, tinted progress bars, and node indicators - Support live worker streaming from AI providers to GUI panels - Fix numerous integration test regressions and stabilize headless service
584 lines
28 KiB
Python
584 lines
28 KiB
Python
|
|
import os
|
|
|
|
path = 'ai_client.py'
|
|
with open(path, 'r', encoding='utf-8') as f:
|
|
lines = f.readlines()
|
|
|
|
# Very basic cleanup: remove lines after the first 'def get_history_bleed_stats'
|
|
# or other markers of duplication if they exist.
|
|
# Actually, I'll just rewrite the relevant functions and clean up the end of the file.
|
|
|
|
new_lines = []
|
|
skip = False
|
|
for line in lines:
|
|
if 'def _send_gemini(' in line and 'stream_callback' in line:
|
|
# This is my partially applied change, I'll keep it but fix it.
|
|
pass
|
|
if 'def send(' in line and 'import json' in lines[lines.index(line)-1]:
|
|
# This looks like the duplicated send at the end
|
|
skip = True
|
|
if not skip:
|
|
new_lines.append(line)
|
|
if skip and 'return {' in line and 'percentage' in line:
|
|
# End of duplicated get_history_bleed_stats
|
|
# skip = False # actually just keep skipping till the end
|
|
pass
|
|
|
|
# It's better to just surgically fix the file content in memory.
|
|
content = "".join(new_lines)
|
|
|
|
# I'll use a more robust approach: I'll define the final versions of the functions I want to change.
|
|
|
|
_SEND_GEMINI_NEW = '''def _send_gemini(md_content: str, user_message: str, base_dir: str,
|
|
file_items: list[dict[str, Any]] | None = None,
|
|
discussion_history: str = "",
|
|
pre_tool_callback: Optional[Callable[[str], bool]] = None,
|
|
qa_callback: Optional[Callable[[str], str]] = None,
|
|
enable_tools: bool = True,
|
|
stream_callback: Optional[Callable[[str], None]] = None) -> str:
|
|
global _gemini_chat, _gemini_cache, _gemini_cache_md_hash, _gemini_cache_created_at
|
|
try:
|
|
_ensure_gemini_client(); mcp_client.configure(file_items or [], [base_dir])
|
|
# Only stable content (files + screenshots) goes in the cached system instruction.
|
|
# Discussion history is sent as conversation messages so the cache isn't invalidated every turn.
|
|
sys_instr = f"{_get_combined_system_prompt()}
|
|
|
|
<context>
|
|
{md_content}
|
|
</context>"
|
|
td = _gemini_tool_declaration() if enable_tools else None
|
|
tools_decl = [td] if td else None
|
|
# DYNAMIC CONTEXT: Check if files/context changed mid-session
|
|
current_md_hash = hashlib.md5(md_content.encode()).hexdigest()
|
|
old_history = None
|
|
if _gemini_chat and _gemini_cache_md_hash != current_md_hash:
|
|
old_history = list(_get_gemini_history_list(_gemini_chat)) if _get_gemini_history_list(_gemini_chat) else []
|
|
if _gemini_cache:
|
|
try: _gemini_client.caches.delete(name=_gemini_cache.name)
|
|
except Exception as e: _append_comms("OUT", "request", {"message": f"[CACHE DELETE WARN] {e}"})
|
|
_gemini_chat = None
|
|
_gemini_cache = None
|
|
_gemini_cache_created_at = None
|
|
_append_comms("OUT", "request", {"message": "[CONTEXT CHANGED] Rebuilding cache and chat session..."})
|
|
if _gemini_chat and _gemini_cache and _gemini_cache_created_at:
|
|
elapsed = time.time() - _gemini_cache_created_at
|
|
if elapsed > _GEMINI_CACHE_TTL * 0.9:
|
|
old_history = list(_get_gemini_history_list(_gemini_chat)) if _get_gemini_history_list(_get_gemini_history_list(_gemini_chat)) else []
|
|
try: _gemini_client.caches.delete(name=_gemini_cache.name)
|
|
except Exception as e: _append_comms("OUT", "request", {"message": f"[CACHE DELETE WARN] {e}"})
|
|
_gemini_chat = None
|
|
_gemini_cache = None
|
|
_gemini_cache_created_at = None
|
|
_append_comms("OUT", "request", {"message": f"[CACHE TTL] Rebuilding cache (expired after {int(elapsed)}s)..."})
|
|
if not _gemini_chat:
|
|
chat_config = types.GenerateContentConfig(
|
|
system_instruction=sys_instr,
|
|
tools=tools_decl,
|
|
temperature=_temperature,
|
|
max_output_tokens=_max_tokens,
|
|
safety_settings=[types.SafetySetting(category="HARM_CATEGORY_DANGEROUS_CONTENT", threshold="BLOCK_ONLY_HIGH")]
|
|
)
|
|
should_cache = False
|
|
try:
|
|
count_resp = _gemini_client.models.count_tokens(model=_model, contents=[sys_instr])
|
|
if count_resp.total_tokens >= 2048:
|
|
should_cache = True
|
|
else:
|
|
_append_comms("OUT", "request", {"message": f"[CACHING SKIPPED] Context too small ({count_resp.total_tokens} tokens < 2048)"})
|
|
except Exception as e:
|
|
_append_comms("OUT", "request", {"message": f"[COUNT FAILED] {e}"})
|
|
if should_cache:
|
|
try:
|
|
_gemini_cache = _gemini_client.caches.create(
|
|
model=_model,
|
|
config=types.CreateCachedContentConfig(
|
|
system_instruction=sys_instr,
|
|
tools=tools_decl,
|
|
ttl=f"{_GEMINI_CACHE_TTL}s",
|
|
)
|
|
)
|
|
_gemini_cache_created_at = time.time()
|
|
chat_config = types.GenerateContentConfig(
|
|
cached_content=_gemini_cache.name,
|
|
temperature=_temperature,
|
|
max_output_tokens=_max_tokens,
|
|
safety_settings=[types.SafetySetting(category="HARM_CATEGORY_DANGEROUS_CONTENT", threshold="BLOCK_ONLY_HIGH")]
|
|
)
|
|
_append_comms("OUT", "request", {"message": f"[CACHE CREATED] {_gemini_cache.name}"})
|
|
except Exception as e:
|
|
_gemini_cache = None
|
|
_gemini_cache_created_at = None
|
|
_append_comms("OUT", "request", {"message": f"[CACHE FAILED] {type(e).__name__}: {e} \u2014 falling back to inline system_instruction"})
|
|
kwargs = {"model": _model, "config": chat_config}
|
|
if old_history:
|
|
kwargs["history"] = old_history
|
|
_gemini_chat = _gemini_client.chats.create(**kwargs)
|
|
_gemini_cache_md_hash = current_md_hash
|
|
if discussion_history and not old_history:
|
|
_gemini_chat.send_message(f"[DISCUSSION HISTORY]
|
|
|
|
{discussion_history}")
|
|
_append_comms("OUT", "request", {"message": f"[HISTORY INJECTED] {len(discussion_history)} chars"})
|
|
_append_comms("OUT", "request", {"message": f"[ctx {len(md_content)} + msg {len(user_message)}]"})
|
|
payload: str | list[types.Part] = user_message
|
|
all_text: list[str] = []
|
|
_cumulative_tool_bytes = 0
|
|
if _gemini_chat and _get_gemini_history_list(_gemini_chat):
|
|
for msg in _get_gemini_history_list(_gemini_chat):
|
|
if msg.role == "user" and hasattr(msg, "parts"):
|
|
for p in msg.parts:
|
|
if hasattr(p, "function_response") and p.function_response and hasattr(p.function_response, "response"):
|
|
r = p.function_response.response
|
|
if isinstance(r, dict) and "output" in r:
|
|
val = r["output"]
|
|
if isinstance(val, str):
|
|
if "[SYSTEM: FILES UPDATED]" in val:
|
|
val = val.split("[SYSTEM: FILES UPDATED]")[0].strip()
|
|
if _history_trunc_limit > 0 and len(val) > _history_trunc_limit:
|
|
val = val[:_history_trunc_limit] + "
|
|
|
|
... [TRUNCATED BY SYSTEM TO SAVE TOKENS.]"
|
|
r["output"] = val
|
|
for r_idx in range(MAX_TOOL_ROUNDS + 2):
|
|
events.emit("request_start", payload={"provider": "gemini", "model": _model, "round": r_idx})
|
|
if stream_callback:
|
|
resp = _gemini_chat.send_message_stream(payload)
|
|
txt_chunks = []
|
|
for chunk in resp:
|
|
c_txt = chunk.text
|
|
if c_txt:
|
|
txt_chunks.append(c_txt)
|
|
stream_callback(c_txt)
|
|
txt = "".join(txt_chunks)
|
|
calls = [p.function_call for c in resp.candidates if getattr(c, "content", None) for p in c.content.parts if hasattr(p, "function_call") and p.function_call]
|
|
usage = {"input_tokens": getattr(resp.usage_metadata, "prompt_token_count", 0), "output_tokens": getattr(resp.usage_metadata, "candidates_token_count", 0)}
|
|
cached_tokens = getattr(resp.usage_metadata, "cached_content_token_count", None)
|
|
if cached_tokens: usage["cache_read_input_tokens"] = cached_tokens
|
|
else:
|
|
resp = _gemini_chat.send_message(payload)
|
|
txt = "
|
|
".join(p.text for c in resp.candidates if getattr(c, "content", None) for p in c.content.parts if hasattr(p, "text") and p.text)
|
|
calls = [p.function_call for c in resp.candidates if getattr(c, "content", None) for p in c.content.parts if hasattr(p, "function_call") and p.function_call]
|
|
usage = {"input_tokens": getattr(resp.usage_metadata, "prompt_token_count", 0), "output_tokens": getattr(resp.usage_metadata, "candidates_token_count", 0)}
|
|
cached_tokens = getattr(resp.usage_metadata, "cached_content_token_count", None)
|
|
if cached_tokens: usage["cache_read_input_tokens"] = cached_tokens
|
|
if txt: all_text.append(txt)
|
|
events.emit("response_received", payload={"provider": "gemini", "model": _model, "usage": usage, "round": r_idx})
|
|
reason = resp.candidates[0].finish_reason.name if resp.candidates and hasattr(resp.candidates[0], "finish_reason") else "STOP"
|
|
_append_comms("IN", "response", {"round": r_idx, "stop_reason": reason, "text": txt, "tool_calls": [{"name": c.name, "args": dict(c.args)} for c in calls], "usage": usage})
|
|
total_in = usage.get("input_tokens", 0)
|
|
if total_in > _GEMINI_MAX_INPUT_TOKENS * 0.4 and _gemini_chat and _get_gemini_history_list(_gemini_chat):
|
|
hist = _get_gemini_history_list(_gemini_chat)
|
|
dropped = 0
|
|
while len(hist) > 4 and total_in > _GEMINI_MAX_INPUT_TOKENS * 0.3:
|
|
saved = 0
|
|
for _ in range(2):
|
|
if not hist: break
|
|
for p in hist[0].parts:
|
|
if hasattr(p, "text") and p.text: saved += int(len(p.text) / _CHARS_PER_TOKEN)
|
|
elif hasattr(p, "function_response") and p.function_response:
|
|
r = getattr(p.function_response, "response", {})
|
|
if isinstance(r, dict): saved += int(len(str(r.get("output", ""))) / _CHARS_PER_TOKEN)
|
|
hist.pop(0)
|
|
dropped += 1
|
|
total_in -= max(saved, 200)
|
|
if dropped > 0: _append_comms("OUT", "request", {"message": f"[GEMINI HISTORY TRIMMED: dropped {dropped} old entries]"})
|
|
if not calls or r_idx > MAX_TOOL_ROUNDS: break
|
|
f_resps: list[types.Part] = []
|
|
log: list[dict[str, Any]] = []
|
|
for i, fc in enumerate(calls):
|
|
name, args = fc.name, dict(fc.args)
|
|
if pre_tool_callback:
|
|
payload_str = json.dumps({"tool": name, "args": args})
|
|
if not pre_tool_callback(payload_str):
|
|
out = "USER REJECTED: tool execution cancelled"
|
|
f_resps.append(types.Part.from_function_response(name=name, response={"output": out}))
|
|
log.append({"tool_use_id": name, "content": out})
|
|
continue
|
|
events.emit("tool_execution", payload={"status": "started", "tool": name, "args": args, "round": r_idx})
|
|
if name in mcp_client.TOOL_NAMES:
|
|
_append_comms("OUT", "tool_call", {"name": name, "args": args})
|
|
out = mcp_client.dispatch(name, args)
|
|
elif name == TOOL_NAME:
|
|
scr = args.get("script", "")
|
|
_append_comms("OUT", "tool_call", {"name": TOOL_NAME, "script": scr})
|
|
out = _run_script(scr, base_dir, qa_callback)
|
|
else: out = f"ERROR: unknown tool '{name}'"
|
|
if i == len(calls) - 1:
|
|
if file_items:
|
|
file_items, changed = _reread_file_items(file_items)
|
|
ctx = _build_file_diff_text(changed)
|
|
if ctx: out += f"
|
|
|
|
[SYSTEM: FILES UPDATED]
|
|
|
|
{ctx}"
|
|
if r_idx == MAX_TOOL_ROUNDS: out += "
|
|
|
|
[SYSTEM: MAX ROUNDS. PROVIDE FINAL ANSWER.]"
|
|
out = _truncate_tool_output(out)
|
|
_cumulative_tool_bytes += len(out)
|
|
f_resps.append(types.Part.from_function_response(name=name, response={"output": out}))
|
|
log.append({"tool_use_id": name, "content": out})
|
|
events.emit("tool_execution", payload={"status": "completed", "tool": name, "result": out, "round": r_idx})
|
|
if _cumulative_tool_bytes > _MAX_TOOL_OUTPUT_BYTES:
|
|
f_resps.append(types.Part.from_text(f"SYSTEM WARNING: Cumulative tool output exceeded {_MAX_TOOL_OUTPUT_BYTES // 1000}KB budget."))
|
|
_append_comms("OUT", "request", {"message": f"[TOOL OUTPUT BUDGET EXCEEDED: {_cumulative_tool_bytes} bytes]"})
|
|
_append_comms("OUT", "tool_result_send", {"results": log})
|
|
payload = f_resps
|
|
return "
|
|
|
|
".join(all_text) if all_text else "(No text returned)"
|
|
except Exception as e: raise _classify_gemini_error(e) from e
|
|
'''
|
|
|
|
_SEND_ANTHROPIC_NEW = '''def _send_anthropic(md_content: str, user_message: str, base_dir: str, file_items: list[dict[str, Any]] | None = None, discussion_history: str = "", pre_tool_callback: Optional[Callable[[str], bool]] = None, qa_callback: Optional[Callable[[str], str]] = None, stream_callback: Optional[Callable[[str], None]] = None) -> str:
|
|
try:
|
|
_ensure_anthropic_client()
|
|
mcp_client.configure(file_items or [], [base_dir])
|
|
stable_prompt = _get_combined_system_prompt()
|
|
stable_blocks = [{"type": "text", "text": stable_prompt, "cache_control": {"type": "ephemeral"}}]
|
|
context_text = f"
|
|
|
|
<context>
|
|
{md_content}
|
|
</context>"
|
|
context_blocks = _build_chunked_context_blocks(context_text)
|
|
system_blocks = stable_blocks + context_blocks
|
|
if discussion_history and not _anthropic_history:
|
|
user_content: list[dict[str, Any]] = [{"type": "text", "text": f"[DISCUSSION HISTORY]
|
|
|
|
{discussion_history}
|
|
|
|
---
|
|
|
|
{user_message}"}]
|
|
else:
|
|
user_content = [{"type": "text", "text": user_message}]
|
|
for msg in _anthropic_history:
|
|
if msg.get("role") == "user" and isinstance(msg.get("content"), list):
|
|
modified = False
|
|
for block in msg["content"]:
|
|
if isinstance(block, dict) and block.get("type") == "tool_result":
|
|
t_content = block.get("content", "")
|
|
if _history_trunc_limit > 0 and isinstance(t_content, str) and len(t_content) > _history_trunc_limit:
|
|
block["content"] = t_content[:_history_trunc_limit] + "
|
|
|
|
... [TRUNCATED BY SYSTEM]"
|
|
modified = True
|
|
if modified: _invalidate_token_estimate(msg)
|
|
_strip_cache_controls(_anthropic_history)
|
|
_repair_anthropic_history(_anthropic_history)
|
|
_anthropic_history.append({"role": "user", "content": user_content})
|
|
_add_history_cache_breakpoint(_anthropic_history)
|
|
all_text_parts: list[str] = []
|
|
_cumulative_tool_bytes = 0
|
|
def _strip_private_keys(history: list[dict[str, Any]]) -> list[dict[str, Any]]:
|
|
return [{k: v for k, v in m.items() if not k.startswith("_")} for m in history]
|
|
for round_idx in range(MAX_TOOL_ROUNDS + 2):
|
|
dropped = _trim_anthropic_history(system_blocks, _anthropic_history)
|
|
if dropped > 0:
|
|
est_tokens = _estimate_prompt_tokens(system_blocks, _anthropic_history)
|
|
_append_comms("OUT", "request", {"message": f"[HISTORY TRIMMED: dropped {dropped} old messages]"})
|
|
events.emit("request_start", payload={"provider": "anthropic", "model": _model, "round": round_idx})
|
|
if stream_callback:
|
|
with _anthropic_client.messages.stream(
|
|
model=_model,
|
|
max_tokens=_max_tokens,
|
|
temperature=_temperature,
|
|
system=system_blocks,
|
|
tools=_get_anthropic_tools(),
|
|
messages=_strip_private_keys(_anthropic_history),
|
|
) as stream:
|
|
for event in stream:
|
|
if event.type == "content_block_delta" and event.delta.type == "text_delta":
|
|
stream_callback(event.delta.text)
|
|
response = stream.get_final_message()
|
|
else:
|
|
response = _anthropic_client.messages.create(
|
|
model=_model,
|
|
max_tokens=_max_tokens,
|
|
temperature=_temperature,
|
|
system=system_blocks,
|
|
tools=_get_anthropic_tools(),
|
|
messages=_strip_private_keys(_anthropic_history),
|
|
)
|
|
serialised_content = [_content_block_to_dict(b) for b in response.content]
|
|
_anthropic_history.append({"role": "assistant", "content": serialised_content})
|
|
text_blocks = [b.text for b in response.content if hasattr(b, "text") and b.text]
|
|
if text_blocks: all_text_parts.append("
|
|
".join(text_blocks))
|
|
tool_use_blocks = [{"id": b.id, "name": b.name, "input": b.input} for b in response.content if getattr(b, "type", None) == "tool_use"]
|
|
usage_dict: dict[str, Any] = {}
|
|
if response.usage:
|
|
usage_dict["input_tokens"] = response.usage.input_tokens
|
|
usage_dict["output_tokens"] = response.usage.output_tokens
|
|
for k in ["cache_creation_input_tokens", "cache_read_input_tokens"]:
|
|
val = getattr(response.usage, k, None)
|
|
if val is not None: usage_dict[k] = val
|
|
events.emit("response_received", payload={"provider": "anthropic", "model": _model, "usage": usage_dict, "round": round_idx})
|
|
_append_comms("IN", "response", {"round": round_idx, "stop_reason": response.stop_reason, "text": "
|
|
".join(text_blocks), "tool_calls": tool_use_blocks, "usage": usage_dict})
|
|
if response.stop_reason != "tool_use" or not tool_use_blocks: break
|
|
if round_idx > MAX_TOOL_ROUNDS: break
|
|
tool_results: list[dict[str, Any]] = []
|
|
for block in response.content:
|
|
if getattr(block, "type", None) != "tool_use": continue
|
|
b_name, b_id, b_input = block.name, block.id, block.input
|
|
if pre_tool_callback:
|
|
if not pre_tool_callback(json.dumps({"tool": b_name, "args": b_input})):
|
|
tool_results.append({"type": "tool_result", "tool_use_id": b_id, "content": "USER REJECTED: tool execution cancelled"})
|
|
continue
|
|
events.emit("tool_execution", payload={"status": "started", "tool": b_name, "args": b_input, "round": round_idx})
|
|
if b_name in mcp_client.TOOL_NAMES:
|
|
_append_comms("OUT", "tool_call", {"name": b_name, "id": b_id, "args": b_input})
|
|
output = mcp_client.dispatch(b_name, b_input)
|
|
elif b_name == TOOL_NAME:
|
|
scr = b_input.get("script", "")
|
|
_append_comms("OUT", "tool_call", {"name": TOOL_NAME, "id": b_id, "script": scr})
|
|
output = _run_script(scr, base_dir, qa_callback)
|
|
else: output = f"ERROR: unknown tool '{b_name}'"
|
|
truncated = _truncate_tool_output(output)
|
|
_cumulative_tool_bytes += len(truncated)
|
|
tool_results.append({"type": "tool_result", "tool_use_id": b_id, "content": truncated})
|
|
_append_comms("IN", "tool_result", {"name": b_name, "id": b_id, "output": output})
|
|
events.emit("tool_execution", payload={"status": "completed", "tool": b_name, "result": output, "round": round_idx})
|
|
if _cumulative_tool_bytes > _MAX_TOOL_OUTPUT_BYTES:
|
|
tool_results.append({"type": "text", "text": "SYSTEM WARNING: Cumulative tool output exceeded budget."})
|
|
if file_items:
|
|
file_items, changed = _reread_file_items(file_items)
|
|
refreshed_ctx = _build_file_diff_text(changed)
|
|
if refreshed_ctx: tool_results.append({"type": "text", "text": f"[FILES UPDATED]
|
|
|
|
{refreshed_ctx}"})
|
|
if round_idx == MAX_TOOL_ROUNDS: tool_results.append({"type": "text", "text": "SYSTEM WARNING: MAX TOOL ROUNDS REACHED."})
|
|
_anthropic_history.append({"role": "user", "content": tool_results})
|
|
_append_comms("OUT", "tool_result_send", {"results": [{"tool_use_id": r["tool_use_id"], "content": r["content"]} for r in tool_results if r.get("type") == "tool_result"]})
|
|
return "
|
|
|
|
".join(all_text_parts) if all_text_parts else "(No text returned)"
|
|
except Exception as exc: raise _classify_anthropic_error(exc) from exc
|
|
'''
|
|
|
|
_SEND_DEEPSEEK_NEW = '''def _send_deepseek(md_content: str, user_message: str, base_dir: str,
|
|
file_items: list[dict[str, Any]] | None = None,
|
|
discussion_history: str = "",
|
|
stream: bool = False,
|
|
pre_tool_callback: Optional[Callable[[str], bool]] = None,
|
|
qa_callback: Optional[Callable[[str], str]] = None,
|
|
stream_callback: Optional[Callable[[str], None]] = None) -> str:
|
|
try:
|
|
mcp_client.configure(file_items or [], [base_dir])
|
|
creds = _load_credentials()
|
|
api_key = creds.get("deepseek", {}).get("api_key")
|
|
if not api_key: raise ValueError("DeepSeek API key not found")
|
|
api_url = "https://api.deepseek.com/chat/completions"
|
|
headers = {"Authorization": f"Bearer {api_key}", "Content-Type": "application/json"}
|
|
current_api_messages: list[dict[str, Any]] = []
|
|
with _deepseek_history_lock:
|
|
for msg in _deepseek_history: current_api_messages.append(msg)
|
|
initial_user_message_content = user_message
|
|
if discussion_history: initial_user_message_content = f"[DISCUSSION HISTORY]
|
|
|
|
{discussion_history}
|
|
|
|
---
|
|
|
|
{user_message}"
|
|
current_api_messages.append({"role": "user", "content": initial_user_message_content})
|
|
request_payload: dict[str, Any] = {"model": _model, "messages": current_api_messages, "temperature": _temperature, "max_tokens": _max_tokens, "stream": stream}
|
|
sys_msg = {"role": "system", "content": f"{_get_combined_system_prompt()}
|
|
|
|
<context>
|
|
{md_content}
|
|
</context>"}
|
|
request_payload["messages"].insert(0, sys_msg)
|
|
all_text_parts: list[str] = []
|
|
_cumulative_tool_bytes = 0
|
|
round_idx = 0
|
|
while round_idx <= MAX_TOOL_ROUNDS + 1:
|
|
events.emit("request_start", payload={"provider": "deepseek", "model": _model, "round": round_idx, "streaming": stream})
|
|
try:
|
|
response = requests.post(api_url, headers=headers, json=request_payload, timeout=60, stream=stream)
|
|
response.raise_for_status()
|
|
except requests.exceptions.RequestException as e: raise _classify_deepseek_error(e) from e
|
|
if stream:
|
|
aggregated_content, aggregated_tool_calls, aggregated_reasoning = "", [], ""
|
|
current_usage, final_finish_reason = {}, "stop"
|
|
for line in response.iter_lines():
|
|
if not line: continue
|
|
decoded = line.decode('utf-8')
|
|
if decoded.startswith('data: '):
|
|
chunk_str = decoded[len('data: '):]
|
|
if chunk_str.strip() == '[DONE]': continue
|
|
try:
|
|
chunk = json.loads(chunk_str)
|
|
delta = chunk.get("choices", [{}])[0].get("delta", {})
|
|
if delta.get("content"):
|
|
aggregated_content += delta["content"]
|
|
if stream_callback: stream_callback(delta["content"])
|
|
if delta.get("reasoning_content"): aggregated_reasoning += delta["reasoning_content"]
|
|
if delta.get("tool_calls"):
|
|
for tc_delta in delta["tool_calls"]:
|
|
idx = tc_delta.get("index", 0)
|
|
while len(aggregated_tool_calls) <= idx: aggregated_tool_calls.append({"id": "", "type": "function", "function": {"name": "", "arguments": ""}})
|
|
target = aggregated_tool_calls[idx]
|
|
if tc_delta.get("id"): target["id"] = tc_delta["id"]
|
|
if tc_delta.get("function", {}).get("name"): target["function"]["name"] += tc_delta["function"]["name"]
|
|
if tc_delta.get("function", {}).get("arguments"): target["function"]["arguments"] += tc_delta["function"]["arguments"]
|
|
if chunk.get("choices", [{}])[0].get("finish_reason"): final_finish_reason = chunk["choices"][0]["finish_reason"]
|
|
if chunk.get("usage"): current_usage = chunk["usage"]
|
|
except json.JSONDecodeError: continue
|
|
assistant_text, tool_calls_raw, reasoning_content, finish_reason, usage = aggregated_content, aggregated_tool_calls, aggregated_reasoning, final_finish_reason, current_usage
|
|
else:
|
|
response_data = response.json()
|
|
choices = response_data.get("choices", [])
|
|
if not choices: break
|
|
choice = choices[0]
|
|
message = choice.get("message", {})
|
|
assistant_text, tool_calls_raw, reasoning_content, finish_reason, usage = message.get("content", ""), message.get("tool_calls", []), message.get("reasoning_content", ""), choice.get("finish_reason", "stop"), response_data.get("usage", {})
|
|
full_assistant_text = (f"<thinking>
|
|
{reasoning_content}
|
|
</thinking>
|
|
" if reasoning_content else "") + assistant_text
|
|
with _deepseek_history_lock:
|
|
msg_to_store = {"role": "assistant", "content": assistant_text}
|
|
if reasoning_content: msg_to_store["reasoning_content"] = reasoning_content
|
|
if tool_calls_raw: msg_to_store["tool_calls"] = tool_calls_raw
|
|
_deepseek_history.append(msg_to_store)
|
|
if full_assistant_text: all_text_parts.append(full_assistant_text)
|
|
_append_comms("IN", "response", {"round": round_idx, "stop_reason": finish_reason, "text": full_assistant_text, "tool_calls": tool_calls_raw, "usage": usage, "streaming": stream})
|
|
if finish_reason != "tool_calls" and not tool_calls_raw: break
|
|
if round_idx > MAX_TOOL_ROUNDS: break
|
|
tool_results_for_history: list[dict[str, Any]] = []
|
|
for i, tc_raw in enumerate(tool_calls_raw):
|
|
tool_info = tc_raw.get("function", {})
|
|
tool_name, tool_args_str, tool_id = tool_info.get("name"), tool_info.get("arguments", "{}"), tc_raw.get("id")
|
|
try: tool_args = json.loads(tool_args_str)
|
|
except: tool_args = {}
|
|
if pre_tool_callback:
|
|
if not pre_tool_callback(json.dumps({"tool": tool_name, "args": tool_args})):
|
|
tool_output = "USER REJECTED: tool execution cancelled"
|
|
tool_results_for_history.append({"role": "tool", "tool_call_id": tool_id, "content": tool_output})
|
|
continue
|
|
events.emit("tool_execution", payload={"status": "started", "tool": tool_name, "args": tool_args, "round": round_idx})
|
|
if tool_name in mcp_client.TOOL_NAMES:
|
|
_append_comms("OUT", "tool_call", {"name": tool_name, "id": tool_id, "args": tool_args})
|
|
tool_output = mcp_client.dispatch(tool_name, tool_args)
|
|
elif tool_name == TOOL_NAME:
|
|
script = tool_args.get("script", "")
|
|
_append_comms("OUT", "tool_call", {"name": TOOL_NAME, "id": tool_id, "script": script})
|
|
tool_output = _run_script(script, base_dir, qa_callback)
|
|
else: tool_output = f"ERROR: unknown tool '{tool_name}'"
|
|
if i == len(tool_calls_raw) - 1:
|
|
if file_items:
|
|
file_items, changed = _reread_file_items(file_items)
|
|
ctx = _build_file_diff_text(changed)
|
|
if ctx: tool_output += f"
|
|
|
|
[SYSTEM: FILES UPDATED]
|
|
|
|
{ctx}"
|
|
if round_idx == MAX_TOOL_ROUNDS: tool_output += "
|
|
|
|
[SYSTEM: MAX ROUNDS. PROVIDE FINAL ANSWER.]"
|
|
tool_output = _truncate_tool_output(tool_output)
|
|
_cumulative_tool_bytes += len(tool_output)
|
|
tool_results_for_history.append({"role": "tool", "tool_call_id": tool_id, "content": tool_output})
|
|
_append_comms("IN", "tool_result", {"name": tool_name, "id": tool_id, "output": tool_output})
|
|
events.emit("tool_execution", payload={"status": "completed", "tool": tool_name, "result": tool_output, "round": round_idx})
|
|
if _cumulative_tool_bytes > _MAX_TOOL_OUTPUT_BYTES:
|
|
tool_results_for_history.append({"role": "user", "content": "SYSTEM WARNING: Cumulative tool output exceeded budget."})
|
|
with _deepseek_history_lock:
|
|
for tr in tool_results_for_history: _deepseek_history.append(tr)
|
|
next_messages: list[dict[str, Any]] = []
|
|
with _deepseek_history_lock:
|
|
for msg in _deepseek_history: next_messages.append(msg)
|
|
next_messages.insert(0, sys_msg)
|
|
request_payload["messages"] = next_messages
|
|
round_idx += 1
|
|
return "
|
|
|
|
".join(all_text_parts) if all_text_parts else "(No text returned)"
|
|
except Exception as e: raise _classify_deepseek_error(e) from e
|
|
'''
|
|
|
|
_SEND_NEW = '''def send(
|
|
md_content: str,
|
|
user_message: str,
|
|
base_dir: str = ".",
|
|
file_items: list[dict[str, Any]] | None = None,
|
|
discussion_history: str = "",
|
|
stream: bool = False,
|
|
pre_tool_callback: Optional[Callable[[str], bool]] = None,
|
|
qa_callback: Optional[Callable[[str], str]] = None,
|
|
enable_tools: bool = True,
|
|
stream_callback: Optional[Callable[[str], None]] = None,
|
|
) -> str:
|
|
"""
|
|
Sends a prompt with the full markdown context to the current AI provider.
|
|
Returns the final text response.
|
|
"""
|
|
with _send_lock:
|
|
if _provider == "gemini":
|
|
return _send_gemini(
|
|
md_content, user_message, base_dir, file_items, discussion_history,
|
|
pre_tool_callback, qa_callback, enable_tools, stream_callback
|
|
)
|
|
elif _provider == "gemini_cli":
|
|
return _send_gemini_cli(
|
|
md_content, user_message, base_dir, file_items, discussion_history,
|
|
pre_tool_callback, qa_callback
|
|
)
|
|
elif _provider == "anthropic":
|
|
return _send_anthropic(
|
|
md_content, user_message, base_dir, file_items, discussion_history,
|
|
pre_tool_callback, qa_callback, stream_callback=stream_callback
|
|
)
|
|
elif _provider == "deepseek":
|
|
return _send_deepseek(
|
|
md_content, user_message, base_dir, file_items, discussion_history,
|
|
stream, pre_tool_callback, qa_callback, stream_callback
|
|
)
|
|
else:
|
|
raise ValueError(f"Unknown provider: {_provider}")
|
|
'''
|
|
|
|
# Use regex or simple string replacement to replace the old functions with new ones.
|
|
import re
|
|
|
|
def replace_func(content, func_name, new_body):
|
|
# This is tricky because functions can be complex.
|
|
# I'll just use a marker based approach for this specific file.
|
|
start_marker = f'def {func_name}('
|
|
# Find the next 'def ' or end of file
|
|
start_idx = content.find(start_marker)
|
|
if start_idx == -1: return content
|
|
|
|
# Find the end of the function (rough estimation based on next def at column 0)
|
|
next_def = re.search(r'
|
|
|
|
def ', content[start_idx+1:])
|
|
if next_def:
|
|
end_idx = start_idx + 1 + next_def.start()
|
|
else:
|
|
end_idx = len(content)
|
|
|
|
return content[:start_idx] + new_body + content[end_idx:]
|
|
|
|
# Final content construction
|
|
content = replace_func(content, '_send_gemini', _SEND_GEMINI_NEW)
|
|
content = replace_func(content, '_send_anthropic', _SEND_ANTHROPIC_NEW)
|
|
content = replace_func(content, '_send_deepseek', _SEND_DEEPSEEK_NEW)
|
|
content = replace_func(content, 'send', _SEND_NEW)
|
|
|
|
# Remove the duplicated parts at the end if any
|
|
marker = 'import json
|
|
from typing import Any, Callable, Optional, List'
|
|
if marker in content:
|
|
content = content[:content.find(marker)]
|
|
|
|
with open(path, 'w', encoding='utf-8') as f:
|
|
f.write(content)
|