1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/compiler/aot/embedded_protocol_buffers.h"
17 
18 #include <memory>
19 #include <string>
20 
21 #include "absl/memory/memory.h"
22 #include "absl/strings/str_replace.h"
23 #include "llvm/ADT/Triple.h"
24 #include "llvm/IR/GlobalVariable.h"
25 #include "llvm/IR/LLVMContext.h"
26 #include "llvm/IR/LegacyPassManager.h"
27 #include "llvm/IR/Module.h"
28 #include "llvm/Support/TargetRegistry.h"
29 #include "llvm/Target/TargetMachine.h"
30 #include "llvm/Target/TargetOptions.h"
31 #include "tensorflow/compiler/xla/service/llvm_ir/llvm_util.h"
32 #include "tensorflow/compiler/xla/util.h"
33 
34 namespace tensorflow {
35 namespace tfcompile {
36 
37 using xla::llvm_ir::AsStringRef;
38 
AddEmbeddedProtocolBufferToLlvmModule(llvm::Module * module,const::tensorflow::protobuf::MessageLite & proto,absl::string_view unique_identifier,string * protobuf_array_symbol_name,int64 * protobuf_array_size)39 static void AddEmbeddedProtocolBufferToLlvmModule(
40     llvm::Module* module, const ::tensorflow::protobuf::MessageLite& proto,
41     absl::string_view unique_identifier, string* protobuf_array_symbol_name,
42     int64* protobuf_array_size) {
43   string protobuf_array_contents = proto.SerializeAsString();
44   *protobuf_array_symbol_name =
45       absl::StrCat(unique_identifier, "_protobuf_array_contents");
46   *protobuf_array_size = protobuf_array_contents.size();
47 
48   llvm::Constant* protobuf_array_initializer =
49       llvm::ConstantDataArray::getString(module->getContext(),
50                                          AsStringRef(protobuf_array_contents),
51                                          /*AddNull=*/false);
52   new llvm::GlobalVariable(
53       *module, protobuf_array_initializer->getType(),
54       /*isConstant=*/true, llvm::GlobalValue::ExternalLinkage,
55       protobuf_array_initializer, AsStringRef(*protobuf_array_symbol_name));
56 }
57 
CreateCPPShimExpression(absl::string_view qualified_cpp_protobuf_name,absl::string_view protobuf_array_symbol_name,int64 protobuf_array_size)58 static string CreateCPPShimExpression(
59     absl::string_view qualified_cpp_protobuf_name,
60     absl::string_view protobuf_array_symbol_name, int64 protobuf_array_size) {
61   string code =
62       "[]() {\n"
63       "    {{PROTOBUF_NAME}}* proto = new {{PROTOBUF_NAME}};\n"
64       "    proto->ParseFromArray(&{{ARRAY_SYMBOL}}[0], {{ARRAY_SIZE}});\n"
65       "    return proto;\n"
66       "  }()";
67 
68   return absl::StrReplaceAll(
69       code,
70       {
71           {"{{ARRAY_SYMBOL}}", absl::StrCat(protobuf_array_symbol_name)},
72           {"{{ARRAY_SIZE}}", absl::StrCat(protobuf_array_size)},
73           {"{{PROTOBUF_NAME}}", absl::StrCat(qualified_cpp_protobuf_name)},
74       });
75 }
76 
CodegenModule(llvm::TargetMachine * target_machine,std::unique_ptr<llvm::Module> module)77 static StatusOr<string> CodegenModule(llvm::TargetMachine* target_machine,
78                                       std::unique_ptr<llvm::Module> module) {
79   llvm::SmallVector<char, 0> stream_buffer;
80   llvm::raw_svector_ostream ostream(stream_buffer);
81   llvm::legacy::PassManager codegen_passes;
82 
83   if (target_machine->addPassesToEmitFile(codegen_passes, ostream, nullptr,
84                                           llvm::CGFT_ObjectFile)) {
85     return xla::InternalError(
86         "Could not create pass pipeline to generate object file");
87   }
88 
89   codegen_passes.run(*module);
90 
91   return string(stream_buffer.begin(), stream_buffer.end());
92 }
93 
94 static StatusOr<std::unique_ptr<llvm::TargetMachine>>
GetTargetMachineFromTriple(absl::string_view target_triple)95 GetTargetMachineFromTriple(absl::string_view target_triple) {
96   std::string error;
97   std::string normalized_triple =
98       llvm::Triple::normalize(AsStringRef(absl::string_view(target_triple)));
99   const llvm::Target* target =
100       llvm::TargetRegistry::lookupTarget(normalized_triple, error);
101   if (target == nullptr) {
102     return xla::InternalError("TargetRegistry::lookupTarget failed: %s",
103                               error.c_str());
104   }
105 
106   return absl::WrapUnique(target->createTargetMachine(
107       normalized_triple, /*CPU=*/"",
108       /*Features=*/"", llvm::TargetOptions(), llvm::None));
109 }
110 
CreateEmbeddedProtocolBuffers(absl::string_view target_triple,absl::Span<const ProtobufToEmbed> protobufs_to_embed)111 StatusOr<EmbeddedProtocolBuffers> CreateEmbeddedProtocolBuffers(
112     absl::string_view target_triple,
113     absl::Span<const ProtobufToEmbed> protobufs_to_embed) {
114   TF_ASSIGN_OR_RETURN(std::unique_ptr<llvm::TargetMachine> target_machine,
115                       GetTargetMachineFromTriple(target_triple));
116 
117   llvm::LLVMContext llvm_context;
118   std::unique_ptr<llvm::Module> module_with_serialized_proto =
119       absl::make_unique<llvm::Module>("embedded_data_module", llvm_context);
120 
121   EmbeddedProtocolBuffers result;
122 
123   for (const ProtobufToEmbed& protobuf_to_embed : protobufs_to_embed) {
124     string cpp_shim, cpp_variable_decl;
125     if (protobuf_to_embed.message) {
126       string protobuf_array_symbol_name;
127       int64 protobuf_array_size;
128 
129       AddEmbeddedProtocolBufferToLlvmModule(
130           module_with_serialized_proto.get(), *protobuf_to_embed.message,
131           protobuf_to_embed.symbol_prefix, &protobuf_array_symbol_name,
132           &protobuf_array_size);
133       cpp_shim = CreateCPPShimExpression(
134           protobuf_to_embed.qualified_cpp_protobuf_name,
135           protobuf_array_symbol_name, protobuf_array_size);
136 
137       cpp_variable_decl =
138           absl::StrCat("extern \"C\" char ", protobuf_array_symbol_name, "[];");
139     } else {
140       cpp_shim = "nullptr";
141     }
142     result.cpp_shims.push_back({cpp_shim, cpp_variable_decl});
143   }
144 
145   TF_ASSIGN_OR_RETURN(result.object_file_data,
146                       CodegenModule(target_machine.get(),
147                                     std::move(module_with_serialized_proto)));
148   return result;
149 }
150 
151 }  // namespace tfcompile
152 }  // namespace tensorflow
153