from unittest.mock import MagicMock, patch import pytest from src import ai_client from src import mma_prompts def test_tier4_patch_prompt_exists() -> None: """Test that TIER4_PATCH_PROMPT is defined in mma_prompts.py.""" assert hasattr(mma_prompts, "TIER4_PATCH_PROMPT") prompt = mma_prompts.TIER4_PATCH_PROMPT assert "unified diff" in prompt.lower() or "diff -u" in prompt.lower() assert "---" in prompt assert "+++" in prompt assert "@@" in prompt def test_tier4_patch_prompt_format_instructions() -> None: """Test that the patch prompt includes format instructions for unified diff.""" prompt = mma_prompts.TIER4_PATCH_PROMPT assert "--- a/" in prompt or "---" in prompt assert "+++ b/" in prompt or "+++" in prompt def test_run_tier4_patch_generation_exists() -> None: """Test that run_tier4_patch_generation function exists in ai_client.""" assert hasattr(ai_client, "run_tier4_patch_generation") assert callable(ai_client.run_tier4_patch_generation) def test_run_tier4_patch_generation_empty_error() -> None: """Test that run_tier4_patch_generation returns empty string on empty error.""" with patch("src.ai_client._ensure_gemini_client"), \ patch("src.ai_client._gemini_client") as mock_client: mock_resp = MagicMock() mock_resp.text = "" mock_client.models.generate_content.return_value = mock_resp result = ai_client.run_tier4_patch_generation("", "file context") assert result == "" def test_run_tier4_patch_generation_calls_ai() -> None: """Test that run_tier4_patch_generation calls the AI with the correct prompt.""" with patch("src.ai_client._ensure_gemini_client"), \ patch("src.ai_client._gemini_client", create=True) as mock_client, \ patch("src.ai_client.types") as mock_types: mock_resp = MagicMock() mock_resp.text = "--- a/test.py\n+++ b/test.py\n@@ -1 +1 @@\n-old\n+new" mock_client.models.generate_content.return_value = mock_resp mock_types.GenerateContentConfig = MagicMock() error = "TypeError: unsupported operand" file_context = "def foo():\n pass" result = ai_client.run_tier4_patch_generation(error, file_context) mock_client.models.generate_content.assert_called() def test_run_tier4_patch_generation_returns_diff() -> None: """Test that run_tier4_patch_generation returns diff text.""" with patch("src.ai_client._ensure_gemini_client"), \ patch("src.ai_client._gemini_client", create=True) as mock_client, \ patch("src.ai_client.types") as mock_types: expected_diff = "--- a/src/test.py\n+++ b/src/test.py\n@@ -10,5 +10,6 @@\n def test_func():\n- old_value = 1\n+ old_value = 1\n+ new_value = 2" mock_resp = MagicMock() mock_resp.text = expected_diff mock_client.models.generate_content.return_value = mock_resp mock_types.GenerateContentConfig = MagicMock() result = ai_client.run_tier4_patch_generation("error", "context") assert "---" in result assert "+++" in result assert "@@" in result