feat(ui): AI Settings Overhaul - added dual sliders for model params including top_p

This commit is contained in:
2026-03-11 20:22:06 -04:00
parent 55475b80e7
commit 09902701b4
3 changed files with 55 additions and 8 deletions

View File

@@ -42,6 +42,7 @@ from src.events import EventEmitter
_provider: str = "gemini" _provider: str = "gemini"
_model: str = "gemini-2.5-flash-lite" _model: str = "gemini-2.5-flash-lite"
_temperature: float = 0.0 _temperature: float = 0.0
_top_p: float = 1.0
_max_tokens: int = 8192 _max_tokens: int = 8192
_history_trunc_limit: int = 8000 _history_trunc_limit: int = 8000
@@ -49,11 +50,12 @@ _history_trunc_limit: int = 8000
# Global event emitter for API lifecycle events # Global event emitter for API lifecycle events
events: EventEmitter = EventEmitter() events: EventEmitter = EventEmitter()
def set_model_params(temp: float, max_tok: int, trunc_limit: int = 8000) -> None: def set_model_params(temp: float, max_tok: int, trunc_limit: int = 8000, top_p: float = 1.0) -> None:
global _temperature, _max_tokens, _history_trunc_limit global _temperature, _max_tokens, _history_trunc_limit, _top_p
_temperature = temp _temperature = temp
_max_tokens = max_tok _max_tokens = max_tok
_history_trunc_limit = trunc_limit _history_trunc_limit = trunc_limit
_top_p = top_p
def get_history_trunc_limit() -> int: def get_history_trunc_limit() -> int:
return _history_trunc_limit return _history_trunc_limit
@@ -939,6 +941,7 @@ def _send_gemini(md_content: str, user_message: str, base_dir: str,
system_instruction=sys_instr, system_instruction=sys_instr,
tools=cast(Any, tools_decl), tools=cast(Any, tools_decl),
temperature=_temperature, temperature=_temperature,
top_p=_top_p,
max_output_tokens=_max_tokens, max_output_tokens=_max_tokens,
safety_settings=[types.SafetySetting(category=types.HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT, threshold=types.HarmBlockThreshold.BLOCK_ONLY_HIGH)] safety_settings=[types.SafetySetting(category=types.HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT, threshold=types.HarmBlockThreshold.BLOCK_ONLY_HIGH)]
) )
@@ -1010,6 +1013,7 @@ def _send_gemini(md_content: str, user_message: str, base_dir: str,
config = types.GenerateContentConfig( config = types.GenerateContentConfig(
tools=[td] if td else [], tools=[td] if td else [],
temperature=_temperature, temperature=_temperature,
top_p=_top_p,
max_output_tokens=_max_tokens, max_output_tokens=_max_tokens,
) )
@@ -1455,6 +1459,7 @@ def _send_anthropic(md_content: str, user_message: str, base_dir: str, file_item
model=_model, model=_model,
max_tokens=_max_tokens, max_tokens=_max_tokens,
temperature=_temperature, temperature=_temperature,
top_p=_top_p,
system=cast(Iterable[anthropic.types.TextBlockParam], system_blocks), system=cast(Iterable[anthropic.types.TextBlockParam], system_blocks),
tools=cast(Iterable[anthropic.types.ToolParam], _get_anthropic_tools()), tools=cast(Iterable[anthropic.types.ToolParam], _get_anthropic_tools()),
messages=cast(Iterable[anthropic.types.MessageParam], _strip_private_keys(_anthropic_history)), messages=cast(Iterable[anthropic.types.MessageParam], _strip_private_keys(_anthropic_history)),
@@ -1468,6 +1473,7 @@ def _send_anthropic(md_content: str, user_message: str, base_dir: str, file_item
model=_model, model=_model,
max_tokens=_max_tokens, max_tokens=_max_tokens,
temperature=_temperature, temperature=_temperature,
top_p=_top_p,
system=cast(Iterable[anthropic.types.TextBlockParam], system_blocks), system=cast(Iterable[anthropic.types.TextBlockParam], system_blocks),
tools=cast(Iterable[anthropic.types.ToolParam], _get_anthropic_tools()), tools=cast(Iterable[anthropic.types.ToolParam], _get_anthropic_tools()),
messages=cast(Iterable[anthropic.types.MessageParam], _strip_private_keys(_anthropic_history)), messages=cast(Iterable[anthropic.types.MessageParam], _strip_private_keys(_anthropic_history)),
@@ -1696,6 +1702,7 @@ def _send_deepseek(md_content: str, user_message: str, base_dir: str,
if not is_reasoner: if not is_reasoner:
request_payload["temperature"] = _temperature request_payload["temperature"] = _temperature
request_payload["top_p"] = _top_p
# DeepSeek max_tokens is for the output, clamp to 8192 which is their hard limit for V3/Chat # DeepSeek max_tokens is for the output, clamp to 8192 which is their hard limit for V3/Chat
request_payload["max_tokens"] = min(_max_tokens, 8192) request_payload["max_tokens"] = min(_max_tokens, 8192)
tools = _get_deepseek_tools() tools = _get_deepseek_tools()
@@ -1927,6 +1934,7 @@ def _send_minimax(md_content: str, user_message: str, base_dir: str,
request_payload["stream_options"] = {"include_usage": True} request_payload["stream_options"] = {"include_usage": True}
request_payload["temperature"] = 1.0 request_payload["temperature"] = 1.0
request_payload["top_p"] = _top_p
request_payload["max_tokens"] = min(_max_tokens, 8192) request_payload["max_tokens"] = min(_max_tokens, 8192)
tools = _get_deepseek_tools() tools = _get_deepseek_tools()

View File

@@ -61,8 +61,8 @@ class GenerateRequest(BaseModel):
prompt: str prompt: str
auto_add_history: bool = True auto_add_history: bool = True
temperature: float | None = None temperature: float | None = None
top_p: float | None = None
max_tokens: int | None = None max_tokens: int | None = None
class ConfirmRequest(BaseModel): class ConfirmRequest(BaseModel):
approved: bool approved: bool
script: Optional[str] = None script: Optional[str] = None
@@ -199,6 +199,7 @@ class AppController:
self._current_provider: str = "gemini" self._current_provider: str = "gemini"
self._current_model: str = "gemini-2.5-flash-lite" self._current_model: str = "gemini-2.5-flash-lite"
self.temperature: float = 0.0 self.temperature: float = 0.0
self.top_p: float = 1.0
self.max_tokens: int = 8192 self.max_tokens: int = 8192
self.history_trunc_limit: int = 8000 self.history_trunc_limit: int = 8000
# UI-related state moved to controller # UI-related state moved to controller
@@ -484,6 +485,7 @@ class AppController:
self._predefined_callbacks: dict[str, Callable[..., Any]] = { self._predefined_callbacks: dict[str, Callable[..., Any]] = {
'_test_callback_func_write_to_file': self._test_callback_func_write_to_file, '_test_callback_func_write_to_file': self._test_callback_func_write_to_file,
'_set_env_var': lambda k, v: os.environ.update({k: v}), '_set_env_var': lambda k, v: os.environ.update({k: v}),
'_set_attr': lambda k, v: setattr(self, k, v),
'_apply_preset': self._apply_preset, '_apply_preset': self._apply_preset,
'_cb_save_preset': self._cb_save_preset, '_cb_save_preset': self._cb_save_preset,
'_cb_delete_preset': self._cb_delete_preset, '_cb_delete_preset': self._cb_delete_preset,
@@ -835,6 +837,7 @@ class AppController:
self._current_provider = ai_cfg.get("provider", "gemini") self._current_provider = ai_cfg.get("provider", "gemini")
self._current_model = ai_cfg.get("model", "gemini-2.5-flash-lite") self._current_model = ai_cfg.get("model", "gemini-2.5-flash-lite")
self.temperature = ai_cfg.get("temperature", 0.0) self.temperature = ai_cfg.get("temperature", 0.0)
self.top_p = ai_cfg.get("top_p", 1.0)
self.max_tokens = ai_cfg.get("max_tokens", 8192) self.max_tokens = ai_cfg.get("max_tokens", 8192)
self.history_trunc_limit = ai_cfg.get("history_trunc_limit", 8000) self.history_trunc_limit = ai_cfg.get("history_trunc_limit", 8000)
projects_cfg = self.config.get("projects", {}) projects_cfg = self.config.get("projects", {})
@@ -1246,7 +1249,7 @@ class AppController:
self.ai_response = "" self.ai_response = ""
csp = filter(bool, [self.ui_global_system_prompt.strip(), self.ui_project_system_prompt.strip()]) csp = filter(bool, [self.ui_global_system_prompt.strip(), self.ui_project_system_prompt.strip()])
ai_client.set_custom_system_prompt("\n\n".join(csp)) ai_client.set_custom_system_prompt("\n\n".join(csp))
ai_client.set_model_params(self.temperature, self.max_tokens, self.history_trunc_limit) ai_client.set_model_params(self.temperature, self.max_tokens, self.history_trunc_limit, self.top_p)
ai_client.set_agent_tools(self.ui_agent_tools) ai_client.set_agent_tools(self.ui_agent_tools)
# Force update adapter path right before send to bypass potential duplication issues # Force update adapter path right before send to bypass potential duplication issues
self._update_gcli_adapter(self.ui_gemini_cli_path) self._update_gcli_adapter(self.ui_gemini_cli_path)
@@ -1633,8 +1636,9 @@ class AppController:
csp = filter(bool, [self.ui_global_system_prompt.strip(), self.ui_project_system_prompt.strip()]) csp = filter(bool, [self.ui_global_system_prompt.strip(), self.ui_project_system_prompt.strip()])
ai_client.set_custom_system_prompt("\n\n".join(csp)) ai_client.set_custom_system_prompt("\n\n".join(csp))
temp = req.temperature if req.temperature is not None else self.temperature temp = req.temperature if req.temperature is not None else self.temperature
top_p = req.top_p if req.top_p is not None else self.top_p
tokens = req.max_tokens if req.max_tokens is not None else self.max_tokens tokens = req.max_tokens if req.max_tokens is not None else self.max_tokens
ai_client.set_model_params(temp, tokens, self.history_trunc_limit) ai_client.set_model_params(temp, tokens, self.history_trunc_limit, top_p)
ai_client.set_agent_tools(self.ui_agent_tools) ai_client.set_agent_tools(self.ui_agent_tools)
if req.auto_add_history: if req.auto_add_history:
with self._pending_history_adds_lock: with self._pending_history_adds_lock:
@@ -2265,6 +2269,7 @@ class AppController:
"provider": self.current_provider, "provider": self.current_provider,
"model": self.current_model, "model": self.current_model,
"temperature": self.temperature, "temperature": self.temperature,
"top_p": self.top_p,
"max_tokens": self.max_tokens, "max_tokens": self.max_tokens,
"history_trunc_limit": self.history_trunc_limit, "history_trunc_limit": self.history_trunc_limit,
"active_preset": self.ui_global_preset_name, "active_preset": self.ui_global_preset_name,

View File

@@ -78,6 +78,7 @@ class GenerateRequest(BaseModel):
prompt: str prompt: str
auto_add_history: bool = True auto_add_history: bool = True
temperature: float | None = None temperature: float | None = None
top_p: float | None = None
max_tokens: int | None = None max_tokens: int | None = None
class ConfirmRequest(BaseModel): class ConfirmRequest(BaseModel):
@@ -108,6 +109,7 @@ class App:
self._editing_persona_model = "" self._editing_persona_model = ""
self._editing_persona_system_prompt = "" self._editing_persona_system_prompt = ""
self._editing_persona_temperature = 0.7 self._editing_persona_temperature = 0.7
self._editing_persona_top_p = 1.0
self._editing_persona_max_tokens = 4096 self._editing_persona_max_tokens = 4096
self._editing_persona_tool_preset_id = "" self._editing_persona_tool_preset_id = ""
self._editing_persona_bias_profile_id = "" self._editing_persona_bias_profile_id = ""
@@ -1345,7 +1347,14 @@ class App:
self._persona_pref_models_expanded[i] = not is_expanded self._persona_pref_models_expanded[i] = not is_expanded
imgui.same_line() imgui.same_line()
imgui.text(f"{i+1}. {prov} - {mod}") imgui.text(f"{i+1}.")
imgui.same_line()
imgui.text_colored(C_LBL, f"{prov}")
imgui.same_line()
imgui.text("-")
imgui.same_line()
imgui.text_colored(C_IN, f"{mod}")
imgui.same_line(imgui.get_content_region_avail().x - 30) imgui.same_line(imgui.get_content_region_avail().x - 30)
if imgui.button("x"): if imgui.button("x"):
to_remove.append(i) to_remove.append(i)
@@ -2366,8 +2375,33 @@ def hello():
imgui.end_list_box() imgui.end_list_box()
imgui.separator() imgui.separator()
imgui.text("Parameters") imgui.text("Parameters")
ch, self.temperature = imgui.slider_float("Temperature", self.temperature, 0.0, 2.0, "%.2f") # Temperature
ch, self.max_tokens = imgui.input_int("Max Tokens (Output)", self.max_tokens, 1024) imgui.push_id("temp")
imgui.set_next_item_width(imgui.get_content_region_avail().x * 0.6)
_, self.temperature = imgui.slider_float("##slider", self.temperature, 0.0, 2.0, "%.2f")
imgui.same_line()
imgui.set_next_item_width(-1)
_, self.temperature = imgui.input_float("Temp", self.temperature, 0.0, 0.0, "%.2f")
imgui.pop_id()
# Top-P
imgui.push_id("top_p")
imgui.set_next_item_width(imgui.get_content_region_avail().x * 0.6)
_, self.top_p = imgui.slider_float("##slider", self.top_p, 0.0, 1.0, "%.2f")
imgui.same_line()
imgui.set_next_item_width(-1)
_, self.top_p = imgui.input_float("Top-P", self.top_p, 0.0, 0.0, "%.2f")
imgui.pop_id()
# Max Tokens
imgui.push_id("max_tokens")
imgui.set_next_item_width(imgui.get_content_region_avail().x * 0.6)
_, self.max_tokens = imgui.slider_int("##slider", self.max_tokens, 1, 32768)
imgui.same_line()
imgui.set_next_item_width(-1)
_, self.max_tokens = imgui.input_int("MaxTok", self.max_tokens)
imgui.pop_id()
ch, self.history_trunc_limit = imgui.input_int("History Truncation Limit", self.history_trunc_limit, 1024) ch, self.history_trunc_limit = imgui.input_int("History Truncation Limit", self.history_trunc_limit, 1024)