1 /* Copyright 2016 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 #include "tensorflow/core/framework/common_shape_fns.h" 17 #include "tensorflow/core/framework/op.h" 18 #include "tensorflow/core/framework/shape_inference.h" 19 20 namespace tensorflow { 21 22 using shape_inference::DimensionHandle; 23 using shape_inference::InferenceContext; 24 using shape_inference::ShapeHandle; 25 26 REGISTER_OP("SetSize") 27 .Input("set_indices: int64") 28 .Input("set_values: T") 29 .Input("set_shape: int64") 30 .Attr("validate_indices: bool = true") 31 .Attr("T: {int8, int16, int32, int64, uint8, uint16, string}") 32 .Output("size: int32") 33 .SetShapeFn(shape_inference::UnknownShape); 34 35 REGISTER_OP("DenseToDenseSetOperation") 36 .Input("set1: T") 37 .Input("set2: T") 38 .Attr("set_operation: string") 39 .Attr("validate_indices: bool = true") 40 .Attr("T: {int8, int16, int32, int64, uint8, uint16, string}") 41 .Output("result_indices: int64") 42 .Output("result_values: T") 43 .Output("result_shape: int64") __anon9b840f3f0102(InferenceContext* c) 44 .SetShapeFn([](InferenceContext* c) { 45 if (c->num_inputs() != 2) { 46 return errors::InvalidArgument("len(inputs) != 2."); 47 } 48 // The following should stay in sync with `ComputeDenseToDense` shape 49 // assertions in kernels/set_kernels.cc. 50 // Dimension n contains the set values to be compared, so ranks must be 51 // >= 2, and the first n-1 dimensions of inputs and output must be 52 // compatible. 53 DimensionHandle output_rank; 54 ShapeHandle input0_shape = c->input(0); 55 TF_RETURN_IF_ERROR(c->WithRankAtLeast(input0_shape, 2, &input0_shape)); 56 if (c->RankKnown(input0_shape)) { 57 const int32 input0_rank = c->Rank(input0_shape); 58 ShapeHandle input1_shape = c->input(1); 59 TF_RETURN_IF_ERROR( 60 c->WithRank(input1_shape, input0_rank, &input1_shape)); 61 if (c->RankKnown(input1_shape)) { 62 // If both ranks are specified, the first n-1 dims must be compatible. 63 const int32 rank = c->Rank(input1_shape); 64 ShapeHandle group0_shape; 65 TF_RETURN_IF_ERROR( 66 c->Subshape(input0_shape, 0, rank - 1, &group0_shape)); 67 ShapeHandle group1_shape; 68 TF_RETURN_IF_ERROR( 69 c->Subshape(input1_shape, 0, rank - 1, &group1_shape)); 70 ShapeHandle unused_shape; 71 TF_RETURN_IF_ERROR( 72 c->Merge(group0_shape, group1_shape, &unused_shape)); 73 } 74 output_rank = c->MakeDim(input0_rank); 75 } else { 76 ShapeHandle input1_shape = c->input(1); 77 TF_RETURN_IF_ERROR(c->WithRankAtLeast(input1_shape, 2, &input1_shape)); 78 if (c->RankKnown(input1_shape)) { 79 output_rank = c->MakeDim(c->Rank(input1_shape)); 80 } else { 81 output_rank = c->UnknownDim(); 82 } 83 } 84 85 c->set_output(0, c->Matrix(c->UnknownDim(), output_rank)); 86 c->set_output(1, c->Vector(c->UnknownDim())); 87 c->set_output(2, c->Vector(output_rank)); 88 return Status::OK(); 89 }); 90 91 REGISTER_OP("DenseToSparseSetOperation") 92 .Input("set1: T") 93 .Input("set2_indices: int64") 94 .Input("set2_values: T") 95 .Input("set2_shape: int64") 96 .Attr("set_operation: string") 97 .Attr("validate_indices: bool = true") 98 .Attr("T: {int8, int16, int32, int64, uint8, uint16, string}") 99 .Output("result_indices: int64") 100 .Output("result_values: T") 101 .Output("result_shape: int64") __anon9b840f3f0202(InferenceContext* c) 102 .SetShapeFn([](InferenceContext* c) { 103 if (c->num_inputs() != 4) { 104 return errors::InvalidArgument("len(inputs) != 4."); 105 } 106 // The following should stay in sync with `ComputeDenseToSparse` shape 107 // assertions in kernels/set_kernels.cc. 108 // Ranks must be compatible, and be >= 2. 109 ShapeHandle input1_shape_shape = c->input(3); 110 TF_RETURN_IF_ERROR(shape_inference::ValidateSparseTensor( 111 c, c->input(1), c->input(2), input1_shape_shape)); 112 113 DimensionHandle input1_rank_dim = c->Dim(input1_shape_shape, 0); 114 115 DimensionHandle output_rank_dim; 116 ShapeHandle input0_shape = c->input(0); 117 TF_RETURN_IF_ERROR(c->WithRankAtLeast(input0_shape, 2, &input0_shape)); 118 if (c->RankKnown(input0_shape)) { 119 const int32 input0_rank = c->Rank(input0_shape); 120 TF_RETURN_IF_ERROR( 121 c->WithValue(input1_rank_dim, input0_rank, &input1_rank_dim)); 122 output_rank_dim = c->MakeDim(input0_rank); 123 } else if (c->ValueKnown(input1_rank_dim)) { 124 output_rank_dim = input1_rank_dim; 125 } else { 126 output_rank_dim = c->UnknownDim(); 127 } 128 129 c->set_output(0, c->Matrix(c->UnknownDim(), output_rank_dim)); 130 c->set_output(1, c->Vector(c->UnknownDim())); 131 c->set_output(2, c->Vector(output_rank_dim)); 132 return Status::OK(); 133 }); 134 135 REGISTER_OP("SparseToSparseSetOperation") 136 .Input("set1_indices: int64") 137 .Input("set1_values: T") 138 .Input("set1_shape: int64") 139 .Input("set2_indices: int64") 140 .Input("set2_values: T") 141 .Input("set2_shape: int64") 142 .Attr("set_operation: string") 143 .Attr("validate_indices: bool = true") 144 .Attr("T: {int8, int16, int32, int64, uint8, uint16, string}") 145 .Output("result_indices: int64") 146 .Output("result_values: T") 147 .Output("result_shape: int64") __anon9b840f3f0302(InferenceContext* c) 148 .SetShapeFn([](InferenceContext* c) { 149 if (c->num_inputs() != 6) { 150 return errors::InvalidArgument("len(inputs) != 6."); 151 } 152 // The following should stay in sync with `ComputeSparseToSparse` shape 153 // assertions in kernels/set_kernels.cc. 154 // Ranks must be compatible, and be >= 2. 155 ShapeHandle input0_shape_shape = c->input(2); 156 ShapeHandle input1_shape_shape = c->input(5); 157 TF_RETURN_IF_ERROR(shape_inference::ValidateSparseTensor( 158 c, c->input(0), c->input(1), input0_shape_shape)); 159 TF_RETURN_IF_ERROR(shape_inference::ValidateSparseTensor( 160 c, c->input(3), c->input(4), input1_shape_shape)); 161 162 DimensionHandle input0_rank_dim = c->Dim(input0_shape_shape, 0); 163 DimensionHandle input1_rank_dim = c->Dim(input1_shape_shape, 0); 164 DimensionHandle output_rank_dim; 165 if (c->ValueKnown(input0_rank_dim)) { 166 const int64 input0_rank = c->Value(input0_rank_dim); 167 if (input0_rank < 2) { 168 return errors::InvalidArgument("Input 0, expected rank >= 2, got ", 169 input0_rank, "."); 170 } 171 TF_RETURN_IF_ERROR( 172 c->WithValue(input1_rank_dim, input0_rank, &input1_rank_dim)); 173 output_rank_dim = input0_rank_dim; 174 } else if (c->ValueKnown(input1_rank_dim)) { 175 const int64 input1_rank = c->Value(input1_rank_dim); 176 if (input1_rank < 2) { 177 return errors::InvalidArgument("Input 1, expected rank >= 2, got ", 178 input1_rank, "."); 179 } 180 output_rank_dim = input1_rank_dim; 181 } else { 182 output_rank_dim = c->UnknownDim(); 183 } 184 185 c->set_output(0, c->Matrix(c->UnknownDim(), output_rank_dim)); 186 c->set_output(1, c->Vector(c->UnknownDim())); 187 c->set_output(2, c->Vector(output_rank_dim)); 188 return Status::OK(); 189 }); 190 191 } // namespace tensorflow 192