1# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Ops to manipulate lists of tensors."""
16
17# pylint: disable=g-bad-name
18from __future__ import absolute_import
19from __future__ import division
20from __future__ import print_function
21
22import numpy as np
23
24from tensorflow.core.framework import types_pb2
25from tensorflow.python.framework import cpp_shape_inference_pb2
26from tensorflow.python.framework import dtypes
27from tensorflow.python.framework import ops
28from tensorflow.python.framework import tensor_shape
29from tensorflow.python.framework import tensor_util
30from tensorflow.python.ops import array_ops
31from tensorflow.python.ops import gen_list_ops
32from tensorflow.python.ops import handle_data_util
33# go/tf-wildcard-import
34# pylint: disable=wildcard-import
35from tensorflow.python.ops.gen_list_ops import *
36# pylint: enable=wildcard-import
37from tensorflow.python.util.lazy_loader import LazyLoader
38
39# list_ops -> control_flow_ops -> tensor_array_ops -> list_ops
40control_flow_ops = LazyLoader(
41    "control_flow_ops", globals(),
42    "tensorflow.python.ops.control_flow_ops")
43
44
45ops.NotDifferentiable("TensorListConcatLists")
46ops.NotDifferentiable("TensorListElementShape")
47ops.NotDifferentiable("TensorListLength")
48ops.NotDifferentiable("TensorListPushBackBatch")
49
50
51def empty_tensor_list(element_shape,
52                      element_dtype,
53                      max_num_elements=None,
54                      name=None):
55  if max_num_elements is None:
56    max_num_elements = -1
57
58  return gen_list_ops.empty_tensor_list(
59      element_shape=_build_element_shape(element_shape),
60      element_dtype=element_dtype,
61      max_num_elements=max_num_elements,
62      name=name)
63
64
65def _set_handle_data(list_handle, element_shape, element_dtype):
66  """Sets type information on `list_handle` for consistency with graphs."""
67  # TODO(b/169968286): It would be better if we had a consistent story for
68  # creating handle data from eager operations (shared with VarHandleOp).
69  if isinstance(list_handle, ops.EagerTensor):
70    if tensor_util.is_tf_type(element_shape):
71      element_shape = tensor_shape.TensorShape(None)
72    elif not isinstance(element_shape, tensor_shape.TensorShape):
73      element_shape = tensor_shape.TensorShape(element_shape)
74    handle_data = cpp_shape_inference_pb2.CppShapeInferenceResult.HandleData()
75    handle_data.is_set = True
76    handle_data.shape_and_type.append(
77        cpp_shape_inference_pb2.CppShapeInferenceResult.HandleShapeAndType(
78            shape=element_shape.as_proto(),
79            dtype=element_dtype.as_datatype_enum,
80            specialized_type=types_pb2.ST_TENSOR_LIST))
81    list_handle._handle_data = handle_data  # pylint: disable=protected-access
82
83
84def tensor_list_reserve(element_shape, num_elements, element_dtype, name=None):
85  result = gen_list_ops.tensor_list_reserve(
86      element_shape=_build_element_shape(element_shape),
87      num_elements=num_elements,
88      element_dtype=element_dtype,
89      name=name)
90  # TODO(b/169968286): gen_ops needs to ensure the metadata is properly
91  # populated for eager operations.
92  _set_handle_data(result, element_shape, element_dtype)
93  return result
94
95
96def tensor_list_from_tensor(tensor, element_shape, name=None):
97  tensor = ops.convert_to_tensor(tensor)
98  result = gen_list_ops.tensor_list_from_tensor(
99      tensor=tensor,
100      element_shape=_build_element_shape(element_shape),
101      name=name)
102  _set_handle_data(result, tensor.shape, tensor.dtype)
103  return result
104
105
106def tensor_list_get_item(input_handle, index, element_dtype, element_shape=None,
107                         name=None):
108  return gen_list_ops.tensor_list_get_item(
109      input_handle=input_handle,
110      index=index,
111      element_shape=_build_element_shape(element_shape),
112      element_dtype=element_dtype,
113      name=name)
114
115
116def tensor_list_pop_back(input_handle, element_dtype, name=None):
117  return gen_list_ops.tensor_list_pop_back(
118      input_handle=input_handle,
119      element_shape=-1,
120      element_dtype=element_dtype,
121      name=name)
122
123
124def tensor_list_gather(input_handle,
125                       indices,
126                       element_dtype,
127                       element_shape=None,
128                       name=None):
129  return gen_list_ops.tensor_list_gather(
130      input_handle=input_handle,
131      indices=indices,
132      element_shape=_build_element_shape(element_shape),
133      element_dtype=element_dtype,
134      name=name)
135
136
137def tensor_list_scatter(tensor,
138                        indices,
139                        element_shape=None,
140                        input_handle=None,
141                        name=None):
142  """Returns a TensorList created or updated by scattering `tensor`."""
143  tensor = ops.convert_to_tensor(tensor)
144  if input_handle is not None:
145    output_handle = gen_list_ops.tensor_list_scatter_into_existing_list(
146        input_handle=input_handle, tensor=tensor, indices=indices, name=name)
147    handle_data_util.copy_handle_data(input_handle, output_handle)
148    return output_handle
149  else:
150    output_handle = gen_list_ops.tensor_list_scatter_v2(
151        tensor=tensor,
152        indices=indices,
153        element_shape=_build_element_shape(element_shape),
154        num_elements=-1,
155        name=name)
156    _set_handle_data(output_handle, element_shape, tensor.dtype)
157    return output_handle
158
159
160def tensor_list_stack(input_handle,
161                      element_dtype,
162                      num_elements=-1,
163                      element_shape=None,
164                      name=None):
165  return gen_list_ops.tensor_list_stack(
166      input_handle=input_handle,
167      element_shape=_build_element_shape(element_shape),
168      element_dtype=element_dtype,
169      num_elements=num_elements,
170      name=name)
171
172
173def tensor_list_concat(input_handle, element_dtype, element_shape=None,
174                       name=None):
175  # Ignore the lengths output of TensorListConcat. It is only used during
176  # gradient computation.
177  return gen_list_ops.tensor_list_concat_v2(
178      input_handle=input_handle,
179      element_dtype=element_dtype,
180      element_shape=_build_element_shape(element_shape),
181      leading_dims=ops.convert_to_tensor([], dtype=dtypes.int64),
182      name=name)[0]
183
184
185def tensor_list_split(tensor, element_shape, lengths, name=None):
186  return gen_list_ops.tensor_list_split(
187      tensor=tensor,
188      element_shape=_build_element_shape(element_shape),
189      lengths=lengths,
190      name=name)
191
192
193def tensor_list_set_item(input_handle,
194                         index,
195                         item,
196                         resize_if_index_out_of_bounds=False,
197                         name=None):
198  """Sets `item` at `index` in input list."""
199  if resize_if_index_out_of_bounds:
200    input_list_size = gen_list_ops.tensor_list_length(input_handle)
201    # TODO(srbs): This could cause some slowdown. Consider fusing resize
202    # functionality in the SetItem op.
203    input_handle = control_flow_ops.cond(
204        index >= input_list_size,
205        lambda: gen_list_ops.tensor_list_resize(  # pylint: disable=g-long-lambda
206            input_handle, index + 1),
207        lambda: input_handle)
208  output_handle = gen_list_ops.tensor_list_set_item(
209      input_handle=input_handle, index=index, item=item, name=name)
210  handle_data_util.copy_handle_data(input_handle, output_handle)
211  return output_handle
212
213
214@ops.RegisterGradient("TensorListPushBack")
215def _PushBackGrad(op, dresult):
216  return gen_list_ops.tensor_list_pop_back(
217      dresult,
218      element_shape=array_ops.shape(op.inputs[1]),
219      element_dtype=op.get_attr("element_dtype"))
220
221
222@ops.RegisterGradient("TensorListPopBack")
223def _PopBackGrad(op, dlist, delement):
224  if dlist is None:
225    dlist = empty_tensor_list(
226        element_dtype=delement.dtype,
227        element_shape=gen_list_ops.tensor_list_element_shape(
228            op.outputs[0], shape_type=dtypes.int32))
229  if delement is None:
230    delement = array_ops.zeros_like(op.outputs[1])
231  return gen_list_ops.tensor_list_push_back(dlist, delement), None
232
233
234@ops.RegisterGradient("TensorListStack")
235def _TensorListStackGrad(unused_op, dtensor):
236  return tensor_list_from_tensor(dtensor, element_shape=dtensor.shape[1:]), None
237
238
239@ops.RegisterGradient("TensorListConcat")
240@ops.RegisterGradient("TensorListConcatV2")
241def _TensorListConcatGrad(op, dtensor, unused_dlengths):
242  """Gradient function for TensorListConcat."""
243  dlist = tensor_list_split(
244      dtensor,
245      element_shape=gen_list_ops.tensor_list_element_shape(
246          op.inputs[0], shape_type=dtypes.int32),
247      lengths=op.outputs[1])
248  if op.type == "TensorListConcatV2":
249    return dlist, None, None
250  else:
251    return dlist
252
253
254@ops.RegisterGradient("TensorListSplit")
255def _TensorListSplitGrad(op, dlist):
256  tensor, _, lengths = op.inputs
257  element_shape = array_ops.slice(array_ops.shape(tensor), [1], [-1])
258  element_shape = array_ops.concat([[-1], element_shape], axis=0)
259  return gen_list_ops.tensor_list_concat_v2(
260      dlist,
261      element_shape=element_shape,
262      leading_dims=lengths,
263      element_dtype=op.inputs[0].dtype)[0], None, None
264
265
266@ops.RegisterGradient("TensorListFromTensor")
267def _TensorListFromTensorGrad(op, dlist):
268  """Gradient for TensorListFromTensor."""
269  t = op.inputs[0]
270  if t.shape.dims and t.shape.dims[0].value is not None:
271    num_elements = t.shape.dims[0].value
272  else:
273    num_elements = None
274  if dlist is None:
275    dlist = empty_tensor_list(
276        element_dtype=t.dtype,
277        element_shape=gen_list_ops.tensor_list_element_shape(
278            op.outputs[0], shape_type=dtypes.int32))
279  tensor_grad = gen_list_ops.tensor_list_stack(
280      dlist,
281      element_shape=array_ops.slice(array_ops.shape(t), [1], [-1]),
282      element_dtype=t.dtype,
283      num_elements=num_elements)
284  shape_grad = None
285  return tensor_grad, shape_grad
286
287
288@ops.RegisterGradient("TensorListGetItem")
289def _TensorListGetItemGrad(op, ditem):
290  """Gradient for TensorListGetItem."""
291  list_size = gen_list_ops.tensor_list_length(op.inputs[0])
292  list_grad = gen_list_ops.tensor_list_set_item(
293      gen_list_ops.tensor_list_reserve(
294          gen_list_ops.tensor_list_element_shape(op.inputs[0],
295                                                 shape_type=dtypes.int32),
296          list_size, element_dtype=ditem.dtype),
297      index=op.inputs[1],
298      item=ditem)
299  index_grad = None
300  element_shape_grad = None
301  return list_grad, index_grad, element_shape_grad
302
303
304@ops.RegisterGradient("TensorListSetItem")
305def _TensorListSetItemGrad(op, dlist):
306  """Gradient function for TensorListSetItem."""
307  _, index, item = op.inputs
308  list_grad = gen_list_ops.tensor_list_set_item(
309      dlist, index=index, item=array_ops.zeros_like(item))
310  index_grad = None
311  element_grad = tensor_list_get_item(
312      dlist,
313      index,
314      element_shape=array_ops.shape(item),
315      element_dtype=item.dtype)
316  return list_grad, index_grad, element_grad
317
318
319@ops.RegisterGradient("TensorListResize")
320def _TensorListResizeGrad(op, dlist):
321  input_list, _ = op.inputs
322  input_list_size = gen_list_ops.tensor_list_length(input_list)
323  return gen_list_ops.tensor_list_resize(dlist, input_list_size), None
324
325
326@ops.RegisterGradient("TensorListGather")
327def _TensorListGatherGrad(op, dtensor):
328  """Gradient function for TensorListGather."""
329  input_list, indices, _ = op.inputs
330  element_shape = gen_list_ops.tensor_list_element_shape(
331      input_list, shape_type=dtypes.int32)
332  num_elements = gen_list_ops.tensor_list_length(input_list)
333  dlist = tensor_list_reserve(element_shape, num_elements, dtensor.dtype)
334  dlist = tensor_list_scatter(
335      tensor=dtensor, indices=indices, input_handle=dlist)
336  return dlist, None, None
337
338
339@ops.RegisterGradient("TensorListScatter")
340@ops.RegisterGradient("TensorListScatterV2")
341def _TensorListScatterGrad(op, dlist):
342  """Gradient function for TensorListScatter."""
343  tensor = op.inputs[0]
344  indices = op.inputs[1]
345  dtensor = gen_list_ops.tensor_list_gather(
346      dlist,
347      indices,
348      element_shape=array_ops.slice(array_ops.shape(tensor), [1], [-1]),
349      element_dtype=tensor.dtype)
350  if op.type == "TensorListScatterV2":
351    return dtensor, None, None, None
352  else:
353    return dtensor, None, None
354
355
356@ops.RegisterGradient("TensorListScatterIntoExistingList")
357def _TensorListScatterIntoExistingListGrad(op, dlist):
358  """Gradient function for TensorListScatterIntoExistingList."""
359  _, tensor, indices = op.inputs
360  dtensor = gen_list_ops.tensor_list_gather(
361      dlist,
362      indices,
363      element_shape=array_ops.slice(array_ops.shape(tensor), [1], [-1]),
364      element_dtype=tensor.dtype)
365  zeros = array_ops.zeros_like(tensor)
366  dlist = tensor_list_scatter(zeros, indices, indices, input_handle=dlist)
367  return dlist, dtensor, None
368
369
370def _build_element_shape(shape):
371  """Converts shape to a format understood by list_ops for element_shape.
372
373  If `shape` is already a `Tensor` it is returned as-is. We do not perform a
374  type check here.
375
376  If shape is None or a TensorShape with unknown rank, -1 is returned.
377
378  If shape is a scalar, an int32 tensor with empty list is returned. Note we
379  do directly return an empty list since ops.convert_to_tensor would conver it
380  to a float32 which is not a valid type for element_shape.
381
382  If shape is a sequence of dims, None's in the list are replaced with -1. We
383  do not check the dtype of the other dims.
384
385  Args:
386    shape: Could be None, Tensor, TensorShape or a list of dims (each dim could
387      be a None, scalar or Tensor).
388
389  Returns:
390    A None-free shape that can be converted to a tensor.
391  """
392  if isinstance(shape, ops.Tensor):
393    return shape
394  if isinstance(shape, tensor_shape.TensorShape):
395    # `TensorShape.as_list` requires rank to be known.
396    shape = shape.as_list() if shape else None
397  # Shape is unknown.
398  if shape is None:
399    return -1
400  # Shape is numpy array or a scalar.
401  if isinstance(shape, (np.ndarray, np.generic)) or not shape:
402    return ops.convert_to_tensor(shape, dtype=dtypes.int32)
403  # Shape is a sequence of dimensions. Convert None dims to -1.
404  def convert(val):
405    if val is None:
406      return -1
407    if isinstance(val, ops.Tensor):
408      return val
409    if isinstance(val, tensor_shape.Dimension):
410      return val.value if val.value is not None else -1
411    return val
412
413  return [convert(d) for d in shape]
414