# Copyright 2022 The TensorFlow Authors. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== """Keras utilities for DTensor unit test.""" import numpy as np import tensorflow.compat.v2 as tf from absl.testing import parameterized # isort: off from tensorflow.dtensor.python import api as dtensor_api from tensorflow.python.eager import context _DEFAULT_GPU_MEMORY_LIMIT = 200 # MB class DTensorBaseTest(tf.test.TestCase, parameterized.TestCase): """Provides comparison helper for dtensor vs local results.""" @classmethod def setUpClass(cls): super(DTensorBaseTest, cls).setUpClass() def tearDown(self): super().tearDown() # Make sure all async ops finish. context.async_wait() # TODO(hthu): Remove the reset once we fixed the CopyToMesh with # DefaultMesh placement issue. reset_dtensor() @staticmethod def configTestMesh(device_type_mesh_map): """Configs corresponding mesh given test context. If runs on a CPU mesh, set virtual device on CPU. If runs on a GPU mesh, sets virtual device on GPU with proper memory limits. if runs on a TPU mesh, initializes TPU system. Args: device_type_mesh_map: A dictionary containing device_type -> mesh mapping. Returns: A properly configured mesh for use in test. """ reset_context() def get_mesh(device_type): mesh = device_type_mesh_map.get(device_type, None) if mesh is None: dt = device_type raise ValueError(f"Requires a {dt} mesh to run test on {dt}.") return mesh mesh = None if tf.config.list_physical_devices("GPU"): mesh = get_mesh("GPU") reset_logical_devices("GPU", np.prod(mesh.shape())) else: mesh = get_mesh("CPU") reset_logical_devices("CPU", np.prod(mesh.shape())) context.ensure_initialized() return mesh def create_device_array(shape, device_type): device_count = np.prod(shape) return np.asarray( [ tf.DeviceSpec( job="localhost/replica:0/task:0", device_type=device_type, device_index=i, ) for i in range(device_count) ] ).reshape(shape) def create_device_list(shape, device_type): devices = create_device_array(shape, device_type) return np.ravel(devices).tolist() def create_device_ids_array(shape): device_count = np.prod(shape) return np.arange(device_count).reshape(shape) def reset_context(): context._reset_context() def reset_logical_devices(device_type, count): """Resets logical devices for CPU/GPU. Logical devices can only be instantiated once on a particular context. For now, context re-use is triggering some function duplication errors, so we reset the context on each call. Args: device_type: The device_type to reset. count: numbers of virtual device to reset to. """ reset_context() devices = tf.config.list_physical_devices(device_type) if device_type.upper() == "CPU": tf.config.set_logical_device_configuration( devices[0], [ tf.config.LogicalDeviceConfiguration(), ] * count, ) elif device_type.upper() == "GPU": tf.config.set_logical_device_configuration( devices[0], [ tf.config.LogicalDeviceConfiguration( memory_limit=_DEFAULT_GPU_MEMORY_LIMIT ), ] * count, ) else: dt = device_type raise ValueError( f"resetting logical device for non-supported device type: {dt}" ) def reset_dtensor(): dtensor_api._reset()