181 lines
6.3 KiB
Python
181 lines
6.3 KiB
Python
# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
# ==============================================================================
|
|
"""Utilities that help manage directory path in distributed settings.
|
|
|
|
In multi-worker training, the need to write a file to distributed file
|
|
location often requires only one copy done by one worker despite many workers
|
|
that are involved in training. The option to only perform saving by chief is
|
|
not feasible for a couple of reasons: 1) Chief and workers may each contain
|
|
a client that runs the same piece of code and it's preferred not to make
|
|
any distinction between the code run by chief and other workers, and 2)
|
|
saving of model or model's related information may require SyncOnRead
|
|
variables to be read, which needs the cooperation of all workers to perform
|
|
all-reduce.
|
|
|
|
This set of utility is used so that only one copy is written to the needed
|
|
directory, by supplying a temporary write directory path for workers that don't
|
|
need to save, and removing the temporary directory once file writing is done.
|
|
|
|
Example usage:
|
|
```
|
|
# Before using a directory to write file to.
|
|
self.log_write_dir = write_dirpath(self.log_dir, get_distribution_strategy())
|
|
# Now `self.log_write_dir` can be safely used to write file to.
|
|
|
|
...
|
|
|
|
# After the file is written to the directory.
|
|
remove_temp_dirpath(self.log_dir, get_distribution_strategy())
|
|
|
|
```
|
|
|
|
Experimental. API is subject to change.
|
|
"""
|
|
|
|
import os
|
|
|
|
import requests
|
|
import tensorflow.compat.v2 as tf
|
|
|
|
GCP_METADATA_HEADER = {"Metadata-Flavor": "Google"}
|
|
_GCE_METADATA_URL_ENV_VARIABLE = "GCE_METADATA_IP"
|
|
|
|
|
|
def _get_base_dirpath(strategy):
|
|
task_id = strategy.extended._task_id
|
|
return "workertemp_" + str(task_id)
|
|
|
|
|
|
def _is_temp_dir(dirpath, strategy):
|
|
return dirpath.endswith(_get_base_dirpath(strategy))
|
|
|
|
|
|
def _get_temp_dir(dirpath, strategy):
|
|
if _is_temp_dir(dirpath, strategy):
|
|
temp_dir = dirpath
|
|
else:
|
|
temp_dir = os.path.join(dirpath, _get_base_dirpath(strategy))
|
|
tf.io.gfile.makedirs(temp_dir)
|
|
return temp_dir
|
|
|
|
|
|
def write_dirpath(dirpath, strategy):
|
|
"""Returns the writing dir that should be used to save file distributedly.
|
|
|
|
`dirpath` would be created if it doesn't exist.
|
|
|
|
Args:
|
|
dirpath: Original dirpath that would be used without distribution.
|
|
strategy: The tf.distribute strategy object currently used.
|
|
|
|
Returns:
|
|
The writing dir path that should be used to save with distribution.
|
|
"""
|
|
if strategy is None:
|
|
# Infer strategy from `distribution_strategy_context` if not given.
|
|
strategy = tf.distribute.get_strategy()
|
|
if strategy is None:
|
|
# If strategy is still not available, this is not in distributed
|
|
# training. Fallback to original dirpath.
|
|
return dirpath
|
|
if not strategy.extended._in_multi_worker_mode():
|
|
return dirpath
|
|
if strategy.extended.should_checkpoint:
|
|
return dirpath
|
|
# If this worker is not chief and hence should not save file, save it to a
|
|
# temporary directory to be removed later.
|
|
return _get_temp_dir(dirpath, strategy)
|
|
|
|
|
|
def remove_temp_dirpath(dirpath, strategy):
|
|
"""Removes the temp path after writing is finished.
|
|
|
|
Args:
|
|
dirpath: Original dirpath that would be used without distribution.
|
|
strategy: The tf.distribute strategy object currently used.
|
|
"""
|
|
if strategy is None:
|
|
# Infer strategy from `distribution_strategy_context` if not given.
|
|
strategy = tf.distribute.get_strategy()
|
|
if strategy is None:
|
|
# If strategy is still not available, this is not in distributed
|
|
# training. Fallback to no-op.
|
|
return
|
|
# TODO(anjalisridhar): Consider removing the check for multi worker mode
|
|
# since it is redundant when used with the should_checkpoint property.
|
|
if (
|
|
strategy.extended._in_multi_worker_mode()
|
|
and not strategy.extended.should_checkpoint
|
|
):
|
|
# If this worker is not chief and hence should not save file, remove
|
|
# the temporary directory.
|
|
tf.compat.v1.gfile.DeleteRecursively(_get_temp_dir(dirpath, strategy))
|
|
|
|
|
|
def write_filepath(filepath, strategy):
|
|
"""Returns the writing file path to be used to save file distributedly.
|
|
|
|
Directory to contain `filepath` would be created if it doesn't exist.
|
|
|
|
Args:
|
|
filepath: Original filepath that would be used without distribution.
|
|
strategy: The tf.distribute strategy object currently used.
|
|
|
|
Returns:
|
|
The writing filepath that should be used to save file with distribution.
|
|
"""
|
|
dirpath = os.path.dirname(filepath)
|
|
base = os.path.basename(filepath)
|
|
return os.path.join(write_dirpath(dirpath, strategy), base)
|
|
|
|
|
|
def remove_temp_dir_with_filepath(filepath, strategy):
|
|
"""Removes the temp path for file after writing is finished.
|
|
|
|
Args:
|
|
filepath: Original filepath that would be used without distribution.
|
|
strategy: The tf.distribute strategy object currently used.
|
|
"""
|
|
remove_temp_dirpath(os.path.dirname(filepath), strategy)
|
|
|
|
|
|
def _on_gcp():
|
|
"""Detect whether the current running environment is on GCP."""
|
|
gce_metadata_endpoint = "http://" + os.environ.get(
|
|
_GCE_METADATA_URL_ENV_VARIABLE, "metadata.google.internal"
|
|
)
|
|
|
|
try:
|
|
# Timeout in 5 seconds, in case the test environment has connectivity
|
|
# issue. There is not default timeout, which means it might block
|
|
# forever.
|
|
response = requests.get(
|
|
f"{gce_metadata_endpoint}/computeMetadata/v1/{'instance/hostname'}",
|
|
headers=GCP_METADATA_HEADER,
|
|
timeout=5,
|
|
)
|
|
return response.status_code
|
|
except requests.exceptions.RequestException:
|
|
return False
|
|
|
|
|
|
def support_on_demand_checkpoint_callback(strategy):
|
|
if _on_gcp() and isinstance(
|
|
strategy, tf.distribute.MultiWorkerMirroredStrategy
|
|
):
|
|
return True
|
|
|
|
return False
|