3RNN/Lib/site-packages/pasta/base/test_utils.py
2024-05-26 19:49:15 +02:00

90 lines
3.0 KiB
Python

# coding=utf-8
"""Useful stuff for tests."""
# Copyright 2017 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import ast
import sys
import unittest
from six.moves import zip
class TestCase(unittest.TestCase):
def checkAstsEqual(self, a, b):
"""Compares two ASTs and fails if there are differences.
Ignores `ctx` fields and formatting info.
"""
if a is None and b is None:
return
try:
self.assertIsNotNone(a)
self.assertIsNotNone(b)
for node_a, node_b in zip(ast.walk(a), ast.walk(b)):
self.assertEqual(type(node_a), type(node_b))
for field in type(node_a)()._fields:
a_val = getattr(node_a, field, None)
b_val = getattr(node_b, field, None)
if isinstance(a_val, list):
for item_a, item_b in zip(a_val, b_val):
self.checkAstsEqual(item_a, item_b)
elif isinstance(a_val, ast.AST) or isinstance(b_val, ast.AST):
if (not isinstance(a_val, (ast.Load, ast.Store, ast.Param)) and
not isinstance(b_val, (ast.Load, ast.Store, ast.Param))):
self.assertIsNotNone(a_val)
self.assertIsNotNone(b_val)
self.checkAstsEqual(a_val, b_val)
else:
self.assertEqual(a_val, b_val)
except AssertionError as ae:
self.fail('ASTs differ:\n%s\n !=\n%s\n\n%s' % (
ast.dump(a), ast.dump(b), ae))
if not hasattr(TestCase, 'assertItemsEqual'):
setattr(TestCase, 'assertItemsEqual', TestCase.assertCountEqual)
def requires_features(*features):
return unittest.skipIf(
any(not supports_feature(feature) for feature in features),
'Tests features which are not supported by this version of python. '
'Missing: %r' % [f for f in features if not supports_feature(f)])
def supports_feature(feature):
if feature == 'bytes_node':
return hasattr(ast, 'Bytes') and issubclass(ast.Bytes, ast.AST)
if feature == 'exec_node':
return hasattr(ast, 'Exec') and issubclass(ast.Exec, ast.AST)
if feature == 'type_annotations':
try:
ast.parse('def foo(bar: str=123) -> None: pass')
except SyntaxError:
return False
return True
if feature == 'fstring':
return hasattr(ast, 'JoinedStr') and issubclass(ast.JoinedStr, ast.AST)
# Python 2 counts tabs as 8 spaces for indentation
if feature == 'mixed_tabs_spaces':
return sys.version_info[0] < 3
return False