1# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Functions for configuring TensorFlow execution."""
16
17from __future__ import absolute_import
18from __future__ import division
19from __future__ import print_function
20
21from typing import Union
22
23from tensorflow.python.eager import context
24from tensorflow.python.util import _pywrap_tensor_float_32_execution
25from tensorflow.python.util import deprecation
26from tensorflow.python.util.tf_export import tf_export
27
28
29@tf_export('config.experimental.tensor_float_32_execution_enabled')
30def tensor_float_32_execution_enabled():
31  """Returns whether TensorFloat-32 is enabled.
32
33  By default, TensorFloat-32 is enabled, but this can be changed with
34  `tf.config.experimental.enable_tensor_float_32_execution`.
35
36  Returns:
37    True if TensorFloat-32 is enabled (the default) and False otherwise
38  """
39  return _pywrap_tensor_float_32_execution.is_enabled()
40
41
42@tf_export('config.experimental.enable_tensor_float_32_execution')
43def enable_tensor_float_32_execution(enabled):
44  """Enable or disable the use of TensorFloat-32 on supported hardware.
45
46  [TensorFloat-32](https://blogs.nvidia.com/blog/2020/05/14/tensorfloat-32-precision-format),
47  or TF32 for short, is a math mode for NVIDIA Ampere GPUs. TensorFloat-32
48  execution causes certain float32 ops, such as matrix multiplications and
49  convolutions, to run much faster on Ampere GPUs but with reduced precision.
50  This reduced precision should not impact convergence of deep learning models
51  in practice.
52
53  TensorFloat-32 is enabled by default. TensorFloat-32 is only supported on
54  Ampere GPUs, so all other hardware will use the full float32 precision
55  regardless of whether TensorFloat-32 is enabled or not. If you want to use the
56  full float32 precision on Ampere, you can disable TensorFloat-32 execution
57  with this function. For example:
58
59  ```python
60  x = tf.fill((2, 2), 1.0001)
61  y = tf.fill((2, 2), 1.)
62  # TensorFloat-32 is enabled, so matmul is run with reduced precision
63  print(tf.linalg.matmul(x, y))  # [[2., 2.], [2., 2.]]
64  tf.config.experimental.enable_tensor_float_32_execution(False)
65  # Matmul is run with full precision
66  print(tf.linalg.matmul(x, y))  # [[2.0002, 2.0002], [2.0002, 2.0002]]
67  ```
68
69  To check whether TensorFloat-32 execution is currently enabled, use
70  `tf.config.experimental.tensor_float_32_execution_enabled`.
71
72  If TensorFloat-32 is enabled, float32 inputs of supported ops, such as
73  `tf.linalg.matmul`, will be rounded from 23 bits of precision to 10 bits of
74  precision in most cases. This allows the ops to execute much faster by
75  utilizing the GPU's tensor cores. TensorFloat-32 has the same dynamic range as
76  float32, meaning it is no more likely to underflow or overflow than float32.
77  Ops still use float32 accumulation when TensorFloat-32 is enabled. Enabling or
78  disabling TensorFloat-32 only affects Ampere GPUs and subsequent GPUs that
79  support TensorFloat-32.
80
81  Note TensorFloat-32 is not always used in supported ops, as only inputs of
82  certain shapes are supported. Support for more input shapes and more ops may
83  be added in the future. As a result, precision of float32 ops may decrease in
84  minor versions of TensorFlow.
85
86  TensorFloat-32 is also used for some complex64 ops. Currently, TensorFloat-32
87  is used in fewer cases for complex64 as it is for float32.
88
89  Args:
90    enabled: Bool indicating whether to enable TensorFloat-32 execution.
91  """
92  _pywrap_tensor_float_32_execution.enable(enabled)
93
94
95@tf_export('config.threading.get_intra_op_parallelism_threads')
96def get_intra_op_parallelism_threads():
97  """Get number of threads used within an individual op for parallelism.
98
99  Certain operations like matrix multiplication and reductions can utilize
100  parallel threads for speed ups. A value of 0 means the system picks an
101  appropriate number.
102
103  Returns:
104    Number of parallel threads
105  """
106  return context.context().intra_op_parallelism_threads
107
108
109@tf_export('config.threading.set_intra_op_parallelism_threads')
110def set_intra_op_parallelism_threads(num_threads):
111  """Set number of threads used within an individual op for parallelism.
112
113  Certain operations like matrix multiplication and reductions can utilize
114  parallel threads for speed ups. A value of 0 means the system picks an
115  appropriate number.
116
117  Args:
118    num_threads: Number of parallel threads
119  """
120  context.context().intra_op_parallelism_threads = num_threads
121
122
123@tf_export('config.threading.get_inter_op_parallelism_threads')
124def get_inter_op_parallelism_threads():
125  """Get number of threads used for parallelism between independent operations.
126
127  Determines the number of threads used by independent non-blocking operations.
128  0 means the system picks an appropriate number.
129
130  Returns:
131    Number of parallel threads
132  """
133  return context.context().inter_op_parallelism_threads
134
135
136@tf_export('config.threading.set_inter_op_parallelism_threads')
137def set_inter_op_parallelism_threads(num_threads):
138  """Set number of threads used for parallelism between independent operations.
139
140  Determines the number of threads used by independent non-blocking operations.
141  0 means the system picks an appropriate number.
142
143  Args:
144    num_threads: Number of parallel threads
145  """
146  context.context().inter_op_parallelism_threads = num_threads
147
148
149@tf_export('config.optimizer.get_jit')
150def get_optimizer_jit() -> str:
151  """Returns JIT compilation configuration for code inside `tf.function`.
152
153  Possible return values:
154     -`"autoclustering"` if
155     [autoclustering](https://www.tensorflow.org/xla#auto-clustering) is enabled
156     - `""` when no default compilation is applied.
157  """
158  if context.context().optimizer_jit:
159    return 'autoclustering'
160  return ''
161
162
163@tf_export('config.optimizer.set_jit')
164@deprecation.deprecated_arg_values(
165    None,
166    '`True` setting is deprecated, use `autoclustering` instead.',
167    warn_once=True,
168    jit_config=True)
169def set_optimizer_jit(enabled: Union[bool, str]):
170  """Configure JIT compilation.
171
172  Note: compilation is only applied to code that is compiled into a
173  graph (in TF2 that's only a code inside `tf.function`).
174
175  Args:
176    enabled: JIT compilation configuration.
177    Possible values:
178     - `"autoclustering"` (`True` is a deprecated alias): perform
179     [autoclustering](https://www.tensorflow.org/xla#auto-clustering)
180     (automatically identify and compile clusters of nodes) on all graphs using
181     [XLA](https://www.tensorflow.org/xla).
182     - `False`: do not automatically compile any graphs.
183  """
184  autoclustering_enabled = enabled in (True, 'autoclustering')
185  context.context().optimizer_jit = autoclustering_enabled
186
187
188@tf_export('config.optimizer.get_experimental_options')
189def get_optimizer_experimental_options():
190  """Get experimental optimizer options.
191
192  Refer to tf.config.optimizer.set_experimental_options for a list of current
193  options.
194
195  Note that optimizations are only applied in graph mode, (within tf.function).
196  In addition, as these are experimental options, the list is subject to change.
197
198  Returns:
199    Dictionary of configured experimental optimizer options
200  """
201  return context.context().get_optimizer_experimental_options()
202
203
204@tf_export('config.optimizer.set_experimental_options')
205def set_optimizer_experimental_options(options):
206  """Set experimental optimizer options.
207
208  Note that optimizations are only applied in graph mode, (within tf.function).
209  In addition, as these are experimental options, the list is subject to change.
210
211  Args:
212    options: Dictionary of experimental optimizer options to configure.
213      Valid keys:
214      - layout_optimizer: Optimize tensor layouts
215        e.g. This will try to use NCHW layout on GPU which is faster.
216      - constant_folding: Fold constants
217        Statically infer the value of tensors when possible, and materialize the
218        result using constants.
219      - shape_optimization: Simplify computations made on shapes.
220      - remapping: Remap subgraphs onto more efficient implementations.
221      - arithmetic_optimization: Simplify arithmetic ops with common
222        sub-expression elimination and arithmetic simplification.
223      - dependency_optimization: Control dependency optimizations. Remove
224        redundant control dependencies, which may enable other optimization.
225        This optimizer is also essential for pruning Identity and NoOp nodes.
226      - loop_optimization: Loop optimizations.
227      - function_optimization: Function optimizations and inlining.
228      - debug_stripper: Strips debug-related nodes from the graph.
229      - disable_model_pruning: Disable removal of unnecessary ops from the graph
230      - scoped_allocator_optimization: Try to allocate some independent Op
231        outputs contiguously in order to merge or eliminate downstream Ops.
232      - pin_to_host_optimization: Force small ops onto the CPU.
233      - implementation_selector: Enable the swap of kernel implementations based
234        on the device placement.
235      - auto_mixed_precision: Change certain float32 ops to float16 on Volta
236        GPUs and above. Without the use of loss scaling, this can cause
237        numerical underflow (see
238        `keras.mixed_precision.experimental.LossScaleOptimizer`).
239      - disable_meta_optimizer: Disable the entire meta optimizer.
240      - min_graph_nodes: The minimum number of nodes in a graph to optimizer.
241        For smaller graphs, optimization is skipped.
242  """
243  context.context().set_optimizer_experimental_options(options)
244
245
246@tf_export('config.get_soft_device_placement')
247def get_soft_device_placement():
248  """Get if soft device placement is enabled.
249
250  If enabled, an op will be placed on CPU if any of the following are true
251    1. there's no GPU implementation for the OP
252    2. no GPU devices are known or registered
253    3. need to co-locate with reftype input(s) which are from CPU
254
255  Returns:
256    If soft placement is enabled.
257  """
258  return context.context().soft_device_placement
259
260
261@tf_export('config.set_soft_device_placement')
262def set_soft_device_placement(enabled):
263  """Set if soft device placement is enabled.
264
265  If enabled, an op will be placed on CPU if any of the following are true
266    1. there's no GPU implementation for the OP
267    2. no GPU devices are known or registered
268    3. need to co-locate with reftype input(s) which are from CPU
269
270  Args:
271    enabled: Whether to enable soft placement.
272  """
273  context.context().soft_device_placement = enabled
274
275
276@tf_export('config.experimental.get_device_policy')
277def get_device_policy():
278  """Gets the current device policy.
279
280  The device policy controls how operations requiring inputs on a specific
281  device (e.g., on GPU:0) handle inputs on a different device (e.g. GPU:1).
282
283  This function only gets the device policy for the current thread. Any
284  subsequently started thread will again use the default policy.
285
286  Returns:
287    Current thread device policy
288  """
289  device_policy = context.context().device_policy
290  if device_policy == context.DEVICE_PLACEMENT_SILENT:
291    return 'silent'
292  elif device_policy == context.DEVICE_PLACEMENT_SILENT_FOR_INT32:
293    return 'silent_for_int32'
294  elif device_policy == context.DEVICE_PLACEMENT_WARN:
295    return 'warn'
296  elif device_policy == context.DEVICE_PLACEMENT_EXPLICIT:
297    return 'explicit'
298  else:
299    raise ValueError('Not a valid device policy: %r' % device_policy)
300
301
302@tf_export('config.experimental.set_device_policy')
303def set_device_policy(device_policy):
304  """Sets the current thread device policy.
305
306  The device policy controls how operations requiring inputs on a specific
307  device (e.g., on GPU:0) handle inputs on a different device (e.g. GPU:1).
308
309  When using the default, an appropriate policy will be picked automatically.
310  The default policy may change over time.
311
312  This function only sets the device policy for the current thread. Any
313  subsequently started thread will again use the default policy.
314
315  Args:
316    device_policy: A device policy.
317      Valid values:
318      - None: Switch to a system default.
319      - 'warn': Copies the tensors which are not on the right device and logs
320          a warning.
321      - 'explicit': Raises an error if the placement is not as required.
322      - 'silent': Silently copies the tensors. Note that this may hide
323          performance problems as there is no notification provided when
324          operations are blocked on the tensor being copied between devices.
325      - 'silent_for_int32': silently copies `int32` tensors, raising errors on
326          the other ones.
327
328  Raises:
329      ValueError: If an invalid `device_policy` is passed.
330  """
331  if device_policy == 'silent':
332    context.context().device_policy = context.DEVICE_PLACEMENT_SILENT
333  elif device_policy == 'silent_for_int32':
334    context.context().device_policy = context.DEVICE_PLACEMENT_SILENT_FOR_INT32
335  elif device_policy == 'warn':
336    context.context().device_policy = context.DEVICE_PLACEMENT_WARN
337  elif device_policy == 'explicit':
338    context.context().device_policy = context.DEVICE_PLACEMENT_EXPLICIT
339  elif device_policy is None:
340    context.context().device_policy = None
341  else:
342    raise ValueError('Not a valid device policy: %r' % device_policy)
343
344
345@tf_export('config.experimental.get_synchronous_execution')
346def get_synchronous_execution():
347  """Gets whether operations are executed synchronously or asynchronously.
348
349  TensorFlow can execute operations synchronously or asynchronously. If
350  asynchronous execution is enabled, operations may return "non-ready" handles.
351
352  Returns:
353    Current thread execution mode
354  """
355  return context.context().execution_mode == context.SYNC
356
357
358@tf_export('config.experimental.set_synchronous_execution')
359def set_synchronous_execution(enable):
360  """Specifies whether operations are executed synchronously or asynchronously.
361
362  TensorFlow can execute operations synchronously or asynchronously. If
363  asynchronous execution is enabled, operations may return "non-ready" handles.
364
365  When `enable` is set to None, an appropriate value will be picked
366  automatically. The value picked may change between TensorFlow releases.
367
368  Args:
369    enable: Whether operations should be dispatched synchronously.
370      Valid values:
371      - None: sets the system default.
372      - True: executes each operation synchronously.
373      - False: executes each operation asynchronously.
374  """
375  if enable is None:
376    context.context().execution_mode = None
377  elif enable:
378    context.context().execution_mode = context.SYNC
379  else:
380    context.context().execution_mode = context.ASYNC
381
382
383@tf_export('config.list_physical_devices',
384           'config.experimental.list_physical_devices')
385@deprecation.deprecated_endpoints(
386    'config.experimental.list_physical_devices')
387def list_physical_devices(device_type=None):
388  """Return a list of physical devices visible to the host runtime.
389
390  Physical devices are hardware devices present on the host machine. By default
391  all discovered CPU and GPU devices are considered visible.
392
393  This API allows querying the physical hardware resources prior to runtime
394  initialization. Thus, giving an opportunity to call any additional
395  configuration APIs. This is in contrast to `tf.config.list_logical_devices`,
396  which triggers runtime initialization in order to list the configured devices.
397
398  The following example lists the number of visible GPUs on the host.
399
400  >>> physical_devices = tf.config.list_physical_devices('GPU')
401  >>> print("Num GPUs:", len(physical_devices))
402  Num GPUs: ...
403
404  However, the number of GPUs available to the runtime may change during runtime
405  initialization due to marking certain devices as not visible or configuring
406  multiple logical devices.
407
408  Args:
409    device_type: (optional string) Only include devices matching this device
410      type. For example "CPU" or "GPU".
411
412  Returns:
413    List of discovered `tf.config.PhysicalDevice` objects
414  """
415  return context.context().list_physical_devices(device_type)
416
417
418@tf_export('config.list_logical_devices',
419           'config.experimental.list_logical_devices')
420@deprecation.deprecated_endpoints(
421    'config.experimental.list_logical_devices')
422def list_logical_devices(device_type=None):
423  """Return a list of logical devices created by runtime.
424
425  Logical devices may correspond to physical devices or remote devices in the
426  cluster. Operations and tensors may be placed on these devices by using the
427  `name` of the `tf.config.LogicalDevice`.
428
429  Calling `tf.config.list_logical_devices` triggers the runtime to configure any
430  `tf.config.PhysicalDevice` visible to the runtime, thereby preventing
431  further configuration. To avoid runtime initialization, call
432  `tf.config.list_physical_devices` instead.
433
434  For example:
435
436  >>> logical_devices = tf.config.list_logical_devices('GPU')
437  >>> if len(logical_devices) > 0:
438  ...   # Allocate on GPU:0
439  ...   with tf.device(logical_devices[0].name):
440  ...     one = tf.constant(1)
441  ...   # Allocate on GPU:1
442  ...   with tf.device(logical_devices[1].name):
443  ...     two = tf.constant(2)
444
445  Args:
446    device_type: (optional string) Only include devices matching this device
447      type. For example "CPU" or "GPU".
448
449  Returns:
450    List of initialized `LogicalDevice`s
451  """
452  return context.context().list_logical_devices(device_type=device_type)
453
454
455@tf_export('config.get_visible_devices',
456           'config.experimental.get_visible_devices')
457@deprecation.deprecated_endpoints(
458    'config.experimental.get_visible_devices')
459def get_visible_devices(device_type=None):
460  """Get the list of visible physical devices.
461
462  Returns the list of `PhysicalDevice`s currently marked as visible to the
463  runtime. A visible device will have at least one `LogicalDevice` associated
464  with it once the runtime is initialized.
465
466  The following example verifies all visible GPUs have been disabled:
467
468  >>> physical_devices = tf.config.list_physical_devices('GPU')
469  >>> try:
470  ...   # Disable all GPUS
471  ...   tf.config.set_visible_devices([], 'GPU')
472  ...   visible_devices = tf.config.get_visible_devices()
473  ...   for device in visible_devices:
474  ...     assert device.device_type != 'GPU'
475  ... except:
476  ...   # Invalid device or cannot modify virtual devices once initialized.
477  ...   pass
478
479  Args:
480    device_type: (optional string) Only include devices matching this device
481      type. For example "CPU" or "GPU".
482
483  Returns:
484    List of visible `PhysicalDevice`s
485  """
486  return context.context().get_visible_devices(device_type)
487
488
489@tf_export('config.set_visible_devices',
490           'config.experimental.set_visible_devices')
491@deprecation.deprecated_endpoints(
492    'config.experimental.set_visible_devices')
493def set_visible_devices(devices, device_type=None):
494  """Set the list of visible devices.
495
496  Specifies which `PhysicalDevice` objects are visible to the runtime.
497  TensorFlow will only allocate memory and place operations on visible
498  physical devices, as otherwise no `LogicalDevice` will be created on them.
499  By default all discovered devices are marked as visible.
500
501  The following example demonstrates disabling the first GPU on the machine.
502
503  >>> physical_devices = tf.config.list_physical_devices('GPU')
504  >>> try:
505  ...   # Disable first GPU
506  ...   tf.config.set_visible_devices(physical_devices[1:], 'GPU')
507  ...   logical_devices = tf.config.list_logical_devices('GPU')
508  ...   # Logical device was not created for first GPU
509  ...   assert len(logical_devices) == len(physical_devices) - 1
510  ... except:
511  ...   # Invalid device or cannot modify virtual devices once initialized.
512  ...   pass
513
514  Args:
515    devices: List of `PhysicalDevice`s to make visible
516    device_type: (optional) Only configure devices matching this device type.
517      For example "CPU" or "GPU". Other devices will be left unaltered.
518
519  Raises:
520    ValueError: If argument validation fails.
521    RuntimeError: Runtime is already initialized.
522  """
523  context.context().set_visible_devices(devices, device_type)
524
525
526@tf_export('config.experimental.get_memory_info')
527def get_memory_info(device):
528  """Get memory info for the chosen device, as a dict.
529
530  This function returns a dict containing information about the device's memory
531  usage. For example:
532
533  >>> if tf.config.list_physical_devices('GPU'):
534  ...   # Returns a dict in the form {'current': <current mem usage>,
535  ...   #                             'peak': <peak mem usage>}
536  ...   tf.config.experimental.get_memory_info('GPU:0')
537
538  Currently returns the following keys:
539    `'current'`: The current memory used by the device, in bytes.
540    `'peak'`: The peak memory used by the device across the run of the program,
541        in bytes.
542
543  More keys may be added in the future, including device-specific keys.
544
545  Currently raises an exception for the CPU.
546
547  For GPUs, TensorFlow will allocate all the memory by default, unless changed
548  with `tf.config.experimental.set_memory_growth`. The dict specifies only the
549  current and peak memory that TensorFlow is actually using, not the memory that
550  TensorFlow has allocated on the GPU.
551
552  Args:
553    device: Device string to get the memory information for, e.g. `"GPU:0"`. See
554      https://www.tensorflow.org/api_docs/python/tf/device for specifying device
555      strings.
556
557  Returns:
558    A dict with keys `'current'` and `'peak'`, specifying the current and peak
559    memory usage respectively.
560
561  Raises:
562    ValueError: Non-existent or CPU device specified.
563
564  """
565  return context.context().get_memory_info(device)
566
567
568@deprecation.deprecated(
569    None,
570    "Use tf.config.experimental.get_memory_info(device)['current'] instead.")
571@tf_export('config.experimental.get_memory_usage')
572def get_memory_usage(device):
573  """Get the current memory usage, in bytes, for the chosen device.
574
575  This function is deprecated in favor of
576  `tf.config.experimental.get_memory_info`. Calling this function is equivalent
577  to calling `tf.config.experimental.get_memory_info()['current']`.
578
579  See https://www.tensorflow.org/api_docs/python/tf/device for specifying device
580  strings.
581
582  For example:
583
584  >>> gpu_devices = tf.config.list_physical_devices('GPU')
585  >>> if gpu_devices:
586  ...   tf.config.experimental.get_memory_usage('GPU:0')
587
588  Does not work for CPU.
589
590  For GPUs, TensorFlow will allocate all the memory by default, unless changed
591  with `tf.config.experimental.set_memory_growth`. This function only returns
592  the memory that TensorFlow is actually using, not the memory that TensorFlow
593  has allocated on the GPU.
594
595  Args:
596    device: Device string to get the bytes in use for, e.g. `"GPU:0"`
597
598  Returns:
599    Total memory usage in bytes.
600
601  Raises:
602    ValueError: Non-existent or CPU device specified.
603  """
604  return get_memory_info(device)['current']
605
606
607@tf_export('config.experimental.get_memory_growth')
608def get_memory_growth(device):
609  """Get if memory growth is enabled for a `PhysicalDevice`.
610
611  If memory growth is enabled for a `PhysicalDevice`, the runtime initialization
612  will not allocate all memory on the device.
613
614  For example:
615
616  >>> physical_devices = tf.config.list_physical_devices('GPU')
617  >>> try:
618  ...   tf.config.experimental.set_memory_growth(physical_devices[0], True)
619  ...   assert tf.config.experimental.get_memory_growth(physical_devices[0])
620  ... except:
621  ...   # Invalid device or cannot modify virtual devices once initialized.
622  ...   pass
623
624  Args:
625    device: `PhysicalDevice` to query
626
627  Returns:
628    A boolean indicating the memory growth setting for the `PhysicalDevice`.
629
630  Raises:
631    ValueError: Invalid `PhysicalDevice` specified.
632  """
633  return context.context().get_memory_growth(device)
634
635
636@tf_export('config.experimental.set_memory_growth')
637def set_memory_growth(device, enable):
638  """Set if memory growth should be enabled for a `PhysicalDevice`.
639
640  If memory growth is enabled for a `PhysicalDevice`, the runtime initialization
641  will not allocate all memory on the device. Memory growth cannot be configured
642  on a `PhysicalDevice` with virtual devices configured.
643
644  For example:
645
646  >>> physical_devices = tf.config.list_physical_devices('GPU')
647  >>> try:
648  ...   tf.config.experimental.set_memory_growth(physical_devices[0], True)
649  ... except:
650  ...   # Invalid device or cannot modify virtual devices once initialized.
651  ...   pass
652
653  Args:
654    device: `PhysicalDevice` to configure
655    enable: (Boolean) Whether to enable or disable memory growth
656
657  Raises:
658    ValueError: Invalid `PhysicalDevice` specified.
659    RuntimeError: Runtime is already initialized.
660  """
661  context.context().set_memory_growth(device, enable)
662
663
664@tf_export('config.experimental.get_device_details')
665def get_device_details(device):
666  """Returns details about a physical devices.
667
668  This API takes in a `tf.config.PhysicalDevice` returned by
669  `tf.config.list_physical_devices`. It returns a dict with string keys
670  containing various details about the device. Each key is only supported by a
671  subset of devices, so you should not assume the returned dict will have any
672  particular key.
673
674  >>> gpu_devices = tf.config.list_physical_devices('GPU')
675  >>> if gpu_devices:
676  ...   details = tf.config.experimental.get_device_details(gpu_devices[0])
677  ...   details.get('device_name', 'Unknown GPU')
678
679  Currently, details are only returned for GPUs. This function returns an
680  empty dict if passed a non-GPU device.
681
682  The returned dict may have the following keys:
683  * `'device_name'`: A human-readable name of the device as a string, e.g.
684    "Titan V". Unlike `tf.config.PhysicalDevice.name`, this will be the same for
685    multiple devices if each device is the same model. Currently only available
686    for GPUs.
687  * `'compute_capability'`: The
688    [compute capability](https://developer.nvidia.com/cuda-gpus) of the device
689    as a tuple of two ints, in the form `(major_version, minor_version)`. Only
690    available for NVIDIA GPUs
691
692  Note: This is similar to `tf.sysconfig.get_build_info` in that both functions
693  can return information relating to GPUs. However, this function returns
694  run-time information about a specific device (such as a GPU's compute
695  capability), while `tf.sysconfig.get_build_info` returns compile-time
696  information about how TensorFlow was built (such as what version of CUDA
697  TensorFlow was built for).
698
699  Args:
700    device: A `tf.config.PhysicalDevice` returned by
701      `tf.config.list_physical_devices` or `tf.config.get_visible_devices`.
702
703  Returns:
704    A dict with string keys.
705  """
706  return context.context().get_device_details(device)
707
708
709@tf_export('config.get_logical_device_configuration',
710           'config.experimental.get_virtual_device_configuration')
711@deprecation.deprecated_endpoints(
712    'config.experimental.get_virtual_device_configuration')
713def get_logical_device_configuration(device):
714  """Get the virtual device configuration for a `tf.config.PhysicalDevice`.
715
716  Returns the list of `tf.config.LogicalDeviceConfiguration`
717  objects previously configured by a call to
718  `tf.config.set_logical_device_configuration`.
719
720  For example:
721
722  >>> physical_devices = tf.config.list_physical_devices('CPU')
723  >>> assert len(physical_devices) == 1, "No CPUs found"
724  >>> configs = tf.config.get_logical_device_configuration(
725  ...   physical_devices[0])
726  >>> try:
727  ...   assert configs is None
728  ...   tf.config.set_logical_device_configuration(
729  ...     physical_devices[0],
730  ...     [tf.config.LogicalDeviceConfiguration(),
731  ...      tf.config.LogicalDeviceConfiguration()])
732  ...   configs = tf.config.get_logical_device_configuration(
733  ...     physical_devices[0])
734  ...   assert len(configs) == 2
735  ... except:
736  ...   # Cannot modify virtual devices once initialized.
737  ...   pass
738
739  Args:
740    device: `PhysicalDevice` to query
741
742  Returns:
743    List of `tf.config.LogicalDeviceConfiguration` objects or
744    `None` if no virtual device configuration has been set for this physical
745    device.
746  """
747  return context.context().get_logical_device_configuration(device)
748
749
750@tf_export('config.set_logical_device_configuration',
751           'config.experimental.set_virtual_device_configuration')
752@deprecation.deprecated_endpoints(
753    'config.experimental.set_virtual_device_configuration')
754def set_logical_device_configuration(device, logical_devices):
755  """Set the logical device configuration for a `tf.config.PhysicalDevice`.
756
757  A visible `tf.config.PhysicalDevice` will by default have a single
758  `tf.config.LogicalDevice` associated with it once the runtime is initialized.
759  Specifying a list of `tf.config.LogicalDeviceConfiguration` objects allows
760  multiple devices to be created on the same `tf.config.PhysicalDevice`.
761
762  The following example splits the CPU into 2 logical devices:
763
764  >>> physical_devices = tf.config.list_physical_devices('CPU')
765  >>> assert len(physical_devices) == 1, "No CPUs found"
766  >>> # Specify 2 virtual CPUs. Note currently memory limit is not supported.
767  >>> try:
768  ...   tf.config.set_logical_device_configuration(
769  ...     physical_devices[0],
770  ...     [tf.config.LogicalDeviceConfiguration(),
771  ...      tf.config.LogicalDeviceConfiguration()])
772  ...   logical_devices = tf.config.list_logical_devices('CPU')
773  ...   assert len(logical_devices) == 2
774  ...
775  ...   tf.config.set_logical_device_configuration(
776  ...     physical_devices[0],
777  ...     [tf.config.LogicalDeviceConfiguration(),
778  ...      tf.config.LogicalDeviceConfiguration(),
779  ...      tf.config.LogicalDeviceConfiguration(),
780  ...      tf.config.LogicalDeviceConfiguration()])
781  ... except:
782  ...   # Cannot modify logical devices once initialized.
783  ...   pass
784
785  The following example splits the GPU into 2 logical devices with 100 MB each:
786
787  >>> physical_devices = tf.config.list_physical_devices('GPU')
788  >>> try:
789  ...   tf.config.set_logical_device_configuration(
790  ...     physical_devices[0],
791  ...     [tf.config.LogicalDeviceConfiguration(memory_limit=100),
792  ...      tf.config.LogicalDeviceConfiguration(memory_limit=100)])
793  ...
794  ...   logical_devices = tf.config.list_logical_devices('GPU')
795  ...   assert len(logical_devices) == len(physical_devices) + 1
796  ...
797  ...   tf.config.set_logical_device_configuration(
798  ...     physical_devices[0],
799  ...     [tf.config.LogicalDeviceConfiguration(memory_limit=10),
800  ...      tf.config.LogicalDeviceConfiguration(memory_limit=10)])
801  ... except:
802  ...   # Invalid device or cannot modify logical devices once initialized.
803  ...   pass
804
805  Args:
806    device: The `PhysicalDevice` to configure.
807    logical_devices: (optional) List of `tf.config.LogicalDeviceConfiguration`
808      objects to allocate for the specified `PhysicalDevice`. If None, the
809      default configuration will be used.
810
811  Raises:
812    ValueError: If argument validation fails.
813    RuntimeError: Runtime is already initialized.
814  """
815  context.context().set_logical_device_configuration(device, logical_devices)
816
817
818@tf_export('config.experimental.enable_mlir_bridge')
819def enable_mlir_bridge():
820  """Enables experimental MLIR-Based TensorFlow Compiler Bridge.
821
822  DO NOT USE, DEV AND TESTING ONLY AT THE MOMENT.
823
824  NOTE: MLIR-Based TensorFlow Compiler is under active development and has
825  missing features, please refrain from using. This API exists for development
826  and testing only.
827
828  TensorFlow Compiler Bridge (TF Bridge) is responsible for translating parts
829  of TensorFlow graph into a form that can be accepted as an input by a backend
830  compiler such as XLA.
831  """
832  context.context().enable_mlir_bridge = True
833
834
835@tf_export('config.experimental.enable_mlir_graph_optimization')
836def enable_mlir_graph_optimization():
837  """Enables experimental MLIR-Based TensorFlow Compiler Optimizations.
838
839  DO NOT USE, DEV AND TESTING ONLY AT THE MOMENT.
840
841  NOTE: MLIR-Based TensorFlow Compiler is under active development and has
842  missing features, please refrain from using. This API exists for development
843  and testing only.
844
845  TensorFlow Compiler Optimizations are responsible general graph level
846  optimizations that in the current stack mostly done by Grappler graph
847  optimizers.
848  """
849  context.context().enable_mlir_graph_optimization = True
850
851
852@tf_export('config.experimental.disable_mlir_bridge')
853def disable_mlir_bridge():
854  """Disables experimental MLIR-Based TensorFlow Compiler Bridge."""
855  context.context().enable_mlir_bridge = False
856
857
858@tf_export('config.experimental.disable_mlir_graph_optimization')
859def disable_mlir_graph_optimization():
860  """Disables experimental MLIR-Based TensorFlow Compiler Optimizations."""
861  context.context().enable_mlir_graph_optimization = False
862