vibin
This commit is contained in:
205
ai_client.py
205
ai_client.py
@@ -11,6 +11,13 @@ _gemini_chat = None
|
||||
_anthropic_client = None
|
||||
_anthropic_history: list[dict] = []
|
||||
|
||||
# Injected by gui.py - called when AI wants to run a command.
|
||||
# Signature: (script: str) -> str | None
|
||||
# Returns the output string if approved, None if rejected.
|
||||
confirm_and_run_callback = None
|
||||
|
||||
MAX_TOOL_ROUNDS = 5
|
||||
|
||||
def _load_credentials() -> dict:
|
||||
with open("credentials.toml", "rb") as f:
|
||||
return tomllib.load(f)
|
||||
@@ -61,21 +68,139 @@ def _list_anthropic_models() -> list[str]:
|
||||
models.append(m.id)
|
||||
return sorted(models)
|
||||
|
||||
|
||||
# --------------------------------------------------------- tool definition
|
||||
|
||||
TOOL_NAME = "run_powershell"
|
||||
|
||||
_ANTHROPIC_TOOLS = [
|
||||
{
|
||||
"name": TOOL_NAME,
|
||||
"description": (
|
||||
"Run a PowerShell script within the project base_dir. "
|
||||
"Use this to create, edit, rename, or delete files and directories. "
|
||||
"The working directory is set to base_dir automatically. "
|
||||
"Always prefer targeted edits over full rewrites where possible. "
|
||||
"stdout and stderr are returned to you as the result."
|
||||
),
|
||||
"input_schema": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"script": {
|
||||
"type": "string",
|
||||
"description": "The PowerShell script to execute."
|
||||
}
|
||||
},
|
||||
"required": ["script"]
|
||||
}
|
||||
}
|
||||
]
|
||||
|
||||
def _gemini_tool_declaration():
|
||||
from google.genai import types
|
||||
return types.Tool(
|
||||
function_declarations=[
|
||||
types.FunctionDeclaration(
|
||||
name=TOOL_NAME,
|
||||
description=(
|
||||
"Run a PowerShell script within the project base_dir. "
|
||||
"Use this to create, edit, rename, or delete files and directories. "
|
||||
"The working directory is set to base_dir automatically. "
|
||||
"stdout and stderr are returned to you as the result."
|
||||
),
|
||||
parameters=types.Schema(
|
||||
type=types.Type.OBJECT,
|
||||
properties={
|
||||
"script": types.Schema(
|
||||
type=types.Type.STRING,
|
||||
description="The PowerShell script to execute."
|
||||
)
|
||||
},
|
||||
required=["script"]
|
||||
)
|
||||
)
|
||||
]
|
||||
)
|
||||
|
||||
def _run_script(script: str, base_dir: str) -> str:
|
||||
"""
|
||||
Delegate to the GUI confirmation callback.
|
||||
Returns result string (stdout/stderr) or a rejection message.
|
||||
"""
|
||||
if confirm_and_run_callback is None:
|
||||
return "ERROR: no confirmation handler registered"
|
||||
result = confirm_and_run_callback(script, base_dir)
|
||||
if result is None:
|
||||
return "USER REJECTED: command was not executed"
|
||||
return result
|
||||
|
||||
# ------------------------------------------------------------------ gemini
|
||||
|
||||
def _ensure_gemini_chat():
|
||||
global _gemini_client, _gemini_chat
|
||||
if _gemini_chat is None:
|
||||
def _ensure_gemini_client():
|
||||
global _gemini_client
|
||||
if _gemini_client is None:
|
||||
from google import genai
|
||||
creds = _load_credentials()
|
||||
_gemini_client = genai.Client(api_key=creds["gemini"]["api_key"])
|
||||
_gemini_chat = _gemini_client.chats.create(model=_model)
|
||||
|
||||
def _send_gemini(md_content: str, user_message: str) -> str:
|
||||
_ensure_gemini_chat()
|
||||
def _send_gemini(md_content: str, user_message: str, base_dir: str) -> str:
|
||||
global _gemini_chat
|
||||
from google import genai
|
||||
from google.genai import types
|
||||
|
||||
_ensure_gemini_client()
|
||||
|
||||
# Gemini chats don't support mutating tools after creation,
|
||||
# so we recreate if None (reset_session clears it).
|
||||
if _gemini_chat is None:
|
||||
_gemini_chat = _gemini_client.chats.create(
|
||||
model=_model,
|
||||
config=types.GenerateContentConfig(
|
||||
tools=[_gemini_tool_declaration()]
|
||||
)
|
||||
)
|
||||
|
||||
full_message = f"<context>\n{md_content}\n</context>\n\n{user_message}"
|
||||
|
||||
response = _gemini_chat.send_message(full_message)
|
||||
return response.text
|
||||
|
||||
for _ in range(MAX_TOOL_ROUNDS):
|
||||
# Collect all function calls in this response
|
||||
tool_calls = [
|
||||
part.function_call
|
||||
for candidate in response.candidates
|
||||
for part in candidate.content.parts
|
||||
if part.function_call is not None
|
||||
]
|
||||
if not tool_calls:
|
||||
break
|
||||
|
||||
# Execute each tool call and collect results
|
||||
function_responses = []
|
||||
for fc in tool_calls:
|
||||
if fc.name == TOOL_NAME:
|
||||
script = fc.args.get("script", "")
|
||||
output = _run_script(script, base_dir)
|
||||
function_responses.append(
|
||||
types.Part.from_function_response(
|
||||
name=TOOL_NAME,
|
||||
response={"output": output}
|
||||
)
|
||||
)
|
||||
|
||||
if not function_responses:
|
||||
break
|
||||
|
||||
response = _gemini_chat.send_message(function_responses)
|
||||
|
||||
# Extract text from final response
|
||||
text_parts = [
|
||||
part.text
|
||||
for candidate in response.candidates
|
||||
for part in candidate.content.parts
|
||||
if hasattr(part, "text") and part.text
|
||||
]
|
||||
return "\n".join(text_parts)
|
||||
|
||||
# ------------------------------------------------------------------ anthropic
|
||||
|
||||
@@ -86,25 +211,65 @@ def _ensure_anthropic_client():
|
||||
creds = _load_credentials()
|
||||
_anthropic_client = anthropic.Anthropic(api_key=creds["anthropic"]["api_key"])
|
||||
|
||||
def _send_anthropic(md_content: str, user_message: str) -> str:
|
||||
def _send_anthropic(md_content: str, user_message: str, base_dir: str) -> str:
|
||||
global _anthropic_history
|
||||
import anthropic
|
||||
|
||||
_ensure_anthropic_client()
|
||||
|
||||
full_message = f"<context>\n{md_content}\n</context>\n\n{user_message}"
|
||||
_anthropic_history.append({"role": "user", "content": full_message})
|
||||
response = _anthropic_client.messages.create(
|
||||
model=_model,
|
||||
max_tokens=8096,
|
||||
messages=_anthropic_history
|
||||
)
|
||||
reply = response.content[0].text
|
||||
_anthropic_history.append({"role": "assistant", "content": reply})
|
||||
return reply
|
||||
|
||||
for _ in range(MAX_TOOL_ROUNDS):
|
||||
response = _anthropic_client.messages.create(
|
||||
model=_model,
|
||||
max_tokens=8096,
|
||||
tools=_ANTHROPIC_TOOLS,
|
||||
messages=_anthropic_history
|
||||
)
|
||||
|
||||
# Always record the assistant turn
|
||||
_anthropic_history.append({
|
||||
"role": "assistant",
|
||||
"content": response.content
|
||||
})
|
||||
|
||||
if response.stop_reason != "tool_use":
|
||||
break
|
||||
|
||||
# Process tool calls
|
||||
tool_results = []
|
||||
for block in response.content:
|
||||
if block.type == "tool_use" and block.name == TOOL_NAME:
|
||||
script = block.input.get("script", "")
|
||||
output = _run_script(script, base_dir)
|
||||
tool_results.append({
|
||||
"type": "tool_result",
|
||||
"tool_use_id": block.id,
|
||||
"content": output
|
||||
})
|
||||
|
||||
if not tool_results:
|
||||
break
|
||||
|
||||
_anthropic_history.append({
|
||||
"role": "user",
|
||||
"content": tool_results
|
||||
})
|
||||
|
||||
# Extract final text
|
||||
text_parts = [
|
||||
block.text
|
||||
for block in response.content
|
||||
if hasattr(block, "text") and block.text
|
||||
]
|
||||
return "\n".join(text_parts)
|
||||
|
||||
# ------------------------------------------------------------------ unified send
|
||||
|
||||
def send(md_content: str, user_message: str) -> str:
|
||||
def send(md_content: str, user_message: str, base_dir: str = ".") -> str:
|
||||
if _provider == "gemini":
|
||||
return _send_gemini(md_content, user_message)
|
||||
return _send_gemini(md_content, user_message, base_dir)
|
||||
elif _provider == "anthropic":
|
||||
return _send_anthropic(md_content, user_message)
|
||||
raise ValueError(f"unknown provider: {_provider}")
|
||||
return _send_anthropic(md_content, user_message, base_dir)
|
||||
raise ValueError(f"unknown provider: {_provider}")
|
||||
Reference in New Issue
Block a user