1 /* Copyright 2016 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 #include "tensorflow/core/framework/op.h" 17 #include "tensorflow/core/framework/shape_inference.h" 18 19 namespace tensorflow { 20 21 using shape_inference::DimensionHandle; 22 using shape_inference::InferenceContext; 23 using shape_inference::ShapeHandle; 24 25 // CTC is Connectionist Temporal Classification. See util/ctc/ for details. 26 27 REGISTER_OP("CTCLoss") 28 .Input("inputs: T") 29 .Input("labels_indices: int64") 30 .Input("labels_values: int32") 31 .Input("sequence_length: int32") 32 .Attr("preprocess_collapse_repeated: bool = false") 33 .Attr("ctc_merge_repeated: bool = true") 34 .Attr("ignore_longer_outputs_than_inputs: bool = false") 35 .Output("loss: T") 36 .Output("gradient: T") 37 .Attr("T: {float, double} = DT_FLOAT") __anon7dc6794d0102(InferenceContext* c) 38 .SetShapeFn([](InferenceContext* c) { 39 ShapeHandle inputs; 40 ShapeHandle labels_indices; 41 ShapeHandle labels_values; 42 ShapeHandle sequence_length; 43 44 TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 3, &inputs)); 45 TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 2, &labels_indices)); 46 TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 1, &labels_values)); 47 TF_RETURN_IF_ERROR(c->WithRank(c->input(3), 1, &sequence_length)); 48 49 DimensionHandle unused; 50 TF_RETURN_IF_ERROR(c->Merge(c->Dim(labels_indices, 0), 51 c->Dim(labels_values, 0), &unused)); 52 53 // Get batch size from inputs and sequence_length, and update inputs 54 // with the merged batch_size since it is returned. 55 DimensionHandle batch_size; 56 TF_RETURN_IF_ERROR( 57 c->Merge(c->Dim(inputs, 1), c->Dim(sequence_length, 0), &batch_size)); 58 TF_RETURN_IF_ERROR(c->ReplaceDim(inputs, 1, batch_size, &inputs)); 59 60 c->set_output(0, c->Vector(batch_size)); 61 c->set_output(1, inputs); 62 return Status::OK(); 63 }); 64 65 REGISTER_OP("CTCLossV2") 66 .Input("inputs: float") 67 .Input("labels_indices: int64") 68 .Input("labels_values: int32") 69 .Input("sequence_length: int32") 70 .Attr("preprocess_collapse_repeated: bool = false") 71 .Attr("ctc_merge_repeated: bool = true") 72 .Attr("ignore_longer_outputs_than_inputs: bool = false") 73 .Output("loss: float") 74 .Output("gradient: float") __anon7dc6794d0202(InferenceContext* c) 75 .SetShapeFn([](InferenceContext* c) { 76 ShapeHandle inputs; 77 ShapeHandle labels_indices; 78 ShapeHandle labels_values; 79 ShapeHandle sequence_length; 80 81 TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 3, &inputs)); 82 TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 2, &labels_indices)); 83 TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 1, &labels_values)); 84 TF_RETURN_IF_ERROR(c->WithRank(c->input(3), 1, &sequence_length)); 85 86 DimensionHandle unused; 87 TF_RETURN_IF_ERROR(c->Merge(c->Dim(labels_indices, 0), 88 c->Dim(labels_values, 0), &unused)); 89 90 // Get batch size from inputs and sequence_length, and update inputs 91 // with the merged batch_size since it is returned. 92 DimensionHandle batch_size; 93 TF_RETURN_IF_ERROR( 94 c->Merge(c->Dim(inputs, 1), c->Dim(sequence_length, 0), &batch_size)); 95 TF_RETURN_IF_ERROR(c->ReplaceDim(inputs, 1, batch_size, &inputs)); 96 97 c->set_output(0, c->Vector(batch_size)); 98 c->set_output(1, inputs); 99 return Status::OK(); 100 }); 101 102 REGISTER_OP("CTCGreedyDecoder") 103 .Input("inputs: T") 104 .Input("sequence_length: int32") 105 .Attr("merge_repeated: bool = false") 106 .Output("decoded_indices: int64") 107 .Output("decoded_values: int64") 108 .Output("decoded_shape: int64") 109 .Output("log_probability: T") 110 .Attr("T: {float, double} = DT_FLOAT") __anon7dc6794d0302(InferenceContext* c) 111 .SetShapeFn([](InferenceContext* c) { 112 ShapeHandle inputs; 113 ShapeHandle sequence_length; 114 115 TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 3, &inputs)); 116 TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 1, &sequence_length)); 117 118 // Get batch size from inputs and sequence_length. 119 DimensionHandle batch_size; 120 TF_RETURN_IF_ERROR( 121 c->Merge(c->Dim(inputs, 1), c->Dim(sequence_length, 0), &batch_size)); 122 123 DimensionHandle total_decoded_outputs = c->UnknownDim(); 124 c->set_output(0, c->Matrix(total_decoded_outputs, 2)); 125 c->set_output(1, c->Vector(total_decoded_outputs)); 126 c->set_output(2, c->Vector(2)); 127 c->set_output(3, c->Matrix(batch_size, 1)); 128 return Status::OK(); 129 }); 130 131 REGISTER_OP("CTCBeamSearchDecoder") 132 .Input("inputs: T") 133 .Input("sequence_length: int32") 134 .Attr("beam_width: int >= 1") 135 .Attr("top_paths: int >= 1") 136 .Attr("merge_repeated: bool = true") 137 .Output("decoded_indices: top_paths * int64") 138 .Output("decoded_values: top_paths * int64") 139 .Output("decoded_shape: top_paths * int64") 140 .Output("log_probability: T") 141 .Attr("T: {float, double} = DT_FLOAT") __anon7dc6794d0402(InferenceContext* c) 142 .SetShapeFn([](InferenceContext* c) { 143 ShapeHandle inputs; 144 ShapeHandle sequence_length; 145 146 TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 3, &inputs)); 147 TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 1, &sequence_length)); 148 149 // Get batch size from inputs and sequence_length. 150 DimensionHandle batch_size; 151 TF_RETURN_IF_ERROR( 152 c->Merge(c->Dim(inputs, 1), c->Dim(sequence_length, 0), &batch_size)); 153 154 int32 top_paths; 155 TF_RETURN_IF_ERROR(c->GetAttr("top_paths", &top_paths)); 156 157 // Outputs. 158 int out_idx = 0; 159 for (int i = 0; i < top_paths; ++i) { // decoded_indices 160 c->set_output(out_idx++, c->Matrix(InferenceContext::kUnknownDim, 2)); 161 } 162 for (int i = 0; i < top_paths; ++i) { // decoded_values 163 c->set_output(out_idx++, c->Vector(InferenceContext::kUnknownDim)); 164 } 165 ShapeHandle shape_v = c->Vector(2); 166 for (int i = 0; i < top_paths; ++i) { // decoded_shape 167 c->set_output(out_idx++, shape_v); 168 } 169 c->set_output(out_idx++, c->Matrix(batch_size, top_paths)); 170 return Status::OK(); 171 }); 172 173 } // namespace tensorflow 174