1# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Implementation of tf.sets."""
16
17from __future__ import absolute_import
18from __future__ import division
19from __future__ import print_function
20
21
22from tensorflow.python.framework import dtypes
23from tensorflow.python.framework import ops
24from tensorflow.python.framework import sparse_tensor
25from tensorflow.python.ops import gen_set_ops
26from tensorflow.python.util.tf_export import tf_export
27
28
29_VALID_DTYPES = set([
30    dtypes.int8, dtypes.int16, dtypes.int32, dtypes.int64,
31    dtypes.uint8, dtypes.uint16, dtypes.string])
32
33
34@tf_export("sets.size", v1=["sets.size", "sets.set_size"])
35def set_size(a, validate_indices=True):
36  """Compute number of unique elements along last dimension of `a`.
37
38  Args:
39    a: `SparseTensor`, with indices sorted in row-major order.
40    validate_indices: Whether to validate the order and range of sparse indices
41       in `a`.
42
43  Returns:
44    `int32` `Tensor` of set sizes. For `a` ranked `n`, this is a `Tensor` with
45    rank `n-1`, and the same 1st `n-1` dimensions as `a`. Each value is the
46    number of unique elements in the corresponding `[0...n-1]` dimension of `a`.
47
48  Raises:
49    TypeError: If `a` is an invalid types.
50  """
51  a = sparse_tensor.convert_to_tensor_or_sparse_tensor(a, name="a")
52  if not isinstance(a, sparse_tensor.SparseTensor):
53    raise TypeError("Expected `SparseTensor`, got %s." % a)
54  if a.values.dtype.base_dtype not in _VALID_DTYPES:
55    raise TypeError("Invalid dtype %s." % a.values.dtype)
56  # pylint: disable=protected-access
57  return gen_set_ops.set_size(
58      a.indices, a.values, a.dense_shape, validate_indices)
59
60ops.NotDifferentiable("SetSize")
61
62
63ops.NotDifferentiable("DenseToDenseSetOperation")
64ops.NotDifferentiable("DenseToSparseSetOperation")
65ops.NotDifferentiable("SparseToSparseSetOperation")
66
67
68def _convert_to_tensors_or_sparse_tensors(a, b):
69  """Convert to tensor types, and flip order if necessary.
70
71  Args:
72    a: `Tensor` or `SparseTensor` of the same type as `b`.
73    b: `Tensor` or `SparseTensor` of the same type as `a`.
74
75  Returns:
76    Tuple of `(a, b, flipped)`, where `a` and `b` have been converted to
77    `Tensor` or `SparseTensor`, and `flipped` indicates whether the order has
78    been flipped to make it dense,sparse instead of sparse,dense (since the set
79    ops do not support the latter).
80  """
81  a = sparse_tensor.convert_to_tensor_or_sparse_tensor(a, name="a")
82  if a.dtype.base_dtype not in _VALID_DTYPES:
83    raise TypeError("'a' invalid dtype %s." % a.dtype)
84  b = sparse_tensor.convert_to_tensor_or_sparse_tensor(b, name="b")
85  if b.dtype.base_dtype != a.dtype.base_dtype:
86    raise TypeError("Types don't match, %s vs %s." % (a.dtype, b.dtype))
87  if (isinstance(a, sparse_tensor.SparseTensor) and
88      not isinstance(b, sparse_tensor.SparseTensor)):
89    return b, a, True
90  return a, b, False
91
92
93def _set_operation(a, b, set_operation, validate_indices=True):
94  """Compute set operation of elements in last dimension of `a` and `b`.
95
96  All but the last dimension of `a` and `b` must match.
97
98  Args:
99    a: `Tensor` or `SparseTensor` of the same type as `b`. If sparse, indices
100        must be sorted in row-major order.
101    b: `Tensor` or `SparseTensor` of the same type as `a`. Must be
102        `SparseTensor` if `a` is `SparseTensor`. If sparse, indices must be
103        sorted in row-major order.
104    set_operation: String indicating set operation. See
105        SetOperationOp::SetOperationFromContext for valid values.
106    validate_indices: Whether to validate the order and range of sparse indices
107       in `a` and `b`.
108
109  Returns:
110    A `SparseTensor` with the same rank as `a` and `b`, and all but the last
111    dimension the same. Elements along the last dimension contain the results
112    of the set operation.
113
114  Raises:
115    TypeError: If inputs are invalid types.
116    ValueError: If `a` is sparse and `b` is dense.
117  """
118  if isinstance(a, sparse_tensor.SparseTensor):
119    if isinstance(b, sparse_tensor.SparseTensor):
120      indices, values, shape = gen_set_ops.sparse_to_sparse_set_operation(
121          a.indices, a.values, a.dense_shape,
122          b.indices, b.values, b.dense_shape,
123          set_operation, validate_indices)
124    else:
125      raise ValueError("Sparse,Dense is not supported, but Dense,Sparse is. "
126                       "Please flip the order of your inputs.")
127  elif isinstance(b, sparse_tensor.SparseTensor):
128    indices, values, shape = gen_set_ops.dense_to_sparse_set_operation(
129        a, b.indices, b.values, b.dense_shape, set_operation, validate_indices)
130  else:
131    indices, values, shape = gen_set_ops.dense_to_dense_set_operation(
132        a, b, set_operation, validate_indices)
133  return sparse_tensor.SparseTensor(indices, values, shape)
134
135
136@tf_export(
137    "sets.intersection", v1=["sets.intersection", "sets.set_intersection"])
138def set_intersection(a, b, validate_indices=True):
139  """Compute set intersection of elements in last dimension of `a` and `b`.
140
141  All but the last dimension of `a` and `b` must match.
142
143  Example:
144
145  ```python
146    import tensorflow as tf
147    import collections
148
149    # Represent the following array of sets as a sparse tensor:
150    # a = np.array([[{1, 2}, {3}], [{4}, {5, 6}]])
151    a = collections.OrderedDict([
152        ((0, 0, 0), 1),
153        ((0, 0, 1), 2),
154        ((0, 1, 0), 3),
155        ((1, 0, 0), 4),
156        ((1, 1, 0), 5),
157        ((1, 1, 1), 6),
158    ])
159    a = tf.SparseTensor(list(a.keys()), list(a.values()), dense_shape=[2,2,2])
160
161    # b = np.array([[{1}, {}], [{4}, {5, 6, 7, 8}]])
162    b = collections.OrderedDict([
163        ((0, 0, 0), 1),
164        ((1, 0, 0), 4),
165        ((1, 1, 0), 5),
166        ((1, 1, 1), 6),
167        ((1, 1, 2), 7),
168        ((1, 1, 3), 8),
169    ])
170    b = tf.SparseTensor(list(b.keys()), list(b.values()), dense_shape=[2, 2, 4])
171
172    # `tf.sets.set_intersection` is applied to each aligned pair of sets.
173    tf.sets.set_intersection(a, b)
174
175    # The result will be equivalent to either of:
176    #
177    # np.array([[{1}, {}], [{4}, {5, 6}]])
178    #
179    # collections.OrderedDict([
180    #     ((0, 0, 0), 1),
181    #     ((1, 0, 0), 4),
182    #     ((1, 1, 0), 5),
183    #     ((1, 1, 1), 6),
184    # ])
185  ```
186
187  Args:
188    a: `Tensor` or `SparseTensor` of the same type as `b`. If sparse, indices
189        must be sorted in row-major order.
190    b: `Tensor` or `SparseTensor` of the same type as `a`. If sparse, indices
191        must be sorted in row-major order.
192    validate_indices: Whether to validate the order and range of sparse indices
193       in `a` and `b`.
194
195  Returns:
196    A `SparseTensor` whose shape is the same rank as `a` and `b`, and all but
197    the last dimension the same. Elements along the last dimension contain the
198    intersections.
199  """
200  a, b, _ = _convert_to_tensors_or_sparse_tensors(a, b)
201  return _set_operation(a, b, "intersection", validate_indices)
202
203
204@tf_export(
205	   "sets.difference", v1=["sets.difference", "sets.set_difference"])
206def set_difference(a, b, aminusb=True, validate_indices=True):
207  """Compute set difference of elements in last dimension of `a` and `b`.
208
209  All but the last dimension of `a` and `b` must match.
210
211  Example:
212
213  ```python
214    import tensorflow as tf
215    import collections
216
217    # Represent the following array of sets as a sparse tensor:
218    # a = np.array([[{1, 2}, {3}], [{4}, {5, 6}]])
219    a = collections.OrderedDict([
220        ((0, 0, 0), 1),
221        ((0, 0, 1), 2),
222        ((0, 1, 0), 3),
223        ((1, 0, 0), 4),
224        ((1, 1, 0), 5),
225        ((1, 1, 1), 6),
226    ])
227    a = tf.SparseTensor(list(a.keys()), list(a.values()), dense_shape=[2, 2, 2])
228
229    # np.array([[{1, 3}, {2}], [{4, 5}, {5, 6, 7, 8}]])
230    b = collections.OrderedDict([
231        ((0, 0, 0), 1),
232        ((0, 0, 1), 3),
233        ((0, 1, 0), 2),
234        ((1, 0, 0), 4),
235        ((1, 0, 1), 5),
236        ((1, 1, 0), 5),
237        ((1, 1, 1), 6),
238        ((1, 1, 2), 7),
239        ((1, 1, 3), 8),
240    ])
241    b = tf.SparseTensor(list(b.keys()), list(b.values()), dense_shape=[2, 2, 4])
242
243    # `set_difference` is applied to each aligned pair of sets.
244    tf.sets.set_difference(a, b)
245
246    # The result will be equivalent to either of:
247    #
248    # np.array([[{2}, {3}], [{}, {}]])
249    #
250    # collections.OrderedDict([
251    #     ((0, 0, 0), 2),
252    #     ((0, 1, 0), 3),
253    # ])
254  ```
255
256  Args:
257    a: `Tensor` or `SparseTensor` of the same type as `b`. If sparse, indices
258        must be sorted in row-major order.
259    b: `Tensor` or `SparseTensor` of the same type as `a`. If sparse, indices
260        must be sorted in row-major order.
261    aminusb: Whether to subtract `b` from `a`, vs vice versa.
262    validate_indices: Whether to validate the order and range of sparse indices
263       in `a` and `b`.
264
265  Returns:
266    A `SparseTensor` whose shape is the same rank as `a` and `b`, and all but
267    the last dimension the same. Elements along the last dimension contain the
268    differences.
269  """
270  a, b, flipped = _convert_to_tensors_or_sparse_tensors(a, b)
271  if flipped:
272    aminusb = not aminusb
273  return _set_operation(a, b, "a-b" if aminusb else "b-a", validate_indices)
274
275
276@tf_export(
277	   "sets.union", v1=["sets.union", "sets.set_union"])
278def set_union(a, b, validate_indices=True):
279  """Compute set union of elements in last dimension of `a` and `b`.
280
281  All but the last dimension of `a` and `b` must match.
282
283  Example:
284
285  ```python
286    import tensorflow as tf
287    import collections
288
289    # [[{1, 2}, {3}], [{4}, {5, 6}]]
290    a = collections.OrderedDict([
291        ((0, 0, 0), 1),
292        ((0, 0, 1), 2),
293        ((0, 1, 0), 3),
294        ((1, 0, 0), 4),
295        ((1, 1, 0), 5),
296        ((1, 1, 1), 6),
297    ])
298    a = tf.SparseTensor(list(a.keys()), list(a.values()), dense_shape=[2, 2, 2])
299
300    # [[{1, 3}, {2}], [{4, 5}, {5, 6, 7, 8}]]
301    b = collections.OrderedDict([
302        ((0, 0, 0), 1),
303        ((0, 0, 1), 3),
304        ((0, 1, 0), 2),
305        ((1, 0, 0), 4),
306        ((1, 0, 1), 5),
307        ((1, 1, 0), 5),
308        ((1, 1, 1), 6),
309        ((1, 1, 2), 7),
310        ((1, 1, 3), 8),
311    ])
312    b = tf.SparseTensor(list(b.keys()), list(b.values()), dense_shape=[2, 2, 4])
313
314    # `set_union` is applied to each aligned pair of sets.
315    tf.sets.set_union(a, b)
316
317    # The result will be a equivalent to either of:
318    #
319    # np.array([[{1, 2, 3}, {2, 3}], [{4, 5}, {5, 6, 7, 8}]])
320    #
321    # collections.OrderedDict([
322    #     ((0, 0, 0), 1),
323    #     ((0, 0, 1), 2),
324    #     ((0, 0, 2), 3),
325    #     ((0, 1, 0), 2),
326    #     ((0, 1, 1), 3),
327    #     ((1, 0, 0), 4),
328    #     ((1, 0, 1), 5),
329    #     ((1, 1, 0), 5),
330    #     ((1, 1, 1), 6),
331    #     ((1, 1, 2), 7),
332    #     ((1, 1, 3), 8),
333    # ])
334  ```
335
336  Args:
337    a: `Tensor` or `SparseTensor` of the same type as `b`. If sparse, indices
338        must be sorted in row-major order.
339    b: `Tensor` or `SparseTensor` of the same type as `a`. If sparse, indices
340        must be sorted in row-major order.
341    validate_indices: Whether to validate the order and range of sparse indices
342       in `a` and `b`.
343
344  Returns:
345    A `SparseTensor` whose shape is the same rank as `a` and `b`, and all but
346    the last dimension the same. Elements along the last dimension contain the
347    unions.
348  """
349  a, b, _ = _convert_to_tensors_or_sparse_tensors(a, b)
350  return _set_operation(a, b, "union", validate_indices)
351