65 lines
1.8 KiB
Python
65 lines
1.8 KiB
Python
import asyncio
|
|
import pytest
|
|
from unittest.mock import MagicMock, patch
|
|
|
|
from rook.agent import AgentLoop, PRIMARY_MODEL, FALLBACK_MODEL
|
|
|
|
|
|
def test_agent_loop_init():
|
|
loop = AgentLoop(system='You are Rook.')
|
|
assert loop.model == PRIMARY_MODEL
|
|
assert loop.history == []
|
|
assert isinstance(loop.tools, dict)
|
|
|
|
|
|
def test_register_tool():
|
|
loop = AgentLoop(system='You are Rook.')
|
|
loop.register_tool('ping', lambda: 'pong')
|
|
assert 'ping' in loop.tools
|
|
|
|
|
|
def test_send_returns_string():
|
|
mock_text_block = MagicMock()
|
|
mock_text_block.type = 'text'
|
|
mock_text_block.text = 'Hello'
|
|
mock_response = MagicMock()
|
|
mock_response.content = [mock_text_block]
|
|
mock_client = MagicMock()
|
|
mock_client.messages.create.return_value = mock_response
|
|
with patch('anthropic.Anthropic', return_value=mock_client):
|
|
loop = AgentLoop(system='You are Rook.')
|
|
result = asyncio.run(loop.send('hi'))
|
|
assert result == 'Hello'
|
|
|
|
|
|
def test_tool_dispatch_called():
|
|
tool_use_block = MagicMock()
|
|
tool_use_block.type = 'tool_use'
|
|
tool_use_block.name = 'ping'
|
|
tool_use_block.id = 'tu_1'
|
|
tool_use_block.input = {}
|
|
first_response = MagicMock()
|
|
first_response.content = [tool_use_block]
|
|
text_block = MagicMock()
|
|
text_block.type = 'text'
|
|
text_block.text = 'done'
|
|
second_response = MagicMock()
|
|
second_response.content = [text_block]
|
|
mock_client = MagicMock()
|
|
mock_client.messages.create.side_effect = [first_response, second_response]
|
|
ping_fn = MagicMock(return_value='pong')
|
|
with patch('anthropic.Anthropic', return_value=mock_client):
|
|
loop = AgentLoop(system='You are Rook.')
|
|
loop.register_tool('ping', ping_fn)
|
|
result = asyncio.run(loop.send('call ping'))
|
|
ping_fn.assert_called_once()
|
|
assert result == 'done'
|
|
|
|
|
|
def test_run_in_thread():
|
|
loop = AgentLoop(system='You are Rook.')
|
|
t = loop.run_in_thread()
|
|
assert t.is_alive()
|
|
assert t.daemon == True
|
|
t.join(timeout=0.1)
|