Managing thirdparty package: defer.
This commit is contained in:
+66
@@ -0,0 +1,66 @@
|
||||
from ast import (
|
||||
AST,
|
||||
AsyncFunctionDef,
|
||||
Call,
|
||||
Compare,
|
||||
FunctionDef,
|
||||
In,
|
||||
Lambda,
|
||||
Load,
|
||||
Module,
|
||||
Name,
|
||||
NodeTransformer,
|
||||
arg,
|
||||
arguments,
|
||||
copy_location,
|
||||
fix_missing_locations,
|
||||
)
|
||||
from ast import (
|
||||
walk as ast_walk,
|
||||
)
|
||||
from typing import Any, Optional
|
||||
|
||||
|
||||
class RewriteDefer(NodeTransformer):
|
||||
def __init__(self, root: AST) -> None:
|
||||
super().__init__()
|
||||
self._dirty = False
|
||||
self._root = root
|
||||
|
||||
@classmethod
|
||||
def transform(cls, node: FunctionDef | AsyncFunctionDef) -> Optional[Module]:
|
||||
instance = cls(node)
|
||||
node = instance.visit(node)
|
||||
if not instance._dirty:
|
||||
return None
|
||||
return fix_missing_locations(Module(body=[node], type_ignores=[]))
|
||||
|
||||
def visit_Compare(self, node: Compare):
|
||||
match node:
|
||||
case Compare(ops=[In()], comparators=[Name(id="defer", ctx=Load())]):
|
||||
names = [n for n in ast_walk(node.left) if isinstance(n, Name)]
|
||||
fn = Lambda(
|
||||
args=arguments(
|
||||
args=[arg(arg=n.id, annotation=None) for n in names],
|
||||
kwonlyargs=[],
|
||||
kw_defaults=[],
|
||||
defaults=[],
|
||||
posonlyargs=[],
|
||||
),
|
||||
body=node.left,
|
||||
)
|
||||
call = Call(
|
||||
func=Name(id="defer", ctx=Load()), args=[fn, *names], keywords=[]
|
||||
)
|
||||
copy_location(call, node)
|
||||
self._dirty = True
|
||||
return call
|
||||
case _:
|
||||
return node
|
||||
|
||||
def visit_FunctionDef(self, node: FunctionDef | AsyncFunctionDef) -> Any:
|
||||
if node is self._root:
|
||||
return self.generic_visit(node)
|
||||
return node
|
||||
|
||||
visit_AsyncFunctionDef = visit_FunctionDef
|
||||
Reference in New Issue
Block a user