1 /* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/core/framework/common_shape_fns.h"
17 #include "tensorflow/core/framework/op.h"
18 #include "tensorflow/core/framework/shape_inference.h"
19 
20 namespace tensorflow {
21 
22 using shape_inference::DimensionHandle;
23 using shape_inference::InferenceContext;
24 using shape_inference::ShapeHandle;
25 
26 REGISTER_OP("SetSize")
27     .Input("set_indices: int64")
28     .Input("set_values: T")
29     .Input("set_shape: int64")
30     .Attr("validate_indices: bool = true")
31     .Attr("T: {int8, int16, int32, int64, uint8, uint16, string}")
32     .Output("size: int32")
33     .SetShapeFn(shape_inference::UnknownShape);
34 
35 REGISTER_OP("DenseToDenseSetOperation")
36     .Input("set1: T")
37     .Input("set2: T")
38     .Attr("set_operation: string")
39     .Attr("validate_indices: bool = true")
40     .Attr("T: {int8, int16, int32, int64, uint8, uint16, string}")
41     .Output("result_indices: int64")
42     .Output("result_values: T")
43     .Output("result_shape: int64")
__anon9b840f3f0102(InferenceContext* c) 44     .SetShapeFn([](InferenceContext* c) {
45       if (c->num_inputs() != 2) {
46         return errors::InvalidArgument("len(inputs) != 2.");
47       }
48       // The following should stay in sync with `ComputeDenseToDense` shape
49       // assertions in kernels/set_kernels.cc.
50       // Dimension n contains the set values to be compared, so ranks must be
51       // >= 2, and the first n-1 dimensions of inputs and output must be
52       // compatible.
53       DimensionHandle output_rank;
54       ShapeHandle input0_shape = c->input(0);
55       TF_RETURN_IF_ERROR(c->WithRankAtLeast(input0_shape, 2, &input0_shape));
56       if (c->RankKnown(input0_shape)) {
57         const int32 input0_rank = c->Rank(input0_shape);
58         ShapeHandle input1_shape = c->input(1);
59         TF_RETURN_IF_ERROR(
60             c->WithRank(input1_shape, input0_rank, &input1_shape));
61         if (c->RankKnown(input1_shape)) {
62           // If both ranks are specified, the first n-1 dims must be compatible.
63           const int32 rank = c->Rank(input1_shape);
64           ShapeHandle group0_shape;
65           TF_RETURN_IF_ERROR(
66               c->Subshape(input0_shape, 0, rank - 1, &group0_shape));
67           ShapeHandle group1_shape;
68           TF_RETURN_IF_ERROR(
69               c->Subshape(input1_shape, 0, rank - 1, &group1_shape));
70           ShapeHandle unused_shape;
71           TF_RETURN_IF_ERROR(
72               c->Merge(group0_shape, group1_shape, &unused_shape));
73         }
74         output_rank = c->MakeDim(input0_rank);
75       } else {
76         ShapeHandle input1_shape = c->input(1);
77         TF_RETURN_IF_ERROR(c->WithRankAtLeast(input1_shape, 2, &input1_shape));
78         if (c->RankKnown(input1_shape)) {
79           output_rank = c->MakeDim(c->Rank(input1_shape));
80         } else {
81           output_rank = c->UnknownDim();
82         }
83       }
84 
85       c->set_output(0, c->Matrix(c->UnknownDim(), output_rank));
86       c->set_output(1, c->Vector(c->UnknownDim()));
87       c->set_output(2, c->Vector(output_rank));
88       return Status::OK();
89     });
90 
91 REGISTER_OP("DenseToSparseSetOperation")
92     .Input("set1: T")
93     .Input("set2_indices: int64")
94     .Input("set2_values: T")
95     .Input("set2_shape: int64")
96     .Attr("set_operation: string")
97     .Attr("validate_indices: bool = true")
98     .Attr("T: {int8, int16, int32, int64, uint8, uint16, string}")
99     .Output("result_indices: int64")
100     .Output("result_values: T")
101     .Output("result_shape: int64")
__anon9b840f3f0202(InferenceContext* c) 102     .SetShapeFn([](InferenceContext* c) {
103       if (c->num_inputs() != 4) {
104         return errors::InvalidArgument("len(inputs) != 4.");
105       }
106       // The following should stay in sync with `ComputeDenseToSparse` shape
107       // assertions in kernels/set_kernels.cc.
108       // Ranks must be compatible, and be >= 2.
109       ShapeHandle input1_shape_shape = c->input(3);
110       TF_RETURN_IF_ERROR(shape_inference::ValidateSparseTensor(
111           c, c->input(1), c->input(2), input1_shape_shape));
112 
113       DimensionHandle input1_rank_dim = c->Dim(input1_shape_shape, 0);
114 
115       DimensionHandle output_rank_dim;
116       ShapeHandle input0_shape = c->input(0);
117       TF_RETURN_IF_ERROR(c->WithRankAtLeast(input0_shape, 2, &input0_shape));
118       if (c->RankKnown(input0_shape)) {
119         const int32 input0_rank = c->Rank(input0_shape);
120         TF_RETURN_IF_ERROR(
121             c->WithValue(input1_rank_dim, input0_rank, &input1_rank_dim));
122         output_rank_dim = c->MakeDim(input0_rank);
123       } else if (c->ValueKnown(input1_rank_dim)) {
124         output_rank_dim = input1_rank_dim;
125       } else {
126         output_rank_dim = c->UnknownDim();
127       }
128 
129       c->set_output(0, c->Matrix(c->UnknownDim(), output_rank_dim));
130       c->set_output(1, c->Vector(c->UnknownDim()));
131       c->set_output(2, c->Vector(output_rank_dim));
132       return Status::OK();
133     });
134 
135 REGISTER_OP("SparseToSparseSetOperation")
136     .Input("set1_indices: int64")
137     .Input("set1_values: T")
138     .Input("set1_shape: int64")
139     .Input("set2_indices: int64")
140     .Input("set2_values: T")
141     .Input("set2_shape: int64")
142     .Attr("set_operation: string")
143     .Attr("validate_indices: bool = true")
144     .Attr("T: {int8, int16, int32, int64, uint8, uint16, string}")
145     .Output("result_indices: int64")
146     .Output("result_values: T")
147     .Output("result_shape: int64")
__anon9b840f3f0302(InferenceContext* c) 148     .SetShapeFn([](InferenceContext* c) {
149       if (c->num_inputs() != 6) {
150         return errors::InvalidArgument("len(inputs) != 6.");
151       }
152       // The following should stay in sync with `ComputeSparseToSparse` shape
153       // assertions in kernels/set_kernels.cc.
154       // Ranks must be compatible, and be >= 2.
155       ShapeHandle input0_shape_shape = c->input(2);
156       ShapeHandle input1_shape_shape = c->input(5);
157       TF_RETURN_IF_ERROR(shape_inference::ValidateSparseTensor(
158           c, c->input(0), c->input(1), input0_shape_shape));
159       TF_RETURN_IF_ERROR(shape_inference::ValidateSparseTensor(
160           c, c->input(3), c->input(4), input1_shape_shape));
161 
162       DimensionHandle input0_rank_dim = c->Dim(input0_shape_shape, 0);
163       DimensionHandle input1_rank_dim = c->Dim(input1_shape_shape, 0);
164       DimensionHandle output_rank_dim;
165       if (c->ValueKnown(input0_rank_dim)) {
166         const int64 input0_rank = c->Value(input0_rank_dim);
167         if (input0_rank < 2) {
168           return errors::InvalidArgument("Input 0, expected rank >= 2, got ",
169                                          input0_rank, ".");
170         }
171         TF_RETURN_IF_ERROR(
172             c->WithValue(input1_rank_dim, input0_rank, &input1_rank_dim));
173         output_rank_dim = input0_rank_dim;
174       } else if (c->ValueKnown(input1_rank_dim)) {
175         const int64 input1_rank = c->Value(input1_rank_dim);
176         if (input1_rank < 2) {
177           return errors::InvalidArgument("Input 1, expected rank >= 2, got ",
178                                          input1_rank, ".");
179         }
180         output_rank_dim = input1_rank_dim;
181       } else {
182         output_rank_dim = c->UnknownDim();
183       }
184 
185       c->set_output(0, c->Matrix(c->UnknownDim(), output_rank_dim));
186       c->set_output(1, c->Vector(c->UnknownDim()));
187       c->set_output(2, c->Vector(output_rank_dim));
188       return Status::OK();
189     });
190 
191 }  // namespace tensorflow
192