fix(headless): Apply review suggestions for track 'manual_slop_headless_20260225'

This commit is contained in:
2026-02-25 13:33:59 -05:00
parent 63fd391dff
commit 9b50bfa75e
2 changed files with 40 additions and 6 deletions

View File

@@ -307,6 +307,7 @@ class App:
self._discussion_names_dirty = True self._discussion_names_dirty = True
def create_api(self) -> FastAPI: def create_api(self) -> FastAPI:
"""Creates and configures the FastAPI application for headless mode."""
api = FastAPI(title="Manual Slop Headless API") api = FastAPI(title="Manual Slop Headless API")
class GenerateRequest(BaseModel): class GenerateRequest(BaseModel):
@@ -322,22 +323,26 @@ class App:
api_key_header = APIKeyHeader(name=API_KEY_NAME, auto_error=False) api_key_header = APIKeyHeader(name=API_KEY_NAME, auto_error=False)
async def get_api_key(header_key: str = Depends(api_key_header)): async def get_api_key(header_key: str = Depends(api_key_header)):
"""Validates the API key from the request header against configuration."""
headless_cfg = self.config.get("headless", {}) headless_cfg = self.config.get("headless", {})
config_key = headless_cfg.get("api_key", "").strip() config_key = headless_cfg.get("api_key", "").strip()
env_key = os.environ.get("SLOP_API_KEY", "").strip() env_key = os.environ.get("SLOP_API_KEY", "").strip()
target_key = env_key or config_key target_key = env_key or config_key
if not target_key: if not target_key:
return None # If no key is configured, we must deny access by default for security
raise HTTPException(status_code=403, detail="API Key not configured on server")
if header_key == target_key: if header_key == target_key:
return header_key return header_key
raise HTTPException(status_code=403, detail="Could not validate API Key") raise HTTPException(status_code=403, detail="Could not validate API Key")
@api.get("/health") @api.get("/health")
def health(): def health():
"""Basic health check endpoint."""
return {"status": "ok"} return {"status": "ok"}
@api.get("/status", dependencies=[Depends(get_api_key)]) @api.get("/status", dependencies=[Depends(get_api_key)])
def status(): def status():
"""Returns the current status of the AI provider and active project."""
return { return {
"provider": self.current_provider, "provider": self.current_provider,
"model": self.current_model, "model": self.current_model,
@@ -348,6 +353,7 @@ class App:
@api.get("/api/v1/pending_actions", dependencies=[Depends(get_api_key)]) @api.get("/api/v1/pending_actions", dependencies=[Depends(get_api_key)])
def pending_actions(): def pending_actions():
"""Lists all PowerShell scripts awaiting manual confirmation."""
actions = [] actions = []
with self._pending_dialog_lock: with self._pending_dialog_lock:
# Include multi-actions from headless mode # Include multi-actions from headless mode
@@ -368,6 +374,7 @@ class App:
@api.post("/api/v1/confirm/{action_id}", dependencies=[Depends(get_api_key)]) @api.post("/api/v1/confirm/{action_id}", dependencies=[Depends(get_api_key)])
def confirm_action(action_id: str, req: ConfirmRequest): def confirm_action(action_id: str, req: ConfirmRequest):
"""Approves or denies a pending PowerShell script execution."""
success = self.resolve_pending_action(action_id, req.approved) success = self.resolve_pending_action(action_id, req.approved)
if not success: if not success:
raise HTTPException(status_code=404, detail=f"Action ID {action_id} not found") raise HTTPException(status_code=404, detail=f"Action ID {action_id} not found")
@@ -375,6 +382,7 @@ class App:
@api.get("/api/v1/sessions", dependencies=[Depends(get_api_key)]) @api.get("/api/v1/sessions", dependencies=[Depends(get_api_key)])
def list_sessions(): def list_sessions():
"""Lists all available session log files."""
log_dir = Path("logs") log_dir = Path("logs")
if not log_dir.exists(): if not log_dir.exists():
return [] return []
@@ -382,6 +390,7 @@ class App:
@api.get("/api/v1/sessions/{filename}", dependencies=[Depends(get_api_key)]) @api.get("/api/v1/sessions/{filename}", dependencies=[Depends(get_api_key)])
def get_session(filename: str): def get_session(filename: str):
"""Retrieves the content of a specific session log file."""
if ".." in filename or "/" in filename or "\\" in filename: if ".." in filename or "/" in filename or "\\" in filename:
raise HTTPException(status_code=400, detail="Invalid filename") raise HTTPException(status_code=400, detail="Invalid filename")
log_path = Path("logs") / filename log_path = Path("logs") / filename
@@ -395,6 +404,7 @@ class App:
@api.delete("/api/v1/sessions/{filename}", dependencies=[Depends(get_api_key)]) @api.delete("/api/v1/sessions/{filename}", dependencies=[Depends(get_api_key)])
def delete_session(filename: str): def delete_session(filename: str):
"""Deletes a specific session log file."""
if ".." in filename or "/" in filename or "\\" in filename: if ".." in filename or "/" in filename or "\\" in filename:
raise HTTPException(status_code=400, detail="Invalid filename") raise HTTPException(status_code=400, detail="Invalid filename")
log_path = Path("logs") / filename log_path = Path("logs") / filename
@@ -408,6 +418,7 @@ class App:
@api.get("/api/v1/context", dependencies=[Depends(get_api_key)]) @api.get("/api/v1/context", dependencies=[Depends(get_api_key)])
def get_context(): def get_context():
"""Returns the current file and screenshot context configuration."""
return { return {
"files": self.files, "files": self.files,
"screenshots": self.screenshots, "screenshots": self.screenshots,
@@ -417,6 +428,7 @@ class App:
@api.post("/api/v1/generate", dependencies=[Depends(get_api_key)]) @api.post("/api/v1/generate", dependencies=[Depends(get_api_key)])
def generate(req: GenerateRequest): def generate(req: GenerateRequest):
"""Triggers an AI generation request using the current project context."""
if not req.prompt.strip(): if not req.prompt.strip():
raise HTTPException(status_code=400, detail="Prompt cannot be empty") raise HTTPException(status_code=400, detail="Prompt cannot be empty")
@@ -486,6 +498,7 @@ class App:
@api.post("/api/v1/stream", dependencies=[Depends(get_api_key)]) @api.post("/api/v1/stream", dependencies=[Depends(get_api_key)])
async def stream(req: GenerateRequest): async def stream(req: GenerateRequest):
"""Placeholder for streaming AI generation responses (Not yet implemented)."""
# Streaming implementation would require ai_client to support yield-based responses. # Streaming implementation would require ai_client to support yield-based responses.
# Currently added as a placeholder to satisfy spec requirements. # Currently added as a placeholder to satisfy spec requirements.
raise HTTPException(status_code=501, detail="Streaming endpoint (/api/v1/stream) is not yet supported in this version.") raise HTTPException(status_code=501, detail="Streaming endpoint (/api/v1/stream) is not yet supported in this version.")
@@ -971,6 +984,15 @@ class App:
return output return output
def resolve_pending_action(self, action_id: str, approved: bool): def resolve_pending_action(self, action_id: str, approved: bool):
"""Resolves a pending PowerShell script confirmation by its ID.
Args:
action_id: The unique identifier for the pending action.
approved: True if the script should be executed, False otherwise.
Returns:
bool: True if the action was found and resolved, False otherwise.
"""
with self._pending_dialog_lock: with self._pending_dialog_lock:
if action_id in self._pending_actions: if action_id in self._pending_actions:
dialog = self._pending_actions[action_id] dialog = self._pending_actions[action_id]

View File

@@ -11,6 +11,11 @@ class TestHeadlessAPI(unittest.TestCase):
patch('gui_2.ai_client.set_provider'), \ patch('gui_2.ai_client.set_provider'), \
patch('gui_2.session_logger.close_session'): patch('gui_2.session_logger.close_session'):
self.app_instance = gui_2.App() self.app_instance = gui_2.App()
# Set a default API key for tests
self.test_api_key = "test-secret-key"
self.app_instance.config["headless"] = {"api_key": self.test_api_key}
self.headers = {"X-API-KEY": self.test_api_key}
# Clear any leftover state # Clear any leftover state
self.app_instance._pending_actions = {} self.app_instance._pending_actions = {}
self.app_instance._pending_dialog = None self.app_instance._pending_dialog = None
@@ -51,7 +56,7 @@ class TestHeadlessAPI(unittest.TestCase):
} }
}] }]
response = self.client.post("/api/v1/generate", json=payload) response = self.client.post("/api/v1/generate", json=payload, headers=self.headers)
self.assertEqual(response.status_code, 200) self.assertEqual(response.status_code, 200)
data = response.json() data = response.json()
self.assertEqual(data["text"], "Hello from Mock AI") self.assertEqual(data["text"], "Hello from Mock AI")
@@ -64,7 +69,7 @@ class TestHeadlessAPI(unittest.TestCase):
dialog = gui_2.ConfirmDialog("dir", ".") dialog = gui_2.ConfirmDialog("dir", ".")
self.app_instance._pending_actions[dialog._uid] = dialog self.app_instance._pending_actions[dialog._uid] = dialog
response = self.client.get("/api/v1/pending_actions") response = self.client.get("/api/v1/pending_actions", headers=self.headers)
self.assertEqual(response.status_code, 200) self.assertEqual(response.status_code, 200)
data = response.json() data = response.json()
self.assertEqual(len(data), 1) self.assertEqual(len(data), 1)
@@ -77,7 +82,7 @@ class TestHeadlessAPI(unittest.TestCase):
self.app_instance._pending_actions[dialog._uid] = dialog self.app_instance._pending_actions[dialog._uid] = dialog
payload = {"approved": True} payload = {"approved": True}
response = self.client.post("/api/v1/confirm/test-confirm-id", json=payload) response = self.client.post("/api/v1/confirm/test-confirm-id", json=payload, headers=self.headers)
self.assertEqual(response.status_code, 200) self.assertEqual(response.status_code, 200)
self.assertTrue(dialog._done) self.assertTrue(dialog._done)
self.assertTrue(dialog._approved) self.assertTrue(dialog._approved)
@@ -90,7 +95,7 @@ class TestHeadlessAPI(unittest.TestCase):
dummy_log.write_text("dummy content") dummy_log.write_text("dummy content")
try: try:
response = self.client.get("/api/v1/sessions") response = self.client.get("/api/v1/sessions", headers=self.headers)
self.assertEqual(response.status_code, 200) self.assertEqual(response.status_code, 200)
data = response.json() data = response.json()
self.assertIn("test_session_api.log", data) self.assertIn("test_session_api.log", data)
@@ -99,12 +104,19 @@ class TestHeadlessAPI(unittest.TestCase):
dummy_log.unlink() dummy_log.unlink()
def test_get_context_endpoint(self): def test_get_context_endpoint(self):
response = self.client.get("/api/v1/context") response = self.client.get("/api/v1/context", headers=self.headers)
self.assertEqual(response.status_code, 200) self.assertEqual(response.status_code, 200)
data = response.json() data = response.json()
self.assertIn("files", data) self.assertIn("files", data)
self.assertIn("screenshots", data) self.assertIn("screenshots", data)
self.assertIn("files_base_dir", data) self.assertIn("files_base_dir", data)
def test_endpoint_no_api_key_configured(self):
# Test the security fix specifically
with patch.dict(self.app_instance.config, {"headless": {"api_key": ""}}):
response = self.client.get("/status", headers=self.headers)
self.assertEqual(response.status_code, 403)
self.assertEqual(response.json()["detail"], "API Key not configured on server")
if __name__ == "__main__": if __name__ == "__main__":
unittest.main() unittest.main()