feat(bias): implement data models and storage for tool weighting and bias profiles

This commit is contained in:
2026-03-10 09:27:12 -04:00
parent ee19cc1d2a
commit 77a0b385d5
6 changed files with 264 additions and 154 deletions

View File

@@ -368,18 +368,70 @@ class Preset:
) )
@dataclass @dataclass
class ToolPreset: class Tool:
name: str name: str
categories: Dict[str, Dict[str, Any]] approval: str = 'auto'
weight: int = 3
parameter_bias: Dict[str, str] = field(default_factory=dict)
def to_dict(self) -> Dict[str, Any]: def to_dict(self) -> Dict[str, Any]:
return { return {
"categories": self.categories, "name": self.name,
"approval": self.approval,
"weight": self.weight,
"parameter_bias": self.parameter_bias,
}
@classmethod
def from_dict(cls, data: Dict[str, Any]) -> "Tool":
return cls(
name=data["name"],
approval=data.get("approval", "auto"),
weight=data.get("weight", 3),
parameter_bias=data.get("parameter_bias", {}),
)
@dataclass
class ToolPreset:
name: str
categories: Dict[str, List[Union[Tool, Any]]]
def to_dict(self) -> Dict[str, Any]:
serialized_categories = {}
for cat, tools in self.categories.items():
serialized_categories[cat] = [t.to_dict() if isinstance(t, Tool) else t for t in tools]
return {
"categories": serialized_categories,
} }
@classmethod @classmethod
def from_dict(cls, name: str, data: Dict[str, Any]) -> "ToolPreset": def from_dict(cls, name: str, data: Dict[str, Any]) -> "ToolPreset":
raw_categories = data.get("categories", {})
parsed_categories = {}
for cat, tools in raw_categories.items():
parsed_categories[cat] = [Tool.from_dict(t) if isinstance(t, dict) else t for t in tools]
return cls( return cls(
name=name, name=name,
categories=data.get("categories", {}), categories=parsed_categories,
)
@dataclass
class BiasProfile:
name: str
tool_weights: Dict[str, int] = field(default_factory=dict)
category_multipliers: Dict[str, float] = field(default_factory=dict)
def to_dict(self) -> Dict[str, Any]:
return {
"name": self.name,
"tool_weights": self.tool_weights,
"category_multipliers": self.category_multipliers,
}
@classmethod
def from_dict(cls, data: Dict[str, Any]) -> "BiasProfile":
return cls(
name=data["name"],
tool_weights=data.get("tool_weights", {}),
category_multipliers=data.get("category_multipliers", {}),
) )

View File

@@ -1,91 +1,110 @@
import tomllib import tomllib
import tomli_w import tomli_w
from pathlib import Path from pathlib import Path
from typing import Dict, List, Optional, Union from typing import Dict, List, Optional, Union, Any
from src import paths from src import paths
from src.models import ToolPreset from src.models import ToolPreset, BiasProfile
class ToolPresetManager: class ToolPresetManager:
def __init__(self, project_root: Optional[Union[str, Path]] = None): def __init__(self, project_root: Optional[Union[str, Path]] = None):
self.project_root = Path(project_root) if project_root else None self.project_root = Path(project_root) if project_root else None
def _load_from_path(self, path: Path) -> Dict[str, ToolPreset]: def _get_path(self, scope: str) -> Path:
if scope == "global":
return paths.get_global_tool_presets_path()
elif scope == "project":
if not self.project_root:
raise ValueError("Project root not set for project scope operation.")
return paths.get_project_tool_presets_path(self.project_root)
else:
raise ValueError(f"Invalid scope: {scope}")
def _read_raw(self, path: Path) -> Dict[str, Any]:
if not path.exists(): if not path.exists():
return {} return {}
try: try:
with open(path, "rb") as f: with open(path, "rb") as f:
data = tomllib.load(f) return tomllib.load(f)
presets = {}
for name, config in data.items():
if isinstance(config, dict):
presets[name] = ToolPreset.from_dict(name, config)
return presets
except Exception: except Exception:
return {} return {}
def load_all(self) -> Dict[str, ToolPreset]: def _write_raw(self, path: Path, data: Dict[str, Any]) -> None:
"""
Merges global and project presets.
Project presets override global ones if they have the same name.
"""
presets = self._load_from_path(paths.get_global_tool_presets_path())
if self.project_root:
project_presets = self._load_from_path(paths.get_project_tool_presets_path(self.project_root))
presets.update(project_presets)
return presets
def save_preset(self, preset: ToolPreset, scope: str = "project") -> None:
"""
Saves a preset to either 'global' or 'project' scope.
Scope must be 'global' or 'project'.
"""
if scope == "global":
path = paths.get_global_tool_presets_path()
elif scope == "project":
if not self.project_root:
raise ValueError("Project root not set for project scope saving.")
path = paths.get_project_tool_presets_path(self.project_root)
else:
raise ValueError(f"Invalid scope: {scope}")
data = {}
if path.exists():
try:
with open(path, "rb") as f:
data = tomllib.load(f)
except Exception:
data = {}
data[preset.name] = preset.to_dict()
path.parent.mkdir(parents=True, exist_ok=True) path.parent.mkdir(parents=True, exist_ok=True)
with open(path, "wb") as f: with open(path, "wb") as f:
tomli_w.dump(data, f) tomli_w.dump(data, f)
def load_all_presets(self) -> Dict[str, ToolPreset]:
global_path = paths.get_global_tool_presets_path()
global_data = self._read_raw(global_path).get("presets", {})
presets = {}
for name, config in global_data.items():
if isinstance(config, dict):
presets[name] = ToolPreset.from_dict(name, config)
if self.project_root:
project_path = paths.get_project_tool_presets_path(self.project_root)
project_data = self._read_raw(project_path).get("presets", {})
for name, config in project_data.items():
if isinstance(config, dict):
presets[name] = ToolPreset.from_dict(name, config)
return presets
def load_all(self) -> Dict[str, ToolPreset]:
"""Backward compatibility for load_all()."""
return self.load_all_presets()
def save_preset(self, preset: ToolPreset, scope: str = "project") -> None:
path = self._get_path(scope)
data = self._read_raw(path)
if "presets" not in data:
data["presets"] = {}
data["presets"][preset.name] = preset.to_dict()
self._write_raw(path, data)
def delete_preset(self, name: str, scope: str = "project") -> None: def delete_preset(self, name: str, scope: str = "project") -> None:
""" path = self._get_path(scope)
Deletes a preset from the specified scope. data = self._read_raw(path)
Scope must be 'global' or 'project'. if "presets" in data and name in data["presets"]:
""" del data["presets"][name]
if scope == "global": self._write_raw(path, data)
path = paths.get_global_tool_presets_path()
elif scope == "project": def load_all_bias_profiles(self) -> Dict[str, BiasProfile]:
if not self.project_root: global_path = paths.get_global_tool_presets_path()
raise ValueError("Project root not set for project scope deletion.") global_data = self._read_raw(global_path).get("bias_profiles", {})
path = paths.get_project_tool_presets_path(self.project_root)
else:
raise ValueError(f"Invalid scope: {scope}")
if not path.exists(): profiles = {}
return for name, config in global_data.items():
if isinstance(config, dict):
cfg = dict(config)
if "name" not in cfg:
cfg["name"] = name
profiles[name] = BiasProfile.from_dict(cfg)
try: if self.project_root:
with open(path, "rb") as f: project_path = paths.get_project_tool_presets_path(self.project_root)
data = tomllib.load(f) project_data = self._read_raw(project_path).get("bias_profiles", {})
except Exception: for name, config in project_data.items():
return if isinstance(config, dict):
cfg = dict(config)
if "name" not in cfg:
cfg["name"] = name
profiles[name] = BiasProfile.from_dict(cfg)
if name in data: return profiles
del data[name]
with open(path, "wb") as f: def save_bias_profile(self, profile: BiasProfile, scope: str = "project") -> None:
tomli_w.dump(data, f) path = self._get_path(scope)
data = self._read_raw(path)
if "bias_profiles" not in data:
data["bias_profiles"] = {}
data["bias_profiles"][profile.name] = profile.to_dict()
self._write_raw(path, data)
def delete_bias_profile(self, name: str, scope: str = "project") -> None:
path = self._get_path(scope)
data = self._read_raw(path)
if "bias_profiles" in data and name in data["bias_profiles"]:
del data["bias_profiles"][name]
self._write_raw(path, data)

46
tests/test_bias_models.py Normal file
View File

@@ -0,0 +1,46 @@
import pytest
from src.models import Tool, ToolPreset, BiasProfile
def test_tool_model():
tool = Tool(name="read_file", weight=5, parameter_bias={"path": "preferred"})
data = tool.to_dict()
assert data["name"] == "read_file"
assert data["weight"] == 5
assert data["parameter_bias"]["path"] == "preferred"
tool2 = Tool.from_dict(data)
assert tool2.name == "read_file"
assert tool2.weight == 5
assert tool2.parameter_bias["path"] == "preferred"
def test_tool_preset_extension():
# Verify that ToolPreset correctly parses and serializes Tool objects
tool_data = {"name": "read_file", "weight": 4, "parameter_bias": {"path": "high"}}
raw_data = {"categories": {"General": [tool_data]}}
# Test parsing via from_dict
preset = ToolPreset.from_dict("test", raw_data)
assert isinstance(preset.categories["General"][0], Tool)
assert preset.categories["General"][0].weight == 4
# Test serialization
data = preset.to_dict()
assert data["categories"]["General"][0]["weight"] == 4
assert data["categories"]["General"][0]["name"] == "read_file"
def test_bias_profile_model():
profile = BiasProfile(
name="Execution-Focused",
tool_weights={"run_powershell": 5},
category_multipliers={"Surgical": 1.5}
)
data = profile.to_dict()
assert data["tool_weights"]["run_powershell"] == 5
assert data["category_multipliers"]["Surgical"] == 1.5
# BiasProfile.from_dict expects 'name' inside the dict as well if coming from load_all_bias_profiles
data["name"] = "Execution-Focused"
profile2 = BiasProfile.from_dict(data)
assert profile2.name == "Execution-Focused"
assert profile2.tool_weights["run_powershell"] == 5
assert profile2.category_multipliers["Surgical"] == 1.5

View File

@@ -2,7 +2,7 @@ import pytest
import tomli_w import tomli_w
from pathlib import Path from pathlib import Path
from src.tool_presets import ToolPresetManager from src.tool_presets import ToolPresetManager
from src.models import ToolPreset from src.models import ToolPreset, BiasProfile, Tool
from src import paths from src import paths
@pytest.fixture @pytest.fixture
@@ -25,17 +25,18 @@ def temp_paths(tmp_path, monkeypatch):
"project_presets": project_presets "project_presets": project_presets
} }
def test_load_all_merged(temp_paths): def test_load_all_presets_merged(temp_paths):
# Setup global presets # Setup global presets
global_data = { global_data = {
"default": { "presets": {
"categories": { "default": {
"file": {"read": True}, "categories": {
"shell": {"run": False} "General": [{"name": "read_file", "approval": "auto"}]
}
},
"global_only": {
"categories": {"Web": [{"name": "web_search", "approval": "ask"}]}
} }
},
"global_only": {
"categories": {"web": {"search": True}}
} }
} }
with open(temp_paths["global_presets"], "wb") as f: with open(temp_paths["global_presets"], "wb") as f:
@@ -43,98 +44,78 @@ def test_load_all_merged(temp_paths):
# Setup project presets (overrides 'default') # Setup project presets (overrides 'default')
project_data = { project_data = {
"default": { "presets": {
"categories": { "default": {
"file": {"read": True}, "categories": {
"shell": {"run": True} # Override "General": [{"name": "read_file", "approval": "auto", "weight": 5}]
}
} }
},
"project_only": {
"categories": {"git": {"commit": True}}
} }
} }
with open(temp_paths["project_presets"], "wb") as f: with open(temp_paths["project_presets"], "wb") as f:
tomli_w.dump(project_data, f) tomli_w.dump(project_data, f)
manager = ToolPresetManager(project_root=temp_paths["project_dir"]) manager = ToolPresetManager(project_root=temp_paths["project_dir"])
all_presets = manager.load_all() all_presets = manager.load_all_presets()
assert "default" in all_presets assert "default" in all_presets
assert all_presets["default"].categories["shell"]["run"] is True # Overridden assert isinstance(all_presets["default"].categories["General"][0], Tool)
assert all_presets["default"].categories["General"][0].weight == 5
assert "global_only" in all_presets assert "global_only" in all_presets
assert "project_only" in all_presets
assert all_presets["global_only"].categories["web"]["search"] is True
assert all_presets["project_only"].categories["git"]["commit"] is True
def test_save_preset_global(temp_paths): def test_bias_profiles_merged(temp_paths):
manager = ToolPresetManager() # Setup global biases
preset = ToolPreset(name="new_global", categories={"test": {"ok": True}})
manager.save_preset(preset, scope="global")
assert temp_paths["global_presets"].exists()
loaded = manager._load_from_path(temp_paths["global_presets"])
assert "new_global" in loaded
assert loaded["new_global"].categories == {"test": {"ok": True}}
def test_save_preset_project(temp_paths):
manager = ToolPresetManager(project_root=temp_paths["project_dir"])
preset = ToolPreset(name="new_project", categories={"test": {"ok": False}})
manager.save_preset(preset, scope="project")
assert temp_paths["project_presets"].exists()
loaded = manager._load_from_path(temp_paths["project_presets"])
assert "new_project" in loaded
assert loaded["new_project"].categories == {"test": {"ok": False}}
def test_delete_preset_global(temp_paths):
# Initial global setup
global_data = { global_data = {
"to_delete": {"categories": {}}, "bias_profiles": {
"keep": {"categories": {}} "Discovery": {
"tool_weights": {"web_search": 5},
"category_multipliers": {"Web": 1.5}
}
}
} }
with open(temp_paths["global_presets"], "wb") as f: with open(temp_paths["global_presets"], "wb") as f:
tomli_w.dump(global_data, f) tomli_w.dump(global_data, f)
manager = ToolPresetManager() # Setup project biases
manager.delete_preset("to_delete", scope="global")
loaded = manager._load_from_path(temp_paths["global_presets"])
assert "to_delete" not in loaded
assert "keep" in loaded
def test_delete_preset_project(temp_paths):
# Initial project setup
project_data = { project_data = {
"to_delete": {"categories": {}}, "bias_profiles": {
"keep": {"categories": {}} "Execution": {
"tool_weights": {"run_powershell": 5}
}
}
} }
with open(temp_paths["project_presets"], "wb") as f: with open(temp_paths["project_presets"], "wb") as f:
tomli_w.dump(project_data, f) tomli_w.dump(project_data, f)
manager = ToolPresetManager(project_root=temp_paths["project_dir"]) manager = ToolPresetManager(project_root=temp_paths["project_dir"])
manager.delete_preset("to_delete", scope="project") profiles = manager.load_all_bias_profiles()
loaded = manager._load_from_path(temp_paths["project_presets"]) assert "Discovery" in profiles
assert "to_delete" not in loaded assert profiles["Discovery"].category_multipliers["Web"] == 1.5
assert "keep" in loaded assert "Execution" in profiles
assert profiles["Execution"].tool_weights["run_powershell"] == 5
def test_save_project_no_root_raises(temp_paths): def test_save_bias_profile(temp_paths):
manager = ToolPresetManager(project_root=None) manager = ToolPresetManager(project_root=temp_paths["project_dir"])
preset = ToolPreset(name="fail", categories={}) profile = BiasProfile(name="Custom", tool_weights={"test": 1})
with pytest.raises(ValueError, match="Project root not set"):
manager.save_preset(preset, scope="project") manager.save_bias_profile(profile, scope="project")
loaded = manager.load_all_bias_profiles()
assert "Custom" in loaded
assert loaded["Custom"].tool_weights["test"] == 1
def test_delete_project_no_root_raises(temp_paths): def test_delete_bias_profile(temp_paths):
manager = ToolPresetManager(project_root=None) project_data = {
with pytest.raises(ValueError, match="Project root not set"): "bias_profiles": {
manager.delete_preset("any", scope="project") "to_delete": {"tool_weights": {}}
}
def test_invalid_scope_raises(temp_paths): }
manager = ToolPresetManager() with open(temp_paths["project_presets"], "wb") as f:
preset = ToolPreset(name="fail", categories={}) tomli_w.dump(project_data, f)
with pytest.raises(ValueError, match="Invalid scope"):
manager.save_preset(preset, scope="invalid") manager = ToolPresetManager(project_root=temp_paths["project_dir"])
with pytest.raises(ValueError, match="Invalid scope"): manager.delete_bias_profile("to_delete", scope="project")
manager.delete_preset("any", scope="invalid")
profiles = manager.load_all_bias_profiles()
assert "to_delete" not in profiles

View File

@@ -3,14 +3,14 @@ import asyncio
from src import ai_client from src import ai_client
from src import mcp_client from src import mcp_client
from src import models from src import models
from src.models import ToolPreset from src.models import ToolPreset, Tool
from unittest.mock import MagicMock, patch from unittest.mock import MagicMock, patch
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_tool_auto_approval(): async def test_tool_auto_approval():
# Setup a preset with read_file as auto # Setup a preset with read_file as auto
preset = ToolPreset(name="AutoTest", categories={ preset = ToolPreset(name="AutoTest", categories={
"General": {"read_file": "auto"} "General": [Tool(name="read_file", approval="auto")]
}) })
with patch("src.tool_presets.ToolPresetManager.load_all", return_value={"AutoTest": preset}): with patch("src.tool_presets.ToolPresetManager.load_all", return_value={"AutoTest": preset}):
@@ -39,7 +39,7 @@ async def test_tool_auto_approval():
async def test_tool_ask_approval(): async def test_tool_ask_approval():
# Setup a preset with run_powershell as ask # Setup a preset with run_powershell as ask
preset = ToolPreset(name="AskTest", categories={ preset = ToolPreset(name="AskTest", categories={
"General": {"run_powershell": "ask"} "General": [Tool(name="run_powershell", approval="ask")]
}) })
with patch("src.tool_presets.ToolPresetManager.load_all", return_value={"AskTest": preset}): with patch("src.tool_presets.ToolPresetManager.load_all", return_value={"AskTest": preset}):
@@ -65,7 +65,7 @@ async def test_tool_ask_approval():
async def test_tool_rejection(): async def test_tool_rejection():
# Setup a preset with run_powershell as ask # Setup a preset with run_powershell as ask
preset = ToolPreset(name="AskTest", categories={ preset = ToolPreset(name="AskTest", categories={
"General": {"run_powershell": "ask"} "General": [Tool(name="run_powershell", approval="ask")]
}) })
with patch("src.tool_presets.ToolPresetManager.load_all", return_value={"AskTest": preset}): with patch("src.tool_presets.ToolPresetManager.load_all", return_value={"AskTest": preset}):

View File

@@ -23,3 +23,15 @@ categories.Python = [
{ name = "py_check_syntax", approval = "auto" }, { name = "py_check_syntax", approval = "auto" },
{ name = "py_get_hierarchy", approval = "auto" } { name = "py_get_hierarchy", approval = "auto" }
] ]
[bias_profiles.Balanced]
tool_weights = {}
category_multipliers = {}
[bias_profiles.Execution-Focused]
tool_weights = { run_powershell = 5 }
category_multipliers = { Runtime = 1.5, Surgical = 1.2 }
[bias_profiles.Discovery-Heavy]
tool_weights = { web_search = 4, search_files = 4 }
category_multipliers = { Web = 1.5, Analysis = 1.3 }