67 lines
1.8 KiB
Python
67 lines
1.8 KiB
Python
from ast import (
|
|
AST,
|
|
AsyncFunctionDef,
|
|
Call,
|
|
Compare,
|
|
FunctionDef,
|
|
In,
|
|
Lambda,
|
|
Load,
|
|
Module,
|
|
Name,
|
|
NodeTransformer,
|
|
arg,
|
|
arguments,
|
|
copy_location,
|
|
fix_missing_locations,
|
|
)
|
|
from ast import (
|
|
walk as ast_walk,
|
|
)
|
|
from typing import Any, Optional
|
|
|
|
|
|
class RewriteDefer(NodeTransformer):
|
|
def __init__(self, root: AST) -> None:
|
|
super().__init__()
|
|
self._dirty = False
|
|
self._root = root
|
|
|
|
@classmethod
|
|
def transform(cls, node: FunctionDef | AsyncFunctionDef) -> Optional[Module]:
|
|
instance = cls(node)
|
|
node = instance.visit(node)
|
|
if not instance._dirty:
|
|
return None
|
|
return fix_missing_locations(Module(body=[node], type_ignores=[]))
|
|
|
|
def visit_Compare(self, node: Compare):
|
|
match node:
|
|
case Compare(ops=[In()], comparators=[Name(id="defer", ctx=Load())]):
|
|
names = [n for n in ast_walk(node.left) if isinstance(n, Name)]
|
|
fn = Lambda(
|
|
args=arguments(
|
|
args=[arg(arg=n.id, annotation=None) for n in names],
|
|
kwonlyargs=[],
|
|
kw_defaults=[],
|
|
defaults=[],
|
|
posonlyargs=[],
|
|
),
|
|
body=node.left,
|
|
)
|
|
call = Call(
|
|
func=Name(id="defer", ctx=Load()), args=[fn, *names], keywords=[]
|
|
)
|
|
copy_location(call, node)
|
|
self._dirty = True
|
|
return call
|
|
case _:
|
|
return node
|
|
|
|
def visit_FunctionDef(self, node: FunctionDef | AsyncFunctionDef) -> Any:
|
|
if node is self._root:
|
|
return self.generic_visit(node)
|
|
return node
|
|
|
|
visit_AsyncFunctionDef = visit_FunctionDef
|