Intelegentny_Pszczelarz/.venv/Lib/site-packages/keras/dtensor/test_util.py
2023-06-19 00:49:18 +02:00

149 lines
4.4 KiB
Python

# Copyright 2022 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Keras utilities for DTensor unit test."""
import numpy as np
import tensorflow.compat.v2 as tf
from absl.testing import parameterized
# isort: off
from tensorflow.dtensor.python import api as dtensor_api
from tensorflow.python.eager import context
_DEFAULT_GPU_MEMORY_LIMIT = 200 # MB
class DTensorBaseTest(tf.test.TestCase, parameterized.TestCase):
"""Provides comparison helper for dtensor vs local results."""
@classmethod
def setUpClass(cls):
super(DTensorBaseTest, cls).setUpClass()
def tearDown(self):
super().tearDown()
# Make sure all async ops finish.
context.async_wait()
# TODO(hthu): Remove the reset once we fixed the CopyToMesh with
# DefaultMesh placement issue.
reset_dtensor()
@staticmethod
def configTestMesh(device_type_mesh_map):
"""Configs corresponding mesh given test context.
If runs on a CPU mesh, set virtual device on CPU.
If runs on a GPU mesh, sets virtual device on GPU with proper memory
limits.
if runs on a TPU mesh, initializes TPU system.
Args:
device_type_mesh_map: A dictionary containing device_type -> mesh
mapping.
Returns:
A properly configured mesh for use in test.
"""
reset_context()
def get_mesh(device_type):
mesh = device_type_mesh_map.get(device_type, None)
if mesh is None:
dt = device_type
raise ValueError(f"Requires a {dt} mesh to run test on {dt}.")
return mesh
mesh = None
if tf.config.list_physical_devices("GPU"):
mesh = get_mesh("GPU")
reset_logical_devices("GPU", np.prod(mesh.shape()))
else:
mesh = get_mesh("CPU")
reset_logical_devices("CPU", np.prod(mesh.shape()))
context.ensure_initialized()
return mesh
def create_device_array(shape, device_type):
device_count = np.prod(shape)
return np.asarray(
[
tf.DeviceSpec(
job="localhost/replica:0/task:0",
device_type=device_type,
device_index=i,
)
for i in range(device_count)
]
).reshape(shape)
def create_device_list(shape, device_type):
devices = create_device_array(shape, device_type)
return np.ravel(devices).tolist()
def create_device_ids_array(shape):
device_count = np.prod(shape)
return np.arange(device_count).reshape(shape)
def reset_context():
context._reset_context()
def reset_logical_devices(device_type, count):
"""Resets logical devices for CPU/GPU.
Logical devices can only be instantiated once on a particular context. For
now, context re-use is triggering some function duplication errors, so we
reset the context on each call.
Args:
device_type: The device_type to reset.
count: numbers of virtual device to reset to.
"""
reset_context()
devices = tf.config.list_physical_devices(device_type)
if device_type.upper() == "CPU":
tf.config.set_logical_device_configuration(
devices[0],
[
tf.config.LogicalDeviceConfiguration(),
]
* count,
)
elif device_type.upper() == "GPU":
tf.config.set_logical_device_configuration(
devices[0],
[
tf.config.LogicalDeviceConfiguration(
memory_limit=_DEFAULT_GPU_MEMORY_LIMIT
),
]
* count,
)
else:
dt = device_type
raise ValueError(
f"resetting logical device for non-supported device type: {dt}"
)
def reset_dtensor():
dtensor_api._reset()