feat(bias): implement data models and storage for tool weighting and bias profiles

This commit is contained in:
2026-03-10 09:27:12 -04:00
parent ee19cc1d2a
commit 77a0b385d5
6 changed files with 264 additions and 154 deletions

View File

@@ -368,18 +368,70 @@ class Preset:
)
@dataclass
class ToolPreset:
class Tool:
name: str
categories: Dict[str, Dict[str, Any]]
approval: str = 'auto'
weight: int = 3
parameter_bias: Dict[str, str] = field(default_factory=dict)
def to_dict(self) -> Dict[str, Any]:
return {
"categories": self.categories,
"name": self.name,
"approval": self.approval,
"weight": self.weight,
"parameter_bias": self.parameter_bias,
}
@classmethod
def from_dict(cls, data: Dict[str, Any]) -> "Tool":
return cls(
name=data["name"],
approval=data.get("approval", "auto"),
weight=data.get("weight", 3),
parameter_bias=data.get("parameter_bias", {}),
)
@dataclass
class ToolPreset:
name: str
categories: Dict[str, List[Union[Tool, Any]]]
def to_dict(self) -> Dict[str, Any]:
serialized_categories = {}
for cat, tools in self.categories.items():
serialized_categories[cat] = [t.to_dict() if isinstance(t, Tool) else t for t in tools]
return {
"categories": serialized_categories,
}
@classmethod
def from_dict(cls, name: str, data: Dict[str, Any]) -> "ToolPreset":
raw_categories = data.get("categories", {})
parsed_categories = {}
for cat, tools in raw_categories.items():
parsed_categories[cat] = [Tool.from_dict(t) if isinstance(t, dict) else t for t in tools]
return cls(
name=name,
categories=data.get("categories", {}),
categories=parsed_categories,
)
@dataclass
class BiasProfile:
name: str
tool_weights: Dict[str, int] = field(default_factory=dict)
category_multipliers: Dict[str, float] = field(default_factory=dict)
def to_dict(self) -> Dict[str, Any]:
return {
"name": self.name,
"tool_weights": self.tool_weights,
"category_multipliers": self.category_multipliers,
}
@classmethod
def from_dict(cls, data: Dict[str, Any]) -> "BiasProfile":
return cls(
name=data["name"],
tool_weights=data.get("tool_weights", {}),
category_multipliers=data.get("category_multipliers", {}),
)

View File

@@ -1,91 +1,110 @@
import tomllib
import tomli_w
from pathlib import Path
from typing import Dict, List, Optional, Union
from typing import Dict, List, Optional, Union, Any
from src import paths
from src.models import ToolPreset
from src.models import ToolPreset, BiasProfile
class ToolPresetManager:
def __init__(self, project_root: Optional[Union[str, Path]] = None):
self.project_root = Path(project_root) if project_root else None
def _load_from_path(self, path: Path) -> Dict[str, ToolPreset]:
def _get_path(self, scope: str) -> Path:
if scope == "global":
return paths.get_global_tool_presets_path()
elif scope == "project":
if not self.project_root:
raise ValueError("Project root not set for project scope operation.")
return paths.get_project_tool_presets_path(self.project_root)
else:
raise ValueError(f"Invalid scope: {scope}")
def _read_raw(self, path: Path) -> Dict[str, Any]:
if not path.exists():
return {}
try:
with open(path, "rb") as f:
data = tomllib.load(f)
presets = {}
for name, config in data.items():
if isinstance(config, dict):
presets[name] = ToolPreset.from_dict(name, config)
return presets
return tomllib.load(f)
except Exception:
return {}
def load_all(self) -> Dict[str, ToolPreset]:
"""
Merges global and project presets.
Project presets override global ones if they have the same name.
"""
presets = self._load_from_path(paths.get_global_tool_presets_path())
if self.project_root:
project_presets = self._load_from_path(paths.get_project_tool_presets_path(self.project_root))
presets.update(project_presets)
return presets
def save_preset(self, preset: ToolPreset, scope: str = "project") -> None:
"""
Saves a preset to either 'global' or 'project' scope.
Scope must be 'global' or 'project'.
"""
if scope == "global":
path = paths.get_global_tool_presets_path()
elif scope == "project":
if not self.project_root:
raise ValueError("Project root not set for project scope saving.")
path = paths.get_project_tool_presets_path(self.project_root)
else:
raise ValueError(f"Invalid scope: {scope}")
data = {}
if path.exists():
try:
with open(path, "rb") as f:
data = tomllib.load(f)
except Exception:
data = {}
data[preset.name] = preset.to_dict()
def _write_raw(self, path: Path, data: Dict[str, Any]) -> None:
path.parent.mkdir(parents=True, exist_ok=True)
with open(path, "wb") as f:
tomli_w.dump(data, f)
def load_all_presets(self) -> Dict[str, ToolPreset]:
global_path = paths.get_global_tool_presets_path()
global_data = self._read_raw(global_path).get("presets", {})
presets = {}
for name, config in global_data.items():
if isinstance(config, dict):
presets[name] = ToolPreset.from_dict(name, config)
if self.project_root:
project_path = paths.get_project_tool_presets_path(self.project_root)
project_data = self._read_raw(project_path).get("presets", {})
for name, config in project_data.items():
if isinstance(config, dict):
presets[name] = ToolPreset.from_dict(name, config)
return presets
def load_all(self) -> Dict[str, ToolPreset]:
"""Backward compatibility for load_all()."""
return self.load_all_presets()
def save_preset(self, preset: ToolPreset, scope: str = "project") -> None:
path = self._get_path(scope)
data = self._read_raw(path)
if "presets" not in data:
data["presets"] = {}
data["presets"][preset.name] = preset.to_dict()
self._write_raw(path, data)
def delete_preset(self, name: str, scope: str = "project") -> None:
"""
Deletes a preset from the specified scope.
Scope must be 'global' or 'project'.
"""
if scope == "global":
path = paths.get_global_tool_presets_path()
elif scope == "project":
if not self.project_root:
raise ValueError("Project root not set for project scope deletion.")
path = paths.get_project_tool_presets_path(self.project_root)
else:
raise ValueError(f"Invalid scope: {scope}")
path = self._get_path(scope)
data = self._read_raw(path)
if "presets" in data and name in data["presets"]:
del data["presets"][name]
self._write_raw(path, data)
def load_all_bias_profiles(self) -> Dict[str, BiasProfile]:
global_path = paths.get_global_tool_presets_path()
global_data = self._read_raw(global_path).get("bias_profiles", {})
if not path.exists():
return
profiles = {}
for name, config in global_data.items():
if isinstance(config, dict):
cfg = dict(config)
if "name" not in cfg:
cfg["name"] = name
profiles[name] = BiasProfile.from_dict(cfg)
try:
with open(path, "rb") as f:
data = tomllib.load(f)
except Exception:
return
if self.project_root:
project_path = paths.get_project_tool_presets_path(self.project_root)
project_data = self._read_raw(project_path).get("bias_profiles", {})
for name, config in project_data.items():
if isinstance(config, dict):
cfg = dict(config)
if "name" not in cfg:
cfg["name"] = name
profiles[name] = BiasProfile.from_dict(cfg)
if name in data:
del data[name]
with open(path, "wb") as f:
tomli_w.dump(data, f)
return profiles
def save_bias_profile(self, profile: BiasProfile, scope: str = "project") -> None:
path = self._get_path(scope)
data = self._read_raw(path)
if "bias_profiles" not in data:
data["bias_profiles"] = {}
data["bias_profiles"][profile.name] = profile.to_dict()
self._write_raw(path, data)
def delete_bias_profile(self, name: str, scope: str = "project") -> None:
path = self._get_path(scope)
data = self._read_raw(path)
if "bias_profiles" in data and name in data["bias_profiles"]:
del data["bias_profiles"][name]
self._write_raw(path, data)