1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/core/framework/common_shape_fns.h"
17 #include "tensorflow/core/framework/op.h"
18 #include "tensorflow/core/framework/shape_inference.h"
19
20 namespace tensorflow {
21
22 using shape_inference::DimensionHandle;
23 using shape_inference::InferenceContext;
24 using shape_inference::ShapeHandle;
25
26 namespace {
27
SparseSparseMinOrMaxShapeFn(InferenceContext * c)28 Status SparseSparseMinOrMaxShapeFn(InferenceContext* c) {
29 ShapeHandle unused;
30 TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 2, &unused)); // a_indices
31 TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 1, &unused)); // a_values
32 TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 1, &unused)); // a_shape
33 TF_RETURN_IF_ERROR(c->WithRank(c->input(3), 2, &unused)); // b_indices
34 TF_RETURN_IF_ERROR(c->WithRank(c->input(4), 1, &unused)); // b_values
35 TF_RETURN_IF_ERROR(c->WithRank(c->input(5), 1, &unused)); // b_shape
36 c->set_output(0, c->Matrix(InferenceContext::kUnknownDim,
37 InferenceContext::kUnknownDim));
38 c->set_output(1, c->Vector(InferenceContext::kUnknownDim));
39 return Status::OK();
40 }
41
42 } // namespace
43
44 REGISTER_OP("SparseAddGrad")
45 .Input("backprop_val_grad: T")
46 .Input("a_indices: int64")
47 .Input("b_indices: int64")
48 .Input("sum_indices: int64")
49 .Output("a_val_grad: T")
50 .Output("b_val_grad: T")
51 .Attr("T: numbertype")
__anone6e195410202(InferenceContext* c) 52 .SetShapeFn([](InferenceContext* c) {
53 ShapeHandle a_indices;
54 ShapeHandle b_indices;
55 TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 2, &a_indices));
56 TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 2, &b_indices));
57 c->set_output(0, c->Vector(c->Dim(a_indices, 0)));
58 c->set_output(1, c->Vector(c->Dim(b_indices, 0)));
59 return Status::OK();
60 });
61
62 REGISTER_OP("SparseAdd")
63 .Input("a_indices: int64")
64 .Input("a_values: T")
65 .Input("a_shape: int64")
66 .Input("b_indices: int64")
67 .Input("b_values: T")
68 .Input("b_shape: int64")
69 .Input("thresh: Treal")
70 .Output("sum_indices: int64")
71 .Output("sum_values: T")
72 .Output("sum_shape: int64")
73 .Attr("T: numbertype")
74 .Attr("Treal: realnumbertype")
__anone6e195410302(InferenceContext* c) 75 .SetShapeFn([](InferenceContext* c) {
76 ShapeHandle a_shape;
77 TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 1, &a_shape));
78 c->set_output(
79 0, c->Matrix(InferenceContext::kUnknownDim, c->Dim(a_shape, 0)));
80 c->set_output(1, c->Vector(InferenceContext::kUnknownDim));
81 c->set_output(2, a_shape);
82 return Status::OK();
83 });
84
85 REGISTER_OP("SparseTensorDenseMatMul")
86 .Input("a_indices: Tindices")
87 .Input("a_values: T")
88 .Input("a_shape: int64")
89 .Input("b: T")
90 .Output("product: T")
91 .Attr("T: type")
92 .Attr("Tindices: {int32,int64} = DT_INT64")
93 .Attr("adjoint_a: bool = false")
94 .Attr("adjoint_b: bool = false")
__anone6e195410402(InferenceContext* c) 95 .SetShapeFn([](InferenceContext* c) {
96 DimensionHandle unused_dim;
97 ShapeHandle unused;
98 ShapeHandle b;
99 ShapeHandle a_shape;
100 TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 2, &unused)); // a_indices
101 TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 1, &unused)); // a_values
102 TF_RETURN_IF_ERROR(c->MakeShapeFromShapeTensor(2, &a_shape));
103 TF_RETURN_IF_ERROR(c->WithRank(a_shape, 2, &a_shape));
104 TF_RETURN_IF_ERROR(c->WithRank(c->input(3), 2, &b));
105
106 bool adjoint_a;
107 bool adjoint_b;
108 TF_RETURN_IF_ERROR(c->GetAttr("adjoint_a", &adjoint_a));
109 TF_RETURN_IF_ERROR(c->GetAttr("adjoint_b", &adjoint_b));
110
111 DimensionHandle output_right = c->Dim(b, adjoint_b ? 0 : 1);
112 DimensionHandle output_left = c->Dim(a_shape, adjoint_a ? 1 : 0);
113 DimensionHandle inner_left = c->Dim(a_shape, adjoint_a ? 0 : 1);
114 DimensionHandle inner_right = c->Dim(b, adjoint_b ? 1 : 0);
115 TF_RETURN_IF_ERROR(c->Merge(inner_left, inner_right, &unused_dim));
116 c->set_output(0, c->Matrix(output_left, output_right));
117 return Status::OK();
118 });
119
120 REGISTER_OP("SerializeSparse")
121 .Input("sparse_indices: int64")
122 .Input("sparse_values: T")
123 .Input("sparse_shape: int64")
124 .Attr("T: type")
125 .Output("serialized_sparse: out_type")
126 .Attr("out_type: {string, variant} = DT_STRING")
__anone6e195410502(InferenceContext* c) 127 .SetShapeFn([](InferenceContext* c) {
128 ShapeHandle unused;
129 TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 2, &unused));
130 TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 1, &unused));
131 TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 1, &unused));
132 c->set_output(0, c->Vector(3));
133 return Status::OK();
134 });
135
136 REGISTER_OP("SerializeManySparse")
137 .Input("sparse_indices: int64")
138 .Input("sparse_values: T")
139 .Input("sparse_shape: int64")
140 .Attr("T: type")
141 .Output("serialized_sparse: out_type")
142 .Attr("out_type: {string, variant} = DT_STRING")
__anone6e195410602(InferenceContext* c) 143 .SetShapeFn([](InferenceContext* c) {
144 ShapeHandle unused;
145 TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 2, &unused));
146 TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 1, &unused));
147 TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 1, &unused));
148 c->set_output(0, c->Matrix(InferenceContext::kUnknownDim, 3));
149 return Status::OK();
150 });
151
152 REGISTER_OP("DeserializeSparse")
153 .Input("serialized_sparse: Tserialized")
154 .Output("sparse_indices: int64")
155 .Output("sparse_values: dtype")
156 .Output("sparse_shape: int64")
157 .Attr("dtype: type")
158 .Attr("Tserialized: {string, variant} = DT_STRING")
__anone6e195410702(InferenceContext* c) 159 .SetShapeFn([](InferenceContext* c) {
160 // serialized sparse is [?, ..., ?, 3] vector.
161 DimensionHandle unused;
162 TF_RETURN_IF_ERROR(c->WithValue(c->Dim(c->input(0), -1), 3, &unused));
163 c->set_output(0, c->Matrix(InferenceContext::kUnknownDim,
164 InferenceContext::kUnknownDim));
165 c->set_output(1, c->Vector(InferenceContext::kUnknownDim));
166 c->set_output(2, c->Vector(InferenceContext::kUnknownDim));
167 return Status::OK();
168 });
169
170 REGISTER_OP("DeserializeManySparse")
171 .Input("serialized_sparse: string")
172 .Output("sparse_indices: int64")
173 .Output("sparse_values: dtype")
174 .Output("sparse_shape: int64")
175 .Attr("dtype: type")
__anone6e195410802(InferenceContext* c) 176 .SetShapeFn([](InferenceContext* c) {
177 // serialized sparse is [?,3] matrix.
178 ShapeHandle serialized_sparse;
179 TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 2, &serialized_sparse));
180 DimensionHandle unused;
181 TF_RETURN_IF_ERROR(
182 c->WithValue(c->Dim(serialized_sparse, 1), 3, &unused));
183
184 c->set_output(0, c->Matrix(InferenceContext::kUnknownDim,
185 InferenceContext::kUnknownDim));
186 c->set_output(1, c->Vector(InferenceContext::kUnknownDim));
187 c->set_output(2, c->Vector(InferenceContext::kUnknownDim));
188 return Status::OK();
189 });
190
191 REGISTER_OP("SparseToDense")
192 .Input("sparse_indices: Tindices")
193 .Input("output_shape: Tindices")
194 .Input("sparse_values: T")
195 .Input("default_value: T")
196 .Attr("validate_indices: bool = true")
197 .Attr("T: type")
198 .Output("dense: T")
199 .Attr("Tindices: {int32, int64}")
__anone6e195410902(InferenceContext* c) 200 .SetShapeFn([](InferenceContext* c) {
201 ShapeHandle out;
202 TF_RETURN_IF_ERROR(c->MakeShapeFromShapeTensor(1, &out));
203 c->set_output(0, out);
204 return Status::OK();
205 });
206
207 REGISTER_OP("SparseConcat")
208 .Input("indices: N * int64")
209 .Input("values: N * T")
210 .Input("shapes: N * int64")
211 .Output("output_indices: int64")
212 .Output("output_values: T")
213 .Output("output_shape: int64")
214 .Attr("concat_dim: int")
215 .Attr("N: int >= 2")
216 .Attr("T: type")
__anone6e195410a02(InferenceContext* c) 217 .SetShapeFn([](InferenceContext* c) {
218 // These accumulates the sum.
219 DimensionHandle output_row_count = c->MakeDim(0ll);
220
221 // These are only merged.
222 DimensionHandle output_ind_cols = c->UnknownDim();
223 ShapeHandle output_shape = c->UnknownShape();
224
225 const int n = c->num_inputs() / 3;
226 for (int i = 0; i < n; i++) {
227 ShapeHandle ind;
228 TF_RETURN_IF_ERROR(c->WithRank(c->input(i), 2, &ind));
229 ShapeHandle val;
230 TF_RETURN_IF_ERROR(c->WithRank(c->input(i + n), 1, &val));
231 ShapeHandle shape;
232 TF_RETURN_IF_ERROR(c->WithRank(c->input(i + 2 * n), 1, &shape));
233
234 // Add to output_ind_rows.
235 DimensionHandle num_dim;
236 TF_RETURN_IF_ERROR(c->Merge(c->Dim(ind, 0), c->Dim(val, 0), &num_dim));
237 TF_RETURN_IF_ERROR(
238 c->Add(output_row_count, num_dim, &output_row_count));
239
240 // Merge into output_ind_cols and output_shape.
241 TF_RETURN_IF_ERROR(
242 c->Merge(output_ind_cols, c->Dim(ind, 1), &output_ind_cols));
243 TF_RETURN_IF_ERROR(c->Merge(output_shape, shape, &output_shape));
244 }
245
246 c->set_output(0, c->Matrix(output_row_count, output_ind_cols));
247 c->set_output(1, c->Vector(output_row_count));
248 c->set_output(2, output_shape);
249 return Status::OK();
250 });
251
252 REGISTER_OP("SparseCross")
253 .Input("indices: N * int64")
254 .Input("values: sparse_types")
255 .Input("shapes: N * int64")
256 .Input("dense_inputs: dense_types")
257 .Output("output_indices: int64")
258 .Output("output_values: out_type")
259 .Output("output_shape: int64")
260 .Attr("N: int >= 0")
261 .Attr("hashed_output: bool")
262 .Attr("num_buckets: int >= 0")
263 .Attr("hash_key: int")
264 .Attr("sparse_types: list({int64, string}) >= 0")
265 .Attr("dense_types: list({int64, string}) >= 0")
266 .Attr("out_type: {int64, string}")
267 .Attr("internal_type: {int64, string}")
__anone6e195410b02(shape_inference::InferenceContext* c) 268 .SetShapeFn([](shape_inference::InferenceContext* c) {
269 c->set_output(0, c->Matrix(c->UnknownDim(), 2));
270 c->set_output(1, c->Vector(c->UnknownDim()));
271 c->set_output(2, c->Vector(2));
272 return Status::OK();
273 });
274
275 REGISTER_OP("SparseSplit")
276 .Input("split_dim: int64")
277 .Input("indices: int64")
278 .Input("values: T")
279 .Input("shape: int64")
280 .Output("output_indices: num_split * int64")
281 .Output("output_values: num_split * T")
282 .Output("output_shape: num_split * int64")
283 .Attr("num_split: int >= 1")
284 .Attr("T: type")
__anone6e195410c02(InferenceContext* c) 285 .SetShapeFn([](InferenceContext* c) {
286 ShapeHandle input_shape = c->input(3);
287 ShapeHandle output_indices =
288 c->Matrix(InferenceContext::kUnknownDim, c->NumElements(input_shape));
289 ShapeHandle output_values = c->Vector(InferenceContext::kUnknownDim);
290 ShapeHandle output_shape = input_shape;
291
292 // Copy the outputs into the output ranges.
293 int num_splits = c->num_outputs() / 3;
294 int out_idx = 0;
295 for (int i = 0; i < num_splits; ++i)
296 c->set_output(out_idx++, output_indices);
297 for (int i = 0; i < num_splits; ++i)
298 c->set_output(out_idx++, output_values);
299 for (int i = 0; i < num_splits; ++i)
300 c->set_output(out_idx++, output_shape);
301 return Status::OK();
302 });
303
304 REGISTER_OP("SparseSliceGrad")
305 .Input("backprop_val_grad: T")
306 .Input("input_indices: int64")
307 .Input("input_start: int64")
308 .Input("output_indices: int64")
309 .Output("val_grad: T")
310 .Attr("T: numbertype")
__anone6e195410d02(InferenceContext* c) 311 .SetShapeFn([](InferenceContext* c) {
312 ShapeHandle indices;
313 TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 2, &indices));
314 c->set_output(0, c->Vector(c->Dim(indices, 0)));
315 return Status::OK();
316 });
317
318 REGISTER_OP("SparseSlice")
319 .Input("indices: int64")
320 .Input("values: T")
321 .Input("shape: int64")
322 .Input("start: int64")
323 .Input("size: int64")
324 .Output("output_indices: int64")
325 .Output("output_values: T")
326 .Output("output_shape: int64")
327 .Attr("T: type")
__anone6e195410e02(InferenceContext* c) 328 .SetShapeFn([](InferenceContext* c) {
329 ShapeHandle input_shape = c->input(2);
330 ShapeHandle output_indices =
331 c->Matrix(InferenceContext::kUnknownDim, c->NumElements(input_shape));
332 ShapeHandle output_values = c->Vector(InferenceContext::kUnknownDim);
333 ShapeHandle output_shape = input_shape;
334
335 c->set_output(0, output_indices);
336 c->set_output(1, output_values);
337 c->set_output(2, output_shape);
338 return Status::OK();
339 });
340
341 REGISTER_OP("SparseReorder")
342 .Input("input_indices: int64")
343 .Input("input_values: T")
344 .Input("input_shape: int64")
345 .Output("output_indices: int64")
346 .Output("output_values: T")
347 .Attr("T: type")
__anone6e195410f02(InferenceContext* c) 348 .SetShapeFn([](InferenceContext* c) {
349 ShapeHandle indices;
350 ShapeHandle values;
351 ShapeHandle unused;
352
353 TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 2, &indices));
354 TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 1, &values));
355 TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 1, &unused));
356
357 c->set_output(0, indices);
358 c->set_output(1, values);
359 return Status::OK();
360 });
361
362 REGISTER_OP("SparseReshape")
363 .Input("input_indices: int64")
364 .Input("input_shape: int64")
365 .Input("new_shape: int64")
366 .Output("output_indices: int64")
367 .Output("output_shape: int64")
__anone6e195411002(InferenceContext* c) 368 .SetShapeFn([](InferenceContext* c) {
369 ShapeHandle indices;
370 ShapeHandle unused;
371 ShapeHandle new_shape;
372
373 TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 2, &indices));
374 TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 1, &unused));
375 TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 1, &new_shape));
376
377 c->set_output(0, c->Matrix(c->Dim(indices, 0), c->Dim(new_shape, 0)));
378 c->set_output(1, new_shape);
379 return Status::OK();
380 });
381
382 REGISTER_OP("SparseTensorDenseAdd")
383 .Input("a_indices: Tindices")
384 .Input("a_values: T")
385 .Input("a_shape: Tindices")
386 .Input("b: T")
387 .Output("output: T")
388 .Attr("T: numbertype")
389 .Attr("Tindices: {int32, int64}")
__anone6e195411102(InferenceContext* c) 390 .SetShapeFn([](InferenceContext* c) {
391 c->set_output(0, c->input(3));
392 return Status::OK();
393 });
394
395 REGISTER_OP("SparseReduceMax")
396 .Input("input_indices: int64")
397 .Input("input_values: T")
398 .Input("input_shape: int64")
399 .Input("reduction_axes: int32")
400 .Attr("keep_dims: bool = False")
401 .Output("output: T")
402 .Attr("T: realnumbertype")
403 .SetShapeFn(shape_inference::SparseReduceShapeFn);
404
405 REGISTER_OP("SparseReduceMaxSparse")
406 .Input("input_indices: int64")
407 .Input("input_values: T")
408 .Input("input_shape: int64")
409 .Input("reduction_axes: int32")
410 .Attr("keep_dims: bool = False")
411 .Output("output_indices: int64")
412 .Output("output_values: T")
413 .Output("output_shape: int64")
414 .Attr("T: realnumbertype")
415 .SetShapeFn(shape_inference::UnknownShape);
416
417 REGISTER_OP("SparseReduceSum")
418 .Input("input_indices: int64")
419 .Input("input_values: T")
420 .Input("input_shape: int64")
421 .Input("reduction_axes: int32")
422 .Attr("keep_dims: bool = False")
423 .Output("output: T")
424 .Attr("T: numbertype")
425 .SetShapeFn(shape_inference::SparseReduceShapeFn);
426
427 REGISTER_OP("SparseReduceSumSparse")
428 .Input("input_indices: int64")
429 .Input("input_values: T")
430 .Input("input_shape: int64")
431 .Input("reduction_axes: int32")
432 .Attr("keep_dims: bool = False")
433 .Output("output_indices: int64")
434 .Output("output_values: T")
435 .Output("output_shape: int64")
436 .Attr("T: numbertype")
437 .SetShapeFn(shape_inference::UnknownShape);
438
439 #define SPARSE_DENSE_CWISE_SIGNATURE() \
440 Input("sp_indices: int64") \
441 .Input("sp_values: T") \
442 .Input("sp_shape: int64") \
443 .Input("dense: T") \
444 .Output("output: T") \
445 .Attr("T: numbertype") \
446 .SetShapeFn([](InferenceContext* c) { \
447 ShapeHandle input; \
448 TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 2, &input)); \
449 c->set_output(0, c->Vector(c->Dim(input, 0))); \
450 return Status::OK(); \
451 })
452
453 REGISTER_OP("SparseDenseCwiseMul").SPARSE_DENSE_CWISE_SIGNATURE();
454
455 REGISTER_OP("SparseDenseCwiseDiv").SPARSE_DENSE_CWISE_SIGNATURE();
456
457 REGISTER_OP("SparseDenseCwiseAdd").SPARSE_DENSE_CWISE_SIGNATURE();
458
459 #undef SPARSE_DENSE_CWISE_SIGNATURE
460
461 REGISTER_OP("SparseSoftmax")
462 .Input("sp_indices: int64")
463 .Input("sp_values: T")
464 .Input("sp_shape: int64")
465 .Output("output: T")
466 .Attr("T: {float, double}")
__anone6e195411202(InferenceContext* c) 467 .SetShapeFn([](InferenceContext* c) {
468 ShapeHandle unused;
469 ShapeHandle values;
470 TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 2, &unused)); // sp_indices
471 TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 1, &values)); // sp_values
472 TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 1, &unused));
473 c->set_output(0, values);
474 return Status::OK();
475 });
476
477 REGISTER_OP("SparseSparseMaximum")
478 .Input("a_indices: int64")
479 .Input("a_values: T")
480 .Input("a_shape: int64")
481 .Input("b_indices: int64")
482 .Input("b_values: T")
483 .Input("b_shape: int64")
484 .Output("output_indices: int64")
485 .Output("output_values: T")
486 .Attr("T: realnumbertype")
487 .SetShapeFn(SparseSparseMinOrMaxShapeFn);
488
489 REGISTER_OP("SparseSparseMinimum")
490 .Input("a_indices: int64")
491 .Input("a_values: T")
492 .Input("a_shape: int64")
493 .Input("b_indices: int64")
494 .Input("b_values: T")
495 .Input("b_shape: int64")
496 .Output("output_indices: int64")
497 .Output("output_values: T")
498 .Attr("T: numbertype")
499 .SetShapeFn(SparseSparseMinOrMaxShapeFn);
500
501 REGISTER_OP("AddSparseToTensorsMap")
502 .Input("sparse_indices: int64")
503 .Input("sparse_values: T")
504 .Input("sparse_shape: int64")
505 .Output("sparse_handle: int64")
506 .Attr("T: type")
507 .Attr("container: string = ''")
508 .Attr("shared_name: string = ''")
509 .SetIsStateful()
__anone6e195411302(InferenceContext* c) 510 .SetShapeFn([](InferenceContext* c) {
511 ShapeHandle unused;
512 TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 2, &unused));
513 TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 1, &unused));
514 TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 1, &unused));
515 c->set_output(0, c->Scalar());
516 return Status::OK();
517 });
518
519 REGISTER_OP("AddManySparseToTensorsMap")
520 .Input("sparse_indices: int64")
521 .Input("sparse_values: T")
522 .Input("sparse_shape: int64")
523 .Output("sparse_handles: int64")
524 .Attr("T: type")
525 .Attr("container: string = ''")
526 .Attr("shared_name: string = ''")
527 .SetIsStateful()
__anone6e195411402(InferenceContext* c) 528 .SetShapeFn([](InferenceContext* c) {
529 ShapeHandle unused;
530 TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 2, &unused));
531 TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 1, &unused));
532 TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 1, &unused));
533 c->set_output(0, c->Vector(InferenceContext::kUnknownDim));
534 return Status::OK();
535 });
536
537 REGISTER_OP("TakeManySparseFromTensorsMap")
538 .Input("sparse_handles: int64")
539 .Output("sparse_indices: int64")
540 .Output("sparse_values: dtype")
541 .Output("sparse_shape: int64")
542 .Attr("dtype: type")
543 .Attr("container: string = ''")
544 .Attr("shared_name: string = ''")
545 .SetIsStateful()
__anone6e195411502(InferenceContext* c) 546 .SetShapeFn([](InferenceContext* c) {
547 // serialized sparse is [?,1] matrix.
548 ShapeHandle sparse_handles;
549 TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 1, &sparse_handles));
550
551 c->set_output(0, c->Matrix(InferenceContext::kUnknownDim,
552 InferenceContext::kUnknownDim));
553 c->set_output(1, c->Vector(InferenceContext::kUnknownDim));
554 c->set_output(2, c->Vector(InferenceContext::kUnknownDim));
555 return Status::OK();
556 });
557
558 REGISTER_OP("SparseFillEmptyRows")
559 .Input("indices: int64")
560 .Input("values: T")
561 .Input("dense_shape: int64")
562 .Input("default_value: T")
563 .Output("output_indices: int64")
564 .Output("output_values: T")
565 .Output("empty_row_indicator: bool")
566 .Output("reverse_index_map: int64")
567 .Attr("T: type")
__anone6e195411602(InferenceContext* c) 568 .SetShapeFn([](InferenceContext* c) {
569 ShapeHandle input_indices = c->input(0);
570 TF_RETURN_IF_ERROR(c->WithRank(input_indices, 2, &input_indices));
571 ShapeHandle input_values = c->input(1);
572 TF_RETURN_IF_ERROR(c->WithRank(input_values, 1, &input_values));
573 ShapeHandle input_shape = c->input(2);
574 TF_RETURN_IF_ERROR(c->WithRank(input_shape, 1, &input_shape));
575 ShapeHandle default_value = c->input(3);
576 TF_RETURN_IF_ERROR(c->WithRank(default_value, 0, &default_value));
577 DimensionHandle N = c->Dim(input_indices, 0);
578 TF_RETURN_IF_ERROR(c->Merge(N, c->Dim(input_values, 0), &N));
579 DimensionHandle unused_dim;
580 TF_RETURN_IF_ERROR(c->Merge(c->Dim(input_indices, 1),
581 c->Dim(input_shape, 0), &unused_dim));
582 ShapeHandle output_indices =
583 c->Matrix(InferenceContext::kUnknownDim, c->NumElements(input_shape));
584 ShapeHandle output_values = c->Vector(InferenceContext::kUnknownDim);
585 ShapeHandle constant_input_shape;
586 TF_RETURN_IF_ERROR(c->MakeShapeFromShapeTensor(2, &constant_input_shape));
587 ShapeHandle empty_row_indicator =
588 c->Vector(c->Dim(constant_input_shape, 0));
589 ShapeHandle reverse_index_map = c->Vector(N);
590 c->set_output(0, output_indices);
591 c->set_output(1, output_values);
592 c->set_output(2, empty_row_indicator);
593 c->set_output(3, reverse_index_map);
594 return Status::OK();
595 });
596
597 REGISTER_OP("SparseFillEmptyRowsGrad")
598 .Input("reverse_index_map: int64")
599 .Input("grad_values: T")
600 .Output("d_values: T")
601 .Output("d_default_value: T")
602 .Attr("T: type")
__anone6e195411702(InferenceContext* c) 603 .SetShapeFn([](InferenceContext* c) {
604 ShapeHandle reverse_index_map = c->input(0);
605 TF_RETURN_IF_ERROR(c->WithRank(reverse_index_map, 1, &reverse_index_map));
606 ShapeHandle grad_values = c->input(1);
607 TF_RETURN_IF_ERROR(c->WithRank(grad_values, 1, &grad_values));
608 c->set_output(0, reverse_index_map);
609 c->set_output(1, c->Scalar());
610 return Status::OK();
611 });
612
613 } // namespace tensorflow
614