1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/core/framework/op.h"
17 #include "tensorflow/core/framework/shape_inference.h"
18 
19 namespace tensorflow {
20 
21 using shape_inference::DimensionHandle;
22 using shape_inference::InferenceContext;
23 using shape_inference::ShapeHandle;
24 
25 namespace {
26 
CandidateSamplerShapeFn(InferenceContext * c)27 Status CandidateSamplerShapeFn(InferenceContext* c) {
28   int64 num_sampled;
29   TF_RETURN_IF_ERROR(c->GetAttr("num_sampled", &num_sampled));
30   int64 num_true;
31   TF_RETURN_IF_ERROR(c->GetAttr("num_true", &num_true));
32 
33   ShapeHandle true_classes_shape;
34   TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 2, &true_classes_shape));
35   DimensionHandle batch_size = c->Dim(true_classes_shape, 0);
36 
37   ShapeHandle num_sampled_v = c->Vector(num_sampled);
38   c->set_output(0, num_sampled_v);
39   c->set_output(1, c->Matrix(batch_size, num_true));
40   c->set_output(2, num_sampled_v);
41   return Status::OK();
42 }
43 
44 }  // namespace
45 
46 REGISTER_OP("UniformCandidateSampler")
47     .Input("true_classes: int64")
48     .Output("sampled_candidates: int64")
49     .Output("true_expected_count: float")
50     .Output("sampled_expected_count: float")
51     .Attr("num_true: int >= 1")
52     .Attr("num_sampled: int >= 1")
53     .Attr("unique: bool")
54     .Attr("range_max: int >= 1")
55     .Attr("seed: int = 0")
56     .Attr("seed2: int = 0")
57     .SetShapeFn(CandidateSamplerShapeFn)
58     .SetIsStateful();
59 
60 REGISTER_OP("LogUniformCandidateSampler")
61     .Input("true_classes: int64")
62     .Output("sampled_candidates: int64")
63     .Output("true_expected_count: float")
64     .Output("sampled_expected_count: float")
65     .Attr("num_true: int >= 1")
66     .Attr("num_sampled: int >= 1")
67     .Attr("unique: bool")
68     .Attr("range_max: int >= 1")
69     .Attr("seed: int = 0")
70     .Attr("seed2: int = 0")
71     .SetShapeFn(CandidateSamplerShapeFn)
72     .SetIsStateful();
73 
74 REGISTER_OP("LearnedUnigramCandidateSampler")
75     .Input("true_classes: int64")
76     .Output("sampled_candidates: int64")
77     .Output("true_expected_count: float")
78     .Output("sampled_expected_count: float")
79     .Attr("num_true: int >= 1")
80     .Attr("num_sampled: int >= 1")
81     .Attr("unique: bool")
82     .Attr("range_max: int >= 1")
83     .Attr("seed: int = 0")
84     .Attr("seed2: int = 0")
85     .SetShapeFn(CandidateSamplerShapeFn)
86     .SetIsStateful();
87 
88 REGISTER_OP("ThreadUnsafeUnigramCandidateSampler")
89     .Input("true_classes: int64")
90     .Output("sampled_candidates: int64")
91     .Output("true_expected_count: float")
92     .Output("sampled_expected_count: float")
93     .Attr("num_true: int >= 1")
94     .Attr("num_sampled: int >= 1")
95     .Attr("unique: bool")
96     .Attr("range_max: int >= 1")
97     .Attr("seed: int = 0")
98     .Attr("seed2: int = 0")
99     .SetShapeFn(CandidateSamplerShapeFn)
100     .SetIsStateful();
101 
102 REGISTER_OP("FixedUnigramCandidateSampler")
103     .Input("true_classes: int64")
104     .Output("sampled_candidates: int64")
105     .Output("true_expected_count: float")
106     .Output("sampled_expected_count: float")
107     .Attr("num_true: int >= 1")
108     .Attr("num_sampled: int >= 1")
109     .Attr("unique: bool")
110     .Attr("range_max: int >= 1")
111     .Attr("vocab_file: string = ''")
112     .Attr("distortion: float = 1.0")
113     .Attr("num_reserved_ids: int = 0")
114     .Attr("num_shards: int >= 1 = 1")
115     .Attr("shard: int >= 0 = 0")
116     .Attr("unigrams: list(float) = []")
117     .Attr("seed: int = 0")
118     .Attr("seed2: int = 0")
119     .SetShapeFn(CandidateSamplerShapeFn)
120     .SetIsStateful();
121 
122 REGISTER_OP("AllCandidateSampler")
123     .Input("true_classes: int64")
124     .Output("sampled_candidates: int64")
125     .Output("true_expected_count: float")
126     .Output("sampled_expected_count: float")
127     .Attr("num_true: int >= 1")
128     .Attr("num_sampled: int >= 1")
129     .Attr("unique: bool")
130     .Attr("seed: int = 0")
131     .Attr("seed2: int = 0")
132     .SetShapeFn(CandidateSamplerShapeFn)
133     .SetIsStateful();
134 
135 REGISTER_OP("ComputeAccidentalHits")
136     .Input("true_classes: int64")
137     .Input("sampled_candidates: int64")
138     .Output("indices: int32")
139     .Output("ids: int64")
140     .Output("weights: float")
141     .Attr("num_true: int")
142     .Attr("seed: int = 0")
143     .Attr("seed2: int = 0")
__anon3d65770a0202(InferenceContext* c) 144     .SetShapeFn([](InferenceContext* c) {
145       int64 num_true;
146       TF_RETURN_IF_ERROR(c->GetAttr("num_true", &num_true));
147 
148       // Validate true_classes, must be a matrix.
149       ShapeHandle true_classes;
150       TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 2, &true_classes));
151       DimensionHandle unused;
152       TF_RETURN_IF_ERROR(
153           c->WithValue(c->Dim(true_classes, 1), num_true, &unused));
154       // Validate sampled_candidates, must be a vector.
155       ShapeHandle sampled_candidates;
156       TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 1, &sampled_candidates));
157 
158       // All three outputs are the same shape.
159       ShapeHandle v = c->Vector(InferenceContext::kUnknownDim);
160       c->set_output(0, v);
161       c->set_output(1, v);
162       c->set_output(2, v);
163       return Status::OK();
164     });
165 
166 }  // namespace tensorflow
167