summaryrefslogtreecommitdiffhomepage
path: root/scripts/jinja2/visitor.py
diff options
context:
space:
mode:
Diffstat (limited to 'scripts/jinja2/visitor.py')
-rw-r--r--scripts/jinja2/visitor.py39
1 files changed, 22 insertions, 17 deletions
diff --git a/scripts/jinja2/visitor.py b/scripts/jinja2/visitor.py
index 413e7c3..b150e57 100644
--- a/scripts/jinja2/visitor.py
+++ b/scripts/jinja2/visitor.py
@@ -1,17 +1,19 @@
-# -*- coding: utf-8 -*-
+"""API for traversing the AST nodes. Implemented by the compiler and
+meta introspection.
"""
- jinja2.visitor
- ~~~~~~~~~~~~~~
+import typing as t
- This module implements a visitor for the nodes.
+from .nodes import Node
- :copyright: (c) 2010 by the Jinja Team.
- :license: BSD.
-"""
-from jinja2.nodes import Node
+if t.TYPE_CHECKING:
+ import typing_extensions as te
+
+ class VisitCallable(te.Protocol):
+ def __call__(self, node: Node, *args: t.Any, **kwargs: t.Any) -> t.Any:
+ ...
-class NodeVisitor(object):
+class NodeVisitor:
"""Walks the abstract syntax tree and call visitor functions for every
node found. The visitor functions may return values which will be
forwarded by the `visit` method.
@@ -23,22 +25,23 @@ class NodeVisitor(object):
(return value `None`) the `generic_visit` visitor is used instead.
"""
- def get_visitor(self, node):
+ def get_visitor(self, node: Node) -> "t.Optional[VisitCallable]":
"""Return the visitor function for this node or `None` if no visitor
exists for this node. In that case the generic visit function is
used instead.
"""
- method = 'visit_' + node.__class__.__name__
- return getattr(self, method, None)
+ return getattr(self, f"visit_{type(node).__name__}", None) # type: ignore
- def visit(self, node, *args, **kwargs):
+ def visit(self, node: Node, *args: t.Any, **kwargs: t.Any) -> t.Any:
"""Visit a node."""
f = self.get_visitor(node)
+
if f is not None:
return f(node, *args, **kwargs)
+
return self.generic_visit(node, *args, **kwargs)
- def generic_visit(self, node, *args, **kwargs):
+ def generic_visit(self, node: Node, *args: t.Any, **kwargs: t.Any) -> t.Any:
"""Called if no explicit visitor function exists for a node."""
for node in node.iter_child_nodes():
self.visit(node, *args, **kwargs)
@@ -55,7 +58,7 @@ class NodeTransformer(NodeVisitor):
replacement takes place.
"""
- def generic_visit(self, node, *args, **kwargs):
+ def generic_visit(self, node: Node, *args: t.Any, **kwargs: t.Any) -> Node:
for field, old_value in node.iter_fields():
if isinstance(old_value, list):
new_values = []
@@ -77,11 +80,13 @@ class NodeTransformer(NodeVisitor):
setattr(node, field, new_node)
return node
- def visit_list(self, node, *args, **kwargs):
+ def visit_list(self, node: Node, *args: t.Any, **kwargs: t.Any) -> t.List[Node]:
"""As transformers may return lists in some places this method
can be used to enforce a list as return value.
"""
rv = self.visit(node, *args, **kwargs)
+
if not isinstance(rv, list):
- rv = [rv]
+ return [rv]
+
return rv