Private
Public Access
0
0

fix(optional): NG2 fixed - 7 Optional[T] return-type violations migrated to Result[T]

This commit is contained in:
2026-06-24 17:37:17 -04:00
parent ee4287ae4d
commit 99e0c77dcd
4 changed files with 65 additions and 34 deletions
+49 -27
View File
@@ -158,9 +158,9 @@ _local_storage = threading.local()
_tool_approval_modes: dict[str, str] = {}
def get_current_tier() -> Optional[str]:
"""Returns the current tier from thread-local storage."""
return getattr(_local_storage, "current_tier", None)
def get_current_tier_result() -> Result[str]:
"""Returns the current tier from thread-local storage as a Result."""
return Result(data=getattr(_local_storage, "current_tier", None))
def set_current_tier(tier: Optional[str]) -> None:
"""Sets the current tier in thread-local storage."""
@@ -246,10 +246,10 @@ COMMS_CLAMP_CHARS: int = 300
#region: Comms Log
def get_comms_log_callback() -> Optional[CommsLogCallback]:
def get_comms_log_callback_result() -> Result[CommsLogCallback]:
tl_cb = getattr(_local_storage, "comms_log_callback", None)
if tl_cb: return tl_cb
return comms_log_callback
if tl_cb: return Result(data=tl_cb)
return Result(data=comms_log_callback)
def set_comms_log_callback(cb: Optional[CommsLogCallback]) -> None:
global comms_log_callback
@@ -264,11 +264,11 @@ def _append_comms(direction: str, kind: str, payload: Metadata) -> None:
"provider": _provider,
"model": _model,
"payload": payload,
"source_tier": get_current_tier(),
"source_tier": get_current_tier_result().data,
"local_ts": time.time(),
}
_comms_log.append(entry)
_cb = get_comms_log_callback()
_cb = get_comms_log_callback_result().data
if _cb is not None:
_cb(entry)
@@ -607,9 +607,9 @@ def set_bias_profile(profile_name: Optional[str]) -> None:
else:
_set_bias_profile_result(profile_name)
def get_bias_profile() -> Optional[str]:
def get_bias_profile_result() -> Result[str]:
"""Returns the name of the currently active bias profile."""
return _active_bias_profile.name if _active_bias_profile else None
return Result(data=_active_bias_profile.name if _active_bias_profile else None)
def _build_anthropic_tools() -> list[ToolDefinition]:
"""
@@ -661,10 +661,9 @@ def _get_anthropic_tools() -> list[Metadata]:
_CACHED_ANTHROPIC_TOOLS = _build_anthropic_tools()
return _CACHED_ANTHROPIC_TOOLS
def _gemini_tool_declaration() -> Optional[types.Tool]:
"""
[C: tests/test_tool_access_exclusion.py:test_gemini_tool_declaration_excludes_disabled]
"""
def _gemini_tool_declaration_result() -> Result[types.Tool]:
"""Result-returning variant of _gemini_tool_declaration."""
# Note: We look up the PARENT package `google.genai` and access `.types`
# as an attribute, not `_require_warmed("google.genai.types")` directly.
# The latter triggers a latent circular-import bug in google-genai's
@@ -723,7 +722,23 @@ def _gemini_tool_declaration() -> Optional[types.Tool]:
required = params.get("required", []),
),
))
return types.Tool(function_declarations=declarations) if declarations else None
if not declarations:
return Result(data=None, errors=[ErrorInfo(kind=ErrorKind.NOT_FOUND, message="No tool declarations to build", source="ai_client._gemini_tool_declaration_result")])
return Result(data=types.Tool(function_declarations=declarations))
def _gemini_tool_declaration_result_legacy_compat() -> Optional[types.Tool]:
"""
LEGACY: prefer _gemini_tool_declaration_result() (returns Result[types.Tool]).
This wrapper is retained for tests that call _gemini_tool_declaration() directly;
it returns Optional[types.Tool] for backward compat only.
[C: tests/test_tool_access_exclusion.py:test_gemini_tool_declaration_excludes_disabled]
"""
r = _gemini_tool_declaration_result()
return r.data if r.ok else None
def _gemini_tool_declaration() -> Optional[types.Tool]:
"""Backward-compat alias for _gemini_tool_declaration_result_legacy_compat."""
return _gemini_tool_declaration_result_legacy_compat()
#endregion: Tool Configuration
@@ -787,7 +802,7 @@ async def _execute_tool_calls_concurrently(
"""
monitor = performance_monitor.get_monitor()
if monitor.enabled: monitor.start_component("ai_client._execute_tool_calls_concurrently")
tier = get_current_tier()
tier = get_current_tier_result().data
file_errors: list[ErrorInfo] = []
tasks = []
for fc in calls:
@@ -1814,7 +1829,7 @@ def _send_gemini(md_content: str, user_message: str, base_dir: str,
try:
_ensure_gemini_client(); mcp_client.configure(file_items or [], [base_dir])
sys_instr = f"{_get_combined_system_prompt()}\n\n<context>\n{md_content}\n</context>"
td = _gemini_tool_declaration() if enable_tools else None
td = _gemini_tool_declaration_result().data if enable_tools else None
tools_decl = [td] if td else None
current_md_hash = hashlib.md5(md_content.encode()).hexdigest()
old_history = None
@@ -1883,9 +1898,9 @@ def _send_gemini(md_content: str, user_message: str, base_dir: str,
r["output"] = val
for r_idx in range(MAX_TOOL_ROUNDS + 2):
events.emit("request_start", payload={"provider": "gemini", "model": _model, "round": r_idx})
# Shared config for this round
td = _gemini_tool_declaration() if enable_tools else None
td = _gemini_tool_declaration_result().data if enable_tools else None
config = types.GenerateContentConfig(
tools=[td] if td else [],
temperature=_temperature,
@@ -2068,7 +2083,7 @@ def _send_gemini_cli(md_content: str, user_message: str, base_dir: str,
"usage": usage
})
if txt and calls:
cb = get_comms_log_callback()
cb = get_comms_log_callback_result().data
if cb:
cb({
"ts": project_manager.now_ts(),
@@ -3078,13 +3093,14 @@ def run_tier4_analysis(stderr: str) -> str:
#region: Session & Public API
def _run_tier4_patch_callback_result(stderr: str, base_dir: str) -> Result[Optional[str]]:
def _run_tier4_patch_callback_result(stderr: str, base_dir: str) -> Result[str]:
"""Tier 4 QA agent: propose a unified-diff patch for the stderr.
Returns Result(data=patch) when a valid diff is produced, Result(data=None)
when no valid diff, Result(data=None, errors=[ErrorInfo]) on SDK failure.
Returns Result(data=patch) when a valid diff is produced, Result(data="")
when no valid diff, Result(data="", errors=[ErrorInfo]) on SDK failure.
The legacy caller (run_tier4_patch_callback) returns result.data
(preserving the original Optional[str] signature).
(preserving the original Optional[str] signature; empty string is treated
as "no patch" by callers).
"""
try:
file_items = project_manager.get_current_file_items()
@@ -3096,16 +3112,22 @@ def _run_tier4_patch_callback_result(stderr: str, base_dir: str) -> Result[Optio
patch = run_tier4_patch_generation(stderr, file_context)
if patch and "---" in patch and "+++" in patch:
return Result(data=patch)
return Result(data=None)
return Result(data="")
except Exception as e:
return Result(
data=None,
data="",
errors=[ErrorInfo(kind=ErrorKind.INTERNAL, message=f"tier4 patch callback failed: {e}", source="ai_client._run_tier4_patch_callback_result", original=e)],
)
def run_tier4_patch_callback_legacy_compat(stderr: str, base_dir: str) -> Optional[str]:
"""LEGACY: prefer _run_tier4_patch_callback_result() (returns Result[str])."""
r = _run_tier4_patch_callback_result(stderr, base_dir)
return r.data if r.ok and r.data else None
def run_tier4_patch_callback(stderr: str, base_dir: str) -> Optional[str]:
return _run_tier4_patch_callback_result(stderr, base_dir).data
"""Backward-compat alias for run_tier4_patch_callback_legacy_compat."""
return run_tier4_patch_callback_legacy_compat(stderr, base_dir)
def _run_tier4_patch_generation_result(error: str, file_context: str) -> Result[str]:
"""Tier 4 QA agent: generate a unified-diff patch for the given error.
+1 -1
View File
@@ -4233,7 +4233,7 @@ class AppController:
"""
session_logger.log_tool_call(script, result, None)
session_logger.log_tool_output(result)
source_tier = ai_client.get_current_tier()
source_tier = ai_client.get_current_tier_result().data
with self._pending_tool_calls_lock:
self._pending_tool_calls.append({"script": script, "result": result, "ts": time.time(), "source_tier": source_tier})
+14 -5
View File
@@ -1283,11 +1283,20 @@ def ts_cpp_update_definition(path: str, name: str, new_content: str) -> str:
#region: Python AST
def _get_symbol_node(tree: ast.AST, name: str) -> Optional[ast.AST]:
"""Helper to find an AST node by name (Class, Function, or Variable). Supports dot notation."""
def _get_symbol_node_legacy_compat(tree: ast.AST, name: str) -> ast.AST | None:
"""LEGACY: prefer _get_symbol_node_result() (returns Result[ast.AST])."""
r = _get_symbol_node_result(tree, name)
return r.data if r.ok else None
def _get_symbol_node(tree: ast.AST, name: str) -> ast.AST | None:
"""Backward-compat alias for _get_symbol_node_legacy_compat."""
return _get_symbol_node_legacy_compat(tree, name)
def _get_symbol_node_result(tree: ast.AST, name: str) -> Result[ast.AST]:
"""Result-returning variant of _get_symbol_node."""
parts = name.split(".")
def find_in_scope(scope_node: Any, target_name: str) -> Optional[ast.AST]:
def find_in_scope(scope_node: Any, target_name: str) -> ast.AST | None:
# scope_node could be Module, ClassDef, or FunctionDef
body = getattr(scope_node, "body", [])
for node in body:
@@ -1305,9 +1314,9 @@ def _get_symbol_node(tree: ast.AST, name: str) -> Optional[ast.AST]:
for part in parts:
found = find_in_scope(current, part)
if not found:
return None
return Result(data=None, errors=[ErrorInfo(kind=ErrorKind.NOT_FOUND, message=f"Symbol {part!r} not found in scope", source="mcp_client._get_symbol_node_result")])
current = found
return current
return Result(data=current)
def py_get_skeleton(path: str) -> str:
"""Returns a skeleton of a Python file (preserving docstrings, stripping function bodies).
+1 -1
View File
@@ -570,7 +570,7 @@ def run_worker_lifecycle(ticket: Ticket, context: WorkerContext, context_files:
if event_queue:
_queue_put(event_queue, 'mma_stream', {'stream_id': f'Tier 3 (Worker): {ticket.id}', 'text': chunk})
old_comms_cb = ai_client.get_comms_log_callback()
old_comms_cb = ai_client.get_comms_log_callback_result().data
def worker_comms_callback(entry: dict) -> None:
entry["mma_ticket_id"] = ticket.id
if event_queue: