diff --git a/gui.py b/gui.py index 015a232..aa01be5 100644 --- a/gui.py +++ b/gui.py @@ -47,6 +47,21 @@ def hide_tk_root() -> Tk: root.wm_attributes("-topmost", True) return root +def get_total_token_usage() -> dict: + """Returns aggregated token usage across the entire session from comms log.""" + usage = { + "input_tokens": 0, + "output_tokens": 0, + "cache_read_input_tokens": 0, + "cache_creation_input_tokens": 0 + } + for entry in ai_client.get_comms_log(): + if entry.get("kind") == "response" and "usage" in entry.get("payload", {}): + u = entry["payload"]["usage"] + for k in usage.keys(): + usage[k] += u.get(k, 0) or 0 + return usage + # ------------------------------------------------------------------ comms rendering helpers @@ -713,6 +728,15 @@ class App: for entry in entries: self._comms_entry_count += 1 self._append_comms_entry(entry, self._comms_entry_count) + if entries: + self._update_token_usage() + + def _update_token_usage(self): + if not dpg.does_item_exist("ai_token_usage"): + return + usage = get_total_token_usage() + total = usage["input_tokens"] + usage["output_tokens"] + dpg.set_value("ai_token_usage", f"Tokens: {total} (In: {usage['input_tokens']} Out: {usage['output_tokens']})") def _append_comms_entry(self, entry: dict, idx: int): if not dpg.does_item_exist("comms_scroll"): @@ -1217,6 +1241,7 @@ class App: with self._pending_comms_lock: self._pending_comms.clear() self._comms_entry_count = 0 + self._update_token_usage() if dpg.does_item_exist("comms_scroll"): dpg.delete_item("comms_scroll", children_only=True) @@ -1864,6 +1889,8 @@ class App: with dpg.group(horizontal=True): dpg.add_text("Status: idle", tag="ai_status", color=(200, 220, 160)) dpg.add_spacer(width=16) + dpg.add_text("Tokens: 0 (In: 0 Out: 0)", tag="ai_token_usage", color=(180, 255, 180)) + dpg.add_spacer(width=16) dpg.add_button(label="Clear", callback=self.cb_clear_comms) dpg.add_separator() with dpg.group(horizontal=True): diff --git a/pyproject.toml b/pyproject.toml index b3b599d..4101559 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -10,3 +10,8 @@ dependencies = [ "anthropic", "tomli-w" ] + +[dependency-groups] +dev = [ + "pytest>=9.0.2", +] diff --git a/tests/test_token_usage.py b/tests/test_token_usage.py new file mode 100644 index 0000000..1cb5322 --- /dev/null +++ b/tests/test_token_usage.py @@ -0,0 +1,35 @@ +import pytest + +def test_token_usage_aggregation(): + # A dummy test to fulfill the Red Phase for the new token usage widget. + # We will implement a function in gui.py or ai_client.py to aggregate tokens. + from ai_client import _comms_log, clear_comms_log, _append_comms + + clear_comms_log() + + _append_comms("IN", "response", { + "usage": { + "input_tokens": 100, + "output_tokens": 50, + "cache_read_input_tokens": 10, + "cache_creation_input_tokens": 5 + } + }) + + _append_comms("IN", "response", { + "usage": { + "input_tokens": 200, + "output_tokens": 100, + "cache_read_input_tokens": 20, + "cache_creation_input_tokens": 0 + } + }) + + # We expect a new function get_total_token_usage() to exist + from gui import get_total_token_usage + + totals = get_total_token_usage() + assert totals["input_tokens"] == 300 + assert totals["output_tokens"] == 150 + assert totals["cache_read_input_tokens"] == 30 + assert totals["cache_creation_input_tokens"] == 5