1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/core/framework/common_shape_fns.h"
17 #include "tensorflow/core/framework/op.h"
18 #include "tensorflow/core/framework/shape_inference.h"
19
20 namespace tensorflow {
21
22 using shape_inference::DimensionHandle;
23 using shape_inference::InferenceContext;
24 using shape_inference::ShapeHandle;
25
26 namespace {
27
SparseSparseMinOrMaxShapeFn(InferenceContext * c)28 Status SparseSparseMinOrMaxShapeFn(InferenceContext* c) {
29 ShapeHandle unused;
30 TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 2, &unused)); // a_indices
31 TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 1, &unused)); // a_values
32 TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 1, &unused)); // a_shape
33 TF_RETURN_IF_ERROR(c->WithRank(c->input(3), 2, &unused)); // b_indices
34 TF_RETURN_IF_ERROR(c->WithRank(c->input(4), 1, &unused)); // b_values
35 TF_RETURN_IF_ERROR(c->WithRank(c->input(5), 1, &unused)); // b_shape
36 c->set_output(0, c->Matrix(InferenceContext::kUnknownDim,
37 InferenceContext::kUnknownDim));
38 c->set_output(1, c->Vector(InferenceContext::kUnknownDim));
39 return Status::OK();
40 }
41
42 } // namespace
43
44 REGISTER_OP("SparseAddGrad")
45 .Input("backprop_val_grad: T")
46 .Input("a_indices: int64")
47 .Input("b_indices: int64")
48 .Input("sum_indices: int64")
49 .Output("a_val_grad: T")
50 .Output("b_val_grad: T")
51 .Attr("T: numbertype")
__anone6e195410202(InferenceContext* c) 52 .SetShapeFn([](InferenceContext* c) {
53 ShapeHandle a_indices;
54 ShapeHandle b_indices;
55 TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 2, &a_indices));
56 TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 2, &b_indices));
57 c->set_output(0, c->Vector(c->Dim(a_indices, 0)));
58 c->set_output(1, c->Vector(c->Dim(b_indices, 0)));
59 return Status::OK();
60 });
61
62 REGISTER_OP("SparseAdd")
63 .Input("a_indices: int64")
64 .Input("a_values: T")
65 .Input("a_shape: int64")
66 .Input("b_indices: int64")
67 .Input("b_values: T")
68 .Input("b_shape: int64")
69 .Input("thresh: Treal")
70 .Output("sum_indices: int64")
71 .Output("sum_values: T")
72 .Output("sum_shape: int64")
73 .Attr("T: numbertype")
74 .Attr("Treal: realnumbertype")
__anone6e195410302(InferenceContext* c) 75 .SetShapeFn([](InferenceContext* c) {
76 ShapeHandle a_shape;
77 TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 1, &a_shape));
78 c->set_output(
79 0, c->Matrix(InferenceContext::kUnknownDim, c->Dim(a_shape, 0)));
80 c->set_output(1, c->Vector(InferenceContext::kUnknownDim));
81 c->set_output(2, a_shape);
82 return Status::OK();
83 });
84
85 REGISTER_OP("SparseTensorDenseMatMul")
86 .Input("a_indices: Tindices")
87 .Input("a_values: T")
88 .Input("a_shape: int64")
89 .Input("b: T")
90 .Output("product: T")
91 .Attr("T: type")
92 .Attr("Tindices: {int32,int64} = DT_INT64")
93 .Attr("adjoint_a: bool = false")
94 .Attr("adjoint_b: bool = false")
__anone6e195410402(InferenceContext* c) 95 .SetShapeFn([](InferenceContext* c) {
96 DimensionHandle unused_dim;
97 ShapeHandle unused;
98 ShapeHandle b;
99 ShapeHandle a_shape;
100 TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 2, &unused)); // a_indices
101 TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 1, &unused)); // a_values
102 TF_RETURN_IF_ERROR(c->MakeShapeFromShapeTensor(2, &a_shape));
103 TF_RETURN_IF_ERROR(c->WithRank(a_shape, 2, &a_shape));
104 TF_RETURN_IF_ERROR(c->WithRank(c->input(3), 2, &b));
105
106 bool adjoint_a;
107 bool adjoint_b;
108 TF_RETURN_IF_ERROR(c->GetAttr("adjoint_a", &adjoint_a));
109 TF_RETURN_IF_ERROR(c->GetAttr("adjoint_b", &adjoint_b));
110
111 DimensionHandle output_right = c->Dim(b, adjoint_b ? 0 : 1);
112 DimensionHandle output_left = c->Dim(a_shape, adjoint_a ? 1 : 0);
113 DimensionHandle inner_left = c->Dim(a_shape, adjoint_a ? 0 : 1);
114 DimensionHandle inner_right = c->Dim(b, adjoint_b ? 1 : 0);
115 TF_RETURN_IF_ERROR(c->Merge(inner_left, inner_right, &unused_dim));
116 c->set_output(0, c->Matrix(output_left, output_right));
117 return Status::OK();
118 });
119
120 REGISTER_OP("SerializeSparse")
121 .Input("sparse_indices: int64")
122 .Input("sparse_values: T")
123 .Input("sparse_shape: int64")
124 .Attr("T: type")
125 .Output("serialized_sparse: out_type")
126 .Attr("out_type: {string, variant} = DT_STRING")
__anone6e195410502(InferenceContext* c) 127 .SetShapeFn([](InferenceContext* c) {
128 ShapeHandle unused;
129 TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 2, &unused));
130 TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 1, &unused));
131 TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 1, &unused));
132 c->set_output(0, c->Vector(3));
133 return Status::OK();
134 });
135
136 REGISTER_OP("SerializeManySparse")
137 .Input("sparse_indices: int64")
138 .Input("sparse_values: T")
139 .Input("sparse_shape: int64")
140 .Attr("T: type")
141 .Output("serialized_sparse: out_type")
142 .Attr("out_type: {string, variant} = DT_STRING")
__anone6e195410602(InferenceContext* c) 143 .SetShapeFn([](InferenceContext* c) {
144 ShapeHandle unused;
145 TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 2, &unused));
146 TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 1, &unused));
147 TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 1, &unused));
148 c->set_output(0, c->Matrix(InferenceContext::kUnknownDim, 3));
149 return Status::OK();
150 });
151
152 REGISTER_OP("DeserializeSparse")
153 .Input("serialized_sparse: Tserialized")
154 .Output("sparse_indices: int64")
155 .Output("sparse_values: dtype")
156 .Output("sparse_shape: int64")
157 .Attr("dtype: type")
158 .Attr("Tserialized: {string, variant} = DT_STRING")
__anone6e195410702(InferenceContext* c) 159 .SetShapeFn([](InferenceContext* c) {
160 // serialized sparse is [?, ..., ?, 3] vector.
161 DimensionHandle unused;
162 TF_RETURN_IF_ERROR(c->WithValue(c->Dim(c->input(0), -1), 3, &unused));
163 c->set_output(0, c->Matrix(InferenceContext::kUnknownDim,
164 InferenceContext::kUnknownDim));
165 c->set_output(1, c->Vector(InferenceContext::kUnknownDim));
166 c->set_output(2, c->Vector(InferenceContext::kUnknownDim));
167 return Status::OK();
168 });
169
170 REGISTER_OP("DeserializeManySparse")
171 .Input("serialized_sparse: string")
172 .Output("sparse_indices: int64")
173 .Output("sparse_values: dtype")
174 .Output("sparse_shape: int64")
175 .Attr("dtype: type")
__anone6e195410802(InferenceContext* c) 176 .SetShapeFn([](InferenceContext* c) {
177 // serialized sparse is [?,3] matrix.
178 ShapeHandle serialized_sparse;
179 TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 2, &serialized_sparse));
180 DimensionHandle unused;
181 TF_RETURN_IF_ERROR(
182 c->WithValue(c->Dim(serialized_sparse, 1), 3, &unused));
183
184 c->set_output(0, c->Matrix(InferenceContext::kUnknownDim,
185 InferenceContext::kUnknownDim));
186 c->set_output(1, c->Vector(InferenceContext::kUnknownDim));
187 c->set_output(2, c->Vector(InferenceContext::kUnknownDim));
188 return Status::OK();
189 });
190
191 REGISTER_OP("SparseToDense")
192 .Input("sparse_indices: Tindices")
193 .Input("output_shape: Tindices")
194 .Input("sparse_values: T")
195 .Input("default_value: T")
196 .Attr("validate_indices: bool = true")
197 .Attr("T: type")
198 .Output("dense: T")
199 .Attr("Tindices: {int32, int64}")
__anone6e195410902(InferenceContext* c) 200 .SetShapeFn([](InferenceContext* c) {
201 ShapeHandle out;
202 TF_RETURN_IF_ERROR(c->MakeShapeFromShapeTensor(1, &out));
203 c->set_output(0, out);
204 return Status::OK();
205 });
206
207 REGISTER_OP("SparseConcat")
208 .Input("indices: N * int64")
209 .Input("values: N * T")
210 .Input("shapes: N * int64")
211 .Output("output_indices: int64")
212 .Output("output_values: T")
213 .Output("output_shape: int64")
214 .Attr("concat_dim: int")
215 .Attr("N: int >= 2")
216 .Attr("T: type")
__anone6e195410a02(InferenceContext* c) 217 .SetShapeFn([](InferenceContext* c) {
218 // These accumulates the sum.
219 DimensionHandle output_row_count = c->MakeDim(0ll);
220
221 // These are only merged.
222 DimensionHandle output_ind_cols = c->UnknownDim();
223 ShapeHandle output_shape = c->UnknownShape();
224
225 const int n = c->num_inputs() / 3;
226 for (int i = 0; i < n; i++) {
227 ShapeHandle ind;
228 TF_RETURN_IF_ERROR(c->WithRank(c->input(i), 2, &ind));
229 ShapeHandle val;
230 TF_RETURN_IF_ERROR(c->WithRank(c->input(i + n), 1, &val));
231 ShapeHandle shape;
232 TF_RETURN_IF_ERROR(c->WithRank(c->input(i + 2 * n), 1, &shape));
233
234 // Add to output_ind_rows.
235 DimensionHandle num_dim;
236 TF_RETURN_IF_ERROR(c->Merge(c->Dim(ind, 0), c->Dim(val, 0), &num_dim));
237 TF_RETURN_IF_ERROR(
238 c->Add(output_row_count, num_dim, &output_row_count));
239
240 // Merge into output_ind_cols and output_shape.
241 TF_RETURN_IF_ERROR(
242 c->Merge(output_ind_cols, c->Dim(ind, 1), &output_ind_cols));
243 TF_RETURN_IF_ERROR(c->Merge(output_shape, shape, &output_shape));
244 }
245
246 c->set_output(0, c->Matrix(output_row_count, output_ind_cols));
247 c->set_output(1, c->Vector(output_row_count));
248 c->set_output(2, output_shape);
249 return Status::OK();
250 });
251
252 REGISTER_OP("SparseCross")
253 .Input("indices: N * int64")
254 .Input("values: sparse_types")
255 .Input("shapes: N * int64")
256 .Input("dense_inputs: dense_types")
257 .Output("output_indices: int64")
258 .Output("output_values: out_type")
259 .Output("output_shape: int64")
260 .Attr("N: int >= 0")
261 .Attr("hashed_output: bool")
262 .Attr("num_buckets: int >= 0")
263 .Attr("hash_key: int")
264 .Attr("sparse_types: list({int64, string}) >= 0")
265 .Attr("dense_types: list({int64, string}) >= 0")
266 .Attr("out_type: {int64, string}")
267 .Attr("internal_type: {int64, string}")
__anone6e195410b02(shape_inference::InferenceContext* c) 268 .SetShapeFn([](shape_inference::InferenceContext* c) {
269 c->set_output(0, c->Matrix(c->UnknownDim(), 2));
270 c->set_output(1, c->Vector(c->UnknownDim()));
271 c->set_output(2, c->Vector(2));
272 return Status::OK();
273 });
274
275 REGISTER_OP("SparseCrossV2")
276 .Input("indices: N * int64")
277 .Input("values: sparse_types")
278 .Input("shapes: N * int64")
279 .Input("dense_inputs: dense_types")
280 .Input("sep: string")
281 .Output("output_indices: int64")
282 .Output("output_values: string")
283 .Output("output_shape: int64")
284 .Attr("N: int >= 0")
285 .Attr("sparse_types: list({int64, string}) >= 0")
286 .Attr("dense_types: list({int64, string}) >= 0")
__anone6e195410c02(shape_inference::InferenceContext* c) 287 .SetShapeFn([](shape_inference::InferenceContext* c) {
288 c->set_output(0, c->Matrix(c->UnknownDim(), 2));
289 c->set_output(1, c->Vector(c->UnknownDim()));
290 c->set_output(2, c->Vector(2));
291 return Status::OK();
292 });
293
294 REGISTER_OP("SparseCrossHashed")
295 .Input("indices: N * int64")
296 .Input("values: sparse_types")
297 .Input("shapes: N * int64")
298 .Input("dense_inputs: dense_types")
299 .Input("num_buckets: int64")
300 .Input("strong_hash: bool")
301 .Input("salt: int64")
302 .Output("output_indices: int64")
303 .Output("output_values: int64")
304 .Output("output_shape: int64")
305 .Attr("N: int >= 0")
306 .Attr("sparse_types: list({int64, string}) >= 0")
307 .Attr("dense_types: list({int64, string}) >= 0")
__anone6e195410d02(shape_inference::InferenceContext* c) 308 .SetShapeFn([](shape_inference::InferenceContext* c) {
309 c->set_output(0, c->Matrix(c->UnknownDim(), 2));
310 c->set_output(1, c->Vector(c->UnknownDim()));
311 c->set_output(2, c->Vector(2));
312 return Status::OK();
313 });
314
315 REGISTER_OP("SparseSplit")
316 .Input("split_dim: int64")
317 .Input("indices: int64")
318 .Input("values: T")
319 .Input("shape: int64")
320 .Output("output_indices: num_split * int64")
321 .Output("output_values: num_split * T")
322 .Output("output_shape: num_split * int64")
323 .Attr("num_split: int >= 1")
324 .Attr("T: type")
__anone6e195410e02(InferenceContext* c) 325 .SetShapeFn([](InferenceContext* c) {
326 ShapeHandle input_shape = c->input(3);
327 ShapeHandle output_indices =
328 c->Matrix(InferenceContext::kUnknownDim, c->NumElements(input_shape));
329 ShapeHandle output_values = c->Vector(InferenceContext::kUnknownDim);
330 ShapeHandle output_shape = input_shape;
331
332 // Copy the outputs into the output ranges.
333 int num_splits = c->num_outputs() / 3;
334 int out_idx = 0;
335 for (int i = 0; i < num_splits; ++i)
336 c->set_output(out_idx++, output_indices);
337 for (int i = 0; i < num_splits; ++i)
338 c->set_output(out_idx++, output_values);
339 for (int i = 0; i < num_splits; ++i)
340 c->set_output(out_idx++, output_shape);
341 return Status::OK();
342 });
343
344 REGISTER_OP("SparseSliceGrad")
345 .Input("backprop_val_grad: T")
346 .Input("input_indices: int64")
347 .Input("input_start: int64")
348 .Input("output_indices: int64")
349 .Output("val_grad: T")
350 .Attr("T: numbertype")
__anone6e195410f02(InferenceContext* c) 351 .SetShapeFn([](InferenceContext* c) {
352 ShapeHandle indices;
353 TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 2, &indices));
354 c->set_output(0, c->Vector(c->Dim(indices, 0)));
355 return Status::OK();
356 });
357
358 REGISTER_OP("SparseSlice")
359 .Input("indices: int64")
360 .Input("values: T")
361 .Input("shape: int64")
362 .Input("start: int64")
363 .Input("size: int64")
364 .Output("output_indices: int64")
365 .Output("output_values: T")
366 .Output("output_shape: int64")
367 .Attr("T: type")
__anone6e195411002(InferenceContext* c) 368 .SetShapeFn([](InferenceContext* c) {
369 ShapeHandle input_shape = c->input(2);
370 ShapeHandle output_indices =
371 c->Matrix(InferenceContext::kUnknownDim, c->NumElements(input_shape));
372 ShapeHandle output_values = c->Vector(InferenceContext::kUnknownDim);
373 ShapeHandle output_shape = input_shape;
374
375 c->set_output(0, output_indices);
376 c->set_output(1, output_values);
377 c->set_output(2, output_shape);
378 return Status::OK();
379 });
380
381 REGISTER_OP("SparseReorder")
382 .Input("input_indices: int64")
383 .Input("input_values: T")
384 .Input("input_shape: int64")
385 .Output("output_indices: int64")
386 .Output("output_values: T")
387 .Attr("T: type")
__anone6e195411102(InferenceContext* c) 388 .SetShapeFn([](InferenceContext* c) {
389 ShapeHandle indices;
390 ShapeHandle values;
391 ShapeHandle unused;
392
393 TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 2, &indices));
394 TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 1, &values));
395 TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 1, &unused));
396
397 c->set_output(0, indices);
398 c->set_output(1, values);
399 return Status::OK();
400 });
401
402 REGISTER_OP("SparseReshape")
403 .Input("input_indices: int64")
404 .Input("input_shape: int64")
405 .Input("new_shape: int64")
406 .Output("output_indices: int64")
407 .Output("output_shape: int64")
__anone6e195411202(InferenceContext* c) 408 .SetShapeFn([](InferenceContext* c) {
409 ShapeHandle indices;
410 ShapeHandle unused;
411 ShapeHandle new_shape;
412
413 TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 2, &indices));
414 TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 1, &unused));
415 TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 1, &new_shape));
416
417 c->set_output(0, c->Matrix(c->Dim(indices, 0), c->Dim(new_shape, 0)));
418 c->set_output(1, new_shape);
419 return Status::OK();
420 });
421
422 REGISTER_OP("SparseTensorDenseAdd")
423 .Input("a_indices: Tindices")
424 .Input("a_values: T")
425 .Input("a_shape: Tindices")
426 .Input("b: T")
427 .Output("output: T")
428 .Attr("T: numbertype")
429 .Attr("Tindices: {int32, int64}")
__anone6e195411302(InferenceContext* c) 430 .SetShapeFn([](InferenceContext* c) {
431 c->set_output(0, c->input(3));
432 return Status::OK();
433 });
434
435 REGISTER_OP("SparseReduceMax")
436 .Input("input_indices: int64")
437 .Input("input_values: T")
438 .Input("input_shape: int64")
439 .Input("reduction_axes: int32")
440 .Attr("keep_dims: bool = False")
441 .Output("output: T")
442 .Attr("T: realnumbertype")
443 .SetShapeFn(shape_inference::SparseReduceShapeFn);
444
445 REGISTER_OP("SparseReduceMaxSparse")
446 .Input("input_indices: int64")
447 .Input("input_values: T")
448 .Input("input_shape: int64")
449 .Input("reduction_axes: int32")
450 .Attr("keep_dims: bool = False")
451 .Output("output_indices: int64")
452 .Output("output_values: T")
453 .Output("output_shape: int64")
454 .Attr("T: realnumbertype")
455 .SetShapeFn(shape_inference::UnknownShape);
456
457 REGISTER_OP("SparseReduceSum")
458 .Input("input_indices: int64")
459 .Input("input_values: T")
460 .Input("input_shape: int64")
461 .Input("reduction_axes: int32")
462 .Attr("keep_dims: bool = False")
463 .Output("output: T")
464 .Attr("T: numbertype")
465 .SetShapeFn(shape_inference::SparseReduceShapeFn);
466
467 REGISTER_OP("SparseReduceSumSparse")
468 .Input("input_indices: int64")
469 .Input("input_values: T")
470 .Input("input_shape: int64")
471 .Input("reduction_axes: int32")
472 .Attr("keep_dims: bool = False")
473 .Output("output_indices: int64")
474 .Output("output_values: T")
475 .Output("output_shape: int64")
476 .Attr("T: numbertype")
477 .SetShapeFn(shape_inference::UnknownShape);
478
479 #define SPARSE_DENSE_CWISE_SIGNATURE() \
480 Input("sp_indices: int64") \
481 .Input("sp_values: T") \
482 .Input("sp_shape: int64") \
483 .Input("dense: T") \
484 .Output("output: T") \
485 .Attr("T: numbertype") \
486 .SetShapeFn([](InferenceContext* c) { \
487 ShapeHandle input; \
488 TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 2, &input)); \
489 c->set_output(0, c->Vector(c->Dim(input, 0))); \
490 return Status::OK(); \
491 })
492
493 REGISTER_OP("SparseDenseCwiseMul").SPARSE_DENSE_CWISE_SIGNATURE();
494
495 REGISTER_OP("SparseDenseCwiseDiv").SPARSE_DENSE_CWISE_SIGNATURE();
496
497 REGISTER_OP("SparseDenseCwiseAdd").SPARSE_DENSE_CWISE_SIGNATURE();
498
499 #undef SPARSE_DENSE_CWISE_SIGNATURE
500
501 REGISTER_OP("SparseSoftmax")
502 .Input("sp_indices: int64")
503 .Input("sp_values: T")
504 .Input("sp_shape: int64")
505 .Output("output: T")
506 .Attr("T: {float, double}")
__anone6e195411402(InferenceContext* c) 507 .SetShapeFn([](InferenceContext* c) {
508 ShapeHandle unused;
509 ShapeHandle values;
510 TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 2, &unused)); // sp_indices
511 TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 1, &values)); // sp_values
512 TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 1, &unused));
513 c->set_output(0, values);
514 return Status::OK();
515 });
516
517 REGISTER_OP("SparseSparseMaximum")
518 .Input("a_indices: int64")
519 .Input("a_values: T")
520 .Input("a_shape: int64")
521 .Input("b_indices: int64")
522 .Input("b_values: T")
523 .Input("b_shape: int64")
524 .Output("output_indices: int64")
525 .Output("output_values: T")
526 .Attr("T: realnumbertype")
527 .SetShapeFn(SparseSparseMinOrMaxShapeFn);
528
529 REGISTER_OP("SparseSparseMinimum")
530 .Input("a_indices: int64")
531 .Input("a_values: T")
532 .Input("a_shape: int64")
533 .Input("b_indices: int64")
534 .Input("b_values: T")
535 .Input("b_shape: int64")
536 .Output("output_indices: int64")
537 .Output("output_values: T")
538 .Attr("T: numbertype")
539 .SetShapeFn(SparseSparseMinOrMaxShapeFn);
540
541 REGISTER_OP("AddSparseToTensorsMap")
542 .Input("sparse_indices: int64")
543 .Input("sparse_values: T")
544 .Input("sparse_shape: int64")
545 .Output("sparse_handle: int64")
546 .Attr("T: type")
547 .Attr("container: string = ''")
548 .Attr("shared_name: string = ''")
549 .SetIsStateful()
__anone6e195411502(InferenceContext* c) 550 .SetShapeFn([](InferenceContext* c) {
551 ShapeHandle unused;
552 TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 2, &unused));
553 TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 1, &unused));
554 TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 1, &unused));
555 c->set_output(0, c->Scalar());
556 return Status::OK();
557 });
558
559 REGISTER_OP("AddManySparseToTensorsMap")
560 .Input("sparse_indices: int64")
561 .Input("sparse_values: T")
562 .Input("sparse_shape: int64")
563 .Output("sparse_handles: int64")
564 .Attr("T: type")
565 .Attr("container: string = ''")
566 .Attr("shared_name: string = ''")
567 .SetIsStateful()
__anone6e195411602(InferenceContext* c) 568 .SetShapeFn([](InferenceContext* c) {
569 ShapeHandle unused;
570 TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 2, &unused));
571 TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 1, &unused));
572 TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 1, &unused));
573 c->set_output(0, c->Vector(InferenceContext::kUnknownDim));
574 return Status::OK();
575 });
576
577 REGISTER_OP("TakeManySparseFromTensorsMap")
578 .Input("sparse_handles: int64")
579 .Output("sparse_indices: int64")
580 .Output("sparse_values: dtype")
581 .Output("sparse_shape: int64")
582 .Attr("dtype: type")
583 .Attr("container: string = ''")
584 .Attr("shared_name: string = ''")
585 .SetIsStateful()
__anone6e195411702(InferenceContext* c) 586 .SetShapeFn([](InferenceContext* c) {
587 // serialized sparse is [?,1] matrix.
588 ShapeHandle sparse_handles;
589 TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 1, &sparse_handles));
590
591 c->set_output(0, c->Matrix(InferenceContext::kUnknownDim,
592 InferenceContext::kUnknownDim));
593 c->set_output(1, c->Vector(InferenceContext::kUnknownDim));
594 c->set_output(2, c->Vector(InferenceContext::kUnknownDim));
595 return Status::OK();
596 });
597
598 REGISTER_OP("SparseFillEmptyRows")
599 .Input("indices: int64")
600 .Input("values: T")
601 .Input("dense_shape: int64")
602 .Input("default_value: T")
603 .Output("output_indices: int64")
604 .Output("output_values: T")
605 .Output("empty_row_indicator: bool")
606 .Output("reverse_index_map: int64")
607 .Attr("T: type")
__anone6e195411802(InferenceContext* c) 608 .SetShapeFn([](InferenceContext* c) {
609 ShapeHandle input_indices = c->input(0);
610 TF_RETURN_IF_ERROR(c->WithRank(input_indices, 2, &input_indices));
611 ShapeHandle input_values = c->input(1);
612 TF_RETURN_IF_ERROR(c->WithRank(input_values, 1, &input_values));
613 ShapeHandle input_shape = c->input(2);
614 TF_RETURN_IF_ERROR(c->WithRank(input_shape, 1, &input_shape));
615 ShapeHandle default_value = c->input(3);
616 TF_RETURN_IF_ERROR(c->WithRank(default_value, 0, &default_value));
617 DimensionHandle N = c->Dim(input_indices, 0);
618 TF_RETURN_IF_ERROR(c->Merge(N, c->Dim(input_values, 0), &N));
619 DimensionHandle unused_dim;
620 TF_RETURN_IF_ERROR(c->Merge(c->Dim(input_indices, 1),
621 c->Dim(input_shape, 0), &unused_dim));
622 ShapeHandle output_indices =
623 c->Matrix(InferenceContext::kUnknownDim, c->NumElements(input_shape));
624 ShapeHandle output_values = c->Vector(InferenceContext::kUnknownDim);
625 ShapeHandle constant_input_shape;
626 TF_RETURN_IF_ERROR(c->MakeShapeFromShapeTensor(2, &constant_input_shape));
627 ShapeHandle empty_row_indicator =
628 c->Vector(c->Dim(constant_input_shape, 0));
629 ShapeHandle reverse_index_map = c->Vector(N);
630 c->set_output(0, output_indices);
631 c->set_output(1, output_values);
632 c->set_output(2, empty_row_indicator);
633 c->set_output(3, reverse_index_map);
634 return Status::OK();
635 });
636
637 REGISTER_OP("SparseFillEmptyRowsGrad")
638 .Input("reverse_index_map: int64")
639 .Input("grad_values: T")
640 .Output("d_values: T")
641 .Output("d_default_value: T")
642 .Attr("T: type")
__anone6e195411902(InferenceContext* c) 643 .SetShapeFn([](InferenceContext* c) {
644 ShapeHandle reverse_index_map = c->input(0);
645 TF_RETURN_IF_ERROR(c->WithRank(reverse_index_map, 1, &reverse_index_map));
646 ShapeHandle grad_values = c->input(1);
647 TF_RETURN_IF_ERROR(c->WithRank(grad_values, 1, &grad_values));
648 c->set_output(0, reverse_index_map);
649 c->set_output(1, c->Scalar());
650 return Status::OK();
651 });
652
653 } // namespace tensorflow
654