1# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Functional operations for RaggedTensors."""
16
17from __future__ import absolute_import
18from __future__ import division
19from __future__ import print_function
20from tensorflow.python.ops.ragged import ragged_tensor
21from tensorflow.python.util import nest
22from tensorflow.python.util.lazy_loader import LazyLoader
23
24
25map_fn_lib = LazyLoader(
26    "map_fn_lib", globals(),
27    "tensorflow.python.ops.map_fn")
28
29
30def map_fn(fn,
31           elems,
32           dtype=None,
33           parallel_iterations=None,
34           back_prop=True,
35           swap_memory=False,
36           infer_shape=True,
37           name=None):
38  """map on the list of tensors unpacked from `elems` on dimension 0.
39
40  The simplest version of `map_fn` repeatedly applies the callable `fn` to a
41  sequence of elements from first to last. The elements are made of the
42  tensors unpacked from `elems`. `dtype` is the data type of the return
43  value of `fn`. Users must provide `dtype` if it is different from
44  the data type of `elems`.
45
46  Suppose that `elems` is unpacked into `values`, a list of tensors. The shape
47  of the result tensor is `[values.shape[0]] + fn(values[0]).shape`.
48
49  This method also allows multi-arity `elems` and output of `fn`.  If `elems`
50  is a (possibly nested) list or tuple of tensors, then each of these tensors
51  must have a matching first (unpack) dimension.  The signature of `fn` may
52  match the structure of `elems`.  That is, if `elems` is
53  `(t1, [t2, t3, [t4, t5]])`, then an appropriate signature for `fn` is:
54  `fn = lambda (t1, [t2, t3, [t4, t5]]):`.
55
56  Furthermore, `fn` may emit a different structure than its input.  For example,
57  `fn` may look like: `fn = lambda t1: return (t1 + 1, t1 - 1)`.  In this case,
58  the `dtype` parameter is not optional: `dtype` must be a type or (possibly
59  nested) tuple of types matching the output of `fn`.
60
61  To apply a functional operation to the nonzero elements of a SparseTensor
62  one of the following methods is recommended. First, if the function is
63  expressible as TensorFlow ops, use
64
65  ```python
66    result = SparseTensor(input.indices, fn(input.values), input.dense_shape)
67  ```
68
69  If, however, the function is not expressible as a TensorFlow op, then use
70
71  ```python
72  result = SparseTensor(
73    input.indices, map_fn(fn, input.values), input.dense_shape)
74  ```
75
76  instead.
77
78  When executing eagerly, map_fn does not execute in parallel even if
79  `parallel_iterations` is set to a value > 1. You can still get the
80  performance benefits of running a function in parallel by using the
81  `tf.contrib.eager.defun` decorator,
82
83  ```python
84  # Assume the function being used in map_fn is fn.
85  # To ensure map_fn calls fn in parallel, use the defun decorator.
86  @tf.contrib.eager.defun
87  def func(tensor):
88    return tf.map_fn(fn, tensor)
89  ```
90
91  Note that if you use the defun decorator, any non-TensorFlow Python code
92  that you may have written in your function won't get executed. See
93  `tf.contrib.eager.defun` for more details. The recommendation would be to
94  debug without defun but switch to defun to get performance benefits of
95  running map_fn in parallel.
96
97  Args:
98    fn: The callable to be performed.  It accepts one argument, which will have
99      the same (possibly nested) structure as `elems`.  Its output must have the
100      same structure as `dtype` if one is provided, otherwise it must have the
101      same structure as `elems`.
102    elems: A tensor or (possibly nested) sequence of tensors, each of which will
103      be unpacked along their first dimension.  The nested sequence of the
104      resulting slices will be applied to `fn`.
105    dtype: (optional) The output type(s) of `fn`.  If `fn` returns a structure
106      of Tensors differing from the structure of `elems`, then `dtype` is not
107      optional and must have the same structure as the output of `fn`. Use
108      `RaggedTensorType` to declare an output of type `RaggedTensor`.
109    parallel_iterations: (optional) The number of iterations allowed to run in
110      parallel. When graph building, the default value is 10. While executing
111      eagerly, the default value is set to 1.
112    back_prop: (optional) True enables support for back propagation.
113    swap_memory: (optional) True enables GPU-CPU memory swapping.
114    infer_shape: (optional) False disables tests for consistent output shapes.
115    name: (optional) Name prefix for the returned tensors.
116
117  Returns:
118    A possibly nested sequence of potentially ragged tensors.  Each
119    tensor packs the results of applying `fn` to tensors unpacked from `elems`
120    along the first dimension, from first to last.
121
122  Raises:
123    TypeError: if `fn` is not callable or the structure of the output of
124      `fn` and `dtype` do not match, or if elems is a SparseTensor.
125    ValueError: if the lengths of the output of `fn` and `dtype` do not match.
126
127  #### Examples:
128
129    ```python
130    elems = np.array([1, 2, 3, 4, 5, 6])
131    squares = map_fn(lambda x: x * x, elems)
132    # squares == [1, 4, 9, 16, 25, 36]
133    ```
134
135    ```python
136    elems = (np.array([1, 2, 3]), np.array([-1, 1, -1]))
137    alternate = map_fn(lambda x: x[0] * x[1], elems, dtype=tf.int64)
138    # alternate == [-1, 2, -3]
139    ```
140
141    ```python
142    elems = np.array([1, 2, 3])
143    alternates = map_fn(lambda x: (x, -x), elems, dtype=(tf.int64, tf.int64))
144    # alternates[0] == [1, 2, 3]
145    # alternates[1] == [-1, -2, -3]
146    ```
147
148    ```python
149    elems=ragged.constant([[1, 2, 3], [4, 5], [6, 7]])
150    mean = map_fn(tf.reduce_mean, elems)
151    # mean == [2, 4, 6]
152    ```
153
154    ```python
155    elems=ragged.constant([[1, 2, 3], [4, 5], [6, 7]], dtype=tf.int64)
156    out = map_fn(fn=lambda x: x+1, elems,
157      dtype=ragged.RaggedTensorType(type=tf.int64, ragged_rank=0))
158    # out = tf.ragged.constant([[2, 3, 4], [5, 6], [7, 8]])
159    ```
160  """
161  if dtype is None:
162    dtype = nest.map_structure(lambda e: e.dtype, elems)
163  dtype = nest.map_structure(_ragged_type_to_spec, dtype)
164  return map_fn_lib.map_fn(fn,
165                           elems,
166                           dtype,
167                           parallel_iterations,
168                           back_prop,
169                           swap_memory,
170                           infer_shape,
171                           name)
172
173
174def _ragged_type_to_spec(t):
175  if isinstance(t, ragged_tensor.RaggedTensorType):
176    # Note: need to adjust ragged_rank by 1, since RaggedTensorSpec gives the
177    # type for the mapped `fn` output, but RaggedTensorType gives the type for
178    # the result of stacking the mapped `fn` outputs.
179    return ragged_tensor.RaggedTensorSpec(
180        None, t.dtype, t.ragged_rank - 1, t.row_splits_dtype)
181  else:
182    return t
183