diff --git a/src/openai_schemas.py b/src/openai_schemas.py index 76dd5e2e..7fab91e3 100644 --- a/src/openai_schemas.py +++ b/src/openai_schemas.py @@ -16,10 +16,14 @@ CONVENTION: 1-space indentation. NO COMMENTS. """ from __future__ import annotations -from dataclasses import dataclass, field +from dataclasses import dataclass, field, fields as dc_fields from typing import Any, Callable, Optional -from src.type_aliases import JsonValue +from src.type_aliases import JsonValue, Metadata + + +def _from_dict_filter(cls: type, data: Metadata) -> Metadata: + return {k: v for k, v in data.items() if k in {f.name for f in dc_fields(cls)}} @dataclass(frozen=True) @@ -44,11 +48,16 @@ class ToolCall: }, } + @classmethod + def from_dict(cls, data: Metadata) -> "ToolCall": + fn = ToolCallFunction(**_from_dict_filter(ToolCallFunction, data.get("function", {}))) + return cls(**{**_from_dict_filter(cls, data), "function": fn}) + @dataclass(frozen=True) class ChatMessage: role: str - content: str | list # str for text; list of content parts for multimodal (text + image_url, etc.) + content: str | list tool_calls: Optional[tuple[ToolCall, ...]] = None tool_call_id: Optional[str] = None name: Optional[str] = None @@ -63,6 +72,14 @@ class ChatMessage: d["name"] = self.name return d + @classmethod + def from_dict(cls, data: Metadata) -> "ChatMessage": + raw_tool_calls = data.get("tool_calls") + tool_calls = None + if raw_tool_calls is not None: + tool_calls = tuple(ToolCall.from_dict(tc) for tc in raw_tool_calls) + return cls(**{**_from_dict_filter(cls, data), "tool_calls": tool_calls}) + @dataclass(frozen=True) class UsageStats: @@ -71,6 +88,10 @@ class UsageStats: cache_read_tokens: int = 0 cache_creation_tokens: int = 0 + @classmethod + def from_dict(cls, data: Metadata) -> "UsageStats": + return cls(**_from_dict_filter(cls, data)) + @dataclass(frozen=True) class NormalizedResponse: