fix(headless): Apply review suggestions for track 'manual_slop_headless_20260225'
This commit is contained in:
24
gui_2.py
24
gui_2.py
@@ -307,6 +307,7 @@ class App:
|
|||||||
self._discussion_names_dirty = True
|
self._discussion_names_dirty = True
|
||||||
|
|
||||||
def create_api(self) -> FastAPI:
|
def create_api(self) -> FastAPI:
|
||||||
|
"""Creates and configures the FastAPI application for headless mode."""
|
||||||
api = FastAPI(title="Manual Slop Headless API")
|
api = FastAPI(title="Manual Slop Headless API")
|
||||||
|
|
||||||
class GenerateRequest(BaseModel):
|
class GenerateRequest(BaseModel):
|
||||||
@@ -322,22 +323,26 @@ class App:
|
|||||||
api_key_header = APIKeyHeader(name=API_KEY_NAME, auto_error=False)
|
api_key_header = APIKeyHeader(name=API_KEY_NAME, auto_error=False)
|
||||||
|
|
||||||
async def get_api_key(header_key: str = Depends(api_key_header)):
|
async def get_api_key(header_key: str = Depends(api_key_header)):
|
||||||
|
"""Validates the API key from the request header against configuration."""
|
||||||
headless_cfg = self.config.get("headless", {})
|
headless_cfg = self.config.get("headless", {})
|
||||||
config_key = headless_cfg.get("api_key", "").strip()
|
config_key = headless_cfg.get("api_key", "").strip()
|
||||||
env_key = os.environ.get("SLOP_API_KEY", "").strip()
|
env_key = os.environ.get("SLOP_API_KEY", "").strip()
|
||||||
target_key = env_key or config_key
|
target_key = env_key or config_key
|
||||||
if not target_key:
|
if not target_key:
|
||||||
return None
|
# If no key is configured, we must deny access by default for security
|
||||||
|
raise HTTPException(status_code=403, detail="API Key not configured on server")
|
||||||
if header_key == target_key:
|
if header_key == target_key:
|
||||||
return header_key
|
return header_key
|
||||||
raise HTTPException(status_code=403, detail="Could not validate API Key")
|
raise HTTPException(status_code=403, detail="Could not validate API Key")
|
||||||
|
|
||||||
@api.get("/health")
|
@api.get("/health")
|
||||||
def health():
|
def health():
|
||||||
|
"""Basic health check endpoint."""
|
||||||
return {"status": "ok"}
|
return {"status": "ok"}
|
||||||
|
|
||||||
@api.get("/status", dependencies=[Depends(get_api_key)])
|
@api.get("/status", dependencies=[Depends(get_api_key)])
|
||||||
def status():
|
def status():
|
||||||
|
"""Returns the current status of the AI provider and active project."""
|
||||||
return {
|
return {
|
||||||
"provider": self.current_provider,
|
"provider": self.current_provider,
|
||||||
"model": self.current_model,
|
"model": self.current_model,
|
||||||
@@ -348,6 +353,7 @@ class App:
|
|||||||
|
|
||||||
@api.get("/api/v1/pending_actions", dependencies=[Depends(get_api_key)])
|
@api.get("/api/v1/pending_actions", dependencies=[Depends(get_api_key)])
|
||||||
def pending_actions():
|
def pending_actions():
|
||||||
|
"""Lists all PowerShell scripts awaiting manual confirmation."""
|
||||||
actions = []
|
actions = []
|
||||||
with self._pending_dialog_lock:
|
with self._pending_dialog_lock:
|
||||||
# Include multi-actions from headless mode
|
# Include multi-actions from headless mode
|
||||||
@@ -368,6 +374,7 @@ class App:
|
|||||||
|
|
||||||
@api.post("/api/v1/confirm/{action_id}", dependencies=[Depends(get_api_key)])
|
@api.post("/api/v1/confirm/{action_id}", dependencies=[Depends(get_api_key)])
|
||||||
def confirm_action(action_id: str, req: ConfirmRequest):
|
def confirm_action(action_id: str, req: ConfirmRequest):
|
||||||
|
"""Approves or denies a pending PowerShell script execution."""
|
||||||
success = self.resolve_pending_action(action_id, req.approved)
|
success = self.resolve_pending_action(action_id, req.approved)
|
||||||
if not success:
|
if not success:
|
||||||
raise HTTPException(status_code=404, detail=f"Action ID {action_id} not found")
|
raise HTTPException(status_code=404, detail=f"Action ID {action_id} not found")
|
||||||
@@ -375,6 +382,7 @@ class App:
|
|||||||
|
|
||||||
@api.get("/api/v1/sessions", dependencies=[Depends(get_api_key)])
|
@api.get("/api/v1/sessions", dependencies=[Depends(get_api_key)])
|
||||||
def list_sessions():
|
def list_sessions():
|
||||||
|
"""Lists all available session log files."""
|
||||||
log_dir = Path("logs")
|
log_dir = Path("logs")
|
||||||
if not log_dir.exists():
|
if not log_dir.exists():
|
||||||
return []
|
return []
|
||||||
@@ -382,6 +390,7 @@ class App:
|
|||||||
|
|
||||||
@api.get("/api/v1/sessions/{filename}", dependencies=[Depends(get_api_key)])
|
@api.get("/api/v1/sessions/{filename}", dependencies=[Depends(get_api_key)])
|
||||||
def get_session(filename: str):
|
def get_session(filename: str):
|
||||||
|
"""Retrieves the content of a specific session log file."""
|
||||||
if ".." in filename or "/" in filename or "\\" in filename:
|
if ".." in filename or "/" in filename or "\\" in filename:
|
||||||
raise HTTPException(status_code=400, detail="Invalid filename")
|
raise HTTPException(status_code=400, detail="Invalid filename")
|
||||||
log_path = Path("logs") / filename
|
log_path = Path("logs") / filename
|
||||||
@@ -395,6 +404,7 @@ class App:
|
|||||||
|
|
||||||
@api.delete("/api/v1/sessions/{filename}", dependencies=[Depends(get_api_key)])
|
@api.delete("/api/v1/sessions/{filename}", dependencies=[Depends(get_api_key)])
|
||||||
def delete_session(filename: str):
|
def delete_session(filename: str):
|
||||||
|
"""Deletes a specific session log file."""
|
||||||
if ".." in filename or "/" in filename or "\\" in filename:
|
if ".." in filename or "/" in filename or "\\" in filename:
|
||||||
raise HTTPException(status_code=400, detail="Invalid filename")
|
raise HTTPException(status_code=400, detail="Invalid filename")
|
||||||
log_path = Path("logs") / filename
|
log_path = Path("logs") / filename
|
||||||
@@ -408,6 +418,7 @@ class App:
|
|||||||
|
|
||||||
@api.get("/api/v1/context", dependencies=[Depends(get_api_key)])
|
@api.get("/api/v1/context", dependencies=[Depends(get_api_key)])
|
||||||
def get_context():
|
def get_context():
|
||||||
|
"""Returns the current file and screenshot context configuration."""
|
||||||
return {
|
return {
|
||||||
"files": self.files,
|
"files": self.files,
|
||||||
"screenshots": self.screenshots,
|
"screenshots": self.screenshots,
|
||||||
@@ -417,6 +428,7 @@ class App:
|
|||||||
|
|
||||||
@api.post("/api/v1/generate", dependencies=[Depends(get_api_key)])
|
@api.post("/api/v1/generate", dependencies=[Depends(get_api_key)])
|
||||||
def generate(req: GenerateRequest):
|
def generate(req: GenerateRequest):
|
||||||
|
"""Triggers an AI generation request using the current project context."""
|
||||||
if not req.prompt.strip():
|
if not req.prompt.strip():
|
||||||
raise HTTPException(status_code=400, detail="Prompt cannot be empty")
|
raise HTTPException(status_code=400, detail="Prompt cannot be empty")
|
||||||
|
|
||||||
@@ -486,6 +498,7 @@ class App:
|
|||||||
|
|
||||||
@api.post("/api/v1/stream", dependencies=[Depends(get_api_key)])
|
@api.post("/api/v1/stream", dependencies=[Depends(get_api_key)])
|
||||||
async def stream(req: GenerateRequest):
|
async def stream(req: GenerateRequest):
|
||||||
|
"""Placeholder for streaming AI generation responses (Not yet implemented)."""
|
||||||
# Streaming implementation would require ai_client to support yield-based responses.
|
# Streaming implementation would require ai_client to support yield-based responses.
|
||||||
# Currently added as a placeholder to satisfy spec requirements.
|
# Currently added as a placeholder to satisfy spec requirements.
|
||||||
raise HTTPException(status_code=501, detail="Streaming endpoint (/api/v1/stream) is not yet supported in this version.")
|
raise HTTPException(status_code=501, detail="Streaming endpoint (/api/v1/stream) is not yet supported in this version.")
|
||||||
@@ -971,6 +984,15 @@ class App:
|
|||||||
return output
|
return output
|
||||||
|
|
||||||
def resolve_pending_action(self, action_id: str, approved: bool):
|
def resolve_pending_action(self, action_id: str, approved: bool):
|
||||||
|
"""Resolves a pending PowerShell script confirmation by its ID.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
action_id: The unique identifier for the pending action.
|
||||||
|
approved: True if the script should be executed, False otherwise.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
bool: True if the action was found and resolved, False otherwise.
|
||||||
|
"""
|
||||||
with self._pending_dialog_lock:
|
with self._pending_dialog_lock:
|
||||||
if action_id in self._pending_actions:
|
if action_id in self._pending_actions:
|
||||||
dialog = self._pending_actions[action_id]
|
dialog = self._pending_actions[action_id]
|
||||||
|
|||||||
@@ -11,6 +11,11 @@ class TestHeadlessAPI(unittest.TestCase):
|
|||||||
patch('gui_2.ai_client.set_provider'), \
|
patch('gui_2.ai_client.set_provider'), \
|
||||||
patch('gui_2.session_logger.close_session'):
|
patch('gui_2.session_logger.close_session'):
|
||||||
self.app_instance = gui_2.App()
|
self.app_instance = gui_2.App()
|
||||||
|
# Set a default API key for tests
|
||||||
|
self.test_api_key = "test-secret-key"
|
||||||
|
self.app_instance.config["headless"] = {"api_key": self.test_api_key}
|
||||||
|
self.headers = {"X-API-KEY": self.test_api_key}
|
||||||
|
|
||||||
# Clear any leftover state
|
# Clear any leftover state
|
||||||
self.app_instance._pending_actions = {}
|
self.app_instance._pending_actions = {}
|
||||||
self.app_instance._pending_dialog = None
|
self.app_instance._pending_dialog = None
|
||||||
@@ -51,7 +56,7 @@ class TestHeadlessAPI(unittest.TestCase):
|
|||||||
}
|
}
|
||||||
}]
|
}]
|
||||||
|
|
||||||
response = self.client.post("/api/v1/generate", json=payload)
|
response = self.client.post("/api/v1/generate", json=payload, headers=self.headers)
|
||||||
self.assertEqual(response.status_code, 200)
|
self.assertEqual(response.status_code, 200)
|
||||||
data = response.json()
|
data = response.json()
|
||||||
self.assertEqual(data["text"], "Hello from Mock AI")
|
self.assertEqual(data["text"], "Hello from Mock AI")
|
||||||
@@ -64,7 +69,7 @@ class TestHeadlessAPI(unittest.TestCase):
|
|||||||
dialog = gui_2.ConfirmDialog("dir", ".")
|
dialog = gui_2.ConfirmDialog("dir", ".")
|
||||||
self.app_instance._pending_actions[dialog._uid] = dialog
|
self.app_instance._pending_actions[dialog._uid] = dialog
|
||||||
|
|
||||||
response = self.client.get("/api/v1/pending_actions")
|
response = self.client.get("/api/v1/pending_actions", headers=self.headers)
|
||||||
self.assertEqual(response.status_code, 200)
|
self.assertEqual(response.status_code, 200)
|
||||||
data = response.json()
|
data = response.json()
|
||||||
self.assertEqual(len(data), 1)
|
self.assertEqual(len(data), 1)
|
||||||
@@ -77,7 +82,7 @@ class TestHeadlessAPI(unittest.TestCase):
|
|||||||
self.app_instance._pending_actions[dialog._uid] = dialog
|
self.app_instance._pending_actions[dialog._uid] = dialog
|
||||||
|
|
||||||
payload = {"approved": True}
|
payload = {"approved": True}
|
||||||
response = self.client.post("/api/v1/confirm/test-confirm-id", json=payload)
|
response = self.client.post("/api/v1/confirm/test-confirm-id", json=payload, headers=self.headers)
|
||||||
self.assertEqual(response.status_code, 200)
|
self.assertEqual(response.status_code, 200)
|
||||||
self.assertTrue(dialog._done)
|
self.assertTrue(dialog._done)
|
||||||
self.assertTrue(dialog._approved)
|
self.assertTrue(dialog._approved)
|
||||||
@@ -90,7 +95,7 @@ class TestHeadlessAPI(unittest.TestCase):
|
|||||||
dummy_log.write_text("dummy content")
|
dummy_log.write_text("dummy content")
|
||||||
|
|
||||||
try:
|
try:
|
||||||
response = self.client.get("/api/v1/sessions")
|
response = self.client.get("/api/v1/sessions", headers=self.headers)
|
||||||
self.assertEqual(response.status_code, 200)
|
self.assertEqual(response.status_code, 200)
|
||||||
data = response.json()
|
data = response.json()
|
||||||
self.assertIn("test_session_api.log", data)
|
self.assertIn("test_session_api.log", data)
|
||||||
@@ -99,12 +104,19 @@ class TestHeadlessAPI(unittest.TestCase):
|
|||||||
dummy_log.unlink()
|
dummy_log.unlink()
|
||||||
|
|
||||||
def test_get_context_endpoint(self):
|
def test_get_context_endpoint(self):
|
||||||
response = self.client.get("/api/v1/context")
|
response = self.client.get("/api/v1/context", headers=self.headers)
|
||||||
self.assertEqual(response.status_code, 200)
|
self.assertEqual(response.status_code, 200)
|
||||||
data = response.json()
|
data = response.json()
|
||||||
self.assertIn("files", data)
|
self.assertIn("files", data)
|
||||||
self.assertIn("screenshots", data)
|
self.assertIn("screenshots", data)
|
||||||
self.assertIn("files_base_dir", data)
|
self.assertIn("files_base_dir", data)
|
||||||
|
|
||||||
|
def test_endpoint_no_api_key_configured(self):
|
||||||
|
# Test the security fix specifically
|
||||||
|
with patch.dict(self.app_instance.config, {"headless": {"api_key": ""}}):
|
||||||
|
response = self.client.get("/status", headers=self.headers)
|
||||||
|
self.assertEqual(response.status_code, 403)
|
||||||
|
self.assertEqual(response.json()["detail"], "API Key not configured on server")
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
unittest.main()
|
unittest.main()
|
||||||
|
|||||||
Reference in New Issue
Block a user