1 /*
2  * Copyright (C) 2018 The Android Open Source Project
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *      http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #include "actions/grammar-actions.h"
18 
19 #include "actions/feature-processor.h"
20 #include "actions/utils.h"
21 #include "annotator/types.h"
22 #include "utils/base/arena.h"
23 #include "utils/base/statusor.h"
24 #include "utils/utf8/unicodetext.h"
25 
26 namespace libtextclassifier3 {
27 
GrammarActions(const UniLib * unilib,const RulesModel_::GrammarRules * grammar_rules,const MutableFlatbufferBuilder * entity_data_builder,const std::string & smart_reply_action_type)28 GrammarActions::GrammarActions(
29     const UniLib* unilib, const RulesModel_::GrammarRules* grammar_rules,
30     const MutableFlatbufferBuilder* entity_data_builder,
31     const std::string& smart_reply_action_type)
32     : unilib_(*unilib),
33       grammar_rules_(grammar_rules),
34       tokenizer_(CreateTokenizer(grammar_rules->tokenizer_options(), unilib)),
35       entity_data_builder_(entity_data_builder),
36       analyzer_(unilib, grammar_rules->rules(), tokenizer_.get()),
37       smart_reply_action_type_(smart_reply_action_type) {}
38 
InstantiateActionsFromMatch(const grammar::TextContext & text_context,const int message_index,const grammar::Derivation & derivation,std::vector<ActionSuggestion> * result) const39 bool GrammarActions::InstantiateActionsFromMatch(
40     const grammar::TextContext& text_context, const int message_index,
41     const grammar::Derivation& derivation,
42     std::vector<ActionSuggestion>* result) const {
43   const RulesModel_::GrammarRules_::RuleMatch* rule_match =
44       grammar_rules_->rule_match()->Get(derivation.rule_id);
45   if (rule_match == nullptr || rule_match->action_id() == nullptr) {
46     TC3_LOG(ERROR) << "No rule action defined.";
47     return false;
48   }
49 
50   // Gather active capturing matches.
51   std::unordered_map<uint16, const grammar::ParseTree*> capturing_matches;
52   for (const grammar::MappingNode* mapping_node :
53        grammar::SelectAllOfType<grammar::MappingNode>(
54            derivation.parse_tree, grammar::ParseTree::Type::kMapping)) {
55     capturing_matches[mapping_node->id] = mapping_node;
56   }
57 
58   // Instantiate actions from the rule match.
59   for (const uint16 action_id : *rule_match->action_id()) {
60     const RulesModel_::RuleActionSpec* action_spec =
61         grammar_rules_->actions()->Get(action_id);
62     std::vector<ActionSuggestionAnnotation> annotations;
63 
64     std::unique_ptr<MutableFlatbuffer> entity_data =
65         entity_data_builder_ != nullptr ? entity_data_builder_->NewRoot()
66                                         : nullptr;
67 
68     // Set information from capturing matches.
69     if (action_spec->capturing_group() != nullptr) {
70       for (const RulesModel_::RuleActionSpec_::RuleCapturingGroup* group :
71            *action_spec->capturing_group()) {
72         auto it = capturing_matches.find(group->group_id());
73         if (it == capturing_matches.end()) {
74           // Capturing match is not active, skip.
75           continue;
76         }
77 
78         const grammar::ParseTree* capturing_match = it->second;
79         const UnicodeText match_text =
80             text_context.Span(capturing_match->codepoint_span);
81         UnicodeText normalized_match_text =
82             NormalizeMatchText(unilib_, group, match_text);
83 
84         if (!MergeEntityDataFromCapturingMatch(
85                 group, normalized_match_text.ToUTF8String(),
86                 entity_data.get())) {
87           TC3_LOG(ERROR)
88               << "Could not merge entity data from a capturing match.";
89           return false;
90         }
91 
92         // Add smart reply suggestions.
93         SuggestTextRepliesFromCapturingMatch(entity_data_builder_, group,
94                                              normalized_match_text,
95                                              smart_reply_action_type_, result);
96 
97         // Add annotation.
98         ActionSuggestionAnnotation annotation;
99         if (FillAnnotationFromCapturingMatch(
100                 /*span=*/capturing_match->codepoint_span, group,
101                 /*message_index=*/message_index, match_text.ToUTF8String(),
102                 &annotation)) {
103           if (group->use_annotation_match()) {
104             std::vector<const grammar::AnnotationNode*> annotations =
105                 grammar::SelectAllOfType<grammar::AnnotationNode>(
106                     capturing_match, grammar::ParseTree::Type::kAnnotation);
107             if (annotations.size() != 1) {
108               TC3_LOG(ERROR) << "Could not get annotation for match.";
109               return false;
110             }
111             annotation.entity = *annotations.front()->annotation;
112           }
113           annotations.push_back(std::move(annotation));
114         }
115       }
116     }
117 
118     if (action_spec->action() != nullptr) {
119       ActionSuggestion suggestion;
120       suggestion.annotations = annotations;
121       FillSuggestionFromSpec(action_spec->action(), entity_data.get(),
122                              &suggestion);
123       result->push_back(std::move(suggestion));
124     }
125   }
126   return true;
127 }
SuggestActions(const Conversation & conversation,std::vector<ActionSuggestion> * result) const128 bool GrammarActions::SuggestActions(
129     const Conversation& conversation,
130     std::vector<ActionSuggestion>* result) const {
131   if (grammar_rules_->rules()->rules() == nullptr ||
132       conversation.messages.back().text.empty()) {
133     // Nothing to do.
134     return true;
135   }
136 
137   std::vector<Locale> locales;
138   if (!ParseLocales(conversation.messages.back().detected_text_language_tags,
139                     &locales)) {
140     TC3_LOG(ERROR) << "Could not parse locales of input text.";
141     return false;
142   }
143 
144   const int message_index = conversation.messages.size() - 1;
145   grammar::TextContext text = analyzer_.BuildTextContextForInput(
146       UTF8ToUnicodeText(conversation.messages.back().text, /*do_copy=*/false),
147       locales);
148   text.annotations = conversation.messages.back().annotations;
149 
150   UnsafeArena arena(/*block_size=*/16 << 10);
151   StatusOr<std::vector<grammar::EvaluatedDerivation>> evaluated_derivations =
152       analyzer_.Parse(text, &arena);
153   // TODO(b/171294882): Return the status here and below.
154   if (!evaluated_derivations.ok()) {
155     TC3_LOG(ERROR) << "Could not run grammar analyzer: "
156                    << evaluated_derivations.status().error_message();
157     return false;
158   }
159 
160   for (const grammar::EvaluatedDerivation& evaluated_derivation :
161        evaluated_derivations.ValueOrDie()) {
162     if (!InstantiateActionsFromMatch(text, message_index, evaluated_derivation,
163                                      result)) {
164       TC3_LOG(ERROR) << "Could not instantiate actions from a grammar match.";
165       return false;
166     }
167   }
168 
169   return true;
170 }
171 
172 }  // namespace libtextclassifier3
173