fix(headless): Apply review suggestions for track 'manual_slop_headless_20260225'
This commit is contained in:
24
gui_2.py
24
gui_2.py
@@ -307,6 +307,7 @@ class App:
|
||||
self._discussion_names_dirty = True
|
||||
|
||||
def create_api(self) -> FastAPI:
|
||||
"""Creates and configures the FastAPI application for headless mode."""
|
||||
api = FastAPI(title="Manual Slop Headless API")
|
||||
|
||||
class GenerateRequest(BaseModel):
|
||||
@@ -322,22 +323,26 @@ class App:
|
||||
api_key_header = APIKeyHeader(name=API_KEY_NAME, auto_error=False)
|
||||
|
||||
async def get_api_key(header_key: str = Depends(api_key_header)):
|
||||
"""Validates the API key from the request header against configuration."""
|
||||
headless_cfg = self.config.get("headless", {})
|
||||
config_key = headless_cfg.get("api_key", "").strip()
|
||||
env_key = os.environ.get("SLOP_API_KEY", "").strip()
|
||||
target_key = env_key or config_key
|
||||
if not target_key:
|
||||
return None
|
||||
# If no key is configured, we must deny access by default for security
|
||||
raise HTTPException(status_code=403, detail="API Key not configured on server")
|
||||
if header_key == target_key:
|
||||
return header_key
|
||||
raise HTTPException(status_code=403, detail="Could not validate API Key")
|
||||
|
||||
@api.get("/health")
|
||||
def health():
|
||||
"""Basic health check endpoint."""
|
||||
return {"status": "ok"}
|
||||
|
||||
@api.get("/status", dependencies=[Depends(get_api_key)])
|
||||
def status():
|
||||
"""Returns the current status of the AI provider and active project."""
|
||||
return {
|
||||
"provider": self.current_provider,
|
||||
"model": self.current_model,
|
||||
@@ -348,6 +353,7 @@ class App:
|
||||
|
||||
@api.get("/api/v1/pending_actions", dependencies=[Depends(get_api_key)])
|
||||
def pending_actions():
|
||||
"""Lists all PowerShell scripts awaiting manual confirmation."""
|
||||
actions = []
|
||||
with self._pending_dialog_lock:
|
||||
# Include multi-actions from headless mode
|
||||
@@ -368,6 +374,7 @@ class App:
|
||||
|
||||
@api.post("/api/v1/confirm/{action_id}", dependencies=[Depends(get_api_key)])
|
||||
def confirm_action(action_id: str, req: ConfirmRequest):
|
||||
"""Approves or denies a pending PowerShell script execution."""
|
||||
success = self.resolve_pending_action(action_id, req.approved)
|
||||
if not success:
|
||||
raise HTTPException(status_code=404, detail=f"Action ID {action_id} not found")
|
||||
@@ -375,6 +382,7 @@ class App:
|
||||
|
||||
@api.get("/api/v1/sessions", dependencies=[Depends(get_api_key)])
|
||||
def list_sessions():
|
||||
"""Lists all available session log files."""
|
||||
log_dir = Path("logs")
|
||||
if not log_dir.exists():
|
||||
return []
|
||||
@@ -382,6 +390,7 @@ class App:
|
||||
|
||||
@api.get("/api/v1/sessions/{filename}", dependencies=[Depends(get_api_key)])
|
||||
def get_session(filename: str):
|
||||
"""Retrieves the content of a specific session log file."""
|
||||
if ".." in filename or "/" in filename or "\\" in filename:
|
||||
raise HTTPException(status_code=400, detail="Invalid filename")
|
||||
log_path = Path("logs") / filename
|
||||
@@ -395,6 +404,7 @@ class App:
|
||||
|
||||
@api.delete("/api/v1/sessions/{filename}", dependencies=[Depends(get_api_key)])
|
||||
def delete_session(filename: str):
|
||||
"""Deletes a specific session log file."""
|
||||
if ".." in filename or "/" in filename or "\\" in filename:
|
||||
raise HTTPException(status_code=400, detail="Invalid filename")
|
||||
log_path = Path("logs") / filename
|
||||
@@ -408,6 +418,7 @@ class App:
|
||||
|
||||
@api.get("/api/v1/context", dependencies=[Depends(get_api_key)])
|
||||
def get_context():
|
||||
"""Returns the current file and screenshot context configuration."""
|
||||
return {
|
||||
"files": self.files,
|
||||
"screenshots": self.screenshots,
|
||||
@@ -417,6 +428,7 @@ class App:
|
||||
|
||||
@api.post("/api/v1/generate", dependencies=[Depends(get_api_key)])
|
||||
def generate(req: GenerateRequest):
|
||||
"""Triggers an AI generation request using the current project context."""
|
||||
if not req.prompt.strip():
|
||||
raise HTTPException(status_code=400, detail="Prompt cannot be empty")
|
||||
|
||||
@@ -486,6 +498,7 @@ class App:
|
||||
|
||||
@api.post("/api/v1/stream", dependencies=[Depends(get_api_key)])
|
||||
async def stream(req: GenerateRequest):
|
||||
"""Placeholder for streaming AI generation responses (Not yet implemented)."""
|
||||
# Streaming implementation would require ai_client to support yield-based responses.
|
||||
# Currently added as a placeholder to satisfy spec requirements.
|
||||
raise HTTPException(status_code=501, detail="Streaming endpoint (/api/v1/stream) is not yet supported in this version.")
|
||||
@@ -971,6 +984,15 @@ class App:
|
||||
return output
|
||||
|
||||
def resolve_pending_action(self, action_id: str, approved: bool):
|
||||
"""Resolves a pending PowerShell script confirmation by its ID.
|
||||
|
||||
Args:
|
||||
action_id: The unique identifier for the pending action.
|
||||
approved: True if the script should be executed, False otherwise.
|
||||
|
||||
Returns:
|
||||
bool: True if the action was found and resolved, False otherwise.
|
||||
"""
|
||||
with self._pending_dialog_lock:
|
||||
if action_id in self._pending_actions:
|
||||
dialog = self._pending_actions[action_id]
|
||||
|
||||
@@ -11,6 +11,11 @@ class TestHeadlessAPI(unittest.TestCase):
|
||||
patch('gui_2.ai_client.set_provider'), \
|
||||
patch('gui_2.session_logger.close_session'):
|
||||
self.app_instance = gui_2.App()
|
||||
# Set a default API key for tests
|
||||
self.test_api_key = "test-secret-key"
|
||||
self.app_instance.config["headless"] = {"api_key": self.test_api_key}
|
||||
self.headers = {"X-API-KEY": self.test_api_key}
|
||||
|
||||
# Clear any leftover state
|
||||
self.app_instance._pending_actions = {}
|
||||
self.app_instance._pending_dialog = None
|
||||
@@ -51,7 +56,7 @@ class TestHeadlessAPI(unittest.TestCase):
|
||||
}
|
||||
}]
|
||||
|
||||
response = self.client.post("/api/v1/generate", json=payload)
|
||||
response = self.client.post("/api/v1/generate", json=payload, headers=self.headers)
|
||||
self.assertEqual(response.status_code, 200)
|
||||
data = response.json()
|
||||
self.assertEqual(data["text"], "Hello from Mock AI")
|
||||
@@ -64,7 +69,7 @@ class TestHeadlessAPI(unittest.TestCase):
|
||||
dialog = gui_2.ConfirmDialog("dir", ".")
|
||||
self.app_instance._pending_actions[dialog._uid] = dialog
|
||||
|
||||
response = self.client.get("/api/v1/pending_actions")
|
||||
response = self.client.get("/api/v1/pending_actions", headers=self.headers)
|
||||
self.assertEqual(response.status_code, 200)
|
||||
data = response.json()
|
||||
self.assertEqual(len(data), 1)
|
||||
@@ -77,7 +82,7 @@ class TestHeadlessAPI(unittest.TestCase):
|
||||
self.app_instance._pending_actions[dialog._uid] = dialog
|
||||
|
||||
payload = {"approved": True}
|
||||
response = self.client.post("/api/v1/confirm/test-confirm-id", json=payload)
|
||||
response = self.client.post("/api/v1/confirm/test-confirm-id", json=payload, headers=self.headers)
|
||||
self.assertEqual(response.status_code, 200)
|
||||
self.assertTrue(dialog._done)
|
||||
self.assertTrue(dialog._approved)
|
||||
@@ -90,7 +95,7 @@ class TestHeadlessAPI(unittest.TestCase):
|
||||
dummy_log.write_text("dummy content")
|
||||
|
||||
try:
|
||||
response = self.client.get("/api/v1/sessions")
|
||||
response = self.client.get("/api/v1/sessions", headers=self.headers)
|
||||
self.assertEqual(response.status_code, 200)
|
||||
data = response.json()
|
||||
self.assertIn("test_session_api.log", data)
|
||||
@@ -99,12 +104,19 @@ class TestHeadlessAPI(unittest.TestCase):
|
||||
dummy_log.unlink()
|
||||
|
||||
def test_get_context_endpoint(self):
|
||||
response = self.client.get("/api/v1/context")
|
||||
response = self.client.get("/api/v1/context", headers=self.headers)
|
||||
self.assertEqual(response.status_code, 200)
|
||||
data = response.json()
|
||||
self.assertIn("files", data)
|
||||
self.assertIn("screenshots", data)
|
||||
self.assertIn("files_base_dir", data)
|
||||
|
||||
def test_endpoint_no_api_key_configured(self):
|
||||
# Test the security fix specifically
|
||||
with patch.dict(self.app_instance.config, {"headless": {"api_key": ""}}):
|
||||
response = self.client.get("/status", headers=self.headers)
|
||||
self.assertEqual(response.status_code, 403)
|
||||
self.assertEqual(response.json()["detail"], "API Key not configured on server")
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
|
||||
Reference in New Issue
Block a user