171 lines
5.2 KiB
Python
171 lines
5.2 KiB
Python
# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
# =============================================================================
|
|
"""A module to support operations on ipynb files"""
|
|
|
|
import collections
|
|
import copy
|
|
import json
|
|
import re
|
|
import shutil
|
|
import tempfile
|
|
|
|
CodeLine = collections.namedtuple("CodeLine", ["cell_number", "code"])
|
|
|
|
def is_python(cell):
|
|
"""Checks if the cell consists of Python code."""
|
|
return (cell["cell_type"] == "code" # code cells only
|
|
and cell["source"] # non-empty cells
|
|
and not cell["source"][0].startswith("%%")) # multiline eg: %%bash
|
|
|
|
|
|
def process_file(in_filename, out_filename, upgrader):
|
|
"""The function where we inject the support for ipynb upgrade."""
|
|
print("Extracting code lines from original notebook")
|
|
raw_code, notebook = _get_code(in_filename)
|
|
raw_lines = [cl.code for cl in raw_code]
|
|
|
|
# The function follows the original flow from `upgrader.process_fil`
|
|
with tempfile.NamedTemporaryFile("w", delete=False) as temp_file:
|
|
|
|
processed_file, new_file_content, log, process_errors = (
|
|
upgrader.update_string_pasta("\n".join(raw_lines), in_filename))
|
|
|
|
if temp_file and processed_file:
|
|
new_notebook = _update_notebook(notebook, raw_code,
|
|
new_file_content.split("\n"))
|
|
json.dump(new_notebook, temp_file)
|
|
else:
|
|
raise SyntaxError(
|
|
"Was not able to process the file: \n%s\n" % "".join(log))
|
|
|
|
files_processed = processed_file
|
|
report_text = upgrader._format_log(log, in_filename, out_filename)
|
|
errors = process_errors
|
|
|
|
shutil.move(temp_file.name, out_filename)
|
|
|
|
return files_processed, report_text, errors
|
|
|
|
|
|
def skip_magic(code_line, magic_list):
|
|
"""Checks if the cell has magic, that is not Python-based.
|
|
|
|
Args:
|
|
code_line: A line of Python code
|
|
magic_list: A list of jupyter "magic" exceptions
|
|
|
|
Returns:
|
|
If the line jupyter "magic" line, not Python line
|
|
|
|
>>> skip_magic('!ls -laF', ['%', '!', '?'])
|
|
True
|
|
"""
|
|
|
|
for magic in magic_list:
|
|
if code_line.startswith(magic):
|
|
return True
|
|
|
|
return False
|
|
|
|
|
|
def check_line_split(code_line):
|
|
r"""Checks if a line was split with `\`.
|
|
|
|
Args:
|
|
code_line: A line of Python code
|
|
|
|
Returns:
|
|
If the line was split with `\`
|
|
|
|
>>> skip_magic("!gcloud ml-engine models create ${MODEL} \\\n")
|
|
True
|
|
"""
|
|
|
|
return re.search(r"\\\s*\n$", code_line)
|
|
|
|
|
|
def _get_code(input_file):
|
|
"""Loads the ipynb file and returns a list of CodeLines."""
|
|
|
|
raw_code = []
|
|
|
|
with open(input_file) as in_file:
|
|
notebook = json.load(in_file)
|
|
|
|
cell_index = 0
|
|
for cell in notebook["cells"]:
|
|
if is_python(cell):
|
|
cell_lines = cell["source"]
|
|
|
|
is_line_split = False
|
|
for line_idx, code_line in enumerate(cell_lines):
|
|
|
|
# Sometimes, jupyter has more than python code
|
|
# Idea is to comment these lines, for upgrade time
|
|
if skip_magic(code_line, ["%", "!", "?"]) or is_line_split:
|
|
# Found a special character, need to "encode"
|
|
code_line = "###!!!" + code_line
|
|
|
|
# if this cell ends with `\` -> skip the next line
|
|
is_line_split = check_line_split(code_line)
|
|
|
|
if is_line_split:
|
|
is_line_split = check_line_split(code_line)
|
|
|
|
# Sometimes, people leave \n at the end of cell
|
|
# in order to migrate only related things, and make the diff
|
|
# the smallest -> here is another hack
|
|
if (line_idx == len(cell_lines) - 1) and code_line.endswith("\n"):
|
|
code_line = code_line.replace("\n", "###===")
|
|
|
|
# sometimes a line would start with `\n` and content after
|
|
# that's the hack for this
|
|
raw_code.append(
|
|
CodeLine(cell_index,
|
|
code_line.rstrip().replace("\n", "###===")))
|
|
|
|
cell_index += 1
|
|
|
|
return raw_code, notebook
|
|
|
|
|
|
def _update_notebook(original_notebook, original_raw_lines, updated_code_lines):
|
|
"""Updates notebook, once migration is done."""
|
|
|
|
new_notebook = copy.deepcopy(original_notebook)
|
|
|
|
# validate that the number of lines is the same
|
|
assert len(original_raw_lines) == len(updated_code_lines), \
|
|
("The lengths of input and converted files are not the same: "
|
|
"{} vs {}".format(len(original_raw_lines), len(updated_code_lines)))
|
|
|
|
code_cell_idx = 0
|
|
for cell in new_notebook["cells"]:
|
|
if not is_python(cell):
|
|
continue
|
|
|
|
applicable_lines = [
|
|
idx for idx, code_line in enumerate(original_raw_lines)
|
|
if code_line.cell_number == code_cell_idx
|
|
]
|
|
|
|
new_code = [updated_code_lines[idx] for idx in applicable_lines]
|
|
|
|
cell["source"] = "\n".join(new_code).replace("###!!!", "").replace(
|
|
"###===", "\n")
|
|
code_cell_idx += 1
|
|
|
|
return new_notebook
|