1# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Operators specific to data structures: list append, subscripts, etc."""
16
17from __future__ import absolute_import
18from __future__ import division
19from __future__ import print_function
20
21import collections
22
23from tensorflow.python.framework import constant_op
24from tensorflow.python.framework import dtypes
25from tensorflow.python.framework import ops
26from tensorflow.python.framework import tensor_util
27from tensorflow.python.ops import array_ops
28from tensorflow.python.ops import control_flow_ops
29from tensorflow.python.ops import list_ops
30from tensorflow.python.ops import tensor_array_ops
31
32
33# TODO(mdan): Once control flow supports objects, repackage as a class.
34
35
36def new_list(iterable=None):
37  """The list constructor.
38
39  Args:
40    iterable: Optional elements to fill the list with.
41
42  Returns:
43    A list-like object. The exact return value depends on the initial elements.
44  """
45  if iterable:
46    elements = tuple(iterable)
47  else:
48    elements = ()
49
50  if elements:
51    # When the list contains elements, it is assumed to be a "Python" lvalue
52    # list.
53    return _py_list_new(elements)
54  return tf_tensor_list_new(elements)
55
56
57def tf_tensor_array_new(elements, element_dtype=None, element_shape=None):
58  """Overload of new_list that stages a Tensor list creation."""
59  elements = tuple(ops.convert_to_tensor(el) for el in elements)
60
61  all_dtypes = set(el.dtype for el in elements)
62  if len(all_dtypes) == 1:
63    inferred_dtype, = tuple(all_dtypes)
64    if element_dtype is not None and element_dtype != inferred_dtype:
65      raise ValueError(
66          'incompatible dtype; specified: {}, inferred from {}: {}'.format(
67              element_dtype, elements, inferred_dtype))
68  elif len(all_dtypes) > 1:
69    raise ValueError(
70        'TensorArray requires all elements to have the same dtype:'
71        ' {}'.format(elements))
72  else:
73    if element_dtype is None:
74      raise ValueError('dtype is required to create an empty TensorArray')
75
76  all_shapes = set(tuple(el.shape.as_list()) for el in elements)
77  if len(all_shapes) == 1:
78    inferred_shape, = tuple(all_shapes)
79    if element_shape is not None and element_shape != inferred_shape:
80      raise ValueError(
81          'incompatible shape; specified: {}, inferred from {}: {}'.format(
82              element_shape, elements, inferred_shape))
83  elif len(all_shapes) > 1:
84    raise ValueError(
85        'TensorArray requires all elements to have the same shape:'
86        ' {}'.format(elements))
87    # TODO(mdan): We may want to allow different shapes with infer_shape=False.
88  else:
89    inferred_shape = None
90
91  if element_dtype is None:
92    element_dtype = inferred_dtype
93  if element_shape is None:
94    element_shape = inferred_shape
95
96  l = tensor_array_ops.TensorArray(
97      dtype=element_dtype,
98      size=len(elements),
99      dynamic_size=True,
100      infer_shape=(element_shape is None),
101      element_shape=element_shape)
102  for i, el in enumerate(elements):
103    l = l.write(i, el)
104  return l
105
106
107def tf_tensor_list_new(elements, element_dtype=None, element_shape=None):
108  """Overload of new_list that stages a Tensor list creation."""
109  if tensor_util.is_tf_type(elements):
110    if element_shape is not None:
111      raise ValueError(
112          'element shape may not be specified when creating list from tensor')
113    element_shape = array_ops.shape(elements)[1:]
114    l = list_ops.tensor_list_from_tensor(elements, element_shape=element_shape)
115    return l
116
117  elements = tuple(ops.convert_to_tensor(el) for el in elements)
118
119  all_dtypes = set(el.dtype for el in elements)
120  if len(all_dtypes) == 1:
121    inferred_dtype = tuple(all_dtypes)[0]
122    if element_dtype is not None and element_dtype != inferred_dtype:
123      raise ValueError(
124          'incompatible dtype; specified: {}, inferred from {}: {}'.format(
125              element_dtype, elements, inferred_dtype))
126  elif all_dtypes:
127    # Heterogeneous lists are ok.
128    if element_dtype is not None:
129      raise ValueError(
130          'specified dtype {} is inconsistent with that of elements {}'.format(
131              element_dtype, elements))
132    inferred_dtype = dtypes.variant
133  else:
134    inferred_dtype = dtypes.variant
135
136  all_shapes = set(tuple(el.shape.as_list()) for el in elements)
137  if len(all_shapes) == 1:
138    inferred_shape = array_ops.shape(elements[0])
139    if element_shape is not None and element_shape != inferred_shape:
140      raise ValueError(
141          'incompatible shape; specified: {}, inferred from {}: {}'.format(
142              element_shape, elements, inferred_shape))
143  elif all_shapes:
144    # Heterogeneous lists are ok.
145    if element_shape is not None:
146      raise ValueError(
147          'specified shape {} is inconsistent with that of elements {}'.format(
148              element_shape, elements))
149    inferred_shape = constant_op.constant(-1)  # unknown shape, by convention
150  else:
151    inferred_shape = constant_op.constant(-1)  # unknown shape, by convention
152
153  if element_dtype is None:
154    element_dtype = inferred_dtype
155  if element_shape is None:
156    element_shape = inferred_shape
157
158  element_shape = ops.convert_to_tensor(element_shape, dtype=dtypes.int32)
159  l = list_ops.empty_tensor_list(
160      element_shape=element_shape, element_dtype=element_dtype)
161  for el in elements:
162    l = list_ops.tensor_list_push_back(l, el)
163  return l
164
165
166def _py_list_new(elements):
167  """Overload of new_list that creates a Python list."""
168  return list(elements)
169
170
171def list_append(list_, x):
172  """The list append function.
173
174  Note: it is unspecified where list_ will be mutated or not. If list_ is
175  a TensorFlow entity, it will not be typically mutated. If list_ is a plain
176  list, it will be. In general, if the list is mutated then the return value
177  should point to the original entity.
178
179  Args:
180    list_: An entity that supports append semantics.
181    x: The element to append.
182
183  Returns:
184    Same as list_, after the append was performed.
185
186  Raises:
187    ValueError: if list_ is not of a known list-like type.
188  """
189  if isinstance(list_, tensor_array_ops.TensorArray):
190    return _tf_tensorarray_append(list_, x)
191  elif tensor_util.is_tf_type(list_):
192    if list_.dtype == dtypes.variant:
193      return _tf_tensor_list_append(list_, x)
194    else:
195      raise ValueError(
196          'tensor lists are expected to be Tensors with dtype=tf.variant,'
197          ' instead found %s' % list_)
198  else:
199    return _py_list_append(list_, x)
200
201
202def _tf_tensor_list_append(list_, x):
203  """Overload of list_append that stages a Tensor list write."""
204  def empty_list_of_elements_like_x():
205    tensor_x = ops.convert_to_tensor(x)
206    return list_ops.empty_tensor_list(
207        element_shape=array_ops.shape(tensor_x),
208        element_dtype=tensor_x.dtype)
209
210  list_ = control_flow_ops.cond(
211      list_ops.tensor_list_length(list_) > 0,
212      lambda: list_,
213      empty_list_of_elements_like_x,
214  )
215  return list_ops.tensor_list_push_back(list_, x)
216
217
218def _tf_tensorarray_append(list_, x):
219  """Overload of list_append that stages a TensorArray write."""
220  return list_.write(list_.size(), x)
221
222
223def _py_list_append(list_, x):
224  """Overload of list_append that executes a Python list append."""
225  # Revert to the original call.
226  list_.append(x)
227  return list_
228
229
230class ListPopOpts(
231    collections.namedtuple('ListPopOpts', ('element_dtype', 'element_shape'))):
232  pass
233
234
235def list_pop(list_, i, opts):
236  """The list pop function.
237
238  Note: it is unspecified where list_ will be mutated or not. If list_ is
239  a TensorFlow entity, it will not be typically mutated. If list_ is a plain
240  list, it will be. In general, if the list is mutated then the return value
241  should point to the original entity.
242
243  Args:
244    list_: An entity that supports pop semantics.
245    i: Optional index to pop from. May be None.
246    opts: A ListPopOpts.
247
248  Returns:
249    Tuple (x, out_list_):
250      out_list_: same as list_, after the removal was performed.
251      x: the removed element value.
252
253  Raises:
254    ValueError: if list_ is not of a known list-like type or the operation is
255    not supported for that type.
256  """
257  assert isinstance(opts, ListPopOpts)
258
259  if isinstance(list_, tensor_array_ops.TensorArray):
260    raise ValueError('TensorArray does not support item removal')
261  elif tensor_util.is_tf_type(list_):
262    if list_.dtype == dtypes.variant:
263      return _tf_tensor_list_pop(list_, i, opts)
264    else:
265      raise ValueError(
266          'tensor lists are expected to be Tensors with dtype=tf.variant,'
267          ' instead found %s' % list_)
268  else:
269    return _py_list_pop(list_, i)
270
271
272def _tf_tensor_list_pop(list_, i, opts):
273  """Overload of list_pop that stages a Tensor list pop."""
274  if i is not None:
275    raise NotImplementedError('tensor lists only support removing from the end')
276
277  if opts.element_dtype is None:
278    raise ValueError('cannot pop from a list without knowing its element '
279                     'type; use set_element_type to annotate it')
280  if opts.element_shape is None:
281    raise ValueError('cannot pop from a list without knowing its element '
282                     'shape; use set_element_type to annotate it')
283  list_out, x = list_ops.tensor_list_pop_back(
284      list_, element_dtype=opts.element_dtype)
285  x.set_shape(opts.element_shape)
286  return list_out, x
287
288
289def _py_list_pop(list_, i):
290  """Overload of list_pop that executes a Python list append."""
291  if i is None:
292    x = list_.pop()
293  else:
294    x = list_.pop(i)
295  return list_, x
296
297
298# TODO(mdan): Look into reducing duplication between all these containers.
299class ListStackOpts(
300    collections.namedtuple('ListStackOpts',
301                           ('element_dtype', 'original_call'))):
302  pass
303
304
305def list_stack(list_, opts):
306  """The list stack function.
307
308  This does not have a direct correspondent in Python. The closest idiom to
309  this is tf.append or np.stack. It's different from those in the sense that it
310  accepts a Tensor list, rather than a list of tensors. It can also accept
311  TensorArray. When the target is anything else, the dispatcher will rely on
312  ctx.original_call for fallback.
313
314  Args:
315    list_: An entity that supports append semantics.
316    opts: A ListStackOpts object.
317
318  Returns:
319    The output of the stack operation, typically a Tensor.
320  """
321  assert isinstance(opts, ListStackOpts)
322
323  if isinstance(list_, tensor_array_ops.TensorArray):
324    return _tf_tensorarray_stack(list_)
325  elif tensor_util.is_tf_type(list_):
326    if list_.dtype == dtypes.variant:
327      return _tf_tensor_list_stack(list_, opts)
328    else:
329      # No-op for primitive Tensor arguments.
330      return list_
331  else:
332    return _py_list_stack(list_, opts)
333
334
335def _tf_tensorarray_stack(list_):
336  """Overload of list_stack that stages a TensorArray stack."""
337  return list_.stack()
338
339
340def _tf_tensor_list_stack(list_, opts):
341  """Overload of list_stack that stages a Tensor list write."""
342  if opts.element_dtype is None:
343    raise ValueError('cannot stack a list without knowing its element type;'
344                     ' use set_element_type to annotate it')
345  return list_ops.tensor_list_stack(list_, element_dtype=opts.element_dtype)
346
347
348def _py_list_stack(list_, opts):
349  """Overload of list_stack that executes a Python list append."""
350  # Revert to the original call.
351  return opts.original_call(list_)
352