1 /* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 // Ops for operating with sets. They are not checked in
17 // to TensorFlow because we would first like to demonstrate successful
18 // end-to-end use of these ops in eval and polush the api a bit like taking two
19 // SparseTensor rather than on edense and one sparse.
20 
21 #define EIGEN_USE_THREADS
22 
23 #include <algorithm>
24 #include <numeric>
25 // TODO(ptucker): Consider switching back to hash_set - I had trouble getting it
26 // to work with string values.
27 #include <set>
28 #include <string>
29 
30 #include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
31 #include "tensorflow/core/framework/op_kernel.h"
32 #include "tensorflow/core/framework/register_types.h"
33 #include "tensorflow/core/framework/tensor.h"
34 #include "tensorflow/core/framework/tensor_util.h"
35 #include "tensorflow/core/framework/types.h"
36 #include "tensorflow/core/lib/core/status.h"
37 #include "tensorflow/core/platform/env.h"
38 #include "tensorflow/core/util/sparse/sparse_tensor.h"
39 
40 namespace tensorflow {
41 
42 using ShapeArray = sparse::SparseTensor::ShapeArray;
43 using VarDimArray = sparse::SparseTensor::VarDimArray;
44 
45 // Validate rank >= 2.
CheckRankAtLeast2(OpKernelContext * ctx,const TensorShape & shape)46 void CheckRankAtLeast2(OpKernelContext* ctx, const TensorShape& shape) {
47   const auto rank = shape.dims();
48   OP_REQUIRES(ctx, rank >= 2,
49               errors::InvalidArgument("Invalid rank ", rank, "."));
50 }
51 
52 // Return group shape, which is the 1st n-1 dimensions of shape.
GroupShape(const VarDimArray & input_shape,ShapeArray * grouped_shape)53 Status GroupShape(const VarDimArray& input_shape, ShapeArray* grouped_shape) {
54   if (input_shape.size() < 2) {
55     // TODO(irving): Why can't 2 be 1 here?
56     return errors::InvalidArgument("Shape [", str_util::Join(input_shape, ","),
57                                    "] has rank ", input_shape.size(), " < 2");
58   }
59   // grouped_shape is input_shape[:-1]
60   *grouped_shape = ShapeArray(input_shape.begin(), input_shape.end() - 1);
61   return Status::OK();
62 }
63 
64 // Build `SparseTensor` from indices, values, and shape in inputs
65 // [base_index, base_index + 3), and validate its rank and indices.
SparseTensorFromContext(OpKernelContext * ctx,const int32 base_index,bool validate_indices,sparse::SparseTensor * tensor)66 Status SparseTensorFromContext(OpKernelContext* ctx, const int32 base_index,
67                                bool validate_indices,
68                                sparse::SparseTensor* tensor) {
69   // Assume row-major order.
70   const TensorShape shape =
71       TensorShape(ctx->input(base_index + 2).vec<int64>());
72   CheckRankAtLeast2(ctx, shape);
73   std::vector<int64> order(shape.dims());
74   std::iota(order.begin(), order.end(), 0);
75 
76   return sparse::SparseTensor::Create(
77       ctx->input(base_index), ctx->input(base_index + 1), shape, order, tensor);
78 }
79 
80 // TODO(ptucker): CheckGroup is just a sanity check on the result of
81 // SparseTensor.group, consider removing.
82 // `sparse_tensor_shape` is the shape of the `SparseTensor` from which group
83 // was created, and is used to sanity check the indices in `group'.
84 template <typename T>
CheckGroup(OpKernelContext * ctx,const sparse::Group & group,const VarDimArray & sparse_tensor_shape)85 void CheckGroup(OpKernelContext* ctx, const sparse::Group& group,
86                 const VarDimArray& sparse_tensor_shape) {
87   const auto& indices = group.indices();
88   const auto& values = group.values<T>();
89 
90   // Sanity check: group is non-empty, and indices and values are same size.
91   const auto num_values = values.dimension(0);
92   OP_REQUIRES(ctx, indices.size() > 0, errors::Internal("Empty group."));
93   OP_REQUIRES(
94       ctx, indices.dimension(0) == num_values,
95       errors::Internal("shape[0] of group indices ", indices.dimension(0),
96                        " != values ", num_values, "."));
97 
98   // Sanity check: valid indices.
99   const auto group_rank = indices.dimension(1);
100   const auto expected_rank = sparse_tensor_shape.size();
101   OP_REQUIRES(ctx, expected_rank == group_rank,
102               errors::Internal("Rank expected ", expected_rank, ", got ",
103                                group_rank, "."));
104   for (int32 j = 0; j < expected_rank; ++j) {
105     const auto dim_size = sparse_tensor_shape[j];
106     OP_REQUIRES(
107         ctx, dim_size > 0,
108         errors::Internal("Invalid dim_size[", j, "] = ", dim_size, "."));
109     for (int64 i = 0; i < num_values; ++i) {
110       const auto index = indices(i, j);
111       OP_REQUIRES(ctx, dim_size > index,
112                   errors::Internal("indices[", i, ", ", j, "] expected < ",
113                                    dim_size, ", got ", index, "."));
114     }
115   }
116 }
117 
118 // This lets us calculate the row-major index into flattened output.
Strides(const VarDimArray & shape)119 const ShapeArray Strides(const VarDimArray& shape) {
120   ShapeArray result(shape.size());
121   int64 product = 1;
122   for (int i = shape.size() - 1; i >= 0; --i) {
123     result[i] = product;
124     product *= shape[i];
125   }
126   return result;
127 }
128 
129 // TODO(ptucker): If memory becomes an issue, consider a 2-pass approach to
130 // eliminate the intermediate `values` data structure - iterate once to
131 // determine `num_values`, allocate output tensors, then write results directly
132 // to output tensors.
133 
134 // TODO(ptucker): Consider sharding work across multiple threads. See
135 // SparseCrossOp for an example.
136 
137 // Output `SparseTensor` of shape `output_shape`. `sets` contains a map of
138 // group indices (i.e., values for all but the last dimension of `output_shape`)
139 // to set values, each of which will occupy the last dimension of
140 // `output_shape`.
141 template <typename T>
OutputSparseTensor(OpKernelContext * ctx,const TensorShape & output_shape,const int64 num_values,const std::map<std::vector<int64>,std::set<T>> & sets)142 void OutputSparseTensor(OpKernelContext* ctx, const TensorShape& output_shape,
143                         const int64 num_values,
144                         const std::map<std::vector<int64>, std::set<T>>& sets) {
145   // Allocate 3 output tensors for sparse data.
146   Tensor *out_indices_t, *out_values_t, *out_shape_t;
147   OP_REQUIRES_OK(ctx, ctx->allocate_output(
148                           0, TensorShape({num_values, output_shape.dims()}),
149                           &out_indices_t));
150   OP_REQUIRES_OK(
151       ctx, ctx->allocate_output(1, TensorShape({num_values}), &out_values_t));
152   OP_REQUIRES_OK(ctx, ctx->allocate_output(
153                           2, TensorShape({output_shape.dims()}), &out_shape_t));
154   auto out_indices_mat = out_indices_t->matrix<int64>();
155   auto out_values_flat = out_values_t->vec<T>();
156 
157   // For each set, write its indices and values to output tensors.
158   int64 value_index = 0;
159   for (auto it = sets.begin(); it != sets.end(); ++it) {
160     const auto& group_indices = it->first;
161     OP_REQUIRES(
162         ctx, group_indices.size() == output_shape.dims() - 1,
163         errors::Internal("Invalid number of indices ", group_indices.size(),
164                          ", expected ", output_shape.dims() - 1, "."));
165     const auto& set = it->second;
166 
167     // For each set item, write its indices and value to output tensors.
168     int64 group_value_index = 0;
169     for (auto value = set.begin(); value != set.end();
170          ++value, ++value_index, ++group_value_index) {
171       // First n-1 dimensions are the group, last dimension is the position in
172       // the set.
173       for (int32 i = 0; i < group_indices.size(); ++i) {
174         out_indices_mat(value_index, i) = group_indices[i];
175       }
176       out_indices_mat(value_index, group_indices.size()) = group_value_index;
177 
178       out_values_flat(value_index) = *value;
179     }
180   }
181 
182   // Write output shape.
183   auto out_shape_flat = out_shape_t->vec<int64>();
184   for (int32 i = 0; i < output_shape.dims(); ++i) {
185     out_shape_flat(i) = output_shape.dim_size(i);
186   }
187 }
188 
ValidateIndicesFromContext(OpKernelConstruction * ctx)189 bool ValidateIndicesFromContext(OpKernelConstruction* ctx) {
190   bool result;
191   if (ctx->GetAttr("validate_indices", &result).ok()) {
192     return result;
193   }
194   return true;
195 }
196 
197 // Populate `result` set from group in `tensor`. "Group" is defined by
198 // `group_indices`, which are values for the first n-1 dimensions of
199 // `input_tensor`. `input_strides` is provided to avoid recalculating it
200 // multiple times, and is used to calculate the flat index into `input_tensor`
201 // values.
202 template <typename T>
PopulateFromDenseGroup(OpKernelContext * ctx,const Tensor & input_tensor,const VarDimArray & input_strides,const std::vector<int64> & group_indices,std::set<T> * result)203 void PopulateFromDenseGroup(OpKernelContext* ctx, const Tensor& input_tensor,
204                             const VarDimArray& input_strides,
205                             const std::vector<int64>& group_indices,
206                             std::set<T>* result) {
207   OP_REQUIRES(ctx, group_indices.size() == input_strides.size() - 1,
208               errors::Internal("group_indices.size ", group_indices.size(),
209                                ", !=  input_strides.size-1 ",
210                                input_strides.size() - 1, "."));
211   result->clear();
212   auto input_flat = input_tensor.flat<T>();
213   const auto start = std::inner_product(
214       group_indices.begin(), group_indices.end(), input_strides.begin(), 0LL);
215   const TensorShape& input_shape = input_tensor.shape();
216   const auto end = start + input_shape.dim_size(input_shape.dims() - 1);
217   for (int64 i = start; i < end; ++i) {
218     result->insert(input_flat(i));
219   }
220 }
221 
222 // Populate `result` set from `group`. `sparse_tensor_shape` is the shape of the
223 // `SparseTensor` from which group was created, and is used to sanity check the
224 // indices in `group'.
225 template <typename T>
PopulateFromSparseGroup(OpKernelContext * ctx,const sparse::Group & group,const VarDimArray & sparse_tensor_shape,std::set<T> * result)226 void PopulateFromSparseGroup(OpKernelContext* ctx, const sparse::Group& group,
227                              const VarDimArray& sparse_tensor_shape,
228                              std::set<T>* result) {
229   CheckGroup<T>(ctx, group, sparse_tensor_shape);
230   result->clear();
231   const auto& group_values = group.values<T>();
232   for (int64 i = 0; i < group_values.size(); ++i) {
233     result->insert(group_values(i));
234   }
235 }
236 
237 template <typename T>
238 class SetSizeOp : public OpKernel {
239  public:
SetSizeOp(OpKernelConstruction * ctx)240   explicit SetSizeOp(OpKernelConstruction* ctx)
241       : OpKernel(ctx), validate_indices_(ValidateIndicesFromContext(ctx)) {}
242 
243   void Compute(OpKernelContext* ctx) override;
244 
245  private:
246   const bool validate_indices_;
247 };
248 
249 template <typename T>
Compute(OpKernelContext * ctx)250 void SetSizeOp<T>::Compute(OpKernelContext* ctx) {
251   sparse::SparseTensor set_st;
252   OP_REQUIRES_OK(ctx,
253                  SparseTensorFromContext(ctx, 0, validate_indices_, &set_st));
254   OP_REQUIRES_OK(ctx, set_st.IndicesValid());
255 
256   // Output shape is same as input except for last dimension, which reduces
257   // to the set size of values along that dimension.
258   ShapeArray output_shape;
259   OP_REQUIRES_OK(ctx, GroupShape(set_st.shape(), &output_shape));
260   const auto output_strides = Strides(output_shape);
261 
262   TensorShape output_shape_ts;
263   OP_REQUIRES_OK(ctx,
264                  TensorShapeUtils::MakeShape(output_shape, &output_shape_ts));
265   Tensor* out_t;
266   OP_REQUIRES_OK(ctx, ctx->allocate_output(0, output_shape_ts, &out_t));
267   auto out = out_t->flat<int32>();
268   out.device(ctx->eigen_cpu_device()) = out.constant(static_cast<int32>(0.0));
269 
270   // Group by all but last dimension, create a set of group values, and add set
271   // size to output.
272   VarDimArray group_ix = set_st.order().subspan(0, set_st.order().size() - 1);
273   std::set<T> group_set;
274   for (const auto& group : set_st.group(group_ix)) {
275     PopulateFromSparseGroup<T>(ctx, group, set_st.shape(), &group_set);
276 
277     const auto group_key = group.group();
278     const auto output_index = std::inner_product(
279         group_key.begin(), group_key.end(), output_strides.begin(), 0LL);
280     out(output_index) = group_set.size();
281   }
282 }
283 
284 #define _SET_SIZE_REGISTER_KERNEL_BUILDER(T)                     \
285   REGISTER_KERNEL_BUILDER(                                       \
286       Name("SetSize").Device(DEVICE_CPU).TypeConstraint<T>("T"), \
287       SetSizeOp<T>);
288 _SET_SIZE_REGISTER_KERNEL_BUILDER(int8);
289 _SET_SIZE_REGISTER_KERNEL_BUILDER(int16);
290 _SET_SIZE_REGISTER_KERNEL_BUILDER(int32);
291 _SET_SIZE_REGISTER_KERNEL_BUILDER(int64);
292 _SET_SIZE_REGISTER_KERNEL_BUILDER(uint8);
293 _SET_SIZE_REGISTER_KERNEL_BUILDER(uint16);
294 _SET_SIZE_REGISTER_KERNEL_BUILDER(string);
295 #undef _SET_SIZE_REGISTER_KERNEL_BUILDER
296 
297 enum InputTypes {
298   DENSE_DENSE = 0,
299   DENSE_SPARSE = 1,
300   SPARSE_SPARSE = 2,
301 };
302 
303 enum SetOperation { A_MINUS_B = 0, B_MINUS_A = 1, INTERSECTION = 2, UNION = 3 };
304 
SetOperationFromContext(OpKernelConstruction * ctx)305 SetOperation SetOperationFromContext(OpKernelConstruction* ctx) {
306   string set_operation_str;
307   if (!ctx->GetAttr("set_operation", &set_operation_str).ok()) {
308     ctx->CtxFailure(errors::InvalidArgument("Missing set_operation."));
309   } else {
310     std::transform(set_operation_str.begin(), set_operation_str.end(),
311                    set_operation_str.begin(), ::tolower);
312     if ("a-b" == set_operation_str) {
313       return A_MINUS_B;
314     }
315     if ("b-a" == set_operation_str) {
316       return B_MINUS_A;
317     }
318     if ("intersection" == set_operation_str) {
319       return INTERSECTION;
320     }
321     if ("union" != set_operation_str) {
322       ctx->CtxFailure(errors::InvalidArgument("Invalid set_operation ",
323                                               set_operation_str, "."));
324     }
325   }
326   // NOTE: This is not the default, this function fails if no 'set_operation'
327   // attribute is provided.
328   return UNION;
329 }
330 
331 // Abstract base class for performing set operations across the last dimension
332 // of 2 input tensors.
333 template <typename T>
334 class SetOperationOp : public OpKernel {
335  public:
SetOperationOp(OpKernelConstruction * ctx,InputTypes input_types)336   SetOperationOp(OpKernelConstruction* ctx, InputTypes input_types)
337       : OpKernel(ctx),
338         set_operation_(SetOperationFromContext(ctx)),
339         validate_indices_(ValidateIndicesFromContext(ctx)),
340         input_types_(input_types) {}
341 
342   void Compute(OpKernelContext* ctx) override;
343 
344  private:
345   void ApplySetOperation(const std::set<T>& set1, const std::set<T>& set2,
346                          std::set<T>* result) const;
347   void ComputeDenseToDense(OpKernelContext* ctx) const;
348   void ComputeDenseToSparse(OpKernelContext* ctx) const;
349   void ComputeSparseToSparse(OpKernelContext* ctx) const;
350   const SetOperation set_operation_;
351   const bool validate_indices_;
352   const InputTypes input_types_;
353 };
354 
355 template <typename T>
ApplySetOperation(const std::set<T> & set1,const std::set<T> & set2,std::set<T> * result) const356 void SetOperationOp<T>::ApplySetOperation(const std::set<T>& set1,
357                                           const std::set<T>& set2,
358                                           std::set<T>* result) const {
359   switch (set_operation_) {
360     case A_MINUS_B:
361       std::set_difference(set1.begin(), set1.end(), set2.begin(), set2.end(),
362                           std::inserter(*result, result->begin()));
363       break;
364     case B_MINUS_A:
365       std::set_difference(set2.begin(), set2.end(), set1.begin(), set1.end(),
366                           std::inserter(*result, result->begin()));
367       break;
368     case INTERSECTION:
369       std::set_intersection(set1.begin(), set1.end(), set2.begin(), set2.end(),
370                             std::inserter(*result, result->begin()));
371       break;
372     case UNION:
373       std::set_union(set1.begin(), set1.end(), set2.begin(), set2.end(),
374                      std::inserter(*result, result->begin()));
375       break;
376   }
377 }
378 
379 // Validate shapes have the same dimensions.
CheckShapesMatch(VarDimArray shape1,VarDimArray shape2)380 Status CheckShapesMatch(VarDimArray shape1, VarDimArray shape2) {
381   if (shape1 != shape2) {
382     return errors::InvalidArgument("Mismatched shapes [",
383                                    str_util::Join(shape1, ","), "] vs [",
384                                    str_util::Join(shape2, ","), "]");
385   }
386   return Status::OK();
387 }
388 
389 // Validate ranks are the same, and all but last dimension are the same.
390 // Return GroupShape.
GroupShapeFromInputs(VarDimArray shape1,VarDimArray shape2,ShapeArray * group_shape)391 Status GroupShapeFromInputs(VarDimArray shape1, VarDimArray shape2,
392                             ShapeArray* group_shape) {
393   ShapeArray group_shape_1;
394   TF_RETURN_IF_ERROR(GroupShape(shape1, &group_shape_1));
395   ShapeArray group_shape_2;
396   TF_RETURN_IF_ERROR(GroupShape(shape2, &group_shape_2));
397   TF_RETURN_IF_ERROR(CheckShapesMatch(group_shape_1, group_shape_2));
398   *group_shape = group_shape_1;
399   return Status::OK();
400 }
401 
402 // Split `flat_group_index` into separate dimensions based on `group_shape`.
PopulateGroupIndices(const int64 flat_group_index,VarDimArray group_shape,std::vector<int64> * group_indices)403 void PopulateGroupIndices(const int64 flat_group_index, VarDimArray group_shape,
404                           std::vector<int64>* group_indices) {
405   group_indices->clear();
406   int64 running_flat_group_index = flat_group_index;
407   for (int group_dim_index = group_shape.size() - 1; group_dim_index >= 0;
408        --group_dim_index) {
409     const auto group_dim = group_shape[group_dim_index];
410     group_indices->insert(group_indices->begin(),
411                           running_flat_group_index % group_dim);
412     running_flat_group_index /= group_dim;
413   }
414 }
415 
TensorShapeToArray(const TensorShape & t)416 ShapeArray TensorShapeToArray(const TensorShape& t) {
417   ShapeArray vec(t.dims());
418   for (int i = 0; i < t.dims(); ++i) vec[i] = t.dim_size(i);
419   return vec;
420 };
421 
422 // `ctx` contains set1 and set2 dense tensors.
423 // Iterate over groups in set1 and set2, applying `ApplySetOperation` to each,
424 // and outputing the result `SparseTensor`. A "group" is a collection of values
425 // with the same first n-1 dimensions in set1 and set2.
426 template <typename T>
ComputeDenseToDense(OpKernelContext * ctx) const427 void SetOperationOp<T>::ComputeDenseToDense(OpKernelContext* ctx) const {
428   const Tensor& set1_t = ctx->input(0);
429   const Tensor& set2_t = ctx->input(1);
430   // The following should stay in sync with `_dense_to_dense_shape` shape
431   // assertions in python/ops/set_ops.py, and `SetShapeFn` for
432   // `DenseToDenseSetOperation` in ops/set_ops.cc.
433   ShapeArray group_shape;
434   const auto shape1 = TensorShapeToArray(set1_t.shape());
435   const auto shape2 = TensorShapeToArray(set2_t.shape());
436   OP_REQUIRES_OK(ctx, GroupShapeFromInputs(shape1, shape2, &group_shape));
437 
438   const auto set1_strides = Strides(shape1);
439   const auto set2_strides = Strides(shape2);
440 
441   std::map<std::vector<int64>, std::set<T>> group_sets;
442   int64 num_result_values = 0;
443   int64 max_set_size = 0;
444 
445   std::set<T> set1_group_set;
446   std::set<T> set2_group_set;
447   std::vector<int64> group_indices;
448   int64 num_elements;
449   OP_REQUIRES_OK(ctx,
450                  TensorShapeUtils::NumElements(group_shape, &num_elements));
451   for (int64 flat_group_index = 0; flat_group_index < num_elements;
452        ++flat_group_index) {
453     PopulateGroupIndices(flat_group_index, group_shape, &group_indices);
454     PopulateFromDenseGroup<T>(ctx, set1_t, set1_strides, group_indices,
455                               &set1_group_set);
456     PopulateFromDenseGroup<T>(ctx, set2_t, set2_strides, group_indices,
457                               &set2_group_set);
458 
459     std::set<T> group_set;
460     ApplySetOperation(set1_group_set, set2_group_set, &group_set);
461     if (!group_set.empty()) {
462       group_sets[group_indices] = group_set;
463       const auto set_size = group_set.size();
464       if (set_size > max_set_size) {
465         max_set_size = set_size;
466       }
467       num_result_values += set_size;
468     }
469   }
470 
471   TensorShape output_shape;
472   OP_REQUIRES_OK(ctx, TensorShapeUtils::MakeShape(group_shape, &output_shape));
473   output_shape.AddDim(max_set_size);
474   OutputSparseTensor<T>(ctx, output_shape, num_result_values, group_sets);
475 }
476 
477 // `ctx` contains dense set1 and sparse set2 tensors.
478 // Iterate over groups in set1 and set2, applying `ApplySetOperation` to each,
479 // and outputing the result `SparseTensor`. A "group" is a collection of values
480 // with the same first n-1 dimensions in set1 and set2.
481 template <typename T>
ComputeDenseToSparse(OpKernelContext * ctx) const482 void SetOperationOp<T>::ComputeDenseToSparse(OpKernelContext* ctx) const {
483   const Tensor& set1_t = ctx->input(0);
484   sparse::SparseTensor set2_st;
485   OP_REQUIRES_OK(ctx,
486                  SparseTensorFromContext(ctx, 1, validate_indices_, &set2_st));
487   OP_REQUIRES_OK(ctx, set2_st.IndicesValid());
488   // The following should stay in sync with `_dense_to_sparse_shape` shape
489   // assertions in python/ops/set_ops.py, and `SetShapeFn` for
490   // `DenseToSparseSetOperation` in ops/set_ops.cc.
491   ShapeArray group_shape;
492   OP_REQUIRES_OK(ctx, GroupShapeFromInputs(TensorShapeToArray(set1_t.shape()),
493                                            set2_st.shape(), &group_shape));
494 
495   const ShapeArray set1_strides = Strides(TensorShapeToArray(set1_t.shape()));
496 
497   std::map<std::vector<int64>, std::set<T>> group_sets;
498   int64 num_result_values = 0;
499   int64 max_set_size = 0;
500 
501   std::set<T> set1_group_set;
502   std::set<T> set2_group_set;
503   auto set2_grouper =
504       set2_st.group(set2_st.order().subspan(0, set2_st.order().size() - 1));
505   auto set2_group_it = set2_grouper.begin();
506   std::vector<int64> group_indices;
507   int64 num_elements;
508   OP_REQUIRES_OK(ctx,
509                  TensorShapeUtils::NumElements(group_shape, &num_elements));
510   for (int64 flat_group_index = 0; flat_group_index < num_elements;
511        ++flat_group_index) {
512     PopulateGroupIndices(flat_group_index, group_shape, &group_indices);
513 
514     // Get values from set1.
515     PopulateFromDenseGroup<T>(ctx, set1_t, set1_strides, group_indices,
516                               &set1_group_set);
517 
518     // Get values from set2, if applicable.
519     set2_group_set.clear();
520     if (set2_group_it != set2_grouper.end()) {
521       const auto& group = *set2_group_it;
522       const auto set2_group_indices = group.group();
523       OP_REQUIRES(
524           ctx, set2_group_indices.size() == group_indices.size(),
525           errors::InvalidArgument("Invalid number of group indices ",
526                                   set2_group_indices.size(), ", expected ",
527                                   group_indices.size(), "."));
528       bool group_match = true;
529       for (int32 i = 0; group_match && (i < set2_group_indices.size()); ++i) {
530         if (set2_group_indices[i] != group_indices[i]) {
531           group_match = false;
532         }
533       }
534       if (group_match) {
535         PopulateFromSparseGroup<T>(ctx, group, set2_st.shape(),
536                                    &set2_group_set);
537         ++set2_group_it;
538       }
539     }
540 
541     std::set<T> group_set;
542     ApplySetOperation(set1_group_set, set2_group_set, &group_set);
543     if (!group_set.empty()) {
544       group_sets[group_indices] = group_set;
545       const auto set_size = group_set.size();
546       if (set_size > max_set_size) {
547         max_set_size = set_size;
548       }
549       num_result_values += set_size;
550     }
551   }
552 
553   TensorShape output_shape;
554   OP_REQUIRES_OK(ctx, TensorShapeUtils::MakeShape(group_shape, &output_shape));
555   output_shape.AddDim(max_set_size);
556   OutputSparseTensor<T>(ctx, output_shape, num_result_values, group_sets);
557 }
558 
559 // This is used to determine which group iterator is less than the other, based
560 // on row-major ordering of indices.
561 // An empty index list indicates end of iteration, which is interpreted as "max"
562 // for the purposes of comparison; i.e., non-empty < empty.
563 // Return 0 if both groups are empty, or both non-empty with the same values.
564 // Return <0 if set1 <= set2, or set2 is empty.
565 // Return >0 if set2 <= set1, or set1 is empty.
CompareGroups(OpKernelContext * ctx,const std::vector<int64> & set1_group_indices,const std::vector<int64> & set2_group_indices,int64 * result)566 void CompareGroups(OpKernelContext* ctx,
567                    const std::vector<int64>& set1_group_indices,
568                    const std::vector<int64>& set2_group_indices,
569                    int64* result) {
570   if (set1_group_indices.empty()) {
571     *result = set2_group_indices.empty() ? 0 : 1;
572     return;
573   }
574   if (set2_group_indices.empty()) {
575     *result = set1_group_indices.empty() ? 0 : -1;
576     return;
577   }
578   OP_REQUIRES(ctx, set1_group_indices.size() == set2_group_indices.size(),
579               errors::InvalidArgument("Mismatched group dims ",
580                                       set1_group_indices.size(), " vs ",
581                                       set2_group_indices.size(), "."));
582   for (int32 i = 0; i < set1_group_indices.size(); ++i) {
583     *result = set1_group_indices[i] - set2_group_indices[i];
584     if (*result != 0) {
585       return;
586     }
587   }
588 }
589 
590 // Empty indices vector represents iteration end in `CompareGroups`.
591 const std::vector<int64> GROUP_ITER_END;
592 
593 // `ctx` contains set1 and set2 sparse tensors.
594 // Iterate over groups in set1 and set2, applying `ApplySetOperation` to each,
595 // and outputing the result `SparseTensor`. A "group" is a collection of values
596 // with the same first n-1 dimensions in set1 and set2.
597 template <typename T>
ComputeSparseToSparse(OpKernelContext * ctx) const598 void SetOperationOp<T>::ComputeSparseToSparse(OpKernelContext* ctx) const {
599   sparse::SparseTensor set1_st;
600   OP_REQUIRES_OK(ctx,
601                  SparseTensorFromContext(ctx, 0, validate_indices_, &set1_st));
602   OP_REQUIRES_OK(ctx, set1_st.IndicesValid());
603 
604   sparse::SparseTensor set2_st;
605   OP_REQUIRES_OK(ctx,
606                  SparseTensorFromContext(ctx, 3, validate_indices_, &set2_st));
607 
608   // The following should stay in sync with `_sparse_to_sparse_shape` shape
609   // assertions in python/ops/set_ops.py, and `SetShapeFn` for
610   // `SparseToSparseSetOperation` in ops/set_ops.cc.
611   ShapeArray group_shape;
612   OP_REQUIRES_OK(ctx, GroupShapeFromInputs(set1_st.shape(), set2_st.shape(),
613                                            &group_shape));
614 
615   const ShapeArray set1_strides = Strides(set1_st.shape());
616   const ShapeArray set2_strides = Strides(set2_st.shape());
617 
618   std::map<std::vector<int64>, std::set<T>> group_sets;
619   int64 num_result_values = 0;
620   int64 max_set_size = 0;
621 
622   std::set<T> set1_group_set;
623   std::set<T> set2_group_set;
624   auto set1_grouper =
625       set1_st.group(set1_st.order().subspan(0, set1_st.order().size() - 1));
626   auto set1_group_it = set1_grouper.begin();
627   auto set2_grouper =
628       set2_st.group(set2_st.order().subspan(0, set2_st.order().size() - 1));
629   auto set2_group_it = set2_grouper.begin();
630 
631   // Group by rows, and iterate over rows of both sets in parallel, creating a
632   // set for each row.
633   while ((set1_group_it != set1_grouper.end()) ||
634          (set2_group_it != set2_grouper.end())) {
635     const std::vector<int64>& set1_group_indices =
636         (set1_group_it == set1_grouper.end()) ? GROUP_ITER_END
637                                               : (*set1_group_it).group();
638     const std::vector<int64>& set2_group_indices =
639         (set2_group_it == set2_grouper.end()) ? GROUP_ITER_END
640                                               : (*set2_group_it).group();
641 
642     int64 compare_groups;
643     CompareGroups(ctx, set1_group_indices, set2_group_indices, &compare_groups);
644     const std::vector<int64>* group_indices = nullptr;
645 
646     // Get values from set1, if applicable.
647     set1_group_set.clear();
648     if (compare_groups <= 0) {
649       PopulateFromSparseGroup<T>(ctx, *set1_group_it, set1_st.shape(),
650                                  &set1_group_set);
651       ++set1_group_it;
652       group_indices = &set1_group_indices;
653     }
654 
655     // Get values from set2, if applicable.
656     set2_group_set.clear();
657     if (compare_groups >= 0) {
658       PopulateFromSparseGroup<T>(ctx, *set2_group_it, set2_st.shape(),
659                                  &set2_group_set);
660       ++set2_group_it;
661       group_indices = &set2_group_indices;
662     }
663 
664     std::set<T> group_set;
665     ApplySetOperation(set1_group_set, set2_group_set, &group_set);
666     if (!group_set.empty()) {
667       group_sets[*group_indices] = group_set;
668       const auto set_size = group_set.size();
669       if (set_size > max_set_size) {
670         max_set_size = set_size;
671       }
672       num_result_values += set_size;
673     }
674   }
675 
676   TensorShape output_shape;
677   OP_REQUIRES_OK(ctx, TensorShapeUtils::MakeShape(group_shape, &output_shape));
678   output_shape.AddDim(max_set_size);
679   OutputSparseTensor<T>(ctx, output_shape, num_result_values, group_sets);
680 }
681 
682 // Given set1 of shape [b, n1] and data_2 of shape [b, n2], populate result
683 // sparse tendor with [b, n3] values, where each row `i` contains the result of
684 // the set operation on elements from set1[i] and set2[i]. `n3` is the number
685 // of elements in that result row.
686 template <typename T>
Compute(OpKernelContext * ctx)687 void SetOperationOp<T>::Compute(OpKernelContext* ctx) {
688   switch (input_types_) {
689     case DENSE_DENSE:
690       ComputeDenseToDense(ctx);
691       break;
692     case DENSE_SPARSE:
693       ComputeDenseToSparse(ctx);
694       break;
695     case SPARSE_SPARSE:
696       ComputeSparseToSparse(ctx);
697       break;
698   }
699 }
700 
701 template <typename T>
702 class DenseToDenseSetOperationOp : public SetOperationOp<T> {
703  public:
DenseToDenseSetOperationOp(OpKernelConstruction * ctx)704   explicit DenseToDenseSetOperationOp(OpKernelConstruction* ctx)
705       : SetOperationOp<T>(ctx, DENSE_DENSE) {}
706 };
707 
708 #define _DENSE_TO_DENSE_SET_OPERATION_REGISTER_KERNEL_BUILDER(T) \
709   REGISTER_KERNEL_BUILDER(Name("DenseToDenseSetOperation")       \
710                               .Device(DEVICE_CPU)                \
711                               .TypeConstraint<T>("T"),           \
712                           DenseToDenseSetOperationOp<T>);
713 _DENSE_TO_DENSE_SET_OPERATION_REGISTER_KERNEL_BUILDER(int8);
714 _DENSE_TO_DENSE_SET_OPERATION_REGISTER_KERNEL_BUILDER(int16);
715 _DENSE_TO_DENSE_SET_OPERATION_REGISTER_KERNEL_BUILDER(int32);
716 _DENSE_TO_DENSE_SET_OPERATION_REGISTER_KERNEL_BUILDER(int64);
717 _DENSE_TO_DENSE_SET_OPERATION_REGISTER_KERNEL_BUILDER(uint8);
718 _DENSE_TO_DENSE_SET_OPERATION_REGISTER_KERNEL_BUILDER(uint16);
719 _DENSE_TO_DENSE_SET_OPERATION_REGISTER_KERNEL_BUILDER(string);
720 #undef _DENSE_TO_DENSE_SET_OPERATION_REGISTER_KERNEL_BUILDER
721 
722 template <typename T>
723 class DenseToSparseSetOperationOp : public SetOperationOp<T> {
724  public:
DenseToSparseSetOperationOp(OpKernelConstruction * ctx)725   explicit DenseToSparseSetOperationOp(OpKernelConstruction* ctx)
726       : SetOperationOp<T>(ctx, DENSE_SPARSE) {}
727 };
728 
729 #define _DENSE_TO_SPARSE_SET_OPERATION_REGISTER_KERNEL_BUILDER(T) \
730   REGISTER_KERNEL_BUILDER(Name("DenseToSparseSetOperation")       \
731                               .Device(DEVICE_CPU)                 \
732                               .TypeConstraint<T>("T"),            \
733                           DenseToSparseSetOperationOp<T>);
734 _DENSE_TO_SPARSE_SET_OPERATION_REGISTER_KERNEL_BUILDER(int8);
735 _DENSE_TO_SPARSE_SET_OPERATION_REGISTER_KERNEL_BUILDER(int16);
736 _DENSE_TO_SPARSE_SET_OPERATION_REGISTER_KERNEL_BUILDER(int32);
737 _DENSE_TO_SPARSE_SET_OPERATION_REGISTER_KERNEL_BUILDER(int64);
738 _DENSE_TO_SPARSE_SET_OPERATION_REGISTER_KERNEL_BUILDER(uint8);
739 _DENSE_TO_SPARSE_SET_OPERATION_REGISTER_KERNEL_BUILDER(uint16);
740 _DENSE_TO_SPARSE_SET_OPERATION_REGISTER_KERNEL_BUILDER(string);
741 #undef _DENSE_TO_SPARSE_SET_OPERATION_REGISTER_KERNEL_BUILDER
742 
743 template <typename T>
744 class SparseToSparseSetOperationOp : public SetOperationOp<T> {
745  public:
SparseToSparseSetOperationOp(OpKernelConstruction * ctx)746   explicit SparseToSparseSetOperationOp(OpKernelConstruction* ctx)
747       : SetOperationOp<T>(ctx, SPARSE_SPARSE) {}
748 };
749 
750 #define _SPARSE_TO_SPARSE_SET_OPERATION_REGISTER_KERNEL_BUILDER(T) \
751   REGISTER_KERNEL_BUILDER(Name("SparseToSparseSetOperation")       \
752                               .Device(DEVICE_CPU)                  \
753                               .TypeConstraint<T>("T"),             \
754                           SparseToSparseSetOperationOp<T>);
755 _SPARSE_TO_SPARSE_SET_OPERATION_REGISTER_KERNEL_BUILDER(int8);
756 _SPARSE_TO_SPARSE_SET_OPERATION_REGISTER_KERNEL_BUILDER(int16);
757 _SPARSE_TO_SPARSE_SET_OPERATION_REGISTER_KERNEL_BUILDER(int32);
758 _SPARSE_TO_SPARSE_SET_OPERATION_REGISTER_KERNEL_BUILDER(int64);
759 _SPARSE_TO_SPARSE_SET_OPERATION_REGISTER_KERNEL_BUILDER(uint8);
760 _SPARSE_TO_SPARSE_SET_OPERATION_REGISTER_KERNEL_BUILDER(uint16);
761 _SPARSE_TO_SPARSE_SET_OPERATION_REGISTER_KERNEL_BUILDER(string);
762 #undef _SPARSE_TO_SPARSE_SET_OPERATION_REGISTER_KERNEL_BUILDER
763 
764 }  // namespace tensorflow
765