phase 2 checkpoint
This commit is contained in:
@@ -435,3 +435,53 @@ class BiasProfile:
|
||||
tool_weights=data.get("tool_weights", {}),
|
||||
category_multipliers=data.get("category_multipliers", {}),
|
||||
)
|
||||
|
||||
@dataclass
|
||||
class Persona:
|
||||
name: str
|
||||
provider: Optional[str] = None
|
||||
model: Optional[str] = None
|
||||
preferred_models: List[str] = field(default_factory=list)
|
||||
system_prompt: str = ''
|
||||
temperature: Optional[float] = None
|
||||
top_p: Optional[float] = None
|
||||
max_output_tokens: Optional[int] = None
|
||||
tool_preset: Optional[str] = None
|
||||
bias_profile: Optional[str] = None
|
||||
|
||||
def to_dict(self) -> Dict[str, Any]:
|
||||
res = {
|
||||
"system_prompt": self.system_prompt,
|
||||
}
|
||||
if self.provider is not None:
|
||||
res["provider"] = self.provider
|
||||
if self.model is not None:
|
||||
res["model"] = self.model
|
||||
if self.preferred_models:
|
||||
res["preferred_models"] = self.preferred_models
|
||||
if self.temperature is not None:
|
||||
res["temperature"] = self.temperature
|
||||
if self.top_p is not None:
|
||||
res["top_p"] = self.top_p
|
||||
if self.max_output_tokens is not None:
|
||||
res["max_output_tokens"] = self.max_output_tokens
|
||||
if self.tool_preset is not None:
|
||||
res["tool_preset"] = self.tool_preset
|
||||
if self.bias_profile is not None:
|
||||
res["bias_profile"] = self.bias_profile
|
||||
return res
|
||||
|
||||
@classmethod
|
||||
def from_dict(cls, name: str, data: Dict[str, Any]) -> "Persona":
|
||||
return cls(
|
||||
name=name,
|
||||
provider=data.get("provider"),
|
||||
model=data.get("model"),
|
||||
preferred_models=data.get("preferred_models", []),
|
||||
system_prompt=data.get("system_prompt", ""),
|
||||
temperature=data.get("temperature"),
|
||||
top_p=data.get("top_p"),
|
||||
max_output_tokens=data.get("max_output_tokens"),
|
||||
tool_preset=data.get("tool_preset"),
|
||||
bias_profile=data.get("bias_profile"),
|
||||
)
|
||||
|
||||
@@ -64,6 +64,13 @@ def get_global_tool_presets_path() -> Path:
|
||||
def get_project_tool_presets_path(project_root: Path) -> Path:
|
||||
return project_root / "project_tool_presets.toml"
|
||||
|
||||
def get_global_personas_path() -> Path:
|
||||
root_dir = Path(__file__).resolve().parent.parent
|
||||
return Path(os.environ.get("SLOP_GLOBAL_PERSONAS", root_dir / "personas.toml"))
|
||||
|
||||
def get_project_personas_path(project_root: Path) -> Path:
|
||||
return project_root / "project_personas.toml"
|
||||
|
||||
def _resolve_path(env_var: str, config_key: str, default: str) -> Path:
|
||||
if env_var in os.environ:
|
||||
return Path(os.environ[env_var])
|
||||
|
||||
69
src/personas.py
Normal file
69
src/personas.py
Normal file
@@ -0,0 +1,69 @@
|
||||
import tomllib
|
||||
import tomli_w
|
||||
from pathlib import Path
|
||||
from typing import Dict, Any, Optional
|
||||
from src.models import Persona
|
||||
from src import paths
|
||||
|
||||
class PersonaManager:
|
||||
"""Manages Persona profiles across global and project-specific files."""
|
||||
|
||||
def __init__(self, project_root: Optional[Path] = None):
|
||||
self.project_root = project_root
|
||||
|
||||
def _get_path(self, scope: str) -> Path:
|
||||
if scope == "global":
|
||||
return paths.get_global_personas_path()
|
||||
elif scope == "project":
|
||||
if not self.project_root:
|
||||
raise ValueError("Project root is not set, cannot resolve project scope.")
|
||||
return paths.get_project_personas_path(self.project_root)
|
||||
else:
|
||||
raise ValueError("Invalid scope, must be 'global' or 'project'")
|
||||
|
||||
def load_all(self) -> Dict[str, Persona]:
|
||||
"""Merges global and project personas into a single dictionary."""
|
||||
personas = {}
|
||||
|
||||
global_path = paths.get_global_personas_path()
|
||||
global_data = self._load_file(global_path)
|
||||
for name, data in global_data.get("personas", {}).items():
|
||||
personas[name] = Persona.from_dict(name, data)
|
||||
|
||||
if self.project_root:
|
||||
project_path = paths.get_project_personas_path(self.project_root)
|
||||
project_data = self._load_file(project_path)
|
||||
for name, data in project_data.get("personas", {}).items():
|
||||
personas[name] = Persona.from_dict(name, data)
|
||||
|
||||
return personas
|
||||
|
||||
def save_persona(self, persona: Persona, scope: str = "project") -> None:
|
||||
path = self._get_path(scope)
|
||||
data = self._load_file(path)
|
||||
if "personas" not in data:
|
||||
data["personas"] = {}
|
||||
|
||||
data["personas"][persona.name] = persona.to_dict()
|
||||
self._save_file(path, data)
|
||||
|
||||
def delete_persona(self, name: str, scope: str = "project") -> None:
|
||||
path = self._get_path(scope)
|
||||
data = self._load_file(path)
|
||||
if "personas" in data and name in data["personas"]:
|
||||
del data["personas"][name]
|
||||
self._save_file(path, data)
|
||||
|
||||
def _load_file(self, path: Path) -> Dict[str, Any]:
|
||||
if not path.exists():
|
||||
return {}
|
||||
try:
|
||||
with open(path, "rb") as f:
|
||||
return tomllib.load(f)
|
||||
except Exception:
|
||||
return {}
|
||||
|
||||
def _save_file(self, path: Path, data: Dict[str, Any]) -> None:
|
||||
path.parent.mkdir(parents=True, exist_ok=True)
|
||||
with open(path, "wb") as f:
|
||||
tomli_w.dump(data, f)
|
||||
Reference in New Issue
Block a user