625 lines
19 KiB
Python
625 lines
19 KiB
Python
|
# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||
|
#
|
||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||
|
# you may not use this file except in compliance with the License.
|
||
|
# You may obtain a copy of the License at
|
||
|
#
|
||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||
|
#
|
||
|
# Unless required by applicable law or agreed to in writing, software
|
||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||
|
# See the License for the specific language governing permissions and
|
||
|
# limitations under the License.
|
||
|
# ==============================================================================
|
||
|
"""Type inference.
|
||
|
|
||
|
This analysis annotates all symbols nodes of an AST with type information
|
||
|
extracted from static sources:
|
||
|
* type annotations
|
||
|
* global and local symbols visible to the function at analysis time
|
||
|
* literals
|
||
|
|
||
|
Important: This analysis is static, and does not detect dynamic type changes.
|
||
|
The analysis attempts to use the values of external symbols, if available. These
|
||
|
values are also considered static for the purpose of analysis.
|
||
|
|
||
|
Requires reaching function definitions analysis.
|
||
|
"""
|
||
|
|
||
|
import itertools
|
||
|
|
||
|
from typing import Any, Callable, Dict, Set
|
||
|
|
||
|
import gast
|
||
|
|
||
|
from tensorflow.python.autograph.pyct import anno
|
||
|
from tensorflow.python.autograph.pyct import cfg
|
||
|
from tensorflow.python.autograph.pyct import qual_names
|
||
|
from tensorflow.python.autograph.pyct import transformer
|
||
|
from tensorflow.python.autograph.pyct.static_analysis import activity
|
||
|
from tensorflow.python.autograph.pyct.static_analysis import annos
|
||
|
|
||
|
|
||
|
class Resolver(object):
|
||
|
"""Resolver objects handle the process of looking up actual names and types.
|
||
|
|
||
|
Unless noted otherwise, all resolve_* methods:
|
||
|
* have a first namespace argument, mapping string to actual values
|
||
|
* have a second types_namespace argument, mapping string to actual inferred
|
||
|
types
|
||
|
* specify names as QN objects
|
||
|
* specify types as a Set of inferred types
|
||
|
|
||
|
Unless noted otherwise, all resolve_* methods must return either:
|
||
|
* a set of `type` objects
|
||
|
* None
|
||
|
"""
|
||
|
|
||
|
def res_name(self, ns, types_ns, name):
|
||
|
"""Resolves the type/value an external (e.g. closure, global) variable.
|
||
|
|
||
|
Args:
|
||
|
ns: namespace
|
||
|
types_ns: types namespace
|
||
|
name: symbol name
|
||
|
Returns:
|
||
|
Tuple (type, static_value). The first element is the type to use for
|
||
|
inferrence. The second is the static value to use. Return None to treat it
|
||
|
as unknown.
|
||
|
"""
|
||
|
raise NotImplementedError('subclasses must implement')
|
||
|
|
||
|
def res_value(self, ns, value):
|
||
|
"""Resolves the type a literal or static value."""
|
||
|
raise NotImplementedError('subclasses must implement')
|
||
|
|
||
|
def res_arg(self, ns, types_ns, f_name, name, type_anno, f_is_local):
|
||
|
"""Resolves the type of a (possibly annotated) function argument.
|
||
|
|
||
|
Args:
|
||
|
ns: namespace
|
||
|
types_ns: types namespace
|
||
|
f_name: str, the function name
|
||
|
name: str, the argument name
|
||
|
type_anno: the type annotating the argument, if any
|
||
|
f_is_local: bool, whether the function is a local function
|
||
|
Returns:
|
||
|
Set of the argument types.
|
||
|
"""
|
||
|
raise NotImplementedError('subclasses must implement')
|
||
|
|
||
|
def res_call(self, ns, types_ns, node, f_type, args, keywords):
|
||
|
"""Resolves the return type an external function or method call.
|
||
|
|
||
|
Args:
|
||
|
ns: namespace
|
||
|
types_ns: types namespace
|
||
|
node: str, the function name
|
||
|
f_type: types of the actual function being called, if known
|
||
|
args: types of each respective argument in node.args
|
||
|
keywords: types of each respective argument in node.keywords
|
||
|
|
||
|
Returns:
|
||
|
Tuple (return_type, side_effect_types). The first element is just the
|
||
|
return types of the function. The second element is a map from
|
||
|
argument names to sets of types, and allow modelling side effects of
|
||
|
functions (for example via global or nonlocal).
|
||
|
"""
|
||
|
raise NotImplementedError('subclasses must implement')
|
||
|
|
||
|
# TODO(mdan): Clean this up.
|
||
|
def res_slice(self, ns, types_ns, node_or_slice, value, slice_):
|
||
|
"""Resolves the return type of slice operation."""
|
||
|
raise NotImplementedError('subclasses must implement')
|
||
|
|
||
|
def res_compare(self, ns, types_ns, node, left, right):
|
||
|
"""Resolves the return type of a unary operation."""
|
||
|
raise NotImplementedError('subclasses must implement')
|
||
|
|
||
|
def res_unop(self, ns, types_ns, node, opnd):
|
||
|
"""Resolves the return type of a unary operation."""
|
||
|
raise NotImplementedError('subclasses must implement')
|
||
|
|
||
|
def res_binop(self, ns, types_ns, node, left, right):
|
||
|
"""Resolves the return type of a binary operation."""
|
||
|
raise NotImplementedError('subclasses must implement')
|
||
|
|
||
|
def res_list_literal(self, ns, elt_types):
|
||
|
"""Resolves the type of a list literal from its elements."""
|
||
|
raise NotImplementedError('subclasses must implement')
|
||
|
|
||
|
|
||
|
class _TypeMap(object):
|
||
|
"""Abstraction for the state of the CFG walk for type inference.
|
||
|
|
||
|
This is a value type. Only implements the strictly necessary operators.
|
||
|
|
||
|
Attributes:
|
||
|
types: Dict[qual_names.QN, Set[Type]], mapping symbols to the set of
|
||
|
possible types.
|
||
|
"""
|
||
|
|
||
|
def __init__(self, init_from=None):
|
||
|
if init_from:
|
||
|
assert isinstance(init_from, _TypeMap)
|
||
|
self.types = {
|
||
|
s: set(other_types) for s, other_types in init_from.types.items()
|
||
|
}
|
||
|
else:
|
||
|
self.types = {}
|
||
|
|
||
|
def __eq__(self, other):
|
||
|
if frozenset(self.types.keys()) != frozenset(other.types.keys()):
|
||
|
return False
|
||
|
ret = all(self.types[s] == other.types[s] for s in self.types)
|
||
|
return ret
|
||
|
|
||
|
def __ne__(self, other):
|
||
|
return not self.__eq__(other)
|
||
|
|
||
|
def __or__(self, other):
|
||
|
assert isinstance(other, _TypeMap)
|
||
|
result = _TypeMap(self)
|
||
|
for s, other_types in other.types.items():
|
||
|
if s not in result.types:
|
||
|
self_types = set()
|
||
|
result.types[s] = self_types
|
||
|
else:
|
||
|
self_types = result.types[s]
|
||
|
self_types.update(other_types)
|
||
|
return result
|
||
|
|
||
|
def __repr__(self):
|
||
|
return 'SymbolTable {}'.format(self.types)
|
||
|
|
||
|
|
||
|
NO_VALUE = object()
|
||
|
|
||
|
|
||
|
class StmtInferrer(gast.NodeVisitor):
|
||
|
"""Runs type inference on a single AST statement.
|
||
|
|
||
|
This visitor annotates most nodes with type information. It also sets types
|
||
|
for the symbols modified by this statement in its types_out property.
|
||
|
|
||
|
Note: this inferrer is able to capture side effects of functions, however,
|
||
|
these side effects will not be applied to the current expression. Doing so
|
||
|
would create too much of a dependence on the runtime's internal rules about
|
||
|
execution order.
|
||
|
Example:
|
||
|
|
||
|
def f():
|
||
|
nonlocal a
|
||
|
a = 1
|
||
|
return a
|
||
|
|
||
|
a = 0.0
|
||
|
b = f() + a # a = float; side effect of f() ignored
|
||
|
print(a) # a = int; side effect of f() accounted for
|
||
|
"""
|
||
|
|
||
|
def __init__(self,
|
||
|
resolver: Resolver,
|
||
|
scope: activity.Scope,
|
||
|
namespace: Dict[qual_names.QN, Any],
|
||
|
closure_types: Dict[qual_names.QN, Set[Any]],
|
||
|
types_in: _TypeMap):
|
||
|
self.resolver = resolver
|
||
|
self.scope = scope
|
||
|
self.namespace = namespace
|
||
|
self.closure_types = closure_types
|
||
|
self.types_in = types_in
|
||
|
self.new_symbols = {}
|
||
|
|
||
|
# rvalue type. This property is set when encountering an assign operation,
|
||
|
# so that visiting nodes with Store ctx (typically found on left side of
|
||
|
# assignments) can infer the type they should receive.
|
||
|
self.rtype = None
|
||
|
|
||
|
def visit(self, node):
|
||
|
types = super().visit(node)
|
||
|
if __debug__:
|
||
|
self._check_set(types)
|
||
|
if types is not None:
|
||
|
# TODO(mdan): Normalize by removing subtypes.
|
||
|
anno.setanno(node, anno.Static.TYPES, tuple(types))
|
||
|
return types
|
||
|
|
||
|
def _check_set(self, value):
|
||
|
if value is not None and not isinstance(value, set):
|
||
|
raise ValueError('{} method expected to return set, got {}'.format(
|
||
|
self.resolver, value))
|
||
|
|
||
|
def visit_Constant(self, node):
|
||
|
types = self.resolver.res_value(self.namespace, node.value)
|
||
|
if __debug__:
|
||
|
self._check_set(types)
|
||
|
return types
|
||
|
|
||
|
def _apply_unpacking(self, node):
|
||
|
assert isinstance(node.ctx, gast.Store)
|
||
|
if self.rtype is not None:
|
||
|
original_stype = self.rtype
|
||
|
# TODO(mdan): Find a better way to express unpacking.
|
||
|
i_type = self.resolver.res_value(self.namespace, 0)
|
||
|
for i, elt in enumerate(node.elts):
|
||
|
self.rtype = self.resolver.res_slice(
|
||
|
self.namespace, self.types_in.types, i, original_stype, i_type)
|
||
|
self.visit(elt)
|
||
|
self.rtype = original_stype
|
||
|
return original_stype
|
||
|
return None
|
||
|
|
||
|
def visit_Tuple(self, node):
|
||
|
if isinstance(node.ctx, gast.Load):
|
||
|
elt_types = ()
|
||
|
for elt in node.elts:
|
||
|
types_ = self.visit(elt)
|
||
|
if types_ is None:
|
||
|
return None
|
||
|
elt_types += (types_,)
|
||
|
return set(itertools.product(*elt_types))
|
||
|
return self._apply_unpacking(node)
|
||
|
|
||
|
def visit_List(self, node):
|
||
|
if isinstance(node.ctx, gast.Load):
|
||
|
elt_types = tuple(self.visit(elt) for elt in node.elts)
|
||
|
return self.resolver.res_list_literal(self.namespace, elt_types)
|
||
|
return self._apply_unpacking(node)
|
||
|
|
||
|
def visit_Set(self, node):
|
||
|
raise NotImplementedError()
|
||
|
|
||
|
def visit_Name(self, node):
|
||
|
name = anno.getanno(node, anno.Basic.QN)
|
||
|
|
||
|
if isinstance(node.ctx, gast.Load):
|
||
|
types = self.types_in.types.get(name, None)
|
||
|
if types is None:
|
||
|
if (name not in self.scope.bound) or (name in self.scope.nonlocals):
|
||
|
# TODO(mdan): Test with global variables.
|
||
|
if name in self.closure_types:
|
||
|
types = self.closure_types[name]
|
||
|
else:
|
||
|
types, value = self.resolver.res_name(
|
||
|
self.namespace, self.types_in.types, name)
|
||
|
if value is not None:
|
||
|
anno.setanno(node, anno.Static.VALUE, value)
|
||
|
|
||
|
elif isinstance(node.ctx, gast.Param):
|
||
|
# The direct parent it the whole function scope. See activity.py.
|
||
|
f_is_local = self.scope.parent.parent is not None
|
||
|
|
||
|
type_name = anno.getanno(node.annotation, anno.Basic.QN, None)
|
||
|
types = self.resolver.res_arg(self.namespace, self.types_in.types,
|
||
|
self.scope.function_name, name, type_name,
|
||
|
f_is_local)
|
||
|
if types is not None:
|
||
|
self.new_symbols[name] = types
|
||
|
|
||
|
elif isinstance(node.ctx, gast.Store):
|
||
|
if self.rtype is not None:
|
||
|
self.new_symbols[name] = self.rtype
|
||
|
types = self.rtype
|
||
|
|
||
|
else:
|
||
|
assert False, 'unknown ctx'
|
||
|
|
||
|
if __debug__:
|
||
|
self._check_set(types)
|
||
|
|
||
|
return types
|
||
|
|
||
|
def visit_Attribute(self, node):
|
||
|
parent_types = self.visit(node.value)
|
||
|
|
||
|
# Attempt to use the static value if known.
|
||
|
parent_value = anno.Static.VALUE.of(node.value, None)
|
||
|
if parent_value is not None:
|
||
|
static_value = getattr(parent_value, node.attr, NO_VALUE)
|
||
|
|
||
|
if static_value is NO_VALUE:
|
||
|
# Unexpected failure to resolve attribute. Ask the resolver about the
|
||
|
# full name instead.
|
||
|
types, static_value = self.resolver.res_name(
|
||
|
self.namespace, self.types_in, anno.Basic.QN.of(node))
|
||
|
anno.setanno(node, anno.Static.VALUE, static_value)
|
||
|
if __debug__:
|
||
|
self._check_set(types)
|
||
|
return types
|
||
|
|
||
|
else:
|
||
|
# Fall back to the type if that is known.
|
||
|
if parent_types is None:
|
||
|
return None
|
||
|
|
||
|
inferred_values = [getattr(t, node.attr, None) for t in parent_types]
|
||
|
if not inferred_values:
|
||
|
return None
|
||
|
|
||
|
static_value = inferred_values[0]
|
||
|
if static_value is None:
|
||
|
return None
|
||
|
|
||
|
if any(v is not static_value for v in inferred_values[1:]):
|
||
|
# Static value not stable, assume it's dynamic.
|
||
|
return None
|
||
|
|
||
|
types = self.resolver.res_value(self.namespace, static_value)
|
||
|
anno.setanno(node, anno.Static.VALUE, static_value)
|
||
|
|
||
|
if __debug__:
|
||
|
self._check_set(types)
|
||
|
|
||
|
return types
|
||
|
|
||
|
def visit_FunctionDef(self, node):
|
||
|
f_name = qual_names.QN(node.name)
|
||
|
|
||
|
if node.decorator_list:
|
||
|
raise NotImplementedError('decorators: {}'.format(node.decorator_list))
|
||
|
|
||
|
ret_types = None
|
||
|
if node.returns:
|
||
|
ret_types, _ = self.resolver.res_name(
|
||
|
self.namespace, self.types_in.types, anno.Basic.QN.of(node.returns))
|
||
|
if __debug__:
|
||
|
self._check_set(ret_types)
|
||
|
|
||
|
if ret_types is None:
|
||
|
ret_types = {Any}
|
||
|
|
||
|
f_types = set()
|
||
|
for rt in ret_types:
|
||
|
f_types.add(Callable[[Any], rt])
|
||
|
|
||
|
self.new_symbols[f_name] = f_types
|
||
|
# The definition of a function is an expression, hence has no return value.
|
||
|
return None
|
||
|
|
||
|
def _resolve_typed_callable(self, f_types, arg_types, keyword_types):
|
||
|
ret_types = set()
|
||
|
for t in f_types:
|
||
|
|
||
|
if isinstance(t, Callable):
|
||
|
# Note: these are undocummented - may be version-specific!
|
||
|
# Callable[[x], y]: __args__ are (x, y)
|
||
|
args = t.__args__
|
||
|
if args:
|
||
|
ret_types.add(args[-1])
|
||
|
else:
|
||
|
ret_types.add(Any)
|
||
|
else:
|
||
|
raise NotImplementedError('callable type {}'.format(type(t)))
|
||
|
|
||
|
# Side effects can not be inferred based on type alone.
|
||
|
side_effects = None
|
||
|
return ret_types, side_effects
|
||
|
|
||
|
def visit_Call(self, node):
|
||
|
self.visit(node.func)
|
||
|
|
||
|
f_name = anno.Basic.QN.of(node.func)
|
||
|
arg_types = [self.visit(a) for a in node.args]
|
||
|
keyword_types = [self.visit(kw.value) for kw in node.keywords]
|
||
|
|
||
|
if f_name in self.scope.bound:
|
||
|
# Local function, use local type definitions, if available.
|
||
|
f_type = self.types_in.types.get(f_name, None)
|
||
|
if f_type is None:
|
||
|
# No static type info available, nothing more to do.
|
||
|
ret_type, side_effects = None, None
|
||
|
else:
|
||
|
ret_type, side_effects = self._resolve_typed_callable(
|
||
|
f_type, arg_types, keyword_types)
|
||
|
|
||
|
else:
|
||
|
# Nonlocal function, resolve externally.
|
||
|
f_type = anno.Static.TYPES.of(node.func, None)
|
||
|
ret_type, side_effects = self.resolver.res_call(self.namespace,
|
||
|
self.types_in.types, node,
|
||
|
f_type, arg_types,
|
||
|
keyword_types)
|
||
|
|
||
|
if __debug__:
|
||
|
self._check_set(ret_type)
|
||
|
if side_effects:
|
||
|
if not isinstance(side_effects, dict):
|
||
|
raise ValueError(
|
||
|
'side effects must be dict, got {}'.format(side_effects))
|
||
|
for k, v in side_effects.items():
|
||
|
if not isinstance(k, qual_names.QN):
|
||
|
raise ValueError('side effect keys must be QNs, got {}'.format(k))
|
||
|
self._check_set(v)
|
||
|
|
||
|
if side_effects:
|
||
|
self.new_symbols.update(side_effects)
|
||
|
return ret_type
|
||
|
|
||
|
def visit_Expr(self, node):
|
||
|
return self.visit(node.value)
|
||
|
|
||
|
def visit_Assign(self, node):
|
||
|
self.rtype = self.visit(node.value)
|
||
|
|
||
|
for t in node.targets:
|
||
|
self.visit(t)
|
||
|
|
||
|
self.rtype = None
|
||
|
|
||
|
def visit_Subscript(self, node):
|
||
|
val_types = self.visit(node.value)
|
||
|
slice_types = self.visit(node.slice)
|
||
|
|
||
|
if val_types is None or slice_types is None:
|
||
|
return None
|
||
|
|
||
|
types = self.resolver.res_slice(
|
||
|
self.namespace, self.types_in.types, node, val_types, slice_types)
|
||
|
|
||
|
if __debug__:
|
||
|
self._check_set(types)
|
||
|
|
||
|
return types
|
||
|
|
||
|
def visit_Compare(self, node):
|
||
|
left_types = self.visit(node.left)
|
||
|
right_types = [self.visit(c) for c in node.comparators]
|
||
|
|
||
|
if left_types is None or any(t is None for t in right_types):
|
||
|
return None
|
||
|
|
||
|
types = self.resolver.res_compare(
|
||
|
self.namespace, self.types_in.types, node, left_types, right_types)
|
||
|
|
||
|
if __debug__:
|
||
|
self._check_set(types)
|
||
|
|
||
|
return types
|
||
|
|
||
|
def visit_BinOp(self, node):
|
||
|
left_types = self.visit(node.left)
|
||
|
right_types = self.visit(node.right)
|
||
|
|
||
|
if left_types is None or right_types is None:
|
||
|
return None
|
||
|
|
||
|
types = self.resolver.res_binop(
|
||
|
self.namespace, self.types_in.types, node, left_types, right_types)
|
||
|
|
||
|
if __debug__:
|
||
|
self._check_set(types)
|
||
|
|
||
|
return types
|
||
|
|
||
|
def visit_UnaryOp(self, node):
|
||
|
opnd_types = self.visit(node.operand)
|
||
|
|
||
|
if opnd_types is None:
|
||
|
return None
|
||
|
|
||
|
types = self.resolver.res_unop(
|
||
|
self.namespace, self.types_in.types, node, opnd_types)
|
||
|
|
||
|
if __debug__:
|
||
|
self._check_set(types)
|
||
|
|
||
|
return types
|
||
|
|
||
|
|
||
|
class Analyzer(cfg.GraphVisitor):
|
||
|
"""CFG visitor that propagates type information across statements."""
|
||
|
|
||
|
def __init__(self, graph, resolver, namespace, scope, closure_types):
|
||
|
"""Creates a new analyzer.
|
||
|
|
||
|
Args:
|
||
|
graph: cfg.Graph
|
||
|
resolver: Resolver
|
||
|
namespace: Dict[str, Any]
|
||
|
scope: activity.Scope
|
||
|
closure_types: Dict[QN, Set]
|
||
|
"""
|
||
|
super(Analyzer, self).__init__(graph)
|
||
|
self.resolver = resolver
|
||
|
self.namespace = namespace
|
||
|
self.scope = scope
|
||
|
self.closure_types = closure_types
|
||
|
|
||
|
context_types = {
|
||
|
n: t for n, t in closure_types.items() if n not in scope.bound
|
||
|
}
|
||
|
if context_types:
|
||
|
self.context_types = _TypeMap()
|
||
|
self.context_types.types = context_types
|
||
|
else:
|
||
|
self.context_types = None
|
||
|
|
||
|
def init_state(self, _):
|
||
|
return _TypeMap()
|
||
|
|
||
|
def _update_closure_types(self, ast_node, types):
|
||
|
existing_types = anno.Static.CLOSURE_TYPES.of(ast_node, None)
|
||
|
|
||
|
if existing_types is None:
|
||
|
existing_types = {}
|
||
|
anno.Static.CLOSURE_TYPES.add_to(ast_node, existing_types)
|
||
|
|
||
|
for k, v in types.types.items():
|
||
|
if k in existing_types:
|
||
|
existing_types[k].update(v)
|
||
|
else:
|
||
|
existing_types[k] = set(v)
|
||
|
|
||
|
def visit_node(self, node):
|
||
|
prev_types_out = self.out[node]
|
||
|
|
||
|
types_in = _TypeMap()
|
||
|
for n in node.prev:
|
||
|
types_in |= self.out[n]
|
||
|
if (self.context_types is not None) and (node is self.graph.entry):
|
||
|
types_in |= self.context_types
|
||
|
|
||
|
types_out = _TypeMap(types_in)
|
||
|
ast_node = node.ast_node
|
||
|
|
||
|
inferrer = StmtInferrer(self.resolver, self.scope, self.namespace,
|
||
|
self.closure_types, types_in)
|
||
|
inferrer.visit(ast_node)
|
||
|
types_out.types.update(inferrer.new_symbols)
|
||
|
|
||
|
reaching_fndefs = anno.Static.DEFINED_FNS_IN.of(ast_node)
|
||
|
node_scope = anno.Static.SCOPE.of(ast_node, None)
|
||
|
if node_scope is not None:
|
||
|
# TODO(mdan): Check that it's actually safe to skip nodes without scope.
|
||
|
reads = {str(qn) for qn in node_scope.read}
|
||
|
for def_node in reaching_fndefs:
|
||
|
if def_node.name in reads:
|
||
|
self._update_closure_types(def_node, types_out)
|
||
|
|
||
|
self.in_[node] = types_in
|
||
|
self.out[node] = types_out
|
||
|
|
||
|
return prev_types_out != types_out
|
||
|
|
||
|
|
||
|
class FunctionVisitor(transformer.Base):
|
||
|
"""AST visitor that applies type inference to each function separately."""
|
||
|
|
||
|
def __init__(self, source_info, graphs, resolver):
|
||
|
super(FunctionVisitor, self).__init__(source_info)
|
||
|
self.graphs = graphs
|
||
|
self.resolver = resolver
|
||
|
|
||
|
def visit_FunctionDef(self, node):
|
||
|
subgraph = self.graphs[node]
|
||
|
scope = anno.getanno(node, annos.NodeAnno.ARGS_AND_BODY_SCOPE)
|
||
|
closure_types = anno.getanno(node, anno.Static.CLOSURE_TYPES, {})
|
||
|
|
||
|
analyzer = Analyzer(subgraph, self.resolver, self.ctx.info.namespace, scope,
|
||
|
closure_types)
|
||
|
analyzer.visit_forward()
|
||
|
|
||
|
# Recursively process any remaining subfunctions.
|
||
|
node.body = self.visit_block(node.body)
|
||
|
|
||
|
return node
|
||
|
|
||
|
|
||
|
def resolve(node, source_info, graphs, resolver):
|
||
|
"""Performs type inference.
|
||
|
|
||
|
Args:
|
||
|
node: ast.AST
|
||
|
source_info: transformer.SourceInfo
|
||
|
graphs: Dict[ast.FunctionDef, cfg.Graph]
|
||
|
resolver: Resolver
|
||
|
|
||
|
Returns:
|
||
|
ast.AST
|
||
|
"""
|
||
|
visitor = FunctionVisitor(source_info, graphs, resolver)
|
||
|
node = visitor.visit(node)
|
||
|
return node
|