feat(palette): add fuzzy_match with subsequence matching and scoring
This commit is contained in:
+86
-29
@@ -5,42 +5,99 @@ from typing import Optional, Callable, List, Dict, Any
|
||||
|
||||
@dataclass
|
||||
class Command:
|
||||
id: str
|
||||
title: str
|
||||
category: str
|
||||
shortcut: Optional[str] = None
|
||||
description: str = ""
|
||||
enabled_when: Optional[str] = None
|
||||
action: Optional[Callable] = None
|
||||
id: str
|
||||
title: str
|
||||
category: str
|
||||
shortcut: Optional[str] = None
|
||||
description: str = ""
|
||||
enabled_when: Optional[str] = None
|
||||
action: Optional[Callable] = None
|
||||
|
||||
|
||||
@dataclass
|
||||
class ScoredCommand:
|
||||
command: Command
|
||||
score: float
|
||||
command: Command
|
||||
score: float
|
||||
|
||||
|
||||
class CommandRegistry:
|
||||
def __init__(self) -> None:
|
||||
self._commands: Dict[str, Command] = {}
|
||||
def __init__(self) -> None:
|
||||
self._commands: Dict[str, Command] = {}
|
||||
|
||||
def register(self, command_or_callable: Any) -> Any:
|
||||
if isinstance(command_or_callable, Command):
|
||||
cmd = command_or_callable
|
||||
else:
|
||||
cmd = Command(
|
||||
id=command_or_callable.__name__,
|
||||
title=command_or_callable.__name__.replace("_", " ").title(),
|
||||
category="uncategorized",
|
||||
action=command_or_callable,
|
||||
)
|
||||
if cmd.id in self._commands:
|
||||
raise ValueError(f"Command {cmd.id} already registered")
|
||||
self._commands[cmd.id] = cmd
|
||||
return command_or_callable
|
||||
def register(self, command_or_callable: Any) -> Any:
|
||||
if isinstance(command_or_callable, Command):
|
||||
cmd = command_or_callable
|
||||
else:
|
||||
cmd = Command(
|
||||
id=command_or_callable.__name__,
|
||||
title=command_or_callable.__name__.replace("_", " ").title(),
|
||||
category="uncategorized",
|
||||
action=command_or_callable,
|
||||
)
|
||||
if cmd.id in self._commands:
|
||||
raise ValueError(f"Command {cmd.id} already registered")
|
||||
self._commands[cmd.id] = cmd
|
||||
return command_or_callable
|
||||
|
||||
def all(self) -> List[Command]:
|
||||
return list(self._commands.values())
|
||||
def all(self) -> List[Command]:
|
||||
return list(self._commands.values())
|
||||
|
||||
def get(self, command_id: str) -> Optional[Command]:
|
||||
return self._commands.get(command_id)
|
||||
def get(self, command_id: str) -> Optional[Command]:
|
||||
return self._commands.get(command_id)
|
||||
|
||||
|
||||
def fuzzy_match(query: str, candidates: List[Command], top_n: int = 20) -> List[ScoredCommand]:
|
||||
query_lower = query.lower()
|
||||
scored: List[ScoredCommand] = []
|
||||
for cmd in candidates:
|
||||
title_lower = cmd.title.lower()
|
||||
if not _is_subsequence(query_lower, title_lower):
|
||||
continue
|
||||
score = _compute_score(query_lower, title_lower)
|
||||
scored.append(ScoredCommand(command=cmd, score=score))
|
||||
scored.sort(key=lambda r: r.score, reverse=True)
|
||||
return scored[:top_n]
|
||||
|
||||
|
||||
def _is_subsequence(query: str, target: str) -> bool:
|
||||
qi = 0
|
||||
for ch in target:
|
||||
if qi < len(query) and ch == query[qi]:
|
||||
qi += 1
|
||||
return qi == len(query)
|
||||
|
||||
|
||||
def _compute_score(query: str, target: str) -> float:
|
||||
score = 0.0
|
||||
if target.startswith(query):
|
||||
score += 1.0
|
||||
elif _starts_at_word_boundary(query, target):
|
||||
score += 0.5
|
||||
if _is_contiguous(query, target):
|
||||
score += 0.3
|
||||
gaps = _count_gaps(query, target)
|
||||
score -= 0.1 * gaps
|
||||
return score
|
||||
|
||||
|
||||
def _starts_at_word_boundary(query: str, target: str) -> bool:
|
||||
if not target.startswith(query):
|
||||
return False
|
||||
return len(query) == 0 or not query[0].isalnum() or len(target) == len(query) or not target[len(query)].isalnum()
|
||||
|
||||
|
||||
def _is_contiguous(query: str, target: str) -> bool:
|
||||
return query in target
|
||||
|
||||
|
||||
def _count_gaps(query: str, target: str) -> int:
|
||||
qi = 0
|
||||
gaps = 0
|
||||
last_match = -1
|
||||
for ti, ch in enumerate(target):
|
||||
if qi < len(query) and ch == query[qi]:
|
||||
if last_match >= 0 and ti - last_match > 1:
|
||||
gaps += ti - last_match - 1
|
||||
last_match = ti
|
||||
qi += 1
|
||||
return gaps
|
||||
|
||||
@@ -0,0 +1,45 @@
|
||||
from src.command_palette import Command, ScoredCommand, fuzzy_match
|
||||
|
||||
|
||||
def _cmd(id: str, title: str) -> Command:
|
||||
return Command(id=id, title=title, category="test")
|
||||
|
||||
|
||||
def test_fuzzy_match_prefix_ranks_first():
|
||||
candidates = [
|
||||
_cmd("find", "Find in Selection"),
|
||||
_cmd("fold", "Fold All"),
|
||||
_cmd("config", "Configure Settings"),
|
||||
]
|
||||
results = fuzzy_match("fin", candidates, top_n=10)
|
||||
assert len(results) > 0
|
||||
assert results[0].command.id == "find"
|
||||
assert results[0].score > 0.5
|
||||
|
||||
|
||||
def test_fuzzy_match_subsequence_match():
|
||||
candidates = [_cmd("x", "Find")]
|
||||
results = fuzzy_match("fd", candidates, top_n=10)
|
||||
assert len(results) == 1
|
||||
assert results[0].command.id == "x"
|
||||
|
||||
|
||||
def test_fuzzy_match_no_match_returns_empty():
|
||||
candidates = [_cmd("x", "foo bar")]
|
||||
results = fuzzy_match("xyz", candidates, top_n=10)
|
||||
assert results == []
|
||||
|
||||
|
||||
def test_fuzzy_match_top_n_limits_results():
|
||||
candidates = [_cmd(f"cmd_{i}", f"Command {i}") for i in range(50)]
|
||||
results = fuzzy_match("cmd", candidates, top_n=10)
|
||||
assert len(results) == 10
|
||||
|
||||
|
||||
def test_fuzzy_match_score_higher_for_exact_prefix():
|
||||
candidates = [
|
||||
_cmd("a", "find"),
|
||||
_cmd("b", "Configure Find Settings"),
|
||||
]
|
||||
results = fuzzy_match("fin", candidates, top_n=10)
|
||||
assert results[0].command.id == "a"
|
||||
Reference in New Issue
Block a user