diff --git a/api_hooks.py b/api_hooks.py index 287248d..c792647 100644 --- a/api_hooks.py +++ b/api_hooks.py @@ -1,10 +1,11 @@ import json import threading -from http.server import HTTPServer, BaseHTTPRequestHandler +import uuid +from http.server import ThreadingHTTPServer, BaseHTTPRequestHandler import logging import session_logger -class HookServerInstance(HTTPServer): +class HookServerInstance(ThreadingHTTPServer): """Custom HTTPServer that carries a reference to the main App instance.""" def __init__(self, server_address, RequestHandlerClass, app): super().__init__(server_address, RequestHandlerClass) @@ -60,10 +61,10 @@ class HookHandler(BaseHTTPRequestHandler): data = json.loads(body.decode('utf-8')) field_tag = data.get("field") print(f"[DEBUG] Hook Server: get_value for {field_tag}") - + event = threading.Event() result = {"value": None} - + def get_val(): try: if field_tag in app._settable_fields: @@ -81,7 +82,7 @@ class HookHandler(BaseHTTPRequestHandler): "action": "custom_callback", "callback": get_val }) - + if event.wait(timeout=2): self.send_response(200) self.send_header('Content-Type', 'application/json') @@ -95,7 +96,7 @@ class HookHandler(BaseHTTPRequestHandler): field_tag = self.path.split('/')[-1] event = threading.Event() result = {"value": None} - + def get_val(): try: if field_tag in app._settable_fields: @@ -109,7 +110,7 @@ class HookHandler(BaseHTTPRequestHandler): "action": "custom_callback", "callback": get_val }) - + if event.wait(timeout=2): self.send_response(200) self.send_header('Content-Type', 'application/json') @@ -122,7 +123,7 @@ class HookHandler(BaseHTTPRequestHandler): # Safe way to query multiple states at once via the main thread queue event = threading.Event() result = {} - + def check_all(): try: # Generic state check based on App attributes (works for both DPG and ImGui versions) @@ -138,7 +139,7 @@ class HookHandler(BaseHTTPRequestHandler): "action": "custom_callback", "callback": check_all }) - + if event.wait(timeout=2): self.send_response(200) self.send_header('Content-Type', 'application/json') @@ -158,7 +159,7 @@ class HookHandler(BaseHTTPRequestHandler): body = self.rfile.read(content_length) body_str = body.decode('utf-8') if body else "" session_logger.log_api_hook("POST", self.path, body_str) - + try: data = json.loads(body_str) if body_str else {} if self.path == '/api/project': @@ -179,12 +180,74 @@ class HookHandler(BaseHTTPRequestHandler): elif self.path == '/api/gui': with app._pending_gui_tasks_lock: app._pending_gui_tasks.append(data) - + self.send_response(200) self.send_header('Content-Type', 'application/json') self.end_headers() self.wfile.write( json.dumps({'status': 'queued'}).encode('utf-8')) + elif self.path == '/api/ask': + request_id = str(uuid.uuid4()) + event = threading.Event() + + if not hasattr(app, '_pending_asks'): + app._pending_asks = {} + if not hasattr(app, '_ask_responses'): + app._ask_responses = {} + + app._pending_asks[request_id] = event + + # Emit event for test/client discovery + with app._api_event_queue_lock: + app._api_event_queue.append({ + "type": "ask_received", + "request_id": request_id, + "data": data + }) + + with app._pending_gui_tasks_lock: + app._pending_gui_tasks.append({ + "type": "ask", + "request_id": request_id, + "data": data + }) + + if event.wait(timeout=60.0): + response_data = app._ask_responses.get(request_id) + # Clean up response after reading + if request_id in app._ask_responses: + del app._ask_responses[request_id] + + self.send_response(200) + self.send_header('Content-Type', 'application/json') + self.end_headers() + self.wfile.write(json.dumps({'status': 'ok', 'response': response_data}).encode('utf-8')) + else: + if request_id in app._pending_asks: + del app._pending_asks[request_id] + self.send_response(504) + self.end_headers() + self.wfile.write(json.dumps({'error': 'timeout'}).encode('utf-8')) + + elif self.path == '/api/ask/respond': + request_id = data.get('request_id') + response_data = data.get('response') + + if request_id and hasattr(app, '_pending_asks') and request_id in app._pending_asks: + app._ask_responses[request_id] = response_data + event = app._pending_asks[request_id] + event.set() + + # Clean up pending ask entry + del app._pending_asks[request_id] + + self.send_response(200) + self.send_header('Content-Type', 'application/json') + self.end_headers() + self.wfile.write(json.dumps({'status': 'ok'}).encode('utf-8')) + else: + self.send_response(404) + self.end_headers() else: self.send_response(404) self.end_headers() @@ -207,19 +270,25 @@ class HookServer: def start(self): if not getattr(self.app, 'test_hooks_enabled', False): return - + # Ensure the app has the task queue and lock initialized if not hasattr(self.app, '_pending_gui_tasks'): self.app._pending_gui_tasks = [] if not hasattr(self.app, '_pending_gui_tasks_lock'): self.app._pending_gui_tasks_lock = threading.Lock() - + + # Initialize ask-related dictionaries + if not hasattr(self.app, '_pending_asks'): + self.app._pending_asks = {} + if not hasattr(self.app, '_ask_responses'): + self.app._ask_responses = {} + # Event queue for test script subscriptions if not hasattr(self.app, '_api_event_queue'): self.app._api_event_queue = [] if not hasattr(self.app, '_api_event_queue_lock'): self.app._api_event_queue_lock = threading.Lock() - + self.server = HookServerInstance(('127.0.0.1', self.port), HookHandler, self.app) self.thread = threading.Thread(target=self.server.serve_forever, daemon=True) self.thread.start() diff --git a/tests/test_sync_hooks.py b/tests/test_sync_hooks.py new file mode 100644 index 0000000..954b2ff --- /dev/null +++ b/tests/test_sync_hooks.py @@ -0,0 +1,73 @@ +import threading +import time +import requests +import pytest +from api_hook_client import ApiHookClient + +def test_api_ask_synchronous_flow(live_gui): + """ + Tests the full synchronous lifecycle of the /api/ask endpoint: + 1. A client makes a blocking request. + 2. An event is emitted with a unique request_id. + 3. A separate agent responds to that request_id. + 4. The original blocking request completes with the provided data. + """ + # The live_gui fixture starts the Manual Slop application with hooks on 8999. + client = ApiHookClient("http://127.0.0.1:8999") + + # Drain existing events + client.get_events() + + results = {"response": None, "error": None} + + def make_blocking_request(): + try: + # This POST will block until we call /api/ask/respond + # Note: /api/ask returns {'status': 'ok', 'response': ...} + resp = requests.post( + "http://127.0.0.1:8999/api/ask", + json={"prompt": "Should we proceed with the refactor?"}, + timeout=10 + ) + results["response"] = resp.json() + except Exception as e: + results["error"] = str(e) + + # Start the request in a background thread + t = threading.Thread(target=make_blocking_request) + t.start() + + # Poll for the 'ask_received' event to find the generated request_id + request_id = None + start_time = time.time() + while time.time() - start_time < 5: + events = client.get_events() + for ev in events: + if ev.get("type") == "ask_received": + request_id = ev.get("request_id") + break + if request_id: + break + time.sleep(0.1) + + assert request_id is not None, "Timed out waiting for 'ask_received' event" + + # Respond to the task via the respond endpoint + expected_response = {"approved": True, "message": "Proceeding as requested."} + resp = requests.post( + "http://127.0.0.1:8999/api/ask/respond", + json={ + "request_id": request_id, + "response": expected_response + } + ) + assert resp.status_code == 200 + + # Join the thread and verify the original request received the correct data + t.join(timeout=5) + assert not t.is_alive(), "Background thread failed to unblock" + assert results["error"] is None, f"Request failed: {results['error']}" + + # The /api/ask endpoint returns {'status': 'ok', 'response': expected_response} + assert results["response"]["status"] == "ok" + assert results["response"]["response"] == expected_response