feat(bias): implement data models and storage for tool weighting and bias profiles

This commit is contained in:
2026-03-10 09:27:12 -04:00
parent ee19cc1d2a
commit 77a0b385d5
6 changed files with 264 additions and 154 deletions

View File

@@ -368,18 +368,70 @@ class Preset:
)
@dataclass
class ToolPreset:
class Tool:
name: str
categories: Dict[str, Dict[str, Any]]
approval: str = 'auto'
weight: int = 3
parameter_bias: Dict[str, str] = field(default_factory=dict)
def to_dict(self) -> Dict[str, Any]:
return {
"categories": self.categories,
"name": self.name,
"approval": self.approval,
"weight": self.weight,
"parameter_bias": self.parameter_bias,
}
@classmethod
def from_dict(cls, data: Dict[str, Any]) -> "Tool":
return cls(
name=data["name"],
approval=data.get("approval", "auto"),
weight=data.get("weight", 3),
parameter_bias=data.get("parameter_bias", {}),
)
@dataclass
class ToolPreset:
name: str
categories: Dict[str, List[Union[Tool, Any]]]
def to_dict(self) -> Dict[str, Any]:
serialized_categories = {}
for cat, tools in self.categories.items():
serialized_categories[cat] = [t.to_dict() if isinstance(t, Tool) else t for t in tools]
return {
"categories": serialized_categories,
}
@classmethod
def from_dict(cls, name: str, data: Dict[str, Any]) -> "ToolPreset":
raw_categories = data.get("categories", {})
parsed_categories = {}
for cat, tools in raw_categories.items():
parsed_categories[cat] = [Tool.from_dict(t) if isinstance(t, dict) else t for t in tools]
return cls(
name=name,
categories=data.get("categories", {}),
categories=parsed_categories,
)
@dataclass
class BiasProfile:
name: str
tool_weights: Dict[str, int] = field(default_factory=dict)
category_multipliers: Dict[str, float] = field(default_factory=dict)
def to_dict(self) -> Dict[str, Any]:
return {
"name": self.name,
"tool_weights": self.tool_weights,
"category_multipliers": self.category_multipliers,
}
@classmethod
def from_dict(cls, data: Dict[str, Any]) -> "BiasProfile":
return cls(
name=data["name"],
tool_weights=data.get("tool_weights", {}),
category_multipliers=data.get("category_multipliers", {}),
)

View File

@@ -1,91 +1,110 @@
import tomllib
import tomli_w
from pathlib import Path
from typing import Dict, List, Optional, Union
from typing import Dict, List, Optional, Union, Any
from src import paths
from src.models import ToolPreset
from src.models import ToolPreset, BiasProfile
class ToolPresetManager:
def __init__(self, project_root: Optional[Union[str, Path]] = None):
self.project_root = Path(project_root) if project_root else None
def _load_from_path(self, path: Path) -> Dict[str, ToolPreset]:
def _get_path(self, scope: str) -> Path:
if scope == "global":
return paths.get_global_tool_presets_path()
elif scope == "project":
if not self.project_root:
raise ValueError("Project root not set for project scope operation.")
return paths.get_project_tool_presets_path(self.project_root)
else:
raise ValueError(f"Invalid scope: {scope}")
def _read_raw(self, path: Path) -> Dict[str, Any]:
if not path.exists():
return {}
try:
with open(path, "rb") as f:
data = tomllib.load(f)
presets = {}
for name, config in data.items():
if isinstance(config, dict):
presets[name] = ToolPreset.from_dict(name, config)
return presets
return tomllib.load(f)
except Exception:
return {}
def load_all(self) -> Dict[str, ToolPreset]:
"""
Merges global and project presets.
Project presets override global ones if they have the same name.
"""
presets = self._load_from_path(paths.get_global_tool_presets_path())
if self.project_root:
project_presets = self._load_from_path(paths.get_project_tool_presets_path(self.project_root))
presets.update(project_presets)
return presets
def save_preset(self, preset: ToolPreset, scope: str = "project") -> None:
"""
Saves a preset to either 'global' or 'project' scope.
Scope must be 'global' or 'project'.
"""
if scope == "global":
path = paths.get_global_tool_presets_path()
elif scope == "project":
if not self.project_root:
raise ValueError("Project root not set for project scope saving.")
path = paths.get_project_tool_presets_path(self.project_root)
else:
raise ValueError(f"Invalid scope: {scope}")
data = {}
if path.exists():
try:
with open(path, "rb") as f:
data = tomllib.load(f)
except Exception:
data = {}
data[preset.name] = preset.to_dict()
def _write_raw(self, path: Path, data: Dict[str, Any]) -> None:
path.parent.mkdir(parents=True, exist_ok=True)
with open(path, "wb") as f:
tomli_w.dump(data, f)
def load_all_presets(self) -> Dict[str, ToolPreset]:
global_path = paths.get_global_tool_presets_path()
global_data = self._read_raw(global_path).get("presets", {})
presets = {}
for name, config in global_data.items():
if isinstance(config, dict):
presets[name] = ToolPreset.from_dict(name, config)
if self.project_root:
project_path = paths.get_project_tool_presets_path(self.project_root)
project_data = self._read_raw(project_path).get("presets", {})
for name, config in project_data.items():
if isinstance(config, dict):
presets[name] = ToolPreset.from_dict(name, config)
return presets
def load_all(self) -> Dict[str, ToolPreset]:
"""Backward compatibility for load_all()."""
return self.load_all_presets()
def save_preset(self, preset: ToolPreset, scope: str = "project") -> None:
path = self._get_path(scope)
data = self._read_raw(path)
if "presets" not in data:
data["presets"] = {}
data["presets"][preset.name] = preset.to_dict()
self._write_raw(path, data)
def delete_preset(self, name: str, scope: str = "project") -> None:
"""
Deletes a preset from the specified scope.
Scope must be 'global' or 'project'.
"""
if scope == "global":
path = paths.get_global_tool_presets_path()
elif scope == "project":
if not self.project_root:
raise ValueError("Project root not set for project scope deletion.")
path = paths.get_project_tool_presets_path(self.project_root)
else:
raise ValueError(f"Invalid scope: {scope}")
path = self._get_path(scope)
data = self._read_raw(path)
if "presets" in data and name in data["presets"]:
del data["presets"][name]
self._write_raw(path, data)
if not path.exists():
return
def load_all_bias_profiles(self) -> Dict[str, BiasProfile]:
global_path = paths.get_global_tool_presets_path()
global_data = self._read_raw(global_path).get("bias_profiles", {})
try:
with open(path, "rb") as f:
data = tomllib.load(f)
except Exception:
return
profiles = {}
for name, config in global_data.items():
if isinstance(config, dict):
cfg = dict(config)
if "name" not in cfg:
cfg["name"] = name
profiles[name] = BiasProfile.from_dict(cfg)
if name in data:
del data[name]
with open(path, "wb") as f:
tomli_w.dump(data, f)
if self.project_root:
project_path = paths.get_project_tool_presets_path(self.project_root)
project_data = self._read_raw(project_path).get("bias_profiles", {})
for name, config in project_data.items():
if isinstance(config, dict):
cfg = dict(config)
if "name" not in cfg:
cfg["name"] = name
profiles[name] = BiasProfile.from_dict(cfg)
return profiles
def save_bias_profile(self, profile: BiasProfile, scope: str = "project") -> None:
path = self._get_path(scope)
data = self._read_raw(path)
if "bias_profiles" not in data:
data["bias_profiles"] = {}
data["bias_profiles"][profile.name] = profile.to_dict()
self._write_raw(path, data)
def delete_bias_profile(self, name: str, scope: str = "project") -> None:
path = self._get_path(scope)
data = self._read_raw(path)
if "bias_profiles" in data and name in data["bias_profiles"]:
del data["bias_profiles"][name]
self._write_raw(path, data)

46
tests/test_bias_models.py Normal file
View File

@@ -0,0 +1,46 @@
import pytest
from src.models import Tool, ToolPreset, BiasProfile
def test_tool_model():
tool = Tool(name="read_file", weight=5, parameter_bias={"path": "preferred"})
data = tool.to_dict()
assert data["name"] == "read_file"
assert data["weight"] == 5
assert data["parameter_bias"]["path"] == "preferred"
tool2 = Tool.from_dict(data)
assert tool2.name == "read_file"
assert tool2.weight == 5
assert tool2.parameter_bias["path"] == "preferred"
def test_tool_preset_extension():
# Verify that ToolPreset correctly parses and serializes Tool objects
tool_data = {"name": "read_file", "weight": 4, "parameter_bias": {"path": "high"}}
raw_data = {"categories": {"General": [tool_data]}}
# Test parsing via from_dict
preset = ToolPreset.from_dict("test", raw_data)
assert isinstance(preset.categories["General"][0], Tool)
assert preset.categories["General"][0].weight == 4
# Test serialization
data = preset.to_dict()
assert data["categories"]["General"][0]["weight"] == 4
assert data["categories"]["General"][0]["name"] == "read_file"
def test_bias_profile_model():
profile = BiasProfile(
name="Execution-Focused",
tool_weights={"run_powershell": 5},
category_multipliers={"Surgical": 1.5}
)
data = profile.to_dict()
assert data["tool_weights"]["run_powershell"] == 5
assert data["category_multipliers"]["Surgical"] == 1.5
# BiasProfile.from_dict expects 'name' inside the dict as well if coming from load_all_bias_profiles
data["name"] = "Execution-Focused"
profile2 = BiasProfile.from_dict(data)
assert profile2.name == "Execution-Focused"
assert profile2.tool_weights["run_powershell"] == 5
assert profile2.category_multipliers["Surgical"] == 1.5

View File

@@ -2,7 +2,7 @@ import pytest
import tomli_w
from pathlib import Path
from src.tool_presets import ToolPresetManager
from src.models import ToolPreset
from src.models import ToolPreset, BiasProfile, Tool
from src import paths
@pytest.fixture
@@ -25,17 +25,18 @@ def temp_paths(tmp_path, monkeypatch):
"project_presets": project_presets
}
def test_load_all_merged(temp_paths):
def test_load_all_presets_merged(temp_paths):
# Setup global presets
global_data = {
"default": {
"categories": {
"file": {"read": True},
"shell": {"run": False}
"presets": {
"default": {
"categories": {
"General": [{"name": "read_file", "approval": "auto"}]
}
},
"global_only": {
"categories": {"Web": [{"name": "web_search", "approval": "ask"}]}
}
},
"global_only": {
"categories": {"web": {"search": True}}
}
}
with open(temp_paths["global_presets"], "wb") as f:
@@ -43,98 +44,78 @@ def test_load_all_merged(temp_paths):
# Setup project presets (overrides 'default')
project_data = {
"default": {
"categories": {
"file": {"read": True},
"shell": {"run": True} # Override
"presets": {
"default": {
"categories": {
"General": [{"name": "read_file", "approval": "auto", "weight": 5}]
}
}
},
"project_only": {
"categories": {"git": {"commit": True}}
}
}
with open(temp_paths["project_presets"], "wb") as f:
tomli_w.dump(project_data, f)
manager = ToolPresetManager(project_root=temp_paths["project_dir"])
all_presets = manager.load_all()
all_presets = manager.load_all_presets()
assert "default" in all_presets
assert all_presets["default"].categories["shell"]["run"] is True # Overridden
assert isinstance(all_presets["default"].categories["General"][0], Tool)
assert all_presets["default"].categories["General"][0].weight == 5
assert "global_only" in all_presets
assert "project_only" in all_presets
assert all_presets["global_only"].categories["web"]["search"] is True
assert all_presets["project_only"].categories["git"]["commit"] is True
def test_save_preset_global(temp_paths):
manager = ToolPresetManager()
preset = ToolPreset(name="new_global", categories={"test": {"ok": True}})
manager.save_preset(preset, scope="global")
assert temp_paths["global_presets"].exists()
loaded = manager._load_from_path(temp_paths["global_presets"])
assert "new_global" in loaded
assert loaded["new_global"].categories == {"test": {"ok": True}}
def test_save_preset_project(temp_paths):
manager = ToolPresetManager(project_root=temp_paths["project_dir"])
preset = ToolPreset(name="new_project", categories={"test": {"ok": False}})
manager.save_preset(preset, scope="project")
assert temp_paths["project_presets"].exists()
loaded = manager._load_from_path(temp_paths["project_presets"])
assert "new_project" in loaded
assert loaded["new_project"].categories == {"test": {"ok": False}}
def test_delete_preset_global(temp_paths):
# Initial global setup
def test_bias_profiles_merged(temp_paths):
# Setup global biases
global_data = {
"to_delete": {"categories": {}},
"keep": {"categories": {}}
"bias_profiles": {
"Discovery": {
"tool_weights": {"web_search": 5},
"category_multipliers": {"Web": 1.5}
}
}
}
with open(temp_paths["global_presets"], "wb") as f:
tomli_w.dump(global_data, f)
manager = ToolPresetManager()
manager.delete_preset("to_delete", scope="global")
loaded = manager._load_from_path(temp_paths["global_presets"])
assert "to_delete" not in loaded
assert "keep" in loaded
def test_delete_preset_project(temp_paths):
# Initial project setup
# Setup project biases
project_data = {
"to_delete": {"categories": {}},
"keep": {"categories": {}}
"bias_profiles": {
"Execution": {
"tool_weights": {"run_powershell": 5}
}
}
}
with open(temp_paths["project_presets"], "wb") as f:
tomli_w.dump(project_data, f)
manager = ToolPresetManager(project_root=temp_paths["project_dir"])
manager.delete_preset("to_delete", scope="project")
profiles = manager.load_all_bias_profiles()
loaded = manager._load_from_path(temp_paths["project_presets"])
assert "to_delete" not in loaded
assert "keep" in loaded
assert "Discovery" in profiles
assert profiles["Discovery"].category_multipliers["Web"] == 1.5
assert "Execution" in profiles
assert profiles["Execution"].tool_weights["run_powershell"] == 5
def test_save_project_no_root_raises(temp_paths):
manager = ToolPresetManager(project_root=None)
preset = ToolPreset(name="fail", categories={})
with pytest.raises(ValueError, match="Project root not set"):
manager.save_preset(preset, scope="project")
def test_save_bias_profile(temp_paths):
manager = ToolPresetManager(project_root=temp_paths["project_dir"])
profile = BiasProfile(name="Custom", tool_weights={"test": 1})
def test_delete_project_no_root_raises(temp_paths):
manager = ToolPresetManager(project_root=None)
with pytest.raises(ValueError, match="Project root not set"):
manager.delete_preset("any", scope="project")
manager.save_bias_profile(profile, scope="project")
def test_invalid_scope_raises(temp_paths):
manager = ToolPresetManager()
preset = ToolPreset(name="fail", categories={})
with pytest.raises(ValueError, match="Invalid scope"):
manager.save_preset(preset, scope="invalid")
with pytest.raises(ValueError, match="Invalid scope"):
manager.delete_preset("any", scope="invalid")
loaded = manager.load_all_bias_profiles()
assert "Custom" in loaded
assert loaded["Custom"].tool_weights["test"] == 1
def test_delete_bias_profile(temp_paths):
project_data = {
"bias_profiles": {
"to_delete": {"tool_weights": {}}
}
}
with open(temp_paths["project_presets"], "wb") as f:
tomli_w.dump(project_data, f)
manager = ToolPresetManager(project_root=temp_paths["project_dir"])
manager.delete_bias_profile("to_delete", scope="project")
profiles = manager.load_all_bias_profiles()
assert "to_delete" not in profiles

View File

@@ -3,14 +3,14 @@ import asyncio
from src import ai_client
from src import mcp_client
from src import models
from src.models import ToolPreset
from src.models import ToolPreset, Tool
from unittest.mock import MagicMock, patch
@pytest.mark.asyncio
async def test_tool_auto_approval():
# Setup a preset with read_file as auto
preset = ToolPreset(name="AutoTest", categories={
"General": {"read_file": "auto"}
"General": [Tool(name="read_file", approval="auto")]
})
with patch("src.tool_presets.ToolPresetManager.load_all", return_value={"AutoTest": preset}):
@@ -39,7 +39,7 @@ async def test_tool_auto_approval():
async def test_tool_ask_approval():
# Setup a preset with run_powershell as ask
preset = ToolPreset(name="AskTest", categories={
"General": {"run_powershell": "ask"}
"General": [Tool(name="run_powershell", approval="ask")]
})
with patch("src.tool_presets.ToolPresetManager.load_all", return_value={"AskTest": preset}):
@@ -65,7 +65,7 @@ async def test_tool_ask_approval():
async def test_tool_rejection():
# Setup a preset with run_powershell as ask
preset = ToolPreset(name="AskTest", categories={
"General": {"run_powershell": "ask"}
"General": [Tool(name="run_powershell", approval="ask")]
})
with patch("src.tool_presets.ToolPresetManager.load_all", return_value={"AskTest": preset}):

View File

@@ -23,3 +23,15 @@ categories.Python = [
{ name = "py_check_syntax", approval = "auto" },
{ name = "py_get_hierarchy", approval = "auto" }
]
[bias_profiles.Balanced]
tool_weights = {}
category_multipliers = {}
[bias_profiles.Execution-Focused]
tool_weights = { run_powershell = 5 }
category_multipliers = { Runtime = 1.5, Surgical = 1.2 }
[bias_profiles.Discovery-Heavy]
tool_weights = { web_search = 4, search_files = 4 }
category_multipliers = { Web = 1.5, Analysis = 1.3 }