1 /*
2  * Copyright (C) 2017 The Android Open Source Project
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *      http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #include "common/task-context.h"
18 
19 #include <stdlib.h>
20 
21 #include <string>
22 
23 #include "util/base/integral_types.h"
24 #include "util/base/logging.h"
25 #include "util/strings/numbers.h"
26 
27 namespace libtextclassifier {
28 namespace nlp_core {
29 
30 namespace {
ParseInt32WithDefault(const std::string & s,int32 defval)31 int32 ParseInt32WithDefault(const std::string &s, int32 defval) {
32   int32 value = defval;
33   return ParseInt32(s.c_str(), &value) ? value : defval;
34 }
35 
ParseInt64WithDefault(const std::string & s,int64 defval)36 int64 ParseInt64WithDefault(const std::string &s, int64 defval) {
37   int64 value = defval;
38   return ParseInt64(s.c_str(), &value) ? value : defval;
39 }
40 
ParseDoubleWithDefault(const std::string & s,double defval)41 double ParseDoubleWithDefault(const std::string &s, double defval) {
42   double value = defval;
43   return ParseDouble(s.c_str(), &value) ? value : defval;
44 }
45 }  // namespace
46 
GetInput(const std::string & name)47 TaskInput *TaskContext::GetInput(const std::string &name) {
48   // Return existing input if it exists.
49   for (int i = 0; i < spec_.input_size(); ++i) {
50     if (spec_.input(i).name() == name) return spec_.mutable_input(i);
51   }
52 
53   // Create new input.
54   TaskInput *input = spec_.add_input();
55   input->set_name(name);
56   return input;
57 }
58 
GetInput(const std::string & name,const std::string & file_format,const std::string & record_format)59 TaskInput *TaskContext::GetInput(const std::string &name,
60                                  const std::string &file_format,
61                                  const std::string &record_format) {
62   TaskInput *input = GetInput(name);
63   if (!file_format.empty()) {
64     bool found = false;
65     for (int i = 0; i < input->file_format_size(); ++i) {
66       if (input->file_format(i) == file_format) found = true;
67     }
68     if (!found) input->add_file_format(file_format);
69   }
70   if (!record_format.empty()) {
71     bool found = false;
72     for (int i = 0; i < input->record_format_size(); ++i) {
73       if (input->record_format(i) == record_format) found = true;
74     }
75     if (!found) input->add_record_format(record_format);
76   }
77   return input;
78 }
79 
SetParameter(const std::string & name,const std::string & value)80 void TaskContext::SetParameter(const std::string &name,
81                                const std::string &value) {
82   TC_LOG(INFO) << "SetParameter(" << name << ", " << value << ")";
83 
84   // If the parameter already exists update the value.
85   for (int i = 0; i < spec_.parameter_size(); ++i) {
86     if (spec_.parameter(i).name() == name) {
87       spec_.mutable_parameter(i)->set_value(value);
88       return;
89     }
90   }
91 
92   // Add new parameter.
93   TaskSpec::Parameter *param = spec_.add_parameter();
94   param->set_name(name);
95   param->set_value(value);
96 }
97 
GetParameter(const std::string & name) const98 std::string TaskContext::GetParameter(const std::string &name) const {
99   // First try to find parameter in task specification.
100   for (int i = 0; i < spec_.parameter_size(); ++i) {
101     if (spec_.parameter(i).name() == name) return spec_.parameter(i).value();
102   }
103 
104   // Parameter not found, return empty std::string.
105   return "";
106 }
107 
GetIntParameter(const std::string & name) const108 int TaskContext::GetIntParameter(const std::string &name) const {
109   std::string value = GetParameter(name);
110   return ParseInt32WithDefault(value, 0);
111 }
112 
GetInt64Parameter(const std::string & name) const113 int64 TaskContext::GetInt64Parameter(const std::string &name) const {
114   std::string value = GetParameter(name);
115   return ParseInt64WithDefault(value, 0);
116 }
117 
GetBoolParameter(const std::string & name) const118 bool TaskContext::GetBoolParameter(const std::string &name) const {
119   std::string value = GetParameter(name);
120   return value == "true";
121 }
122 
GetFloatParameter(const std::string & name) const123 double TaskContext::GetFloatParameter(const std::string &name) const {
124   std::string value = GetParameter(name);
125   return ParseDoubleWithDefault(value, 0.0);
126 }
127 
Get(const std::string & name,const char * defval) const128 std::string TaskContext::Get(const std::string &name,
129                              const char *defval) const {
130   // First try to find parameter in task specification.
131   for (int i = 0; i < spec_.parameter_size(); ++i) {
132     if (spec_.parameter(i).name() == name) return spec_.parameter(i).value();
133   }
134 
135   // Parameter not found, return default value.
136   return defval;
137 }
138 
Get(const std::string & name,const std::string & defval) const139 std::string TaskContext::Get(const std::string &name,
140                              const std::string &defval) const {
141   return Get(name, defval.c_str());
142 }
143 
Get(const std::string & name,int defval) const144 int TaskContext::Get(const std::string &name, int defval) const {
145   std::string value = Get(name, "");
146   return ParseInt32WithDefault(value, defval);
147 }
148 
Get(const std::string & name,int64 defval) const149 int64 TaskContext::Get(const std::string &name, int64 defval) const {
150   std::string value = Get(name, "");
151   return ParseInt64WithDefault(value, defval);
152 }
153 
Get(const std::string & name,double defval) const154 double TaskContext::Get(const std::string &name, double defval) const {
155   std::string value = Get(name, "");
156   return ParseDoubleWithDefault(value, defval);
157 }
158 
Get(const std::string & name,bool defval) const159 bool TaskContext::Get(const std::string &name, bool defval) const {
160   std::string value = Get(name, "");
161   return value.empty() ? defval : value == "true";
162 }
163 
InputFile(const TaskInput & input)164 std::string TaskContext::InputFile(const TaskInput &input) {
165   if (input.part_size() == 0) {
166     TC_LOG(ERROR) << "No file for TaskInput " << input.name();
167     return "";
168   }
169   if (input.part_size() > 1) {
170     TC_LOG(ERROR) << "Ambiguous: multiple files for TaskInput " << input.name();
171   }
172   return input.part(0).file_pattern();
173 }
174 
Supports(const TaskInput & input,const std::string & file_format,const std::string & record_format)175 bool TaskContext::Supports(const TaskInput &input,
176                            const std::string &file_format,
177                            const std::string &record_format) {
178   // Check file format.
179   if (input.file_format_size() > 0) {
180     bool found = false;
181     for (int i = 0; i < input.file_format_size(); ++i) {
182       if (input.file_format(i) == file_format) {
183         found = true;
184         break;
185       }
186     }
187     if (!found) return false;
188   }
189 
190   // Check record format.
191   if (input.record_format_size() > 0) {
192     bool found = false;
193     for (int i = 0; i < input.record_format_size(); ++i) {
194       if (input.record_format(i) == record_format) {
195         found = true;
196         break;
197       }
198     }
199     if (!found) return false;
200   }
201 
202   return true;
203 }
204 
205 }  // namespace nlp_core
206 }  // namespace libtextclassifier
207