Managing thirdparty package: defer.

This commit is contained in:
2026-05-13 05:09:23 -04:00
parent 8d6c91d306
commit 9266add6a1
10 changed files with 388 additions and 112 deletions
+1
View File
@@ -0,0 +1 @@
from .sugar import defer as defer
+32
View File
@@ -0,0 +1,32 @@
from contextlib import ExitStack
from types import FrameType
from typing import Any, Callable, Optional, ParamSpec, TypeVar
P = ParamSpec("P")
T = TypeVar("T")
class _Defer:
def __init__(self, tracefn: Optional[Callable]) -> None:
self.tracefn = tracefn or (lambda *_: None)
self._stack = ExitStack()
self._stack.__enter__()
def push(self, fn: Callable[P, T], *args: P.args, **kwargs: P.kwargs):
self._stack.callback(fn, *args, **kwargs)
def __call__(self, frame: FrameType, event: str, arg: Any):
self.tracefn(frame, event, arg)
match event:
case "call":
self._stack.__enter__()
case "return":
self._stack.__exit__(None, None, None)
self._stack = ExitStack()
case "exception":
self._stack.__exit__(*arg)
self._stack = ExitStack()
return self
+23
View File
@@ -0,0 +1,23 @@
import sys
from typing import Any, Callable, ParamSpec, TypeVar
from defer._defer import _Defer
from defer.sugar._parse import _ParseDefer
P = ParamSpec("P")
T = TypeVar("T")
class Defer:
def __call__(self, fn: Callable[P, T], *args: P.args, **kwargs: P.kwargs):
if sys.gettrace() is None:
sys.settrace(_ParseDefer.IDENTITY)
frame = sys._getframe(1)
if not isinstance(frame.f_trace, _Defer):
frame.f_trace = _Defer(frame.f_trace)
frame.f_trace.push(fn, *args, **kwargs)
frame.f_trace_lines = False
def __contains__(self, fn: Any):
breakpoint()
self(fn)
+13
View File
@@ -0,0 +1,13 @@
from types import FunctionType
class DeferErrror(RuntimeError):
pass
class FreeVarsError(DeferErrror):
def __init__(self, fn: FunctionType) -> None:
super().__init__(
"deferred function must not have free variables",
)
self.add_note("free vars: " + str(list(fn.__code__.co_freevars)))
+13
View File
@@ -0,0 +1,13 @@
import sys
from defer.defer import Defer
from defer.sugar._parse import _ParseDefer
def install():
if not isinstance(sys.gettrace(), _ParseDefer):
sys.settrace(_ParseDefer(sys.gettrace()))
defer = Defer()
install()
+67
View File
@@ -0,0 +1,67 @@
import sys
from ast import (
AsyncFunctionDef,
FunctionDef,
)
from collections import deque
from types import FrameType
from typing import Any, Callable, Optional, cast
from executing.executing import Executing, Source
from defer.errors import FreeVarsError
from defer.sugar.transformer import RewriteDefer
class _ParseDefer:
IDENTITY = lambda *_: None # noqa: E731
def __init__(self, tracefn: Optional[Callable]) -> None:
self.tracefn = tracefn or self.IDENTITY
self.pending: deque[Executing] = deque()
def __call__(self, frame: FrameType, event: str, arg: Any):
self.tracefn(frame, event, arg)
if any(
frame.f_code.co_filename.startswith(path)
for path in {
sys.base_exec_prefix,
sys.base_prefix,
sys.exec_prefix,
sys.prefix,
"<frozen ",
"<string>",
"<pytest ",
}
):
return self
if event != "line":
return self
if not (stmt := next(iter(exc.statements), None)):
return self
if isinstance(stmt, (AsyncFunctionDef, FunctionDef)):
if stmt.name.startswith("__") and stmt.name.endswith("__"):
return self
self.pending.append(exc)
return self
return self
stmts = self.pending.pop().statements
node = cast(FunctionDef | AsyncFunctionDef, next(iter(stmts)))
fn = frame.f_locals[node.name]
if fn.__module__ in sys.stdlib_module_names:
return self
if fn.__code__.co_freevars:
raise FreeVarsError(fn)
if not (ast := RewriteDefer.transform(node)):
return self
locals = frame.f_locals.copy()
del locals[node.name]
exec(compile(ast, frame.f_code.co_filename, "exec"), frame.f_globals, locals)
frame.f_locals[node.name].__code__ = locals[node.name].__code__
return self
+66
View File
@@ -0,0 +1,66 @@
from ast import (
AST,
AsyncFunctionDef,
Call,
Compare,
FunctionDef,
In,
Lambda,
Load,
Module,
Name,
NodeTransformer,
arg,
arguments,
copy_location,
fix_missing_locations,
)
from ast import (
walk as ast_walk,
)
from typing import Any, Optional
class RewriteDefer(NodeTransformer):
def __init__(self, root: AST) -> None:
super().__init__()
self._dirty = False
self._root = root
@classmethod
def transform(cls, node: FunctionDef | AsyncFunctionDef) -> Optional[Module]:
instance = cls(node)
node = instance.visit(node)
if not instance._dirty:
return None
return fix_missing_locations(Module(body=[node], type_ignores=[]))
def visit_Compare(self, node: Compare):
match node:
case Compare(ops=[In()], comparators=[Name(id="defer", ctx=Load())]):
names = [n for n in ast_walk(node.left) if isinstance(n, Name)]
fn = Lambda(
args=arguments(
args=[arg(arg=n.id, annotation=None) for n in names],
kwonlyargs=[],
kw_defaults=[],
defaults=[],
posonlyargs=[],
),
body=node.left,
)
call = Call(
func=Name(id="defer", ctx=Load()), args=[fn, *names], keywords=[]
)
copy_location(call, node)
self._dirty = True
return call
case _:
return node
def visit_FunctionDef(self, node: FunctionDef | AsyncFunctionDef) -> Any:
if node is self._root:
return self.generic_visit(node)
return node
visit_AsyncFunctionDef = visit_FunctionDef