1 /* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/core/framework/op.h"
17 #include "tensorflow/core/framework/shape_inference.h"
18 
19 namespace tensorflow {
20 
21 using shape_inference::DimensionHandle;
22 using shape_inference::InferenceContext;
23 using shape_inference::ShapeHandle;
24 
25 // CTC is Connectionist Temporal Classification.  See util/ctc/ for details.
26 
27 REGISTER_OP("CTCLoss")
28     .Input("inputs: T")
29     .Input("labels_indices: int64")
30     .Input("labels_values: int32")
31     .Input("sequence_length: int32")
32     .Attr("preprocess_collapse_repeated: bool = false")
33     .Attr("ctc_merge_repeated: bool = true")
34     .Attr("ignore_longer_outputs_than_inputs: bool = false")
35     .Output("loss: T")
36     .Output("gradient: T")
37     .Attr("T: {float, double} = DT_FLOAT")
__anon7dc6794d0102(InferenceContext* c) 38     .SetShapeFn([](InferenceContext* c) {
39       ShapeHandle inputs;
40       ShapeHandle labels_indices;
41       ShapeHandle labels_values;
42       ShapeHandle sequence_length;
43 
44       TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 3, &inputs));
45       TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 2, &labels_indices));
46       TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 1, &labels_values));
47       TF_RETURN_IF_ERROR(c->WithRank(c->input(3), 1, &sequence_length));
48 
49       DimensionHandle unused;
50       TF_RETURN_IF_ERROR(c->Merge(c->Dim(labels_indices, 0),
51                                   c->Dim(labels_values, 0), &unused));
52 
53       // Get batch size from inputs and sequence_length, and update inputs
54       // with the merged batch_size since it is returned.
55       DimensionHandle batch_size;
56       TF_RETURN_IF_ERROR(
57           c->Merge(c->Dim(inputs, 1), c->Dim(sequence_length, 0), &batch_size));
58       TF_RETURN_IF_ERROR(c->ReplaceDim(inputs, 1, batch_size, &inputs));
59 
60       c->set_output(0, c->Vector(batch_size));
61       c->set_output(1, inputs);
62       return Status::OK();
63     });
64 
65 REGISTER_OP("CTCLossV2")
66     .Input("inputs: float")
67     .Input("labels_indices: int64")
68     .Input("labels_values: int32")
69     .Input("sequence_length: int32")
70     .Attr("preprocess_collapse_repeated: bool = false")
71     .Attr("ctc_merge_repeated: bool = true")
72     .Attr("ignore_longer_outputs_than_inputs: bool = false")
73     .Output("loss: float")
74     .Output("gradient: float")
__anon7dc6794d0202(InferenceContext* c) 75     .SetShapeFn([](InferenceContext* c) {
76       ShapeHandle inputs;
77       ShapeHandle labels_indices;
78       ShapeHandle labels_values;
79       ShapeHandle sequence_length;
80 
81       TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 3, &inputs));
82       TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 2, &labels_indices));
83       TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 1, &labels_values));
84       TF_RETURN_IF_ERROR(c->WithRank(c->input(3), 1, &sequence_length));
85 
86       DimensionHandle unused;
87       TF_RETURN_IF_ERROR(c->Merge(c->Dim(labels_indices, 0),
88                                   c->Dim(labels_values, 0), &unused));
89 
90       // Get batch size from inputs and sequence_length, and update inputs
91       // with the merged batch_size since it is returned.
92       DimensionHandle batch_size;
93       TF_RETURN_IF_ERROR(
94           c->Merge(c->Dim(inputs, 1), c->Dim(sequence_length, 0), &batch_size));
95       TF_RETURN_IF_ERROR(c->ReplaceDim(inputs, 1, batch_size, &inputs));
96 
97       c->set_output(0, c->Vector(batch_size));
98       c->set_output(1, inputs);
99       return Status::OK();
100     });
101 
102 REGISTER_OP("CTCGreedyDecoder")
103     .Input("inputs: T")
104     .Input("sequence_length: int32")
105     .Attr("merge_repeated: bool = false")
106     .Output("decoded_indices: int64")
107     .Output("decoded_values: int64")
108     .Output("decoded_shape: int64")
109     .Output("log_probability: T")
110     .Attr("T: {float, double} = DT_FLOAT")
__anon7dc6794d0302(InferenceContext* c) 111     .SetShapeFn([](InferenceContext* c) {
112       ShapeHandle inputs;
113       ShapeHandle sequence_length;
114 
115       TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 3, &inputs));
116       TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 1, &sequence_length));
117 
118       // Get batch size from inputs and sequence_length.
119       DimensionHandle batch_size;
120       TF_RETURN_IF_ERROR(
121           c->Merge(c->Dim(inputs, 1), c->Dim(sequence_length, 0), &batch_size));
122 
123       DimensionHandle total_decoded_outputs = c->UnknownDim();
124       c->set_output(0, c->Matrix(total_decoded_outputs, 2));
125       c->set_output(1, c->Vector(total_decoded_outputs));
126       c->set_output(2, c->Vector(2));
127       c->set_output(3, c->Matrix(batch_size, 1));
128       return Status::OK();
129     });
130 
131 REGISTER_OP("CTCBeamSearchDecoder")
132     .Input("inputs: T")
133     .Input("sequence_length: int32")
134     .Attr("beam_width: int >= 1")
135     .Attr("top_paths: int >= 1")
136     .Attr("merge_repeated: bool = true")
137     .Output("decoded_indices: top_paths * int64")
138     .Output("decoded_values: top_paths * int64")
139     .Output("decoded_shape: top_paths * int64")
140     .Output("log_probability: T")
141     .Attr("T: {float, double} = DT_FLOAT")
__anon7dc6794d0402(InferenceContext* c) 142     .SetShapeFn([](InferenceContext* c) {
143       ShapeHandle inputs;
144       ShapeHandle sequence_length;
145 
146       TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 3, &inputs));
147       TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 1, &sequence_length));
148 
149       // Get batch size from inputs and sequence_length.
150       DimensionHandle batch_size;
151       TF_RETURN_IF_ERROR(
152           c->Merge(c->Dim(inputs, 1), c->Dim(sequence_length, 0), &batch_size));
153 
154       int32 top_paths;
155       TF_RETURN_IF_ERROR(c->GetAttr("top_paths", &top_paths));
156 
157       // Outputs.
158       int out_idx = 0;
159       for (int i = 0; i < top_paths; ++i) {  // decoded_indices
160         c->set_output(out_idx++, c->Matrix(InferenceContext::kUnknownDim, 2));
161       }
162       for (int i = 0; i < top_paths; ++i) {  // decoded_values
163         c->set_output(out_idx++, c->Vector(InferenceContext::kUnknownDim));
164       }
165       ShapeHandle shape_v = c->Vector(2);
166       for (int i = 0; i < top_paths; ++i) {  // decoded_shape
167         c->set_output(out_idx++, shape_v);
168       }
169       c->set_output(out_idx++, c->Matrix(batch_size, top_paths));
170       return Status::OK();
171     });
172 
173 }  // namespace tensorflow
174