# Copyright 2017 The TensorFlow Authors. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== """Code transformation exceptions.""" import collections from tensorflow.python.autograph.pyct import origin_info from tensorflow.python.util import traceback_utils class FrameInfo( collections.namedtuple('FrameInfo', ('filename', 'lineno', 'function_name', 'code', 'is_converted', 'is_allowlisted'))): __slots__ = () def _stack_trace_inside_mapped_code(tb, source_map, converter_filename): """Summarizes inner traceback frames up to the call to a given function. This functions locates the innermost (i.e. most recent) frame that corresponds to code that can be mapped by source_map originated from, and returns a translated stack trace ending at that frame. If no such frame is found, the entire stack trace is summarized. For example, the following code: def f(): for i in tf.range(1): z = y + i # z only defined here Would generate this traceback: ag__.for_stmt(...) return _known_len_tf_for_stmt(iter_, extra_test, body, init_state) <_known_len_tf_for_stmt> _disallow_undefs_into_loop(*init_state) <_disallow_undefs_into_loop> raise ... Which is then processed into: for i in tf.range(1): return _known_len_tf_for_stmt(iter_, extra_test, body, init_state) <_known_len_tf_for_stmt> _disallow_undefs_into_loop(*init_state) <_disallow_undefs_into_loop> raise ... Args: tb: traceback.FrameSummary, The traceback corresponding to an error. Typically, the output of traceback.Summary.extract(capture_locals=True). source_map: Dict[LineLocation, OriginInfo], a source map as created by origin_info.create_source_map. converter_filename: str, the file path of the converted module. Call frames corresponding to this module are elided and their preceding frames are marked as allowlisted. Note that frames enclosing converted code are dropped using a different mechanism. Returns: List[FrameInfo] """ result_frames = [] for filename, line_number, function_name, text in reversed(tb): loc = origin_info.LineLocation(filename=filename, lineno=line_number) if loc in source_map: origin = source_map[loc] fi = FrameInfo( filename=origin.loc.filename, lineno=origin.loc.lineno, function_name=origin.function_name, code=origin.source_code_line, is_converted=True, is_allowlisted=False) result_frames.append(fi) break if filename == converter_filename: if result_frames: prev = result_frames[-1] assert not prev.is_converted # See the if above. fi = FrameInfo( filename=prev.filename, lineno=prev.lineno, function_name=prev.function_name, code=prev.code, is_converted=False, is_allowlisted=True) result_frames[-1] = fi continue fi = FrameInfo( filename=filename, lineno=line_number, function_name=function_name, code=text, is_converted=False, is_allowlisted=False) result_frames.append(fi) return tuple(result_frames) KNOWN_STRING_CONSTRUCTOR_ERRORS = ( AssertionError, AttributeError, NameError, NotImplementedError, RuntimeError, StopIteration, TypeError, UnboundLocalError, ValueError, ) # KeyError escapes newlines in strings. We create a special subclass # that doesn't do that. Overriding the name for display purposes; hopefully # that won't create too many surprises. class MultilineMessageKeyError(KeyError): def __init__(self, message, original_key): super(MultilineMessageKeyError, self).__init__(original_key) self.__message = message def __str__(self): return self.__message MultilineMessageKeyError.__name__ = KeyError.__name__ class ErrorMetadataBase(object): """Container objects attached to exceptions raised in user code. This metadata allows re-raising exceptions that occur in generated code, with a custom error message that includes a stack trace relative to user-readable code from which the generated code originated. """ __slots__ = ('translated_stack', 'cause_message') def __init__(self, callsite_tb, cause_metadata, cause_message, source_map, converter_filename): translated_stack = _stack_trace_inside_mapped_code( callsite_tb, source_map, converter_filename) if cause_metadata is None: self.translated_stack = translated_stack self.cause_message = cause_message else: # Daisy chain the translated stacks. self.translated_stack = ( cause_metadata.translated_stack + (translated_stack[-1],)) self.cause_message = cause_metadata.cause_message def get_message(self): """Returns the message for the underlying exception.""" lines = [] lines.append('in user code:') lines.append('') for frame_info in reversed(self.translated_stack): if (traceback_utils.is_traceback_filtering_enabled() and not traceback_utils.include_frame(frame_info.filename)): continue # Same format with Python traceback. formatted_line = (f' File "{frame_info.filename}", line ' f'{frame_info.lineno}, in {frame_info.function_name}') if frame_info.is_converted: formatted_line += ' *' elif frame_info.is_allowlisted: formatted_line += ' **' lines.append(formatted_line) if frame_info.code is None: code_snippet = '' else: code_snippet = frame_info.code.strip() lines.append(' {}'.format(code_snippet)) lines.append('') message_lines = self.cause_message.split('\n') for i in range(len(message_lines)): message_lines[i] = ' ' + message_lines[i] lines.extend(message_lines) lines.append('') return '\n'.join(lines) def create_exception(self, source_error): """Creates exception from source_error.""" preferred_type = type(source_error) to_ret = None if preferred_type.__init__ is Exception.__init__: to_ret = preferred_type(self.get_message()) if preferred_type in KNOWN_STRING_CONSTRUCTOR_ERRORS: to_ret = preferred_type(self.get_message()) elif preferred_type is KeyError: to_ret = MultilineMessageKeyError(self.get_message(), self.cause_message) if to_ret is not None: return to_ret.with_traceback(source_error.__traceback__) def to_exception(self, source_error): exc = self.create_exception(source_error) exc.__suppress_context__ = True exc.ag_error_metadata = self return exc