Files
manual_slop/thirdparty/defer/sugar/transformer.py
T

67 lines
1.8 KiB
Python

from ast import (
AST,
AsyncFunctionDef,
Call,
Compare,
FunctionDef,
In,
Lambda,
Load,
Module,
Name,
NodeTransformer,
arg,
arguments,
copy_location,
fix_missing_locations,
)
from ast import (
walk as ast_walk,
)
from typing import Any, Optional
class RewriteDefer(NodeTransformer):
def __init__(self, root: AST) -> None:
super().__init__()
self._dirty = False
self._root = root
@classmethod
def transform(cls, node: FunctionDef | AsyncFunctionDef) -> Optional[Module]:
instance = cls(node)
node = instance.visit(node)
if not instance._dirty:
return None
return fix_missing_locations(Module(body=[node], type_ignores=[]))
def visit_Compare(self, node: Compare):
match node:
case Compare(ops=[In()], comparators=[Name(id="defer", ctx=Load())]):
names = [n for n in ast_walk(node.left) if isinstance(n, Name)]
fn = Lambda(
args=arguments(
args=[arg(arg=n.id, annotation=None) for n in names],
kwonlyargs=[],
kw_defaults=[],
defaults=[],
posonlyargs=[],
),
body=node.left,
)
call = Call(
func=Name(id="defer", ctx=Load()), args=[fn, *names], keywords=[]
)
copy_location(call, node)
self._dirty = True
return call
case _:
return node
def visit_FunctionDef(self, node: FunctionDef | AsyncFunctionDef) -> Any:
if node is self._root:
return self.generic_visit(node)
return node
visit_AsyncFunctionDef = visit_FunctionDef