# Copyright 2018 The JAX Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # https://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """Interface and utility functions to XLA. This module wraps the XLA client(s) and builders to standardize their interfaces and provide some automatic type mapping logic for converting between Numpy and XLA. There are also a handful of related casting utilities. """ from functools import partial, lru_cache import importlib import io import json import logging import os import platform as py_platform import pkgutil import sys import threading from typing import Any, Callable, Dict, List, Mapping, Optional, Tuple, Union import warnings import numpy as np from jax._src import lib from jax._src import distributed from jax._src.config import flags, bool_env, config, int_env from jax._src.lib import xla_client from jax._src.lib import xla_extension_version from jax._src import traceback_util from jax._src import util iree: Optional[Any] try: import jax._src.iree as iree # type: ignore except (ModuleNotFoundError, ImportError): iree = None logger = logging.getLogger(__name__) jax_plugins: Optional[Any] try: import jax_plugins # type: ignore except ModuleNotFoundError: jax_plugins = None except ImportError as e: logger.error("Failed to import jax_plugins: %s", e) jax_plugins = None traceback_util.register_exclusion(__file__) XlaBackend = xla_client.Client FLAGS = flags.FLAGS # TODO(phawkins): Remove jax_xla_backend. flags.DEFINE_string( 'jax_xla_backend', '', 'Deprecated, please use --jax_platforms instead.') flags.DEFINE_string( 'jax_backend_target', os.getenv('JAX_BACKEND_TARGET', '').lower(), 'Either "local" or "rpc:address" to connect to a remote service target.') # TODO(skye): warn when this is used once we test out --jax_platforms a bit flags.DEFINE_string( 'jax_platform_name', os.getenv('JAX_PLATFORM_NAME', '').lower(), 'Deprecated, please use --jax_platforms instead.') flags.DEFINE_bool( 'jax_disable_most_optimizations', bool_env('JAX_DISABLE_MOST_OPTIMIZATIONS', False), 'Try not to do much optimization work. This can be useful if the cost of ' 'optimization is greater than that of running a less-optimized program.') flags.DEFINE_integer( 'jax_xla_profile_version', int_env('JAX_XLA_PROFILE_VERSION', 0), 'Optional profile version for XLA compilation. ' 'This is meaningful only when XLA is configured to ' 'support the remote compilation profile feature.') flags.DEFINE_string( 'jax_cuda_visible_devices', 'all', 'Restricts the set of CUDA devices that JAX will use. Either "all", or a ' 'comma-separate list of integer device IDs.') flags.DEFINE_string( 'jax_rocm_visible_devices', 'all', 'Restricts the set of ROCM devices that JAX will use. Either "all", or a ' 'comma-separate list of integer device IDs.') def get_compile_options( num_replicas: int, num_partitions: int, device_assignment=None, use_spmd_partitioning: bool = True, use_auto_spmd_partitioning: bool = False, auto_spmd_partitioning_mesh_shape=[], auto_spmd_partitioning_mesh_ids=[], env_options_overrides: Optional[Dict[str, str]] = None, ) -> xla_client.CompileOptions: """Returns the compile options to use, as derived from flag values. Args: num_replicas: Number of replicas for which to compile. num_partitions: Number of partitions for which to compile. device_assignment: Optional ndarray of jax devices indicating the assignment of logical replicas to physical devices (default inherited from xla_client.CompileOptions). Must be consistent with `num_replicas` and `num_partitions`. use_spmd_partitioning: boolean indicating whether to enable SPMD or MPMD partitioning in XLA. use_auto_spmd_partitioning: boolean indicating whether to automatically generate XLA shardings for SPMD partitioner. auto_spmd_partitioning_mesh_shape: device mesh shape used to create auto_spmd_partitioning search space. auto_spmd_partitioning_mesh_ids: device ids used to create auto_spmd_partitioning search space. env_options_overrides: dict of additional options parsed by the compiler """ compile_options = xla_client.CompileOptions() compile_options.num_replicas = num_replicas compile_options.num_partitions = num_partitions build_options = compile_options.executable_build_options build_options.use_spmd_partitioning = use_spmd_partitioning build_options.use_auto_spmd_partitioning = use_auto_spmd_partitioning if use_auto_spmd_partitioning: build_options.auto_spmd_partitioning_mesh_shape = auto_spmd_partitioning_mesh_shape build_options.auto_spmd_partitioning_mesh_ids = auto_spmd_partitioning_mesh_ids if device_assignment is not None: logger.debug( 'get_compile_options: num_replicas=%s num_partitions=%s device_assignment=%s', num_replicas, num_partitions, device_assignment) device_assignment = np.array(device_assignment) # Allow 1D device assignment if num_partitions is 1. if (device_assignment.ndim == 1) and (num_partitions == 1): device_assignment = device_assignment[:, None] if num_replicas != device_assignment.shape[0]: msg = 'device_assignment does not match num_replicas: {} vs {}.' raise ValueError(msg.format(device_assignment, num_replicas)) if num_partitions != device_assignment.shape[1]: msg = 'device_assignment does not match num_partitions: {} vs {}.' raise ValueError(msg.format(device_assignment, num_partitions)) if device_assignment.dtype == object: device_assignment = np.vectorize(lambda d: d.id, otypes=[int])( device_assignment) device_assignment = xla_client.DeviceAssignment.create(device_assignment) assert device_assignment.replica_count() == num_replicas assert device_assignment.computation_count() == num_partitions compile_options.device_assignment = device_assignment if env_options_overrides is not None: compile_options.env_option_overrides = list(env_options_overrides.items()) debug_options = compile_options.executable_build_options.debug_options if lib.cuda_path is not None: debug_options.xla_gpu_cuda_data_dir = lib.cuda_path if FLAGS.jax_disable_most_optimizations: debug_options.xla_backend_optimization_level = 0 debug_options.xla_llvm_disable_expensive_passes = True debug_options.xla_test_all_input_layouts = False compile_options.profile_version = FLAGS.jax_xla_profile_version return compile_options # Backends def tpu_client_timer_callback(timer_secs: float) -> Optional[xla_client.Client]: def _log_warning(): warnings.warn( f'TPU backend initialization is taking more than {timer_secs} seconds. ' 'Did you run your code on all TPU hosts? ' 'See https://jax.readthedocs.io/en/latest/multi_process.html ' 'for more information.') # Will log a warning after `timer_secs`. t = threading.Timer(timer_secs, _log_warning) t.start() try: client = xla_client.make_tpu_client() finally: t.cancel() return client # Backends, in increasing order of preference. # We have no particular opinion about how "backends" relate to "devices". For # example, there could be multiple backends that provide the same kind of # device. BackendFactory = Callable[[], Optional[xla_client.Client]] _backend_factories: Dict[str, Tuple[BackendFactory, int]] = {} _default_backend: Optional[xla_client.Client] = None _backends : Dict[str, xla_client.Client] = {} _backends_errors : Dict[str, str] = {} _backend_lock = threading.Lock() def register_backend_factory(name: str, factory: BackendFactory, *, priority: int = 0) -> None: with _backend_lock: if name in _backends: raise RuntimeError(f"Backend {name} already initialized") _backend_factories[name] = (factory, priority) register_backend_factory('interpreter', xla_client.make_interpreter_client, priority=-100) register_backend_factory('cpu', partial(xla_client.make_cpu_client, use_tfrt=True), priority=0) def make_gpu_client( *, platform_name: str, visible_devices_flag: str ) -> xla_client.Client: visible_devices = getattr(FLAGS, visible_devices_flag, "all") allowed_devices = None if visible_devices != "all": allowed_devices = {int(x) for x in visible_devices.split(",")} if xla_extension_version < 160: return xla_client.make_gpu_client( distributed_client=distributed.global_state.client, node_id=distributed.global_state.process_id, platform_name=platform_name, allowed_devices=allowed_devices, ) else: # Remove `type: ignore` when the min jaxlib version (xla_extension_version) # >= 160. return xla_client.make_gpu_client( distributed_client=distributed.global_state.client, node_id=distributed.global_state.process_id, num_nodes=distributed.global_state.num_processes, platform_name=platform_name, allowed_devices=allowed_devices, ) # type: ignore if hasattr(xla_client, "make_gpu_client"): register_backend_factory( 'cuda', partial(make_gpu_client, platform_name='cuda', visible_devices_flag='jax_cuda_visible_devices'), priority=200) register_backend_factory( 'rocm', partial(make_gpu_client, platform_name='rocm', visible_devices_flag='jax_rocm_visible_devices'), priority=200) if hasattr(xla_client, "make_tpu_client"): register_backend_factory( 'tpu', partial(tpu_client_timer_callback, timer_secs=60.0), priority=300) if hasattr(xla_client, "make_plugin_device_client"): # It is assumed that if jax has been built with a plugin client, then the # user wants to use the plugin client by default. Therefore, it gets the # highest priority. register_backend_factory("plugin", xla_client.make_plugin_device_client, priority=400) def _get_pjrt_plugin_names_and_library_paths( plugins_from_env: str, ) -> Dict[str, str]: """Gets the names and library paths of PJRT plugins to load from env var. Args: plugins_from_env: plugin name and pathes from env var. It is in the format of 'name1:path1,name2:path2' ('name1;path1,name2;path2' for windows). Returns: A dict of {plugin_name: library path} for the PJRT plugins to load. """ if not plugins_from_env: return {} pjrt_plugins = {} for plugin in plugins_from_env.split(','): try: name, library_path = plugin.split(os.path.pathsep) pjrt_plugins[name] = library_path except ValueError: logger.warning( 'invalid value %s in env var PJRT_NAMES_AND_LIBRARY_PATHS %s', plugin, plugins_from_env, ) return pjrt_plugins def _get_pjrt_plugin_config( json_path: str, ) -> Tuple[str, Optional[Mapping[str, Union[str, int, List[int], float]]]]: """Gets PJRT plugin configuration from a json file. The json file needs to have a "library_path" field for the plugin library path. It can have an optional "create_option" field for the options used when creating a PJRT plugin client. The value of "create_option" is key-value pairs. Please see xla_client._NameValueMapping for the supported types of values. """ with io.open(json_path, 'r') as f: config = json.load(f) if 'library_path' not in config.keys(): raise ValueError( 'PJRT plugin config file should contain "library_path" field.' ) return (config['library_path'], config.get('create_options')) def discover_pjrt_plugins() -> None: """Discovers plugins in the namespace package `jax_plugins` and import them. There are two methods used to discover plugin modules. They are intended to be used together by implementors in order to cover all packaging and development cases: 1. Define a globally unique module under the `jax_plugins` namespace package (i.e. just create a `jax_plugins` directory and define your module below it). 2. If building a package via pyproject.toml or setup.py, advertise your plugin module name by including an entry-point under the `jax_plugins` group which points to your full module name. During Jax startup, Jax will load each module discovered in such a way and call its `initialize()` function. It is expected that this function should register its concrete plugin name/implementations via call(s) to `jax._src.xla_bridge.register_plugin(name, priority=, library_paty=, options=)`. Since `initialize()` functions are called for all installed plugins, they should avoid doing expensive, non-registration related work. TODO: We should provide a variant of `register_plugin` which allows the library_path and options to be resolved via a callback. This would enable light-weight plugin registration in cases where options need to be derived from heavy-weight system initialization. """ plugin_modules = set() # Scan installed modules under |jax_plugins|. Note that not all packaging # scenarios are amenable to such scanning, so we also use the entry-point # method to seed the list. if jax_plugins: for _, name, _ in pkgutil.iter_modules( jax_plugins.__path__, jax_plugins.__name__ + '.' ): logger.debug("Discovered path based JAX plugin: %s", name) plugin_modules.add(name) else: logger.debug("No jax_plugins namespace packages available") # Augment with advertised entrypoints. if sys.version_info < (3, 10): # Use the backport library because it provides a forward-compatible # implementation. from importlib_metadata import entry_points else: from importlib.metadata import entry_points for entry_point in entry_points(group="jax_plugins"): logger.debug("Discovered entry-point based JAX plugin: %s", entry_point.value) plugin_modules.add(entry_point.value) # Now load and initialize them all. for plugin_module_name in plugin_modules: logger.debug("Loading plugin module %s", plugin_module_name) plugin_module = None try: plugin_module = importlib.import_module(plugin_module_name) except ModuleNotFoundError: logger.warning("Jax plugin configuration error: Plugin module %s " "does not exist", plugin_module_name) except ImportError: logger.exception("Jax plugin configuration error: Plugin module %s " "could not be loaded") if plugin_module: try: plugin_module.initialize() except: logger.exception("Jax plugin configuration error: Exception when " "calling %s.initialize()", plugin_module_name) # TODO(b/261345120): decide on a public name and expose a public method which is # an alias of this method. def register_plugin( plugin_name: str, *, priority: int = 400, library_path: Optional[str] = None, options: Optional[Mapping[str, Union[str, int, List[int], float]]] = None, ) -> None: """Registers a backend factory for the PJRT plugin. Args: plugin_name: the name of the plugin. priority: the priority this plugin should be registered in jax backends. Default to be 400. library_path: Optional. The full path to the .so file of the plugin. Required when the plugin is dynamically linked. options: Optional. It is used when creating a PJRT plugin client. """ def factory(): # Plugin may already be statically linked in some configurations. if not xla_client.pjrt_plugin_loaded(plugin_name): if library_path is None: raise ValueError( 'The library path is None when trying to dynamically load the' ' plugin.' ) xla_client.load_pjrt_plugin_dynamically(plugin_name, library_path) return xla_client.make_c_api_client(plugin_name, options) logger.debug( 'registering PJRT plugin %s from %s', plugin_name, library_path ) register_backend_factory(plugin_name, factory, priority=priority) def register_pjrt_plugin_factories_from_env() -> None: """Registers backend factories for PJRT plugins. A backend factory will be registered for every PJRT plugin in the input string, in the format of 'name1:path1,name2:path2' ('name1;path1,name2;path2' for windows). The path can be a path to the plugin library or a path to the plugin configuration json file. The json file needs to have a "library_path" field for the plugin library path. It can have an optional "create_option" field for the options used when creating a PJRT plugin client. The value of "create_option" is key-value pairs. Please see xla_client._NameValueMapping for the supported types of values. TPU PJRT plugin will be loaded and registered separately in make_tpu_client. """ pjrt_plugins = _get_pjrt_plugin_names_and_library_paths( os.getenv('PJRT_NAMES_AND_LIBRARY_PATHS', '') ) for plugin_name, path in pjrt_plugins.items(): if path.endswith('.json'): library_path, options = _get_pjrt_plugin_config(path) else: library_path = path options = None logger.debug( 'registering PJRT plugin %s from %s', plugin_name, library_path ) register_plugin(plugin_name, library_path=library_path, options=options) # Plugins in the namespace package `jax_plugins` will be imported. discover_pjrt_plugins() # Registers plugins names and paths set in env var PJRT_NAMES_AND_LIBRARY_PATHS, # in the format of 'name1:path1,name2:path2' ('name1;path1,name2;path2' for # windows). register_pjrt_plugin_factories_from_env() if iree is not None: register_backend_factory("iree", iree.iree_client_factory, priority=-100) _platform_aliases = { "cuda": "gpu", "rocm": "gpu", } _alias_to_platforms: Dict[str, List[str]] = {} for _platform, _alias in _platform_aliases.items(): _alias_to_platforms.setdefault(_alias, []).append(_platform) def is_known_platform(platform: str) -> bool: # A platform is valid if there is a registered factory for it. It does not # matter if we were unable to initialize that platform; we only care that # we've heard of it and it isn't, e.g., a typo. return (platform in _backend_factories.keys() or platform in _platform_aliases.keys()) def canonicalize_platform(platform: str) -> str: """Replaces platform aliases with their concrete equivalent. In particular, replaces "gpu" with either "cuda" or "rocm", depending on which hardware is actually present. We want to distinguish "cuda" and "rocm" for purposes such as MLIR lowering rules, but in many cases we don't want to force users to care. """ platforms = _alias_to_platforms.get(platform, None) if platforms is None: return platform b = backends() for p in platforms: if p in b.keys(): return p raise RuntimeError(f"Unknown backend: '{platform}' requested, but no " f"platforms that are instances of {platform} are present. " "Platforms present are: " + ",".join(b.keys())) def expand_platform_alias(platform: str) -> List[str]: """Expands, e.g., "gpu" to ["cuda", "rocm"]. This is used for convenience reasons: we expect cuda and rocm to act similarly in many respects since they share most of the same code. """ return _alias_to_platforms.get(platform, [platform]) def is_gpu(platform): return platform in ("cuda", "rocm") def backends() -> Dict[str, xla_client.Client]: global _backends global _backends_errors global _default_backend with _backend_lock: if _backends: return _backends if config.jax_platforms: jax_platforms = config.jax_platforms.split(",") platforms = [] # Allow platform aliases in the list of platforms. for platform in jax_platforms: platforms.extend(expand_platform_alias(platform)) priorities = range(len(platforms), 0, -1) platforms_and_priorities = list(zip(platforms, priorities)) else: platforms_and_priorities = list( (platform, priority) for platform, (_, priority) in _backend_factories.items()) default_priority = -1000 for platform, priority in platforms_and_priorities: try: backend = _init_backend(platform) _backends[platform] = backend if priority > default_priority: _default_backend = backend default_priority = priority except Exception as err: if platform in ('cpu', 'interpreter'): # We always expect the CPU and interpreter backends to initialize # successfully. raise else: # If the backend isn't built into the binary, or if it has no devices, # we expect a RuntimeError. err_msg = f"Unable to initialize backend '{platform}': {err}" if config.jax_platforms: err_msg += " (set JAX_PLATFORMS='' to automatically choose an available backend)" raise RuntimeError(err_msg) else: _backends_errors[platform] = str(err) logger.info(err_msg) continue assert _default_backend is not None # We don't warn about falling back to CPU on Mac OS, because we don't # support anything else there at the moment and warning would be pointless. if (py_platform.system() != "Darwin" and _default_backend.platform == "cpu" and FLAGS.jax_platform_name != 'cpu'): logger.warning('No GPU/TPU found, falling back to CPU. ' '(Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)') return _backends def _clear_backends() -> None: global _backends global _backends_errors global _default_backend logger.info("Clearing JAX backend caches.") with _backend_lock: _backends = {} _backends_errors = {} _default_backend = None get_backend.cache_clear() def _init_backend(platform: str) -> xla_client.Client: factory, unused_priority = _backend_factories.get(platform, (None, None)) if factory is None: raise RuntimeError( f"Backend '{platform}' is not in the list of known backends: " f"{list(_backend_factories.keys())}.") logger.debug("Initializing backend '%s'", platform) backend = factory() # TODO(skye): consider raising more descriptive errors directly from backend # factories instead of returning None. if backend is None: raise RuntimeError(f"Could not initialize backend '{platform}'") if backend.device_count() == 0: raise RuntimeError(f"Backend '{platform}' provides no devices.") util.distributed_debug_log(("Initialized backend", backend.platform), ("process_index", backend.process_index()), ("device_count", backend.device_count()), ("local_devices", backend.local_devices())) logger.debug("Backend '%s' initialized", platform) return backend def _get_backend_uncached( platform: Union[None, str, xla_client.Client] = None ) -> xla_client.Client: # TODO(mattjj,skyewm): remove this input polymorphism after we clean up how # 'backend' values are handled if platform is not None and not isinstance(platform, str): return platform platform = (platform or FLAGS.jax_xla_backend or FLAGS.jax_platform_name or None) bs = backends() if platform is not None: platform = canonicalize_platform(platform) backend = bs.get(platform, None) if backend is None: if platform in _backends_errors: raise RuntimeError(f"Backend '{platform}' failed to initialize: " f"{_backends_errors[platform]}") raise RuntimeError(f"Unknown backend {platform}") return backend else: assert _default_backend is not None return _default_backend @lru_cache(maxsize=None) # don't use util.memoize because there is no X64 dependence. def get_backend( platform: Union[None, str, xla_client.Client] = None ) -> xla_client.Client: return _get_backend_uncached(platform) def get_device_backend( device: Optional[xla_client.Device] = None, ) -> xla_client.Client: """Returns the Backend associated with `device`, or the default Backend.""" if device is not None: return device.client return get_backend() def device_count( backend: Optional[Union[str, xla_client.Client]] = None ) -> int: """Returns the total number of devices. On most platforms, this is the same as :py:func:`jax.local_device_count`. However, on multi-process platforms where different devices are associated with different processes, this will return the total number of devices across all processes. Args: backend: This is an experimental feature and the API is likely to change. Optional, a string representing the xla backend: ``'cpu'``, ``'gpu'``, or ``'tpu'``. Returns: Number of devices. """ return int(get_backend(backend).device_count()) def local_device_count( backend: Optional[Union[str, xla_client.Client]] = None ) -> int: """Returns the number of devices addressable by this process.""" return int(get_backend(backend).local_device_count()) def devices( backend: Optional[Union[str, xla_client.Client]] = None ) -> List[xla_client.Device]: """Returns a list of all devices for a given backend. .. currentmodule:: jaxlib.xla_extension Each device is represented by a subclass of :class:`Device` (e.g. :class:`CpuDevice`, :class:`GpuDevice`). The length of the returned list is equal to ``device_count(backend)``. Local devices can be identified by comparing :attr:`Device.process_index` to the value returned by :py:func:`jax.process_index`. If ``backend`` is ``None``, returns all the devices from the default backend. The default backend is generally ``'gpu'`` or ``'tpu'`` if available, otherwise ``'cpu'``. Args: backend: This is an experimental feature and the API is likely to change. Optional, a string representing the xla backend: ``'cpu'``, ``'gpu'``, or ``'tpu'``. Returns: List of Device subclasses. """ return get_backend(backend).devices() def default_backend() -> str: """Returns the platform name of the default XLA backend.""" return get_backend(None).platform @lru_cache def local_devices(process_index: Optional[int] = None, backend: Optional[Union[str, xla_client.Client]] = None, host_id: Optional[int] = None) -> List[xla_client.Device]: """Like :py:func:`jax.devices`, but only returns devices local to a given process. If ``process_index`` is ``None``, returns devices local to this process. Args: process_index: the integer index of the process. Process indices can be retrieved via ``len(jax.process_count())``. backend: This is an experimental feature and the API is likely to change. Optional, a string representing the xla backend: ``'cpu'``, ``'gpu'``, or ``'tpu'``. Returns: List of Device subclasses. """ if host_id is not None: warnings.warn( "The argument to jax.local_devices has been renamed from `host_id` to " "`process_index`. This alias will eventually be removed; please update " "your code.") process_index = host_id if process_index is None: process_index = get_backend(backend).process_index() if not (0 <= process_index < process_count()): raise ValueError(f"Unknown process_index {process_index}") return [d for d in devices(backend) if d.process_index == process_index] def process_index( backend: Optional[Union[str, xla_client.Client]] = None ) -> int: """Returns the integer process index of this process. On most platforms, this will always be 0. This will vary on multi-process platforms though. Args: backend: This is an experimental feature and the API is likely to change. Optional, a string representing the xla backend: ``'cpu'``, ``'gpu'``, or ``'tpu'``. Returns: Integer process index. """ return get_backend(backend).process_index() # TODO: remove this sometime after jax 0.2.13 is released def host_id(backend: Optional[Union[str, xla_client.Client]] = None) -> int: warnings.warn( "jax.host_id has been renamed to jax.process_index. This alias " "will eventually be removed; please update your code.") return process_index(backend) @lru_cache def process_count( backend: Optional[Union[str, xla_client.Client]] = None ) -> int: """Returns the number of JAX processes associated with the backend.""" return max(d.process_index for d in devices(backend)) + 1 # TODO: remove this sometime after jax 0.2.13 is released def host_count(backend: Optional[Union[str, xla_client.Client]] = None) -> int: warnings.warn( "jax.host_count has been renamed to jax.process_count. This alias " "will eventually be removed; please update your code.") return process_count(backend) # TODO: remove this sometime after jax 0.2.13 is released def host_ids( backend: Optional[Union[str, xla_client.Client]] = None ) -> List[int]: warnings.warn( "jax.host_ids has been deprecated; please use range(jax.process_count()) " "instead. jax.host_ids will eventually be removed; please update your " "code.") return list(range(process_count(backend))) def using_pjrt_c_api(backend=None): return "PJRT C API" in get_backend(backend).platform_version # TODO(parkers): Get rid of this in favor of a generic way to get topologies. def make_pjrt_tpu_topology(topology_name=None, **kwargs): # TODO(b/261484192): Make a system for lazily loading libtpu.so and call # that inside make_tfrt_tpu_c_api_device_topology. get_backend() # Properly initialize libtpu.so. return xla_client.make_tfrt_tpu_c_api_device_topology( topology_name, **kwargs )