1# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Ops to manipulate lists of tensors."""
16
17# pylint: disable=g-bad-name
18from __future__ import absolute_import
19from __future__ import division
20from __future__ import print_function
21
22from tensorflow.python.framework import dtypes
23from tensorflow.python.framework import ops
24from tensorflow.python.framework import tensor_shape
25from tensorflow.python.ops import array_ops
26from tensorflow.python.ops import gen_list_ops
27# go/tf-wildcard-import
28# pylint: disable=wildcard-import
29from tensorflow.python.ops.gen_list_ops import *
30# pylint: enable=wildcard-import
31from tensorflow.python.util.lazy_loader import LazyLoader
32
33# list_ops -> control_flow_ops -> tensor_array_ops -> list_ops
34control_flow_ops = LazyLoader(
35    "control_flow_ops", globals(),
36    "tensorflow.python.ops.control_flow_ops")
37
38
39ops.NotDifferentiable("TensorListConcatLists")
40ops.NotDifferentiable("TensorListElementShape")
41ops.NotDifferentiable("TensorListLength")
42ops.NotDifferentiable("TensorListPushBackBatch")
43
44
45def empty_tensor_list(element_shape,
46                      element_dtype,
47                      max_num_elements=None,
48                      name=None):
49  if max_num_elements is None:
50    max_num_elements = -1
51
52  return gen_list_ops.empty_tensor_list(
53      element_shape=_build_element_shape(element_shape),
54      element_dtype=element_dtype,
55      max_num_elements=max_num_elements,
56      name=name)
57
58
59def tensor_list_reserve(element_shape, num_elements, element_dtype, name=None):
60  return gen_list_ops.tensor_list_reserve(
61      element_shape=_build_element_shape(element_shape),
62      num_elements=num_elements,
63      element_dtype=element_dtype,
64      name=name)
65
66
67def tensor_list_from_tensor(tensor, element_shape, name=None):
68  return gen_list_ops.tensor_list_from_tensor(
69      tensor=tensor,
70      element_shape=_build_element_shape(element_shape),
71      name=name)
72
73
74def tensor_list_get_item(input_handle, index, element_dtype, element_shape=None,
75                         name=None):
76  return gen_list_ops.tensor_list_get_item(
77      input_handle=input_handle,
78      index=index,
79      element_shape=_build_element_shape(element_shape),
80      element_dtype=element_dtype,
81      name=name)
82
83
84def tensor_list_pop_back(input_handle, element_dtype, name=None):
85  return gen_list_ops.tensor_list_pop_back(
86      input_handle=input_handle,
87      element_shape=-1,
88      element_dtype=element_dtype,
89      name=name)
90
91
92def tensor_list_gather(input_handle,
93                       indices,
94                       element_dtype,
95                       element_shape=None,
96                       name=None):
97  return gen_list_ops.tensor_list_gather(
98      input_handle=input_handle,
99      indices=indices,
100      element_shape=_build_element_shape(element_shape),
101      element_dtype=element_dtype,
102      name=name)
103
104
105def tensor_list_scatter(tensor,
106                        indices,
107                        element_shape=None,
108                        input_handle=None,
109                        name=None):
110  if input_handle is not None:
111    return gen_list_ops.tensor_list_scatter_into_existing_list(
112        input_handle=input_handle, tensor=tensor, indices=indices, name=name)
113  else:
114    return gen_list_ops.tensor_list_scatter_v2(
115        tensor=tensor,
116        indices=indices,
117        element_shape=_build_element_shape(element_shape),
118        num_elements=-1,
119        name=name)
120
121
122def tensor_list_stack(input_handle,
123                      element_dtype,
124                      num_elements=-1,
125                      element_shape=None,
126                      name=None):
127  return gen_list_ops.tensor_list_stack(
128      input_handle=input_handle,
129      element_shape=_build_element_shape(element_shape),
130      element_dtype=element_dtype,
131      num_elements=num_elements,
132      name=name)
133
134
135def tensor_list_concat(input_handle, element_dtype, element_shape=None,
136                       name=None):
137  # Ignore the lengths output of TensorListConcat. It is only used during
138  # gradient computation.
139  return gen_list_ops.tensor_list_concat_v2(
140      input_handle=input_handle,
141      element_dtype=element_dtype,
142      element_shape=_build_element_shape(element_shape),
143      leading_dims=ops.convert_to_tensor([], dtype=dtypes.int64),
144      name=name)[0]
145
146
147def tensor_list_split(tensor, element_shape, lengths, name=None):
148  return gen_list_ops.tensor_list_split(
149      tensor=tensor,
150      element_shape=_build_element_shape(element_shape),
151      lengths=lengths,
152      name=name)
153
154
155def tensor_list_set_item(input_handle,
156                         index,
157                         item,
158                         resize_if_index_out_of_bounds=False,
159                         name=None):
160  """Sets `item` at `index` in input list."""
161  if resize_if_index_out_of_bounds:
162    input_list_size = gen_list_ops.tensor_list_length(input_handle)
163    # TODO(srbs): This could cause some slowdown. Consider fusing resize
164    # functionality in the SetItem op.
165    input_handle = control_flow_ops.cond(
166        index >= input_list_size,
167        lambda: gen_list_ops.tensor_list_resize(  # pylint: disable=g-long-lambda
168            input_handle, index + 1),
169        lambda: input_handle)
170  return gen_list_ops.tensor_list_set_item(
171      input_handle=input_handle, index=index, item=item, name=name)
172
173
174@ops.RegisterGradient("TensorListPushBack")
175def _PushBackGrad(op, dresult):
176  return gen_list_ops.tensor_list_pop_back(
177      dresult,
178      element_shape=array_ops.shape(op.inputs[1]),
179      element_dtype=op.get_attr("element_dtype"))
180
181
182@ops.RegisterGradient("TensorListPopBack")
183def _PopBackGrad(op, dlist, delement):
184  if dlist is None:
185    dlist = empty_tensor_list(
186        element_dtype=delement.dtype,
187        element_shape=gen_list_ops.tensor_list_element_shape(
188            op.outputs[0], shape_type=dtypes.int32))
189  return gen_list_ops.tensor_list_push_back(dlist, delement), None
190
191
192@ops.RegisterGradient("TensorListStack")
193def _TensorListStackGrad(unused_op, dtensor):
194  return tensor_list_from_tensor(dtensor, element_shape=dtensor.shape[1:]), None
195
196
197@ops.RegisterGradient("TensorListConcat")
198@ops.RegisterGradient("TensorListConcatV2")
199def _TensorListConcatGrad(op, dtensor, unused_dlengths):
200  """Gradient function for TensorListConcat."""
201  dlist = tensor_list_split(
202      dtensor,
203      element_shape=gen_list_ops.tensor_list_element_shape(
204          op.inputs[0], shape_type=dtypes.int32),
205      lengths=op.outputs[1])
206  if op.type == "TensorListConcatV2":
207    return dlist, None, None
208  else:
209    return dlist
210
211
212@ops.RegisterGradient("TensorListSplit")
213def _TensorListSplitGrad(op, dlist):
214  tensor, _, lengths = op.inputs
215  element_shape = array_ops.slice(array_ops.shape(tensor), [1], [-1])
216  element_shape = array_ops.concat([[-1], element_shape], axis=0)
217  return gen_list_ops.tensor_list_concat_v2(
218      dlist,
219      element_shape=element_shape,
220      leading_dims=lengths,
221      element_dtype=op.inputs[0].dtype)[0], None, None
222
223
224@ops.RegisterGradient("TensorListFromTensor")
225def _TensorListFromTensorGrad(op, dlist):
226  """Gradient for TensorListFromTensor."""
227  t = op.inputs[0]
228  if t.shape.dims and t.shape.dims[0].value is not None:
229    num_elements = t.shape.dims[0].value
230  else:
231    num_elements = None
232  if dlist is None:
233    dlist = empty_tensor_list(
234        element_dtype=t.dtype,
235        element_shape=gen_list_ops.tensor_list_element_shape(
236            op.outputs[0], shape_type=dtypes.int32))
237  tensor_grad = gen_list_ops.tensor_list_stack(
238      dlist,
239      element_shape=array_ops.slice(array_ops.shape(t), [1], [-1]),
240      element_dtype=t.dtype,
241      num_elements=num_elements)
242  shape_grad = None
243  return tensor_grad, shape_grad
244
245
246@ops.RegisterGradient("TensorListGetItem")
247def _TensorListGetItemGrad(op, ditem):
248  """Gradient for TensorListGetItem."""
249  list_size = gen_list_ops.tensor_list_length(op.inputs[0])
250  list_grad = gen_list_ops.tensor_list_set_item(
251      gen_list_ops.tensor_list_reserve(
252          gen_list_ops.tensor_list_element_shape(op.inputs[0],
253                                                 shape_type=dtypes.int32),
254          list_size, element_dtype=ditem.dtype),
255      index=op.inputs[1],
256      item=ditem)
257  index_grad = None
258  element_shape_grad = None
259  return list_grad, index_grad, element_shape_grad
260
261
262@ops.RegisterGradient("TensorListSetItem")
263def _TensorListSetItemGrad(op, dlist):
264  """Gradient function for TensorListSetItem."""
265  _, index, item = op.inputs
266  list_grad = gen_list_ops.tensor_list_set_item(
267      dlist, index=index, item=array_ops.zeros_like(item))
268  index_grad = None
269  element_grad = tensor_list_get_item(
270      dlist,
271      index,
272      element_shape=array_ops.shape(item),
273      element_dtype=item.dtype)
274  return list_grad, index_grad, element_grad
275
276
277@ops.RegisterGradient("TensorListResize")
278def _TensorListResizeGrad(op, dlist):
279  input_list, _ = op.inputs
280  input_list_size = gen_list_ops.tensor_list_length(input_list)
281  return gen_list_ops.tensor_list_resize(dlist, input_list_size), None
282
283
284@ops.RegisterGradient("TensorListGather")
285def _TensorListGatherGrad(op, dtensor):
286  """Gradient function for TensorListGather."""
287  input_list, indices, _ = op.inputs
288  element_shape = gen_list_ops.tensor_list_element_shape(
289      input_list, shape_type=dtypes.int32)
290  num_elements = gen_list_ops.tensor_list_length(input_list)
291  dlist = tensor_list_reserve(element_shape, num_elements, dtensor.dtype)
292  dlist = tensor_list_scatter(
293      tensor=dtensor, indices=indices, input_handle=dlist)
294  return dlist, None, None
295
296
297@ops.RegisterGradient("TensorListScatter")
298@ops.RegisterGradient("TensorListScatterV2")
299def _TensorListScatterGrad(op, dlist):
300  """Gradient function for TensorListScatter."""
301  tensor = op.inputs[0]
302  indices = op.inputs[1]
303  dtensor = gen_list_ops.tensor_list_gather(
304      dlist,
305      indices,
306      element_shape=array_ops.slice(array_ops.shape(tensor), [1], [-1]),
307      element_dtype=tensor.dtype)
308  if op.type == "TensorListScatterV2":
309    return dtensor, None, None, None
310  else:
311    return dtensor, None, None
312
313
314@ops.RegisterGradient("TensorListScatterIntoExistingList")
315def _TensorListScatterIntoExistingListGrad(op, dlist):
316  """Gradient function for TensorListScatterIntoExistingList."""
317  _, tensor, indices = op.inputs
318  dtensor = gen_list_ops.tensor_list_gather(
319      dlist,
320      indices,
321      element_shape=array_ops.slice(array_ops.shape(tensor), [1], [-1]),
322      element_dtype=tensor.dtype)
323  zeros = array_ops.zeros_like(tensor)
324  dlist = tensor_list_scatter(zeros, indices, indices, input_handle=dlist)
325  return dlist, dtensor, None
326
327
328def _build_element_shape(shape):
329  """Converts shape to a format understood by list_ops for element_shape.
330
331  If `shape` is already a `Tensor` it is returned as-is. We do not perform a
332  type check here.
333
334  If shape is None or a TensorShape with unknown rank, -1 is returned.
335
336  If shape is a scalar, an int32 tensor with empty list is returned. Note we
337  do directly return an empty list since ops.convert_to_tensor would conver it
338  to a float32 which is not a valid type for element_shape.
339
340  If shape is a sequence of dims, None's in the list are replaced with -1. We
341  do not check the dtype of the other dims.
342
343  Args:
344    shape: Could be None, Tensor, TensorShape or a list of dims (each dim could
345      be a None, scalar or Tensor).
346
347  Returns:
348    A None-free shape that can be converted to a tensor.
349  """
350  if isinstance(shape, ops.Tensor):
351    return shape
352  if isinstance(shape, tensor_shape.TensorShape):
353    # `TensorShape.as_list` requires rank to be known.
354    shape = shape.as_list() if shape else None
355  # Shape is unknown.
356  if shape is None:
357    return -1
358  # Shape is a scalar.
359  if not shape:
360    return ops.convert_to_tensor(shape, dtype=dtypes.int32)
361  # Shape is a sequence of dimensions. Convert None dims to -1.
362  def convert(val):
363    if val is None:
364      return -1
365    if isinstance(val, ops.Tensor):
366      return val
367    if isinstance(val, tensor_shape.Dimension):
368      return val.value if val.value is not None else -1
369    return val
370
371  return [convert(d) for d in shape]
372