feat(ipc): support synchronous 'ask' requests in api_hooks
This commit is contained in:
97
api_hooks.py
97
api_hooks.py
@@ -1,10 +1,11 @@
|
||||
import json
|
||||
import threading
|
||||
from http.server import HTTPServer, BaseHTTPRequestHandler
|
||||
import uuid
|
||||
from http.server import ThreadingHTTPServer, BaseHTTPRequestHandler
|
||||
import logging
|
||||
import session_logger
|
||||
|
||||
class HookServerInstance(HTTPServer):
|
||||
class HookServerInstance(ThreadingHTTPServer):
|
||||
"""Custom HTTPServer that carries a reference to the main App instance."""
|
||||
def __init__(self, server_address, RequestHandlerClass, app):
|
||||
super().__init__(server_address, RequestHandlerClass)
|
||||
@@ -60,10 +61,10 @@ class HookHandler(BaseHTTPRequestHandler):
|
||||
data = json.loads(body.decode('utf-8'))
|
||||
field_tag = data.get("field")
|
||||
print(f"[DEBUG] Hook Server: get_value for {field_tag}")
|
||||
|
||||
|
||||
event = threading.Event()
|
||||
result = {"value": None}
|
||||
|
||||
|
||||
def get_val():
|
||||
try:
|
||||
if field_tag in app._settable_fields:
|
||||
@@ -81,7 +82,7 @@ class HookHandler(BaseHTTPRequestHandler):
|
||||
"action": "custom_callback",
|
||||
"callback": get_val
|
||||
})
|
||||
|
||||
|
||||
if event.wait(timeout=2):
|
||||
self.send_response(200)
|
||||
self.send_header('Content-Type', 'application/json')
|
||||
@@ -95,7 +96,7 @@ class HookHandler(BaseHTTPRequestHandler):
|
||||
field_tag = self.path.split('/')[-1]
|
||||
event = threading.Event()
|
||||
result = {"value": None}
|
||||
|
||||
|
||||
def get_val():
|
||||
try:
|
||||
if field_tag in app._settable_fields:
|
||||
@@ -109,7 +110,7 @@ class HookHandler(BaseHTTPRequestHandler):
|
||||
"action": "custom_callback",
|
||||
"callback": get_val
|
||||
})
|
||||
|
||||
|
||||
if event.wait(timeout=2):
|
||||
self.send_response(200)
|
||||
self.send_header('Content-Type', 'application/json')
|
||||
@@ -122,7 +123,7 @@ class HookHandler(BaseHTTPRequestHandler):
|
||||
# Safe way to query multiple states at once via the main thread queue
|
||||
event = threading.Event()
|
||||
result = {}
|
||||
|
||||
|
||||
def check_all():
|
||||
try:
|
||||
# Generic state check based on App attributes (works for both DPG and ImGui versions)
|
||||
@@ -138,7 +139,7 @@ class HookHandler(BaseHTTPRequestHandler):
|
||||
"action": "custom_callback",
|
||||
"callback": check_all
|
||||
})
|
||||
|
||||
|
||||
if event.wait(timeout=2):
|
||||
self.send_response(200)
|
||||
self.send_header('Content-Type', 'application/json')
|
||||
@@ -158,7 +159,7 @@ class HookHandler(BaseHTTPRequestHandler):
|
||||
body = self.rfile.read(content_length)
|
||||
body_str = body.decode('utf-8') if body else ""
|
||||
session_logger.log_api_hook("POST", self.path, body_str)
|
||||
|
||||
|
||||
try:
|
||||
data = json.loads(body_str) if body_str else {}
|
||||
if self.path == '/api/project':
|
||||
@@ -179,12 +180,74 @@ class HookHandler(BaseHTTPRequestHandler):
|
||||
elif self.path == '/api/gui':
|
||||
with app._pending_gui_tasks_lock:
|
||||
app._pending_gui_tasks.append(data)
|
||||
|
||||
|
||||
self.send_response(200)
|
||||
self.send_header('Content-Type', 'application/json')
|
||||
self.end_headers()
|
||||
self.wfile.write(
|
||||
json.dumps({'status': 'queued'}).encode('utf-8'))
|
||||
elif self.path == '/api/ask':
|
||||
request_id = str(uuid.uuid4())
|
||||
event = threading.Event()
|
||||
|
||||
if not hasattr(app, '_pending_asks'):
|
||||
app._pending_asks = {}
|
||||
if not hasattr(app, '_ask_responses'):
|
||||
app._ask_responses = {}
|
||||
|
||||
app._pending_asks[request_id] = event
|
||||
|
||||
# Emit event for test/client discovery
|
||||
with app._api_event_queue_lock:
|
||||
app._api_event_queue.append({
|
||||
"type": "ask_received",
|
||||
"request_id": request_id,
|
||||
"data": data
|
||||
})
|
||||
|
||||
with app._pending_gui_tasks_lock:
|
||||
app._pending_gui_tasks.append({
|
||||
"type": "ask",
|
||||
"request_id": request_id,
|
||||
"data": data
|
||||
})
|
||||
|
||||
if event.wait(timeout=60.0):
|
||||
response_data = app._ask_responses.get(request_id)
|
||||
# Clean up response after reading
|
||||
if request_id in app._ask_responses:
|
||||
del app._ask_responses[request_id]
|
||||
|
||||
self.send_response(200)
|
||||
self.send_header('Content-Type', 'application/json')
|
||||
self.end_headers()
|
||||
self.wfile.write(json.dumps({'status': 'ok', 'response': response_data}).encode('utf-8'))
|
||||
else:
|
||||
if request_id in app._pending_asks:
|
||||
del app._pending_asks[request_id]
|
||||
self.send_response(504)
|
||||
self.end_headers()
|
||||
self.wfile.write(json.dumps({'error': 'timeout'}).encode('utf-8'))
|
||||
|
||||
elif self.path == '/api/ask/respond':
|
||||
request_id = data.get('request_id')
|
||||
response_data = data.get('response')
|
||||
|
||||
if request_id and hasattr(app, '_pending_asks') and request_id in app._pending_asks:
|
||||
app._ask_responses[request_id] = response_data
|
||||
event = app._pending_asks[request_id]
|
||||
event.set()
|
||||
|
||||
# Clean up pending ask entry
|
||||
del app._pending_asks[request_id]
|
||||
|
||||
self.send_response(200)
|
||||
self.send_header('Content-Type', 'application/json')
|
||||
self.end_headers()
|
||||
self.wfile.write(json.dumps({'status': 'ok'}).encode('utf-8'))
|
||||
else:
|
||||
self.send_response(404)
|
||||
self.end_headers()
|
||||
else:
|
||||
self.send_response(404)
|
||||
self.end_headers()
|
||||
@@ -207,19 +270,25 @@ class HookServer:
|
||||
def start(self):
|
||||
if not getattr(self.app, 'test_hooks_enabled', False):
|
||||
return
|
||||
|
||||
|
||||
# Ensure the app has the task queue and lock initialized
|
||||
if not hasattr(self.app, '_pending_gui_tasks'):
|
||||
self.app._pending_gui_tasks = []
|
||||
if not hasattr(self.app, '_pending_gui_tasks_lock'):
|
||||
self.app._pending_gui_tasks_lock = threading.Lock()
|
||||
|
||||
|
||||
# Initialize ask-related dictionaries
|
||||
if not hasattr(self.app, '_pending_asks'):
|
||||
self.app._pending_asks = {}
|
||||
if not hasattr(self.app, '_ask_responses'):
|
||||
self.app._ask_responses = {}
|
||||
|
||||
# Event queue for test script subscriptions
|
||||
if not hasattr(self.app, '_api_event_queue'):
|
||||
self.app._api_event_queue = []
|
||||
if not hasattr(self.app, '_api_event_queue_lock'):
|
||||
self.app._api_event_queue_lock = threading.Lock()
|
||||
|
||||
|
||||
self.server = HookServerInstance(('127.0.0.1', self.port), HookHandler, self.app)
|
||||
self.thread = threading.Thread(target=self.server.serve_forever, daemon=True)
|
||||
self.thread.start()
|
||||
|
||||
73
tests/test_sync_hooks.py
Normal file
73
tests/test_sync_hooks.py
Normal file
@@ -0,0 +1,73 @@
|
||||
import threading
|
||||
import time
|
||||
import requests
|
||||
import pytest
|
||||
from api_hook_client import ApiHookClient
|
||||
|
||||
def test_api_ask_synchronous_flow(live_gui):
|
||||
"""
|
||||
Tests the full synchronous lifecycle of the /api/ask endpoint:
|
||||
1. A client makes a blocking request.
|
||||
2. An event is emitted with a unique request_id.
|
||||
3. A separate agent responds to that request_id.
|
||||
4. The original blocking request completes with the provided data.
|
||||
"""
|
||||
# The live_gui fixture starts the Manual Slop application with hooks on 8999.
|
||||
client = ApiHookClient("http://127.0.0.1:8999")
|
||||
|
||||
# Drain existing events
|
||||
client.get_events()
|
||||
|
||||
results = {"response": None, "error": None}
|
||||
|
||||
def make_blocking_request():
|
||||
try:
|
||||
# This POST will block until we call /api/ask/respond
|
||||
# Note: /api/ask returns {'status': 'ok', 'response': ...}
|
||||
resp = requests.post(
|
||||
"http://127.0.0.1:8999/api/ask",
|
||||
json={"prompt": "Should we proceed with the refactor?"},
|
||||
timeout=10
|
||||
)
|
||||
results["response"] = resp.json()
|
||||
except Exception as e:
|
||||
results["error"] = str(e)
|
||||
|
||||
# Start the request in a background thread
|
||||
t = threading.Thread(target=make_blocking_request)
|
||||
t.start()
|
||||
|
||||
# Poll for the 'ask_received' event to find the generated request_id
|
||||
request_id = None
|
||||
start_time = time.time()
|
||||
while time.time() - start_time < 5:
|
||||
events = client.get_events()
|
||||
for ev in events:
|
||||
if ev.get("type") == "ask_received":
|
||||
request_id = ev.get("request_id")
|
||||
break
|
||||
if request_id:
|
||||
break
|
||||
time.sleep(0.1)
|
||||
|
||||
assert request_id is not None, "Timed out waiting for 'ask_received' event"
|
||||
|
||||
# Respond to the task via the respond endpoint
|
||||
expected_response = {"approved": True, "message": "Proceeding as requested."}
|
||||
resp = requests.post(
|
||||
"http://127.0.0.1:8999/api/ask/respond",
|
||||
json={
|
||||
"request_id": request_id,
|
||||
"response": expected_response
|
||||
}
|
||||
)
|
||||
assert resp.status_code == 200
|
||||
|
||||
# Join the thread and verify the original request received the correct data
|
||||
t.join(timeout=5)
|
||||
assert not t.is_alive(), "Background thread failed to unblock"
|
||||
assert results["error"] is None, f"Request failed: {results['error']}"
|
||||
|
||||
# The /api/ask endpoint returns {'status': 'ok', 'response': expected_response}
|
||||
assert results["response"]["status"] == "ok"
|
||||
assert results["response"]["response"] == expected_response
|
||||
Reference in New Issue
Block a user