diff --git a/api_hooks.py b/api_hooks.py index 1bbfd93..5996bba 100644 --- a/api_hooks.py +++ b/api_hooks.py @@ -3,17 +3,61 @@ import threading from http.server import HTTPServer, BaseHTTPRequestHandler import logging +class HookServerInstance(HTTPServer): + def __init__(self, server_address, RequestHandlerClass, app): + super().__init__(server_address, RequestHandlerClass) + self.app = app + class HookHandler(BaseHTTPRequestHandler): def do_GET(self): + app = self.server.app if self.path == '/status': self.send_response(200) self.send_header('Content-Type', 'application/json') self.end_headers() self.wfile.write(json.dumps({'status': 'ok'}).encode('utf-8')) + elif self.path == '/api/project': + self.send_response(200) + self.send_header('Content-Type', 'application/json') + self.end_headers() + self.wfile.write(json.dumps({'project': app.project}).encode('utf-8')) + elif self.path == '/api/session': + self.send_response(200) + self.send_header('Content-Type', 'application/json') + self.end_headers() + self.wfile.write(json.dumps({'session': {'entries': app.disc_entries}}).encode('utf-8')) else: self.send_response(404) self.end_headers() + def do_POST(self): + app = self.server.app + content_length = int(self.headers.get('Content-Length', 0)) + body = self.rfile.read(content_length) + + try: + data = json.loads(body.decode('utf-8')) if body else {} + if self.path == '/api/project': + app.project = data.get('project', app.project) + self.send_response(200) + self.send_header('Content-Type', 'application/json') + self.end_headers() + self.wfile.write(json.dumps({'status': 'updated'}).encode('utf-8')) + elif self.path == '/api/session': + app.disc_entries = data.get('session', {}).get('entries', app.disc_entries) + self.send_response(200) + self.send_header('Content-Type', 'application/json') + self.end_headers() + self.wfile.write(json.dumps({'status': 'updated'}).encode('utf-8')) + else: + self.send_response(404) + self.end_headers() + except Exception as e: + self.send_response(500) + self.send_header('Content-Type', 'application/json') + self.end_headers() + self.wfile.write(json.dumps({'error': str(e)}).encode('utf-8')) + def log_message(self, format, *args): logging.info("Hook API: " + format % args) @@ -27,7 +71,7 @@ class HookServer: def start(self): if not getattr(self.app, 'test_hooks_enabled', False): return - self.server = HTTPServer(('127.0.0.1', self.port), HookHandler) + self.server = HookServerInstance(('127.0.0.1', self.port), HookHandler, self.app) self.thread = threading.Thread(target=self.server.serve_forever, daemon=True) self.thread.start() logging.info(f"Hook server started on port {self.port}") diff --git a/tests/test_hooks.py b/tests/test_hooks.py index 731a1ea..5a72123 100644 --- a/tests/test_hooks.py +++ b/tests/test_hooks.py @@ -42,5 +42,32 @@ def test_ipc_server_starts_and_responds(): assert response.status == 200 data = json.loads(response.read().decode()) assert data.get("status") == "ok" + + # Test project GET + req = urllib.request.Request("http://127.0.0.1:8999/api/project") + with urllib.request.urlopen(req) as response: + assert response.status == 200 + data = json.loads(response.read().decode()) + assert "project" in data + + # Test session GET + req = urllib.request.Request("http://127.0.0.1:8999/api/session") + with urllib.request.urlopen(req) as response: + assert response.status == 200 + data = json.loads(response.read().decode()) + assert "session" in data + + # Test project POST + req = urllib.request.Request("http://127.0.0.1:8999/api/project", method="POST", data=json.dumps({"project": {"foo": "bar"}}).encode("utf-8"), headers={'Content-Type': 'application/json'}) + with urllib.request.urlopen(req) as response: + assert response.status == 200 + assert app_mock.project == {"foo": "bar"} + + # Test session POST + req = urllib.request.Request("http://127.0.0.1:8999/api/session", method="POST", data=json.dumps({"session": {"entries": [{"role": "User", "content": "hi"}]}}).encode("utf-8"), headers={'Content-Type': 'application/json'}) + with urllib.request.urlopen(req) as response: + assert response.status == 200 + assert app_mock.disc_entries == [{"role": "User", "content": "hi"}] + finally: server.stop()