/* * Copyright (C) 2018 The Android Open Source Project * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ #ifndef LIBTEXTCLASSIFIER_ACTIONS_LUA_ACTIONS_H_ #define LIBTEXTCLASSIFIER_ACTIONS_LUA_ACTIONS_H_ #include "actions/actions_model_generated.h" #include "actions/lua-utils.h" #include "actions/types.h" #include "utils/lua-utils.h" #include "utils/tensor-view.h" #include "utils/tflite-model-executor.h" namespace libtextclassifier3 { // Lua backed actions suggestions. class LuaActionsSuggestions : public LuaEnvironment { public: static std::unique_ptr CreateLuaActionsSuggestions( const std::string& snippet, const Conversation& conversation, const TfLiteModelExecutor* model_executor, const TensorflowLiteModelSpec* model_spec, const tflite::Interpreter* interpreter, const reflection::Schema* actions_entity_data_schema, const reflection::Schema* annotations_entity_data_schema); bool SuggestActions(std::vector* actions); private: // Model tensor lua iterator. class TensorViewIterator : public LuaEnvironment::ItemIterator> { public: explicit TensorViewIterator() {} int Item(const TensorView* tensor, const int64 index, lua_State* state) const override; }; LuaActionsSuggestions( const std::string& snippet, const Conversation& conversation, const TfLiteModelExecutor* model_executor, const TensorflowLiteModelSpec* model_spec, const tflite::Interpreter* interpreter, const reflection::Schema* actions_entity_data_schema, const reflection::Schema* annotations_entity_data_schema); bool Initialize(); const std::string& snippet_; const Conversation& conversation_; ConversationIterator conversation_iterator_; TensorViewIterator tensor_iterator_; TensorView actions_scores_; TensorView smart_reply_scores_; TensorView sensitivity_score_; TensorView triggering_score_; const reflection::Schema* actions_entity_data_schema_; const reflection::Schema* annotations_entity_data_schema_; }; } // namespace libtextclassifier3 #endif // LIBTEXTCLASSIFIER_ACTIONS_LUA_ACTIONS_H_