1 /*
2  * Copyright (C) 2018 The Android Open Source Project
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *      http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #include "actions/ranker.h"
18 
19 #include <functional>
20 #include <set>
21 #include <vector>
22 
23 #if !defined(TC3_DISABLE_LUA)
24 #include "actions/lua-ranker.h"
25 #endif
26 #include "actions/zlib-utils.h"
27 #include "annotator/types.h"
28 #include "utils/base/logging.h"
29 #if !defined(TC3_DISABLE_LUA)
30 #include "utils/lua-utils.h"
31 #endif
32 
33 namespace libtextclassifier3 {
34 namespace {
35 
SortByScoreAndType(std::vector<ActionSuggestion> * actions)36 void SortByScoreAndType(std::vector<ActionSuggestion>* actions) {
37   std::sort(actions->begin(), actions->end(),
38             [](const ActionSuggestion& a, const ActionSuggestion& b) {
39               return a.score > b.score ||
40                      (a.score >= b.score && a.type < b.type);
41             });
42 }
43 
44 template <typename T>
Compare(const T & left,const T & right)45 int Compare(const T& left, const T& right) {
46   if (left < right) {
47     return -1;
48   }
49   if (left > right) {
50     return 1;
51   }
52   return 0;
53 }
54 
55 template <>
Compare(const std::string & left,const std::string & right)56 int Compare(const std::string& left, const std::string& right) {
57   return left.compare(right);
58 }
59 
60 template <>
Compare(const MessageTextSpan & span,const MessageTextSpan & other)61 int Compare(const MessageTextSpan& span, const MessageTextSpan& other) {
62   if (const int value = Compare(span.message_index, other.message_index)) {
63     return value;
64   }
65   if (const int value = Compare(span.span.first, other.span.first)) {
66     return value;
67   }
68   if (const int value = Compare(span.span.second, other.span.second)) {
69     return value;
70   }
71   return 0;
72 }
73 
IsSameSpan(const MessageTextSpan & span,const MessageTextSpan & other)74 bool IsSameSpan(const MessageTextSpan& span, const MessageTextSpan& other) {
75   return Compare(span, other) == 0;
76 }
77 
TextSpansIntersect(const MessageTextSpan & span,const MessageTextSpan & other)78 bool TextSpansIntersect(const MessageTextSpan& span,
79                         const MessageTextSpan& other) {
80   return span.message_index == other.message_index &&
81          SpansOverlap(span.span, other.span);
82 }
83 
84 template <>
Compare(const ActionSuggestionAnnotation & annotation,const ActionSuggestionAnnotation & other)85 int Compare(const ActionSuggestionAnnotation& annotation,
86             const ActionSuggestionAnnotation& other) {
87   if (const int value = Compare(annotation.span, other.span)) {
88     return value;
89   }
90   if (const int value = Compare(annotation.name, other.name)) {
91     return value;
92   }
93   if (const int value =
94           Compare(annotation.entity.collection, other.entity.collection)) {
95     return value;
96   }
97   return 0;
98 }
99 
100 // Checks whether two annotations can be considered equivalent.
IsEquivalentActionAnnotation(const ActionSuggestionAnnotation & annotation,const ActionSuggestionAnnotation & other)101 bool IsEquivalentActionAnnotation(const ActionSuggestionAnnotation& annotation,
102                                   const ActionSuggestionAnnotation& other) {
103   return Compare(annotation, other) == 0;
104 }
105 
106 // Compares actions based on annotations.
CompareAnnotationsOnly(const ActionSuggestion & action,const ActionSuggestion & other)107 int CompareAnnotationsOnly(const ActionSuggestion& action,
108                            const ActionSuggestion& other) {
109   if (const int value =
110           Compare(action.annotations.size(), other.annotations.size())) {
111     return value;
112   }
113   for (int i = 0; i < action.annotations.size(); i++) {
114     if (const int value =
115             Compare(action.annotations[i], other.annotations[i])) {
116       return value;
117     }
118   }
119   return 0;
120 }
121 
122 // Checks whether two actions have the same annotations.
HaveEquivalentAnnotations(const ActionSuggestion & action,const ActionSuggestion & other)123 bool HaveEquivalentAnnotations(const ActionSuggestion& action,
124                                const ActionSuggestion& other) {
125   return CompareAnnotationsOnly(action, other) == 0;
126 }
127 
128 template <>
Compare(const ActionSuggestion & action,const ActionSuggestion & other)129 int Compare(const ActionSuggestion& action, const ActionSuggestion& other) {
130   if (const int value = Compare(action.type, other.type)) {
131     return value;
132   }
133   if (const int value = Compare(action.response_text, other.response_text)) {
134     return value;
135   }
136   if (const int value = Compare(action.serialized_entity_data,
137                                 other.serialized_entity_data)) {
138     return value;
139   }
140   return CompareAnnotationsOnly(action, other);
141 }
142 
143 // Checks whether two action suggestions can be considered equivalent.
IsEquivalentActionSuggestion(const ActionSuggestion & action,const ActionSuggestion & other)144 bool IsEquivalentActionSuggestion(const ActionSuggestion& action,
145                                   const ActionSuggestion& other) {
146   return Compare(action, other) == 0;
147 }
148 
149 // Checks whether any action is equivalent to the given one.
IsAnyActionEquivalent(const ActionSuggestion & action,const std::vector<ActionSuggestion> & actions)150 bool IsAnyActionEquivalent(const ActionSuggestion& action,
151                            const std::vector<ActionSuggestion>& actions) {
152   for (const ActionSuggestion& other : actions) {
153     if (IsEquivalentActionSuggestion(action, other)) {
154       return true;
155     }
156   }
157   return false;
158 }
159 
IsConflicting(const ActionSuggestionAnnotation & annotation,const ActionSuggestionAnnotation & other)160 bool IsConflicting(const ActionSuggestionAnnotation& annotation,
161                    const ActionSuggestionAnnotation& other) {
162   // Two annotations are conflicting if they are different but refer to
163   // overlapping spans in the conversation.
164   return (!IsEquivalentActionAnnotation(annotation, other) &&
165           TextSpansIntersect(annotation.span, other.span));
166 }
167 
168 // Checks whether two action suggestions can be considered conflicting.
IsConflictingActionSuggestion(const ActionSuggestion & action,const ActionSuggestion & other)169 bool IsConflictingActionSuggestion(const ActionSuggestion& action,
170                                    const ActionSuggestion& other) {
171   // Actions are considered conflicting, iff they refer to the same text span,
172   // but were not generated from the same annotation.
173   if (action.annotations.empty() || other.annotations.empty()) {
174     return false;
175   }
176   for (const ActionSuggestionAnnotation& annotation : action.annotations) {
177     for (const ActionSuggestionAnnotation& other_annotation :
178          other.annotations) {
179       if (IsConflicting(annotation, other_annotation)) {
180         return true;
181       }
182     }
183   }
184   return false;
185 }
186 
187 // Checks whether any action is considered conflicting with the given one.
IsAnyActionConflicting(const ActionSuggestion & action,const std::vector<ActionSuggestion> & actions)188 bool IsAnyActionConflicting(const ActionSuggestion& action,
189                             const std::vector<ActionSuggestion>& actions) {
190   for (const ActionSuggestion& other : actions) {
191     if (IsConflictingActionSuggestion(action, other)) {
192       return true;
193     }
194   }
195   return false;
196 }
197 
198 }  // namespace
199 
200 std::unique_ptr<ActionsSuggestionsRanker>
CreateActionsSuggestionsRanker(const RankingOptions * options,ZlibDecompressor * decompressor,const std::string & smart_reply_action_type)201 ActionsSuggestionsRanker::CreateActionsSuggestionsRanker(
202     const RankingOptions* options, ZlibDecompressor* decompressor,
203     const std::string& smart_reply_action_type) {
204   auto ranker = std::unique_ptr<ActionsSuggestionsRanker>(
205       new ActionsSuggestionsRanker(options, smart_reply_action_type));
206 
207   if (!ranker->InitializeAndValidate(decompressor)) {
208     TC3_LOG(ERROR) << "Could not initialize action ranker.";
209     return nullptr;
210   }
211 
212   return ranker;
213 }
214 
InitializeAndValidate(ZlibDecompressor * decompressor)215 bool ActionsSuggestionsRanker::InitializeAndValidate(
216     ZlibDecompressor* decompressor) {
217   if (options_ == nullptr) {
218     TC3_LOG(ERROR) << "No ranking options specified.";
219     return false;
220   }
221 
222 #if !defined(TC3_DISABLE_LUA)
223   std::string lua_ranking_script;
224   if (GetUncompressedString(options_->lua_ranking_script(),
225                             options_->compressed_lua_ranking_script(),
226                             decompressor, &lua_ranking_script) &&
227       !lua_ranking_script.empty()) {
228     if (!Compile(lua_ranking_script, &lua_bytecode_)) {
229       TC3_LOG(ERROR) << "Could not precompile lua ranking snippet.";
230       return false;
231     }
232   }
233 #endif
234 
235   return true;
236 }
237 
RankActions(const Conversation & conversation,ActionsSuggestionsResponse * response,const reflection::Schema * entity_data_schema,const reflection::Schema * annotations_entity_data_schema) const238 bool ActionsSuggestionsRanker::RankActions(
239     const Conversation& conversation, ActionsSuggestionsResponse* response,
240     const reflection::Schema* entity_data_schema,
241     const reflection::Schema* annotations_entity_data_schema) const {
242   if (options_->deduplicate_suggestions() ||
243       options_->deduplicate_suggestions_by_span()) {
244     // First order suggestions by priority score for deduplication.
245     std::sort(
246         response->actions.begin(), response->actions.end(),
247         [](const ActionSuggestion& a, const ActionSuggestion& b) {
248           return a.priority_score > b.priority_score ||
249                  (a.priority_score >= b.priority_score && a.score > b.score);
250         });
251 
252     // Deduplicate, keeping the higher score actions.
253     if (options_->deduplicate_suggestions()) {
254       std::vector<ActionSuggestion> deduplicated_actions;
255       for (const ActionSuggestion& candidate : response->actions) {
256         // Check whether we already have an equivalent action.
257         if (!IsAnyActionEquivalent(candidate, deduplicated_actions)) {
258           deduplicated_actions.push_back(std::move(candidate));
259         }
260       }
261       response->actions = std::move(deduplicated_actions);
262     }
263 
264     // Resolve conflicts between conflicting actions referring to the same
265     // text span.
266     if (options_->deduplicate_suggestions_by_span()) {
267       std::vector<ActionSuggestion> deduplicated_actions;
268       for (const ActionSuggestion& candidate : response->actions) {
269         // Check whether we already have a conflicting action.
270         if (!IsAnyActionConflicting(candidate, deduplicated_actions)) {
271           deduplicated_actions.push_back(std::move(candidate));
272         }
273       }
274       response->actions = std::move(deduplicated_actions);
275     }
276   }
277 
278   // Suppress smart replies if actions are present.
279   if (options_->suppress_smart_replies_with_actions()) {
280     std::vector<ActionSuggestion> non_smart_reply_actions;
281     for (const ActionSuggestion& action : response->actions) {
282       if (action.type != smart_reply_action_type_) {
283         non_smart_reply_actions.push_back(std::move(action));
284       }
285     }
286     response->actions = std::move(non_smart_reply_actions);
287   }
288 
289   // Group by annotation if specified.
290   if (options_->group_by_annotations()) {
291     auto group_id = std::map<
292         ActionSuggestion, int,
293         std::function<bool(const ActionSuggestion&, const ActionSuggestion&)>>{
294         [](const ActionSuggestion& action, const ActionSuggestion& other) {
295           return (CompareAnnotationsOnly(action, other) < 0);
296         }};
297     typedef std::vector<ActionSuggestion> ActionSuggestionGroup;
298     std::vector<ActionSuggestionGroup> groups;
299 
300     // Group actions by the annotation set they are based of.
301     for (const ActionSuggestion& action : response->actions) {
302       // Treat actions with no annotations idependently.
303       if (action.annotations.empty()) {
304         groups.emplace_back(1, action);
305         continue;
306       }
307 
308       auto it = group_id.find(action);
309       if (it != group_id.end()) {
310         groups[it->second].push_back(action);
311       } else {
312         group_id[action] = groups.size();
313         groups.emplace_back(1, action);
314       }
315     }
316 
317     // Sort within each group by score.
318     for (std::vector<ActionSuggestion>& group : groups) {
319       SortByScoreAndType(&group);
320     }
321 
322     // Sort groups by maximum score.
323     std::sort(groups.begin(), groups.end(),
324               [](const std::vector<ActionSuggestion>& a,
325                  const std::vector<ActionSuggestion>& b) {
326                 return a.begin()->score > b.begin()->score ||
327                        (a.begin()->score >= b.begin()->score &&
328                         a.begin()->type < b.begin()->type);
329               });
330 
331     // Flatten result.
332     const size_t num_actions = response->actions.size();
333     response->actions.clear();
334     response->actions.reserve(num_actions);
335     for (const std::vector<ActionSuggestion>& actions : groups) {
336       response->actions.insert(response->actions.end(), actions.begin(),
337                                actions.end());
338     }
339 
340   } else {
341     // Order suggestions independently by score.
342     SortByScoreAndType(&response->actions);
343   }
344 
345 #if !defined(TC3_DISABLE_LUA)
346   // Run lua ranking snippet, if provided.
347   if (!lua_bytecode_.empty()) {
348     auto lua_ranker = ActionsSuggestionsLuaRanker::Create(
349         conversation, lua_bytecode_, entity_data_schema,
350         annotations_entity_data_schema, response);
351     if (lua_ranker == nullptr || !lua_ranker->RankActions()) {
352       TC3_LOG(ERROR) << "Could not run lua ranking snippet.";
353       return false;
354     }
355   }
356 #endif
357 
358   return true;
359 }
360 
361 }  // namespace libtextclassifier3
362