feat(tier4): Add patch generation for auto-patching
- Add TIER4_PATCH_PROMPT to mma_prompts.py with unified diff format - Add run_tier4_patch_generation function to ai_client.py - Import mma_prompts in ai_client.py - Add unit tests for patch generation
This commit is contained in:
67
tests/test_tier4_patch_generation.py
Normal file
67
tests/test_tier4_patch_generation.py
Normal file
@@ -0,0 +1,67 @@
|
||||
from unittest.mock import MagicMock, patch
|
||||
import pytest
|
||||
from src import ai_client
|
||||
from src import mma_prompts
|
||||
|
||||
def test_tier4_patch_prompt_exists() -> None:
|
||||
"""Test that TIER4_PATCH_PROMPT is defined in mma_prompts.py."""
|
||||
assert hasattr(mma_prompts, "TIER4_PATCH_PROMPT")
|
||||
prompt = mma_prompts.TIER4_PATCH_PROMPT
|
||||
assert "unified diff" in prompt.lower() or "diff -u" in prompt.lower()
|
||||
assert "---" in prompt
|
||||
assert "+++" in prompt
|
||||
assert "@@" in prompt
|
||||
|
||||
def test_tier4_patch_prompt_format_instructions() -> None:
|
||||
"""Test that the patch prompt includes format instructions for unified diff."""
|
||||
prompt = mma_prompts.TIER4_PATCH_PROMPT
|
||||
assert "--- a/" in prompt or "---" in prompt
|
||||
assert "+++ b/" in prompt or "+++" in prompt
|
||||
|
||||
def test_run_tier4_patch_generation_exists() -> None:
|
||||
"""Test that run_tier4_patch_generation function exists in ai_client."""
|
||||
assert hasattr(ai_client, "run_tier4_patch_generation")
|
||||
assert callable(ai_client.run_tier4_patch_generation)
|
||||
|
||||
def test_run_tier4_patch_generation_empty_error() -> None:
|
||||
"""Test that run_tier4_patch_generation returns empty string on empty error."""
|
||||
with patch("src.ai_client._ensure_gemini_client"), \
|
||||
patch("src.ai_client._gemini_client") as mock_client:
|
||||
mock_resp = MagicMock()
|
||||
mock_resp.text = ""
|
||||
mock_client.models.generate_content.return_value = mock_resp
|
||||
|
||||
result = ai_client.run_tier4_patch_generation("", "file context")
|
||||
assert result == ""
|
||||
|
||||
def test_run_tier4_patch_generation_calls_ai() -> None:
|
||||
"""Test that run_tier4_patch_generation calls the AI with the correct prompt."""
|
||||
with patch("src.ai_client._ensure_gemini_client"), \
|
||||
patch("src.ai_client._gemini_client", create=True) as mock_client, \
|
||||
patch("src.ai_client.types") as mock_types:
|
||||
mock_resp = MagicMock()
|
||||
mock_resp.text = "--- a/test.py\n+++ b/test.py\n@@ -1 +1 @@\n-old\n+new"
|
||||
mock_client.models.generate_content.return_value = mock_resp
|
||||
mock_types.GenerateContentConfig = MagicMock()
|
||||
|
||||
error = "TypeError: unsupported operand"
|
||||
file_context = "def foo():\n pass"
|
||||
result = ai_client.run_tier4_patch_generation(error, file_context)
|
||||
|
||||
mock_client.models.generate_content.assert_called()
|
||||
|
||||
def test_run_tier4_patch_generation_returns_diff() -> None:
|
||||
"""Test that run_tier4_patch_generation returns diff text."""
|
||||
with patch("src.ai_client._ensure_gemini_client"), \
|
||||
patch("src.ai_client._gemini_client", create=True) as mock_client, \
|
||||
patch("src.ai_client.types") as mock_types:
|
||||
expected_diff = "--- a/src/test.py\n+++ b/src/test.py\n@@ -10,5 +10,6 @@\n def test_func():\n- old_value = 1\n+ old_value = 1\n+ new_value = 2"
|
||||
mock_resp = MagicMock()
|
||||
mock_resp.text = expected_diff
|
||||
mock_client.models.generate_content.return_value = mock_resp
|
||||
mock_types.GenerateContentConfig = MagicMock()
|
||||
|
||||
result = ai_client.run_tier4_patch_generation("error", "context")
|
||||
assert "---" in result
|
||||
assert "+++" in result
|
||||
assert "@@" in result
|
||||
Reference in New Issue
Block a user