3RNN/Lib/site-packages/pasta/base/token_generator.py

514 lines
18 KiB
Python
Raw Permalink Normal View History

2024-05-26 19:49:15 +02:00
# coding=utf-8
"""Token generator for analyzing source code in logical units.
This module contains the TokenGenerator used for annotating a parsed syntax tree
with source code formatting.
"""
# Copyright 2017 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import ast
import collections
import contextlib
import itertools
import tokenize
from six import StringIO
from pasta.base import formatting as fmt
from pasta.base import fstring_utils
# Alias for extracting token names
TOKENS = tokenize
Token = collections.namedtuple('Token', ('type', 'src', 'start', 'end', 'line'))
FORMATTING_TOKENS = (TOKENS.INDENT, TOKENS.DEDENT, TOKENS.NL, TOKENS.NEWLINE,
TOKENS.COMMENT)
class TokenGenerator(object):
"""Helper for sequentially parsing Python source code, token by token.
Holds internal state during parsing, including:
_tokens: List of tokens in the source code, as parsed by `tokenize` module.
_parens: Stack of open parenthesis at the current point in parsing.
_hints: Number of open parentheses, brackets, etc. at the current point.
_scope_stack: Stack containing tuples of nodes where the last parenthesis that
was open is related to one of the nodes on the top of the stack.
_lines: Full lines of the source code.
_i: Index of the last token that was parsed. Initially -1.
_loc: (lineno, column_offset) pair of the position in the source that has been
parsed to. This should be either the start or end of the token at index _i.
Arguments:
ignore_error_tokens: If True, will ignore error tokens. Otherwise, an error
token will cause an exception. This is useful when the source being parsed
contains invalid syntax, e.g. if it is in an fstring context.
"""
def __init__(self, source, ignore_error_token=False):
self.lines = source.splitlines(True)
self._tokens = list(_generate_tokens(source, ignore_error_token))
self._parens = []
self._hints = 0
self._scope_stack = []
self._len = len(self._tokens)
self._i = -1
self._loc = self.loc_begin()
def chars_consumed(self):
return len(self._space_between((1, 0), self._tokens[self._i].end))
def loc_begin(self):
"""Get the start column of the current location parsed to."""
if self._i < 0:
return (1, 0)
return self._tokens[self._i].start
def loc_end(self):
"""Get the end column of the current location parsed to."""
if self._i < 0:
return (1, 0)
return self._tokens[self._i].end
def peek(self):
"""Get the next token without advancing."""
if self._i + 1 >= self._len:
return None
return self._tokens[self._i + 1]
def peek_non_whitespace(self):
"""Get the next non-whitespace token without advancing."""
return self.peek_conditional(lambda t: t.type not in FORMATTING_TOKENS)
def peek_conditional(self, condition):
"""Get the next token of the given type without advancing."""
return next((t for t in self._tokens[self._i + 1:] if condition(t)), None)
def next(self, advance=True):
"""Consume the next token and optionally advance the current location."""
self._i += 1
if self._i >= self._len:
return None
if advance:
self._loc = self._tokens[self._i].end
return self._tokens[self._i]
def rewind(self, amount=1):
"""Rewind the token iterator."""
self._i -= amount
def whitespace(self, max_lines=None, comment=False):
"""Parses whitespace from the current _loc to the next non-whitespace.
Arguments:
max_lines: (optional int) Maximum number of lines to consider as part of
the whitespace. Valid values are None, 0 and 1.
comment: (boolean) If True, look for a trailing comment even when not in
a parenthesized scope.
Pre-condition:
`_loc' represents the point before which everything has been parsed and
after which nothing has been parsed.
Post-condition:
`_loc' is exactly at the character that was parsed to.
"""
next_token = self.peek()
if not comment and next_token and next_token.type == TOKENS.COMMENT:
return ''
def predicate(token):
return (token.type in (TOKENS.INDENT, TOKENS.DEDENT) or
token.type == TOKENS.COMMENT and (comment or self._hints) or
token.type == TOKENS.ERRORTOKEN and token.src == ' ' or
max_lines is None and token.type in (TOKENS.NL, TOKENS.NEWLINE))
whitespace = list(self.takewhile(predicate, advance=False))
next_token = self.peek()
result = ''
for tok in itertools.chain(whitespace,
((next_token,) if next_token else ())):
result += self._space_between(self._loc, tok.start)
if tok != next_token:
result += tok.src
self._loc = tok.end
else:
self._loc = tok.start
# Eat a single newline character
if ((max_lines is None or max_lines > 0) and
next_token and next_token.type in (TOKENS.NL, TOKENS.NEWLINE)):
result += self.next().src
return result
def block_whitespace(self, indent_level):
"""Parses whitespace from the current _loc to the end of the block."""
# Get the normal suffix lines, but don't advance the token index unless
# there is no indentation to account for
start_i = self._i
full_whitespace = self.whitespace(comment=True)
if not indent_level:
return full_whitespace
self._i = start_i
# Trim the full whitespace into only lines that match the indentation level
lines = full_whitespace.splitlines(True)
try:
last_line_idx = next(i for i, line in reversed(list(enumerate(lines)))
if line.startswith(indent_level + '#'))
except StopIteration:
# No comment lines at the end of this block
self._loc = self._tokens[self._i].end
return ''
lines = lines[:last_line_idx + 1]
# Advance the current location to the last token in the lines we've read
end_line = self._tokens[self._i].end[0] + 1 + len(lines)
list(self.takewhile(lambda tok: tok.start[0] < end_line))
self._loc = self._tokens[self._i].end
return ''.join(lines)
def dots(self, num_dots):
"""Parse a number of dots.
This is to work around an oddity in python3's tokenizer, which treats three
`.` tokens next to each other in a FromImport's level as an ellipsis. This
parses until the expected number of dots have been seen.
"""
result = ''
dots_seen = 0
prev_loc = self._loc
while dots_seen < num_dots:
tok = self.next()
assert tok.src in ('.', '...')
result += self._space_between(prev_loc, tok.start) + tok.src
dots_seen += tok.src.count('.')
prev_loc = self._loc
return result
def open_scope(self, node, single_paren=False):
"""Open a parenthesized scope on the given node."""
result = ''
parens = []
start_i = self._i
start_loc = prev_loc = self._loc
# Eat whitespace or '(' tokens one at a time
for tok in self.takewhile(
lambda t: t.type in FORMATTING_TOKENS or t.src == '('):
# Stores all the code up to and including this token
result += self._space_between(prev_loc, tok.start)
if tok.src == '(' and single_paren and parens:
self.rewind()
self._loc = tok.start
break
result += tok.src
if tok.src == '(':
# Start a new scope
parens.append(result)
result = ''
start_i = self._i
start_loc = self._loc
prev_loc = self._loc
if parens:
# Add any additional whitespace on to the last open-paren
next_tok = self.peek()
parens[-1] += result + self._space_between(self._loc, next_tok.start)
self._loc = next_tok.start
# Add each paren onto the stack
for paren in parens:
self._parens.append(paren)
self._scope_stack.append(_scope_helper(node))
else:
# No parens were encountered, then reset like this method did nothing
self._i = start_i
self._loc = start_loc
def close_scope(self, node, prefix_attr='prefix', suffix_attr='suffix',
trailing_comma=False, single_paren=False):
"""Close a parenthesized scope on the given node, if one is open."""
# Ensures the prefix + suffix are not None
if fmt.get(node, prefix_attr) is None:
fmt.set(node, prefix_attr, '')
if fmt.get(node, suffix_attr) is None:
fmt.set(node, suffix_attr, '')
if not self._parens or node not in self._scope_stack[-1]:
return
symbols = {')'}
if trailing_comma:
symbols.add(',')
parsed_to_i = self._i
parsed_to_loc = prev_loc = self._loc
encountered_paren = False
result = ''
for tok in self.takewhile(
lambda t: t.type in FORMATTING_TOKENS or t.src in symbols):
# Consume all space up to this token
result += self._space_between(prev_loc, tok.start)
if tok.src == ')' and single_paren and encountered_paren:
self.rewind()
parsed_to_i = self._i
parsed_to_loc = tok.start
fmt.append(node, suffix_attr, result)
break
# Consume the token itself
result += tok.src
if tok.src == ')':
# Close out the open scope
encountered_paren = True
self._scope_stack.pop()
fmt.prepend(node, prefix_attr, self._parens.pop())
fmt.append(node, suffix_attr, result)
result = ''
parsed_to_i = self._i
parsed_to_loc = tok.end
if not self._parens or node not in self._scope_stack[-1]:
break
prev_loc = tok.end
# Reset back to the last place where we parsed anything
self._i = parsed_to_i
self._loc = parsed_to_loc
def hint_open(self):
"""Indicates opening a group of parentheses or brackets."""
self._hints += 1
def hint_closed(self):
"""Indicates closing a group of parentheses or brackets."""
self._hints -= 1
if self._hints < 0:
raise ValueError('Hint value negative')
@contextlib.contextmanager
def scope(self, node, attr=None, trailing_comma=False):
"""Context manager to handle a parenthesized scope."""
self.open_scope(node, single_paren=(attr is not None))
yield
if attr:
self.close_scope(node, prefix_attr=attr + '_prefix',
suffix_attr=attr + '_suffix',
trailing_comma=trailing_comma,
single_paren=True)
else:
self.close_scope(node, trailing_comma=trailing_comma)
def is_in_scope(self):
"""Return True iff there is a scope open."""
return self._parens or self._hints
def str(self):
"""Parse a full string literal from the input."""
def predicate(token):
return (token.type in (TOKENS.STRING, TOKENS.COMMENT) or
self.is_in_scope() and token.type in (TOKENS.NL, TOKENS.NEWLINE))
return self.eat_tokens(predicate)
def eat_tokens(self, predicate):
"""Parse input from tokens while a given condition is met."""
content = ''
prev_loc = self._loc
tok = None
for tok in self.takewhile(predicate, advance=False):
content += self._space_between(prev_loc, tok.start)
content += tok.src
prev_loc = tok.end
if tok:
self._loc = tok.end
return content
def fstr(self):
"""Parses an fstring, including subexpressions.
Returns:
A generator function which, when repeatedly reads a chunk of the fstring
up until the next subexpression and yields that chunk, plus a new token
generator to use to parse the subexpression. The subexpressions in the
original fstring data are replaced by placeholders to make it possible to
fill them in with new values, if desired.
"""
def fstr_parser():
# Reads the whole fstring as a string, then parses it char by char
if self.peek_non_whitespace().type == TOKENS.STRING:
# Normal fstrings are one ore more STRING tokens, maybe mixed with
# spaces, e.g.: f"Hello, {name}"
str_content = self.str()
else:
# Format specifiers in fstrings are also JoinedStr nodes, but these are
# arbitrary expressions, e.g. in: f"{value:{width}.{precision}}", the
# format specifier is an fstring: "{width}.{precision}" but these are
# not STRING tokens.
def fstr_eater(tok):
if tok.type == TOKENS.OP and tok.src == '}':
if fstr_eater.level <= 0:
return False
fstr_eater.level -= 1
if tok.type == TOKENS.OP and tok.src == '{':
fstr_eater.level += 1
return True
fstr_eater.level = 0
str_content = self.eat_tokens(fstr_eater)
indexed_chars = enumerate(str_content)
val_idx = 0
i = -1
result = ''
while i < len(str_content) - 1:
i, c = next(indexed_chars)
result += c
# When an open bracket is encountered, start parsing a subexpression
if c == '{':
# First check if this is part of an escape sequence
# (f"{{" is used to escape a bracket literal)
nexti, nextc = next(indexed_chars)
if nextc == '{':
result += c
continue
indexed_chars = itertools.chain([(nexti, nextc)], indexed_chars)
# Add a placeholder onto the result
result += fstring_utils.placeholder(val_idx) + '}'
val_idx += 1
# Yield a new token generator to parse the subexpression only
tg = TokenGenerator(str_content[i+1:], ignore_error_token=True)
yield (result, tg)
result = ''
# Skip the number of characters consumed by the subexpression
for tg_i in range(tg.chars_consumed()):
i, c = next(indexed_chars)
# Eat up to and including the close bracket
i, c = next(indexed_chars)
while c != '}':
i, c = next(indexed_chars)
# Yield the rest of the fstring, when done
yield (result, None)
return fstr_parser
def _space_between(self, start_loc, end_loc):
"""Parse the space between a location and the next token"""
if start_loc > end_loc:
raise ValueError('start_loc > end_loc', start_loc, end_loc)
if start_loc[0] > len(self.lines):
return ''
prev_row, prev_col = start_loc
end_row, end_col = end_loc
if prev_row == end_row:
return self.lines[prev_row - 1][prev_col:end_col]
return ''.join(itertools.chain(
(self.lines[prev_row - 1][prev_col:],),
self.lines[prev_row:end_row - 1],
(self.lines[end_row - 1][:end_col],) if end_col > 0 else '',
))
def next_name(self):
"""Parse the next name token."""
last_i = self._i
def predicate(token):
return token.type != TOKENS.NAME
unused_tokens = list(self.takewhile(predicate, advance=False))
result = self.next(advance=False)
self._i = last_i
return result
def next_of_type(self, token_type):
"""Parse a token of the given type and return it."""
token = self.next()
if token.type != token_type:
raise ValueError("Expected %r but found %r\nline %d: %s" % (
tokenize.tok_name[token_type], token.src, token.start[0],
self.lines[token.start[0] - 1]))
return token
def takewhile(self, condition, advance=True):
"""Parse tokens as long as a condition holds on the next token."""
prev_loc = self._loc
token = self.next(advance=advance)
while token is not None and condition(token):
yield token
prev_loc = self._loc
token = self.next(advance=advance)
self.rewind()
self._loc = prev_loc
def _scope_helper(node):
"""Get the closure of nodes that could begin a scope at this point.
For instance, when encountering a `(` when parsing a BinOp node, this could
indicate that the BinOp itself is parenthesized OR that the BinOp's left node
could be parenthesized.
E.g.: (a + b * c) or (a + b) * c or (a) + b * c
^ ^ ^
Arguments:
node: (ast.AST) Node encountered when opening a scope.
Returns:
A closure of nodes which that scope might apply to.
"""
if isinstance(node, ast.Attribute):
return (node,) + _scope_helper(node.value)
if isinstance(node, ast.Subscript):
return (node,) + _scope_helper(node.value)
if isinstance(node, ast.Assign):
return (node,) + _scope_helper(node.targets[0])
if isinstance(node, ast.AugAssign):
return (node,) + _scope_helper(node.target)
if isinstance(node, ast.Expr):
return (node,) + _scope_helper(node.value)
if isinstance(node, ast.Compare):
return (node,) + _scope_helper(node.left)
if isinstance(node, ast.BoolOp):
return (node,) + _scope_helper(node.values[0])
if isinstance(node, ast.BinOp):
return (node,) + _scope_helper(node.left)
if isinstance(node, ast.Tuple) and node.elts:
return (node,) + _scope_helper(node.elts[0])
if isinstance(node, ast.Call):
return (node,) + _scope_helper(node.func)
if isinstance(node, ast.GeneratorExp):
return (node,) + _scope_helper(node.elt)
if isinstance(node, ast.IfExp):
return (node,) + _scope_helper(node.body)
return (node,)
def _generate_tokens(source, ignore_error_token=False):
token_generator = tokenize.generate_tokens(StringIO(source).readline)
try:
for tok in token_generator:
yield Token(*tok)
except tokenize.TokenError:
if not ignore_error_token:
raise