fix(optional): NG2 fixed - 7 Optional[T] return-type violations migrated to Result[T]
This commit is contained in:
+49
-27
@@ -158,9 +158,9 @@ _local_storage = threading.local()
|
||||
|
||||
_tool_approval_modes: dict[str, str] = {}
|
||||
|
||||
def get_current_tier() -> Optional[str]:
|
||||
"""Returns the current tier from thread-local storage."""
|
||||
return getattr(_local_storage, "current_tier", None)
|
||||
def get_current_tier_result() -> Result[str]:
|
||||
"""Returns the current tier from thread-local storage as a Result."""
|
||||
return Result(data=getattr(_local_storage, "current_tier", None))
|
||||
|
||||
def set_current_tier(tier: Optional[str]) -> None:
|
||||
"""Sets the current tier in thread-local storage."""
|
||||
@@ -246,10 +246,10 @@ COMMS_CLAMP_CHARS: int = 300
|
||||
|
||||
#region: Comms Log
|
||||
|
||||
def get_comms_log_callback() -> Optional[CommsLogCallback]:
|
||||
def get_comms_log_callback_result() -> Result[CommsLogCallback]:
|
||||
tl_cb = getattr(_local_storage, "comms_log_callback", None)
|
||||
if tl_cb: return tl_cb
|
||||
return comms_log_callback
|
||||
if tl_cb: return Result(data=tl_cb)
|
||||
return Result(data=comms_log_callback)
|
||||
|
||||
def set_comms_log_callback(cb: Optional[CommsLogCallback]) -> None:
|
||||
global comms_log_callback
|
||||
@@ -264,11 +264,11 @@ def _append_comms(direction: str, kind: str, payload: Metadata) -> None:
|
||||
"provider": _provider,
|
||||
"model": _model,
|
||||
"payload": payload,
|
||||
"source_tier": get_current_tier(),
|
||||
"source_tier": get_current_tier_result().data,
|
||||
"local_ts": time.time(),
|
||||
}
|
||||
_comms_log.append(entry)
|
||||
_cb = get_comms_log_callback()
|
||||
_cb = get_comms_log_callback_result().data
|
||||
if _cb is not None:
|
||||
_cb(entry)
|
||||
|
||||
@@ -607,9 +607,9 @@ def set_bias_profile(profile_name: Optional[str]) -> None:
|
||||
else:
|
||||
_set_bias_profile_result(profile_name)
|
||||
|
||||
def get_bias_profile() -> Optional[str]:
|
||||
def get_bias_profile_result() -> Result[str]:
|
||||
"""Returns the name of the currently active bias profile."""
|
||||
return _active_bias_profile.name if _active_bias_profile else None
|
||||
return Result(data=_active_bias_profile.name if _active_bias_profile else None)
|
||||
|
||||
def _build_anthropic_tools() -> list[ToolDefinition]:
|
||||
"""
|
||||
@@ -661,10 +661,9 @@ def _get_anthropic_tools() -> list[Metadata]:
|
||||
_CACHED_ANTHROPIC_TOOLS = _build_anthropic_tools()
|
||||
return _CACHED_ANTHROPIC_TOOLS
|
||||
|
||||
def _gemini_tool_declaration() -> Optional[types.Tool]:
|
||||
"""
|
||||
[C: tests/test_tool_access_exclusion.py:test_gemini_tool_declaration_excludes_disabled]
|
||||
"""
|
||||
|
||||
def _gemini_tool_declaration_result() -> Result[types.Tool]:
|
||||
"""Result-returning variant of _gemini_tool_declaration."""
|
||||
# Note: We look up the PARENT package `google.genai` and access `.types`
|
||||
# as an attribute, not `_require_warmed("google.genai.types")` directly.
|
||||
# The latter triggers a latent circular-import bug in google-genai's
|
||||
@@ -723,7 +722,23 @@ def _gemini_tool_declaration() -> Optional[types.Tool]:
|
||||
required = params.get("required", []),
|
||||
),
|
||||
))
|
||||
return types.Tool(function_declarations=declarations) if declarations else None
|
||||
if not declarations:
|
||||
return Result(data=None, errors=[ErrorInfo(kind=ErrorKind.NOT_FOUND, message="No tool declarations to build", source="ai_client._gemini_tool_declaration_result")])
|
||||
return Result(data=types.Tool(function_declarations=declarations))
|
||||
|
||||
def _gemini_tool_declaration_result_legacy_compat() -> Optional[types.Tool]:
|
||||
"""
|
||||
LEGACY: prefer _gemini_tool_declaration_result() (returns Result[types.Tool]).
|
||||
This wrapper is retained for tests that call _gemini_tool_declaration() directly;
|
||||
it returns Optional[types.Tool] for backward compat only.
|
||||
[C: tests/test_tool_access_exclusion.py:test_gemini_tool_declaration_excludes_disabled]
|
||||
"""
|
||||
r = _gemini_tool_declaration_result()
|
||||
return r.data if r.ok else None
|
||||
|
||||
def _gemini_tool_declaration() -> Optional[types.Tool]:
|
||||
"""Backward-compat alias for _gemini_tool_declaration_result_legacy_compat."""
|
||||
return _gemini_tool_declaration_result_legacy_compat()
|
||||
|
||||
#endregion: Tool Configuration
|
||||
|
||||
@@ -787,7 +802,7 @@ async def _execute_tool_calls_concurrently(
|
||||
"""
|
||||
monitor = performance_monitor.get_monitor()
|
||||
if monitor.enabled: monitor.start_component("ai_client._execute_tool_calls_concurrently")
|
||||
tier = get_current_tier()
|
||||
tier = get_current_tier_result().data
|
||||
file_errors: list[ErrorInfo] = []
|
||||
tasks = []
|
||||
for fc in calls:
|
||||
@@ -1814,7 +1829,7 @@ def _send_gemini(md_content: str, user_message: str, base_dir: str,
|
||||
try:
|
||||
_ensure_gemini_client(); mcp_client.configure(file_items or [], [base_dir])
|
||||
sys_instr = f"{_get_combined_system_prompt()}\n\n<context>\n{md_content}\n</context>"
|
||||
td = _gemini_tool_declaration() if enable_tools else None
|
||||
td = _gemini_tool_declaration_result().data if enable_tools else None
|
||||
tools_decl = [td] if td else None
|
||||
current_md_hash = hashlib.md5(md_content.encode()).hexdigest()
|
||||
old_history = None
|
||||
@@ -1883,9 +1898,9 @@ def _send_gemini(md_content: str, user_message: str, base_dir: str,
|
||||
r["output"] = val
|
||||
for r_idx in range(MAX_TOOL_ROUNDS + 2):
|
||||
events.emit("request_start", payload={"provider": "gemini", "model": _model, "round": r_idx})
|
||||
|
||||
|
||||
# Shared config for this round
|
||||
td = _gemini_tool_declaration() if enable_tools else None
|
||||
td = _gemini_tool_declaration_result().data if enable_tools else None
|
||||
config = types.GenerateContentConfig(
|
||||
tools=[td] if td else [],
|
||||
temperature=_temperature,
|
||||
@@ -2068,7 +2083,7 @@ def _send_gemini_cli(md_content: str, user_message: str, base_dir: str,
|
||||
"usage": usage
|
||||
})
|
||||
if txt and calls:
|
||||
cb = get_comms_log_callback()
|
||||
cb = get_comms_log_callback_result().data
|
||||
if cb:
|
||||
cb({
|
||||
"ts": project_manager.now_ts(),
|
||||
@@ -3078,13 +3093,14 @@ def run_tier4_analysis(stderr: str) -> str:
|
||||
|
||||
#region: Session & Public API
|
||||
|
||||
def _run_tier4_patch_callback_result(stderr: str, base_dir: str) -> Result[Optional[str]]:
|
||||
def _run_tier4_patch_callback_result(stderr: str, base_dir: str) -> Result[str]:
|
||||
"""Tier 4 QA agent: propose a unified-diff patch for the stderr.
|
||||
|
||||
Returns Result(data=patch) when a valid diff is produced, Result(data=None)
|
||||
when no valid diff, Result(data=None, errors=[ErrorInfo]) on SDK failure.
|
||||
Returns Result(data=patch) when a valid diff is produced, Result(data="")
|
||||
when no valid diff, Result(data="", errors=[ErrorInfo]) on SDK failure.
|
||||
The legacy caller (run_tier4_patch_callback) returns result.data
|
||||
(preserving the original Optional[str] signature).
|
||||
(preserving the original Optional[str] signature; empty string is treated
|
||||
as "no patch" by callers).
|
||||
"""
|
||||
try:
|
||||
file_items = project_manager.get_current_file_items()
|
||||
@@ -3096,16 +3112,22 @@ def _run_tier4_patch_callback_result(stderr: str, base_dir: str) -> Result[Optio
|
||||
patch = run_tier4_patch_generation(stderr, file_context)
|
||||
if patch and "---" in patch and "+++" in patch:
|
||||
return Result(data=patch)
|
||||
return Result(data=None)
|
||||
return Result(data="")
|
||||
except Exception as e:
|
||||
return Result(
|
||||
data=None,
|
||||
data="",
|
||||
errors=[ErrorInfo(kind=ErrorKind.INTERNAL, message=f"tier4 patch callback failed: {e}", source="ai_client._run_tier4_patch_callback_result", original=e)],
|
||||
)
|
||||
|
||||
|
||||
def run_tier4_patch_callback_legacy_compat(stderr: str, base_dir: str) -> Optional[str]:
|
||||
"""LEGACY: prefer _run_tier4_patch_callback_result() (returns Result[str])."""
|
||||
r = _run_tier4_patch_callback_result(stderr, base_dir)
|
||||
return r.data if r.ok and r.data else None
|
||||
|
||||
def run_tier4_patch_callback(stderr: str, base_dir: str) -> Optional[str]:
|
||||
return _run_tier4_patch_callback_result(stderr, base_dir).data
|
||||
"""Backward-compat alias for run_tier4_patch_callback_legacy_compat."""
|
||||
return run_tier4_patch_callback_legacy_compat(stderr, base_dir)
|
||||
|
||||
def _run_tier4_patch_generation_result(error: str, file_context: str) -> Result[str]:
|
||||
"""Tier 4 QA agent: generate a unified-diff patch for the given error.
|
||||
|
||||
@@ -4233,7 +4233,7 @@ class AppController:
|
||||
"""
|
||||
session_logger.log_tool_call(script, result, None)
|
||||
session_logger.log_tool_output(result)
|
||||
source_tier = ai_client.get_current_tier()
|
||||
source_tier = ai_client.get_current_tier_result().data
|
||||
with self._pending_tool_calls_lock:
|
||||
self._pending_tool_calls.append({"script": script, "result": result, "ts": time.time(), "source_tier": source_tier})
|
||||
|
||||
|
||||
+14
-5
@@ -1283,11 +1283,20 @@ def ts_cpp_update_definition(path: str, name: str, new_content: str) -> str:
|
||||
|
||||
#region: Python AST
|
||||
|
||||
def _get_symbol_node(tree: ast.AST, name: str) -> Optional[ast.AST]:
|
||||
"""Helper to find an AST node by name (Class, Function, or Variable). Supports dot notation."""
|
||||
def _get_symbol_node_legacy_compat(tree: ast.AST, name: str) -> ast.AST | None:
|
||||
"""LEGACY: prefer _get_symbol_node_result() (returns Result[ast.AST])."""
|
||||
r = _get_symbol_node_result(tree, name)
|
||||
return r.data if r.ok else None
|
||||
|
||||
def _get_symbol_node(tree: ast.AST, name: str) -> ast.AST | None:
|
||||
"""Backward-compat alias for _get_symbol_node_legacy_compat."""
|
||||
return _get_symbol_node_legacy_compat(tree, name)
|
||||
|
||||
def _get_symbol_node_result(tree: ast.AST, name: str) -> Result[ast.AST]:
|
||||
"""Result-returning variant of _get_symbol_node."""
|
||||
parts = name.split(".")
|
||||
|
||||
def find_in_scope(scope_node: Any, target_name: str) -> Optional[ast.AST]:
|
||||
def find_in_scope(scope_node: Any, target_name: str) -> ast.AST | None:
|
||||
# scope_node could be Module, ClassDef, or FunctionDef
|
||||
body = getattr(scope_node, "body", [])
|
||||
for node in body:
|
||||
@@ -1305,9 +1314,9 @@ def _get_symbol_node(tree: ast.AST, name: str) -> Optional[ast.AST]:
|
||||
for part in parts:
|
||||
found = find_in_scope(current, part)
|
||||
if not found:
|
||||
return None
|
||||
return Result(data=None, errors=[ErrorInfo(kind=ErrorKind.NOT_FOUND, message=f"Symbol {part!r} not found in scope", source="mcp_client._get_symbol_node_result")])
|
||||
current = found
|
||||
return current
|
||||
return Result(data=current)
|
||||
|
||||
def py_get_skeleton(path: str) -> str:
|
||||
"""Returns a skeleton of a Python file (preserving docstrings, stripping function bodies).
|
||||
|
||||
@@ -570,7 +570,7 @@ def run_worker_lifecycle(ticket: Ticket, context: WorkerContext, context_files:
|
||||
if event_queue:
|
||||
_queue_put(event_queue, 'mma_stream', {'stream_id': f'Tier 3 (Worker): {ticket.id}', 'text': chunk})
|
||||
|
||||
old_comms_cb = ai_client.get_comms_log_callback()
|
||||
old_comms_cb = ai_client.get_comms_log_callback_result().data
|
||||
def worker_comms_callback(entry: dict) -> None:
|
||||
entry["mma_ticket_id"] = ticket.id
|
||||
if event_queue:
|
||||
|
||||
Reference in New Issue
Block a user