feat(controller): Integrate py_get_definition for on-demand lookup
This commit is contained in:
@@ -19,7 +19,7 @@ Focus: Parse @symbol syntax from user input
|
|||||||
## Phase 2: Definition Retrieval
|
## Phase 2: Definition Retrieval
|
||||||
Focus: Use existing MCP tool to get definitions
|
Focus: Use existing MCP tool to get definitions
|
||||||
|
|
||||||
- [ ] Task 2.1: Integrate py_get_definition
|
- [~] Task 2.1: Integrate py_get_definition
|
||||||
- WHERE: `src/gui_2.py`
|
- WHERE: `src/gui_2.py`
|
||||||
- WHAT: Call MCP tool for each symbol
|
- WHAT: Call MCP tool for each symbol
|
||||||
- HOW:
|
- HOW:
|
||||||
|
|||||||
@@ -46,6 +46,13 @@ def parse_symbols(text: str) -> list[str]:
|
|||||||
"""
|
"""
|
||||||
return re.findall(r"@([a-zA-Z_][a-zA-Z0-9_]*(?:\.[a-zA-Z_][a-zA-Z0-9_]*)*)", text)
|
return re.findall(r"@([a-zA-Z_][a-zA-Z0-9_]*(?:\.[a-zA-Z_][a-zA-Z0-9_]*)*)", text)
|
||||||
|
|
||||||
|
def get_symbol_definition(symbol: str, files: list[str]) -> tuple[str, str] | None:
|
||||||
|
for file_path in files:
|
||||||
|
result = mcp_client.py_get_definition(file_path, symbol)
|
||||||
|
if 'not found' not in result.lower():
|
||||||
|
return (file_path, result)
|
||||||
|
return None
|
||||||
|
|
||||||
class GenerateRequest(BaseModel):
|
class GenerateRequest(BaseModel):
|
||||||
prompt: str
|
prompt: str
|
||||||
auto_add_history: bool = True
|
auto_add_history: bool = True
|
||||||
|
|||||||
@@ -409,7 +409,7 @@ def py_get_definition(path: str, name: str) -> str:
|
|||||||
start = cast(int, getattr(node, "lineno")) - 1
|
start = cast(int, getattr(node, "lineno")) - 1
|
||||||
end = cast(int, getattr(node, "end_lineno"))
|
end = cast(int, getattr(node, "end_lineno"))
|
||||||
return "".join(lines[start:end])
|
return "".join(lines[start:end])
|
||||||
return f"ERROR: could not find definition '{name}' in {path}"
|
return f"ERROR: definition '{name}' not found in {path}"
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
return f"ERROR retrieving definition '{name}' from '{path}': {e}"
|
return f"ERROR retrieving definition '{name}' from '{path}': {e}"
|
||||||
|
|
||||||
|
|||||||
@@ -1,27 +1,58 @@
|
|||||||
import pytest
|
import unittest
|
||||||
from src.app_controller import parse_symbols
|
from unittest.mock import patch, MagicMock
|
||||||
|
from src.app_controller import parse_symbols, get_symbol_definition
|
||||||
|
|
||||||
def test_parse_symbols_basic():
|
class TestSymbolLookup(unittest.TestCase):
|
||||||
|
def test_parse_symbols_basic(self):
|
||||||
text = "Check @MyClass and @my_func."
|
text = "Check @MyClass and @my_func."
|
||||||
symbols = parse_symbols(text)
|
symbols = parse_symbols(text)
|
||||||
assert symbols == ["MyClass", "my_func"]
|
self.assertEqual(symbols, ["MyClass", "my_func"])
|
||||||
|
|
||||||
def test_parse_symbols_methods():
|
def test_parse_symbols_methods(self):
|
||||||
text = "Calling @MyClass.my_method and @AnotherClass.method_name."
|
text = "Calling @MyClass.my_method and @AnotherClass.method_name."
|
||||||
symbols = parse_symbols(text)
|
symbols = parse_symbols(text)
|
||||||
assert symbols == ["MyClass.my_method", "AnotherClass.method_name"]
|
self.assertEqual(symbols, ["MyClass.my_method", "AnotherClass.method_name"])
|
||||||
|
|
||||||
def test_parse_symbols_no_symbols():
|
def test_parse_symbols_no_symbols(self):
|
||||||
text = "This string has no symbols."
|
text = "This string has no symbols."
|
||||||
symbols = parse_symbols(text)
|
symbols = parse_symbols(text)
|
||||||
assert symbols == []
|
self.assertEqual(symbols, [])
|
||||||
|
|
||||||
def test_parse_symbols_mixed():
|
def test_parse_symbols_mixed(self):
|
||||||
text = "Mixed text: @Class1, @func_2, and some text @MyClass.method."
|
text = "Mixed text: @Class1, @func_2, and some text @MyClass.method."
|
||||||
symbols = parse_symbols(text)
|
symbols = parse_symbols(text)
|
||||||
assert symbols == ["Class1", "func_2", "MyClass.method"]
|
self.assertEqual(symbols, ["Class1", "func_2", "MyClass.method"])
|
||||||
|
|
||||||
def test_parse_symbols_edge_cases():
|
def test_parse_symbols_edge_cases(self):
|
||||||
text = "@LeadingSymbol and @SymbolAtEnd"
|
text = "@LeadingSymbol and @SymbolAtEnd"
|
||||||
symbols = parse_symbols(text)
|
symbols = parse_symbols(text)
|
||||||
assert symbols == ["LeadingSymbol", "SymbolAtEnd"]
|
self.assertEqual(symbols, ["LeadingSymbol", "SymbolAtEnd"])
|
||||||
|
|
||||||
|
def test_get_symbol_definition_found(self):
|
||||||
|
files = ["file1.py", "file2.py"]
|
||||||
|
symbol = "my_func"
|
||||||
|
def_content = "def my_func():\n pass"
|
||||||
|
|
||||||
|
with patch("src.mcp_client.py_get_definition") as mock_get_def:
|
||||||
|
# First file not found, second file found
|
||||||
|
mock_get_def.side_effect = [
|
||||||
|
"ERROR: definition 'my_func' not found in file1.py",
|
||||||
|
def_content
|
||||||
|
]
|
||||||
|
|
||||||
|
result = get_symbol_definition(symbol, files)
|
||||||
|
self.assertEqual(result, ("file2.py", def_content))
|
||||||
|
self.assertEqual(mock_get_def.call_count, 2)
|
||||||
|
|
||||||
|
def test_get_symbol_definition_not_found(self):
|
||||||
|
files = ["file1.py"]
|
||||||
|
symbol = "my_func"
|
||||||
|
|
||||||
|
with patch("src.mcp_client.py_get_definition") as mock_get_def:
|
||||||
|
mock_get_def.return_value = "ERROR: definition 'my_func' not found in file1.py"
|
||||||
|
|
||||||
|
result = get_symbol_definition(symbol, files)
|
||||||
|
self.assertIsNone(result)
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
unittest.main()
|
||||||
|
|||||||
Reference in New Issue
Block a user