105 lines
3.4 KiB
Python
105 lines
3.4 KiB
Python
import sys
|
|
import tomllib
|
|
import tomli_w
|
|
from pathlib import Path
|
|
from typing import Dict, Any, Optional
|
|
from src.models import Preset
|
|
from src.paths import get_global_presets_path, get_project_presets_path
|
|
|
|
class PresetManager:
|
|
"""Manages system prompt presets across global and project-specific files."""
|
|
|
|
def __init__(self, project_root: Optional[Path] = None):
|
|
self.project_root = project_root
|
|
self.global_path = get_global_presets_path()
|
|
|
|
@property
|
|
def project_path(self) -> Optional[Path]:
|
|
return get_project_presets_path(self.project_root) if self.project_root else None
|
|
|
|
def load_all(self) -> Dict[str, Preset]:
|
|
"""Merges global and project presets into a single dictionary."""
|
|
presets: Dict[str, Preset] = {}
|
|
|
|
# Load global presets
|
|
data_global = self._load_file(self.global_path)
|
|
for name, p_data in data_global.get("presets", {}).items():
|
|
try:
|
|
presets[name] = Preset.from_dict(name, p_data)
|
|
except Exception as e:
|
|
print(f"Error parsing global preset '{name}': {e}", file=sys.stderr)
|
|
|
|
# Load project presets (overwriting global ones if names conflict)
|
|
if self.project_path:
|
|
data_project = self._load_file(self.project_path)
|
|
for name, p_data in data_project.get("presets", {}).items():
|
|
try:
|
|
presets[name] = Preset.from_dict(name, p_data)
|
|
except Exception as e:
|
|
print(f"Error parsing project preset '{name}': {e}", file=sys.stderr)
|
|
|
|
return presets
|
|
|
|
def save_preset(self, preset: Preset, scope: str = "project") -> None:
|
|
"""Saves a preset to either the global or project-specific TOML file."""
|
|
path = self.global_path if scope == "global" else self.project_path
|
|
if not path:
|
|
if scope == "project":
|
|
raise ValueError("Project scope requested but no project_root provided.")
|
|
path = self.global_path
|
|
|
|
data = self._load_file(path)
|
|
if "presets" not in data:
|
|
data["presets"] = {}
|
|
|
|
data["presets"][preset.name] = preset.to_dict()
|
|
self._save_file(path, data)
|
|
|
|
def delete_preset(self, name: str, scope: str) -> None:
|
|
if scope == "project" and self.project_root:
|
|
path = get_project_presets_path(self.project_root)
|
|
else:
|
|
path = get_global_presets_path()
|
|
|
|
data = self._load_file(path)
|
|
if name in data.get("presets", {}):
|
|
del data["presets"][name]
|
|
self._save_file(path, data)
|
|
|
|
def get_preset_scope(self, name: str) -> str:
|
|
"""Returns the scope ('global' or 'project') of a preset by name."""
|
|
if self.project_root:
|
|
project_p = get_project_presets_path(self.project_root)
|
|
project_data = self._load_file(project_p)
|
|
if name in project_data.get("presets", {}):
|
|
return "project"
|
|
|
|
global_p = get_global_presets_path()
|
|
global_data = self._load_file(global_p)
|
|
if name in global_data.get("presets", {}):
|
|
return "global"
|
|
|
|
return "project"
|
|
|
|
def _load_file(self, path: Path) -> Dict[str, Any]:
|
|
if not path.exists():
|
|
return {"presets": {}}
|
|
try:
|
|
with open(path, "rb") as f:
|
|
data = tomllib.load(f)
|
|
if not isinstance(data, dict):
|
|
return {"presets": {}}
|
|
if "presets" not in data:
|
|
data["presets"] = {}
|
|
return data
|
|
except Exception as e:
|
|
print(f"Error loading presets from {path}: {e}", file=sys.stderr)
|
|
return {"presets": {}}
|
|
|
|
def _save_file(self, path: Path, data: Dict[str, Any]) -> None:
|
|
if path.parent.exists() and path.parent.is_file():
|
|
raise ValueError(f"Cannot save to {path}: Parent directory {path.parent} is a file.")
|
|
path.parent.mkdir(parents=True, exist_ok=True)
|
|
with open(path, "wb") as f:
|
|
f.write(tomli_w.dumps(data).encode("utf-8"))
|