167 lines
6.1 KiB
Python
167 lines
6.1 KiB
Python
# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
# ==============================================================================
|
|
"""Contains function to log if devices are compatible with mixed precision."""
|
|
|
|
import itertools
|
|
|
|
import tensorflow.compat.v2 as tf
|
|
|
|
# isort: off
|
|
from tensorflow.python.platform import tf_logging
|
|
|
|
_COMPAT_CHECK_PREFIX = "Mixed precision compatibility check (mixed_float16): "
|
|
_COMPAT_CHECK_OK_PREFIX = _COMPAT_CHECK_PREFIX + "OK"
|
|
_COMPAT_CHECK_WARNING_PREFIX = _COMPAT_CHECK_PREFIX + "WARNING"
|
|
_COMPAT_CHECK_WARNING_SUFFIX = (
|
|
"If you will use compatible GPU(s) not attached to this host, e.g. by "
|
|
"running a multi-worker model, you can ignore this warning. This message "
|
|
"will only be logged once"
|
|
)
|
|
|
|
|
|
def _dedup_strings(device_strs):
|
|
"""Groups together consecutive identical strings.
|
|
|
|
For example, given:
|
|
['GPU 1', 'GPU 2', 'GPU 2', 'GPU 3', 'GPU 3', 'GPU 3']
|
|
This function returns:
|
|
['GPU 1', 'GPU 2 (x2)', 'GPU 3 (x3)']
|
|
|
|
Args:
|
|
device_strs: A list of strings, each representing a device.
|
|
|
|
Returns:
|
|
A copy of the input, but identical consecutive strings are merged into a
|
|
single string.
|
|
"""
|
|
new_device_strs = []
|
|
for device_str, vals in itertools.groupby(device_strs):
|
|
num = len(list(vals))
|
|
if num == 1:
|
|
new_device_strs.append(device_str)
|
|
else:
|
|
new_device_strs.append("%s (x%d)" % (device_str, num))
|
|
return new_device_strs
|
|
|
|
|
|
def _log_device_compatibility_check(policy_name, gpu_details_list):
|
|
"""Logs a compatibility check if the devices support the policy.
|
|
|
|
Currently only logs for the policy mixed_float16.
|
|
|
|
Args:
|
|
policy_name: The name of the dtype policy.
|
|
gpu_details_list: A list of dicts, one dict per GPU. Each dict
|
|
is the device details for a GPU, as returned by
|
|
`tf.config.experimental.get_device_details()`.
|
|
"""
|
|
if policy_name != "mixed_float16":
|
|
# TODO(b/145686977): Log if the policy is 'mixed_bfloat16'. This
|
|
# requires checking if a TPU is available.
|
|
return
|
|
supported_device_strs = []
|
|
unsupported_device_strs = []
|
|
for details in gpu_details_list:
|
|
name = details.get("device_name", "Unknown GPU")
|
|
cc = details.get("compute_capability")
|
|
if cc:
|
|
device_str = f"{name}, compute capability {cc[0]}.{cc[1]}"
|
|
if cc >= (7, 0):
|
|
supported_device_strs.append(device_str)
|
|
else:
|
|
unsupported_device_strs.append(device_str)
|
|
else:
|
|
unsupported_device_strs.append(
|
|
name + ", no compute capability (probably not an Nvidia GPU)"
|
|
)
|
|
|
|
if unsupported_device_strs:
|
|
warning_str = _COMPAT_CHECK_WARNING_PREFIX + "\n"
|
|
if supported_device_strs:
|
|
warning_str += (
|
|
"Some of your GPUs may run slowly with dtype policy "
|
|
"mixed_float16 because they do not all have compute "
|
|
"capability of at least 7.0. Your GPUs:\n"
|
|
)
|
|
elif len(unsupported_device_strs) == 1:
|
|
warning_str += (
|
|
"Your GPU may run slowly with dtype policy mixed_float16 "
|
|
"because it does not have compute capability of at least "
|
|
"7.0. Your GPU:\n"
|
|
)
|
|
else:
|
|
warning_str += (
|
|
"Your GPUs may run slowly with dtype policy "
|
|
"mixed_float16 because they do not have compute "
|
|
"capability of at least 7.0. Your GPUs:\n"
|
|
)
|
|
for device_str in _dedup_strings(
|
|
supported_device_strs + unsupported_device_strs
|
|
):
|
|
warning_str += " " + device_str + "\n"
|
|
warning_str += (
|
|
"See https://developer.nvidia.com/cuda-gpus for a list of "
|
|
"GPUs and their compute capabilities.\n"
|
|
)
|
|
warning_str += _COMPAT_CHECK_WARNING_SUFFIX
|
|
tf_logging.warning(warning_str)
|
|
elif not supported_device_strs:
|
|
tf_logging.warning(
|
|
"%s\n"
|
|
"The dtype policy mixed_float16 may run slowly because "
|
|
"this machine does not have a GPU. Only Nvidia GPUs with "
|
|
"compute capability of at least 7.0 run quickly with "
|
|
"mixed_float16.\n%s"
|
|
% (_COMPAT_CHECK_WARNING_PREFIX, _COMPAT_CHECK_WARNING_SUFFIX)
|
|
)
|
|
elif len(supported_device_strs) == 1:
|
|
tf_logging.info(
|
|
"%s\n"
|
|
"Your GPU will likely run quickly with dtype policy "
|
|
"mixed_float16 as it has compute capability of at least "
|
|
"7.0. Your GPU: %s"
|
|
% (_COMPAT_CHECK_OK_PREFIX, supported_device_strs[0])
|
|
)
|
|
else:
|
|
tf_logging.info(
|
|
"%s\n"
|
|
"Your GPUs will likely run quickly with dtype policy "
|
|
"mixed_float16 as they all have compute capability of at "
|
|
"least 7.0" % _COMPAT_CHECK_OK_PREFIX
|
|
)
|
|
|
|
|
|
_logged_compatibility_check = False
|
|
|
|
|
|
def log_device_compatibility_check(policy_name):
|
|
"""Logs a compatibility check if the devices support the policy.
|
|
|
|
Currently only logs for the policy mixed_float16. A log is shown only the
|
|
first time this function is called.
|
|
|
|
Args:
|
|
policy_name: The name of the dtype policy.
|
|
"""
|
|
global _logged_compatibility_check
|
|
if _logged_compatibility_check:
|
|
return
|
|
_logged_compatibility_check = True
|
|
gpus = tf.config.list_physical_devices("GPU")
|
|
gpu_details_list = [
|
|
tf.config.experimental.get_device_details(g) for g in gpus
|
|
]
|
|
_log_device_compatibility_check(policy_name, gpu_details_list)
|