1 // Copyright (C) 2019 Google LLC
2 //
3 // Licensed under the Apache License, Version 2.0 (the "License");
4 // you may not use this file except in compliance with the License.
5 // You may obtain a copy of the License at
6 //
7 //      http://www.apache.org/licenses/LICENSE-2.0
8 //
9 // Unless required by applicable law or agreed to in writing, software
10 // distributed under the License is distributed on an "AS IS" BASIS,
11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 // See the License for the specific language governing permissions and
13 // limitations under the License.
14 
15 #ifndef ICING_TESTING_COMMON_MATCHERS_H_
16 #define ICING_TESTING_COMMON_MATCHERS_H_
17 
18 #include <cmath>
19 
20 #include "icing/text_classifier/lib3/utils/base/status.h"
21 #include "icing/text_classifier/lib3/utils/base/status_macros.h"
22 #include "icing/text_classifier/lib3/utils/base/statusor.h"
23 #include "gmock/gmock.h"
24 #include "gtest/gtest.h"
25 #include "icing/absl_ports/str_join.h"
26 #include "icing/index/hit/doc-hit-info.h"
27 #include "icing/legacy/core/icing-string-util.h"
28 #include "icing/proto/search.pb.h"
29 #include "icing/schema/schema-store.h"
30 #include "icing/schema/section.h"
31 #include "icing/util/status-macros.h"
32 
33 namespace icing {
34 namespace lib {
35 
36 // Used to match Token(Token::Type type, std::string_view text)
37 MATCHER_P2(EqualsToken, type, text, "") {
38   std::string arg_string(arg.text.data(), arg.text.length());
39   if (arg.type != type || arg.text != text) {
40     *result_listener << IcingStringUtil::StringPrintf(
41         "(Expected: type=%d, text=\"%s\". Actual: type=%d, text=\"%s\")", type,
42         text, arg.type, arg_string.c_str());
43     return false;
44   }
45   return true;
46 }
47 
48 // Used to match a DocHitInfo
49 MATCHER_P2(EqualsDocHitInfo, document_id, section_ids, "") {
50   const DocHitInfo& actual = arg;
51   SectionIdMask section_mask = kSectionIdMaskNone;
52   for (SectionId section_id : section_ids) {
53     section_mask |= 1U << section_id;
54   }
55   *result_listener << IcingStringUtil::StringPrintf(
56       "(actual is {document_id=%d, section_mask=%d}, but expected was "
57       "{document_id=%d, section_mask=%d}.)",
58       actual.document_id(), actual.hit_section_ids_mask(), document_id,
59       section_mask);
60   return actual.document_id() == document_id &&
61          actual.hit_section_ids_mask() == section_mask;
62 }
63 
64 // Used to match a DocHitInfo
65 MATCHER_P2(EqualsDocHitInfoWithTermFrequency, document_id,
66            section_ids_to_term_frequencies_map, "") {
67   const DocHitInfo& actual = arg;
68   SectionIdMask section_mask = kSectionIdMaskNone;
69 
70   bool term_frequency_as_expected = true;
71   std::vector<Hit::TermFrequency> expected_tfs;
72   std::vector<Hit::TermFrequency> actual_tfs;
73   for (auto itr = section_ids_to_term_frequencies_map.begin();
74        itr != section_ids_to_term_frequencies_map.end(); itr++) {
75     SectionId section_id = itr->first;
76     section_mask |= 1U << section_id;
77     expected_tfs.push_back(itr->second);
78     actual_tfs.push_back(actual.hit_term_frequency(section_id));
79     if (actual.hit_term_frequency(section_id) != itr->second) {
80       term_frequency_as_expected = false;
81     }
82   }
83   std::string actual_term_frequencies = absl_ports::StrCat(
84       "[", absl_ports::StrJoin(actual_tfs, ",", absl_ports::NumberFormatter()),
85       "]");
86   std::string expected_term_frequencies = absl_ports::StrCat(
87       "[",
88       absl_ports::StrJoin(expected_tfs, ",", absl_ports::NumberFormatter()),
89       "]");
90   *result_listener << IcingStringUtil::StringPrintf(
91       "(actual is {document_id=%d, section_mask=%d, term_frequencies=%s}, but "
92       "expected was "
93       "{document_id=%d, section_mask=%d, term_frequencies=%s}.)",
94       actual.document_id(), actual.hit_section_ids_mask(),
95       actual_term_frequencies.c_str(), document_id, section_mask,
96       expected_term_frequencies.c_str());
97   return actual.document_id() == document_id &&
98          actual.hit_section_ids_mask() == section_mask &&
99          term_frequency_as_expected;
100 }
101 
102 // Used to match a ScoredDocumentHit
103 MATCHER_P(EqualsScoredDocumentHit, expected_scored_document_hit, "") {
104   if (arg.document_id() != expected_scored_document_hit.document_id() ||
105       arg.hit_section_id_mask() !=
106           expected_scored_document_hit.hit_section_id_mask() ||
107       std::fabs(arg.score() - expected_scored_document_hit.score()) > 1e-6) {
108     *result_listener << IcingStringUtil::StringPrintf(
109         "Expected: document_id=%d, hit_section_id_mask=%d, score=%.2f. Actual: "
110         "document_id=%d, hit_section_id_mask=%d, score=%.2f",
111         expected_scored_document_hit.document_id(),
112         expected_scored_document_hit.hit_section_id_mask(),
113         expected_scored_document_hit.score(), arg.document_id(),
114         arg.hit_section_id_mask(), arg.score());
115     return false;
116   }
117   return true;
118 }
119 
120 MATCHER_P(EqualsSetSchemaResult, expected, "") {
121   const SchemaStore::SetSchemaResult& actual = arg;
122 
123   if (actual.success == expected.success &&
124       actual.index_incompatible == expected.index_incompatible &&
125       actual.old_schema_type_ids_changed ==
126           expected.old_schema_type_ids_changed &&
127       actual.schema_types_deleted_by_name ==
128           expected.schema_types_deleted_by_name &&
129       actual.schema_types_deleted_by_id ==
130           expected.schema_types_deleted_by_id &&
131       actual.schema_types_incompatible_by_name ==
132           expected.schema_types_incompatible_by_name &&
133       actual.schema_types_incompatible_by_id ==
134           expected.schema_types_incompatible_by_id) {
135     return true;
136   }
137 
138   // Format schema_type_ids_changed
139   std::string actual_old_schema_type_ids_changed = absl_ports::StrCat(
140       "[",
141       absl_ports::StrJoin(actual.old_schema_type_ids_changed, ",",
142                           absl_ports::NumberFormatter()),
143       "]");
144 
145   std::string expected_old_schema_type_ids_changed = absl_ports::StrCat(
146       "[",
147       absl_ports::StrJoin(expected.old_schema_type_ids_changed, ",",
148                           absl_ports::NumberFormatter()),
149       "]");
150 
151   // Format schema_types_deleted_by_name
152   std::string actual_schema_types_deleted_by_name = absl_ports::StrCat(
153       "[", absl_ports::StrJoin(actual.schema_types_deleted_by_name, ","), "]");
154 
155   std::string expected_schema_types_deleted_by_name = absl_ports::StrCat(
156       "[", absl_ports::StrJoin(expected.schema_types_deleted_by_name, ","),
157       "]");
158 
159   // Format schema_types_deleted_by_id
160   std::string actual_schema_types_deleted_by_id = absl_ports::StrCat(
161       "[",
162       absl_ports::StrJoin(actual.schema_types_deleted_by_id, ",",
163                           absl_ports::NumberFormatter()),
164       "]");
165 
166   std::string expected_schema_types_deleted_by_id = absl_ports::StrCat(
167       "[",
168       absl_ports::StrJoin(expected.schema_types_deleted_by_id, ",",
169                           absl_ports::NumberFormatter()),
170       "]");
171 
172   // Format schema_types_incompatible_by_name
173   std::string actual_schema_types_incompatible_by_name = absl_ports::StrCat(
174       "[", absl_ports::StrJoin(actual.schema_types_incompatible_by_name, ","),
175       "]");
176 
177   std::string expected_schema_types_incompatible_by_name = absl_ports::StrCat(
178       "[", absl_ports::StrJoin(expected.schema_types_incompatible_by_name, ","),
179       "]");
180 
181   // Format schema_types_incompatible_by_id
182   std::string actual_schema_types_incompatible_by_id = absl_ports::StrCat(
183       "[",
184       absl_ports::StrJoin(actual.schema_types_incompatible_by_id, ",",
185                           absl_ports::NumberFormatter()),
186       "]");
187 
188   std::string expected_schema_types_incompatible_by_id = absl_ports::StrCat(
189       "[",
190       absl_ports::StrJoin(expected.schema_types_incompatible_by_id, ",",
191                           absl_ports::NumberFormatter()),
192       "]");
193 
194   *result_listener << IcingStringUtil::StringPrintf(
195       "\nExpected {\n"
196       "\tsuccess=%d,\n"
197       "\tindex_incompatible=%d,\n"
198       "\told_schema_type_ids_changed=%s,\n"
199       "\tschema_types_deleted_by_name=%s,\n"
200       "\tschema_types_deleted_by_id=%s,\n"
201       "\tschema_types_incompatible_by_name=%s,\n"
202       "\tschema_types_incompatible_by_id=%s\n"
203       "}\n"
204       "Actual {\n"
205       "\tsuccess=%d,\n"
206       "\tindex_incompatible=%d,\n"
207       "\told_schema_type_ids_changed=%s,\n"
208       "\tschema_types_deleted_by_name=%s,\n"
209       "\tschema_types_deleted_by_id=%s,\n"
210       "\tschema_types_incompatible_by_name=%s,\n"
211       "\tschema_types_incompatible_by_id=%s\n"
212       "}\n",
213       expected.success, expected.index_incompatible,
214       expected_old_schema_type_ids_changed.c_str(),
215       expected_schema_types_deleted_by_name.c_str(),
216       expected_schema_types_deleted_by_id.c_str(),
217       expected_schema_types_incompatible_by_name.c_str(),
218       expected_schema_types_incompatible_by_id.c_str(), actual.success,
219       actual.index_incompatible, actual_old_schema_type_ids_changed.c_str(),
220       actual_schema_types_deleted_by_name.c_str(),
221       actual_schema_types_deleted_by_id.c_str(),
222       actual_schema_types_incompatible_by_name.c_str(),
223       actual_schema_types_incompatible_by_id.c_str());
224 
225   return false;
226 }
227 
StatusCodeToString(libtextclassifier3::StatusCode code)228 std::string StatusCodeToString(libtextclassifier3::StatusCode code) {
229   switch (code) {
230     case libtextclassifier3::StatusCode::OK:
231       return "OK";
232     case libtextclassifier3::StatusCode::CANCELLED:
233       return "CANCELLED";
234     case libtextclassifier3::StatusCode::UNKNOWN:
235       return "UNKNOWN";
236     case libtextclassifier3::StatusCode::INVALID_ARGUMENT:
237       return "INVALID_ARGUMENT";
238     case libtextclassifier3::StatusCode::DEADLINE_EXCEEDED:
239       return "DEADLINE_EXCEEDED";
240     case libtextclassifier3::StatusCode::NOT_FOUND:
241       return "NOT_FOUND";
242     case libtextclassifier3::StatusCode::ALREADY_EXISTS:
243       return "ALREADY_EXISTS";
244     case libtextclassifier3::StatusCode::PERMISSION_DENIED:
245       return "PERMISSION_DENIED";
246     case libtextclassifier3::StatusCode::RESOURCE_EXHAUSTED:
247       return "RESOURCE_EXHAUSTED";
248     case libtextclassifier3::StatusCode::FAILED_PRECONDITION:
249       return "FAILED_PRECONDITION";
250     case libtextclassifier3::StatusCode::ABORTED:
251       return "ABORTED";
252     case libtextclassifier3::StatusCode::OUT_OF_RANGE:
253       return "OUT_OF_RANGE";
254     case libtextclassifier3::StatusCode::UNIMPLEMENTED:
255       return "UNIMPLEMENTED";
256     case libtextclassifier3::StatusCode::INTERNAL:
257       return "INTERNAL";
258     case libtextclassifier3::StatusCode::UNAVAILABLE:
259       return "UNAVAILABLE";
260     case libtextclassifier3::StatusCode::DATA_LOSS:
261       return "DATA_LOSS";
262     case libtextclassifier3::StatusCode::UNAUTHENTICATED:
263       return "UNAUTHENTICATED";
264     default:
265       return "";
266   }
267 }
268 
ProtoStatusCodeToString(StatusProto::Code code)269 std::string ProtoStatusCodeToString(StatusProto::Code code) {
270   switch (code) {
271     case StatusProto::OK:
272       return "OK";
273     case StatusProto::UNKNOWN:
274       return "UNKNOWN";
275     case StatusProto::INVALID_ARGUMENT:
276       return "INVALID_ARGUMENT";
277     case StatusProto::NOT_FOUND:
278       return "NOT_FOUND";
279     case StatusProto::ALREADY_EXISTS:
280       return "ALREADY_EXISTS";
281     case StatusProto::OUT_OF_SPACE:
282       return "OUT_OF_SPACE";
283     case StatusProto::FAILED_PRECONDITION:
284       return "FAILED_PRECONDITION";
285     case StatusProto::ABORTED:
286       return "ABORTED";
287     case StatusProto::INTERNAL:
288       return "INTERNAL";
289     case StatusProto::WARNING_DATA_LOSS:
290       return "WARNING_DATA_LOSS";
291     default:
292       return "";
293   }
294 }
295 
296 MATCHER(IsOk, "") {
297   libtextclassifier3::StatusAdapter adapter(arg);
298   if (adapter.status().ok()) {
299     return true;
300   }
301   *result_listener << IcingStringUtil::StringPrintf(
302       "Expected OK, actual was (%s:%s)",
303       StatusCodeToString(adapter.status().CanonicalCode()).c_str(),
304       adapter.status().error_message().c_str());
305   return false;
306 }
307 
308 MATCHER_P(IsOkAndHolds, matcher, "") {
309   if (!arg.ok()) {
310     *result_listener << IcingStringUtil::StringPrintf(
311         "Expected OK, actual was (%s:%s)",
312         StatusCodeToString(arg.status().CanonicalCode()).c_str(),
313         arg.status().error_message().c_str());
314     return false;
315   }
316   return ExplainMatchResult(matcher, arg.ValueOrDie(), result_listener);
317 }
318 
319 MATCHER_P(StatusIs, status_code, "") {
320   libtextclassifier3::StatusAdapter adapter(arg);
321   if (adapter.status().CanonicalCode() == status_code) {
322     return true;
323   }
324   *result_listener << IcingStringUtil::StringPrintf(
325       "Expected (%s:), actual was (%s:%s)",
326       StatusCodeToString(status_code).c_str(),
327       StatusCodeToString(adapter.status().CanonicalCode()).c_str(),
328       adapter.status().error_message().c_str());
329   return false;
330 }
331 
332 MATCHER_P2(StatusIs, status_code, error_matcher, "") {
333   libtextclassifier3::StatusAdapter adapter(arg);
334   if (adapter.status().CanonicalCode() != status_code) {
335     *result_listener << IcingStringUtil::StringPrintf(
336         "Expected (%s:), actual was (%s:%s)",
337         StatusCodeToString(status_code).c_str(),
338         StatusCodeToString(adapter.status().CanonicalCode()).c_str(),
339         adapter.status().error_message().c_str());
340     return false;
341   }
342   return ExplainMatchResult(error_matcher, adapter.status().error_message(),
343                             result_listener);
344 }
345 
346 MATCHER(ProtoIsOk, "") {
347   if (arg.code() == StatusProto::OK) {
348     return true;
349   }
350   *result_listener << IcingStringUtil::StringPrintf(
351       "Expected OK, actual was (%s:%s)",
352       ProtoStatusCodeToString(arg.code()).c_str(), arg.message().c_str());
353   return false;
354 }
355 
356 MATCHER_P(ProtoStatusIs, status_code, "") {
357   if (arg.code() == status_code) {
358     return true;
359   }
360   *result_listener << IcingStringUtil::StringPrintf(
361       "Expected (%s:), actual was (%s:%s)",
362       ProtoStatusCodeToString(status_code).c_str(),
363       ProtoStatusCodeToString(arg.code()).c_str(), arg.message().c_str());
364   return false;
365 }
366 
367 MATCHER_P2(ProtoStatusIs, status_code, error_matcher, "") {
368   if (arg.code() != status_code) {
369     *result_listener << IcingStringUtil::StringPrintf(
370         "Expected (%s:), actual was (%s:%s)",
371         ProtoStatusCodeToString(status_code).c_str(),
372         ProtoStatusCodeToString(arg.code()).c_str(), arg.message().c_str());
373     return false;
374   }
375   return ExplainMatchResult(error_matcher, arg.message(), result_listener);
376 }
377 
378 MATCHER_P(EqualsSearchResultIgnoreStatsAndScores, expected, "") {
379   SearchResultProto actual_copy = arg;
380   actual_copy.clear_query_stats();
381   actual_copy.clear_debug_info();
382   for (SearchResultProto::ResultProto& result :
383        *actual_copy.mutable_results()) {
384     result.clear_score();
385   }
386 
387   SearchResultProto expected_copy = expected;
388   expected_copy.clear_query_stats();
389   expected_copy.clear_debug_info();
390   for (SearchResultProto::ResultProto& result :
391        *expected_copy.mutable_results()) {
392     result.clear_score();
393   }
394   return ExplainMatchResult(testing::EqualsProto(expected_copy), actual_copy,
395                             result_listener);
396 }
397 
398 // TODO(tjbarron) Remove this once icing has switched to depend on TC3 Status
399 #define ICING_STATUS_MACROS_CONCAT_NAME(x, y) \
400   ICING_STATUS_MACROS_CONCAT_IMPL(x, y)
401 #define ICING_STATUS_MACROS_CONCAT_IMPL(x, y) x##y
402 
403 #define ICING_EXPECT_OK(func) EXPECT_THAT(func, IsOk())
404 #define ICING_ASSERT_OK(func) ASSERT_THAT(func, IsOk())
405 #define ICING_ASSERT_OK_AND_ASSIGN(lhs, rexpr)                             \
406   ICING_ASSERT_OK_AND_ASSIGN_IMPL(                                         \
407       ICING_STATUS_MACROS_CONCAT_NAME(_status_or_value, __COUNTER__), lhs, \
408       rexpr)
409 #define ICING_ASSERT_OK_AND_ASSIGN_IMPL(statusor, lhs, rexpr) \
410   auto statusor = (rexpr);                                    \
411   ICING_ASSERT_OK(statusor.status());                         \
412   lhs = std::move(statusor).ValueOrDie()
413 
414 }  // namespace lib
415 }  // namespace icing
416 
417 #endif  // ICING_TESTING_COMMON_MATCHERS_H_
418