fst.docs.d13_examples
Example recipes
These are a few snippets which do some real-world-ish things. The comment handling blemishes are left in place and your
actual mileage with comments at the edges of modifications will vary. It depends very much of correct usage of the
trivia option both on get() and put(), and comment handling in general needs a bit more work (a lot more work),
but they are preserved mostly.
The examples are deliberately not the most efficient but are rather meant to show off fst usage and features. Some of
them are somewhat formatter-y, which is not the intended use of this module, but fine for demonstration purposes. If you
want to see an example of an intended use case then see Instrument expressions.
You will see a lot of .replace() of nodes while .walk()ing them, this is allowed as walk() is really a
transform() function, see fst.fst.FST.walk() for more details.
To be able to execute the examples, import this.
>>> from fst import *
This is just a print helper function for this documentation, you can ignore it.
>>> def pprint(src): # helper
... print(src.replace('\n\n', '\n\xa0\n')) # replace() to avoid '<BLANKLINE>'
Type annotations to type comments
This doesn't do full validation and there could be extra functionality added for class attributes and updates if they are set in a constructor, but should be enough to show how something more complete would be done.
>>> def type_annotations_to_type_comments(src: str) -> str:
... fst_ = FST(src, 'exec') # same as "fst.parse(src).f"
...
... # walk the whole tree but only yield AnnAssign nodes
... for f in fst_.walk(AnnAssign):
... # if just an annotation then skip it, alternatively could
... # clean and store for later addition to __init__() assign in class
... if not f.value:
... continue
...
... # '.own_src()' gives us the original source exactly as written dedented
... target = f.target.own_src()
... value = f.value.own_src()
...
... # we use ast_src() for the annotation to get a clean type string
... annotation = f.annotation.ast_src()
...
... # preserve any existing end-of-line comment
... comment = ' # ' + comment if (comment := f.get_line_comment()) else ''
...
... # reconstruct the line using the PEP 484 type comment style
... new_src = f'{target} = {value} # type: {annotation}{comment}'
...
... # replace the node, trivia=False preserves any leading comments
... f.replace(new_src, trivia=False)
...
... return fst_.src # same as fst.unparse(fst_.a)
>>> src = """
... def func():
... normal = assign
...
... x: int = 1
...
... # y is such and such
... y: float = 2.0 # more about y
... # y was a good variable...
...
... structure: tuple[
... tuple[int, int], # extraneous comment
... dict[str, Any], # could break stuff
... ] | None = None# blah
...
... call( # invalid but just for demonstration purposes
... some_arg, # non-extraneous comment
... some_kw=kw_value, # will not break stuff
... )[start : stop].attr: SomeClass = getthis()
... """.strip()
Original.
>>> pprint(src)
def func():
normal = assign
x: int = 1
# y is such and such
y: float = 2.0 # more about y
# y was a good variable...
structure: tuple[
tuple[int, int], # extraneous comment
dict[str, Any], # could break stuff
] | None = None# blah
call( # invalid but just for demonstration purposes
some_arg, # non-extraneous comment
some_kw=kw_value, # will not break stuff
)[start : stop].attr: SomeClass = getthis()
Processed:
>>> pprint(type_annotations_to_type_comments(src))
def func():
normal = assign
x = 1 # type: int
# y is such and such
y = 2.0 # type: float # more about y
# y was a good variable...
structure = None # type: tuple[tuple[int, int], dict[str, Any]] | None # blah
call( # invalid but just for demonstration purposes
some_arg, # non-extraneous comment
some_kw=kw_value, # will not break stuff
)[start : stop].attr = getthis() # type: SomeClass
Inject logging metadata
You want to add a correlation_id=CID keyword argument to all logger.info() calls, but only if its not already there.
>>> def inject_logging_metadata(src: str) -> str:
... fst = FST(src, 'exec')
...
... for f in fst.walk(Call):
... if (f.func.is_Attribute
... and f.func.attr == 'info'
... and f.func.value.is_Name
... and f.func.value.id == 'logger'
... and not any(kw.arg == 'correlation_id' for kw in f.keywords)
... ):
... f.append('correlation_id=CID', trivia=(False, False))
...
... return fst.src
>>> src = """
... logger.info('Hello world...') # ok
... logger.info('Already have id', correlation_id=other_cid) # ok
... logger.info() # yes, no logger message, too bad
...
... class cls:
... def method(self, thing, extra):
... if not thing:
... (logger).info( # just checking
... f'not a {thing}', # this is fine
... extra=extra, # also this
... )
... """.strip()
Original.
>>> pprint(src)
logger.info('Hello world...') # ok
logger.info('Already have id', correlation_id=other_cid) # ok
logger.info() # yes, no logger message, too bad
class cls:
def method(self, thing, extra):
if not thing:
(logger).info( # just checking
f'not a {thing}', # this is fine
extra=extra, # also this
)
Processed:
>>> pprint(inject_logging_metadata(src))
logger.info('Hello world...', correlation_id=CID) # ok
logger.info('Already have id', correlation_id=other_cid) # ok
logger.info(correlation_id=CID) # yes, no logger message, too bad
class cls:
def method(self, thing, extra):
if not thing:
(logger).info( # just checking
f'not a {thing}', # this is fine
extra=extra, # also this
correlation_id=CID
)
else if chain to elif
fst has elif <-> else if code built in as its needed for statement insertions and deletions from conditional
bodies so its fairly easy to leverage to change these kinds of chains. The inverse of this operation can be done just
by changing the f.is_elif() check to is True and the elif_ parameter to elif_=False in the replace(), though
you may need to tweak the trivia parameters for best results.
Yes the comments on the replaced else headers disappear. Could preserve them explicitly by using get_line_comment()
and then inserting them above the if manually using put_src(). Eventually should make this automatic.
>>> def else_if_chain_to_elifs(src):
... fst = FST(src, 'exec')
...
... for f in fst.walk(If): # we will only get the `ast.If` nodes
... if (len(f.orelse) == 1
... and f.orelse[0].is_elif() is False # False means normal `if`
... ):
... f.orelse[0].replace( # can replace while walking
... f.orelse[0].copy(trivia=('block', 'all')),
... trivia=(False, 'all'), # trivia specifies how to handle comments
... elif_=True, # elif_=True is default, here to show usage
... )
...
... return fst.src
>>> src = r"""
... def func():
... # pre-if-a
... if a: # if-a
... # pre-i
... i = 1 # i
... # post-i
...
... else: # else-a
... # pre-if-b
... if b: # if-b
... # pre-j
... j = 2 # j
... # post-j
...
... else: # else-b
... # pre-if-c
... if c: # if-c
... # pre-k
... k = 3 # k
... # post-k
...
... else: # else-c
... # pre-l
... l = 4 # l
... # post-l
...
... # post-else-c
...
... # post-else-b
...
... # post-else-a
... """.strip()
Original:
>>> pprint(src)
def func():
# pre-if-a
if a: # if-a
# pre-i
i = 1 # i
# post-i
else: # else-a
# pre-if-b
if b: # if-b
# pre-j
j = 2 # j
# post-j
else: # else-b
# pre-if-c
if c: # if-c
# pre-k
k = 3 # k
# post-k
else: # else-c
# pre-l
l = 4 # l
# post-l
# post-else-c
# post-else-b
# post-else-a
Processed:
>>> pprint(else_if_chain_to_elifs(src))
def func():
# pre-if-a
if a: # if-a
# pre-i
i = 1 # i
# post-i
# pre-if-b
elif b: # if-b
# pre-j
j = 2 # j
# post-j
# pre-if-c
elif c: # if-c
# pre-k
k = 3 # k
# post-k
else: # else-c
# pre-l
l = 4 # l
# post-l
# post-else-c
# post-else-b
# post-else-a
Pull out nested functions
This is a bit trickier than it sounds because of possible nonlocal accesses by the inner functions. This is solved by checking for those accesses and not pulling those functions out. Names are also changed to avoid collision at the global scope.
The nonlocal variable check is not complete as a full check would check if the "free" symbols aren't actually global or
builtins, in which case they can be ignored and the function can be moved. We also don't check possible type_params
or function returns fields for possible nonlocal references but just assume they are all global for this example. We
also don't check for name override in children when replacing function name, but that is another detail not needed for
demonstration purposes.
>>> def pull_out_inner_funcs_safely(src):
... fst = FST(src, 'exec')
...
... for f in fst.walk({FunctionDef, AsyncFunctionDef}):
... if (parent_scope := f.parent_named_scope()).is_funcdef: # func in a func
... func_name = f.name
... syms = f.scope_symbols(full=True)
...
... # ignore any reference of function to itself
... if func_name in syms['free']:
... del syms['free'][func_name]
...
... # check if the inner function uses explicit or implicit nonlocals
... if syms['nonlocal'] or syms['free']:
... continue
...
... # check if any function args defaults use variables
... if next(f.args.walk(Name), False):
... continue
...
... # build global name from enclosing scopes
... global_name = f'{parent_scope.name}_{func_name}'
... top_scope = parent_scope
...
... while up_scope := top_scope.parent_named_scope(mod=False):
... top_scope = up_scope
... global_name = f'{top_scope.name}_{global_name}'
...
... # replace all occurrences of original inner name with new global one
... # we do this first so that it includes the function being moved
... for g in parent_scope.walk(Name):
... if g.id == func_name:
... g.replace(global_name)
...
... f = f.cut()
... f.name = global_name
...
... # insert just before our top-level scope
... fst.body.insert(f, top_scope.pfield.idx, pep8space=1)
...
... return fst.src
>>> src = r"""
... def get_lookup(val):
... if not val:
... return
...
... def closure(): # can't pull out because of closure
... return val[0]
...
... def default_arg(val=val): # can't pull out because of default arg
... return val[0]
...
... def safe(val): # safe to pull out
... return val[0]
...
... return closure, default_arg, safe
...
... class cls:
... def method1(self, a, b):
... def fib(n): # recursive for fun
... if n <= 1:
... return n
...
... return fib(n - 1) + fib(n - 2)
...
... return fib(n)
... """.strip()
Original:
>>> pprint(src)
def get_lookup(val):
if not val:
return
def closure(): # can't pull out because of closure
return val[0]
def default_arg(val=val): # can't pull out because of default arg
return val[0]
def safe(val): # safe to pull out
return val[0]
return closure, default_arg, safe
class cls:
def method1(self, a, b):
def fib(n): # recursive for fun
if n <= 1:
return n
return fib(n - 1) + fib(n - 2)
return fib(n)
Processed:
>>> pprint(pull_out_inner_funcs_safely(src))
def get_lookup_safe(val): # safe to pull out
return val[0]
def get_lookup(val):
if not val:
return
def closure(): # can't pull out because of closure
return val[0]
def default_arg(val=val): # can't pull out because of default arg
return val[0]
return closure, default_arg, get_lookup_safe
def cls_method1_fib(n): # recursive for fun
if n <= 1:
return n
return cls_method1_fib(n - 1) + cls_method1_fib(n - 2)
class cls:
def method1(self, a, b):
return cls_method1_fib(n)
lambda to def
Maybe you have too many lambdas and want proper function defs for debugging or logging or other tools. Note the defs are left in the same scope in case of nonlocals.
>>> def lambdas_to_defs(src):
... fst = FST(src, 'exec')
... indent = fst.indent # to show its there, inferred from src, single level str
...
... for f in fst.walk(Assign):
... if (f.value.is_Lambda
... and f.targets[0].is_Name
... and len(f.targets) == 1 # for demo purposes just deal with this case
... ):
... flmb = f.value
... fdef = FST(f"""
... def {f.targets[0].id}({flmb.args.src}):
... {indent}return _
... """.strip()) # template
... fdef.body[0].value = flmb.body.copy()
...
... # explicitly preserve lambda line comment
... fdef.put_line_comment(f.get_line_comment(full=True), full=True)
...
... f.replace(
... fdef,
... trivia=(False, 'line'), # eat line comment but not leading
... pep8space=1, # don't doublespace inserted func def (at mod scope)
... )
...
... return fst.src
>>> src = r"""
... # lambda comment
... mymin = lambda a, b: a if a < b else b # inline lambda comment
...
... # class comment
... class cls:
... name = lambda self: str(self)
...
... def method(self, a, b):
... add = lambda a, b: a + b
...
... return add(a, b)
... """.strip()
Original:
>>> pprint(src)
# lambda comment
mymin = lambda a, b: a if a < b else b # inline lambda comment
# class comment
class cls:
name = lambda self: str(self)
def method(self, a, b):
add = lambda a, b: a + b
return add(a, b)
Processed:
>>> pprint(src := lambdas_to_defs(src))
# lambda comment
def mymin(a, b): # inline lambda comment
return a if a < b else b
# class comment
class cls:
def name(self):
return str(self)
def method(self, a, b):
def add(a, b):
return a + b
return add(a, b)
Now lets also pull out the nested function.
>>> pprint(pull_out_inner_funcs_safely(src))
# lambda comment
def mymin(a, b): # inline lambda comment
return a if a < b else b
def cls_method_add(a, b):
return a + b
# class comment
class cls:
def name(self):
return str(self)
def method(self, a, b):
return cls_method_add(a, b)
isinstance() to __class__ is / in
Maybe you realize that all your isinstance() checks are incurring a 1.3% performance penalty and you don't like that
so you want to replace them all in your codebase with direct class identity checks. This can be done for non-base
classes (like most AST types are).
>>> NAMES = {
... 'Add', 'And', 'AnnAssign', 'Assert', 'Assign', 'AsyncFor', 'AsyncFunctionDef',
... 'AsyncWith', 'Attribute', 'AugAssign', 'Await', 'BinOp', 'BitAnd', 'BitOr',
... 'BitXor', 'BoolOp', 'Break', 'Call', 'ClassDef', 'Compare', 'Constant',
... 'Continue', 'Del', 'Delete', 'Dict', 'DictComp', 'Div', 'Eq', 'ExceptHandler',
... 'Expr', 'Expression', 'FloorDiv', 'For', 'FormattedValue', 'FunctionDef',
... 'FunctionType', 'GeneratorExp', 'Global', 'Gt', 'GtE', 'If', 'IfExp', 'Import',
... 'ImportFrom', 'In', 'Interactive', 'Interpolation', 'Invert', 'Is', 'IsNot',
... 'JoinedStr', 'LShift', 'Lambda', 'List', 'ListComp', 'Load', 'Lt', 'LtE',
... 'MatMult', 'Match', 'MatchAs', 'MatchClass', 'MatchMapping', 'MatchOr',
... 'MatchSequence', 'MatchSingleton', 'MatchStar', 'MatchValue', 'Mod', 'Module',
... 'Mult', 'Name', 'NamedExpr', 'Nonlocal', 'Not', 'NotEq', 'NotIn', 'Or',
... 'ParamSpec', 'Pass', 'Pow', 'RShift', 'Raise', 'Return', 'Set', 'SetComp',
... 'Slice', 'Starred', 'Store', 'Sub', 'Subscript', 'TemplateStr', 'Try', 'TryStar',
... 'Tuple', 'TypeAlias', 'TypeIgnore', 'TypeVar', 'TypeVarTuple', 'UAdd', 'USub',
... 'UnaryOp', 'While', 'With', 'Yield', 'YieldFrom', 'alias', 'arg', 'arguments',
... 'comprehension', 'keyword', 'match_case', 'withitem',
... }
>>> def isinstance_to_class_check(src):
... fst = FST(src, 'exec')
...
... for f in fst.walk(Call):
... if (f.func.is_Name
... and f.func.id == 'isinstance' # isinstance()
... ):
... ftest, ftype = f.args # assume there are two for isinstance()
... fparent = f.parent
...
... # isinstance(..., one of NAMES)
... if ftype.is_Name and ftype.id in NAMES:
... op, notop = 'is', 'is not'
... # isinstance(..., (one of NAMES, ...))
... elif (
... ftype.is_Tuple
... and all(g.is_Name and g.id in NAMES for g in f.args[1].elts)
... ):
... op, notop = 'in', 'not in'
... else:
... continue
...
... # 'not isinstance()' -> '__class__ is not' / 'not in'
... if fparent.is_UnaryOp and fparent.op.is_Not:
... f = fparent
... fparent = f.parent
... op = notop
...
... fnew = FST(f'_.__class__ {op} _')
...
... fnew.left.value.replace(ftest.copy())
... fnew.comparators[0].replace(ftype.copy())
...
... if fparent.is_NamedExpr: # be nice and parenthesize ourselves here
... fnew.par() # we know we are the .value
...
... f.replace(fnew, pars=True) # preserve our own pars if present
...
... return fst.src
>>> src = """
... def is_valid_target(asts: AST | list[AST]) -> bool:
... \"\"\"Check if `asts` is a valid target for `Assign` or `For`
... operations. Must be `Name`, `Attribute`, `Subscript`
... and / or possibly nested `Starred`, `Tuple` and `List`.\"\"\"
...
... stack = [asts] if isinstance(asts, AST) else list(asts)
...
... while stack:
... if isinstance(a := stack.pop(), (Tuple, List)):
... stack.extend(a.elts)
... elif isinstance(a, Starred):
... stack.append(a.value)
... elif not isinstance(a, (Name, Attribute, Subscript)):
... return False
...
... return True
... """.strip()
Original:
>>> pprint(src)
def is_valid_target(asts: AST | list[AST]) -> bool:
"""Check if `asts` is a valid target for `Assign` or `For`
operations. Must be `Name`, `Attribute`, `Subscript`
and / or possibly nested `Starred`, `Tuple` and `List`."""
stack = [asts] if isinstance(asts, AST) else list(asts)
while stack:
if isinstance(a := stack.pop(), (Tuple, List)):
stack.extend(a.elts)
elif isinstance(a, Starred):
stack.append(a.value)
elif not isinstance(a, (Name, Attribute, Subscript)):
return False
return True
Processed:
def is_valid_target(asts: AST | list[AST]) -> bool:
"""Check if `asts` is a valid target for `Assign` or `For`
operations. Must be `Name`, `Attribute`, `Subscript`
and / or possibly nested `Starred`, `Tuple` and `List`."""
stack = [asts] if isinstance(asts, AST) else list(asts)
while stack:
if (a := stack.pop()).__class__ in (Tuple, List):
stack.extend(a.elts)
elif a.__class__ is Starred:
stack.append(a.value)
elif a.__class__ not in (Name, Attribute, Subscript):
return False
return True
Squash nested withs
Slice operations make this easy enough. We only do synchronous with here as you can't mix sync with async anyway. Yes
the alignment is ugly, it will eventually be done properly, first priority was functional correctness. The comment on
the with ctx(): could be preserved explicitly but not doing it here, eventually will be automatic option.
>>> def squash_nested_withs(src: str) -> str:
... fst = FST(src, 'exec')
...
... for f in fst.walk(With): # we only get With nodes
... while f.body[0].is_With: # first child is another With
... # append child items to ours
... f.items.extend(f.body[0].items.copy(), trivia=(False, False))
...
... f.put_slice( # copy child body into our own
... f.body[0].get_slice(trivia=('all+', 'block'), cut=True),
... trivia=(False, False),
... )
...
... return fst.src
>>> src = r"""
... # with comment
... with open(a) as f:
... with (
... lock1, # first lock
... func() as lock2, # this gets preserved
... ):
... with ctx(): # this does not belong to ctx()
... # body comment
... pass
... # end body comment
...
... # post-with comment
... """.strip()
Original:
>>> pprint(src)
# with comment
with open(a) as f:
with (
lock1, # first lock
func() as lock2, # this gets preserved
):
with ctx(): # this does not belong to ctx()
# body comment
pass
# end body comment
# post-with comment
Processed:
>>> pprint(squash_nested_withs(src))
# with comment
with (open(a) as f,
lock1, # first lock
func() as lock2, # this gets preserved
ctx()
):
# body comment
pass
# end body comment
# post-with comment
Add decorator to class methods
This one is very simple, we just want to add our own decorator to all class methods, not normal functions, with the
constraint that it be as close to the function as possible but not after a @contextmanager.
>>> def insert_decorator_to_class_methods(src: str) -> str:
... fst = FST(src, 'exec')
...
... for f in fst.walk({FunctionDef, AsyncFunctionDef}):
... if f.parent.is_ClassDef:
... if any((deco := d).is_Name and d.id == 'contextmanager'
... for d in f.decorator_list):
... idx = deco.pfield.idx
... else:
... idx = 'end'
...
... f.decorator_list.insert('@our_decorator', idx, trivia=False)
...
... return fst.src
>>> src = r"""
... def normal_function():
... ...
...
... class SomeClass:
... @staticmethod
... def get_options() -> dict[str, Any]:
... ...
...
... # class comment
... @classmethod
... def get_cls_option(option: str, options: Mapping[str, Any] = {}) -> object:
... ...
...
... # another class comment
... async def set_async_inst_options(**options) -> dict[str, Any]:
... ...
...
... @ \
... staticmethod
... # intentionally screwy
... @ (
... contextmanager
... )
... def options(**options) -> Iterator[dict[str, Any]]:
... ...
... """.strip()
Original:
>>> pprint(src)
def normal_function():
...
class SomeClass:
@staticmethod
def get_options() -> dict[str, Any]:
...
# class comment
@classmethod
def get_cls_option(option: str, options: Mapping[str, Any] = {}) -> object:
...
# another class comment
async def set_async_inst_options(**options) -> dict[str, Any]:
...
@ \
staticmethod
# intentionally screwy
@ (
contextmanager
)
def options(**options) -> Iterator[dict[str, Any]]:
...
Processed:
>>> pprint(insert_decorator_to_class_methods(src))
def normal_function():
...
class SomeClass:
@staticmethod
@our_decorator
def get_options() -> dict[str, Any]:
...
# class comment
@classmethod
@our_decorator
def get_cls_option(option: str, options: Mapping[str, Any] = {}) -> object:
...
# another class comment
@our_decorator
async def set_async_inst_options(**options) -> dict[str, Any]:
...
@ \
staticmethod
# intentionally screwy
@our_decorator
@ (
contextmanager
)
def options(**options) -> Iterator[dict[str, Any]]:
...
Comprehension to loop
We build up a body and replace the original comprehension Assign statement with the new statements.
>>> def list_comprehensions_to_loops(src):
... fst = FST(src, 'exec')
...
... for f in fst.walk():
... if (f.is_Assign # to show we can check here instead of passing to walk()
... and f.value.is_ListComp
... and f.targets[0].is_Name
... and len(f.targets) == 1
... ):
... var = f.targets[0].id
... fcomp = f.value
... fcur = ftop = FST(f'{var} = []\n_', 'exec')
... # the `_` will become first `for`
...
... for fgen in fcomp.generators:
... ffor = FST('for _ in _:\n _') # for loop, just copy the source
... ffor.target = fgen.target.copy()
... ffor.iter = fgen.iter.copy()
...
... fcur = fcur.body[-1].replace(ffor)
... fifs = fgen.ifs
... nifs = len(fifs)
...
... if nifs: # if no ifs then no test
... if nifs == 1: # if single test then just use that
... ftest = fifs[0].copy()
...
... else: # if multiple then join with `and`
... ftest = FST(' and '.join('_' * nifs))
...
... for i, fif in enumerate(fifs):
... ftest.values[i] = fif.copy()
...
... fifstmt = FST('if _:\n _')
... fifstmt.test = ftest
...
... fcur = fcur.body[-1].replace(fifstmt)
...
... # the ffor is the last one processed above (the innermost)
... fcur.body[-1].replace(f'{var}.append({fcomp.elt.own_src()})')
...
... f.replace(
... ftop,
... one=False, # this allows to replace a single element with multiple
... trivia=(False, False)
... )
...
... return fst.src
>>> src = r"""
... def f(k):
... # safe comment
... clean = [i for i in k]
...
... # happy comment
... messy = [
... ( i ) # weird pars
... for (
... j
... ) in k # outer loop
... if
... j # misc comment
... and not validate(j)
... for i in j # inner loop
... if i
... if validate(i)
... ]
... # silly comment
...
... return clean + messy
... """.strip()
Original:
>>> pprint(src)
def f(k):
# safe comment
clean = [i for i in k]
# happy comment
messy = [
( i ) # weird pars
for (
j
) in k # outer loop
if
j # misc comment
and not validate(j)
for i in j # inner loop
if i
if validate(i)
]
# silly comment
return clean + messy
Processed:
>>> pprint(list_comprehensions_to_loops(src))
def f(k):
# safe comment
clean = []
for i in k:
clean.append(i)
# happy comment
messy = []
for j in k:
if (j # misc comment
and not validate(j)):
for i in j:
if i and validate(i):
messy.append(i)
# silly comment
return clean + messy
Align equals
This is just here to show pure source modification without messing with the actual structure. It just walks everything
and stores all Assign nodes and any block of two or more nodes which start on consequtive lines and at the same column
get aligned. All the source put function does is offset AST node locations for source change. Same column is used as
proxy for verifying same parent.
>>> def align_equals(fst):
... flast = None
... blocks = [] # [[feq1, feq2, ...], [feq1, ...], ...]
...
... # first we build up list of contiguous Assign nodes
... for f in fst.walk(Assign):
... if not flast or f.col != flast.col or f.ln != flast.ln + 1:
... blocks.append([])
...
... blocks[-1].append(flast := f)
...
... for block in blocks:
... if len(block) > 1:
... eq_col = max(f.targets[-1].pars().end_col for f in block)
...
... for f in block:
... # we know this is all on one line by how we constructed it
... ln, _, _, end_col = f.targets[-1].pars()
... eq_str = f'{" " * (eq_col - end_col)} = '
...
... f.put_src(eq_str, ln, end_col, ln, f.value.pars().col, 'offset')
>>> src = r"""
... a = 1
... this = that
... whatever[f].a = "YAY!"
...
... on_multiple_lines = (
... 1, 2)
... we_dont_align = None
...
... ASTS_LEAF_FUNCDEF = {FunctionDef, AsyncFunctionDef}
... ASTS_LEAF_DEF = ASTS_LEAF_FUNCDEF | {ClassDef}
... ASTS_LEAF_DEF_OR_MOD = ASTS_LEAF_DEF | ASTS_LEAF_MOD
... ASTS_LEAF_FOR = {For, AsyncFor}
... ASTS_LEAF_WITH = {With, AsyncWith}
... ASTS_LEAF_TRY = {Try, TryStar}
... ASTS_LEAF_IMPORT = {Import, ImportFrom}
... """.strip()
Original:
>>> pprint(src)
a = 1
this = that
whatever[f].a = "YAY!"
on_multiple_lines = (
1, 2)
we_dont_align = None
ASTS_LEAF_FUNCDEF = {FunctionDef, AsyncFunctionDef}
ASTS_LEAF_DEF = ASTS_LEAF_FUNCDEF | {ClassDef}
ASTS_LEAF_DEF_OR_MOD = ASTS_LEAF_DEF | ASTS_LEAF_MOD
ASTS_LEAF_FOR = {For, AsyncFor}
ASTS_LEAF_WITH = {With, AsyncWith}
ASTS_LEAF_TRY = {Try, TryStar}
ASTS_LEAF_IMPORT = {Import, ImportFrom}
Processed:
>>> fst = FST(src, 'exec')
>>> align_equals(fst) # we pass as an FST so we can `.dump()` below
>>> pprint(fst.src)
a = 1
this = that
whatever[f].a = "YAY!"
on_multiple_lines = (
1, 2)
we_dont_align = None
ASTS_LEAF_FUNCDEF = {FunctionDef, AsyncFunctionDef}
ASTS_LEAF_DEF = ASTS_LEAF_FUNCDEF | {ClassDef}
ASTS_LEAF_DEF_OR_MOD = ASTS_LEAF_DEF | ASTS_LEAF_MOD
ASTS_LEAF_FOR = {For, AsyncFor}
ASTS_LEAF_WITH = {With, AsyncWith}
ASTS_LEAF_TRY = {Try, TryStar}
ASTS_LEAF_IMPORT = {Import, ImportFrom}
Here we do a quick dump of the first three statements to show that all the locations of the nodes were offset properly.
>>> _ = fst.body[:3].copy().dump('stmt')
Module - ROOT 0,0..2,22
.body[3]
0: a = 1
0] Assign - 0,0..0,17
.targets[1]
0] Name 'a' Store - 0,0..0,1
.value Constant 1 - 0,16..0,17
1: this = that
1] Assign - 1,0..1,20
.targets[1]
0] Name 'this' Store - 1,0..1,4
.value Name 'that' Load - 1,16..1,20
2: whatever[f].a = "YAY!"
2] Assign - 2,0..2,22
.targets[1]
0] Attribute - 2,0..2,13
.value Subscript - 2,0..2,11
.value Name 'whatever' Load - 2,0..2,8
.slice Name 'f' Load - 2,9..2,10
.ctx Load
.attr 'a'
.ctx Store
.value Constant 'YAY!' - 2,16..2,22
Make all f-strings self-documenting
Note: This example in particular is Python 3.12+ because of the comment in the multiline f-string. The function itself will work mostly on lower version Pythons.
Suppose you just want to improve the debug logs by adding self-documenting debug strings to all f-strings, e.g.
f"{var}" into f"{var=}". The example below shows how, though there is just one caveat to look out for if you want to
continue working with the AST tree:
The source is updated and all the node locations are fine, but the put_src(..., action='offset') only offsets node
locations and does not create the AST nodes for any new Constant strings due to the newly self-documenting
FormattedValue nodes. If you only care about source for output (like this example) then this is a non-issue. If you
need those nodes to continue working with an AST tree then you can do a root.reparse() after making all the changes,
or individual reparse() on each modified statement, or use put_src(..., action='reparse') for each individual
change, but that would be slower.
>>> def self_document_fstring_expressions(src: str) -> str:
... fst = FST(src, 'exec')
...
... for f in fst.walk(FormattedValue): # could add Interpolation to do both
... _, _, end_ln, end_col = f.value.pars() # value end after parentheses
...
... # hacky but valid way to check if '=' already there and not in comment
... # between the end of the expression and end of FormattedValue
... lines = f.get_src(end_ln, end_col, f.end_ln, f.end_col - 1, as_lines=True)
...
... if not any(l.lstrip().startswith('=') for l in lines):
... # insert the equals just after the expression
... f.put_src('=', end_ln, end_col, end_ln, end_col, 'offset')
...
... return fst.src
>>> src = """
... f'added here {a}, and here { ( b ) }'
...
... f'not added here {c=}'
...
... f\"\"\"{( # =========================
... d, # commented out =, so added
... )}, {e
... =} <- not commented out so not added\"\"\"
... """.strip()
Original:
>>> pprint(src)
f'added here {a}, and here { ( b ) }'
f'not added here {c=}'
f"""{( # =========================
d, # commented out =, so added
)}, {e
=} <- not commented out so not added"""
Processed:
f'added here {a=}, and here { ( b )= }'
f'not added here {c=}'
f"""{( # =========================
d, # commented out =, so added
)=}, {e
=} <- not commented out so not added"""
Normalize docstrings
get_docstr() gives you a normal string and put_docstr() puts it with appropriate formatting for a docstring. In this
case we also pass reput=True because we specifically want it to remove and then put the docstring expression again so
that it precedes any comments. Otherwise it just replaces the docstring in the location where it currently is.
>>> def normalize_docstrings(src):
... fst = FST(src, 'exec')
...
... for f in fst.walk():
... if f.has_docstr:
... f.put_docstr(f.get_docstr(), reput=True)
...
... return fst.src
>>> src = """
... class cls:
... # docstr should be before this
... "Some\\nunformatted\\ndocstr." # what is this?!?
...
... def method(self):
...
... \"\"\"I'm not
... even properly
... aligned!!! \\U0001F92a
... Or am I\\x3f\"\"\"
... # comment
...
... pass
... """.strip()
Original:
>>> pprint(src)
class cls:
# docstr should be before this
"Some\nunformatted\ndocstr." # what is this?!?
def method(self):
"""I'm not
even properly
aligned!!! \U0001F92a
Or am I\x3f"""
# comment
pass
Processed:
class cls:
"""Some
unformatted
docstr."""
# docstr should be before this
# what is this?!?
def method(self):
"""I'm not
even properly
aligned!!! 🤪
Or am I?"""
# comment
pass
Reparenthesize expressions
Parentheses are handled completely automatically normally and you can use this mechanism to clean up unnecessary
or ugly parenthesization. Yes we realize some unnecessary parenthesization may be dictated by the currently enforced
aesthetic paradigm of any given project. This is just to show proper functional parenthesization by fst.
>>> def reparenthesize_simple(src):
... fst = FST(src, 'exec')
...
... for f in fst.walk():
... if f.is_parenthesizable():
... f.replace(f.copy())
...
... return fst.src
>>> src = r"""
... (x * y) * (a + b) # "x * y" doesn't need pars, 'a + b' does
...
... (x * y) * z # "x * y" is unpard because doesn't change tree structure
...
... x * (y * z) # not unpard because would change structure (order of operations)
...
... a + ( ( (y) * (z))) # nested pars cleaned up
...
... x * ( ( (y) * (z))) # original pars normally left if needed, unless unpar() used
...
... if (
... (a <= b)
... ): # not needed
... return (a # hello
... < b) # needed for parsability
...
... match ("a" # implicit string pars needed
... "b"):
... case ((1) | (a)): # unnecessary pars
... pass
... case (a, b): # MatchSequence node intrinsic pars are not removed
... pass
... case ( ( (a, b) ) ): # but the unnecessary ones are
... pass
... """.strip()
Original:
>>> pprint(src)
(x * y) * (a + b) # "x * y" doesn't need pars, 'a + b' does
(x * y) * z # "x * y" is unpard because doesn't change tree structure
x * (y * z) # not unpard because would change structure (order of operations)
a + ( ( (y) * (z))) # nested pars cleaned up
x * ( ( (y) * (z))) # original pars normally left if needed, unless unpar() used
if (
(a <= b)
): # not needed
return (a # hello
< b) # needed for parsability
match ("a" # implicit string pars needed
"b"):
case ((1) | (a)): # unnecessary pars
pass
case (a, b): # MatchSequence node intrinsic pars are not removed
pass
case ( ( (a, b) ) ): # but the unnecessary ones are
pass
Processed:
>>> pprint(reparenthesize_simple(src))
x * y * (a + b) # "x * y" doesn't need pars, 'a + b' does
x * y * z # "x * y" is unpard because doesn't change tree structure
x * (y * z) # not unpard because would change structure (order of operations)
a + y * z # nested pars cleaned up
x * ( ( y * z)) # original pars normally left if needed, unless unpar() used
if a <= b: # not needed
return (a # hello
< b) # needed for parsability
match ("a" # implicit string pars needed
"b"):
case 1 | a: # unnecessary pars
pass
case (a, b): # MatchSequence node intrinsic pars are not removed
pass
case (a, b): # but the unnecessary ones are
pass
A slight tweak to the function to unparenthesize first will force reparenthesize of parentheses which are needed but may not have been in normal locations to begin with.
>>> def reparenthesize_full(src):
... fst = FST(src, 'exec')
...
... for f in fst.walk():
... if f.is_parenthesizable():
... f.unpar().replace(f.copy())
...
... return fst.src
>>> pprint(reparenthesize_full(r"""
... a + ( ( (y) * (z))) # nested pars cleaned up
...
... x * ( ( (y) * (z))) # original pars normally left if needed, unless unpar() used
... """.strip()))
a + y * z # nested pars cleaned up
x * (y * z) # original pars normally left if needed, unless unpar() used
Instrument expressions
Suppose there is an external library that overloaded some operators and there are problems and you want to log all the
operations (or do something else with them). This is the type of problem that fst was conceived to solve easily. The
instrumentation is ugly but is meant to show that all these nested manipulations maintain the correct working source.
Note that this instrumentation counts on the fact that the syntactic order of the children of these particular node
types is actually the order they will be evaluated in. This is not always the case, e.g. with the IfExp node the
middle gets evaluated first THEN_THIS if FIRST_THIS else OR_THEN_THIS.
>>> def instrument_operations(src):
... fst = FST(src, 'exec')
...
... for f in fst.walk():
... if getattr(f, '_dirty', False): # did we generate this?
... continue
...
... if f.is_BinOp:
... fargs = [f.left, f.right]
... elif f.is_UnaryOp:
... fargs = [f.operand]
... elif f.is_Compare:
... fargs = [f.left, *f.comparators]
... else:
... continue
...
... for g in f.parents(): # make sure we don't overwrite vars in use by parent
... if (base := getattr(g, '_base_idx', None)) is not None:
... break
... else:
... base = 0
...
... nargs = len(fargs)
... ftmps = ', '.join(f'_{i + base}' for i in range(nargs))
... fsettmps = ', '.join(f'_{i + base} := _' for i in range(nargs))
...
... fnew = FST(f'({fsettmps}, log({f.src!r}, {ftmps}), _)[{nargs + 1}]')
... felts = fnew.value.elts
...
... for i in range(nargs):
... felts[i].value.replace(fargs[i].copy())
... fargs[i].replace(f'_{i + base}')
...
... fold = felts[-1].replace(f.copy()) # put operation using temp vars
... fnew = f.replace(fnew)
...
... fold._dirty = True # must set after the replace because FST may change
... fnew._base_idx = base + nargs - 1 # safe start idx for children to use
...
... return fst.src
>>> src = r"""
... def shape_score(a, b, c, d):
... x = (a * b) - -(c + d)
... y = ~(a - c) + (b ^ d)
... z = (x // (abs(y) or 1))
...
... return (z < 0) * -z + (z >= 0) * +z
... """.strip()
Original:
>>> pprint(src)
def shape_score(a, b, c, d):
x = (a * b) - -(c + d)
y = ~(a - c) + (b ^ d)
z = (x // (abs(y) or 1))
return (z < 0) * -z + (z >= 0) * +z
Processed:
>>> pprint(inst := instrument_operations(src))
def shape_score(a, b, c, d):
x = (_0 := (_1 := a, _2 := b, log('a * b', _1, _2), _1 * _2)[3], _1 := (_1 := (_1 := c, _2 := d, log('c + d', _1, _2), _1 + _2)[3], log('-(c + d)', _1), -_1)[2], log('(a * b) - -(c + d)', _0, _1), _0 - _1)[3]
y = (_0 := (_1 := (_1 := a, _2 := c, log('a - c', _1, _2), _1 - _2)[3], log('~(a - c)', _1), ~_1)[2], _1 := (_1 := b, _2 := d, log('b ^ d', _1, _2), _1 ^ _2)[3], log('~(a - c) + (b ^ d)', _0, _1), _0 + _1)[3]
z = (_0 := x, _1 := abs(y) or 1, log('x // (abs(y) or 1)', _0, _1), _0 // _1)[3]
return (_0 := (_1 := (_2 := z, _3 := 0, log('z < 0', _2, _3), _2 < _3)[3], _2 := (_2 := z, log('-z', _2), -_2)[2], log('(z < 0) * -z', _1, _2), _1 * _2)[3], _1 := (_1 := (_2 := z, _3 := 0, log('z >= 0', _2, _3), _2 >= _3)[3], _2 := (_2 := z, log('+z', _2), +_2)[2], log('(z >= 0) * +z', _1, _2), _1 * _2)[3], log('(z < 0) * -z + (z >= 0) * +z', _0, _1), _0 + _1)[3]
Now lets compile and execute these functions to make sure everything still works.
>>> def log(s, *args): # log function for the instrumented code
... print(f'src: {s!r}, args: {args}')
>>> exec(src, original_ns := {})
>>> exec(inst, inst_ns := {'log': log})
>>> original_shape_score = original_ns['shape_score']
>>> inst_shape_score = inst_ns['shape_score']
Here we execute them the with the same sets of values to make sure everything comes out the same. In addition, the instrumented code will log the source being executed and the arguments to those expressions (or do whatever else you want the log function to do).
>>> print(original_shape_score(1, 2, 3, 4))
1
>>> print(inst_shape_score(1, 2, 3, 4))
src: 'a * b', args: (1, 2)
src: 'c + d', args: (3, 4)
src: '-(c + d)', args: (7,)
src: '(a * b) - -(c + d)', args: (2, -7)
src: 'a - c', args: (1, 3)
src: '~(a - c)', args: (-2,)
src: 'b ^ d', args: (2, 4)
src: '~(a - c) + (b ^ d)', args: (1, 6)
src: 'x // (abs(y) or 1)', args: (9, 7)
src: 'z < 0', args: (1, 0)
src: '-z', args: (1,)
src: '(z < 0) * -z', args: (False, -1)
src: 'z >= 0', args: (1, 0)
src: '+z', args: (1,)
src: '(z >= 0) * +z', args: (True, 1)
src: '(z < 0) * -z + (z >= 0) * +z', args: (0, 1)
1
>>> print(original_shape_score(3, 2, 3, 4))
2
>>> print(inst_shape_score(3, 2, 3, 4))
src: 'a * b', args: (3, 2)
src: 'c + d', args: (3, 4)
src: '-(c + d)', args: (7,)
src: '(a * b) - -(c + d)', args: (6, -7)
src: 'a - c', args: (3, 3)
src: '~(a - c)', args: (0,)
src: 'b ^ d', args: (2, 4)
src: '~(a - c) + (b ^ d)', args: (-1, 6)
src: 'x // (abs(y) or 1)', args: (13, 5)
src: 'z < 0', args: (2, 0)
src: '-z', args: (2,)
src: '(z < 0) * -z', args: (False, -2)
src: 'z >= 0', args: (2, 0)
src: '+z', args: (2,)
src: '(z >= 0) * +z', args: (True, 2)
src: '(z < 0) * -z + (z >= 0) * +z', args: (0, 2)
2
>>> print(original_shape_score(3, -1, 9, 4))
10
>>> print(inst_shape_score(3, -1, 9, 4))
src: 'a * b', args: (3, -1)
src: 'c + d', args: (9, 4)
src: '-(c + d)', args: (13,)
src: '(a * b) - -(c + d)', args: (-3, -13)
src: 'a - c', args: (3, 9)
src: '~(a - c)', args: (-6,)
src: 'b ^ d', args: (-1, 4)
src: '~(a - c) + (b ^ d)', args: (5, -5)
src: 'x // (abs(y) or 1)', args: (10, 1)
src: 'z < 0', args: (10, 0)
src: '-z', args: (10,)
src: '(z < 0) * -z', args: (False, -10)
src: 'z >= 0', args: (10, 0)
src: '+z', args: (10,)
src: '(z >= 0) * +z', args: (True, 10)
src: '(z < 0) * -z + (z >= 0) * +z', args: (0, 10)
10