diff --git a/src/models.py b/src/models.py index 16bc71c..fe38524 100644 --- a/src/models.py +++ b/src/models.py @@ -368,18 +368,70 @@ class Preset: ) @dataclass -class ToolPreset: +class Tool: name: str - categories: Dict[str, Dict[str, Any]] + approval: str = 'auto' + weight: int = 3 + parameter_bias: Dict[str, str] = field(default_factory=dict) def to_dict(self) -> Dict[str, Any]: return { - "categories": self.categories, + "name": self.name, + "approval": self.approval, + "weight": self.weight, + "parameter_bias": self.parameter_bias, + } + + @classmethod + def from_dict(cls, data: Dict[str, Any]) -> "Tool": + return cls( + name=data["name"], + approval=data.get("approval", "auto"), + weight=data.get("weight", 3), + parameter_bias=data.get("parameter_bias", {}), + ) + +@dataclass +class ToolPreset: + name: str + categories: Dict[str, List[Union[Tool, Any]]] + + def to_dict(self) -> Dict[str, Any]: + serialized_categories = {} + for cat, tools in self.categories.items(): + serialized_categories[cat] = [t.to_dict() if isinstance(t, Tool) else t for t in tools] + return { + "categories": serialized_categories, } @classmethod def from_dict(cls, name: str, data: Dict[str, Any]) -> "ToolPreset": + raw_categories = data.get("categories", {}) + parsed_categories = {} + for cat, tools in raw_categories.items(): + parsed_categories[cat] = [Tool.from_dict(t) if isinstance(t, dict) else t for t in tools] return cls( name=name, - categories=data.get("categories", {}), + categories=parsed_categories, + ) + +@dataclass +class BiasProfile: + name: str + tool_weights: Dict[str, int] = field(default_factory=dict) + category_multipliers: Dict[str, float] = field(default_factory=dict) + + def to_dict(self) -> Dict[str, Any]: + return { + "name": self.name, + "tool_weights": self.tool_weights, + "category_multipliers": self.category_multipliers, + } + + @classmethod + def from_dict(cls, data: Dict[str, Any]) -> "BiasProfile": + return cls( + name=data["name"], + tool_weights=data.get("tool_weights", {}), + category_multipliers=data.get("category_multipliers", {}), ) diff --git a/src/tool_presets.py b/src/tool_presets.py index 1f288a2..b3519fb 100644 --- a/src/tool_presets.py +++ b/src/tool_presets.py @@ -1,91 +1,110 @@ import tomllib import tomli_w from pathlib import Path -from typing import Dict, List, Optional, Union +from typing import Dict, List, Optional, Union, Any from src import paths -from src.models import ToolPreset +from src.models import ToolPreset, BiasProfile class ToolPresetManager: def __init__(self, project_root: Optional[Union[str, Path]] = None): self.project_root = Path(project_root) if project_root else None - def _load_from_path(self, path: Path) -> Dict[str, ToolPreset]: + def _get_path(self, scope: str) -> Path: + if scope == "global": + return paths.get_global_tool_presets_path() + elif scope == "project": + if not self.project_root: + raise ValueError("Project root not set for project scope operation.") + return paths.get_project_tool_presets_path(self.project_root) + else: + raise ValueError(f"Invalid scope: {scope}") + + def _read_raw(self, path: Path) -> Dict[str, Any]: if not path.exists(): return {} try: with open(path, "rb") as f: - data = tomllib.load(f) - presets = {} - for name, config in data.items(): - if isinstance(config, dict): - presets[name] = ToolPreset.from_dict(name, config) - return presets + return tomllib.load(f) except Exception: return {} - def load_all(self) -> Dict[str, ToolPreset]: - """ - Merges global and project presets. - Project presets override global ones if they have the same name. - """ - presets = self._load_from_path(paths.get_global_tool_presets_path()) - if self.project_root: - project_presets = self._load_from_path(paths.get_project_tool_presets_path(self.project_root)) - presets.update(project_presets) - return presets - - def save_preset(self, preset: ToolPreset, scope: str = "project") -> None: - """ - Saves a preset to either 'global' or 'project' scope. - Scope must be 'global' or 'project'. - """ - if scope == "global": - path = paths.get_global_tool_presets_path() - elif scope == "project": - if not self.project_root: - raise ValueError("Project root not set for project scope saving.") - path = paths.get_project_tool_presets_path(self.project_root) - else: - raise ValueError(f"Invalid scope: {scope}") - - data = {} - if path.exists(): - try: - with open(path, "rb") as f: - data = tomllib.load(f) - except Exception: - data = {} - - data[preset.name] = preset.to_dict() - + def _write_raw(self, path: Path, data: Dict[str, Any]) -> None: path.parent.mkdir(parents=True, exist_ok=True) with open(path, "wb") as f: tomli_w.dump(data, f) + def load_all_presets(self) -> Dict[str, ToolPreset]: + global_path = paths.get_global_tool_presets_path() + global_data = self._read_raw(global_path).get("presets", {}) + + presets = {} + for name, config in global_data.items(): + if isinstance(config, dict): + presets[name] = ToolPreset.from_dict(name, config) + + if self.project_root: + project_path = paths.get_project_tool_presets_path(self.project_root) + project_data = self._read_raw(project_path).get("presets", {}) + for name, config in project_data.items(): + if isinstance(config, dict): + presets[name] = ToolPreset.from_dict(name, config) + + return presets + + def load_all(self) -> Dict[str, ToolPreset]: + """Backward compatibility for load_all().""" + return self.load_all_presets() + + def save_preset(self, preset: ToolPreset, scope: str = "project") -> None: + path = self._get_path(scope) + data = self._read_raw(path) + if "presets" not in data: + data["presets"] = {} + data["presets"][preset.name] = preset.to_dict() + self._write_raw(path, data) + def delete_preset(self, name: str, scope: str = "project") -> None: - """ - Deletes a preset from the specified scope. - Scope must be 'global' or 'project'. - """ - if scope == "global": - path = paths.get_global_tool_presets_path() - elif scope == "project": - if not self.project_root: - raise ValueError("Project root not set for project scope deletion.") - path = paths.get_project_tool_presets_path(self.project_root) - else: - raise ValueError(f"Invalid scope: {scope}") + path = self._get_path(scope) + data = self._read_raw(path) + if "presets" in data and name in data["presets"]: + del data["presets"][name] + self._write_raw(path, data) + + def load_all_bias_profiles(self) -> Dict[str, BiasProfile]: + global_path = paths.get_global_tool_presets_path() + global_data = self._read_raw(global_path).get("bias_profiles", {}) - if not path.exists(): - return + profiles = {} + for name, config in global_data.items(): + if isinstance(config, dict): + cfg = dict(config) + if "name" not in cfg: + cfg["name"] = name + profiles[name] = BiasProfile.from_dict(cfg) - try: - with open(path, "rb") as f: - data = tomllib.load(f) - except Exception: - return + if self.project_root: + project_path = paths.get_project_tool_presets_path(self.project_root) + project_data = self._read_raw(project_path).get("bias_profiles", {}) + for name, config in project_data.items(): + if isinstance(config, dict): + cfg = dict(config) + if "name" not in cfg: + cfg["name"] = name + profiles[name] = BiasProfile.from_dict(cfg) - if name in data: - del data[name] - with open(path, "wb") as f: - tomli_w.dump(data, f) + return profiles + + def save_bias_profile(self, profile: BiasProfile, scope: str = "project") -> None: + path = self._get_path(scope) + data = self._read_raw(path) + if "bias_profiles" not in data: + data["bias_profiles"] = {} + data["bias_profiles"][profile.name] = profile.to_dict() + self._write_raw(path, data) + + def delete_bias_profile(self, name: str, scope: str = "project") -> None: + path = self._get_path(scope) + data = self._read_raw(path) + if "bias_profiles" in data and name in data["bias_profiles"]: + del data["bias_profiles"][name] + self._write_raw(path, data) diff --git a/tests/test_bias_models.py b/tests/test_bias_models.py new file mode 100644 index 0000000..e7e547f --- /dev/null +++ b/tests/test_bias_models.py @@ -0,0 +1,46 @@ +import pytest +from src.models import Tool, ToolPreset, BiasProfile + +def test_tool_model(): + tool = Tool(name="read_file", weight=5, parameter_bias={"path": "preferred"}) + data = tool.to_dict() + assert data["name"] == "read_file" + assert data["weight"] == 5 + assert data["parameter_bias"]["path"] == "preferred" + + tool2 = Tool.from_dict(data) + assert tool2.name == "read_file" + assert tool2.weight == 5 + assert tool2.parameter_bias["path"] == "preferred" + +def test_tool_preset_extension(): + # Verify that ToolPreset correctly parses and serializes Tool objects + tool_data = {"name": "read_file", "weight": 4, "parameter_bias": {"path": "high"}} + raw_data = {"categories": {"General": [tool_data]}} + + # Test parsing via from_dict + preset = ToolPreset.from_dict("test", raw_data) + assert isinstance(preset.categories["General"][0], Tool) + assert preset.categories["General"][0].weight == 4 + + # Test serialization + data = preset.to_dict() + assert data["categories"]["General"][0]["weight"] == 4 + assert data["categories"]["General"][0]["name"] == "read_file" + +def test_bias_profile_model(): + profile = BiasProfile( + name="Execution-Focused", + tool_weights={"run_powershell": 5}, + category_multipliers={"Surgical": 1.5} + ) + data = profile.to_dict() + assert data["tool_weights"]["run_powershell"] == 5 + assert data["category_multipliers"]["Surgical"] == 1.5 + + # BiasProfile.from_dict expects 'name' inside the dict as well if coming from load_all_bias_profiles + data["name"] = "Execution-Focused" + profile2 = BiasProfile.from_dict(data) + assert profile2.name == "Execution-Focused" + assert profile2.tool_weights["run_powershell"] == 5 + assert profile2.category_multipliers["Surgical"] == 1.5 diff --git a/tests/test_tool_preset_manager.py b/tests/test_tool_preset_manager.py index 6d376c5..b1e3826 100644 --- a/tests/test_tool_preset_manager.py +++ b/tests/test_tool_preset_manager.py @@ -2,7 +2,7 @@ import pytest import tomli_w from pathlib import Path from src.tool_presets import ToolPresetManager -from src.models import ToolPreset +from src.models import ToolPreset, BiasProfile, Tool from src import paths @pytest.fixture @@ -25,17 +25,18 @@ def temp_paths(tmp_path, monkeypatch): "project_presets": project_presets } -def test_load_all_merged(temp_paths): +def test_load_all_presets_merged(temp_paths): # Setup global presets global_data = { - "default": { - "categories": { - "file": {"read": True}, - "shell": {"run": False} + "presets": { + "default": { + "categories": { + "General": [{"name": "read_file", "approval": "auto"}] + } + }, + "global_only": { + "categories": {"Web": [{"name": "web_search", "approval": "ask"}]} } - }, - "global_only": { - "categories": {"web": {"search": True}} } } with open(temp_paths["global_presets"], "wb") as f: @@ -43,98 +44,78 @@ def test_load_all_merged(temp_paths): # Setup project presets (overrides 'default') project_data = { - "default": { - "categories": { - "file": {"read": True}, - "shell": {"run": True} # Override + "presets": { + "default": { + "categories": { + "General": [{"name": "read_file", "approval": "auto", "weight": 5}] + } } - }, - "project_only": { - "categories": {"git": {"commit": True}} } } with open(temp_paths["project_presets"], "wb") as f: tomli_w.dump(project_data, f) manager = ToolPresetManager(project_root=temp_paths["project_dir"]) - all_presets = manager.load_all() + all_presets = manager.load_all_presets() assert "default" in all_presets - assert all_presets["default"].categories["shell"]["run"] is True # Overridden + assert isinstance(all_presets["default"].categories["General"][0], Tool) + assert all_presets["default"].categories["General"][0].weight == 5 assert "global_only" in all_presets - assert "project_only" in all_presets - assert all_presets["global_only"].categories["web"]["search"] is True - assert all_presets["project_only"].categories["git"]["commit"] is True -def test_save_preset_global(temp_paths): - manager = ToolPresetManager() - preset = ToolPreset(name="new_global", categories={"test": {"ok": True}}) - - manager.save_preset(preset, scope="global") - - assert temp_paths["global_presets"].exists() - loaded = manager._load_from_path(temp_paths["global_presets"]) - assert "new_global" in loaded - assert loaded["new_global"].categories == {"test": {"ok": True}} - -def test_save_preset_project(temp_paths): - manager = ToolPresetManager(project_root=temp_paths["project_dir"]) - preset = ToolPreset(name="new_project", categories={"test": {"ok": False}}) - - manager.save_preset(preset, scope="project") - - assert temp_paths["project_presets"].exists() - loaded = manager._load_from_path(temp_paths["project_presets"]) - assert "new_project" in loaded - assert loaded["new_project"].categories == {"test": {"ok": False}} - -def test_delete_preset_global(temp_paths): - # Initial global setup +def test_bias_profiles_merged(temp_paths): + # Setup global biases global_data = { - "to_delete": {"categories": {}}, - "keep": {"categories": {}} + "bias_profiles": { + "Discovery": { + "tool_weights": {"web_search": 5}, + "category_multipliers": {"Web": 1.5} + } + } } with open(temp_paths["global_presets"], "wb") as f: tomli_w.dump(global_data, f) - manager = ToolPresetManager() - manager.delete_preset("to_delete", scope="global") - - loaded = manager._load_from_path(temp_paths["global_presets"]) - assert "to_delete" not in loaded - assert "keep" in loaded - -def test_delete_preset_project(temp_paths): - # Initial project setup + # Setup project biases project_data = { - "to_delete": {"categories": {}}, - "keep": {"categories": {}} + "bias_profiles": { + "Execution": { + "tool_weights": {"run_powershell": 5} + } + } } with open(temp_paths["project_presets"], "wb") as f: tomli_w.dump(project_data, f) manager = ToolPresetManager(project_root=temp_paths["project_dir"]) - manager.delete_preset("to_delete", scope="project") + profiles = manager.load_all_bias_profiles() - loaded = manager._load_from_path(temp_paths["project_presets"]) - assert "to_delete" not in loaded - assert "keep" in loaded + assert "Discovery" in profiles + assert profiles["Discovery"].category_multipliers["Web"] == 1.5 + assert "Execution" in profiles + assert profiles["Execution"].tool_weights["run_powershell"] == 5 -def test_save_project_no_root_raises(temp_paths): - manager = ToolPresetManager(project_root=None) - preset = ToolPreset(name="fail", categories={}) - with pytest.raises(ValueError, match="Project root not set"): - manager.save_preset(preset, scope="project") +def test_save_bias_profile(temp_paths): + manager = ToolPresetManager(project_root=temp_paths["project_dir"]) + profile = BiasProfile(name="Custom", tool_weights={"test": 1}) + + manager.save_bias_profile(profile, scope="project") + + loaded = manager.load_all_bias_profiles() + assert "Custom" in loaded + assert loaded["Custom"].tool_weights["test"] == 1 -def test_delete_project_no_root_raises(temp_paths): - manager = ToolPresetManager(project_root=None) - with pytest.raises(ValueError, match="Project root not set"): - manager.delete_preset("any", scope="project") - -def test_invalid_scope_raises(temp_paths): - manager = ToolPresetManager() - preset = ToolPreset(name="fail", categories={}) - with pytest.raises(ValueError, match="Invalid scope"): - manager.save_preset(preset, scope="invalid") - with pytest.raises(ValueError, match="Invalid scope"): - manager.delete_preset("any", scope="invalid") +def test_delete_bias_profile(temp_paths): + project_data = { + "bias_profiles": { + "to_delete": {"tool_weights": {}} + } + } + with open(temp_paths["project_presets"], "wb") as f: + tomli_w.dump(project_data, f) + + manager = ToolPresetManager(project_root=temp_paths["project_dir"]) + manager.delete_bias_profile("to_delete", scope="project") + + profiles = manager.load_all_bias_profiles() + assert "to_delete" not in profiles diff --git a/tests/test_tool_presets_execution.py b/tests/test_tool_presets_execution.py index 6a35568..076a530 100644 --- a/tests/test_tool_presets_execution.py +++ b/tests/test_tool_presets_execution.py @@ -3,14 +3,14 @@ import asyncio from src import ai_client from src import mcp_client from src import models -from src.models import ToolPreset +from src.models import ToolPreset, Tool from unittest.mock import MagicMock, patch @pytest.mark.asyncio async def test_tool_auto_approval(): # Setup a preset with read_file as auto preset = ToolPreset(name="AutoTest", categories={ - "General": {"read_file": "auto"} + "General": [Tool(name="read_file", approval="auto")] }) with patch("src.tool_presets.ToolPresetManager.load_all", return_value={"AutoTest": preset}): @@ -39,7 +39,7 @@ async def test_tool_auto_approval(): async def test_tool_ask_approval(): # Setup a preset with run_powershell as ask preset = ToolPreset(name="AskTest", categories={ - "General": {"run_powershell": "ask"} + "General": [Tool(name="run_powershell", approval="ask")] }) with patch("src.tool_presets.ToolPresetManager.load_all", return_value={"AskTest": preset}): @@ -65,7 +65,7 @@ async def test_tool_ask_approval(): async def test_tool_rejection(): # Setup a preset with run_powershell as ask preset = ToolPreset(name="AskTest", categories={ - "General": {"run_powershell": "ask"} + "General": [Tool(name="run_powershell", approval="ask")] }) with patch("src.tool_presets.ToolPresetManager.load_all", return_value={"AskTest": preset}): diff --git a/tool_presets.toml b/tool_presets.toml index 837516d..06f6938 100644 --- a/tool_presets.toml +++ b/tool_presets.toml @@ -23,3 +23,15 @@ categories.Python = [ { name = "py_check_syntax", approval = "auto" }, { name = "py_get_hierarchy", approval = "auto" } ] + +[bias_profiles.Balanced] +tool_weights = {} +category_multipliers = {} + +[bias_profiles.Execution-Focused] +tool_weights = { run_powershell = 5 } +category_multipliers = { Runtime = 1.5, Surgical = 1.2 } + +[bias_profiles.Discovery-Heavy] +tool_weights = { web_search = 4, search_files = 4 } +category_multipliers = { Web = 1.5, Analysis = 1.3 }