1 /* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 // Ops for operating with sets. They are not checked in
17 // to TensorFlow because we would first like to demonstrate successful
18 // end-to-end use of these ops in eval and polush the api a bit like taking two
19 // SparseTensor rather than on edense and one sparse.
20
21 #define EIGEN_USE_THREADS
22
23 #include <algorithm>
24 #include <numeric>
25 // TODO(ptucker): Consider switching back to hash_set - I had trouble getting it
26 // to work with string values.
27 #include <set>
28 #include <string>
29
30 #include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
31 #include "tensorflow/core/framework/op_kernel.h"
32 #include "tensorflow/core/framework/register_types.h"
33 #include "tensorflow/core/framework/tensor.h"
34 #include "tensorflow/core/framework/tensor_util.h"
35 #include "tensorflow/core/framework/types.h"
36 #include "tensorflow/core/lib/core/status.h"
37 #include "tensorflow/core/platform/env.h"
38 #include "tensorflow/core/util/sparse/sparse_tensor.h"
39
40 namespace tensorflow {
41
42 using ShapeArray = sparse::SparseTensor::ShapeArray;
43 using VarDimArray = sparse::SparseTensor::VarDimArray;
44
45 // Validate rank >= 2.
CheckRankAtLeast2(OpKernelContext * ctx,const TensorShape & shape)46 void CheckRankAtLeast2(OpKernelContext* ctx, const TensorShape& shape) {
47 const auto rank = shape.dims();
48 OP_REQUIRES(ctx, rank >= 2,
49 errors::InvalidArgument("Invalid rank ", rank, "."));
50 }
51
52 // Return group shape, which is the 1st n-1 dimensions of shape.
GroupShape(const VarDimArray & input_shape,ShapeArray * grouped_shape)53 Status GroupShape(const VarDimArray& input_shape, ShapeArray* grouped_shape) {
54 if (input_shape.size() < 2) {
55 // TODO(irving): Why can't 2 be 1 here?
56 return errors::InvalidArgument("Shape [", str_util::Join(input_shape, ","),
57 "] has rank ", input_shape.size(), " < 2");
58 }
59 // grouped_shape is input_shape[:-1]
60 *grouped_shape = ShapeArray(input_shape.begin(), input_shape.end() - 1);
61 return Status::OK();
62 }
63
64 // Build `SparseTensor` from indices, values, and shape in inputs
65 // [base_index, base_index + 3), and validate its rank and indices.
SparseTensorFromContext(OpKernelContext * ctx,const int32 base_index,bool validate_indices,sparse::SparseTensor * tensor)66 Status SparseTensorFromContext(OpKernelContext* ctx, const int32 base_index,
67 bool validate_indices,
68 sparse::SparseTensor* tensor) {
69 // Assume row-major order.
70 const TensorShape shape =
71 TensorShape(ctx->input(base_index + 2).vec<int64>());
72 CheckRankAtLeast2(ctx, shape);
73 std::vector<int64> order(shape.dims());
74 std::iota(order.begin(), order.end(), 0);
75
76 return sparse::SparseTensor::Create(
77 ctx->input(base_index), ctx->input(base_index + 1), shape, order, tensor);
78 }
79
80 // TODO(ptucker): CheckGroup is just a sanity check on the result of
81 // SparseTensor.group, consider removing.
82 // `sparse_tensor_shape` is the shape of the `SparseTensor` from which group
83 // was created, and is used to sanity check the indices in `group'.
84 template <typename T>
CheckGroup(OpKernelContext * ctx,const sparse::Group & group,const VarDimArray & sparse_tensor_shape)85 void CheckGroup(OpKernelContext* ctx, const sparse::Group& group,
86 const VarDimArray& sparse_tensor_shape) {
87 const auto& indices = group.indices();
88 const auto& values = group.values<T>();
89
90 // Sanity check: group is non-empty, and indices and values are same size.
91 const auto num_values = values.dimension(0);
92 OP_REQUIRES(ctx, indices.size() > 0, errors::Internal("Empty group."));
93 OP_REQUIRES(
94 ctx, indices.dimension(0) == num_values,
95 errors::Internal("shape[0] of group indices ", indices.dimension(0),
96 " != values ", num_values, "."));
97
98 // Sanity check: valid indices.
99 const auto group_rank = indices.dimension(1);
100 const auto expected_rank = sparse_tensor_shape.size();
101 OP_REQUIRES(ctx, expected_rank == group_rank,
102 errors::Internal("Rank expected ", expected_rank, ", got ",
103 group_rank, "."));
104 for (int32 j = 0; j < expected_rank; ++j) {
105 const auto dim_size = sparse_tensor_shape[j];
106 OP_REQUIRES(
107 ctx, dim_size > 0,
108 errors::Internal("Invalid dim_size[", j, "] = ", dim_size, "."));
109 for (int64 i = 0; i < num_values; ++i) {
110 const auto index = indices(i, j);
111 OP_REQUIRES(ctx, dim_size > index,
112 errors::Internal("indices[", i, ", ", j, "] expected < ",
113 dim_size, ", got ", index, "."));
114 }
115 }
116 }
117
118 // This lets us calculate the row-major index into flattened output.
Strides(const VarDimArray & shape)119 const ShapeArray Strides(const VarDimArray& shape) {
120 ShapeArray result(shape.size());
121 int64 product = 1;
122 for (int i = shape.size() - 1; i >= 0; --i) {
123 result[i] = product;
124 product *= shape[i];
125 }
126 return result;
127 }
128
129 // TODO(ptucker): If memory becomes an issue, consider a 2-pass approach to
130 // eliminate the intermediate `values` data structure - iterate once to
131 // determine `num_values`, allocate output tensors, then write results directly
132 // to output tensors.
133
134 // TODO(ptucker): Consider sharding work across multiple threads. See
135 // SparseCrossOp for an example.
136
137 // Output `SparseTensor` of shape `output_shape`. `sets` contains a map of
138 // group indices (i.e., values for all but the last dimension of `output_shape`)
139 // to set values, each of which will occupy the last dimension of
140 // `output_shape`.
141 template <typename T>
OutputSparseTensor(OpKernelContext * ctx,const TensorShape & output_shape,const int64 num_values,const std::map<std::vector<int64>,std::set<T>> & sets)142 void OutputSparseTensor(OpKernelContext* ctx, const TensorShape& output_shape,
143 const int64 num_values,
144 const std::map<std::vector<int64>, std::set<T>>& sets) {
145 // Allocate 3 output tensors for sparse data.
146 Tensor *out_indices_t, *out_values_t, *out_shape_t;
147 OP_REQUIRES_OK(ctx, ctx->allocate_output(
148 0, TensorShape({num_values, output_shape.dims()}),
149 &out_indices_t));
150 OP_REQUIRES_OK(
151 ctx, ctx->allocate_output(1, TensorShape({num_values}), &out_values_t));
152 OP_REQUIRES_OK(ctx, ctx->allocate_output(
153 2, TensorShape({output_shape.dims()}), &out_shape_t));
154 auto out_indices_mat = out_indices_t->matrix<int64>();
155 auto out_values_flat = out_values_t->vec<T>();
156
157 // For each set, write its indices and values to output tensors.
158 int64 value_index = 0;
159 for (auto it = sets.begin(); it != sets.end(); ++it) {
160 const auto& group_indices = it->first;
161 OP_REQUIRES(
162 ctx, group_indices.size() == output_shape.dims() - 1,
163 errors::Internal("Invalid number of indices ", group_indices.size(),
164 ", expected ", output_shape.dims() - 1, "."));
165 const auto& set = it->second;
166
167 // For each set item, write its indices and value to output tensors.
168 int64 group_value_index = 0;
169 for (auto value = set.begin(); value != set.end();
170 ++value, ++value_index, ++group_value_index) {
171 // First n-1 dimensions are the group, last dimension is the position in
172 // the set.
173 for (int32 i = 0; i < group_indices.size(); ++i) {
174 out_indices_mat(value_index, i) = group_indices[i];
175 }
176 out_indices_mat(value_index, group_indices.size()) = group_value_index;
177
178 out_values_flat(value_index) = *value;
179 }
180 }
181
182 // Write output shape.
183 auto out_shape_flat = out_shape_t->vec<int64>();
184 for (int32 i = 0; i < output_shape.dims(); ++i) {
185 out_shape_flat(i) = output_shape.dim_size(i);
186 }
187 }
188
ValidateIndicesFromContext(OpKernelConstruction * ctx)189 bool ValidateIndicesFromContext(OpKernelConstruction* ctx) {
190 bool result;
191 if (ctx->GetAttr("validate_indices", &result).ok()) {
192 return result;
193 }
194 return true;
195 }
196
197 // Populate `result` set from group in `tensor`. "Group" is defined by
198 // `group_indices`, which are values for the first n-1 dimensions of
199 // `input_tensor`. `input_strides` is provided to avoid recalculating it
200 // multiple times, and is used to calculate the flat index into `input_tensor`
201 // values.
202 template <typename T>
PopulateFromDenseGroup(OpKernelContext * ctx,const Tensor & input_tensor,const VarDimArray & input_strides,const std::vector<int64> & group_indices,std::set<T> * result)203 void PopulateFromDenseGroup(OpKernelContext* ctx, const Tensor& input_tensor,
204 const VarDimArray& input_strides,
205 const std::vector<int64>& group_indices,
206 std::set<T>* result) {
207 OP_REQUIRES(ctx, group_indices.size() == input_strides.size() - 1,
208 errors::Internal("group_indices.size ", group_indices.size(),
209 ", != input_strides.size-1 ",
210 input_strides.size() - 1, "."));
211 result->clear();
212 auto input_flat = input_tensor.flat<T>();
213 const auto start = std::inner_product(
214 group_indices.begin(), group_indices.end(), input_strides.begin(), 0LL);
215 const TensorShape& input_shape = input_tensor.shape();
216 const auto end = start + input_shape.dim_size(input_shape.dims() - 1);
217 for (int64 i = start; i < end; ++i) {
218 result->insert(input_flat(i));
219 }
220 }
221
222 // Populate `result` set from `group`. `sparse_tensor_shape` is the shape of the
223 // `SparseTensor` from which group was created, and is used to sanity check the
224 // indices in `group'.
225 template <typename T>
PopulateFromSparseGroup(OpKernelContext * ctx,const sparse::Group & group,const VarDimArray & sparse_tensor_shape,std::set<T> * result)226 void PopulateFromSparseGroup(OpKernelContext* ctx, const sparse::Group& group,
227 const VarDimArray& sparse_tensor_shape,
228 std::set<T>* result) {
229 CheckGroup<T>(ctx, group, sparse_tensor_shape);
230 result->clear();
231 const auto& group_values = group.values<T>();
232 for (int64 i = 0; i < group_values.size(); ++i) {
233 result->insert(group_values(i));
234 }
235 }
236
237 template <typename T>
238 class SetSizeOp : public OpKernel {
239 public:
SetSizeOp(OpKernelConstruction * ctx)240 explicit SetSizeOp(OpKernelConstruction* ctx)
241 : OpKernel(ctx), validate_indices_(ValidateIndicesFromContext(ctx)) {}
242
243 void Compute(OpKernelContext* ctx) override;
244
245 private:
246 const bool validate_indices_;
247 };
248
249 template <typename T>
Compute(OpKernelContext * ctx)250 void SetSizeOp<T>::Compute(OpKernelContext* ctx) {
251 sparse::SparseTensor set_st;
252 OP_REQUIRES_OK(ctx,
253 SparseTensorFromContext(ctx, 0, validate_indices_, &set_st));
254 OP_REQUIRES_OK(ctx, set_st.IndicesValid());
255
256 // Output shape is same as input except for last dimension, which reduces
257 // to the set size of values along that dimension.
258 ShapeArray output_shape;
259 OP_REQUIRES_OK(ctx, GroupShape(set_st.shape(), &output_shape));
260 const auto output_strides = Strides(output_shape);
261
262 TensorShape output_shape_ts;
263 OP_REQUIRES_OK(ctx,
264 TensorShapeUtils::MakeShape(output_shape, &output_shape_ts));
265 Tensor* out_t;
266 OP_REQUIRES_OK(ctx, ctx->allocate_output(0, output_shape_ts, &out_t));
267 auto out = out_t->flat<int32>();
268 out.device(ctx->eigen_cpu_device()) = out.constant(static_cast<int32>(0.0));
269
270 // Group by all but last dimension, create a set of group values, and add set
271 // size to output.
272 VarDimArray group_ix = set_st.order().subspan(0, set_st.order().size() - 1);
273 std::set<T> group_set;
274 for (const auto& group : set_st.group(group_ix)) {
275 PopulateFromSparseGroup<T>(ctx, group, set_st.shape(), &group_set);
276
277 const auto group_key = group.group();
278 const auto output_index = std::inner_product(
279 group_key.begin(), group_key.end(), output_strides.begin(), 0LL);
280 out(output_index) = group_set.size();
281 }
282 }
283
284 #define _SET_SIZE_REGISTER_KERNEL_BUILDER(T) \
285 REGISTER_KERNEL_BUILDER( \
286 Name("SetSize").Device(DEVICE_CPU).TypeConstraint<T>("T"), \
287 SetSizeOp<T>);
288 _SET_SIZE_REGISTER_KERNEL_BUILDER(int8);
289 _SET_SIZE_REGISTER_KERNEL_BUILDER(int16);
290 _SET_SIZE_REGISTER_KERNEL_BUILDER(int32);
291 _SET_SIZE_REGISTER_KERNEL_BUILDER(int64);
292 _SET_SIZE_REGISTER_KERNEL_BUILDER(uint8);
293 _SET_SIZE_REGISTER_KERNEL_BUILDER(uint16);
294 _SET_SIZE_REGISTER_KERNEL_BUILDER(string);
295 #undef _SET_SIZE_REGISTER_KERNEL_BUILDER
296
297 enum InputTypes {
298 DENSE_DENSE = 0,
299 DENSE_SPARSE = 1,
300 SPARSE_SPARSE = 2,
301 };
302
303 enum SetOperation { A_MINUS_B = 0, B_MINUS_A = 1, INTERSECTION = 2, UNION = 3 };
304
SetOperationFromContext(OpKernelConstruction * ctx)305 SetOperation SetOperationFromContext(OpKernelConstruction* ctx) {
306 string set_operation_str;
307 if (!ctx->GetAttr("set_operation", &set_operation_str).ok()) {
308 ctx->CtxFailure(errors::InvalidArgument("Missing set_operation."));
309 } else {
310 std::transform(set_operation_str.begin(), set_operation_str.end(),
311 set_operation_str.begin(), ::tolower);
312 if ("a-b" == set_operation_str) {
313 return A_MINUS_B;
314 }
315 if ("b-a" == set_operation_str) {
316 return B_MINUS_A;
317 }
318 if ("intersection" == set_operation_str) {
319 return INTERSECTION;
320 }
321 if ("union" != set_operation_str) {
322 ctx->CtxFailure(errors::InvalidArgument("Invalid set_operation ",
323 set_operation_str, "."));
324 }
325 }
326 // NOTE: This is not the default, this function fails if no 'set_operation'
327 // attribute is provided.
328 return UNION;
329 }
330
331 // Abstract base class for performing set operations across the last dimension
332 // of 2 input tensors.
333 template <typename T>
334 class SetOperationOp : public OpKernel {
335 public:
SetOperationOp(OpKernelConstruction * ctx,InputTypes input_types)336 SetOperationOp(OpKernelConstruction* ctx, InputTypes input_types)
337 : OpKernel(ctx),
338 set_operation_(SetOperationFromContext(ctx)),
339 validate_indices_(ValidateIndicesFromContext(ctx)),
340 input_types_(input_types) {}
341
342 void Compute(OpKernelContext* ctx) override;
343
344 private:
345 void ApplySetOperation(const std::set<T>& set1, const std::set<T>& set2,
346 std::set<T>* result) const;
347 void ComputeDenseToDense(OpKernelContext* ctx) const;
348 void ComputeDenseToSparse(OpKernelContext* ctx) const;
349 void ComputeSparseToSparse(OpKernelContext* ctx) const;
350 const SetOperation set_operation_;
351 const bool validate_indices_;
352 const InputTypes input_types_;
353 };
354
355 template <typename T>
ApplySetOperation(const std::set<T> & set1,const std::set<T> & set2,std::set<T> * result) const356 void SetOperationOp<T>::ApplySetOperation(const std::set<T>& set1,
357 const std::set<T>& set2,
358 std::set<T>* result) const {
359 switch (set_operation_) {
360 case A_MINUS_B:
361 std::set_difference(set1.begin(), set1.end(), set2.begin(), set2.end(),
362 std::inserter(*result, result->begin()));
363 break;
364 case B_MINUS_A:
365 std::set_difference(set2.begin(), set2.end(), set1.begin(), set1.end(),
366 std::inserter(*result, result->begin()));
367 break;
368 case INTERSECTION:
369 std::set_intersection(set1.begin(), set1.end(), set2.begin(), set2.end(),
370 std::inserter(*result, result->begin()));
371 break;
372 case UNION:
373 std::set_union(set1.begin(), set1.end(), set2.begin(), set2.end(),
374 std::inserter(*result, result->begin()));
375 break;
376 }
377 }
378
379 // Validate shapes have the same dimensions.
CheckShapesMatch(VarDimArray shape1,VarDimArray shape2)380 Status CheckShapesMatch(VarDimArray shape1, VarDimArray shape2) {
381 if (shape1 != shape2) {
382 return errors::InvalidArgument("Mismatched shapes [",
383 str_util::Join(shape1, ","), "] vs [",
384 str_util::Join(shape2, ","), "]");
385 }
386 return Status::OK();
387 }
388
389 // Validate ranks are the same, and all but last dimension are the same.
390 // Return GroupShape.
GroupShapeFromInputs(VarDimArray shape1,VarDimArray shape2,ShapeArray * group_shape)391 Status GroupShapeFromInputs(VarDimArray shape1, VarDimArray shape2,
392 ShapeArray* group_shape) {
393 ShapeArray group_shape_1;
394 TF_RETURN_IF_ERROR(GroupShape(shape1, &group_shape_1));
395 ShapeArray group_shape_2;
396 TF_RETURN_IF_ERROR(GroupShape(shape2, &group_shape_2));
397 TF_RETURN_IF_ERROR(CheckShapesMatch(group_shape_1, group_shape_2));
398 *group_shape = group_shape_1;
399 return Status::OK();
400 }
401
402 // Split `flat_group_index` into separate dimensions based on `group_shape`.
PopulateGroupIndices(const int64 flat_group_index,VarDimArray group_shape,std::vector<int64> * group_indices)403 void PopulateGroupIndices(const int64 flat_group_index, VarDimArray group_shape,
404 std::vector<int64>* group_indices) {
405 group_indices->clear();
406 int64 running_flat_group_index = flat_group_index;
407 for (int group_dim_index = group_shape.size() - 1; group_dim_index >= 0;
408 --group_dim_index) {
409 const auto group_dim = group_shape[group_dim_index];
410 group_indices->insert(group_indices->begin(),
411 running_flat_group_index % group_dim);
412 running_flat_group_index /= group_dim;
413 }
414 }
415
TensorShapeToArray(const TensorShape & t)416 ShapeArray TensorShapeToArray(const TensorShape& t) {
417 ShapeArray vec(t.dims());
418 for (int i = 0; i < t.dims(); ++i) vec[i] = t.dim_size(i);
419 return vec;
420 };
421
422 // `ctx` contains set1 and set2 dense tensors.
423 // Iterate over groups in set1 and set2, applying `ApplySetOperation` to each,
424 // and outputing the result `SparseTensor`. A "group" is a collection of values
425 // with the same first n-1 dimensions in set1 and set2.
426 template <typename T>
ComputeDenseToDense(OpKernelContext * ctx) const427 void SetOperationOp<T>::ComputeDenseToDense(OpKernelContext* ctx) const {
428 const Tensor& set1_t = ctx->input(0);
429 const Tensor& set2_t = ctx->input(1);
430 // The following should stay in sync with `_dense_to_dense_shape` shape
431 // assertions in python/ops/set_ops.py, and `SetShapeFn` for
432 // `DenseToDenseSetOperation` in ops/set_ops.cc.
433 ShapeArray group_shape;
434 const auto shape1 = TensorShapeToArray(set1_t.shape());
435 const auto shape2 = TensorShapeToArray(set2_t.shape());
436 OP_REQUIRES_OK(ctx, GroupShapeFromInputs(shape1, shape2, &group_shape));
437
438 const auto set1_strides = Strides(shape1);
439 const auto set2_strides = Strides(shape2);
440
441 std::map<std::vector<int64>, std::set<T>> group_sets;
442 int64 num_result_values = 0;
443 int64 max_set_size = 0;
444
445 std::set<T> set1_group_set;
446 std::set<T> set2_group_set;
447 std::vector<int64> group_indices;
448 int64 num_elements;
449 OP_REQUIRES_OK(ctx,
450 TensorShapeUtils::NumElements(group_shape, &num_elements));
451 for (int64 flat_group_index = 0; flat_group_index < num_elements;
452 ++flat_group_index) {
453 PopulateGroupIndices(flat_group_index, group_shape, &group_indices);
454 PopulateFromDenseGroup<T>(ctx, set1_t, set1_strides, group_indices,
455 &set1_group_set);
456 PopulateFromDenseGroup<T>(ctx, set2_t, set2_strides, group_indices,
457 &set2_group_set);
458
459 std::set<T> group_set;
460 ApplySetOperation(set1_group_set, set2_group_set, &group_set);
461 if (!group_set.empty()) {
462 group_sets[group_indices] = group_set;
463 const auto set_size = group_set.size();
464 if (set_size > max_set_size) {
465 max_set_size = set_size;
466 }
467 num_result_values += set_size;
468 }
469 }
470
471 TensorShape output_shape;
472 OP_REQUIRES_OK(ctx, TensorShapeUtils::MakeShape(group_shape, &output_shape));
473 output_shape.AddDim(max_set_size);
474 OutputSparseTensor<T>(ctx, output_shape, num_result_values, group_sets);
475 }
476
477 // `ctx` contains dense set1 and sparse set2 tensors.
478 // Iterate over groups in set1 and set2, applying `ApplySetOperation` to each,
479 // and outputing the result `SparseTensor`. A "group" is a collection of values
480 // with the same first n-1 dimensions in set1 and set2.
481 template <typename T>
ComputeDenseToSparse(OpKernelContext * ctx) const482 void SetOperationOp<T>::ComputeDenseToSparse(OpKernelContext* ctx) const {
483 const Tensor& set1_t = ctx->input(0);
484 sparse::SparseTensor set2_st;
485 OP_REQUIRES_OK(ctx,
486 SparseTensorFromContext(ctx, 1, validate_indices_, &set2_st));
487 OP_REQUIRES_OK(ctx, set2_st.IndicesValid());
488 // The following should stay in sync with `_dense_to_sparse_shape` shape
489 // assertions in python/ops/set_ops.py, and `SetShapeFn` for
490 // `DenseToSparseSetOperation` in ops/set_ops.cc.
491 ShapeArray group_shape;
492 OP_REQUIRES_OK(ctx, GroupShapeFromInputs(TensorShapeToArray(set1_t.shape()),
493 set2_st.shape(), &group_shape));
494
495 const ShapeArray set1_strides = Strides(TensorShapeToArray(set1_t.shape()));
496
497 std::map<std::vector<int64>, std::set<T>> group_sets;
498 int64 num_result_values = 0;
499 int64 max_set_size = 0;
500
501 std::set<T> set1_group_set;
502 std::set<T> set2_group_set;
503 auto set2_grouper =
504 set2_st.group(set2_st.order().subspan(0, set2_st.order().size() - 1));
505 auto set2_group_it = set2_grouper.begin();
506 std::vector<int64> group_indices;
507 int64 num_elements;
508 OP_REQUIRES_OK(ctx,
509 TensorShapeUtils::NumElements(group_shape, &num_elements));
510 for (int64 flat_group_index = 0; flat_group_index < num_elements;
511 ++flat_group_index) {
512 PopulateGroupIndices(flat_group_index, group_shape, &group_indices);
513
514 // Get values from set1.
515 PopulateFromDenseGroup<T>(ctx, set1_t, set1_strides, group_indices,
516 &set1_group_set);
517
518 // Get values from set2, if applicable.
519 set2_group_set.clear();
520 if (set2_group_it != set2_grouper.end()) {
521 const auto& group = *set2_group_it;
522 const auto set2_group_indices = group.group();
523 OP_REQUIRES(
524 ctx, set2_group_indices.size() == group_indices.size(),
525 errors::InvalidArgument("Invalid number of group indices ",
526 set2_group_indices.size(), ", expected ",
527 group_indices.size(), "."));
528 bool group_match = true;
529 for (int32 i = 0; group_match && (i < set2_group_indices.size()); ++i) {
530 if (set2_group_indices[i] != group_indices[i]) {
531 group_match = false;
532 }
533 }
534 if (group_match) {
535 PopulateFromSparseGroup<T>(ctx, group, set2_st.shape(),
536 &set2_group_set);
537 ++set2_group_it;
538 }
539 }
540
541 std::set<T> group_set;
542 ApplySetOperation(set1_group_set, set2_group_set, &group_set);
543 if (!group_set.empty()) {
544 group_sets[group_indices] = group_set;
545 const auto set_size = group_set.size();
546 if (set_size > max_set_size) {
547 max_set_size = set_size;
548 }
549 num_result_values += set_size;
550 }
551 }
552
553 TensorShape output_shape;
554 OP_REQUIRES_OK(ctx, TensorShapeUtils::MakeShape(group_shape, &output_shape));
555 output_shape.AddDim(max_set_size);
556 OutputSparseTensor<T>(ctx, output_shape, num_result_values, group_sets);
557 }
558
559 // This is used to determine which group iterator is less than the other, based
560 // on row-major ordering of indices.
561 // An empty index list indicates end of iteration, which is interpreted as "max"
562 // for the purposes of comparison; i.e., non-empty < empty.
563 // Return 0 if both groups are empty, or both non-empty with the same values.
564 // Return <0 if set1 <= set2, or set2 is empty.
565 // Return >0 if set2 <= set1, or set1 is empty.
CompareGroups(OpKernelContext * ctx,const std::vector<int64> & set1_group_indices,const std::vector<int64> & set2_group_indices,int64 * result)566 void CompareGroups(OpKernelContext* ctx,
567 const std::vector<int64>& set1_group_indices,
568 const std::vector<int64>& set2_group_indices,
569 int64* result) {
570 if (set1_group_indices.empty()) {
571 *result = set2_group_indices.empty() ? 0 : 1;
572 return;
573 }
574 if (set2_group_indices.empty()) {
575 *result = set1_group_indices.empty() ? 0 : -1;
576 return;
577 }
578 OP_REQUIRES(ctx, set1_group_indices.size() == set2_group_indices.size(),
579 errors::InvalidArgument("Mismatched group dims ",
580 set1_group_indices.size(), " vs ",
581 set2_group_indices.size(), "."));
582 for (int32 i = 0; i < set1_group_indices.size(); ++i) {
583 *result = set1_group_indices[i] - set2_group_indices[i];
584 if (*result != 0) {
585 return;
586 }
587 }
588 }
589
590 // Empty indices vector represents iteration end in `CompareGroups`.
591 const std::vector<int64> GROUP_ITER_END;
592
593 // `ctx` contains set1 and set2 sparse tensors.
594 // Iterate over groups in set1 and set2, applying `ApplySetOperation` to each,
595 // and outputing the result `SparseTensor`. A "group" is a collection of values
596 // with the same first n-1 dimensions in set1 and set2.
597 template <typename T>
ComputeSparseToSparse(OpKernelContext * ctx) const598 void SetOperationOp<T>::ComputeSparseToSparse(OpKernelContext* ctx) const {
599 sparse::SparseTensor set1_st;
600 OP_REQUIRES_OK(ctx,
601 SparseTensorFromContext(ctx, 0, validate_indices_, &set1_st));
602 OP_REQUIRES_OK(ctx, set1_st.IndicesValid());
603
604 sparse::SparseTensor set2_st;
605 OP_REQUIRES_OK(ctx,
606 SparseTensorFromContext(ctx, 3, validate_indices_, &set2_st));
607
608 // The following should stay in sync with `_sparse_to_sparse_shape` shape
609 // assertions in python/ops/set_ops.py, and `SetShapeFn` for
610 // `SparseToSparseSetOperation` in ops/set_ops.cc.
611 ShapeArray group_shape;
612 OP_REQUIRES_OK(ctx, GroupShapeFromInputs(set1_st.shape(), set2_st.shape(),
613 &group_shape));
614
615 const ShapeArray set1_strides = Strides(set1_st.shape());
616 const ShapeArray set2_strides = Strides(set2_st.shape());
617
618 std::map<std::vector<int64>, std::set<T>> group_sets;
619 int64 num_result_values = 0;
620 int64 max_set_size = 0;
621
622 std::set<T> set1_group_set;
623 std::set<T> set2_group_set;
624 auto set1_grouper =
625 set1_st.group(set1_st.order().subspan(0, set1_st.order().size() - 1));
626 auto set1_group_it = set1_grouper.begin();
627 auto set2_grouper =
628 set2_st.group(set2_st.order().subspan(0, set2_st.order().size() - 1));
629 auto set2_group_it = set2_grouper.begin();
630
631 // Group by rows, and iterate over rows of both sets in parallel, creating a
632 // set for each row.
633 while ((set1_group_it != set1_grouper.end()) ||
634 (set2_group_it != set2_grouper.end())) {
635 const std::vector<int64>& set1_group_indices =
636 (set1_group_it == set1_grouper.end()) ? GROUP_ITER_END
637 : (*set1_group_it).group();
638 const std::vector<int64>& set2_group_indices =
639 (set2_group_it == set2_grouper.end()) ? GROUP_ITER_END
640 : (*set2_group_it).group();
641
642 int64 compare_groups;
643 CompareGroups(ctx, set1_group_indices, set2_group_indices, &compare_groups);
644 const std::vector<int64>* group_indices = nullptr;
645
646 // Get values from set1, if applicable.
647 set1_group_set.clear();
648 if (compare_groups <= 0) {
649 PopulateFromSparseGroup<T>(ctx, *set1_group_it, set1_st.shape(),
650 &set1_group_set);
651 ++set1_group_it;
652 group_indices = &set1_group_indices;
653 }
654
655 // Get values from set2, if applicable.
656 set2_group_set.clear();
657 if (compare_groups >= 0) {
658 PopulateFromSparseGroup<T>(ctx, *set2_group_it, set2_st.shape(),
659 &set2_group_set);
660 ++set2_group_it;
661 group_indices = &set2_group_indices;
662 }
663
664 std::set<T> group_set;
665 ApplySetOperation(set1_group_set, set2_group_set, &group_set);
666 if (!group_set.empty()) {
667 group_sets[*group_indices] = group_set;
668 const auto set_size = group_set.size();
669 if (set_size > max_set_size) {
670 max_set_size = set_size;
671 }
672 num_result_values += set_size;
673 }
674 }
675
676 TensorShape output_shape;
677 OP_REQUIRES_OK(ctx, TensorShapeUtils::MakeShape(group_shape, &output_shape));
678 output_shape.AddDim(max_set_size);
679 OutputSparseTensor<T>(ctx, output_shape, num_result_values, group_sets);
680 }
681
682 // Given set1 of shape [b, n1] and data_2 of shape [b, n2], populate result
683 // sparse tendor with [b, n3] values, where each row `i` contains the result of
684 // the set operation on elements from set1[i] and set2[i]. `n3` is the number
685 // of elements in that result row.
686 template <typename T>
Compute(OpKernelContext * ctx)687 void SetOperationOp<T>::Compute(OpKernelContext* ctx) {
688 switch (input_types_) {
689 case DENSE_DENSE:
690 ComputeDenseToDense(ctx);
691 break;
692 case DENSE_SPARSE:
693 ComputeDenseToSparse(ctx);
694 break;
695 case SPARSE_SPARSE:
696 ComputeSparseToSparse(ctx);
697 break;
698 }
699 }
700
701 template <typename T>
702 class DenseToDenseSetOperationOp : public SetOperationOp<T> {
703 public:
DenseToDenseSetOperationOp(OpKernelConstruction * ctx)704 explicit DenseToDenseSetOperationOp(OpKernelConstruction* ctx)
705 : SetOperationOp<T>(ctx, DENSE_DENSE) {}
706 };
707
708 #define _DENSE_TO_DENSE_SET_OPERATION_REGISTER_KERNEL_BUILDER(T) \
709 REGISTER_KERNEL_BUILDER(Name("DenseToDenseSetOperation") \
710 .Device(DEVICE_CPU) \
711 .TypeConstraint<T>("T"), \
712 DenseToDenseSetOperationOp<T>);
713 _DENSE_TO_DENSE_SET_OPERATION_REGISTER_KERNEL_BUILDER(int8);
714 _DENSE_TO_DENSE_SET_OPERATION_REGISTER_KERNEL_BUILDER(int16);
715 _DENSE_TO_DENSE_SET_OPERATION_REGISTER_KERNEL_BUILDER(int32);
716 _DENSE_TO_DENSE_SET_OPERATION_REGISTER_KERNEL_BUILDER(int64);
717 _DENSE_TO_DENSE_SET_OPERATION_REGISTER_KERNEL_BUILDER(uint8);
718 _DENSE_TO_DENSE_SET_OPERATION_REGISTER_KERNEL_BUILDER(uint16);
719 _DENSE_TO_DENSE_SET_OPERATION_REGISTER_KERNEL_BUILDER(string);
720 #undef _DENSE_TO_DENSE_SET_OPERATION_REGISTER_KERNEL_BUILDER
721
722 template <typename T>
723 class DenseToSparseSetOperationOp : public SetOperationOp<T> {
724 public:
DenseToSparseSetOperationOp(OpKernelConstruction * ctx)725 explicit DenseToSparseSetOperationOp(OpKernelConstruction* ctx)
726 : SetOperationOp<T>(ctx, DENSE_SPARSE) {}
727 };
728
729 #define _DENSE_TO_SPARSE_SET_OPERATION_REGISTER_KERNEL_BUILDER(T) \
730 REGISTER_KERNEL_BUILDER(Name("DenseToSparseSetOperation") \
731 .Device(DEVICE_CPU) \
732 .TypeConstraint<T>("T"), \
733 DenseToSparseSetOperationOp<T>);
734 _DENSE_TO_SPARSE_SET_OPERATION_REGISTER_KERNEL_BUILDER(int8);
735 _DENSE_TO_SPARSE_SET_OPERATION_REGISTER_KERNEL_BUILDER(int16);
736 _DENSE_TO_SPARSE_SET_OPERATION_REGISTER_KERNEL_BUILDER(int32);
737 _DENSE_TO_SPARSE_SET_OPERATION_REGISTER_KERNEL_BUILDER(int64);
738 _DENSE_TO_SPARSE_SET_OPERATION_REGISTER_KERNEL_BUILDER(uint8);
739 _DENSE_TO_SPARSE_SET_OPERATION_REGISTER_KERNEL_BUILDER(uint16);
740 _DENSE_TO_SPARSE_SET_OPERATION_REGISTER_KERNEL_BUILDER(string);
741 #undef _DENSE_TO_SPARSE_SET_OPERATION_REGISTER_KERNEL_BUILDER
742
743 template <typename T>
744 class SparseToSparseSetOperationOp : public SetOperationOp<T> {
745 public:
SparseToSparseSetOperationOp(OpKernelConstruction * ctx)746 explicit SparseToSparseSetOperationOp(OpKernelConstruction* ctx)
747 : SetOperationOp<T>(ctx, SPARSE_SPARSE) {}
748 };
749
750 #define _SPARSE_TO_SPARSE_SET_OPERATION_REGISTER_KERNEL_BUILDER(T) \
751 REGISTER_KERNEL_BUILDER(Name("SparseToSparseSetOperation") \
752 .Device(DEVICE_CPU) \
753 .TypeConstraint<T>("T"), \
754 SparseToSparseSetOperationOp<T>);
755 _SPARSE_TO_SPARSE_SET_OPERATION_REGISTER_KERNEL_BUILDER(int8);
756 _SPARSE_TO_SPARSE_SET_OPERATION_REGISTER_KERNEL_BUILDER(int16);
757 _SPARSE_TO_SPARSE_SET_OPERATION_REGISTER_KERNEL_BUILDER(int32);
758 _SPARSE_TO_SPARSE_SET_OPERATION_REGISTER_KERNEL_BUILDER(int64);
759 _SPARSE_TO_SPARSE_SET_OPERATION_REGISTER_KERNEL_BUILDER(uint8);
760 _SPARSE_TO_SPARSE_SET_OPERATION_REGISTER_KERNEL_BUILDER(uint16);
761 _SPARSE_TO_SPARSE_SET_OPERATION_REGISTER_KERNEL_BUILDER(string);
762 #undef _SPARSE_TO_SPARSE_SET_OPERATION_REGISTER_KERNEL_BUILDER
763
764 } // namespace tensorflow
765