63 lines
2.5 KiB
Python
63 lines
2.5 KiB
Python
import unittest
|
|
from fastapi.testclient import TestClient
|
|
import gui_2
|
|
from unittest.mock import patch, MagicMock
|
|
|
|
class TestHeadlessAPI(unittest.TestCase):
|
|
@classmethod
|
|
def setUpClass(cls):
|
|
# We need an App instance to initialize the API, but we want to avoid GUI stuff
|
|
with patch('gui_2.session_logger.open_session'), \
|
|
patch('gui_2.ai_client.set_provider'), \
|
|
patch('gui_2.session_logger.close_session'):
|
|
cls.app_instance = gui_2.App()
|
|
# We will implement create_api method in App
|
|
if hasattr(cls.app_instance, 'create_api'):
|
|
cls.api = cls.app_instance.create_api()
|
|
else:
|
|
cls.api = MagicMock()
|
|
cls.client = TestClient(cls.api)
|
|
|
|
def test_health_endpoint(self):
|
|
response = self.client.get("/health")
|
|
self.assertEqual(response.status_code, 200)
|
|
self.assertEqual(response.json(), {"status": "ok"})
|
|
|
|
def test_status_endpoint_unauthorized(self):
|
|
# Ensure a key is required
|
|
with patch.dict(self.app_instance.config, {"headless": {"api_key": "some-required-key"}}):
|
|
response = self.client.get("/status")
|
|
self.assertEqual(response.status_code, 403)
|
|
|
|
def test_status_endpoint_authorized(self):
|
|
# We'll use a test key
|
|
headers = {"X-API-KEY": "test-secret-key"}
|
|
with patch.dict(self.app_instance.config, {"headless": {"api_key": "test-secret-key"}}):
|
|
response = self.client.get("/status", headers=headers)
|
|
self.assertEqual(response.status_code, 200)
|
|
|
|
def test_generate_endpoint(self):
|
|
payload = {
|
|
"prompt": "Hello AI"
|
|
}
|
|
# Mock ai_client.send and get_comms_log
|
|
with patch('gui_2.ai_client.send') as mock_send, \
|
|
patch('gui_2.ai_client.get_comms_log') as mock_log:
|
|
mock_send.return_value = "Hello from Mock AI"
|
|
mock_log.return_value = [{
|
|
"kind": "response",
|
|
"payload": {
|
|
"usage": {"input_tokens": 10, "output_tokens": 5}
|
|
}
|
|
}]
|
|
|
|
response = self.client.post("/api/v1/generate", json=payload)
|
|
self.assertEqual(response.status_code, 200)
|
|
data = response.json()
|
|
self.assertEqual(data["text"], "Hello from Mock AI")
|
|
self.assertIn("metadata", data)
|
|
self.assertEqual(data["usage"]["input_tokens"], 10)
|
|
|
|
if __name__ == "__main__":
|
|
unittest.main()
|