430 lines
14 KiB
Python
430 lines
14 KiB
Python
|
# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||
|
#
|
||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||
|
# you may not use this file except in compliance with the License.
|
||
|
# You may obtain a copy of the License at
|
||
|
#
|
||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||
|
#
|
||
|
# Unless required by applicable law or agreed to in writing, software
|
||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||
|
# See the License for the specific language governing permissions and
|
||
|
# limitations under the License.
|
||
|
# ==============================================================================
|
||
|
"""Progress tracker for uploader."""
|
||
|
|
||
|
|
||
|
import contextlib
|
||
|
from datetime import datetime
|
||
|
import sys
|
||
|
import time
|
||
|
|
||
|
|
||
|
def readable_time_string():
|
||
|
"""Get a human-readable time string for the present."""
|
||
|
return datetime.now().strftime("%Y-%m-%dT%H:%M:%S")
|
||
|
|
||
|
|
||
|
def readable_bytes_string(bytes):
|
||
|
"""Get a human-readable string for number of bytes."""
|
||
|
if bytes >= 2**20:
|
||
|
return "%.1f MB" % (float(bytes) / 2**20)
|
||
|
elif bytes >= 2**10:
|
||
|
return "%.1f kB" % (float(bytes) / 2**10)
|
||
|
else:
|
||
|
return "%d B" % bytes
|
||
|
|
||
|
|
||
|
class UploadStats:
|
||
|
"""Statistics of uploading."""
|
||
|
|
||
|
def __init__(self):
|
||
|
self._last_summarized_timestamp = time.time()
|
||
|
self._last_data_added_timestamp = 0
|
||
|
self._num_scalars = 0
|
||
|
self._num_tensors = 0
|
||
|
self._num_tensors_skipped = 0
|
||
|
self._tensor_bytes = 0
|
||
|
self._tensor_bytes_skipped = 0
|
||
|
self._num_blobs = 0
|
||
|
self._num_blobs_skipped = 0
|
||
|
self._blob_bytes = 0
|
||
|
self._blob_bytes_skipped = 0
|
||
|
self._plugin_names = set()
|
||
|
|
||
|
def add_scalars(self, num_scalars):
|
||
|
"""Add a batch of scalars.
|
||
|
|
||
|
Args:
|
||
|
num_scalars: Number of scalars uploaded in this batch.
|
||
|
"""
|
||
|
self._refresh_last_data_added_timestamp()
|
||
|
self._num_scalars += num_scalars
|
||
|
|
||
|
def add_tensors(
|
||
|
self,
|
||
|
num_tensors,
|
||
|
num_tensors_skipped,
|
||
|
tensor_bytes,
|
||
|
tensor_bytes_skipped,
|
||
|
):
|
||
|
"""Add a batch of tensors.
|
||
|
|
||
|
Args:
|
||
|
num_tensors: Number of tensors encountered in this batch, including
|
||
|
the ones skipped due to reasons such as large exceeding limit.
|
||
|
num_tensors: Number of tensors skipped. This describes a subset of
|
||
|
`num_tensors` and hence must be `<= num_tensors`.
|
||
|
tensor_bytes: Total byte size of tensors encountered in this batch,
|
||
|
including the skipped ones.
|
||
|
tensor_bytes_skipped: Total byte size of the tensors skipped due to
|
||
|
reasons such as size exceeding limit.
|
||
|
"""
|
||
|
assert num_tensors_skipped <= num_tensors
|
||
|
assert tensor_bytes_skipped <= tensor_bytes
|
||
|
self._refresh_last_data_added_timestamp()
|
||
|
self._num_tensors += num_tensors
|
||
|
self._num_tensors_skipped += num_tensors_skipped
|
||
|
self._tensor_bytes += tensor_bytes
|
||
|
self._tensor_bytes_skipped = tensor_bytes_skipped
|
||
|
|
||
|
def add_blob(self, blob_bytes, is_skipped):
|
||
|
"""Add a blob.
|
||
|
|
||
|
Args:
|
||
|
blob_bytes: Byte size of the blob.
|
||
|
is_skipped: Whether the uploading of the blob is skipped due to
|
||
|
reasons such as size exceeding limit.
|
||
|
"""
|
||
|
self._refresh_last_data_added_timestamp()
|
||
|
self._num_blobs += 1
|
||
|
self._blob_bytes += blob_bytes
|
||
|
if is_skipped:
|
||
|
self._num_blobs_skipped += 1
|
||
|
self._blob_bytes_skipped += blob_bytes
|
||
|
|
||
|
def add_plugin(self, plugin_name):
|
||
|
"""Add a plugin.
|
||
|
|
||
|
Args:
|
||
|
plugin_name: Name of the plugin.
|
||
|
"""
|
||
|
self._refresh_last_data_added_timestamp()
|
||
|
self._plugin_names.add(plugin_name)
|
||
|
|
||
|
@property
|
||
|
def num_scalars(self):
|
||
|
return self._num_scalars
|
||
|
|
||
|
@property
|
||
|
def num_tensors(self):
|
||
|
return self._num_tensors
|
||
|
|
||
|
@property
|
||
|
def num_tensors_skipped(self):
|
||
|
return self._num_tensors_skipped
|
||
|
|
||
|
@property
|
||
|
def tensor_bytes(self):
|
||
|
return self._tensor_bytes
|
||
|
|
||
|
@property
|
||
|
def tensor_bytes_skipped(self):
|
||
|
return self._tensor_bytes_skipped
|
||
|
|
||
|
@property
|
||
|
def num_blobs(self):
|
||
|
return self._num_blobs
|
||
|
|
||
|
@property
|
||
|
def num_blobs_skipped(self):
|
||
|
return self._num_blobs_skipped
|
||
|
|
||
|
@property
|
||
|
def blob_bytes(self):
|
||
|
return self._blob_bytes
|
||
|
|
||
|
@property
|
||
|
def blob_bytes_skipped(self):
|
||
|
return self._blob_bytes_skipped
|
||
|
|
||
|
@property
|
||
|
def plugin_names(self):
|
||
|
return self._plugin_names
|
||
|
|
||
|
def has_data(self):
|
||
|
"""Has any data been tracked by this instance.
|
||
|
|
||
|
This counts the tensor and blob data that have been scanned
|
||
|
but skipped.
|
||
|
|
||
|
Returns:
|
||
|
Whether this stats tracking object has tracked any data.
|
||
|
"""
|
||
|
return (
|
||
|
self._num_scalars > 0
|
||
|
or self._num_tensors > 0
|
||
|
or self._num_blobs > 0
|
||
|
)
|
||
|
|
||
|
def summarize(self):
|
||
|
"""Get a summary string for actually-uploaded and skipped data.
|
||
|
|
||
|
Calling this property also marks the "last_summarized" timestamp, so that
|
||
|
the has_new_data_since_last_summarize() will be able to report the correct value
|
||
|
later.
|
||
|
|
||
|
Returns:
|
||
|
A tuple with two items:
|
||
|
- A string summarizing all data uploaded so far.
|
||
|
- If any data was skipped, a string for all skipped data. Else, `None`.
|
||
|
"""
|
||
|
self._last_summarized_timestamp = time.time()
|
||
|
string_pieces = []
|
||
|
string_pieces.append("%d scalars" % self._num_scalars)
|
||
|
uploaded_tensor_count = self._num_tensors - self._num_tensors_skipped
|
||
|
uploaded_tensor_bytes = self._tensor_bytes - self._tensor_bytes_skipped
|
||
|
string_pieces.append(
|
||
|
"0 tensors"
|
||
|
if not uploaded_tensor_count
|
||
|
else (
|
||
|
"%d tensors (%s)"
|
||
|
% (
|
||
|
uploaded_tensor_count,
|
||
|
readable_bytes_string(uploaded_tensor_bytes),
|
||
|
)
|
||
|
)
|
||
|
)
|
||
|
uploaded_blob_count = self._num_blobs - self._num_blobs_skipped
|
||
|
uploaded_blob_bytes = self._blob_bytes - self._blob_bytes_skipped
|
||
|
string_pieces.append(
|
||
|
"0 binary objects"
|
||
|
if not uploaded_blob_count
|
||
|
else (
|
||
|
"%d binary objects (%s)"
|
||
|
% (
|
||
|
uploaded_blob_count,
|
||
|
readable_bytes_string(uploaded_blob_bytes),
|
||
|
)
|
||
|
)
|
||
|
)
|
||
|
skipped_string = (
|
||
|
self._skipped_summary() if self._skipped_any() else None
|
||
|
)
|
||
|
return ", ".join(string_pieces), skipped_string
|
||
|
|
||
|
def _skipped_any(self):
|
||
|
"""Whether any data was skipped."""
|
||
|
return self._num_tensors_skipped or self._num_blobs_skipped
|
||
|
|
||
|
def has_new_data_since_last_summarize(self):
|
||
|
return self._last_data_added_timestamp > self._last_summarized_timestamp
|
||
|
|
||
|
def _skipped_summary(self):
|
||
|
"""Get a summary string for skipped data."""
|
||
|
string_pieces = []
|
||
|
if self._num_tensors_skipped:
|
||
|
string_pieces.append(
|
||
|
"%d tensors (%s)"
|
||
|
% (
|
||
|
self._num_tensors_skipped,
|
||
|
readable_bytes_string(self._tensor_bytes_skipped),
|
||
|
)
|
||
|
)
|
||
|
if self._num_blobs_skipped:
|
||
|
string_pieces.append(
|
||
|
"%d binary objects (%s)"
|
||
|
% (
|
||
|
self._num_blobs_skipped,
|
||
|
readable_bytes_string(self._blob_bytes_skipped),
|
||
|
)
|
||
|
)
|
||
|
return ", ".join(string_pieces)
|
||
|
|
||
|
def _refresh_last_data_added_timestamp(self):
|
||
|
self._last_data_added_timestamp = time.time()
|
||
|
|
||
|
|
||
|
_STYLE_RESET = "\033[0m"
|
||
|
_STYLE_BOLD = "\033[1m"
|
||
|
_STYLE_GREEN = "\033[32m"
|
||
|
_STYLE_YELLOW = "\033[33m"
|
||
|
_STYLE_DARKGRAY = "\033[90m"
|
||
|
_STYLE_ERASE_LINE = "\033[2K"
|
||
|
|
||
|
|
||
|
class UploadTracker:
|
||
|
"""Tracker for uploader progress and status."""
|
||
|
|
||
|
_SUPPORTED_VERBISITY_VALUES = (0, 1)
|
||
|
|
||
|
def __init__(self, verbosity, one_shot=False):
|
||
|
if verbosity not in self._SUPPORTED_VERBISITY_VALUES:
|
||
|
raise ValueError(
|
||
|
"Unsupported verbosity value %s (supported values: %s)"
|
||
|
% (verbosity, self._SUPPORTED_VERBISITY_VALUES)
|
||
|
)
|
||
|
self._verbosity = verbosity
|
||
|
self._stats = UploadStats()
|
||
|
self._send_count = 0
|
||
|
self._one_shot = one_shot
|
||
|
|
||
|
def _dummy_generator(self):
|
||
|
while True:
|
||
|
# Yield an arbitrary value 0: The progress bar is indefinite.
|
||
|
yield 0
|
||
|
|
||
|
def _overwrite_line_message(self, message, color_code=_STYLE_GREEN):
|
||
|
"""Overwrite the current line with a stylized message."""
|
||
|
if not self._verbosity:
|
||
|
return
|
||
|
message += "." * 3
|
||
|
sys.stdout.write(
|
||
|
_STYLE_ERASE_LINE + color_code + message + _STYLE_RESET + "\r"
|
||
|
)
|
||
|
sys.stdout.flush()
|
||
|
|
||
|
def _single_line_message(self, message):
|
||
|
"""Write a timestamped single line, with newline, to stdout."""
|
||
|
if not self._verbosity:
|
||
|
return
|
||
|
start_message = "%s[%s]%s %s\n" % (
|
||
|
_STYLE_BOLD,
|
||
|
readable_time_string(),
|
||
|
_STYLE_RESET,
|
||
|
message,
|
||
|
)
|
||
|
sys.stdout.write(start_message)
|
||
|
sys.stdout.flush()
|
||
|
|
||
|
def has_data(self):
|
||
|
"""Determine if any data has been uploaded under the tracker's watch."""
|
||
|
return self._stats.has_data()
|
||
|
|
||
|
def _update_cumulative_status(self):
|
||
|
"""Write an update summarizing the data uploaded since the start."""
|
||
|
if not self._verbosity:
|
||
|
return
|
||
|
if not self._stats.has_new_data_since_last_summarize():
|
||
|
return
|
||
|
uploaded_str, skipped_str = self._stats.summarize()
|
||
|
uploaded_message = "%s[%s]%s Total uploaded: %s\n" % (
|
||
|
_STYLE_BOLD,
|
||
|
readable_time_string(),
|
||
|
_STYLE_RESET,
|
||
|
uploaded_str,
|
||
|
)
|
||
|
sys.stdout.write(uploaded_message)
|
||
|
if skipped_str:
|
||
|
sys.stdout.write(
|
||
|
"%sTotal skipped: %s\n%s"
|
||
|
% (_STYLE_DARKGRAY, skipped_str, _STYLE_RESET)
|
||
|
)
|
||
|
sys.stdout.flush()
|
||
|
# TODO(cais): Add summary of what plugins have been involved, once it's
|
||
|
# clear how to get canonical plugin names.
|
||
|
|
||
|
def add_plugin_name(self, plugin_name):
|
||
|
self._stats.add_plugin(plugin_name)
|
||
|
|
||
|
@contextlib.contextmanager
|
||
|
def send_tracker(self):
|
||
|
"""Create a context manager for a round of data sending."""
|
||
|
self._send_count += 1
|
||
|
if self._send_count == 1:
|
||
|
self._single_line_message("Started scanning logdir.")
|
||
|
try:
|
||
|
# self._reset_bars()
|
||
|
self._overwrite_line_message("Data upload starting")
|
||
|
yield
|
||
|
finally:
|
||
|
self._update_cumulative_status()
|
||
|
if self._one_shot:
|
||
|
self._single_line_message("Done scanning logdir.")
|
||
|
else:
|
||
|
self._overwrite_line_message(
|
||
|
"Listening for new data in logdir",
|
||
|
color_code=_STYLE_YELLOW,
|
||
|
)
|
||
|
|
||
|
@contextlib.contextmanager
|
||
|
def scalars_tracker(self, num_scalars):
|
||
|
"""Create a context manager for tracking a scalar batch upload.
|
||
|
|
||
|
Args:
|
||
|
num_scalars: Number of scalars in the batch.
|
||
|
"""
|
||
|
self._overwrite_line_message("Uploading %d scalars" % num_scalars)
|
||
|
try:
|
||
|
yield
|
||
|
finally:
|
||
|
self._stats.add_scalars(num_scalars)
|
||
|
|
||
|
@contextlib.contextmanager
|
||
|
def tensors_tracker(
|
||
|
self,
|
||
|
num_tensors,
|
||
|
num_tensors_skipped,
|
||
|
tensor_bytes,
|
||
|
tensor_bytes_skipped,
|
||
|
):
|
||
|
"""Create a context manager for tracking a tensor batch upload.
|
||
|
|
||
|
Args:
|
||
|
num_tensors: Total number of tensors in the batch.
|
||
|
num_tensors_skipped: Number of tensors skipped (a subset of
|
||
|
`num_tensors`). Hence this must be `<= num_tensors`.
|
||
|
tensor_bytes: Total byte size of the tensors in the batch.
|
||
|
tensor_bytes_skipped: Byte size of skipped tensors in the batch (a
|
||
|
subset of `tensor_bytes`). Must be `<= tensor_bytes`.
|
||
|
"""
|
||
|
if num_tensors_skipped:
|
||
|
message = "Uploading %d tensors (%s) (Skipping %d tensors, %s)" % (
|
||
|
num_tensors - num_tensors_skipped,
|
||
|
readable_bytes_string(tensor_bytes - tensor_bytes_skipped),
|
||
|
num_tensors_skipped,
|
||
|
readable_bytes_string(tensor_bytes_skipped),
|
||
|
)
|
||
|
else:
|
||
|
message = "Uploading %d tensors (%s)" % (
|
||
|
num_tensors,
|
||
|
readable_bytes_string(tensor_bytes),
|
||
|
)
|
||
|
self._overwrite_line_message(message)
|
||
|
try:
|
||
|
yield
|
||
|
finally:
|
||
|
self._stats.add_tensors(
|
||
|
num_tensors,
|
||
|
num_tensors_skipped,
|
||
|
tensor_bytes,
|
||
|
tensor_bytes_skipped,
|
||
|
)
|
||
|
|
||
|
@contextlib.contextmanager
|
||
|
def blob_tracker(self, blob_bytes):
|
||
|
"""Creates context manager tracker for uploading a blob.
|
||
|
|
||
|
Args:
|
||
|
blob_bytes: Total byte size of the blob being uploaded.
|
||
|
"""
|
||
|
self._overwrite_line_message(
|
||
|
"Uploading binary object (%s)" % readable_bytes_string(blob_bytes)
|
||
|
)
|
||
|
try:
|
||
|
yield _BlobTracker(self._stats, blob_bytes)
|
||
|
finally:
|
||
|
pass
|
||
|
|
||
|
|
||
|
class _BlobTracker:
|
||
|
def __init__(self, upload_stats, blob_bytes):
|
||
|
self._upload_stats = upload_stats
|
||
|
self._blob_bytes = blob_bytes
|
||
|
|
||
|
def mark_uploaded(self, is_uploaded):
|
||
|
self._upload_stats.add_blob(
|
||
|
self._blob_bytes, is_skipped=(not is_uploaded)
|
||
|
)
|