feat(ipc): support synchronous 'ask' requests in api_hooks
This commit is contained in:
97
api_hooks.py
97
api_hooks.py
@@ -1,10 +1,11 @@
|
|||||||
import json
|
import json
|
||||||
import threading
|
import threading
|
||||||
from http.server import HTTPServer, BaseHTTPRequestHandler
|
import uuid
|
||||||
|
from http.server import ThreadingHTTPServer, BaseHTTPRequestHandler
|
||||||
import logging
|
import logging
|
||||||
import session_logger
|
import session_logger
|
||||||
|
|
||||||
class HookServerInstance(HTTPServer):
|
class HookServerInstance(ThreadingHTTPServer):
|
||||||
"""Custom HTTPServer that carries a reference to the main App instance."""
|
"""Custom HTTPServer that carries a reference to the main App instance."""
|
||||||
def __init__(self, server_address, RequestHandlerClass, app):
|
def __init__(self, server_address, RequestHandlerClass, app):
|
||||||
super().__init__(server_address, RequestHandlerClass)
|
super().__init__(server_address, RequestHandlerClass)
|
||||||
@@ -60,10 +61,10 @@ class HookHandler(BaseHTTPRequestHandler):
|
|||||||
data = json.loads(body.decode('utf-8'))
|
data = json.loads(body.decode('utf-8'))
|
||||||
field_tag = data.get("field")
|
field_tag = data.get("field")
|
||||||
print(f"[DEBUG] Hook Server: get_value for {field_tag}")
|
print(f"[DEBUG] Hook Server: get_value for {field_tag}")
|
||||||
|
|
||||||
event = threading.Event()
|
event = threading.Event()
|
||||||
result = {"value": None}
|
result = {"value": None}
|
||||||
|
|
||||||
def get_val():
|
def get_val():
|
||||||
try:
|
try:
|
||||||
if field_tag in app._settable_fields:
|
if field_tag in app._settable_fields:
|
||||||
@@ -81,7 +82,7 @@ class HookHandler(BaseHTTPRequestHandler):
|
|||||||
"action": "custom_callback",
|
"action": "custom_callback",
|
||||||
"callback": get_val
|
"callback": get_val
|
||||||
})
|
})
|
||||||
|
|
||||||
if event.wait(timeout=2):
|
if event.wait(timeout=2):
|
||||||
self.send_response(200)
|
self.send_response(200)
|
||||||
self.send_header('Content-Type', 'application/json')
|
self.send_header('Content-Type', 'application/json')
|
||||||
@@ -95,7 +96,7 @@ class HookHandler(BaseHTTPRequestHandler):
|
|||||||
field_tag = self.path.split('/')[-1]
|
field_tag = self.path.split('/')[-1]
|
||||||
event = threading.Event()
|
event = threading.Event()
|
||||||
result = {"value": None}
|
result = {"value": None}
|
||||||
|
|
||||||
def get_val():
|
def get_val():
|
||||||
try:
|
try:
|
||||||
if field_tag in app._settable_fields:
|
if field_tag in app._settable_fields:
|
||||||
@@ -109,7 +110,7 @@ class HookHandler(BaseHTTPRequestHandler):
|
|||||||
"action": "custom_callback",
|
"action": "custom_callback",
|
||||||
"callback": get_val
|
"callback": get_val
|
||||||
})
|
})
|
||||||
|
|
||||||
if event.wait(timeout=2):
|
if event.wait(timeout=2):
|
||||||
self.send_response(200)
|
self.send_response(200)
|
||||||
self.send_header('Content-Type', 'application/json')
|
self.send_header('Content-Type', 'application/json')
|
||||||
@@ -122,7 +123,7 @@ class HookHandler(BaseHTTPRequestHandler):
|
|||||||
# Safe way to query multiple states at once via the main thread queue
|
# Safe way to query multiple states at once via the main thread queue
|
||||||
event = threading.Event()
|
event = threading.Event()
|
||||||
result = {}
|
result = {}
|
||||||
|
|
||||||
def check_all():
|
def check_all():
|
||||||
try:
|
try:
|
||||||
# Generic state check based on App attributes (works for both DPG and ImGui versions)
|
# Generic state check based on App attributes (works for both DPG and ImGui versions)
|
||||||
@@ -138,7 +139,7 @@ class HookHandler(BaseHTTPRequestHandler):
|
|||||||
"action": "custom_callback",
|
"action": "custom_callback",
|
||||||
"callback": check_all
|
"callback": check_all
|
||||||
})
|
})
|
||||||
|
|
||||||
if event.wait(timeout=2):
|
if event.wait(timeout=2):
|
||||||
self.send_response(200)
|
self.send_response(200)
|
||||||
self.send_header('Content-Type', 'application/json')
|
self.send_header('Content-Type', 'application/json')
|
||||||
@@ -158,7 +159,7 @@ class HookHandler(BaseHTTPRequestHandler):
|
|||||||
body = self.rfile.read(content_length)
|
body = self.rfile.read(content_length)
|
||||||
body_str = body.decode('utf-8') if body else ""
|
body_str = body.decode('utf-8') if body else ""
|
||||||
session_logger.log_api_hook("POST", self.path, body_str)
|
session_logger.log_api_hook("POST", self.path, body_str)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
data = json.loads(body_str) if body_str else {}
|
data = json.loads(body_str) if body_str else {}
|
||||||
if self.path == '/api/project':
|
if self.path == '/api/project':
|
||||||
@@ -179,12 +180,74 @@ class HookHandler(BaseHTTPRequestHandler):
|
|||||||
elif self.path == '/api/gui':
|
elif self.path == '/api/gui':
|
||||||
with app._pending_gui_tasks_lock:
|
with app._pending_gui_tasks_lock:
|
||||||
app._pending_gui_tasks.append(data)
|
app._pending_gui_tasks.append(data)
|
||||||
|
|
||||||
self.send_response(200)
|
self.send_response(200)
|
||||||
self.send_header('Content-Type', 'application/json')
|
self.send_header('Content-Type', 'application/json')
|
||||||
self.end_headers()
|
self.end_headers()
|
||||||
self.wfile.write(
|
self.wfile.write(
|
||||||
json.dumps({'status': 'queued'}).encode('utf-8'))
|
json.dumps({'status': 'queued'}).encode('utf-8'))
|
||||||
|
elif self.path == '/api/ask':
|
||||||
|
request_id = str(uuid.uuid4())
|
||||||
|
event = threading.Event()
|
||||||
|
|
||||||
|
if not hasattr(app, '_pending_asks'):
|
||||||
|
app._pending_asks = {}
|
||||||
|
if not hasattr(app, '_ask_responses'):
|
||||||
|
app._ask_responses = {}
|
||||||
|
|
||||||
|
app._pending_asks[request_id] = event
|
||||||
|
|
||||||
|
# Emit event for test/client discovery
|
||||||
|
with app._api_event_queue_lock:
|
||||||
|
app._api_event_queue.append({
|
||||||
|
"type": "ask_received",
|
||||||
|
"request_id": request_id,
|
||||||
|
"data": data
|
||||||
|
})
|
||||||
|
|
||||||
|
with app._pending_gui_tasks_lock:
|
||||||
|
app._pending_gui_tasks.append({
|
||||||
|
"type": "ask",
|
||||||
|
"request_id": request_id,
|
||||||
|
"data": data
|
||||||
|
})
|
||||||
|
|
||||||
|
if event.wait(timeout=60.0):
|
||||||
|
response_data = app._ask_responses.get(request_id)
|
||||||
|
# Clean up response after reading
|
||||||
|
if request_id in app._ask_responses:
|
||||||
|
del app._ask_responses[request_id]
|
||||||
|
|
||||||
|
self.send_response(200)
|
||||||
|
self.send_header('Content-Type', 'application/json')
|
||||||
|
self.end_headers()
|
||||||
|
self.wfile.write(json.dumps({'status': 'ok', 'response': response_data}).encode('utf-8'))
|
||||||
|
else:
|
||||||
|
if request_id in app._pending_asks:
|
||||||
|
del app._pending_asks[request_id]
|
||||||
|
self.send_response(504)
|
||||||
|
self.end_headers()
|
||||||
|
self.wfile.write(json.dumps({'error': 'timeout'}).encode('utf-8'))
|
||||||
|
|
||||||
|
elif self.path == '/api/ask/respond':
|
||||||
|
request_id = data.get('request_id')
|
||||||
|
response_data = data.get('response')
|
||||||
|
|
||||||
|
if request_id and hasattr(app, '_pending_asks') and request_id in app._pending_asks:
|
||||||
|
app._ask_responses[request_id] = response_data
|
||||||
|
event = app._pending_asks[request_id]
|
||||||
|
event.set()
|
||||||
|
|
||||||
|
# Clean up pending ask entry
|
||||||
|
del app._pending_asks[request_id]
|
||||||
|
|
||||||
|
self.send_response(200)
|
||||||
|
self.send_header('Content-Type', 'application/json')
|
||||||
|
self.end_headers()
|
||||||
|
self.wfile.write(json.dumps({'status': 'ok'}).encode('utf-8'))
|
||||||
|
else:
|
||||||
|
self.send_response(404)
|
||||||
|
self.end_headers()
|
||||||
else:
|
else:
|
||||||
self.send_response(404)
|
self.send_response(404)
|
||||||
self.end_headers()
|
self.end_headers()
|
||||||
@@ -207,19 +270,25 @@ class HookServer:
|
|||||||
def start(self):
|
def start(self):
|
||||||
if not getattr(self.app, 'test_hooks_enabled', False):
|
if not getattr(self.app, 'test_hooks_enabled', False):
|
||||||
return
|
return
|
||||||
|
|
||||||
# Ensure the app has the task queue and lock initialized
|
# Ensure the app has the task queue and lock initialized
|
||||||
if not hasattr(self.app, '_pending_gui_tasks'):
|
if not hasattr(self.app, '_pending_gui_tasks'):
|
||||||
self.app._pending_gui_tasks = []
|
self.app._pending_gui_tasks = []
|
||||||
if not hasattr(self.app, '_pending_gui_tasks_lock'):
|
if not hasattr(self.app, '_pending_gui_tasks_lock'):
|
||||||
self.app._pending_gui_tasks_lock = threading.Lock()
|
self.app._pending_gui_tasks_lock = threading.Lock()
|
||||||
|
|
||||||
|
# Initialize ask-related dictionaries
|
||||||
|
if not hasattr(self.app, '_pending_asks'):
|
||||||
|
self.app._pending_asks = {}
|
||||||
|
if not hasattr(self.app, '_ask_responses'):
|
||||||
|
self.app._ask_responses = {}
|
||||||
|
|
||||||
# Event queue for test script subscriptions
|
# Event queue for test script subscriptions
|
||||||
if not hasattr(self.app, '_api_event_queue'):
|
if not hasattr(self.app, '_api_event_queue'):
|
||||||
self.app._api_event_queue = []
|
self.app._api_event_queue = []
|
||||||
if not hasattr(self.app, '_api_event_queue_lock'):
|
if not hasattr(self.app, '_api_event_queue_lock'):
|
||||||
self.app._api_event_queue_lock = threading.Lock()
|
self.app._api_event_queue_lock = threading.Lock()
|
||||||
|
|
||||||
self.server = HookServerInstance(('127.0.0.1', self.port), HookHandler, self.app)
|
self.server = HookServerInstance(('127.0.0.1', self.port), HookHandler, self.app)
|
||||||
self.thread = threading.Thread(target=self.server.serve_forever, daemon=True)
|
self.thread = threading.Thread(target=self.server.serve_forever, daemon=True)
|
||||||
self.thread.start()
|
self.thread.start()
|
||||||
|
|||||||
73
tests/test_sync_hooks.py
Normal file
73
tests/test_sync_hooks.py
Normal file
@@ -0,0 +1,73 @@
|
|||||||
|
import threading
|
||||||
|
import time
|
||||||
|
import requests
|
||||||
|
import pytest
|
||||||
|
from api_hook_client import ApiHookClient
|
||||||
|
|
||||||
|
def test_api_ask_synchronous_flow(live_gui):
|
||||||
|
"""
|
||||||
|
Tests the full synchronous lifecycle of the /api/ask endpoint:
|
||||||
|
1. A client makes a blocking request.
|
||||||
|
2. An event is emitted with a unique request_id.
|
||||||
|
3. A separate agent responds to that request_id.
|
||||||
|
4. The original blocking request completes with the provided data.
|
||||||
|
"""
|
||||||
|
# The live_gui fixture starts the Manual Slop application with hooks on 8999.
|
||||||
|
client = ApiHookClient("http://127.0.0.1:8999")
|
||||||
|
|
||||||
|
# Drain existing events
|
||||||
|
client.get_events()
|
||||||
|
|
||||||
|
results = {"response": None, "error": None}
|
||||||
|
|
||||||
|
def make_blocking_request():
|
||||||
|
try:
|
||||||
|
# This POST will block until we call /api/ask/respond
|
||||||
|
# Note: /api/ask returns {'status': 'ok', 'response': ...}
|
||||||
|
resp = requests.post(
|
||||||
|
"http://127.0.0.1:8999/api/ask",
|
||||||
|
json={"prompt": "Should we proceed with the refactor?"},
|
||||||
|
timeout=10
|
||||||
|
)
|
||||||
|
results["response"] = resp.json()
|
||||||
|
except Exception as e:
|
||||||
|
results["error"] = str(e)
|
||||||
|
|
||||||
|
# Start the request in a background thread
|
||||||
|
t = threading.Thread(target=make_blocking_request)
|
||||||
|
t.start()
|
||||||
|
|
||||||
|
# Poll for the 'ask_received' event to find the generated request_id
|
||||||
|
request_id = None
|
||||||
|
start_time = time.time()
|
||||||
|
while time.time() - start_time < 5:
|
||||||
|
events = client.get_events()
|
||||||
|
for ev in events:
|
||||||
|
if ev.get("type") == "ask_received":
|
||||||
|
request_id = ev.get("request_id")
|
||||||
|
break
|
||||||
|
if request_id:
|
||||||
|
break
|
||||||
|
time.sleep(0.1)
|
||||||
|
|
||||||
|
assert request_id is not None, "Timed out waiting for 'ask_received' event"
|
||||||
|
|
||||||
|
# Respond to the task via the respond endpoint
|
||||||
|
expected_response = {"approved": True, "message": "Proceeding as requested."}
|
||||||
|
resp = requests.post(
|
||||||
|
"http://127.0.0.1:8999/api/ask/respond",
|
||||||
|
json={
|
||||||
|
"request_id": request_id,
|
||||||
|
"response": expected_response
|
||||||
|
}
|
||||||
|
)
|
||||||
|
assert resp.status_code == 200
|
||||||
|
|
||||||
|
# Join the thread and verify the original request received the correct data
|
||||||
|
t.join(timeout=5)
|
||||||
|
assert not t.is_alive(), "Background thread failed to unblock"
|
||||||
|
assert results["error"] is None, f"Request failed: {results['error']}"
|
||||||
|
|
||||||
|
# The /api/ask endpoint returns {'status': 'ok', 'response': expected_response}
|
||||||
|
assert results["response"]["status"] == "ok"
|
||||||
|
assert results["response"]["response"] == expected_response
|
||||||
Reference in New Issue
Block a user