90 lines
3.0 KiB
Python
90 lines
3.0 KiB
Python
# coding=utf-8
|
|
"""Useful stuff for tests."""
|
|
# Copyright 2017 Google LLC
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# https://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
|
|
from __future__ import absolute_import
|
|
from __future__ import division
|
|
from __future__ import print_function
|
|
|
|
import ast
|
|
import sys
|
|
import unittest
|
|
|
|
from six.moves import zip
|
|
|
|
|
|
class TestCase(unittest.TestCase):
|
|
|
|
def checkAstsEqual(self, a, b):
|
|
"""Compares two ASTs and fails if there are differences.
|
|
|
|
Ignores `ctx` fields and formatting info.
|
|
"""
|
|
if a is None and b is None:
|
|
return
|
|
try:
|
|
self.assertIsNotNone(a)
|
|
self.assertIsNotNone(b)
|
|
for node_a, node_b in zip(ast.walk(a), ast.walk(b)):
|
|
self.assertEqual(type(node_a), type(node_b))
|
|
for field in type(node_a)()._fields:
|
|
a_val = getattr(node_a, field, None)
|
|
b_val = getattr(node_b, field, None)
|
|
|
|
if isinstance(a_val, list):
|
|
for item_a, item_b in zip(a_val, b_val):
|
|
self.checkAstsEqual(item_a, item_b)
|
|
elif isinstance(a_val, ast.AST) or isinstance(b_val, ast.AST):
|
|
if (not isinstance(a_val, (ast.Load, ast.Store, ast.Param)) and
|
|
not isinstance(b_val, (ast.Load, ast.Store, ast.Param))):
|
|
self.assertIsNotNone(a_val)
|
|
self.assertIsNotNone(b_val)
|
|
self.checkAstsEqual(a_val, b_val)
|
|
else:
|
|
self.assertEqual(a_val, b_val)
|
|
except AssertionError as ae:
|
|
self.fail('ASTs differ:\n%s\n !=\n%s\n\n%s' % (
|
|
ast.dump(a), ast.dump(b), ae))
|
|
|
|
|
|
if not hasattr(TestCase, 'assertItemsEqual'):
|
|
setattr(TestCase, 'assertItemsEqual', TestCase.assertCountEqual)
|
|
|
|
|
|
def requires_features(*features):
|
|
return unittest.skipIf(
|
|
any(not supports_feature(feature) for feature in features),
|
|
'Tests features which are not supported by this version of python. '
|
|
'Missing: %r' % [f for f in features if not supports_feature(f)])
|
|
|
|
|
|
def supports_feature(feature):
|
|
if feature == 'bytes_node':
|
|
return hasattr(ast, 'Bytes') and issubclass(ast.Bytes, ast.AST)
|
|
if feature == 'exec_node':
|
|
return hasattr(ast, 'Exec') and issubclass(ast.Exec, ast.AST)
|
|
if feature == 'type_annotations':
|
|
try:
|
|
ast.parse('def foo(bar: str=123) -> None: pass')
|
|
except SyntaxError:
|
|
return False
|
|
return True
|
|
if feature == 'fstring':
|
|
return hasattr(ast, 'JoinedStr') and issubclass(ast.JoinedStr, ast.AST)
|
|
# Python 2 counts tabs as 8 spaces for indentation
|
|
if feature == 'mixed_tabs_spaces':
|
|
return sys.version_info[0] < 3
|
|
return False
|