1# Copyright 2015 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Sparse tensors.""" 16# pylint: disable=g-bad-name 17from __future__ import absolute_import 18from __future__ import division 19from __future__ import print_function 20 21import collections 22 23from tensorflow.python import pywrap_tensorflow 24from tensorflow.python.framework import composite_tensor 25from tensorflow.python.framework import dtypes 26from tensorflow.python.framework import ops 27from tensorflow.python.framework import tensor_shape 28from tensorflow.python.framework import tensor_util 29from tensorflow.python.util.tf_export import tf_export 30 31# pylint: disable=protected-access 32_TensorLike = ops._TensorLike 33_eval_using_default_session = ops._eval_using_default_session 34_override_helper = ops._override_helper 35# pylint: enable=protected-access 36 37 38@tf_export("sparse.SparseTensor", "SparseTensor") 39class SparseTensor(_TensorLike, composite_tensor.CompositeTensor): 40 """Represents a sparse tensor. 41 42 TensorFlow represents a sparse tensor as three separate dense tensors: 43 `indices`, `values`, and `dense_shape`. In Python, the three tensors are 44 collected into a `SparseTensor` class for ease of use. If you have separate 45 `indices`, `values`, and `dense_shape` tensors, wrap them in a `SparseTensor` 46 object before passing to the ops below. 47 48 Concretely, the sparse tensor `SparseTensor(indices, values, dense_shape)` 49 comprises the following components, where `N` and `ndims` are the number 50 of values and number of dimensions in the `SparseTensor`, respectively: 51 52 * `indices`: A 2-D int64 tensor of dense_shape `[N, ndims]`, which specifies 53 the indices of the elements in the sparse tensor that contain nonzero 54 values (elements are zero-indexed). For example, `indices=[[1,3], [2,4]]` 55 specifies that the elements with indexes of [1,3] and [2,4] have 56 nonzero values. 57 58 * `values`: A 1-D tensor of any type and dense_shape `[N]`, which supplies the 59 values for each element in `indices`. For example, given 60 `indices=[[1,3], [2,4]]`, the parameter `values=[18, 3.6]` specifies 61 that element [1,3] of the sparse tensor has a value of 18, and element 62 [2,4] of the tensor has a value of 3.6. 63 64 * `dense_shape`: A 1-D int64 tensor of dense_shape `[ndims]`, which specifies 65 the dense_shape of the sparse tensor. Takes a list indicating the number of 66 elements in each dimension. For example, `dense_shape=[3,6]` specifies a 67 two-dimensional 3x6 tensor, `dense_shape=[2,3,4]` specifies a 68 three-dimensional 2x3x4 tensor, and `dense_shape=[9]` specifies a 69 one-dimensional tensor with 9 elements. 70 71 The corresponding dense tensor satisfies: 72 73 ```python 74 dense.shape = dense_shape 75 dense[tuple(indices[i])] = values[i] 76 ``` 77 78 By convention, `indices` should be sorted in row-major order (or equivalently 79 lexicographic order on the tuples `indices[i]`). This is not enforced when 80 `SparseTensor` objects are constructed, but most ops assume correct ordering. 81 If the ordering of sparse tensor `st` is wrong, a fixed version can be 82 obtained by calling `tf.sparse_reorder(st)`. 83 84 Example: The sparse tensor 85 86 ```python 87 SparseTensor(indices=[[0, 0], [1, 2]], values=[1, 2], dense_shape=[3, 4]) 88 ``` 89 90 represents the dense tensor 91 92 ```python 93 [[1, 0, 0, 0] 94 [0, 0, 2, 0] 95 [0, 0, 0, 0]] 96 ``` 97 """ 98 99 @classmethod 100 def from_value(cls, sparse_tensor_value): 101 if not is_sparse(sparse_tensor_value): 102 raise TypeError("Neither a SparseTensor nor SparseTensorValue: %s." % 103 sparse_tensor_value) 104 return SparseTensor( 105 indices=sparse_tensor_value.indices, 106 values=sparse_tensor_value.values, 107 dense_shape=sparse_tensor_value.dense_shape) 108 109 def __init__(self, indices, values, dense_shape): 110 """Creates a `SparseTensor`. 111 112 Args: 113 indices: A 2-D int64 tensor of shape `[N, ndims]`. 114 values: A 1-D tensor of any type and shape `[N]`. 115 dense_shape: A 1-D int64 tensor of shape `[ndims]`. 116 117 """ 118 with ops.name_scope(None, "SparseTensor", [indices, values, dense_shape]): 119 indices = ops.convert_to_tensor( 120 indices, name="indices", dtype=dtypes.int64) 121 # TODO(touts): Consider adding mutable_values() when 'values' 122 # is a VariableOp and updating users of SparseTensor. 123 values = ops.internal_convert_to_tensor(values, name="values") 124 dense_shape = ops.convert_to_tensor( 125 dense_shape, name="dense_shape", dtype=dtypes.int64) 126 self._indices = indices 127 self._values = values 128 self._dense_shape = dense_shape 129 130 indices_shape = indices.get_shape().with_rank(2) 131 values_shape = values.get_shape().with_rank(1) 132 dense_shape_shape = dense_shape.get_shape().with_rank(1) 133 134 # Assert number of rows in indices match the number of elements in values. 135 indices_shape.dims[0].merge_with(values_shape.dims[0]) 136 # Assert number of columns in indices matches the number of elements in 137 # dense_shape. 138 indices_shape.dims[1].merge_with(dense_shape_shape.dims[0]) 139 140 def get_shape(self): 141 """Get the `TensorShape` representing the shape of the dense tensor. 142 143 Returns: 144 A `TensorShape` object. 145 """ 146 return tensor_util.constant_value_as_shape(self._dense_shape) 147 148 @property 149 def indices(self): 150 """The indices of non-zero values in the represented dense tensor. 151 152 Returns: 153 A 2-D Tensor of int64 with dense_shape `[N, ndims]`, where `N` is the 154 number of non-zero values in the tensor, and `ndims` is the rank. 155 """ 156 return self._indices 157 158 @property 159 def values(self): 160 """The non-zero values in the represented dense tensor. 161 162 Returns: 163 A 1-D Tensor of any data type. 164 """ 165 return self._values 166 167 @property 168 def op(self): 169 """The `Operation` that produces `values` as an output.""" 170 return self.values.op 171 172 @property 173 def dtype(self): 174 """The `DType` of elements in this tensor.""" 175 return self._values.dtype 176 177 @property 178 def dense_shape(self): 179 """A 1-D Tensor of int64 representing the shape of the dense tensor.""" 180 return self._dense_shape 181 182 @property 183 def shape(self): 184 """Get the `TensorShape` representing the shape of the dense tensor. 185 186 Returns: 187 A `TensorShape` object. 188 """ 189 return tensor_util.constant_value_as_shape(self._dense_shape) 190 191 @property 192 def graph(self): 193 """The `Graph` that contains the index, value, and dense_shape tensors.""" 194 return self._indices.graph 195 196 def __str__(self): 197 return "SparseTensor(indices=%s, values=%s, dense_shape=%s)" % ( 198 self._indices, self._values, self._dense_shape) 199 200 def eval(self, feed_dict=None, session=None): 201 """Evaluates this sparse tensor in a `Session`. 202 203 Calling this method will execute all preceding operations that 204 produce the inputs needed for the operation that produces this 205 tensor. 206 207 *N.B.* Before invoking `SparseTensor.eval()`, its graph must have been 208 launched in a session, and either a default session must be 209 available, or `session` must be specified explicitly. 210 211 Args: 212 feed_dict: A dictionary that maps `Tensor` objects to feed values. 213 See `tf.Session.run` for a 214 description of the valid feed values. 215 session: (Optional.) The `Session` to be used to evaluate this sparse 216 tensor. If none, the default session will be used. 217 218 Returns: 219 A `SparseTensorValue` object. 220 """ 221 indices, values, dense_shape = _eval_using_default_session( 222 [self.indices, self.values, self.dense_shape], feed_dict, self.graph, 223 session) 224 return SparseTensorValue(indices, values, dense_shape) 225 226 @staticmethod 227 def _override_operator(operator, func): 228 _override_helper(SparseTensor, operator, func) 229 230 def _to_components(self): 231 return (self._indices, self._values, self._dense_shape) 232 233 @classmethod 234 def _from_components(cls, components): 235 return cls(*components) 236 237 def _shape_invariant_to_components(self, shape=None): 238 if shape is None: 239 shape = self.dense_shape.shape 240 if shape.ndims is None: 241 shape = tensor_shape.TensorShape([None]) 242 if shape.ndims != 1: 243 raise ValueError("Shape invariant for SparseTensor must have the form " 244 "TensorShape([r]), got %r" % shape) 245 rank = tensor_shape.dimension_value(shape[0]) 246 return [tensor_shape.TensorShape([None, rank]), # indices 247 tensor_shape.TensorShape([None]), # values 248 tensor_shape.TensorShape([rank])] # dense_shape 249 250 @property 251 def _is_graph_tensor(self): 252 return hasattr(self._values, 'graph') 253 254 255SparseTensorValue = collections.namedtuple( 256 "SparseTensorValue", ["indices", "values", "dense_shape"]) 257tf_export(v1=["SparseTensorValue"])(SparseTensorValue) 258pywrap_tensorflow.RegisterType("SparseTensorValue", SparseTensorValue) 259 260 261@tf_export(v1=["convert_to_tensor_or_sparse_tensor"]) 262def convert_to_tensor_or_sparse_tensor(value, dtype=None, name=None): 263 """Converts value to a `SparseTensor` or `Tensor`. 264 265 Args: 266 value: A `SparseTensor`, `SparseTensorValue`, or an object whose type has a 267 registered `Tensor` conversion function. 268 dtype: Optional element type for the returned tensor. If missing, the 269 type is inferred from the type of `value`. 270 name: Optional name to use if a new `Tensor` is created. 271 272 Returns: 273 A `SparseTensor` or `Tensor` based on `value`. 274 275 Raises: 276 RuntimeError: If result type is incompatible with `dtype`. 277 """ 278 if dtype is not None: 279 dtype = dtypes.as_dtype(dtype) 280 if isinstance(value, SparseTensorValue): 281 value = SparseTensor.from_value(value) 282 if isinstance(value, SparseTensor): 283 if dtype and not dtype.is_compatible_with(value.dtype): 284 raise RuntimeError( 285 "Sparse dtype: requested = %s, actual = %s" % ( 286 dtype.name, value.dtype.name)) 287 return value 288 return ops.internal_convert_to_tensor( 289 value, dtype=dtype, name=name) 290 291 292def is_sparse(x): 293 """Check whether `x` is sparse. 294 295 Check whether an object is a `tf.SparseTensor` or `tf.SparseTensorValue`. 296 297 Args: 298 x: A python object to check. 299 300 Returns: 301 `True` iff `x` is a `tf.SparseTensor` or `tf.SparseTensorValue`. 302 """ 303 return isinstance(x, (SparseTensor, SparseTensorValue)) 304