1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/core/framework/common_shape_fns.h"
17 #include "tensorflow/core/framework/op.h"
18 #include "tensorflow/core/framework/shape_inference.h"
19 
20 namespace tensorflow {
21 
22 using shape_inference::DimensionHandle;
23 using shape_inference::InferenceContext;
24 using shape_inference::ShapeHandle;
25 
26 namespace {
27 
SparseSparseMinOrMaxShapeFn(InferenceContext * c)28 Status SparseSparseMinOrMaxShapeFn(InferenceContext* c) {
29   ShapeHandle unused;
30   TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 2, &unused));  // a_indices
31   TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 1, &unused));  // a_values
32   TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 1, &unused));  // a_shape
33   TF_RETURN_IF_ERROR(c->WithRank(c->input(3), 2, &unused));  // b_indices
34   TF_RETURN_IF_ERROR(c->WithRank(c->input(4), 1, &unused));  // b_values
35   TF_RETURN_IF_ERROR(c->WithRank(c->input(5), 1, &unused));  // b_shape
36   c->set_output(0, c->Matrix(InferenceContext::kUnknownDim,
37                              InferenceContext::kUnknownDim));
38   c->set_output(1, c->Vector(InferenceContext::kUnknownDim));
39   return Status::OK();
40 }
41 
42 }  // namespace
43 
44 REGISTER_OP("SparseAddGrad")
45     .Input("backprop_val_grad: T")
46     .Input("a_indices: int64")
47     .Input("b_indices: int64")
48     .Input("sum_indices: int64")
49     .Output("a_val_grad: T")
50     .Output("b_val_grad: T")
51     .Attr("T: numbertype")
__anone6e195410202(InferenceContext* c) 52     .SetShapeFn([](InferenceContext* c) {
53       ShapeHandle a_indices;
54       ShapeHandle b_indices;
55       TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 2, &a_indices));
56       TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 2, &b_indices));
57       c->set_output(0, c->Vector(c->Dim(a_indices, 0)));
58       c->set_output(1, c->Vector(c->Dim(b_indices, 0)));
59       return Status::OK();
60     });
61 
62 REGISTER_OP("SparseAdd")
63     .Input("a_indices: int64")
64     .Input("a_values: T")
65     .Input("a_shape: int64")
66     .Input("b_indices: int64")
67     .Input("b_values: T")
68     .Input("b_shape: int64")
69     .Input("thresh: Treal")
70     .Output("sum_indices: int64")
71     .Output("sum_values: T")
72     .Output("sum_shape: int64")
73     .Attr("T: numbertype")
74     .Attr("Treal: realnumbertype")
__anone6e195410302(InferenceContext* c) 75     .SetShapeFn([](InferenceContext* c) {
76       ShapeHandle a_shape;
77       TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 1, &a_shape));
78       c->set_output(
79           0, c->Matrix(InferenceContext::kUnknownDim, c->Dim(a_shape, 0)));
80       c->set_output(1, c->Vector(InferenceContext::kUnknownDim));
81       c->set_output(2, a_shape);
82       return Status::OK();
83     });
84 
85 REGISTER_OP("SparseTensorDenseMatMul")
86     .Input("a_indices: Tindices")
87     .Input("a_values: T")
88     .Input("a_shape: int64")
89     .Input("b: T")
90     .Output("product: T")
91     .Attr("T: type")
92     .Attr("Tindices: {int32,int64} = DT_INT64")
93     .Attr("adjoint_a: bool = false")
94     .Attr("adjoint_b: bool = false")
__anone6e195410402(InferenceContext* c) 95     .SetShapeFn([](InferenceContext* c) {
96       DimensionHandle unused_dim;
97       ShapeHandle unused;
98       ShapeHandle b;
99       ShapeHandle a_shape;
100       TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 2, &unused));  // a_indices
101       TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 1, &unused));  // a_values
102       TF_RETURN_IF_ERROR(c->MakeShapeFromShapeTensor(2, &a_shape));
103       TF_RETURN_IF_ERROR(c->WithRank(a_shape, 2, &a_shape));
104       TF_RETURN_IF_ERROR(c->WithRank(c->input(3), 2, &b));
105 
106       bool adjoint_a;
107       bool adjoint_b;
108       TF_RETURN_IF_ERROR(c->GetAttr("adjoint_a", &adjoint_a));
109       TF_RETURN_IF_ERROR(c->GetAttr("adjoint_b", &adjoint_b));
110 
111       DimensionHandle output_right = c->Dim(b, adjoint_b ? 0 : 1);
112       DimensionHandle output_left = c->Dim(a_shape, adjoint_a ? 1 : 0);
113       DimensionHandle inner_left = c->Dim(a_shape, adjoint_a ? 0 : 1);
114       DimensionHandle inner_right = c->Dim(b, adjoint_b ? 1 : 0);
115       TF_RETURN_IF_ERROR(c->Merge(inner_left, inner_right, &unused_dim));
116       c->set_output(0, c->Matrix(output_left, output_right));
117       return Status::OK();
118     });
119 
120 REGISTER_OP("SerializeSparse")
121     .Input("sparse_indices: int64")
122     .Input("sparse_values: T")
123     .Input("sparse_shape: int64")
124     .Attr("T: type")
125     .Output("serialized_sparse: out_type")
126     .Attr("out_type: {string, variant} = DT_STRING")
__anone6e195410502(InferenceContext* c) 127     .SetShapeFn([](InferenceContext* c) {
128       ShapeHandle unused;
129       TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 2, &unused));
130       TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 1, &unused));
131       TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 1, &unused));
132       c->set_output(0, c->Vector(3));
133       return Status::OK();
134     });
135 
136 REGISTER_OP("SerializeManySparse")
137     .Input("sparse_indices: int64")
138     .Input("sparse_values: T")
139     .Input("sparse_shape: int64")
140     .Attr("T: type")
141     .Output("serialized_sparse: out_type")
142     .Attr("out_type: {string, variant} = DT_STRING")
__anone6e195410602(InferenceContext* c) 143     .SetShapeFn([](InferenceContext* c) {
144       ShapeHandle unused;
145       TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 2, &unused));
146       TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 1, &unused));
147       TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 1, &unused));
148       c->set_output(0, c->Matrix(InferenceContext::kUnknownDim, 3));
149       return Status::OK();
150     });
151 
152 REGISTER_OP("DeserializeSparse")
153     .Input("serialized_sparse: Tserialized")
154     .Output("sparse_indices: int64")
155     .Output("sparse_values: dtype")
156     .Output("sparse_shape: int64")
157     .Attr("dtype: type")
158     .Attr("Tserialized: {string, variant} = DT_STRING")
__anone6e195410702(InferenceContext* c) 159     .SetShapeFn([](InferenceContext* c) {
160       // serialized sparse is [?, ..., ?, 3] vector.
161       DimensionHandle unused;
162       TF_RETURN_IF_ERROR(c->WithValue(c->Dim(c->input(0), -1), 3, &unused));
163       c->set_output(0, c->Matrix(InferenceContext::kUnknownDim,
164                                  InferenceContext::kUnknownDim));
165       c->set_output(1, c->Vector(InferenceContext::kUnknownDim));
166       c->set_output(2, c->Vector(InferenceContext::kUnknownDim));
167       return Status::OK();
168     });
169 
170 REGISTER_OP("DeserializeManySparse")
171     .Input("serialized_sparse: string")
172     .Output("sparse_indices: int64")
173     .Output("sparse_values: dtype")
174     .Output("sparse_shape: int64")
175     .Attr("dtype: type")
__anone6e195410802(InferenceContext* c) 176     .SetShapeFn([](InferenceContext* c) {
177       // serialized sparse is [?,3] matrix.
178       ShapeHandle serialized_sparse;
179       TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 2, &serialized_sparse));
180       DimensionHandle unused;
181       TF_RETURN_IF_ERROR(
182           c->WithValue(c->Dim(serialized_sparse, 1), 3, &unused));
183 
184       c->set_output(0, c->Matrix(InferenceContext::kUnknownDim,
185                                  InferenceContext::kUnknownDim));
186       c->set_output(1, c->Vector(InferenceContext::kUnknownDim));
187       c->set_output(2, c->Vector(InferenceContext::kUnknownDim));
188       return Status::OK();
189     });
190 
191 REGISTER_OP("SparseToDense")
192     .Input("sparse_indices: Tindices")
193     .Input("output_shape: Tindices")
194     .Input("sparse_values: T")
195     .Input("default_value: T")
196     .Attr("validate_indices: bool = true")
197     .Attr("T: type")
198     .Output("dense: T")
199     .Attr("Tindices: {int32, int64}")
__anone6e195410902(InferenceContext* c) 200     .SetShapeFn([](InferenceContext* c) {
201       ShapeHandle out;
202       TF_RETURN_IF_ERROR(c->MakeShapeFromShapeTensor(1, &out));
203       c->set_output(0, out);
204       return Status::OK();
205     });
206 
207 REGISTER_OP("SparseConcat")
208     .Input("indices: N * int64")
209     .Input("values: N * T")
210     .Input("shapes: N * int64")
211     .Output("output_indices: int64")
212     .Output("output_values: T")
213     .Output("output_shape: int64")
214     .Attr("concat_dim: int")
215     .Attr("N: int >= 2")
216     .Attr("T: type")
__anone6e195410a02(InferenceContext* c) 217     .SetShapeFn([](InferenceContext* c) {
218       // These accumulates the sum.
219       DimensionHandle output_row_count = c->MakeDim(0ll);
220 
221       // These are only merged.
222       DimensionHandle output_ind_cols = c->UnknownDim();
223       ShapeHandle output_shape = c->UnknownShape();
224 
225       const int n = c->num_inputs() / 3;
226       for (int i = 0; i < n; i++) {
227         ShapeHandle ind;
228         TF_RETURN_IF_ERROR(c->WithRank(c->input(i), 2, &ind));
229         ShapeHandle val;
230         TF_RETURN_IF_ERROR(c->WithRank(c->input(i + n), 1, &val));
231         ShapeHandle shape;
232         TF_RETURN_IF_ERROR(c->WithRank(c->input(i + 2 * n), 1, &shape));
233 
234         // Add to output_ind_rows.
235         DimensionHandle num_dim;
236         TF_RETURN_IF_ERROR(c->Merge(c->Dim(ind, 0), c->Dim(val, 0), &num_dim));
237         TF_RETURN_IF_ERROR(
238             c->Add(output_row_count, num_dim, &output_row_count));
239 
240         // Merge into output_ind_cols and output_shape.
241         TF_RETURN_IF_ERROR(
242             c->Merge(output_ind_cols, c->Dim(ind, 1), &output_ind_cols));
243         TF_RETURN_IF_ERROR(c->Merge(output_shape, shape, &output_shape));
244       }
245 
246       c->set_output(0, c->Matrix(output_row_count, output_ind_cols));
247       c->set_output(1, c->Vector(output_row_count));
248       c->set_output(2, output_shape);
249       return Status::OK();
250     });
251 
252 REGISTER_OP("SparseCross")
253     .Input("indices: N * int64")
254     .Input("values: sparse_types")
255     .Input("shapes: N * int64")
256     .Input("dense_inputs: dense_types")
257     .Output("output_indices: int64")
258     .Output("output_values: out_type")
259     .Output("output_shape: int64")
260     .Attr("N: int >= 0")
261     .Attr("hashed_output: bool")
262     .Attr("num_buckets: int >= 0")
263     .Attr("hash_key: int")
264     .Attr("sparse_types: list({int64, string}) >= 0")
265     .Attr("dense_types: list({int64, string}) >= 0")
266     .Attr("out_type: {int64, string}")
267     .Attr("internal_type: {int64, string}")
__anone6e195410b02(shape_inference::InferenceContext* c) 268     .SetShapeFn([](shape_inference::InferenceContext* c) {
269       c->set_output(0, c->Matrix(c->UnknownDim(), 2));
270       c->set_output(1, c->Vector(c->UnknownDim()));
271       c->set_output(2, c->Vector(2));
272       return Status::OK();
273     });
274 
275 REGISTER_OP("SparseSplit")
276     .Input("split_dim: int64")
277     .Input("indices: int64")
278     .Input("values: T")
279     .Input("shape: int64")
280     .Output("output_indices: num_split * int64")
281     .Output("output_values:  num_split * T")
282     .Output("output_shape:   num_split * int64")
283     .Attr("num_split: int >= 1")
284     .Attr("T: type")
__anone6e195410c02(InferenceContext* c) 285     .SetShapeFn([](InferenceContext* c) {
286       ShapeHandle input_shape = c->input(3);
287       ShapeHandle output_indices =
288           c->Matrix(InferenceContext::kUnknownDim, c->NumElements(input_shape));
289       ShapeHandle output_values = c->Vector(InferenceContext::kUnknownDim);
290       ShapeHandle output_shape = input_shape;
291 
292       // Copy the outputs into the output ranges.
293       int num_splits = c->num_outputs() / 3;
294       int out_idx = 0;
295       for (int i = 0; i < num_splits; ++i)
296         c->set_output(out_idx++, output_indices);
297       for (int i = 0; i < num_splits; ++i)
298         c->set_output(out_idx++, output_values);
299       for (int i = 0; i < num_splits; ++i)
300         c->set_output(out_idx++, output_shape);
301       return Status::OK();
302     });
303 
304 REGISTER_OP("SparseSliceGrad")
305     .Input("backprop_val_grad: T")
306     .Input("input_indices: int64")
307     .Input("input_start: int64")
308     .Input("output_indices: int64")
309     .Output("val_grad: T")
310     .Attr("T: numbertype")
__anone6e195410d02(InferenceContext* c) 311     .SetShapeFn([](InferenceContext* c) {
312       ShapeHandle indices;
313       TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 2, &indices));
314       c->set_output(0, c->Vector(c->Dim(indices, 0)));
315       return Status::OK();
316     });
317 
318 REGISTER_OP("SparseSlice")
319     .Input("indices: int64")
320     .Input("values: T")
321     .Input("shape: int64")
322     .Input("start: int64")
323     .Input("size: int64")
324     .Output("output_indices: int64")
325     .Output("output_values: T")
326     .Output("output_shape: int64")
327     .Attr("T: type")
__anone6e195410e02(InferenceContext* c) 328     .SetShapeFn([](InferenceContext* c) {
329       ShapeHandle input_shape = c->input(2);
330       ShapeHandle output_indices =
331           c->Matrix(InferenceContext::kUnknownDim, c->NumElements(input_shape));
332       ShapeHandle output_values = c->Vector(InferenceContext::kUnknownDim);
333       ShapeHandle output_shape = input_shape;
334 
335       c->set_output(0, output_indices);
336       c->set_output(1, output_values);
337       c->set_output(2, output_shape);
338       return Status::OK();
339     });
340 
341 REGISTER_OP("SparseReorder")
342     .Input("input_indices: int64")
343     .Input("input_values: T")
344     .Input("input_shape: int64")
345     .Output("output_indices: int64")
346     .Output("output_values: T")
347     .Attr("T: type")
__anone6e195410f02(InferenceContext* c) 348     .SetShapeFn([](InferenceContext* c) {
349       ShapeHandle indices;
350       ShapeHandle values;
351       ShapeHandle unused;
352 
353       TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 2, &indices));
354       TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 1, &values));
355       TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 1, &unused));
356 
357       c->set_output(0, indices);
358       c->set_output(1, values);
359       return Status::OK();
360     });
361 
362 REGISTER_OP("SparseReshape")
363     .Input("input_indices: int64")
364     .Input("input_shape: int64")
365     .Input("new_shape: int64")
366     .Output("output_indices: int64")
367     .Output("output_shape: int64")
__anone6e195411002(InferenceContext* c) 368     .SetShapeFn([](InferenceContext* c) {
369       ShapeHandle indices;
370       ShapeHandle unused;
371       ShapeHandle new_shape;
372 
373       TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 2, &indices));
374       TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 1, &unused));
375       TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 1, &new_shape));
376 
377       c->set_output(0, c->Matrix(c->Dim(indices, 0), c->Dim(new_shape, 0)));
378       c->set_output(1, new_shape);
379       return Status::OK();
380     });
381 
382 REGISTER_OP("SparseTensorDenseAdd")
383     .Input("a_indices: Tindices")
384     .Input("a_values: T")
385     .Input("a_shape: Tindices")
386     .Input("b: T")
387     .Output("output: T")
388     .Attr("T: numbertype")
389     .Attr("Tindices: {int32, int64}")
__anone6e195411102(InferenceContext* c) 390     .SetShapeFn([](InferenceContext* c) {
391       c->set_output(0, c->input(3));
392       return Status::OK();
393     });
394 
395 REGISTER_OP("SparseReduceMax")
396     .Input("input_indices: int64")
397     .Input("input_values: T")
398     .Input("input_shape: int64")
399     .Input("reduction_axes: int32")
400     .Attr("keep_dims: bool = False")
401     .Output("output: T")
402     .Attr("T: realnumbertype")
403     .SetShapeFn(shape_inference::SparseReduceShapeFn);
404 
405 REGISTER_OP("SparseReduceMaxSparse")
406     .Input("input_indices: int64")
407     .Input("input_values: T")
408     .Input("input_shape: int64")
409     .Input("reduction_axes: int32")
410     .Attr("keep_dims: bool = False")
411     .Output("output_indices: int64")
412     .Output("output_values: T")
413     .Output("output_shape: int64")
414     .Attr("T: realnumbertype")
415     .SetShapeFn(shape_inference::UnknownShape);
416 
417 REGISTER_OP("SparseReduceSum")
418     .Input("input_indices: int64")
419     .Input("input_values: T")
420     .Input("input_shape: int64")
421     .Input("reduction_axes: int32")
422     .Attr("keep_dims: bool = False")
423     .Output("output: T")
424     .Attr("T: numbertype")
425     .SetShapeFn(shape_inference::SparseReduceShapeFn);
426 
427 REGISTER_OP("SparseReduceSumSparse")
428     .Input("input_indices: int64")
429     .Input("input_values: T")
430     .Input("input_shape: int64")
431     .Input("reduction_axes: int32")
432     .Attr("keep_dims: bool = False")
433     .Output("output_indices: int64")
434     .Output("output_values: T")
435     .Output("output_shape: int64")
436     .Attr("T: numbertype")
437     .SetShapeFn(shape_inference::UnknownShape);
438 
439 #define SPARSE_DENSE_CWISE_SIGNATURE()                           \
440   Input("sp_indices: int64")                                     \
441       .Input("sp_values: T")                                     \
442       .Input("sp_shape: int64")                                  \
443       .Input("dense: T")                                         \
444       .Output("output: T")                                       \
445       .Attr("T: numbertype")                                     \
446       .SetShapeFn([](InferenceContext* c) {                      \
447         ShapeHandle input;                                       \
448         TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 2, &input)); \
449         c->set_output(0, c->Vector(c->Dim(input, 0)));           \
450         return Status::OK();                                     \
451       })
452 
453 REGISTER_OP("SparseDenseCwiseMul").SPARSE_DENSE_CWISE_SIGNATURE();
454 
455 REGISTER_OP("SparseDenseCwiseDiv").SPARSE_DENSE_CWISE_SIGNATURE();
456 
457 REGISTER_OP("SparseDenseCwiseAdd").SPARSE_DENSE_CWISE_SIGNATURE();
458 
459 #undef SPARSE_DENSE_CWISE_SIGNATURE
460 
461 REGISTER_OP("SparseSoftmax")
462     .Input("sp_indices: int64")
463     .Input("sp_values: T")
464     .Input("sp_shape: int64")
465     .Output("output: T")
466     .Attr("T: {float, double}")
__anone6e195411202(InferenceContext* c) 467     .SetShapeFn([](InferenceContext* c) {
468       ShapeHandle unused;
469       ShapeHandle values;
470       TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 2, &unused));  // sp_indices
471       TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 1, &values));  // sp_values
472       TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 1, &unused));
473       c->set_output(0, values);
474       return Status::OK();
475     });
476 
477 REGISTER_OP("SparseSparseMaximum")
478     .Input("a_indices: int64")
479     .Input("a_values: T")
480     .Input("a_shape: int64")
481     .Input("b_indices: int64")
482     .Input("b_values: T")
483     .Input("b_shape: int64")
484     .Output("output_indices: int64")
485     .Output("output_values: T")
486     .Attr("T: realnumbertype")
487     .SetShapeFn(SparseSparseMinOrMaxShapeFn);
488 
489 REGISTER_OP("SparseSparseMinimum")
490     .Input("a_indices: int64")
491     .Input("a_values: T")
492     .Input("a_shape: int64")
493     .Input("b_indices: int64")
494     .Input("b_values: T")
495     .Input("b_shape: int64")
496     .Output("output_indices: int64")
497     .Output("output_values: T")
498     .Attr("T: numbertype")
499     .SetShapeFn(SparseSparseMinOrMaxShapeFn);
500 
501 REGISTER_OP("AddSparseToTensorsMap")
502     .Input("sparse_indices: int64")
503     .Input("sparse_values: T")
504     .Input("sparse_shape: int64")
505     .Output("sparse_handle: int64")
506     .Attr("T: type")
507     .Attr("container: string = ''")
508     .Attr("shared_name: string = ''")
509     .SetIsStateful()
__anone6e195411302(InferenceContext* c) 510     .SetShapeFn([](InferenceContext* c) {
511       ShapeHandle unused;
512       TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 2, &unused));
513       TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 1, &unused));
514       TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 1, &unused));
515       c->set_output(0, c->Scalar());
516       return Status::OK();
517     });
518 
519 REGISTER_OP("AddManySparseToTensorsMap")
520     .Input("sparse_indices: int64")
521     .Input("sparse_values: T")
522     .Input("sparse_shape: int64")
523     .Output("sparse_handles: int64")
524     .Attr("T: type")
525     .Attr("container: string = ''")
526     .Attr("shared_name: string = ''")
527     .SetIsStateful()
__anone6e195411402(InferenceContext* c) 528     .SetShapeFn([](InferenceContext* c) {
529       ShapeHandle unused;
530       TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 2, &unused));
531       TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 1, &unused));
532       TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 1, &unused));
533       c->set_output(0, c->Vector(InferenceContext::kUnknownDim));
534       return Status::OK();
535     });
536 
537 REGISTER_OP("TakeManySparseFromTensorsMap")
538     .Input("sparse_handles: int64")
539     .Output("sparse_indices: int64")
540     .Output("sparse_values: dtype")
541     .Output("sparse_shape: int64")
542     .Attr("dtype: type")
543     .Attr("container: string = ''")
544     .Attr("shared_name: string = ''")
545     .SetIsStateful()
__anone6e195411502(InferenceContext* c) 546     .SetShapeFn([](InferenceContext* c) {
547       // serialized sparse is [?,1] matrix.
548       ShapeHandle sparse_handles;
549       TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 1, &sparse_handles));
550 
551       c->set_output(0, c->Matrix(InferenceContext::kUnknownDim,
552                                  InferenceContext::kUnknownDim));
553       c->set_output(1, c->Vector(InferenceContext::kUnknownDim));
554       c->set_output(2, c->Vector(InferenceContext::kUnknownDim));
555       return Status::OK();
556     });
557 
558 REGISTER_OP("SparseFillEmptyRows")
559     .Input("indices: int64")
560     .Input("values: T")
561     .Input("dense_shape: int64")
562     .Input("default_value: T")
563     .Output("output_indices: int64")
564     .Output("output_values: T")
565     .Output("empty_row_indicator: bool")
566     .Output("reverse_index_map: int64")
567     .Attr("T: type")
__anone6e195411602(InferenceContext* c) 568     .SetShapeFn([](InferenceContext* c) {
569       ShapeHandle input_indices = c->input(0);
570       TF_RETURN_IF_ERROR(c->WithRank(input_indices, 2, &input_indices));
571       ShapeHandle input_values = c->input(1);
572       TF_RETURN_IF_ERROR(c->WithRank(input_values, 1, &input_values));
573       ShapeHandle input_shape = c->input(2);
574       TF_RETURN_IF_ERROR(c->WithRank(input_shape, 1, &input_shape));
575       ShapeHandle default_value = c->input(3);
576       TF_RETURN_IF_ERROR(c->WithRank(default_value, 0, &default_value));
577       DimensionHandle N = c->Dim(input_indices, 0);
578       TF_RETURN_IF_ERROR(c->Merge(N, c->Dim(input_values, 0), &N));
579       DimensionHandle unused_dim;
580       TF_RETURN_IF_ERROR(c->Merge(c->Dim(input_indices, 1),
581                                   c->Dim(input_shape, 0), &unused_dim));
582       ShapeHandle output_indices =
583           c->Matrix(InferenceContext::kUnknownDim, c->NumElements(input_shape));
584       ShapeHandle output_values = c->Vector(InferenceContext::kUnknownDim);
585       ShapeHandle constant_input_shape;
586       TF_RETURN_IF_ERROR(c->MakeShapeFromShapeTensor(2, &constant_input_shape));
587       ShapeHandle empty_row_indicator =
588           c->Vector(c->Dim(constant_input_shape, 0));
589       ShapeHandle reverse_index_map = c->Vector(N);
590       c->set_output(0, output_indices);
591       c->set_output(1, output_values);
592       c->set_output(2, empty_row_indicator);
593       c->set_output(3, reverse_index_map);
594       return Status::OK();
595     });
596 
597 REGISTER_OP("SparseFillEmptyRowsGrad")
598     .Input("reverse_index_map: int64")
599     .Input("grad_values: T")
600     .Output("d_values: T")
601     .Output("d_default_value: T")
602     .Attr("T: type")
__anone6e195411702(InferenceContext* c) 603     .SetShapeFn([](InferenceContext* c) {
604       ShapeHandle reverse_index_map = c->input(0);
605       TF_RETURN_IF_ERROR(c->WithRank(reverse_index_map, 1, &reverse_index_map));
606       ShapeHandle grad_values = c->input(1);
607       TF_RETURN_IF_ERROR(c->WithRank(grad_values, 1, &grad_values));
608       c->set_output(0, reverse_index_map);
609       c->set_output(1, c->Scalar());
610       return Status::OK();
611     });
612 
613 }  // namespace tensorflow
614