1# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Functions for configuring TensorFlow execution."""
16
17from __future__ import absolute_import
18from __future__ import division
19from __future__ import print_function
20
21from tensorflow.python.eager import context
22from tensorflow.python.util.tf_export import tf_export
23
24
25@tf_export('config.gpu.get_per_process_memory_fraction')
26def get_gpu_per_process_memory_fraction():
27  """Get fraction of the available GPU memory to allocate for each process.
28
29  1.0 means to allocate all of the GPU memory, 0.5 means the process allocates
30  up to half of the available GPU memory.
31
32  Returns:
33    Current GPU per process memory fraction
34  """
35  return context.context().gpu_per_process_memory_fraction
36
37
38@tf_export('config.gpu.set_per_process_memory_fraction')
39def set_gpu_per_process_memory_fraction(fraction):
40  """Set fraction of the available GPU memory to allocate for each process.
41
42  1.0 means to allocate all of the GPU memory, 0.5 means the process allocates
43  up to half of the available GPU memory.
44
45  Args:
46    fraction: Fraction of GPU memory to allocate
47  """
48  context.context().gpu_per_process_memory_fraction = fraction
49
50
51@tf_export('config.gpu.get_per_process_memory_growth')
52def get_gpu_per_process_memory_growth():
53  """Get if GPU memory should be pre-allocated or allowed to grow.
54
55  Returns:
56    If GPU memory growth should be enabled
57  """
58  return context.context().gpu_per_process_memory_growth
59
60
61@tf_export('config.gpu.set_per_process_memory_growth')
62def set_gpu_per_process_memory_growth(enabled):
63  """Set if GPU memory should be pre-allocated or allowed to grow.
64
65  Args:
66    enabled: Indicates if GPU memory growth should be enabled
67  """
68  context.context().gpu_per_process_memory_growth = enabled
69
70
71@tf_export('config.threading.intra_op_parallelism_threads')
72def get_intra_op_parallelism_threads():
73  """Get number of threads used within an individual op for parallelism.
74
75  Certain operations like matrix multiplication and reductions can utilize
76  parellel threads for speed ups. A value of 0 means the system picks an
77  appropriate number.
78
79  Returns:
80    Number of parallel threads
81  """
82  return context.context().intra_op_parallelism_threads
83
84
85@tf_export('config.threading.set_intra_op_parallelism_threads')
86def set_intra_op_parallelism_threads(num_threads):
87  """Set number of threads used within an individual op for parallelism.
88
89  Certain operations like matrix multiplication and reductions can utilize
90  parellel threads for speed ups. A value of 0 means the system picks an
91  appropriate number.
92
93  Args:
94    num_threads: Number of parallel threads
95  """
96  context.context().intra_op_parallelism_threads = num_threads
97
98
99@tf_export('config.threading.inter_op_parallelism_threads')
100def get_inter_op_parallelism_threads():
101  """Get number of threads used for parallelism between independent operations.
102
103  Determines the number of threads used by independent non-blokcing operations.
104  0 means the system picks an appropriate number.
105
106  Returns:
107    Number of parallel threads
108  """
109  return context.context().inter_op_parallelism_threads
110
111
112@tf_export('config.threading.set_inter_op_parallelism_threads')
113def set_inter_op_parallelism_threads(num_threads):
114  """Set number of threads used for parallelism between independent operations.
115
116  Determines the number of threads used by independent non-blokcing operations.
117  0 means the system picks an appropriate number.
118
119  Args:
120    num_threads: Number of parallel threads
121  """
122  context.context().inter_op_parallelism_threads = num_threads
123
124
125@tf_export('config.get_soft_device_placement')
126def get_soft_device_placement():
127  """Get if soft device placement is enabled.
128
129  If enabled, an op will be placed on CPU if any of the following are true
130    1. there's no GPU implementation for the OP
131    2. no GPU devices are known or registered
132    3. need to co-locate with reftype input(s) which are from CPU
133
134  Returns:
135    If soft placement is enabled.
136  """
137  return context.context().soft_device_placement
138
139
140@tf_export('config.set_soft_device_placement')
141def set_soft_device_placement(enabled):
142  """Set if soft device placement is enabled.
143
144  If enabled, an op will be placed on CPU if any of the following are true
145    1. there's no GPU implementation for the OP
146    2. no GPU devices are known or registered
147    3. need to co-locate with reftype input(s) which are from CPU
148
149  Args:
150    enabled: Whether to enabled soft placement.
151  """
152  context.context().soft_device_placement = enabled
153
154
155@tf_export('config.experimental.get_device_policy')
156def get_device_policy():
157  """Gets the current device policy.
158
159  The device policy controls how operations requiring inputs on a specific
160  device (e.g., on GPU:0) handle inputs on a different device (e.g. GPU:1).
161
162  This function only gets the device policy for the current thread. Any
163  subsequently started thread will again use the default policy.
164
165  Returns:
166    Current thread device policy
167  """
168  device_policy = context.context().device_policy
169  if device_policy == context.DEVICE_PLACEMENT_SILENT:
170    return 'silent'
171  elif device_policy == context.DEVICE_PLACEMENT_SILENT_FOR_INT32:
172    return 'silent_for_int32'
173  elif device_policy == context.DEVICE_PLACEMENT_WARN:
174    return 'warn'
175  elif device_policy == context.DEVICE_PLACEMENT_EXPLICIT:
176    return 'explicit'
177  else:
178    raise ValueError('Not a valid device policy: %r' % device_policy)
179
180
181@tf_export('config.experimental.set_device_policy')
182def set_device_policy(device_policy):
183  """Sets the current thread device policy.
184
185  The device policy controls how operations requiring inputs on a specific
186  device (e.g., on GPU:0) handle inputs on a different device (e.g. GPU:1).
187
188  When using the default, an appropriate policy will be picked automatically.
189  The default policy may change over time.
190
191  This function only sets the device policy for the current thread. Any
192  subsequently started thread will again use the default policy.
193
194  Args:
195    device_policy: A device policy.
196      Valid values:
197      - None: Switch to a system default.
198      - 'warn': Copies the tensors which are not on the right device and logs
199          a warning.
200      - 'explicit': Raises an error if the placement is not as required.
201      - 'silent': Silently copies the tensors. Note that this may hide
202          performance problems as there is no notification provided when
203          operations are blocked on the tensor being copied between devices.
204      - 'silent_for_int32': silently copies `int32` tensors, raising errors on
205          the other ones.
206
207  Raises:
208      ValueError: If an invalid `device_policy` is passed.
209  """
210  if device_policy == 'silent':
211    context.context().device_policy = context.DEVICE_PLACEMENT_SILENT
212  elif device_policy == 'silent_for_int32':
213    context.context().device_policy = context.DEVICE_PLACEMENT_SILENT_FOR_INT32
214  elif device_policy == 'warn':
215    context.context().device_policy = context.DEVICE_PLACEMENT_WARN
216  elif device_policy == 'explicit':
217    context.context().device_policy = context.DEVICE_PLACEMENT_EXPLICIT
218  elif device_policy is None:
219    context.context().device_policy = None
220  else:
221    raise ValueError('Not a valid device policy: %r' % device_policy)
222
223
224@tf_export('config.experimental.get_synchronous_execution')
225def get_synchronous_execution():
226  """Gets whether operations are executed synchronously or asynchronously.
227
228  TensorFlow can execute operations synchronously or asynchronously. If
229  asynchronous execution is enabled, operations may return "non-ready" handles.
230
231  Returns:
232    Current thread execution mode
233  """
234  return context.context().execution_mode == context.SYNC
235
236
237@tf_export('config.experimental.set_synchronous_execution')
238def set_synchronous_execution(enable):
239  """Specifies whether operations are executed synchronously or asynchronously.
240
241  TensorFlow can execute operations synchronously or asynchronously. If
242  asynchronous execution is enabled, operations may return "non-ready" handles.
243
244  When `enable` is set to None, an appropriate value will be picked
245  automatically. The value picked may change between TensorFlow releases.
246
247  Args:
248    enable: Whether operations should be dispatched synchronously.
249      Valid values:
250      - None: sets the system default.
251      - True: executes each operation synchronously.
252      - False: executes each operation asynchronously.
253  """
254  if enable is None:
255    context.context().execution_mode = None
256  elif enable:
257    context.context().execution_mode = context.SYNC
258  else:
259    context.context().execution_mode = context.ASYNC
260