1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/core/framework/common_shape_fns.h"
17 #include "tensorflow/core/framework/op.h"
18 #include "tensorflow/core/framework/shape_inference.h"
19 
20 namespace tensorflow {
21 
22 using shape_inference::DimensionHandle;
23 using shape_inference::InferenceContext;
24 using shape_inference::ShapeHandle;
25 
26 namespace {
27 
SparseSparseMinOrMaxShapeFn(InferenceContext * c)28 Status SparseSparseMinOrMaxShapeFn(InferenceContext* c) {
29   ShapeHandle unused;
30   TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 2, &unused));  // a_indices
31   TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 1, &unused));  // a_values
32   TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 1, &unused));  // a_shape
33   TF_RETURN_IF_ERROR(c->WithRank(c->input(3), 2, &unused));  // b_indices
34   TF_RETURN_IF_ERROR(c->WithRank(c->input(4), 1, &unused));  // b_values
35   TF_RETURN_IF_ERROR(c->WithRank(c->input(5), 1, &unused));  // b_shape
36   c->set_output(0, c->Matrix(InferenceContext::kUnknownDim,
37                              InferenceContext::kUnknownDim));
38   c->set_output(1, c->Vector(InferenceContext::kUnknownDim));
39   return Status::OK();
40 }
41 
42 }  // namespace
43 
44 REGISTER_OP("SparseAddGrad")
45     .Input("backprop_val_grad: T")
46     .Input("a_indices: int64")
47     .Input("b_indices: int64")
48     .Input("sum_indices: int64")
49     .Output("a_val_grad: T")
50     .Output("b_val_grad: T")
51     .Attr("T: numbertype")
__anone6e195410202(InferenceContext* c) 52     .SetShapeFn([](InferenceContext* c) {
53       ShapeHandle a_indices;
54       ShapeHandle b_indices;
55       TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 2, &a_indices));
56       TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 2, &b_indices));
57       c->set_output(0, c->Vector(c->Dim(a_indices, 0)));
58       c->set_output(1, c->Vector(c->Dim(b_indices, 0)));
59       return Status::OK();
60     });
61 
62 REGISTER_OP("SparseAdd")
63     .Input("a_indices: int64")
64     .Input("a_values: T")
65     .Input("a_shape: int64")
66     .Input("b_indices: int64")
67     .Input("b_values: T")
68     .Input("b_shape: int64")
69     .Input("thresh: Treal")
70     .Output("sum_indices: int64")
71     .Output("sum_values: T")
72     .Output("sum_shape: int64")
73     .Attr("T: numbertype")
74     .Attr("Treal: realnumbertype")
__anone6e195410302(InferenceContext* c) 75     .SetShapeFn([](InferenceContext* c) {
76       ShapeHandle a_shape;
77       TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 1, &a_shape));
78       c->set_output(
79           0, c->Matrix(InferenceContext::kUnknownDim, c->Dim(a_shape, 0)));
80       c->set_output(1, c->Vector(InferenceContext::kUnknownDim));
81       c->set_output(2, a_shape);
82       return Status::OK();
83     });
84 
85 REGISTER_OP("SparseTensorDenseMatMul")
86     .Input("a_indices: Tindices")
87     .Input("a_values: T")
88     .Input("a_shape: int64")
89     .Input("b: T")
90     .Output("product: T")
91     .Attr("T: type")
92     .Attr("Tindices: {int32,int64} = DT_INT64")
93     .Attr("adjoint_a: bool = false")
94     .Attr("adjoint_b: bool = false")
__anone6e195410402(InferenceContext* c) 95     .SetShapeFn([](InferenceContext* c) {
96       DimensionHandle unused_dim;
97       ShapeHandle unused;
98       ShapeHandle b;
99       ShapeHandle a_shape;
100       TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 2, &unused));  // a_indices
101       TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 1, &unused));  // a_values
102       TF_RETURN_IF_ERROR(c->MakeShapeFromShapeTensor(2, &a_shape));
103       TF_RETURN_IF_ERROR(c->WithRank(a_shape, 2, &a_shape));
104       TF_RETURN_IF_ERROR(c->WithRank(c->input(3), 2, &b));
105 
106       bool adjoint_a;
107       bool adjoint_b;
108       TF_RETURN_IF_ERROR(c->GetAttr("adjoint_a", &adjoint_a));
109       TF_RETURN_IF_ERROR(c->GetAttr("adjoint_b", &adjoint_b));
110 
111       DimensionHandle output_right = c->Dim(b, adjoint_b ? 0 : 1);
112       DimensionHandle output_left = c->Dim(a_shape, adjoint_a ? 1 : 0);
113       DimensionHandle inner_left = c->Dim(a_shape, adjoint_a ? 0 : 1);
114       DimensionHandle inner_right = c->Dim(b, adjoint_b ? 1 : 0);
115       TF_RETURN_IF_ERROR(c->Merge(inner_left, inner_right, &unused_dim));
116       c->set_output(0, c->Matrix(output_left, output_right));
117       return Status::OK();
118     });
119 
120 REGISTER_OP("SerializeSparse")
121     .Input("sparse_indices: int64")
122     .Input("sparse_values: T")
123     .Input("sparse_shape: int64")
124     .Attr("T: type")
125     .Output("serialized_sparse: out_type")
126     .Attr("out_type: {string, variant} = DT_STRING")
__anone6e195410502(InferenceContext* c) 127     .SetShapeFn([](InferenceContext* c) {
128       ShapeHandle unused;
129       TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 2, &unused));
130       TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 1, &unused));
131       TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 1, &unused));
132       c->set_output(0, c->Vector(3));
133       return Status::OK();
134     });
135 
136 REGISTER_OP("SerializeManySparse")
137     .Input("sparse_indices: int64")
138     .Input("sparse_values: T")
139     .Input("sparse_shape: int64")
140     .Attr("T: type")
141     .Output("serialized_sparse: out_type")
142     .Attr("out_type: {string, variant} = DT_STRING")
__anone6e195410602(InferenceContext* c) 143     .SetShapeFn([](InferenceContext* c) {
144       ShapeHandle unused;
145       TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 2, &unused));
146       TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 1, &unused));
147       TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 1, &unused));
148       c->set_output(0, c->Matrix(InferenceContext::kUnknownDim, 3));
149       return Status::OK();
150     });
151 
152 REGISTER_OP("DeserializeSparse")
153     .Input("serialized_sparse: Tserialized")
154     .Output("sparse_indices: int64")
155     .Output("sparse_values: dtype")
156     .Output("sparse_shape: int64")
157     .Attr("dtype: type")
158     .Attr("Tserialized: {string, variant} = DT_STRING")
__anone6e195410702(InferenceContext* c) 159     .SetShapeFn([](InferenceContext* c) {
160       // serialized sparse is [?, ..., ?, 3] vector.
161       DimensionHandle unused;
162       TF_RETURN_IF_ERROR(c->WithValue(c->Dim(c->input(0), -1), 3, &unused));
163       c->set_output(0, c->Matrix(InferenceContext::kUnknownDim,
164                                  InferenceContext::kUnknownDim));
165       c->set_output(1, c->Vector(InferenceContext::kUnknownDim));
166       c->set_output(2, c->Vector(InferenceContext::kUnknownDim));
167       return Status::OK();
168     });
169 
170 REGISTER_OP("DeserializeManySparse")
171     .Input("serialized_sparse: string")
172     .Output("sparse_indices: int64")
173     .Output("sparse_values: dtype")
174     .Output("sparse_shape: int64")
175     .Attr("dtype: type")
__anone6e195410802(InferenceContext* c) 176     .SetShapeFn([](InferenceContext* c) {
177       // serialized sparse is [?,3] matrix.
178       ShapeHandle serialized_sparse;
179       TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 2, &serialized_sparse));
180       DimensionHandle unused;
181       TF_RETURN_IF_ERROR(
182           c->WithValue(c->Dim(serialized_sparse, 1), 3, &unused));
183 
184       c->set_output(0, c->Matrix(InferenceContext::kUnknownDim,
185                                  InferenceContext::kUnknownDim));
186       c->set_output(1, c->Vector(InferenceContext::kUnknownDim));
187       c->set_output(2, c->Vector(InferenceContext::kUnknownDim));
188       return Status::OK();
189     });
190 
191 REGISTER_OP("SparseToDense")
192     .Input("sparse_indices: Tindices")
193     .Input("output_shape: Tindices")
194     .Input("sparse_values: T")
195     .Input("default_value: T")
196     .Attr("validate_indices: bool = true")
197     .Attr("T: type")
198     .Output("dense: T")
199     .Attr("Tindices: {int32, int64}")
__anone6e195410902(InferenceContext* c) 200     .SetShapeFn([](InferenceContext* c) {
201       ShapeHandle out;
202       TF_RETURN_IF_ERROR(c->MakeShapeFromShapeTensor(1, &out));
203       c->set_output(0, out);
204       return Status::OK();
205     });
206 
207 REGISTER_OP("SparseConcat")
208     .Input("indices: N * int64")
209     .Input("values: N * T")
210     .Input("shapes: N * int64")
211     .Output("output_indices: int64")
212     .Output("output_values: T")
213     .Output("output_shape: int64")
214     .Attr("concat_dim: int")
215     .Attr("N: int >= 2")
216     .Attr("T: type")
__anone6e195410a02(InferenceContext* c) 217     .SetShapeFn([](InferenceContext* c) {
218       // These accumulates the sum.
219       DimensionHandle output_row_count = c->MakeDim(0ll);
220 
221       // These are only merged.
222       DimensionHandle output_ind_cols = c->UnknownDim();
223       ShapeHandle output_shape = c->UnknownShape();
224 
225       const int n = c->num_inputs() / 3;
226       for (int i = 0; i < n; i++) {
227         ShapeHandle ind;
228         TF_RETURN_IF_ERROR(c->WithRank(c->input(i), 2, &ind));
229         ShapeHandle val;
230         TF_RETURN_IF_ERROR(c->WithRank(c->input(i + n), 1, &val));
231         ShapeHandle shape;
232         TF_RETURN_IF_ERROR(c->WithRank(c->input(i + 2 * n), 1, &shape));
233 
234         // Add to output_ind_rows.
235         DimensionHandle num_dim;
236         TF_RETURN_IF_ERROR(c->Merge(c->Dim(ind, 0), c->Dim(val, 0), &num_dim));
237         TF_RETURN_IF_ERROR(
238             c->Add(output_row_count, num_dim, &output_row_count));
239 
240         // Merge into output_ind_cols and output_shape.
241         TF_RETURN_IF_ERROR(
242             c->Merge(output_ind_cols, c->Dim(ind, 1), &output_ind_cols));
243         TF_RETURN_IF_ERROR(c->Merge(output_shape, shape, &output_shape));
244       }
245 
246       c->set_output(0, c->Matrix(output_row_count, output_ind_cols));
247       c->set_output(1, c->Vector(output_row_count));
248       c->set_output(2, output_shape);
249       return Status::OK();
250     });
251 
252 REGISTER_OP("SparseCross")
253     .Input("indices: N * int64")
254     .Input("values: sparse_types")
255     .Input("shapes: N * int64")
256     .Input("dense_inputs: dense_types")
257     .Output("output_indices: int64")
258     .Output("output_values: out_type")
259     .Output("output_shape: int64")
260     .Attr("N: int >= 0")
261     .Attr("hashed_output: bool")
262     .Attr("num_buckets: int >= 0")
263     .Attr("hash_key: int")
264     .Attr("sparse_types: list({int64, string}) >= 0")
265     .Attr("dense_types: list({int64, string}) >= 0")
266     .Attr("out_type: {int64, string}")
267     .Attr("internal_type: {int64, string}")
__anone6e195410b02(shape_inference::InferenceContext* c) 268     .SetShapeFn([](shape_inference::InferenceContext* c) {
269       c->set_output(0, c->Matrix(c->UnknownDim(), 2));
270       c->set_output(1, c->Vector(c->UnknownDim()));
271       c->set_output(2, c->Vector(2));
272       return Status::OK();
273     });
274 
275 REGISTER_OP("SparseCrossV2")
276     .Input("indices: N * int64")
277     .Input("values: sparse_types")
278     .Input("shapes: N * int64")
279     .Input("dense_inputs: dense_types")
280     .Input("sep: string")
281     .Output("output_indices: int64")
282     .Output("output_values: string")
283     .Output("output_shape: int64")
284     .Attr("N: int >= 0")
285     .Attr("sparse_types: list({int64, string}) >= 0")
286     .Attr("dense_types: list({int64, string}) >= 0")
__anone6e195410c02(shape_inference::InferenceContext* c) 287     .SetShapeFn([](shape_inference::InferenceContext* c) {
288       c->set_output(0, c->Matrix(c->UnknownDim(), 2));
289       c->set_output(1, c->Vector(c->UnknownDim()));
290       c->set_output(2, c->Vector(2));
291       return Status::OK();
292     });
293 
294 REGISTER_OP("SparseCrossHashed")
295     .Input("indices: N * int64")
296     .Input("values: sparse_types")
297     .Input("shapes: N * int64")
298     .Input("dense_inputs: dense_types")
299     .Input("num_buckets: int64")
300     .Input("strong_hash: bool")
301     .Input("salt: int64")
302     .Output("output_indices: int64")
303     .Output("output_values: int64")
304     .Output("output_shape: int64")
305     .Attr("N: int >= 0")
306     .Attr("sparse_types: list({int64, string}) >= 0")
307     .Attr("dense_types: list({int64, string}) >= 0")
__anone6e195410d02(shape_inference::InferenceContext* c) 308     .SetShapeFn([](shape_inference::InferenceContext* c) {
309       c->set_output(0, c->Matrix(c->UnknownDim(), 2));
310       c->set_output(1, c->Vector(c->UnknownDim()));
311       c->set_output(2, c->Vector(2));
312       return Status::OK();
313     });
314 
315 REGISTER_OP("SparseSplit")
316     .Input("split_dim: int64")
317     .Input("indices: int64")
318     .Input("values: T")
319     .Input("shape: int64")
320     .Output("output_indices: num_split * int64")
321     .Output("output_values:  num_split * T")
322     .Output("output_shape:   num_split * int64")
323     .Attr("num_split: int >= 1")
324     .Attr("T: type")
__anone6e195410e02(InferenceContext* c) 325     .SetShapeFn([](InferenceContext* c) {
326       ShapeHandle input_shape = c->input(3);
327       ShapeHandle output_indices =
328           c->Matrix(InferenceContext::kUnknownDim, c->NumElements(input_shape));
329       ShapeHandle output_values = c->Vector(InferenceContext::kUnknownDim);
330       ShapeHandle output_shape = input_shape;
331 
332       // Copy the outputs into the output ranges.
333       int num_splits = c->num_outputs() / 3;
334       int out_idx = 0;
335       for (int i = 0; i < num_splits; ++i)
336         c->set_output(out_idx++, output_indices);
337       for (int i = 0; i < num_splits; ++i)
338         c->set_output(out_idx++, output_values);
339       for (int i = 0; i < num_splits; ++i)
340         c->set_output(out_idx++, output_shape);
341       return Status::OK();
342     });
343 
344 REGISTER_OP("SparseSliceGrad")
345     .Input("backprop_val_grad: T")
346     .Input("input_indices: int64")
347     .Input("input_start: int64")
348     .Input("output_indices: int64")
349     .Output("val_grad: T")
350     .Attr("T: numbertype")
__anone6e195410f02(InferenceContext* c) 351     .SetShapeFn([](InferenceContext* c) {
352       ShapeHandle indices;
353       TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 2, &indices));
354       c->set_output(0, c->Vector(c->Dim(indices, 0)));
355       return Status::OK();
356     });
357 
358 REGISTER_OP("SparseSlice")
359     .Input("indices: int64")
360     .Input("values: T")
361     .Input("shape: int64")
362     .Input("start: int64")
363     .Input("size: int64")
364     .Output("output_indices: int64")
365     .Output("output_values: T")
366     .Output("output_shape: int64")
367     .Attr("T: type")
__anone6e195411002(InferenceContext* c) 368     .SetShapeFn([](InferenceContext* c) {
369       ShapeHandle input_shape = c->input(2);
370       ShapeHandle output_indices =
371           c->Matrix(InferenceContext::kUnknownDim, c->NumElements(input_shape));
372       ShapeHandle output_values = c->Vector(InferenceContext::kUnknownDim);
373       ShapeHandle output_shape = input_shape;
374 
375       c->set_output(0, output_indices);
376       c->set_output(1, output_values);
377       c->set_output(2, output_shape);
378       return Status::OK();
379     });
380 
381 REGISTER_OP("SparseReorder")
382     .Input("input_indices: int64")
383     .Input("input_values: T")
384     .Input("input_shape: int64")
385     .Output("output_indices: int64")
386     .Output("output_values: T")
387     .Attr("T: type")
__anone6e195411102(InferenceContext* c) 388     .SetShapeFn([](InferenceContext* c) {
389       ShapeHandle indices;
390       ShapeHandle values;
391       ShapeHandle unused;
392 
393       TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 2, &indices));
394       TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 1, &values));
395       TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 1, &unused));
396 
397       c->set_output(0, indices);
398       c->set_output(1, values);
399       return Status::OK();
400     });
401 
402 REGISTER_OP("SparseReshape")
403     .Input("input_indices: int64")
404     .Input("input_shape: int64")
405     .Input("new_shape: int64")
406     .Output("output_indices: int64")
407     .Output("output_shape: int64")
__anone6e195411202(InferenceContext* c) 408     .SetShapeFn([](InferenceContext* c) {
409       ShapeHandle indices;
410       ShapeHandle unused;
411       ShapeHandle new_shape;
412 
413       TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 2, &indices));
414       TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 1, &unused));
415       TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 1, &new_shape));
416 
417       c->set_output(0, c->Matrix(c->Dim(indices, 0), c->Dim(new_shape, 0)));
418       c->set_output(1, new_shape);
419       return Status::OK();
420     });
421 
422 REGISTER_OP("SparseTensorDenseAdd")
423     .Input("a_indices: Tindices")
424     .Input("a_values: T")
425     .Input("a_shape: Tindices")
426     .Input("b: T")
427     .Output("output: T")
428     .Attr("T: numbertype")
429     .Attr("Tindices: {int32, int64}")
__anone6e195411302(InferenceContext* c) 430     .SetShapeFn([](InferenceContext* c) {
431       c->set_output(0, c->input(3));
432       return Status::OK();
433     });
434 
435 REGISTER_OP("SparseReduceMax")
436     .Input("input_indices: int64")
437     .Input("input_values: T")
438     .Input("input_shape: int64")
439     .Input("reduction_axes: int32")
440     .Attr("keep_dims: bool = False")
441     .Output("output: T")
442     .Attr("T: realnumbertype")
443     .SetShapeFn(shape_inference::SparseReduceShapeFn);
444 
445 REGISTER_OP("SparseReduceMaxSparse")
446     .Input("input_indices: int64")
447     .Input("input_values: T")
448     .Input("input_shape: int64")
449     .Input("reduction_axes: int32")
450     .Attr("keep_dims: bool = False")
451     .Output("output_indices: int64")
452     .Output("output_values: T")
453     .Output("output_shape: int64")
454     .Attr("T: realnumbertype")
455     .SetShapeFn(shape_inference::UnknownShape);
456 
457 REGISTER_OP("SparseReduceSum")
458     .Input("input_indices: int64")
459     .Input("input_values: T")
460     .Input("input_shape: int64")
461     .Input("reduction_axes: int32")
462     .Attr("keep_dims: bool = False")
463     .Output("output: T")
464     .Attr("T: numbertype")
465     .SetShapeFn(shape_inference::SparseReduceShapeFn);
466 
467 REGISTER_OP("SparseReduceSumSparse")
468     .Input("input_indices: int64")
469     .Input("input_values: T")
470     .Input("input_shape: int64")
471     .Input("reduction_axes: int32")
472     .Attr("keep_dims: bool = False")
473     .Output("output_indices: int64")
474     .Output("output_values: T")
475     .Output("output_shape: int64")
476     .Attr("T: numbertype")
477     .SetShapeFn(shape_inference::UnknownShape);
478 
479 #define SPARSE_DENSE_CWISE_SIGNATURE()                           \
480   Input("sp_indices: int64")                                     \
481       .Input("sp_values: T")                                     \
482       .Input("sp_shape: int64")                                  \
483       .Input("dense: T")                                         \
484       .Output("output: T")                                       \
485       .Attr("T: numbertype")                                     \
486       .SetShapeFn([](InferenceContext* c) {                      \
487         ShapeHandle input;                                       \
488         TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 2, &input)); \
489         c->set_output(0, c->Vector(c->Dim(input, 0)));           \
490         return Status::OK();                                     \
491       })
492 
493 REGISTER_OP("SparseDenseCwiseMul").SPARSE_DENSE_CWISE_SIGNATURE();
494 
495 REGISTER_OP("SparseDenseCwiseDiv").SPARSE_DENSE_CWISE_SIGNATURE();
496 
497 REGISTER_OP("SparseDenseCwiseAdd").SPARSE_DENSE_CWISE_SIGNATURE();
498 
499 #undef SPARSE_DENSE_CWISE_SIGNATURE
500 
501 REGISTER_OP("SparseSoftmax")
502     .Input("sp_indices: int64")
503     .Input("sp_values: T")
504     .Input("sp_shape: int64")
505     .Output("output: T")
506     .Attr("T: {float, double}")
__anone6e195411402(InferenceContext* c) 507     .SetShapeFn([](InferenceContext* c) {
508       ShapeHandle unused;
509       ShapeHandle values;
510       TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 2, &unused));  // sp_indices
511       TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 1, &values));  // sp_values
512       TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 1, &unused));
513       c->set_output(0, values);
514       return Status::OK();
515     });
516 
517 REGISTER_OP("SparseSparseMaximum")
518     .Input("a_indices: int64")
519     .Input("a_values: T")
520     .Input("a_shape: int64")
521     .Input("b_indices: int64")
522     .Input("b_values: T")
523     .Input("b_shape: int64")
524     .Output("output_indices: int64")
525     .Output("output_values: T")
526     .Attr("T: realnumbertype")
527     .SetShapeFn(SparseSparseMinOrMaxShapeFn);
528 
529 REGISTER_OP("SparseSparseMinimum")
530     .Input("a_indices: int64")
531     .Input("a_values: T")
532     .Input("a_shape: int64")
533     .Input("b_indices: int64")
534     .Input("b_values: T")
535     .Input("b_shape: int64")
536     .Output("output_indices: int64")
537     .Output("output_values: T")
538     .Attr("T: numbertype")
539     .SetShapeFn(SparseSparseMinOrMaxShapeFn);
540 
541 REGISTER_OP("AddSparseToTensorsMap")
542     .Input("sparse_indices: int64")
543     .Input("sparse_values: T")
544     .Input("sparse_shape: int64")
545     .Output("sparse_handle: int64")
546     .Attr("T: type")
547     .Attr("container: string = ''")
548     .Attr("shared_name: string = ''")
549     .SetIsStateful()
__anone6e195411502(InferenceContext* c) 550     .SetShapeFn([](InferenceContext* c) {
551       ShapeHandle unused;
552       TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 2, &unused));
553       TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 1, &unused));
554       TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 1, &unused));
555       c->set_output(0, c->Scalar());
556       return Status::OK();
557     });
558 
559 REGISTER_OP("AddManySparseToTensorsMap")
560     .Input("sparse_indices: int64")
561     .Input("sparse_values: T")
562     .Input("sparse_shape: int64")
563     .Output("sparse_handles: int64")
564     .Attr("T: type")
565     .Attr("container: string = ''")
566     .Attr("shared_name: string = ''")
567     .SetIsStateful()
__anone6e195411602(InferenceContext* c) 568     .SetShapeFn([](InferenceContext* c) {
569       ShapeHandle unused;
570       TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 2, &unused));
571       TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 1, &unused));
572       TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 1, &unused));
573       c->set_output(0, c->Vector(InferenceContext::kUnknownDim));
574       return Status::OK();
575     });
576 
577 REGISTER_OP("TakeManySparseFromTensorsMap")
578     .Input("sparse_handles: int64")
579     .Output("sparse_indices: int64")
580     .Output("sparse_values: dtype")
581     .Output("sparse_shape: int64")
582     .Attr("dtype: type")
583     .Attr("container: string = ''")
584     .Attr("shared_name: string = ''")
585     .SetIsStateful()
__anone6e195411702(InferenceContext* c) 586     .SetShapeFn([](InferenceContext* c) {
587       // serialized sparse is [?,1] matrix.
588       ShapeHandle sparse_handles;
589       TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 1, &sparse_handles));
590 
591       c->set_output(0, c->Matrix(InferenceContext::kUnknownDim,
592                                  InferenceContext::kUnknownDim));
593       c->set_output(1, c->Vector(InferenceContext::kUnknownDim));
594       c->set_output(2, c->Vector(InferenceContext::kUnknownDim));
595       return Status::OK();
596     });
597 
598 REGISTER_OP("SparseFillEmptyRows")
599     .Input("indices: int64")
600     .Input("values: T")
601     .Input("dense_shape: int64")
602     .Input("default_value: T")
603     .Output("output_indices: int64")
604     .Output("output_values: T")
605     .Output("empty_row_indicator: bool")
606     .Output("reverse_index_map: int64")
607     .Attr("T: type")
__anone6e195411802(InferenceContext* c) 608     .SetShapeFn([](InferenceContext* c) {
609       ShapeHandle input_indices = c->input(0);
610       TF_RETURN_IF_ERROR(c->WithRank(input_indices, 2, &input_indices));
611       ShapeHandle input_values = c->input(1);
612       TF_RETURN_IF_ERROR(c->WithRank(input_values, 1, &input_values));
613       ShapeHandle input_shape = c->input(2);
614       TF_RETURN_IF_ERROR(c->WithRank(input_shape, 1, &input_shape));
615       ShapeHandle default_value = c->input(3);
616       TF_RETURN_IF_ERROR(c->WithRank(default_value, 0, &default_value));
617       DimensionHandle N = c->Dim(input_indices, 0);
618       TF_RETURN_IF_ERROR(c->Merge(N, c->Dim(input_values, 0), &N));
619       DimensionHandle unused_dim;
620       TF_RETURN_IF_ERROR(c->Merge(c->Dim(input_indices, 1),
621                                   c->Dim(input_shape, 0), &unused_dim));
622       ShapeHandle output_indices =
623           c->Matrix(InferenceContext::kUnknownDim, c->NumElements(input_shape));
624       ShapeHandle output_values = c->Vector(InferenceContext::kUnknownDim);
625       ShapeHandle constant_input_shape;
626       TF_RETURN_IF_ERROR(c->MakeShapeFromShapeTensor(2, &constant_input_shape));
627       ShapeHandle empty_row_indicator =
628           c->Vector(c->Dim(constant_input_shape, 0));
629       ShapeHandle reverse_index_map = c->Vector(N);
630       c->set_output(0, output_indices);
631       c->set_output(1, output_values);
632       c->set_output(2, empty_row_indicator);
633       c->set_output(3, reverse_index_map);
634       return Status::OK();
635     });
636 
637 REGISTER_OP("SparseFillEmptyRowsGrad")
638     .Input("reverse_index_map: int64")
639     .Input("grad_values: T")
640     .Output("d_values: T")
641     .Output("d_default_value: T")
642     .Attr("T: type")
__anone6e195411902(InferenceContext* c) 643     .SetShapeFn([](InferenceContext* c) {
644       ShapeHandle reverse_index_map = c->input(0);
645       TF_RETURN_IF_ERROR(c->WithRank(reverse_index_map, 1, &reverse_index_map));
646       ShapeHandle grad_values = c->input(1);
647       TF_RETURN_IF_ERROR(c->WithRank(grad_values, 1, &grad_values));
648       c->set_output(0, reverse_index_map);
649       c->set_output(1, c->Scalar());
650       return Status::OK();
651     });
652 
653 }  // namespace tensorflow
654