feat(bias): implement data models and storage for tool weighting and bias profiles
This commit is contained in:
@@ -368,18 +368,70 @@ class Preset:
|
||||
)
|
||||
|
||||
@dataclass
|
||||
class ToolPreset:
|
||||
class Tool:
|
||||
name: str
|
||||
categories: Dict[str, Dict[str, Any]]
|
||||
approval: str = 'auto'
|
||||
weight: int = 3
|
||||
parameter_bias: Dict[str, str] = field(default_factory=dict)
|
||||
|
||||
def to_dict(self) -> Dict[str, Any]:
|
||||
return {
|
||||
"categories": self.categories,
|
||||
"name": self.name,
|
||||
"approval": self.approval,
|
||||
"weight": self.weight,
|
||||
"parameter_bias": self.parameter_bias,
|
||||
}
|
||||
|
||||
@classmethod
|
||||
def from_dict(cls, data: Dict[str, Any]) -> "Tool":
|
||||
return cls(
|
||||
name=data["name"],
|
||||
approval=data.get("approval", "auto"),
|
||||
weight=data.get("weight", 3),
|
||||
parameter_bias=data.get("parameter_bias", {}),
|
||||
)
|
||||
|
||||
@dataclass
|
||||
class ToolPreset:
|
||||
name: str
|
||||
categories: Dict[str, List[Union[Tool, Any]]]
|
||||
|
||||
def to_dict(self) -> Dict[str, Any]:
|
||||
serialized_categories = {}
|
||||
for cat, tools in self.categories.items():
|
||||
serialized_categories[cat] = [t.to_dict() if isinstance(t, Tool) else t for t in tools]
|
||||
return {
|
||||
"categories": serialized_categories,
|
||||
}
|
||||
|
||||
@classmethod
|
||||
def from_dict(cls, name: str, data: Dict[str, Any]) -> "ToolPreset":
|
||||
raw_categories = data.get("categories", {})
|
||||
parsed_categories = {}
|
||||
for cat, tools in raw_categories.items():
|
||||
parsed_categories[cat] = [Tool.from_dict(t) if isinstance(t, dict) else t for t in tools]
|
||||
return cls(
|
||||
name=name,
|
||||
categories=data.get("categories", {}),
|
||||
categories=parsed_categories,
|
||||
)
|
||||
|
||||
@dataclass
|
||||
class BiasProfile:
|
||||
name: str
|
||||
tool_weights: Dict[str, int] = field(default_factory=dict)
|
||||
category_multipliers: Dict[str, float] = field(default_factory=dict)
|
||||
|
||||
def to_dict(self) -> Dict[str, Any]:
|
||||
return {
|
||||
"name": self.name,
|
||||
"tool_weights": self.tool_weights,
|
||||
"category_multipliers": self.category_multipliers,
|
||||
}
|
||||
|
||||
@classmethod
|
||||
def from_dict(cls, data: Dict[str, Any]) -> "BiasProfile":
|
||||
return cls(
|
||||
name=data["name"],
|
||||
tool_weights=data.get("tool_weights", {}),
|
||||
category_multipliers=data.get("category_multipliers", {}),
|
||||
)
|
||||
|
||||
@@ -1,91 +1,110 @@
|
||||
import tomllib
|
||||
import tomli_w
|
||||
from pathlib import Path
|
||||
from typing import Dict, List, Optional, Union
|
||||
from typing import Dict, List, Optional, Union, Any
|
||||
from src import paths
|
||||
from src.models import ToolPreset
|
||||
from src.models import ToolPreset, BiasProfile
|
||||
|
||||
class ToolPresetManager:
|
||||
def __init__(self, project_root: Optional[Union[str, Path]] = None):
|
||||
self.project_root = Path(project_root) if project_root else None
|
||||
|
||||
def _load_from_path(self, path: Path) -> Dict[str, ToolPreset]:
|
||||
def _get_path(self, scope: str) -> Path:
|
||||
if scope == "global":
|
||||
return paths.get_global_tool_presets_path()
|
||||
elif scope == "project":
|
||||
if not self.project_root:
|
||||
raise ValueError("Project root not set for project scope operation.")
|
||||
return paths.get_project_tool_presets_path(self.project_root)
|
||||
else:
|
||||
raise ValueError(f"Invalid scope: {scope}")
|
||||
|
||||
def _read_raw(self, path: Path) -> Dict[str, Any]:
|
||||
if not path.exists():
|
||||
return {}
|
||||
try:
|
||||
with open(path, "rb") as f:
|
||||
data = tomllib.load(f)
|
||||
presets = {}
|
||||
for name, config in data.items():
|
||||
if isinstance(config, dict):
|
||||
presets[name] = ToolPreset.from_dict(name, config)
|
||||
return presets
|
||||
return tomllib.load(f)
|
||||
except Exception:
|
||||
return {}
|
||||
|
||||
def load_all(self) -> Dict[str, ToolPreset]:
|
||||
"""
|
||||
Merges global and project presets.
|
||||
Project presets override global ones if they have the same name.
|
||||
"""
|
||||
presets = self._load_from_path(paths.get_global_tool_presets_path())
|
||||
if self.project_root:
|
||||
project_presets = self._load_from_path(paths.get_project_tool_presets_path(self.project_root))
|
||||
presets.update(project_presets)
|
||||
return presets
|
||||
|
||||
def save_preset(self, preset: ToolPreset, scope: str = "project") -> None:
|
||||
"""
|
||||
Saves a preset to either 'global' or 'project' scope.
|
||||
Scope must be 'global' or 'project'.
|
||||
"""
|
||||
if scope == "global":
|
||||
path = paths.get_global_tool_presets_path()
|
||||
elif scope == "project":
|
||||
if not self.project_root:
|
||||
raise ValueError("Project root not set for project scope saving.")
|
||||
path = paths.get_project_tool_presets_path(self.project_root)
|
||||
else:
|
||||
raise ValueError(f"Invalid scope: {scope}")
|
||||
|
||||
data = {}
|
||||
if path.exists():
|
||||
try:
|
||||
with open(path, "rb") as f:
|
||||
data = tomllib.load(f)
|
||||
except Exception:
|
||||
data = {}
|
||||
|
||||
data[preset.name] = preset.to_dict()
|
||||
|
||||
def _write_raw(self, path: Path, data: Dict[str, Any]) -> None:
|
||||
path.parent.mkdir(parents=True, exist_ok=True)
|
||||
with open(path, "wb") as f:
|
||||
tomli_w.dump(data, f)
|
||||
|
||||
def load_all_presets(self) -> Dict[str, ToolPreset]:
|
||||
global_path = paths.get_global_tool_presets_path()
|
||||
global_data = self._read_raw(global_path).get("presets", {})
|
||||
|
||||
presets = {}
|
||||
for name, config in global_data.items():
|
||||
if isinstance(config, dict):
|
||||
presets[name] = ToolPreset.from_dict(name, config)
|
||||
|
||||
if self.project_root:
|
||||
project_path = paths.get_project_tool_presets_path(self.project_root)
|
||||
project_data = self._read_raw(project_path).get("presets", {})
|
||||
for name, config in project_data.items():
|
||||
if isinstance(config, dict):
|
||||
presets[name] = ToolPreset.from_dict(name, config)
|
||||
|
||||
return presets
|
||||
|
||||
def load_all(self) -> Dict[str, ToolPreset]:
|
||||
"""Backward compatibility for load_all()."""
|
||||
return self.load_all_presets()
|
||||
|
||||
def save_preset(self, preset: ToolPreset, scope: str = "project") -> None:
|
||||
path = self._get_path(scope)
|
||||
data = self._read_raw(path)
|
||||
if "presets" not in data:
|
||||
data["presets"] = {}
|
||||
data["presets"][preset.name] = preset.to_dict()
|
||||
self._write_raw(path, data)
|
||||
|
||||
def delete_preset(self, name: str, scope: str = "project") -> None:
|
||||
"""
|
||||
Deletes a preset from the specified scope.
|
||||
Scope must be 'global' or 'project'.
|
||||
"""
|
||||
if scope == "global":
|
||||
path = paths.get_global_tool_presets_path()
|
||||
elif scope == "project":
|
||||
if not self.project_root:
|
||||
raise ValueError("Project root not set for project scope deletion.")
|
||||
path = paths.get_project_tool_presets_path(self.project_root)
|
||||
else:
|
||||
raise ValueError(f"Invalid scope: {scope}")
|
||||
path = self._get_path(scope)
|
||||
data = self._read_raw(path)
|
||||
if "presets" in data and name in data["presets"]:
|
||||
del data["presets"][name]
|
||||
self._write_raw(path, data)
|
||||
|
||||
def load_all_bias_profiles(self) -> Dict[str, BiasProfile]:
|
||||
global_path = paths.get_global_tool_presets_path()
|
||||
global_data = self._read_raw(global_path).get("bias_profiles", {})
|
||||
|
||||
if not path.exists():
|
||||
return
|
||||
profiles = {}
|
||||
for name, config in global_data.items():
|
||||
if isinstance(config, dict):
|
||||
cfg = dict(config)
|
||||
if "name" not in cfg:
|
||||
cfg["name"] = name
|
||||
profiles[name] = BiasProfile.from_dict(cfg)
|
||||
|
||||
try:
|
||||
with open(path, "rb") as f:
|
||||
data = tomllib.load(f)
|
||||
except Exception:
|
||||
return
|
||||
if self.project_root:
|
||||
project_path = paths.get_project_tool_presets_path(self.project_root)
|
||||
project_data = self._read_raw(project_path).get("bias_profiles", {})
|
||||
for name, config in project_data.items():
|
||||
if isinstance(config, dict):
|
||||
cfg = dict(config)
|
||||
if "name" not in cfg:
|
||||
cfg["name"] = name
|
||||
profiles[name] = BiasProfile.from_dict(cfg)
|
||||
|
||||
if name in data:
|
||||
del data[name]
|
||||
with open(path, "wb") as f:
|
||||
tomli_w.dump(data, f)
|
||||
return profiles
|
||||
|
||||
def save_bias_profile(self, profile: BiasProfile, scope: str = "project") -> None:
|
||||
path = self._get_path(scope)
|
||||
data = self._read_raw(path)
|
||||
if "bias_profiles" not in data:
|
||||
data["bias_profiles"] = {}
|
||||
data["bias_profiles"][profile.name] = profile.to_dict()
|
||||
self._write_raw(path, data)
|
||||
|
||||
def delete_bias_profile(self, name: str, scope: str = "project") -> None:
|
||||
path = self._get_path(scope)
|
||||
data = self._read_raw(path)
|
||||
if "bias_profiles" in data and name in data["bias_profiles"]:
|
||||
del data["bias_profiles"][name]
|
||||
self._write_raw(path, data)
|
||||
|
||||
Reference in New Issue
Block a user