diff --git a/src/ai_client.py b/src/ai_client.py index 583f0406..39a61a68 100644 --- a/src/ai_client.py +++ b/src/ai_client.py @@ -158,9 +158,9 @@ _local_storage = threading.local() _tool_approval_modes: dict[str, str] = {} -def get_current_tier() -> Optional[str]: - """Returns the current tier from thread-local storage.""" - return getattr(_local_storage, "current_tier", None) +def get_current_tier_result() -> Result[str]: + """Returns the current tier from thread-local storage as a Result.""" + return Result(data=getattr(_local_storage, "current_tier", None)) def set_current_tier(tier: Optional[str]) -> None: """Sets the current tier in thread-local storage.""" @@ -246,10 +246,10 @@ COMMS_CLAMP_CHARS: int = 300 #region: Comms Log -def get_comms_log_callback() -> Optional[CommsLogCallback]: +def get_comms_log_callback_result() -> Result[CommsLogCallback]: tl_cb = getattr(_local_storage, "comms_log_callback", None) - if tl_cb: return tl_cb - return comms_log_callback + if tl_cb: return Result(data=tl_cb) + return Result(data=comms_log_callback) def set_comms_log_callback(cb: Optional[CommsLogCallback]) -> None: global comms_log_callback @@ -264,11 +264,11 @@ def _append_comms(direction: str, kind: str, payload: Metadata) -> None: "provider": _provider, "model": _model, "payload": payload, - "source_tier": get_current_tier(), + "source_tier": get_current_tier_result().data, "local_ts": time.time(), } _comms_log.append(entry) - _cb = get_comms_log_callback() + _cb = get_comms_log_callback_result().data if _cb is not None: _cb(entry) @@ -607,9 +607,9 @@ def set_bias_profile(profile_name: Optional[str]) -> None: else: _set_bias_profile_result(profile_name) -def get_bias_profile() -> Optional[str]: +def get_bias_profile_result() -> Result[str]: """Returns the name of the currently active bias profile.""" - return _active_bias_profile.name if _active_bias_profile else None + return Result(data=_active_bias_profile.name if _active_bias_profile else None) def _build_anthropic_tools() -> list[ToolDefinition]: """ @@ -661,10 +661,9 @@ def _get_anthropic_tools() -> list[Metadata]: _CACHED_ANTHROPIC_TOOLS = _build_anthropic_tools() return _CACHED_ANTHROPIC_TOOLS -def _gemini_tool_declaration() -> Optional[types.Tool]: - """ - [C: tests/test_tool_access_exclusion.py:test_gemini_tool_declaration_excludes_disabled] - """ + +def _gemini_tool_declaration_result() -> Result[types.Tool]: + """Result-returning variant of _gemini_tool_declaration.""" # Note: We look up the PARENT package `google.genai` and access `.types` # as an attribute, not `_require_warmed("google.genai.types")` directly. # The latter triggers a latent circular-import bug in google-genai's @@ -723,7 +722,23 @@ def _gemini_tool_declaration() -> Optional[types.Tool]: required = params.get("required", []), ), )) - return types.Tool(function_declarations=declarations) if declarations else None + if not declarations: + return Result(data=None, errors=[ErrorInfo(kind=ErrorKind.NOT_FOUND, message="No tool declarations to build", source="ai_client._gemini_tool_declaration_result")]) + return Result(data=types.Tool(function_declarations=declarations)) + +def _gemini_tool_declaration_result_legacy_compat() -> Optional[types.Tool]: + """ + LEGACY: prefer _gemini_tool_declaration_result() (returns Result[types.Tool]). + This wrapper is retained for tests that call _gemini_tool_declaration() directly; + it returns Optional[types.Tool] for backward compat only. + [C: tests/test_tool_access_exclusion.py:test_gemini_tool_declaration_excludes_disabled] + """ + r = _gemini_tool_declaration_result() + return r.data if r.ok else None + +def _gemini_tool_declaration() -> Optional[types.Tool]: + """Backward-compat alias for _gemini_tool_declaration_result_legacy_compat.""" + return _gemini_tool_declaration_result_legacy_compat() #endregion: Tool Configuration @@ -787,7 +802,7 @@ async def _execute_tool_calls_concurrently( """ monitor = performance_monitor.get_monitor() if monitor.enabled: monitor.start_component("ai_client._execute_tool_calls_concurrently") - tier = get_current_tier() + tier = get_current_tier_result().data file_errors: list[ErrorInfo] = [] tasks = [] for fc in calls: @@ -1814,7 +1829,7 @@ def _send_gemini(md_content: str, user_message: str, base_dir: str, try: _ensure_gemini_client(); mcp_client.configure(file_items or [], [base_dir]) sys_instr = f"{_get_combined_system_prompt()}\n\n\n{md_content}\n" - td = _gemini_tool_declaration() if enable_tools else None + td = _gemini_tool_declaration_result().data if enable_tools else None tools_decl = [td] if td else None current_md_hash = hashlib.md5(md_content.encode()).hexdigest() old_history = None @@ -1883,9 +1898,9 @@ def _send_gemini(md_content: str, user_message: str, base_dir: str, r["output"] = val for r_idx in range(MAX_TOOL_ROUNDS + 2): events.emit("request_start", payload={"provider": "gemini", "model": _model, "round": r_idx}) - + # Shared config for this round - td = _gemini_tool_declaration() if enable_tools else None + td = _gemini_tool_declaration_result().data if enable_tools else None config = types.GenerateContentConfig( tools=[td] if td else [], temperature=_temperature, @@ -2068,7 +2083,7 @@ def _send_gemini_cli(md_content: str, user_message: str, base_dir: str, "usage": usage }) if txt and calls: - cb = get_comms_log_callback() + cb = get_comms_log_callback_result().data if cb: cb({ "ts": project_manager.now_ts(), @@ -3078,13 +3093,14 @@ def run_tier4_analysis(stderr: str) -> str: #region: Session & Public API -def _run_tier4_patch_callback_result(stderr: str, base_dir: str) -> Result[Optional[str]]: +def _run_tier4_patch_callback_result(stderr: str, base_dir: str) -> Result[str]: """Tier 4 QA agent: propose a unified-diff patch for the stderr. - Returns Result(data=patch) when a valid diff is produced, Result(data=None) - when no valid diff, Result(data=None, errors=[ErrorInfo]) on SDK failure. + Returns Result(data=patch) when a valid diff is produced, Result(data="") + when no valid diff, Result(data="", errors=[ErrorInfo]) on SDK failure. The legacy caller (run_tier4_patch_callback) returns result.data - (preserving the original Optional[str] signature). + (preserving the original Optional[str] signature; empty string is treated + as "no patch" by callers). """ try: file_items = project_manager.get_current_file_items() @@ -3096,16 +3112,22 @@ def _run_tier4_patch_callback_result(stderr: str, base_dir: str) -> Result[Optio patch = run_tier4_patch_generation(stderr, file_context) if patch and "---" in patch and "+++" in patch: return Result(data=patch) - return Result(data=None) + return Result(data="") except Exception as e: return Result( - data=None, + data="", errors=[ErrorInfo(kind=ErrorKind.INTERNAL, message=f"tier4 patch callback failed: {e}", source="ai_client._run_tier4_patch_callback_result", original=e)], ) +def run_tier4_patch_callback_legacy_compat(stderr: str, base_dir: str) -> Optional[str]: + """LEGACY: prefer _run_tier4_patch_callback_result() (returns Result[str]).""" + r = _run_tier4_patch_callback_result(stderr, base_dir) + return r.data if r.ok and r.data else None + def run_tier4_patch_callback(stderr: str, base_dir: str) -> Optional[str]: - return _run_tier4_patch_callback_result(stderr, base_dir).data + """Backward-compat alias for run_tier4_patch_callback_legacy_compat.""" + return run_tier4_patch_callback_legacy_compat(stderr, base_dir) def _run_tier4_patch_generation_result(error: str, file_context: str) -> Result[str]: """Tier 4 QA agent: generate a unified-diff patch for the given error. diff --git a/src/app_controller.py b/src/app_controller.py index a8913759..3c6bdfcd 100644 --- a/src/app_controller.py +++ b/src/app_controller.py @@ -4233,7 +4233,7 @@ class AppController: """ session_logger.log_tool_call(script, result, None) session_logger.log_tool_output(result) - source_tier = ai_client.get_current_tier() + source_tier = ai_client.get_current_tier_result().data with self._pending_tool_calls_lock: self._pending_tool_calls.append({"script": script, "result": result, "ts": time.time(), "source_tier": source_tier}) diff --git a/src/mcp_client.py b/src/mcp_client.py index 9b457eee..5ef4b75a 100644 --- a/src/mcp_client.py +++ b/src/mcp_client.py @@ -1283,11 +1283,20 @@ def ts_cpp_update_definition(path: str, name: str, new_content: str) -> str: #region: Python AST -def _get_symbol_node(tree: ast.AST, name: str) -> Optional[ast.AST]: - """Helper to find an AST node by name (Class, Function, or Variable). Supports dot notation.""" +def _get_symbol_node_legacy_compat(tree: ast.AST, name: str) -> ast.AST | None: + """LEGACY: prefer _get_symbol_node_result() (returns Result[ast.AST]).""" + r = _get_symbol_node_result(tree, name) + return r.data if r.ok else None + +def _get_symbol_node(tree: ast.AST, name: str) -> ast.AST | None: + """Backward-compat alias for _get_symbol_node_legacy_compat.""" + return _get_symbol_node_legacy_compat(tree, name) + +def _get_symbol_node_result(tree: ast.AST, name: str) -> Result[ast.AST]: + """Result-returning variant of _get_symbol_node.""" parts = name.split(".") - def find_in_scope(scope_node: Any, target_name: str) -> Optional[ast.AST]: + def find_in_scope(scope_node: Any, target_name: str) -> ast.AST | None: # scope_node could be Module, ClassDef, or FunctionDef body = getattr(scope_node, "body", []) for node in body: @@ -1305,9 +1314,9 @@ def _get_symbol_node(tree: ast.AST, name: str) -> Optional[ast.AST]: for part in parts: found = find_in_scope(current, part) if not found: - return None + return Result(data=None, errors=[ErrorInfo(kind=ErrorKind.NOT_FOUND, message=f"Symbol {part!r} not found in scope", source="mcp_client._get_symbol_node_result")]) current = found - return current + return Result(data=current) def py_get_skeleton(path: str) -> str: """Returns a skeleton of a Python file (preserving docstrings, stripping function bodies). diff --git a/src/multi_agent_conductor.py b/src/multi_agent_conductor.py index 2b77af34..5ee804df 100644 --- a/src/multi_agent_conductor.py +++ b/src/multi_agent_conductor.py @@ -570,7 +570,7 @@ def run_worker_lifecycle(ticket: Ticket, context: WorkerContext, context_files: if event_queue: _queue_put(event_queue, 'mma_stream', {'stream_id': f'Tier 3 (Worker): {ticket.id}', 'text': chunk}) - old_comms_cb = ai_client.get_comms_log_callback() + old_comms_cb = ai_client.get_comms_log_callback_result().data def worker_comms_callback(entry: dict) -> None: entry["mma_ticket_id"] = ticket.id if event_queue: