1# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# =============================================================================
15
16"""Functional operations."""
17
18from __future__ import absolute_import
19from __future__ import division
20from __future__ import print_function
21
22from tensorflow.core.framework import attr_value_pb2
23from tensorflow.python.eager import context
24from tensorflow.python.framework import constant_op
25from tensorflow.python.framework import dtypes
26from tensorflow.python.framework import function
27from tensorflow.python.framework import ops
28from tensorflow.python.framework import tensor_shape
29from tensorflow.python.ops import array_ops
30from tensorflow.python.ops import control_flow_ops
31from tensorflow.python.ops import gen_functional_ops
32from tensorflow.python.ops import math_ops
33from tensorflow.python.ops import tensor_array_ops
34from tensorflow.python.ops import variable_scope as vs
35# pylint: disable=unused-import
36from tensorflow.python.ops.gen_functional_ops import remote_call
37# pylint: enable=unused-import
38from tensorflow.python.ops.gen_functional_ops import symbolic_gradient
39from tensorflow.python.util import compat
40from tensorflow.python.util import function_utils
41from tensorflow.python.util import nest
42from tensorflow.python.util.tf_export import tf_export
43
44
45# TODO(yuanbyu, mrry): Handle stride to support sliding windows.
46@tf_export("foldl")
47def foldl(fn, elems, initializer=None, parallel_iterations=10, back_prop=True,
48          swap_memory=False, name=None):
49  """foldl on the list of tensors unpacked from `elems` on dimension 0.
50
51  This foldl operator repeatedly applies the callable `fn` to a sequence
52  of elements from first to last. The elements are made of the tensors
53  unpacked from `elems` on dimension 0. The callable fn takes two tensors as
54  arguments. The first argument is the accumulated value computed from the
55  preceding invocation of fn. If `initializer` is None, `elems` must contain
56  at least one element, and its first element is used as the initializer.
57
58  Suppose that `elems` is unpacked into `values`, a list of tensors. The shape
59  of the result tensor is fn(initializer, values[0]).shape`.
60
61  This method also allows multi-arity `elems` and output of `fn`.  If `elems`
62  is a (possibly nested) list or tuple of tensors, then each of these tensors
63  must have a matching first (unpack) dimension.  The signature of `fn` may
64  match the structure of `elems`.  That is, if `elems` is
65  `(t1, [t2, t3, [t4, t5]])`, then an appropriate signature for `fn` is:
66  `fn = lambda (t1, [t2, t3, [t4, t5]]):`.
67
68  Args:
69    fn: The callable to be performed.
70    elems: A tensor or (possibly nested) sequence of tensors, each of which
71      will be unpacked along their first dimension.  The nested sequence
72      of the resulting slices will be the first argument to `fn`.
73    initializer: (optional) A tensor or (possibly nested) sequence of tensors,
74      as the initial value for the accumulator.
75    parallel_iterations: (optional) The number of iterations allowed to run
76      in parallel.
77    back_prop: (optional) True enables support for back propagation.
78    swap_memory: (optional) True enables GPU-CPU memory swapping.
79    name: (optional) Name prefix for the returned tensors.
80
81  Returns:
82    A tensor or (possibly nested) sequence of tensors, resulting from applying
83    `fn` consecutively to the list of tensors unpacked from `elems`, from first
84    to last.
85
86  Raises:
87    TypeError: if `fn` is not callable.
88
89  Example:
90    ```python
91    elems = tf.constant([1, 2, 3, 4, 5, 6])
92    sum = foldl(lambda a, x: a + x, elems)
93    # sum == 21
94    ```
95  """
96  if not callable(fn):
97    raise TypeError("fn must be callable.")
98
99  def create_ta(elem):
100    return tensor_array_ops.TensorArray(
101        dtype=elem.dtype, size=n, dynamic_size=False,
102        infer_shape=True).unstack(elem)
103
104  in_graph_mode = not context.executing_eagerly()
105  with ops.name_scope(name, "foldl", [elems]):
106    # TODO(akshayka): Remove the in_graph_mode check once caching devices are
107    # supported in Eager
108    if in_graph_mode:
109      # Any get_variable calls in fn will cache the first call locally
110      # and not issue repeated network I/O requests for each iteration.
111      varscope = vs.get_variable_scope()
112      varscope_caching_device_was_none = False
113      if varscope.caching_device is None:
114        # TODO(ebrevdo): Change to using colocate_with here and in other
115        # methods.
116        varscope.set_caching_device(lambda op: op.device)
117        varscope_caching_device_was_none = True
118
119    # Convert elems to tensor array. n may be known statically.
120    elems_flat = [
121        ops.convert_to_tensor(elem, name="elem") for elem in nest.flatten(elems)
122    ]
123    n = (tensor_shape.dimension_value(elems_flat[0].shape[0])
124         or array_ops.shape(elems_flat[0])[0])
125
126    elems_ta = nest.map_structure(create_ta, elems)
127
128    if initializer is None:
129      a = nest.map_structure(lambda elem: elem.read(0), elems_ta)
130      i = constant_op.constant(1)
131    else:
132      a = initializer
133      i = constant_op.constant(0)
134
135    def compute(i, a):
136      elem_i = nest.map_structure(lambda elem: elem.read(i), elems_ta)
137      a = fn(a, elem_i)
138      return [i + 1, a]
139
140    _, r_a = control_flow_ops.while_loop(
141        lambda i, a: i < n, compute, [i, a],
142        parallel_iterations=parallel_iterations,
143        back_prop=back_prop,
144        swap_memory=swap_memory,
145        maximum_iterations=n)
146
147    # TODO(akshayka): Remove the in_graph_mode check once caching devices are
148    # supported in Eager
149    if in_graph_mode and varscope_caching_device_was_none:
150      varscope.set_caching_device(None)
151
152    return r_a
153
154
155@tf_export("foldr")
156def foldr(fn, elems, initializer=None, parallel_iterations=10, back_prop=True,
157          swap_memory=False, name=None):
158  """foldr on the list of tensors unpacked from `elems` on dimension 0.
159
160  This foldr operator repeatedly applies the callable `fn` to a sequence
161  of elements from last to first. The elements are made of the tensors
162  unpacked from `elems`. The callable fn takes two tensors as arguments.
163  The first argument is the accumulated value computed from the preceding
164  invocation of fn. If `initializer` is None, `elems` must contain at least
165  one element, and its first element is used as the initializer.
166
167  Suppose that `elems` is unpacked into `values`, a list of tensors. The shape
168  of the result tensor is `fn(initializer, values[0]).shape`.
169
170  This method also allows multi-arity `elems` and output of `fn`.  If `elems`
171  is a (possibly nested) list or tuple of tensors, then each of these tensors
172  must have a matching first (unpack) dimension.  The signature of `fn` may
173  match the structure of `elems`.  That is, if `elems` is
174  `(t1, [t2, t3, [t4, t5]])`, then an appropriate signature for `fn` is:
175  `fn = lambda (t1, [t2, t3, [t4, t5]]):`.
176
177  Args:
178    fn: The callable to be performed.
179    elems: A tensor or (possibly nested) sequence of tensors, each of which
180      will be unpacked along their first dimension.  The nested sequence
181      of the resulting slices will be the first argument to `fn`.
182    initializer: (optional) A tensor or (possibly nested) sequence of tensors,
183      as the initial value for the accumulator.
184    parallel_iterations: (optional) The number of iterations allowed to run
185      in parallel.
186    back_prop: (optional) True enables support for back propagation.
187    swap_memory: (optional) True enables GPU-CPU memory swapping.
188    name: (optional) Name prefix for the returned tensors.
189
190  Returns:
191    A tensor or (possibly nested) sequence of tensors, resulting from applying
192    `fn` consecutively to the list of tensors unpacked from `elems`, from last
193    to first.
194
195  Raises:
196    TypeError: if `fn` is not callable.
197
198  Example:
199    ```python
200    elems = [1, 2, 3, 4, 5, 6]
201    sum = foldr(lambda a, x: a + x, elems)
202    # sum == 21
203    ```
204  """
205  if not callable(fn):
206    raise TypeError("fn must be callable.")
207
208  def create_ta(elem):
209    return tensor_array_ops.TensorArray(
210        dtype=elem.dtype, size=n, dynamic_size=False,
211        infer_shape=True).unstack(elem)
212
213  in_graph_mode = not context.executing_eagerly()
214  with ops.name_scope(name, "foldr", [elems]):
215    # TODO(akshayka): Remove the in_graph_mode check once caching devices are
216    # supported in Eager
217    if in_graph_mode:
218      # Any get_variable calls in fn will cache the first call locally and not
219      # issue repeated network I/O requests for each iteration.
220      varscope = vs.get_variable_scope()
221      varscope_caching_device_was_none = False
222      if varscope.caching_device is None:
223        # TODO(ebrevdo): Change to using colocate_with here and in other
224        # methods.
225        varscope.set_caching_device(lambda op: op.device)
226        varscope_caching_device_was_none = True
227
228    # Convert elems to tensor array. n may be known statically.
229    elems_flat = [
230        ops.convert_to_tensor(elem, name="elem") for elem in nest.flatten(elems)
231    ]
232    n = (tensor_shape.dimension_value(elems_flat[0].shape[0])
233         or array_ops.shape(elems_flat[0])[0])
234
235    elems_ta = nest.map_structure(create_ta, elems)
236
237    if initializer is None:
238      i = n - 1
239      a = nest.map_structure(lambda elem: elem.read(i), elems_ta)
240    else:
241      i = n
242      a = initializer
243
244    def compute(i, a):
245      i -= 1
246      elem = nest.map_structure(lambda elem: elem.read(i), elems_ta)
247      a_out = fn(a, elem)
248      return [i, a_out]
249
250    _, r_a = control_flow_ops.while_loop(
251        lambda i, a: i > 0,
252        compute, [i, a],
253        parallel_iterations=parallel_iterations,
254        back_prop=back_prop,
255        swap_memory=swap_memory,
256        maximum_iterations=n)
257
258    # TODO(akshayka): Remove the in_graph_mode check once caching devices are
259    # supported in Eager
260    if in_graph_mode and varscope_caching_device_was_none:
261      varscope.set_caching_device(None)
262
263    return r_a
264
265
266@tf_export("scan")
267def scan(fn, elems, initializer=None, parallel_iterations=10, back_prop=True,
268         swap_memory=False, infer_shape=True, reverse=False, name=None):
269  """scan on the list of tensors unpacked from `elems` on dimension 0.
270
271  The simplest version of `scan` repeatedly applies the callable `fn` to a
272  sequence of elements from first to last. The elements are made of the tensors
273  unpacked from `elems` on dimension 0. The callable fn takes two tensors as
274  arguments. The first argument is the accumulated value computed from the
275  preceding invocation of fn. If `initializer` is None, `elems` must contain
276  at least one element, and its first element is used as the initializer.
277
278  Suppose that `elems` is unpacked into `values`, a list of tensors. The shape
279  of the result tensor is `[len(values)] + fn(initializer, values[0]).shape`.
280  If reverse=True, it's fn(initializer, values[-1]).shape.
281
282  This method also allows multi-arity `elems` and accumulator.  If `elems`
283  is a (possibly nested) list or tuple of tensors, then each of these tensors
284  must have a matching first (unpack) dimension.  The second argument of
285  `fn` must match the structure of `elems`.
286
287  If no `initializer` is provided, the output structure and dtypes of `fn`
288  are assumed to be the same as its input; and in this case, the first
289  argument of `fn` must match the structure of `elems`.
290
291  If an `initializer` is provided, then the output of `fn` must have the same
292  structure as `initializer`; and the first argument of `fn` must match
293  this structure.
294
295  For example, if `elems` is `(t1, [t2, t3])` and `initializer` is
296  `[i1, i2]` then an appropriate signature for `fn` in `python2` is:
297  `fn = lambda (acc_p1, acc_p2), (t1, [t2, t3]):` and `fn` must return a list,
298  `[acc_n1, acc_n2]`.  An alternative correct signature for `fn`, and the
299   one that works in `python3`, is:
300  `fn = lambda a, t:`, where `a` and `t` correspond to the input tuples.
301
302  Args:
303    fn: The callable to be performed.  It accepts two arguments.  The first
304      will have the same structure as `initializer` if one is provided,
305      otherwise it will have the same structure as `elems`.  The second
306      will have the same (possibly nested) structure as `elems`.  Its output
307      must have the same structure as `initializer` if one is provided,
308      otherwise it must have the same structure as `elems`.
309    elems: A tensor or (possibly nested) sequence of tensors, each of which
310      will be unpacked along their first dimension.  The nested sequence
311      of the resulting slices will be the first argument to `fn`.
312    initializer: (optional) A tensor or (possibly nested) sequence of tensors,
313      initial value for the accumulator, and the expected output type of `fn`.
314    parallel_iterations: (optional) The number of iterations allowed to run
315      in parallel.
316    back_prop: (optional) True enables support for back propagation.
317    swap_memory: (optional) True enables GPU-CPU memory swapping.
318    infer_shape: (optional) False disables tests for consistent output shapes.
319    reverse: (optional) True scans the tensor last to first (instead of first
320      to last).
321    name: (optional) Name prefix for the returned tensors.
322
323  Returns:
324    A tensor or (possibly nested) sequence of tensors.  Each tensor packs the
325    results of applying `fn` to tensors unpacked from `elems` along the first
326    dimension, and the previous accumulator value(s), from first to last (or
327    last to first, if `reverse=True`).
328
329  Raises:
330    TypeError: if `fn` is not callable or the structure of the output of
331      `fn` and `initializer` do not match.
332    ValueError: if the lengths of the output of `fn` and `initializer`
333      do not match.
334
335  Examples:
336    ```python
337    elems = np.array([1, 2, 3, 4, 5, 6])
338    sum = scan(lambda a, x: a + x, elems)
339    # sum == [1, 3, 6, 10, 15, 21]
340    sum = scan(lambda a, x: a + x, elems, reverse=True)
341    # sum == [22, 21, 18, 15, 11, 6]
342    ```
343
344    ```python
345    elems = np.array([1, 2, 3, 4, 5, 6])
346    initializer = np.array(0)
347    sum_one = scan(
348        lambda a, x: x[0] - x[1] + a, (elems + 1, elems), initializer)
349    # sum_one == [1, 2, 3, 4, 5, 6]
350    ```
351
352    ```python
353    elems = np.array([1, 0, 0, 0, 0, 0])
354    initializer = (np.array(0), np.array(1))
355    fibonaccis = scan(lambda a, _: (a[1], a[0] + a[1]), elems, initializer)
356    # fibonaccis == ([1, 1, 2, 3, 5, 8], [1, 2, 3, 5, 8, 13])
357    ```
358  """
359  if not callable(fn):
360    raise TypeError("fn must be callable.")
361
362  input_is_sequence = nest.is_sequence(elems)
363  input_flatten = lambda x: nest.flatten(x) if input_is_sequence else [x]
364  def input_pack(x):
365    return nest.pack_sequence_as(elems, x) if input_is_sequence else x[0]
366
367  if initializer is None:
368    output_is_sequence = input_is_sequence
369    output_flatten = input_flatten
370    output_pack = input_pack
371  else:
372    output_is_sequence = nest.is_sequence(initializer)
373    output_flatten = lambda x: nest.flatten(x) if output_is_sequence else [x]
374    def output_pack(x):
375      return (nest.pack_sequence_as(initializer, x)
376              if output_is_sequence else x[0])
377
378  elems_flat = input_flatten(elems)
379
380  in_graph_mode = not context.executing_eagerly()
381  with ops.name_scope(name, "scan", elems_flat):
382    # TODO(akshayka): Remove the in_graph_mode check once caching devices are
383    # supported in Eager
384    if in_graph_mode:
385      # Any get_variable calls in fn will cache the first call locally
386      # and not issue repeated network I/O requests for each iteration.
387      varscope = vs.get_variable_scope()
388      varscope_caching_device_was_none = False
389      if varscope.caching_device is None:
390        # TODO(ebrevdo): Change to using colocate_with here and in other
391        # methods.
392        varscope.set_caching_device(lambda op: op.device)
393        varscope_caching_device_was_none = True
394
395    # Convert elems to tensor array.
396    elems_flat = [
397        ops.convert_to_tensor(elem, name="elem") for elem in elems_flat]
398
399    # Convert elems to tensor array. n may be known statically.
400    n = tensor_shape.dimension_value(elems_flat[0].shape[0])
401    if n is None:
402      n = array_ops.shape(elems_flat[0])[0]
403
404    # TensorArrays are always flat
405    elems_ta = [
406        tensor_array_ops.TensorArray(dtype=elem.dtype, size=n,
407                                     dynamic_size=False,
408                                     element_shape=elem.shape[1:],
409                                     infer_shape=True)
410        for elem in elems_flat]
411    # Unpack elements
412    elems_ta = [
413        elem_ta.unstack(elem) for elem_ta, elem in zip(elems_ta, elems_flat)]
414
415    if initializer is None:
416      a_flat = [elem.read(n - 1 if reverse else 0) for elem in elems_ta]
417      i = constant_op.constant(1)
418    else:
419      initializer_flat = output_flatten(initializer)
420      a_flat = [ops.convert_to_tensor(init) for init in initializer_flat]
421      i = constant_op.constant(0)
422
423    # Create a tensor array to store the intermediate values.
424    accs_ta = [
425        tensor_array_ops.TensorArray(
426            dtype=init.dtype, size=n,
427            element_shape=init.shape if infer_shape else None,
428            dynamic_size=False,
429            infer_shape=infer_shape)
430        for init in a_flat]
431
432    if initializer is None:
433      accs_ta = [acc_ta.write(n - 1 if reverse else 0, a)
434                 for (acc_ta, a) in zip(accs_ta, a_flat)]
435
436    def compute(i, a_flat, tas):
437      """The loop body of scan.
438
439      Args:
440        i: the loop counter.
441        a_flat: the accumulator value(s), flattened.
442        tas: the output accumulator TensorArray(s), flattened.
443
444      Returns:
445        [i + 1, a_flat, tas]: the updated counter + new accumulator values +
446          updated TensorArrays
447
448      Raises:
449        TypeError: if initializer and fn() output structure do not match
450        ValueType: if initializer and fn() output lengths do not match
451      """
452      packed_elems = input_pack([elem_ta.read(i) for elem_ta in elems_ta])
453      packed_a = output_pack(a_flat)
454      a_out = fn(packed_a, packed_elems)
455      nest.assert_same_structure(
456          elems if initializer is None else initializer, a_out)
457      flat_a_out = output_flatten(a_out)
458      tas = [ta.write(i, value) for (ta, value) in zip(tas, flat_a_out)]
459      if reverse:
460        next_i = i - 1
461      else:
462        next_i = i + 1
463      return (next_i, flat_a_out, tas)
464
465    if reverse:
466      initial_i = n - 1 - i
467      condition = lambda i, _1, _2: i >= 0
468    else:
469      initial_i = i
470      condition = lambda i, _1, _2: i < n
471    _, _, r_a = control_flow_ops.while_loop(
472        condition, compute, (initial_i, a_flat, accs_ta),
473        parallel_iterations=parallel_iterations,
474        back_prop=back_prop, swap_memory=swap_memory,
475        maximum_iterations=n)
476
477    results_flat = [r.stack() for r in r_a]
478
479    n_static = tensor_shape.Dimension(tensor_shape.dimension_value(
480        elems_flat[0].get_shape().with_rank_at_least(1)[0]))
481    for elem in elems_flat[1:]:
482      n_static.merge_with(tensor_shape.Dimension(tensor_shape.dimension_value(
483          elem.get_shape().with_rank_at_least(1)[0])))
484    for r in results_flat:
485      r.set_shape(tensor_shape.TensorShape(n_static).concatenate(
486          r.get_shape()[1:]))
487
488    # TODO(akshayka): Remove the in_graph_mode check once caching devices are
489    # supported in Eager
490    if in_graph_mode and varscope_caching_device_was_none:
491      varscope.set_caching_device(None)
492
493    return output_pack(results_flat)
494
495
496# pylint: disable=invalid-name
497def If(cond, inputs, then_branch, else_branch, name=None):
498  r"""output = Cond(inputs) ? then_branch(inputs) : else_branch(inputs).
499
500  Args:
501    cond: A `Tensor`. A scalar. If the scalar is not a boolean, the scalar is
502      converted to a boolean according to the following rule: if the
503      scalar is a numerical value, non-zero means True and zero means
504      False; if the scalar is a string, non-empty means True and empty
505      means False.
506    inputs: A list of input tensors.
507    then_branch: A function takes 'inputs' and returns a list of tensors,
508        whose types are the same as what else_branch returns.
509    else_branch: A function takes 'inputs' and returns a list of tensors.
510        whose types are the same as what then_branch returns.
511    name: A name for the operation (optional).
512
513  Returns:
514    A list of tensors returned by either then_branch(inputs)
515    or else_branch(inputs).
516  """
517  # pylint: disable=protected-access
518  return gen_functional_ops._if(
519      cond,
520      inputs, [_.type for _ in then_branch.definition.signature.output_arg],
521      then_branch,
522      else_branch,
523      name=name)
524
525
526def Gradient(inputs, f, name=None):
527  r"""Computes the gradient function for function f via backpropagation.
528
529  Args:
530    inputs: A list of tensors of size N + M.
531    f: The function we want to compute the gradient for.
532
533      The function 'f' must be a numerical function which takes N inputs and
534      produces M outputs. Its gradient function 'g', which is  a function
535      taking N + M inputs and produces N outputs.
536
537      I.e. if we have
538         (y1, y2, ..., yM) = f(x1, x2, ..., xN),
539      then, g is
540         (dL/dx1, dL/dx2, ..., dL/dxN) = g(x1, x2, ..., xN,
541                                           dL/dy1, dL/dy2, ..., dL/dyM),
542
543      where L is a scalar-value function of (x1, x2, ..., xN) (e.g., the
544      loss function). dL/dxi is the partial derivative of L with respect
545      to xi.
546
547    name: A name for the operation (optional).
548
549  Returns:
550    A list of tensors of size N.
551  """
552  # TODO(zhifengc): Pretty-print the above spec in latex.
553  # TODO(zhfiengc): Needs some math expert to say the comment above better.
554  tlist = [_.type for _ in f.definition.signature.input_arg]
555  return symbolic_gradient(input=inputs, Tout=tlist, f=f, name=name)
556
557
558def _LoopBodyCaptureWrapper(func):
559  """Returns a wrapper for `func` that handles loop-carried captured inputs."""
560
561  @function.Defun(
562      *func.declared_input_types, func_name="%s_Wrapper" % func.name)
563  def Wrapper(*args):
564    """A wrapper that handles loop-carried captured inputs."""
565    result = func(*args)
566    extra_args = tuple(function.get_extra_args())
567    # Nullary functions return an Operation. Normal functions can't do this
568    # because their return values are converted to Tensors.
569    if isinstance(result, ops.Operation):
570      return extra_args
571    # Unary functions return a single Tensor value.
572    elif not isinstance(result, tuple):
573      return (result,) + extra_args
574    # N-ary functions return a tuple of Tensors.
575    else:
576      return result + extra_args
577
578  return Wrapper
579
580
581# pylint: disable=invalid-name,protected-access
582def While(input_, cond, body, name=None, hostmem=None):
583  r"""output = input; While (Cond(output)) { output = Body(output) }.
584
585  Args:
586    input_: A list of `Tensor` objects.
587      A list of input tensors whose types are T.
588    cond: . A function takes 'input' and returns a tensor.  If the tensor is
589      a scalar of non-boolean, the scalar is converted to a boolean
590      according to the following rule: if the scalar is a numerical
591      value, non-zero means True and zero means False; if the scalar is
592      a string, non-empty means True and empty means False. If the
593      tensor is not a scalar, non-emptiness means True and False
594      otherwise.
595    body: . A function takes a list of tensors and returns another
596      list tensors. Both lists have the same types as specified
597      by T.
598    name: A name for the operation (optional).
599    hostmem: A list of integer. If i is in the list, input[i] is a
600      host memory tensor.
601
602  Raises:
603    ValueError: if `cond` has implicitly captured inputs or if `cond` and `body`
604      have different signatures.
605
606  Returns:
607    A list of `Tensor` objects. Has the same type as `input`.
608    A list of output tensors whose types are T.
609  """
610  if cond.captured_inputs:
611    raise ValueError("While op 'cond' argument must be a function "
612                     "without implicitly captured inputs.")
613
614  if cond.declared_input_types != body.declared_input_types:
615    raise ValueError(
616        "While op 'cond' and 'body' signatures do not match. %r vs %r" %
617        (cond.declared_input_types, body.declared_input_types))
618
619  if body.captured_inputs:
620    cond_dtypes = list(
621        body.declared_input_types) + [t.dtype for t in body.captured_inputs]
622
623    @function.Defun(*cond_dtypes, func_name="%s_Wrapper" % cond.name)
624    def CondWrapper(*args):
625      """A wrapper that handles loop-carried captured inputs."""
626      return cond(*args[:len(body.declared_input_types)])
627
628    ret = gen_functional_ops._while(
629        input_ + body.captured_inputs,
630        CondWrapper,
631        _LoopBodyCaptureWrapper(body),
632        name=name)
633    # Slice off the loop-carried captured inputs.
634    ret = ret[:-len(body.captured_inputs)]
635  else:
636    ret = gen_functional_ops._while(input_, cond, body, name=name)
637  if hostmem:
638    input_attr = attr_value_pb2.AttrValue()
639    input_attr.list.i.extend(hostmem)
640    ret[0].op._set_attr("_input_hostmem", input_attr)  # pylint: disable=protected-access
641
642    output_attr = attr_value_pb2.AttrValue()
643    output_attr.list.i.extend(hostmem)
644    ret[0].op._set_attr("_output_hostmem", output_attr)  # pylint: disable=protected-access
645  return ret
646
647
648# b/36459430
649#
650# Ideally, we do not need this rewrite For loop into a While loop.
651# However, today, if a While runs on GPU and the condition returns a
652# boolean, the While kernel crashes. Even if we fix the crash, the
653# bool needs to be copied between GPU and CPU. So, a for loop is much
654# preferred when running on GPU.
655#
656# On the other hand, For op has no directly XLA kernel. So, when we run
657# a for loop, we need to rewrite it using a While op.
658#
659# It should be possible and probably better to write a XLA C++ kernel
660# implementing the logic in _ForUsingWhile.
661def _ForUsingWhile(start,
662                   limit,
663                   delta,
664                   inputs,
665                   forbody,
666                   name=None,
667                   hostmem=None):
668  """Helper to implement a For loop using a While."""
669  # To support negative delta (e.g., range(100, 0, -3)), we iterate
670  # over the range(n) and use iter * delta + start as the real
671  # iteration index. (e.g., for i in range(34): iter = i * (-3) +
672  # 100).
673  d = math_ops.abs(delta)
674  # XLA on TPUs doesn't support integer division
675  n = math_ops.cast(
676      math_ops.cast((math_ops.abs(limit - start) + d - 1), dtypes.float32) /
677      math_ops.cast(d, dtypes.float32), dtypes.int32)
678
679  # Carried loop variables ("extra_args") are implicitly added to the input list
680  # of the WhileBody function. WhileCond does not call forbody, and so does not
681  # depend on any of forbody's extra_args. Since WhileCond and WhileBody
682  # must have identical inputs, we have to augment the cond signature to take
683  # the same types as the carried loop variables.
684  body_sig = [dtypes.int32] * 4 + list(forbody.declared_input_types)[1:]
685
686  cond_name = "%s_Cond" % forbody.name
687
688  @function.Defun(*body_sig, func_name=cond_name)
689  def WhileCond(i, n, *args):
690    del args
691    return i < n
692
693  body_name = "%s_Body" % forbody.name
694
695  @function.Defun(*body_sig, func_name=body_name)
696  def WhileBody(i, n, start, delta, *args):
697    """A While wrapper for forbody that handles loop-carried captured inputs."""
698    for_result = forbody(start + i * delta, *args)
699    # Nullary functions return an Operation. Normal functions can't do this
700    # because their return values are converted to Tensors.
701    if isinstance(for_result, ops.Operation):
702      for_result = ()
703    # Unary functions return a single Tensor value.
704    elif isinstance(for_result, ops.Tensor):
705      for_result = (for_result,)
706    return (i + 1, n, start, delta) + tuple(for_result)
707
708  if hostmem is not None:
709    hostmem = [0, 1, 2, 3] + [(4 + _) for _ in hostmem]
710  else:
711    hostmem = [0, 1, 2, 3]
712
713  results = While(
714      input_=[0, n, start, delta] + inputs,
715      cond=WhileCond,
716      body=WhileBody,
717      name=name,
718      hostmem=hostmem)
719  # Slice off the loop-carried captured inputs.
720  return list(results[4:len(results)])
721
722
723def For(start,
724        limit,
725        delta,
726        inputs,
727        body,
728        name=None,
729        hostmem=None,
730        rewrite_with_while=None):
731  r"""out = input; for i in range(start, limit, delta) out = body(i, out).
732
733  Args:
734    start: A `Tensor` of type `int32`.
735    limit: A `Tensor` of type `int32`.
736    delta: A `Tensor` of type `int32`.
737    inputs: A list of `Tensor` objects.
738      A list of input tensors whose types are T.
739    body: A function takes a list of tensors and returns another
740      list of tensors. Both lists have the same types as (int32, T...).
741    name: A name for the operation (optional).
742    hostmem: A list of integer. If i is in the list, inputs[i] is a
743      host memory tensor. In other words, (i+1)-th argument of the body
744      function is expecting a host memory.
745    rewrite_with_while: If True, using While op to implement the For.
746
747  Returns:
748    A list of `Tensor` objects. Has the same type as `input`.
749    A list of output tensors whose types are T.
750  """
751  if rewrite_with_while:
752    return _ForUsingWhile(start, limit, delta, inputs, body, name, hostmem)
753  if body.captured_inputs:
754    ret = gen_functional_ops._for(
755        start,
756        limit,
757        delta,
758        inputs + body.captured_inputs,
759        _LoopBodyCaptureWrapper(body),
760        name=name)
761    # Slice off the loop-carried captured inputs.
762    ret = ret[:-len(body.captured_inputs)]
763  else:
764    ret = gen_functional_ops._for(start, limit, delta, inputs, body, name=name)
765  if hostmem:
766    num_for_params = 3  # start/limit/delta
767
768    input_attr = attr_value_pb2.AttrValue()
769    input_attr.list.i.extend([num_for_params + i for i in hostmem])
770    ret[0].op._set_attr("_input_hostmem", input_attr)  # pylint: disable=protected-access
771
772    output_attr = attr_value_pb2.AttrValue()
773    output_attr.list.i.extend(hostmem)
774    ret[0].op._set_attr("_output_hostmem", output_attr)  # pylint: disable=protected-access
775  return ret
776# pylint: enable=invalid-name,protected-access
777
778
779def partitioned_call(args, f, tout=None, executing_eagerly=None, config=None,
780                     executor_type=None):
781  """Executes a function while respecting device annotations.
782
783  Currently, only those functions that execute within the same address space
784  can be executed.
785
786  Args:
787    args: The arguments of the function, including captured inputs.
788    f: The function to execute; an instance of `_DefinedFunction` or
789      `_EagerDefinedFunction`.
790    tout: a list containing the output dtypes enums; if `None`, inferred from
791      the signature of `f`.
792    executing_eagerly: (Optional) A boolean indicating whether the context is
793      executing eagerly. If `None`, fetched from the global context.
794    config: (Optional) A `tensorflow::ConfigProto` proto, serialized. If
795      `None`, all optimizations are disabled. Currently only handled for eager
796      defined functions.
797    executor_type: (Optional) A string for the name of the executor to be used
798      in the function call. If not set, or set to an empty string, the default
799      tensorflow executor will be used.
800
801  Returns:
802    The list of `Tensor`s returned by invoking `f(args)`. If the function does
803    not return anything, then returns `None` if eager execution is enabled, or
804    the `Operation` if not.
805  """
806
807  if tout is None:
808    tout = tuple(x.type for x in f.definition.signature.output_arg)
809
810  if executing_eagerly is None:
811    executing_eagerly = context.executing_eagerly()
812
813  if config is None:
814    config = function_utils.get_disabled_rewriter_config()
815
816  if executor_type is None:
817    executor_type = ""
818
819  if executing_eagerly or len(tout):
820    if f.stateful_ops:
821      outputs = gen_functional_ops.stateful_partitioned_call(
822          args=args, Tout=tout, f=f, config_proto=config,
823          executor_type=executor_type)
824    else:
825      outputs = gen_functional_ops.partitioned_call(
826          args=args, Tout=tout, f=f, config_proto=config,
827          executor_type=executor_type)
828    return outputs if outputs else None
829
830  # The generated binding returns an empty list for functions that don't
831  # return any Tensors, hence the need to use `create_op` directly.
832  args = [ops.internal_convert_to_tensor(x) for x in args]
833  tin_attr = attr_value_pb2.AttrValue(
834      list=attr_value_pb2.AttrValue.ListValue(
835          type=[x.dtype.as_datatype_enum for x in args]))
836  tout_attr = attr_value_pb2.AttrValue(
837      list=attr_value_pb2.AttrValue.ListValue(type=tout))
838  func_attr = attr_value_pb2.AttrValue(
839      func=attr_value_pb2.NameAttrList(name=f.name))
840  executor_type_attr = attr_value_pb2.AttrValue(
841      s=compat.as_bytes(executor_type))
842
843  # When running in graph mode, the graph and function graphs are optimized
844  # (i.e. run through grappler) per the session options, so we can disable any
845  # eager-specific rewriting.
846  config_proto = attr_value_pb2.AttrValue(
847      s=function_utils.get_disabled_rewriter_config())
848
849  graph = ops.get_default_graph()
850  f.add_to_graph(graph)
851  op_name = "StatefulPartitionedCall" if f.stateful_ops else "PartitionedCall"
852  op = graph.create_op(
853      op_name,
854      args,
855      tout,
856      compute_shapes=False,
857      name="PartitionedFunctionCall",
858      attrs={
859          "Tin": tin_attr,
860          "Tout": tout_attr,
861          "f": func_attr,
862          "config_proto": config_proto,
863          "executor_type": executor_type_attr,
864      })
865  outputs = op.outputs
866  return outputs if outputs else op
867