1# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Sparse tensors."""
16# pylint: disable=g-bad-name
17from __future__ import absolute_import
18from __future__ import division
19from __future__ import print_function
20
21import collections
22
23from tensorflow.python import pywrap_tensorflow
24from tensorflow.python.framework import composite_tensor
25from tensorflow.python.framework import dtypes
26from tensorflow.python.framework import ops
27from tensorflow.python.framework import tensor_shape
28from tensorflow.python.framework import tensor_util
29from tensorflow.python.util.tf_export import tf_export
30
31# pylint: disable=protected-access
32_TensorLike = ops._TensorLike
33_eval_using_default_session = ops._eval_using_default_session
34_override_helper = ops._override_helper
35# pylint: enable=protected-access
36
37
38@tf_export("sparse.SparseTensor", "SparseTensor")
39class SparseTensor(_TensorLike, composite_tensor.CompositeTensor):
40  """Represents a sparse tensor.
41
42  TensorFlow represents a sparse tensor as three separate dense tensors:
43  `indices`, `values`, and `dense_shape`.  In Python, the three tensors are
44  collected into a `SparseTensor` class for ease of use.  If you have separate
45  `indices`, `values`, and `dense_shape` tensors, wrap them in a `SparseTensor`
46  object before passing to the ops below.
47
48  Concretely, the sparse tensor `SparseTensor(indices, values, dense_shape)`
49  comprises the following components, where `N` and `ndims` are the number
50  of values and number of dimensions in the `SparseTensor`, respectively:
51
52  * `indices`: A 2-D int64 tensor of dense_shape `[N, ndims]`, which specifies
53    the indices of the elements in the sparse tensor that contain nonzero
54    values (elements are zero-indexed). For example, `indices=[[1,3], [2,4]]`
55    specifies that the elements with indexes of [1,3] and [2,4] have
56    nonzero values.
57
58  * `values`: A 1-D tensor of any type and dense_shape `[N]`, which supplies the
59    values for each element in `indices`. For example, given
60    `indices=[[1,3], [2,4]]`, the parameter `values=[18, 3.6]` specifies
61    that element [1,3] of the sparse tensor has a value of 18, and element
62    [2,4] of the tensor has a value of 3.6.
63
64  * `dense_shape`: A 1-D int64 tensor of dense_shape `[ndims]`, which specifies
65    the dense_shape of the sparse tensor. Takes a list indicating the number of
66    elements in each dimension. For example, `dense_shape=[3,6]` specifies a
67    two-dimensional 3x6 tensor, `dense_shape=[2,3,4]` specifies a
68    three-dimensional 2x3x4 tensor, and `dense_shape=[9]` specifies a
69    one-dimensional tensor with 9 elements.
70
71  The corresponding dense tensor satisfies:
72
73  ```python
74  dense.shape = dense_shape
75  dense[tuple(indices[i])] = values[i]
76  ```
77
78  By convention, `indices` should be sorted in row-major order (or equivalently
79  lexicographic order on the tuples `indices[i]`). This is not enforced when
80  `SparseTensor` objects are constructed, but most ops assume correct ordering.
81  If the ordering of sparse tensor `st` is wrong, a fixed version can be
82  obtained by calling `tf.sparse_reorder(st)`.
83
84  Example: The sparse tensor
85
86  ```python
87  SparseTensor(indices=[[0, 0], [1, 2]], values=[1, 2], dense_shape=[3, 4])
88  ```
89
90  represents the dense tensor
91
92  ```python
93  [[1, 0, 0, 0]
94   [0, 0, 2, 0]
95   [0, 0, 0, 0]]
96  ```
97  """
98
99  @classmethod
100  def from_value(cls, sparse_tensor_value):
101    if not is_sparse(sparse_tensor_value):
102      raise TypeError("Neither a SparseTensor nor SparseTensorValue: %s." %
103                      sparse_tensor_value)
104    return SparseTensor(
105        indices=sparse_tensor_value.indices,
106        values=sparse_tensor_value.values,
107        dense_shape=sparse_tensor_value.dense_shape)
108
109  def __init__(self, indices, values, dense_shape):
110    """Creates a `SparseTensor`.
111
112    Args:
113      indices: A 2-D int64 tensor of shape `[N, ndims]`.
114      values: A 1-D tensor of any type and shape `[N]`.
115      dense_shape: A 1-D int64 tensor of shape `[ndims]`.
116
117    """
118    with ops.name_scope(None, "SparseTensor", [indices, values, dense_shape]):
119      indices = ops.convert_to_tensor(
120          indices, name="indices", dtype=dtypes.int64)
121      # TODO(touts): Consider adding mutable_values() when 'values'
122      # is a VariableOp and updating users of SparseTensor.
123      values = ops.internal_convert_to_tensor(values, name="values")
124      dense_shape = ops.convert_to_tensor(
125          dense_shape, name="dense_shape", dtype=dtypes.int64)
126    self._indices = indices
127    self._values = values
128    self._dense_shape = dense_shape
129
130    indices_shape = indices.get_shape().with_rank(2)
131    values_shape = values.get_shape().with_rank(1)
132    dense_shape_shape = dense_shape.get_shape().with_rank(1)
133
134    # Assert number of rows in indices match the number of elements in values.
135    indices_shape.dims[0].merge_with(values_shape.dims[0])
136    # Assert number of columns in indices matches the number of elements in
137    # dense_shape.
138    indices_shape.dims[1].merge_with(dense_shape_shape.dims[0])
139
140  def get_shape(self):
141    """Get the `TensorShape` representing the shape of the dense tensor.
142
143    Returns:
144      A `TensorShape` object.
145    """
146    return tensor_util.constant_value_as_shape(self._dense_shape)
147
148  @property
149  def indices(self):
150    """The indices of non-zero values in the represented dense tensor.
151
152    Returns:
153      A 2-D Tensor of int64 with dense_shape `[N, ndims]`, where `N` is the
154        number of non-zero values in the tensor, and `ndims` is the rank.
155    """
156    return self._indices
157
158  @property
159  def values(self):
160    """The non-zero values in the represented dense tensor.
161
162    Returns:
163      A 1-D Tensor of any data type.
164    """
165    return self._values
166
167  @property
168  def op(self):
169    """The `Operation` that produces `values` as an output."""
170    return self.values.op
171
172  @property
173  def dtype(self):
174    """The `DType` of elements in this tensor."""
175    return self._values.dtype
176
177  @property
178  def dense_shape(self):
179    """A 1-D Tensor of int64 representing the shape of the dense tensor."""
180    return self._dense_shape
181
182  @property
183  def shape(self):
184    """Get the `TensorShape` representing the shape of the dense tensor.
185
186    Returns:
187      A `TensorShape` object.
188    """
189    return tensor_util.constant_value_as_shape(self._dense_shape)
190
191  @property
192  def graph(self):
193    """The `Graph` that contains the index, value, and dense_shape tensors."""
194    return self._indices.graph
195
196  def __str__(self):
197    return "SparseTensor(indices=%s, values=%s, dense_shape=%s)" % (
198        self._indices, self._values, self._dense_shape)
199
200  def eval(self, feed_dict=None, session=None):
201    """Evaluates this sparse tensor in a `Session`.
202
203    Calling this method will execute all preceding operations that
204    produce the inputs needed for the operation that produces this
205    tensor.
206
207    *N.B.* Before invoking `SparseTensor.eval()`, its graph must have been
208    launched in a session, and either a default session must be
209    available, or `session` must be specified explicitly.
210
211    Args:
212      feed_dict: A dictionary that maps `Tensor` objects to feed values.
213        See `tf.Session.run` for a
214        description of the valid feed values.
215      session: (Optional.) The `Session` to be used to evaluate this sparse
216        tensor. If none, the default session will be used.
217
218    Returns:
219      A `SparseTensorValue` object.
220    """
221    indices, values, dense_shape = _eval_using_default_session(
222        [self.indices, self.values, self.dense_shape], feed_dict, self.graph,
223        session)
224    return SparseTensorValue(indices, values, dense_shape)
225
226  @staticmethod
227  def _override_operator(operator, func):
228    _override_helper(SparseTensor, operator, func)
229
230  def _to_components(self):
231    return (self._indices, self._values, self._dense_shape)
232
233  @classmethod
234  def _from_components(cls, components):
235    return cls(*components)
236
237  def _shape_invariant_to_components(self, shape=None):
238    if shape is None:
239      shape = self.dense_shape.shape
240    if shape.ndims is None:
241      shape = tensor_shape.TensorShape([None])
242    if shape.ndims != 1:
243      raise ValueError("Shape invariant for SparseTensor must have the form "
244                       "TensorShape([r]), got %r" % shape)
245    rank = tensor_shape.dimension_value(shape[0])
246    return [tensor_shape.TensorShape([None, rank]),  # indices
247            tensor_shape.TensorShape([None]),  # values
248            tensor_shape.TensorShape([rank])]  # dense_shape
249
250  @property
251  def _is_graph_tensor(self):
252    return hasattr(self._values, 'graph')
253
254
255SparseTensorValue = collections.namedtuple(
256    "SparseTensorValue", ["indices", "values", "dense_shape"])
257tf_export(v1=["SparseTensorValue"])(SparseTensorValue)
258pywrap_tensorflow.RegisterType("SparseTensorValue", SparseTensorValue)
259
260
261@tf_export(v1=["convert_to_tensor_or_sparse_tensor"])
262def convert_to_tensor_or_sparse_tensor(value, dtype=None, name=None):
263  """Converts value to a `SparseTensor` or `Tensor`.
264
265  Args:
266    value: A `SparseTensor`, `SparseTensorValue`, or an object whose type has a
267      registered `Tensor` conversion function.
268    dtype: Optional element type for the returned tensor. If missing, the
269      type is inferred from the type of `value`.
270    name: Optional name to use if a new `Tensor` is created.
271
272  Returns:
273    A `SparseTensor` or `Tensor` based on `value`.
274
275  Raises:
276    RuntimeError: If result type is incompatible with `dtype`.
277  """
278  if dtype is not None:
279    dtype = dtypes.as_dtype(dtype)
280  if isinstance(value, SparseTensorValue):
281    value = SparseTensor.from_value(value)
282  if isinstance(value, SparseTensor):
283    if dtype and not dtype.is_compatible_with(value.dtype):
284      raise RuntimeError(
285          "Sparse dtype: requested = %s, actual = %s" % (
286              dtype.name, value.dtype.name))
287    return value
288  return ops.internal_convert_to_tensor(
289      value, dtype=dtype, name=name)
290
291
292def is_sparse(x):
293  """Check whether `x` is sparse.
294
295  Check whether an object is a `tf.SparseTensor` or `tf.SparseTensorValue`.
296
297  Args:
298    x: A python object to check.
299
300  Returns:
301    `True` iff `x` is a `tf.SparseTensor` or `tf.SparseTensorValue`.
302  """
303  return isinstance(x, (SparseTensor, SparseTensorValue))
304