Intelegentny_Pszczelarz/.venv/Lib/site-packages/tensorflow/python/autograph/pyct/static_analysis/type_inference.py
2023-06-19 00:49:18 +02:00

625 lines
19 KiB
Python

# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Type inference.
This analysis annotates all symbols nodes of an AST with type information
extracted from static sources:
* type annotations
* global and local symbols visible to the function at analysis time
* literals
Important: This analysis is static, and does not detect dynamic type changes.
The analysis attempts to use the values of external symbols, if available. These
values are also considered static for the purpose of analysis.
Requires reaching function definitions analysis.
"""
import itertools
from typing import Any, Callable, Dict, Set
import gast
from tensorflow.python.autograph.pyct import anno
from tensorflow.python.autograph.pyct import cfg
from tensorflow.python.autograph.pyct import qual_names
from tensorflow.python.autograph.pyct import transformer
from tensorflow.python.autograph.pyct.static_analysis import activity
from tensorflow.python.autograph.pyct.static_analysis import annos
class Resolver(object):
"""Resolver objects handle the process of looking up actual names and types.
Unless noted otherwise, all resolve_* methods:
* have a first namespace argument, mapping string to actual values
* have a second types_namespace argument, mapping string to actual inferred
types
* specify names as QN objects
* specify types as a Set of inferred types
Unless noted otherwise, all resolve_* methods must return either:
* a set of `type` objects
* None
"""
def res_name(self, ns, types_ns, name):
"""Resolves the type/value an external (e.g. closure, global) variable.
Args:
ns: namespace
types_ns: types namespace
name: symbol name
Returns:
Tuple (type, static_value). The first element is the type to use for
inferrence. The second is the static value to use. Return None to treat it
as unknown.
"""
raise NotImplementedError('subclasses must implement')
def res_value(self, ns, value):
"""Resolves the type a literal or static value."""
raise NotImplementedError('subclasses must implement')
def res_arg(self, ns, types_ns, f_name, name, type_anno, f_is_local):
"""Resolves the type of a (possibly annotated) function argument.
Args:
ns: namespace
types_ns: types namespace
f_name: str, the function name
name: str, the argument name
type_anno: the type annotating the argument, if any
f_is_local: bool, whether the function is a local function
Returns:
Set of the argument types.
"""
raise NotImplementedError('subclasses must implement')
def res_call(self, ns, types_ns, node, f_type, args, keywords):
"""Resolves the return type an external function or method call.
Args:
ns: namespace
types_ns: types namespace
node: str, the function name
f_type: types of the actual function being called, if known
args: types of each respective argument in node.args
keywords: types of each respective argument in node.keywords
Returns:
Tuple (return_type, side_effect_types). The first element is just the
return types of the function. The second element is a map from
argument names to sets of types, and allow modelling side effects of
functions (for example via global or nonlocal).
"""
raise NotImplementedError('subclasses must implement')
# TODO(mdan): Clean this up.
def res_slice(self, ns, types_ns, node_or_slice, value, slice_):
"""Resolves the return type of slice operation."""
raise NotImplementedError('subclasses must implement')
def res_compare(self, ns, types_ns, node, left, right):
"""Resolves the return type of a unary operation."""
raise NotImplementedError('subclasses must implement')
def res_unop(self, ns, types_ns, node, opnd):
"""Resolves the return type of a unary operation."""
raise NotImplementedError('subclasses must implement')
def res_binop(self, ns, types_ns, node, left, right):
"""Resolves the return type of a binary operation."""
raise NotImplementedError('subclasses must implement')
def res_list_literal(self, ns, elt_types):
"""Resolves the type of a list literal from its elements."""
raise NotImplementedError('subclasses must implement')
class _TypeMap(object):
"""Abstraction for the state of the CFG walk for type inference.
This is a value type. Only implements the strictly necessary operators.
Attributes:
types: Dict[qual_names.QN, Set[Type]], mapping symbols to the set of
possible types.
"""
def __init__(self, init_from=None):
if init_from:
assert isinstance(init_from, _TypeMap)
self.types = {
s: set(other_types) for s, other_types in init_from.types.items()
}
else:
self.types = {}
def __eq__(self, other):
if frozenset(self.types.keys()) != frozenset(other.types.keys()):
return False
ret = all(self.types[s] == other.types[s] for s in self.types)
return ret
def __ne__(self, other):
return not self.__eq__(other)
def __or__(self, other):
assert isinstance(other, _TypeMap)
result = _TypeMap(self)
for s, other_types in other.types.items():
if s not in result.types:
self_types = set()
result.types[s] = self_types
else:
self_types = result.types[s]
self_types.update(other_types)
return result
def __repr__(self):
return 'SymbolTable {}'.format(self.types)
NO_VALUE = object()
class StmtInferrer(gast.NodeVisitor):
"""Runs type inference on a single AST statement.
This visitor annotates most nodes with type information. It also sets types
for the symbols modified by this statement in its types_out property.
Note: this inferrer is able to capture side effects of functions, however,
these side effects will not be applied to the current expression. Doing so
would create too much of a dependence on the runtime's internal rules about
execution order.
Example:
def f():
nonlocal a
a = 1
return a
a = 0.0
b = f() + a # a = float; side effect of f() ignored
print(a) # a = int; side effect of f() accounted for
"""
def __init__(self,
resolver: Resolver,
scope: activity.Scope,
namespace: Dict[qual_names.QN, Any],
closure_types: Dict[qual_names.QN, Set[Any]],
types_in: _TypeMap):
self.resolver = resolver
self.scope = scope
self.namespace = namespace
self.closure_types = closure_types
self.types_in = types_in
self.new_symbols = {}
# rvalue type. This property is set when encountering an assign operation,
# so that visiting nodes with Store ctx (typically found on left side of
# assignments) can infer the type they should receive.
self.rtype = None
def visit(self, node):
types = super().visit(node)
if __debug__:
self._check_set(types)
if types is not None:
# TODO(mdan): Normalize by removing subtypes.
anno.setanno(node, anno.Static.TYPES, tuple(types))
return types
def _check_set(self, value):
if value is not None and not isinstance(value, set):
raise ValueError('{} method expected to return set, got {}'.format(
self.resolver, value))
def visit_Constant(self, node):
types = self.resolver.res_value(self.namespace, node.value)
if __debug__:
self._check_set(types)
return types
def _apply_unpacking(self, node):
assert isinstance(node.ctx, gast.Store)
if self.rtype is not None:
original_stype = self.rtype
# TODO(mdan): Find a better way to express unpacking.
i_type = self.resolver.res_value(self.namespace, 0)
for i, elt in enumerate(node.elts):
self.rtype = self.resolver.res_slice(
self.namespace, self.types_in.types, i, original_stype, i_type)
self.visit(elt)
self.rtype = original_stype
return original_stype
return None
def visit_Tuple(self, node):
if isinstance(node.ctx, gast.Load):
elt_types = ()
for elt in node.elts:
types_ = self.visit(elt)
if types_ is None:
return None
elt_types += (types_,)
return set(itertools.product(*elt_types))
return self._apply_unpacking(node)
def visit_List(self, node):
if isinstance(node.ctx, gast.Load):
elt_types = tuple(self.visit(elt) for elt in node.elts)
return self.resolver.res_list_literal(self.namespace, elt_types)
return self._apply_unpacking(node)
def visit_Set(self, node):
raise NotImplementedError()
def visit_Name(self, node):
name = anno.getanno(node, anno.Basic.QN)
if isinstance(node.ctx, gast.Load):
types = self.types_in.types.get(name, None)
if types is None:
if (name not in self.scope.bound) or (name in self.scope.nonlocals):
# TODO(mdan): Test with global variables.
if name in self.closure_types:
types = self.closure_types[name]
else:
types, value = self.resolver.res_name(
self.namespace, self.types_in.types, name)
if value is not None:
anno.setanno(node, anno.Static.VALUE, value)
elif isinstance(node.ctx, gast.Param):
# The direct parent it the whole function scope. See activity.py.
f_is_local = self.scope.parent.parent is not None
type_name = anno.getanno(node.annotation, anno.Basic.QN, None)
types = self.resolver.res_arg(self.namespace, self.types_in.types,
self.scope.function_name, name, type_name,
f_is_local)
if types is not None:
self.new_symbols[name] = types
elif isinstance(node.ctx, gast.Store):
if self.rtype is not None:
self.new_symbols[name] = self.rtype
types = self.rtype
else:
assert False, 'unknown ctx'
if __debug__:
self._check_set(types)
return types
def visit_Attribute(self, node):
parent_types = self.visit(node.value)
# Attempt to use the static value if known.
parent_value = anno.Static.VALUE.of(node.value, None)
if parent_value is not None:
static_value = getattr(parent_value, node.attr, NO_VALUE)
if static_value is NO_VALUE:
# Unexpected failure to resolve attribute. Ask the resolver about the
# full name instead.
types, static_value = self.resolver.res_name(
self.namespace, self.types_in, anno.Basic.QN.of(node))
anno.setanno(node, anno.Static.VALUE, static_value)
if __debug__:
self._check_set(types)
return types
else:
# Fall back to the type if that is known.
if parent_types is None:
return None
inferred_values = [getattr(t, node.attr, None) for t in parent_types]
if not inferred_values:
return None
static_value = inferred_values[0]
if static_value is None:
return None
if any(v is not static_value for v in inferred_values[1:]):
# Static value not stable, assume it's dynamic.
return None
types = self.resolver.res_value(self.namespace, static_value)
anno.setanno(node, anno.Static.VALUE, static_value)
if __debug__:
self._check_set(types)
return types
def visit_FunctionDef(self, node):
f_name = qual_names.QN(node.name)
if node.decorator_list:
raise NotImplementedError('decorators: {}'.format(node.decorator_list))
ret_types = None
if node.returns:
ret_types, _ = self.resolver.res_name(
self.namespace, self.types_in.types, anno.Basic.QN.of(node.returns))
if __debug__:
self._check_set(ret_types)
if ret_types is None:
ret_types = {Any}
f_types = set()
for rt in ret_types:
f_types.add(Callable[[Any], rt])
self.new_symbols[f_name] = f_types
# The definition of a function is an expression, hence has no return value.
return None
def _resolve_typed_callable(self, f_types, arg_types, keyword_types):
ret_types = set()
for t in f_types:
if isinstance(t, Callable):
# Note: these are undocummented - may be version-specific!
# Callable[[x], y]: __args__ are (x, y)
args = t.__args__
if args:
ret_types.add(args[-1])
else:
ret_types.add(Any)
else:
raise NotImplementedError('callable type {}'.format(type(t)))
# Side effects can not be inferred based on type alone.
side_effects = None
return ret_types, side_effects
def visit_Call(self, node):
self.visit(node.func)
f_name = anno.Basic.QN.of(node.func)
arg_types = [self.visit(a) for a in node.args]
keyword_types = [self.visit(kw.value) for kw in node.keywords]
if f_name in self.scope.bound:
# Local function, use local type definitions, if available.
f_type = self.types_in.types.get(f_name, None)
if f_type is None:
# No static type info available, nothing more to do.
ret_type, side_effects = None, None
else:
ret_type, side_effects = self._resolve_typed_callable(
f_type, arg_types, keyword_types)
else:
# Nonlocal function, resolve externally.
f_type = anno.Static.TYPES.of(node.func, None)
ret_type, side_effects = self.resolver.res_call(self.namespace,
self.types_in.types, node,
f_type, arg_types,
keyword_types)
if __debug__:
self._check_set(ret_type)
if side_effects:
if not isinstance(side_effects, dict):
raise ValueError(
'side effects must be dict, got {}'.format(side_effects))
for k, v in side_effects.items():
if not isinstance(k, qual_names.QN):
raise ValueError('side effect keys must be QNs, got {}'.format(k))
self._check_set(v)
if side_effects:
self.new_symbols.update(side_effects)
return ret_type
def visit_Expr(self, node):
return self.visit(node.value)
def visit_Assign(self, node):
self.rtype = self.visit(node.value)
for t in node.targets:
self.visit(t)
self.rtype = None
def visit_Subscript(self, node):
val_types = self.visit(node.value)
slice_types = self.visit(node.slice)
if val_types is None or slice_types is None:
return None
types = self.resolver.res_slice(
self.namespace, self.types_in.types, node, val_types, slice_types)
if __debug__:
self._check_set(types)
return types
def visit_Compare(self, node):
left_types = self.visit(node.left)
right_types = [self.visit(c) for c in node.comparators]
if left_types is None or any(t is None for t in right_types):
return None
types = self.resolver.res_compare(
self.namespace, self.types_in.types, node, left_types, right_types)
if __debug__:
self._check_set(types)
return types
def visit_BinOp(self, node):
left_types = self.visit(node.left)
right_types = self.visit(node.right)
if left_types is None or right_types is None:
return None
types = self.resolver.res_binop(
self.namespace, self.types_in.types, node, left_types, right_types)
if __debug__:
self._check_set(types)
return types
def visit_UnaryOp(self, node):
opnd_types = self.visit(node.operand)
if opnd_types is None:
return None
types = self.resolver.res_unop(
self.namespace, self.types_in.types, node, opnd_types)
if __debug__:
self._check_set(types)
return types
class Analyzer(cfg.GraphVisitor):
"""CFG visitor that propagates type information across statements."""
def __init__(self, graph, resolver, namespace, scope, closure_types):
"""Creates a new analyzer.
Args:
graph: cfg.Graph
resolver: Resolver
namespace: Dict[str, Any]
scope: activity.Scope
closure_types: Dict[QN, Set]
"""
super(Analyzer, self).__init__(graph)
self.resolver = resolver
self.namespace = namespace
self.scope = scope
self.closure_types = closure_types
context_types = {
n: t for n, t in closure_types.items() if n not in scope.bound
}
if context_types:
self.context_types = _TypeMap()
self.context_types.types = context_types
else:
self.context_types = None
def init_state(self, _):
return _TypeMap()
def _update_closure_types(self, ast_node, types):
existing_types = anno.Static.CLOSURE_TYPES.of(ast_node, None)
if existing_types is None:
existing_types = {}
anno.Static.CLOSURE_TYPES.add_to(ast_node, existing_types)
for k, v in types.types.items():
if k in existing_types:
existing_types[k].update(v)
else:
existing_types[k] = set(v)
def visit_node(self, node):
prev_types_out = self.out[node]
types_in = _TypeMap()
for n in node.prev:
types_in |= self.out[n]
if (self.context_types is not None) and (node is self.graph.entry):
types_in |= self.context_types
types_out = _TypeMap(types_in)
ast_node = node.ast_node
inferrer = StmtInferrer(self.resolver, self.scope, self.namespace,
self.closure_types, types_in)
inferrer.visit(ast_node)
types_out.types.update(inferrer.new_symbols)
reaching_fndefs = anno.Static.DEFINED_FNS_IN.of(ast_node)
node_scope = anno.Static.SCOPE.of(ast_node, None)
if node_scope is not None:
# TODO(mdan): Check that it's actually safe to skip nodes without scope.
reads = {str(qn) for qn in node_scope.read}
for def_node in reaching_fndefs:
if def_node.name in reads:
self._update_closure_types(def_node, types_out)
self.in_[node] = types_in
self.out[node] = types_out
return prev_types_out != types_out
class FunctionVisitor(transformer.Base):
"""AST visitor that applies type inference to each function separately."""
def __init__(self, source_info, graphs, resolver):
super(FunctionVisitor, self).__init__(source_info)
self.graphs = graphs
self.resolver = resolver
def visit_FunctionDef(self, node):
subgraph = self.graphs[node]
scope = anno.getanno(node, annos.NodeAnno.ARGS_AND_BODY_SCOPE)
closure_types = anno.getanno(node, anno.Static.CLOSURE_TYPES, {})
analyzer = Analyzer(subgraph, self.resolver, self.ctx.info.namespace, scope,
closure_types)
analyzer.visit_forward()
# Recursively process any remaining subfunctions.
node.body = self.visit_block(node.body)
return node
def resolve(node, source_info, graphs, resolver):
"""Performs type inference.
Args:
node: ast.AST
source_info: transformer.SourceInfo
graphs: Dict[ast.FunctionDef, cfg.Graph]
resolver: Resolver
Returns:
ast.AST
"""
visitor = FunctionVisitor(source_info, graphs, resolver)
node = visitor.visit(node)
return node