3RNN/Lib/site-packages/tensorflow/tools/compatibility/ipynb.py
2024-05-26 19:49:15 +02:00

171 lines
5.2 KiB
Python

# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# =============================================================================
"""A module to support operations on ipynb files"""
import collections
import copy
import json
import re
import shutil
import tempfile
CodeLine = collections.namedtuple("CodeLine", ["cell_number", "code"])
def is_python(cell):
"""Checks if the cell consists of Python code."""
return (cell["cell_type"] == "code" # code cells only
and cell["source"] # non-empty cells
and not cell["source"][0].startswith("%%")) # multiline eg: %%bash
def process_file(in_filename, out_filename, upgrader):
"""The function where we inject the support for ipynb upgrade."""
print("Extracting code lines from original notebook")
raw_code, notebook = _get_code(in_filename)
raw_lines = [cl.code for cl in raw_code]
# The function follows the original flow from `upgrader.process_fil`
with tempfile.NamedTemporaryFile("w", delete=False) as temp_file:
processed_file, new_file_content, log, process_errors = (
upgrader.update_string_pasta("\n".join(raw_lines), in_filename))
if temp_file and processed_file:
new_notebook = _update_notebook(notebook, raw_code,
new_file_content.split("\n"))
json.dump(new_notebook, temp_file)
else:
raise SyntaxError(
"Was not able to process the file: \n%s\n" % "".join(log))
files_processed = processed_file
report_text = upgrader._format_log(log, in_filename, out_filename)
errors = process_errors
shutil.move(temp_file.name, out_filename)
return files_processed, report_text, errors
def skip_magic(code_line, magic_list):
"""Checks if the cell has magic, that is not Python-based.
Args:
code_line: A line of Python code
magic_list: A list of jupyter "magic" exceptions
Returns:
If the line jupyter "magic" line, not Python line
>>> skip_magic('!ls -laF', ['%', '!', '?'])
True
"""
for magic in magic_list:
if code_line.startswith(magic):
return True
return False
def check_line_split(code_line):
r"""Checks if a line was split with `\`.
Args:
code_line: A line of Python code
Returns:
If the line was split with `\`
>>> skip_magic("!gcloud ml-engine models create ${MODEL} \\\n")
True
"""
return re.search(r"\\\s*\n$", code_line)
def _get_code(input_file):
"""Loads the ipynb file and returns a list of CodeLines."""
raw_code = []
with open(input_file) as in_file:
notebook = json.load(in_file)
cell_index = 0
for cell in notebook["cells"]:
if is_python(cell):
cell_lines = cell["source"]
is_line_split = False
for line_idx, code_line in enumerate(cell_lines):
# Sometimes, jupyter has more than python code
# Idea is to comment these lines, for upgrade time
if skip_magic(code_line, ["%", "!", "?"]) or is_line_split:
# Found a special character, need to "encode"
code_line = "###!!!" + code_line
# if this cell ends with `\` -> skip the next line
is_line_split = check_line_split(code_line)
if is_line_split:
is_line_split = check_line_split(code_line)
# Sometimes, people leave \n at the end of cell
# in order to migrate only related things, and make the diff
# the smallest -> here is another hack
if (line_idx == len(cell_lines) - 1) and code_line.endswith("\n"):
code_line = code_line.replace("\n", "###===")
# sometimes a line would start with `\n` and content after
# that's the hack for this
raw_code.append(
CodeLine(cell_index,
code_line.rstrip().replace("\n", "###===")))
cell_index += 1
return raw_code, notebook
def _update_notebook(original_notebook, original_raw_lines, updated_code_lines):
"""Updates notebook, once migration is done."""
new_notebook = copy.deepcopy(original_notebook)
# validate that the number of lines is the same
assert len(original_raw_lines) == len(updated_code_lines), \
("The lengths of input and converted files are not the same: "
"{} vs {}".format(len(original_raw_lines), len(updated_code_lines)))
code_cell_idx = 0
for cell in new_notebook["cells"]:
if not is_python(cell):
continue
applicable_lines = [
idx for idx, code_line in enumerate(original_raw_lines)
if code_line.cell_number == code_cell_idx
]
new_code = [updated_code_lines[idx] for idx in applicable_lines]
cell["source"] = "\n".join(new_code).replace("###!!!", "").replace(
"###===", "\n")
code_cell_idx += 1
return new_notebook