feat(ai): implement DeepSeek provider with streaming and reasoning support
This commit is contained in:
268
ai_client.py
268
ai_client.py
@@ -18,6 +18,7 @@ import datetime
|
||||
import hashlib
|
||||
import difflib
|
||||
import threading
|
||||
import requests
|
||||
from pathlib import Path
|
||||
import os
|
||||
import project_manager
|
||||
@@ -1434,22 +1435,233 @@ def _ensure_deepseek_client():
|
||||
|
||||
def _send_deepseek(md_content: str, user_message: str, base_dir: str,
|
||||
file_items: list[dict] | None = None,
|
||||
discussion_history: str = "") -> str:
|
||||
discussion_history: str = "",
|
||||
stream: bool = False) -> str:
|
||||
"""
|
||||
Placeholder implementation for DeepSeek provider.
|
||||
Aligns with Gemini/Anthropic patterns for history and tool calling.
|
||||
Sends a message to the DeepSeek API, handling tool calls and history.
|
||||
Supports streaming responses.
|
||||
"""
|
||||
try:
|
||||
_ensure_deepseek_client()
|
||||
mcp_client.configure(file_items or [], [base_dir])
|
||||
creds = _load_credentials()
|
||||
api_key = creds.get("deepseek", {}).get("api_key")
|
||||
if not api_key:
|
||||
raise ValueError("DeepSeek API key not found in credentials.toml")
|
||||
|
||||
# TODO: Implement full DeepSeek logic in Phase 2
|
||||
# 1. Build system prompt with context
|
||||
# 2. Manage _deepseek_history
|
||||
# 3. Handle reasoning traces for R1
|
||||
# 4. Handle tool calling loop
|
||||
# DeepSeek API details
|
||||
api_url = "https://api.deepseek.com/chat/completions"
|
||||
headers = {
|
||||
"Authorization": f"Bearer {api_key}",
|
||||
"Content-Type": "application/json",
|
||||
}
|
||||
|
||||
raise ValueError("DeepSeek provider is currently in the infrastructure phase and not yet fully implemented.")
|
||||
# Build the messages for the current API call
|
||||
current_api_messages = []
|
||||
with _deepseek_history_lock:
|
||||
for msg in _deepseek_history:
|
||||
current_api_messages.append(msg)
|
||||
|
||||
# Add the current user's input for this turn
|
||||
initial_user_message_content = user_message
|
||||
if discussion_history:
|
||||
initial_user_message_content = f"[DISCUSSION HISTORY]\n\n{discussion_history}\n\n---\n\n{user_message}"
|
||||
current_api_messages.append({"role": "user", "content": initial_user_message_content})
|
||||
|
||||
# Construct the full request payload
|
||||
request_payload = {
|
||||
"model": _model,
|
||||
"messages": current_api_messages,
|
||||
"temperature": _temperature,
|
||||
"max_tokens": _max_tokens,
|
||||
"stream": stream,
|
||||
}
|
||||
|
||||
# Insert system prompt at the beginning
|
||||
sys_msg = {"role": "system", "content": f"{_get_combined_system_prompt()}\n\n<context>\n{md_content}\n</context>"}
|
||||
request_payload["messages"].insert(0, sys_msg)
|
||||
|
||||
all_text_parts = []
|
||||
_cumulative_tool_bytes = 0
|
||||
round_idx = 0
|
||||
|
||||
while round_idx <= MAX_TOOL_ROUNDS + 1:
|
||||
events.emit("request_start", payload={"provider": "deepseek", "model": _model, "round": round_idx, "streaming": stream})
|
||||
|
||||
try:
|
||||
response = requests.post(api_url, headers=headers, json=request_payload, timeout=60, stream=stream)
|
||||
response.raise_for_status()
|
||||
except requests.exceptions.RequestException as e:
|
||||
raise _classify_deepseek_error(e) from e
|
||||
|
||||
# Process response
|
||||
if stream:
|
||||
aggregated_content = ""
|
||||
aggregated_tool_calls = []
|
||||
aggregated_reasoning = ""
|
||||
current_usage = {}
|
||||
final_finish_reason = "stop"
|
||||
|
||||
for line in response.iter_lines():
|
||||
if not line:
|
||||
continue
|
||||
decoded = line.decode('utf-8')
|
||||
if decoded.startswith('data: '):
|
||||
chunk_str = decoded[len('data: '):]
|
||||
if chunk_str.strip() == '[DONE]':
|
||||
continue
|
||||
try:
|
||||
chunk = json.loads(chunk_str)
|
||||
delta = chunk.get("choices", [{}])[0].get("delta", {})
|
||||
|
||||
if delta.get("content"):
|
||||
aggregated_content += delta["content"]
|
||||
|
||||
if delta.get("reasoning_content"):
|
||||
aggregated_reasoning += delta["reasoning_content"]
|
||||
|
||||
if delta.get("tool_calls"):
|
||||
# Simple aggregation of tool call deltas
|
||||
for tc_delta in delta["tool_calls"]:
|
||||
idx = tc_delta.get("index", 0)
|
||||
while len(aggregated_tool_calls) <= idx:
|
||||
aggregated_tool_calls.append({"id": "", "type": "function", "function": {"name": "", "arguments": ""}})
|
||||
|
||||
target = aggregated_tool_calls[idx]
|
||||
if tc_delta.get("id"):
|
||||
target["id"] = tc_delta["id"]
|
||||
if tc_delta.get("function", {}).get("name"):
|
||||
target["function"]["name"] += tc_delta["function"]["name"]
|
||||
if tc_delta.get("function", {}).get("arguments"):
|
||||
target["function"]["arguments"] += tc_delta["function"]["arguments"]
|
||||
|
||||
if chunk.get("choices", [{}])[0].get("finish_reason"):
|
||||
final_finish_reason = chunk["choices"][0]["finish_reason"]
|
||||
|
||||
if chunk.get("usage"):
|
||||
current_usage = chunk["usage"]
|
||||
except json.JSONDecodeError:
|
||||
continue
|
||||
|
||||
assistant_text = aggregated_content
|
||||
tool_calls_raw = aggregated_tool_calls
|
||||
reasoning_content = aggregated_reasoning
|
||||
finish_reason = final_finish_reason
|
||||
usage = current_usage
|
||||
else:
|
||||
response_data = response.json()
|
||||
choices = response_data.get("choices", [])
|
||||
if not choices:
|
||||
_append_comms("IN", "response", {"round": round_idx, "text": "(No choices returned)", "usage": response_data.get("usage", {})})
|
||||
break
|
||||
|
||||
choice = choices[0]
|
||||
message = choice.get("message", {})
|
||||
assistant_text = message.get("content", "")
|
||||
tool_calls_raw = message.get("tool_calls", [])
|
||||
reasoning_content = message.get("reasoning_content", "")
|
||||
finish_reason = choice.get("finish_reason", "stop")
|
||||
usage = response_data.get("usage", {})
|
||||
|
||||
# Format reasoning content if it exists
|
||||
thinking_tags = ""
|
||||
if reasoning_content:
|
||||
thinking_tags = f"<thinking>\n{reasoning_content}\n</thinking>\n"
|
||||
|
||||
full_assistant_text = thinking_tags + assistant_text
|
||||
|
||||
# Update history
|
||||
with _deepseek_history_lock:
|
||||
msg_to_store = {"role": "assistant", "content": assistant_text}
|
||||
if reasoning_content:
|
||||
msg_to_store["reasoning_content"] = reasoning_content
|
||||
if tool_calls_raw:
|
||||
msg_to_store["tool_calls"] = tool_calls_raw
|
||||
_deepseek_history.append(msg_to_store)
|
||||
|
||||
if full_assistant_text:
|
||||
all_text_parts.append(full_assistant_text)
|
||||
|
||||
_append_comms("IN", "response", {
|
||||
"round": round_idx,
|
||||
"stop_reason": finish_reason,
|
||||
"text": full_assistant_text,
|
||||
"tool_calls": tool_calls_raw,
|
||||
"usage": usage,
|
||||
"streaming": stream
|
||||
})
|
||||
|
||||
if finish_reason != "tool_calls" and not tool_calls_raw:
|
||||
break
|
||||
|
||||
if round_idx > MAX_TOOL_ROUNDS:
|
||||
break
|
||||
|
||||
tool_results_for_history = []
|
||||
for i, tc_raw in enumerate(tool_calls_raw):
|
||||
tool_info = tc_raw.get("function", {})
|
||||
tool_name = tool_info.get("name")
|
||||
tool_args_str = tool_info.get("arguments", "{}")
|
||||
tool_id = tc_raw.get("id")
|
||||
|
||||
try:
|
||||
tool_args = json.loads(tool_args_str)
|
||||
except:
|
||||
tool_args = {}
|
||||
|
||||
events.emit("tool_execution", payload={"status": "started", "tool": tool_name, "args": tool_args, "round": round_idx})
|
||||
|
||||
if tool_name in mcp_client.TOOL_NAMES:
|
||||
_append_comms("OUT", "tool_call", {"name": tool_name, "id": tool_id, "args": tool_args})
|
||||
tool_output = mcp_client.dispatch(tool_name, tool_args)
|
||||
elif tool_name == TOOL_NAME:
|
||||
script = tool_args.get("script", "")
|
||||
_append_comms("OUT", "tool_call", {"name": TOOL_NAME, "id": tool_id, "script": script})
|
||||
tool_output = _run_script(script, base_dir)
|
||||
else:
|
||||
tool_output = f"ERROR: unknown tool '{tool_name}'"
|
||||
|
||||
if i == len(tool_calls_raw) - 1:
|
||||
if file_items:
|
||||
file_items, changed = _reread_file_items(file_items)
|
||||
ctx = _build_file_diff_text(changed)
|
||||
if ctx:
|
||||
tool_output += f"\n\n[SYSTEM: FILES UPDATED]\n\n{ctx}"
|
||||
if round_idx == MAX_TOOL_ROUNDS:
|
||||
tool_output += "\n\n[SYSTEM: MAX ROUNDS. PROVIDE FINAL ANSWER.]"
|
||||
|
||||
tool_output = _truncate_tool_output(tool_output)
|
||||
_cumulative_tool_bytes += len(tool_output)
|
||||
|
||||
tool_results_for_history.append({
|
||||
"role": "tool",
|
||||
"tool_call_id": tool_id,
|
||||
"content": tool_output,
|
||||
})
|
||||
|
||||
_append_comms("IN", "tool_result", {"name": tool_name, "id": tool_id, "output": tool_output})
|
||||
events.emit("tool_execution", payload={"status": "completed", "tool": tool_name, "result": tool_output, "round": round_idx})
|
||||
|
||||
if _cumulative_tool_bytes > _MAX_TOOL_OUTPUT_BYTES:
|
||||
tool_results_for_history.append({
|
||||
"role": "user",
|
||||
"content": f"SYSTEM WARNING: Cumulative tool output exceeded {_MAX_TOOL_OUTPUT_BYTES // 1000}KB budget. Provide your final answer now."
|
||||
})
|
||||
_append_comms("OUT", "request", {"message": f"[TOOL OUTPUT BUDGET EXCEEDED: {_cumulative_tool_bytes} bytes]"})
|
||||
|
||||
with _deepseek_history_lock:
|
||||
for tr in tool_results_for_history:
|
||||
_deepseek_history.append(tr)
|
||||
|
||||
# Update for next round
|
||||
next_messages = []
|
||||
with _deepseek_history_lock:
|
||||
for msg in _deepseek_history:
|
||||
next_messages.append(msg)
|
||||
next_messages.insert(0, sys_msg)
|
||||
request_payload["messages"] = next_messages
|
||||
round_idx += 1
|
||||
|
||||
return "\n\n".join(all_text_parts) if all_text_parts else "(No text returned)"
|
||||
|
||||
except Exception as e:
|
||||
raise _classify_deepseek_error(e) from e
|
||||
@@ -1463,6 +1675,7 @@ def send(
|
||||
base_dir: str = ".",
|
||||
file_items: list[dict] | None = None,
|
||||
discussion_history: str = "",
|
||||
stream: bool = False,
|
||||
) -> str:
|
||||
"""
|
||||
Send a message to the active provider.
|
||||
@@ -1475,6 +1688,7 @@ def send(
|
||||
dynamic context refresh after tool calls
|
||||
discussion_history : discussion history text (used by Gemini to inject as
|
||||
conversation message instead of caching it)
|
||||
stream : Whether to use streaming (supported by DeepSeek)
|
||||
"""
|
||||
with _send_lock:
|
||||
if _provider == "gemini":
|
||||
@@ -1484,7 +1698,7 @@ def send(
|
||||
elif _provider == "anthropic":
|
||||
return _send_anthropic(md_content, user_message, base_dir, file_items, discussion_history)
|
||||
elif _provider == "deepseek":
|
||||
return _send_deepseek(md_content, user_message, base_dir, file_items, discussion_history)
|
||||
return _send_deepseek(md_content, user_message, base_dir, file_items, discussion_history, stream=stream)
|
||||
raise ValueError(f"unknown provider: {_provider}")
|
||||
|
||||
def get_history_bleed_stats(md_content: str | None = None) -> dict:
|
||||
@@ -1597,12 +1811,36 @@ def get_history_bleed_stats(md_content: str | None = None) -> dict:
|
||||
"percentage": percentage,
|
||||
}
|
||||
elif _provider == "deepseek":
|
||||
# Placeholder for DeepSeek token estimation
|
||||
limit_tokens = 64000
|
||||
current_tokens = 0
|
||||
with _deepseek_history_lock:
|
||||
for msg in _deepseek_history:
|
||||
content = msg.get("content", "")
|
||||
if isinstance(content, str):
|
||||
current_tokens += len(content)
|
||||
elif isinstance(content, list):
|
||||
for block in content:
|
||||
if isinstance(block, dict):
|
||||
text = block.get("text", "")
|
||||
if isinstance(text, str):
|
||||
current_tokens += len(text)
|
||||
inp = block.get("input")
|
||||
if isinstance(inp, dict):
|
||||
import json as _json
|
||||
current_tokens += len(_json.dumps(inp, ensure_ascii=False))
|
||||
|
||||
if md_content:
|
||||
current_tokens += len(md_content)
|
||||
if user_message:
|
||||
current_tokens += len(user_message)
|
||||
|
||||
current_tokens = max(1, int(current_tokens / _CHARS_PER_TOKEN))
|
||||
percentage = (current_tokens / limit_tokens) * 100 if limit_tokens > 0 else 0
|
||||
return {
|
||||
"provider": "deepseek",
|
||||
"limit": 64000, # Common limit for deepseek
|
||||
"current": 0,
|
||||
"percentage": 0,
|
||||
"limit": limit_tokens,
|
||||
"current": current_tokens,
|
||||
"percentage": percentage,
|
||||
}
|
||||
|
||||
# Default empty state
|
||||
|
||||
Reference in New Issue
Block a user