74 lines
3.3 KiB
Python
74 lines
3.3 KiB
Python
import unittest
|
|
from typing import Any
|
|
from unittest.mock import patch, MagicMock
|
|
import json
|
|
import orchestrator_pm
|
|
import mma_prompts
|
|
|
|
class TestOrchestratorPM(unittest.TestCase):
|
|
|
|
@patch('summarize.build_summary_markdown')
|
|
@patch('ai_client.send')
|
|
def test_generate_tracks_success(self, mock_send: Any, mock_summarize: Any) -> None:
|
|
# Setup mocks
|
|
mock_summarize.return_value = "REPO_MAP_CONTENT"
|
|
mock_response_data = [
|
|
{
|
|
"id": "track_1",
|
|
"type": "Track",
|
|
"module": "test_module",
|
|
"persona": "Tech Lead",
|
|
"severity": "Medium",
|
|
"goal": "Test goal",
|
|
"acceptance_criteria": ["criteria 1"]
|
|
}
|
|
]
|
|
mock_send.return_value = json.dumps(mock_response_data)
|
|
user_request = "Implement unit tests"
|
|
project_config = {"files": {"paths": ["src"]}}
|
|
file_items = [{"path": "src/main.py", "content": "print('hello')"}]
|
|
# Execute
|
|
result = orchestrator_pm.generate_tracks(user_request, project_config, file_items)
|
|
# Verify summarize call
|
|
mock_summarize.assert_called_once_with(file_items)
|
|
# Verify ai_client.send call
|
|
expected_system_prompt = mma_prompts.PROMPTS['tier1_epic_init']
|
|
mock_send.assert_called_once()
|
|
args, kwargs = mock_send.call_args
|
|
self.assertEqual(kwargs['md_content'], "")
|
|
# Cannot check system_prompt via mock_send kwargs anymore as it's set globally
|
|
# But we can verify user_message was passed
|
|
self.assertIn(user_request, kwargs['user_message'])
|
|
self.assertIn("REPO_MAP_CONTENT", kwargs['user_message'])
|
|
# Verify result
|
|
self.assertEqual(result[0]['id'], mock_response_data[0]['id'])
|
|
|
|
@patch('summarize.build_summary_markdown')
|
|
@patch('ai_client.send')
|
|
def test_generate_tracks_markdown_wrapped(self, mock_send: Any, mock_summarize: Any) -> None:
|
|
mock_summarize.return_value = "REPO_MAP"
|
|
mock_response_data = [{"id": "track_1"}]
|
|
expected_result = [{"id": "track_1", "title": "Untitled Track"}]
|
|
# Wrapped in ```json ... ```
|
|
mock_send.return_value = f"Here is the plan:\n```json\n{json.dumps(mock_response_data)}\n```\nHope this helps."
|
|
result = orchestrator_pm.generate_tracks("req", {}, [])
|
|
self.assertEqual(result, expected_result)
|
|
# Wrapped in ``` ... ```
|
|
mock_send.return_value = f"```\n{json.dumps(mock_response_data)}\n```"
|
|
result = orchestrator_pm.generate_tracks("req", {}, [])
|
|
self.assertEqual(result, expected_result)
|
|
|
|
@patch('summarize.build_summary_markdown')
|
|
@patch('ai_client.send')
|
|
def test_generate_tracks_malformed_json(self, mock_send: Any, mock_summarize: Any) -> None:
|
|
mock_summarize.return_value = "REPO_MAP"
|
|
mock_send.return_value = "NOT A JSON"
|
|
# Should return empty list and print error (we can mock print if we want to be thorough)
|
|
with patch('builtins.print') as mock_print:
|
|
result = orchestrator_pm.generate_tracks("req", {}, [])
|
|
self.assertEqual(result, [])
|
|
mock_print.assert_any_call("Error parsing Tier 1 response: Expecting value: line 1 column 1 (char 0)")
|
|
|
|
if __name__ == '__main__':
|
|
unittest.main()
|