111 lines
3.5 KiB
Python
111 lines
3.5 KiB
Python
# ai_client.py
|
|
import tomllib
|
|
from pathlib import Path
|
|
|
|
_provider: str = "gemini"
|
|
_model: str = "gemini-2.0-flash"
|
|
|
|
_gemini_client = None
|
|
_gemini_chat = None
|
|
|
|
_anthropic_client = None
|
|
_anthropic_history: list[dict] = []
|
|
|
|
def _load_credentials() -> dict:
|
|
with open("credentials.toml", "rb") as f:
|
|
return tomllib.load(f)
|
|
|
|
# ------------------------------------------------------------------ provider setup
|
|
|
|
def set_provider(provider: str, model: str):
|
|
global _provider, _model
|
|
_provider = provider
|
|
_model = model
|
|
|
|
def reset_session():
|
|
global _gemini_client, _gemini_chat
|
|
global _anthropic_client, _anthropic_history
|
|
_gemini_client = None
|
|
_gemini_chat = None
|
|
_anthropic_client = None
|
|
_anthropic_history = []
|
|
|
|
# ------------------------------------------------------------------ model listing
|
|
|
|
def list_models(provider: str) -> list[str]:
|
|
creds = _load_credentials()
|
|
if provider == "gemini":
|
|
return _list_gemini_models(creds["gemini"]["api_key"])
|
|
elif provider == "anthropic":
|
|
return _list_anthropic_models()
|
|
return []
|
|
|
|
def _list_gemini_models(api_key: str) -> list[str]:
|
|
from google import genai
|
|
client = genai.Client(api_key=api_key)
|
|
models = []
|
|
for m in client.models.list():
|
|
name = m.name
|
|
if name.startswith("models/"):
|
|
name = name[len("models/"):]
|
|
if "gemini" in name.lower():
|
|
models.append(name)
|
|
return sorted(models)
|
|
|
|
def _list_anthropic_models() -> list[str]:
|
|
import anthropic
|
|
creds = _load_credentials()
|
|
client = anthropic.Anthropic(api_key=creds["anthropic"]["api_key"])
|
|
models = []
|
|
for m in client.models.list():
|
|
models.append(m.id)
|
|
return sorted(models)
|
|
|
|
# ------------------------------------------------------------------ gemini
|
|
|
|
def _ensure_gemini_chat():
|
|
global _gemini_client, _gemini_chat
|
|
if _gemini_chat is None:
|
|
from google import genai
|
|
creds = _load_credentials()
|
|
_gemini_client = genai.Client(api_key=creds["gemini"]["api_key"])
|
|
_gemini_chat = _gemini_client.chats.create(model=_model)
|
|
|
|
def _send_gemini(md_content: str, user_message: str) -> str:
|
|
_ensure_gemini_chat()
|
|
full_message = f"<context>\n{md_content}\n</context>\n\n{user_message}"
|
|
response = _gemini_chat.send_message(full_message)
|
|
return response.text
|
|
|
|
# ------------------------------------------------------------------ anthropic
|
|
|
|
def _ensure_anthropic_client():
|
|
global _anthropic_client
|
|
if _anthropic_client is None:
|
|
import anthropic
|
|
creds = _load_credentials()
|
|
_anthropic_client = anthropic.Anthropic(api_key=creds["anthropic"]["api_key"])
|
|
|
|
def _send_anthropic(md_content: str, user_message: str) -> str:
|
|
global _anthropic_history
|
|
_ensure_anthropic_client()
|
|
full_message = f"<context>\n{md_content}\n</context>\n\n{user_message}"
|
|
_anthropic_history.append({"role": "user", "content": full_message})
|
|
response = _anthropic_client.messages.create(
|
|
model=_model,
|
|
max_tokens=8096,
|
|
messages=_anthropic_history
|
|
)
|
|
reply = response.content[0].text
|
|
_anthropic_history.append({"role": "assistant", "content": reply})
|
|
return reply
|
|
|
|
# ------------------------------------------------------------------ unified send
|
|
|
|
def send(md_content: str, user_message: str) -> str:
|
|
if _provider == "gemini":
|
|
return _send_gemini(md_content, user_message)
|
|
elif _provider == "anthropic":
|
|
return _send_anthropic(md_content, user_message)
|
|
raise ValueError(f"unknown provider: {_provider}")
|