1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/compiler/aot/embedded_protocol_buffers.h"
17
18 #include <memory>
19 #include <string>
20
21 #include "absl/memory/memory.h"
22 #include "absl/strings/str_replace.h"
23 #include "llvm/ADT/Triple.h"
24 #include "llvm/IR/GlobalVariable.h"
25 #include "llvm/IR/LLVMContext.h"
26 #include "llvm/IR/LegacyPassManager.h"
27 #include "llvm/IR/Module.h"
28 #include "llvm/Support/TargetRegistry.h"
29 #include "llvm/Target/TargetMachine.h"
30 #include "llvm/Target/TargetOptions.h"
31 #include "tensorflow/compiler/xla/service/llvm_ir/llvm_util.h"
32 #include "tensorflow/compiler/xla/util.h"
33
34 namespace tensorflow {
35 namespace tfcompile {
36
37 using xla::llvm_ir::AsStringRef;
38
AddEmbeddedProtocolBufferToLlvmModule(llvm::Module * module,const::tensorflow::protobuf::MessageLite & proto,absl::string_view unique_identifier,string * protobuf_array_symbol_name,int64 * protobuf_array_size)39 static void AddEmbeddedProtocolBufferToLlvmModule(
40 llvm::Module* module, const ::tensorflow::protobuf::MessageLite& proto,
41 absl::string_view unique_identifier, string* protobuf_array_symbol_name,
42 int64* protobuf_array_size) {
43 string protobuf_array_contents = proto.SerializeAsString();
44 *protobuf_array_symbol_name =
45 absl::StrCat(unique_identifier, "_protobuf_array_contents");
46 *protobuf_array_size = protobuf_array_contents.size();
47
48 llvm::Constant* protobuf_array_initializer =
49 llvm::ConstantDataArray::getString(module->getContext(),
50 AsStringRef(protobuf_array_contents),
51 /*AddNull=*/false);
52 new llvm::GlobalVariable(
53 *module, protobuf_array_initializer->getType(),
54 /*isConstant=*/true, llvm::GlobalValue::ExternalLinkage,
55 protobuf_array_initializer, AsStringRef(*protobuf_array_symbol_name));
56 }
57
CreateCPPShimExpression(absl::string_view qualified_cpp_protobuf_name,absl::string_view protobuf_array_symbol_name,int64 protobuf_array_size)58 static string CreateCPPShimExpression(
59 absl::string_view qualified_cpp_protobuf_name,
60 absl::string_view protobuf_array_symbol_name, int64 protobuf_array_size) {
61 string code =
62 "[]() {\n"
63 " {{PROTOBUF_NAME}}* proto = new {{PROTOBUF_NAME}};\n"
64 " proto->ParseFromArray(&{{ARRAY_SYMBOL}}[0], {{ARRAY_SIZE}});\n"
65 " return proto;\n"
66 " }()";
67
68 return absl::StrReplaceAll(
69 code,
70 {
71 {"{{ARRAY_SYMBOL}}", absl::StrCat(protobuf_array_symbol_name)},
72 {"{{ARRAY_SIZE}}", absl::StrCat(protobuf_array_size)},
73 {"{{PROTOBUF_NAME}}", absl::StrCat(qualified_cpp_protobuf_name)},
74 });
75 }
76
CodegenModule(llvm::TargetMachine * target_machine,std::unique_ptr<llvm::Module> module)77 static StatusOr<string> CodegenModule(llvm::TargetMachine* target_machine,
78 std::unique_ptr<llvm::Module> module) {
79 llvm::SmallVector<char, 0> stream_buffer;
80 llvm::raw_svector_ostream ostream(stream_buffer);
81 llvm::legacy::PassManager codegen_passes;
82
83 if (target_machine->addPassesToEmitFile(codegen_passes, ostream, nullptr,
84 llvm::CGFT_ObjectFile)) {
85 return xla::InternalError(
86 "Could not create pass pipeline to generate object file");
87 }
88
89 codegen_passes.run(*module);
90
91 return string(stream_buffer.begin(), stream_buffer.end());
92 }
93
94 static StatusOr<std::unique_ptr<llvm::TargetMachine>>
GetTargetMachineFromTriple(absl::string_view target_triple)95 GetTargetMachineFromTriple(absl::string_view target_triple) {
96 std::string error;
97 std::string normalized_triple =
98 llvm::Triple::normalize(AsStringRef(absl::string_view(target_triple)));
99 const llvm::Target* target =
100 llvm::TargetRegistry::lookupTarget(normalized_triple, error);
101 if (target == nullptr) {
102 return xla::InternalError("TargetRegistry::lookupTarget failed: %s",
103 error.c_str());
104 }
105
106 return absl::WrapUnique(target->createTargetMachine(
107 normalized_triple, /*CPU=*/"",
108 /*Features=*/"", llvm::TargetOptions(), llvm::None));
109 }
110
CreateEmbeddedProtocolBuffers(absl::string_view target_triple,absl::Span<const ProtobufToEmbed> protobufs_to_embed)111 StatusOr<EmbeddedProtocolBuffers> CreateEmbeddedProtocolBuffers(
112 absl::string_view target_triple,
113 absl::Span<const ProtobufToEmbed> protobufs_to_embed) {
114 TF_ASSIGN_OR_RETURN(std::unique_ptr<llvm::TargetMachine> target_machine,
115 GetTargetMachineFromTriple(target_triple));
116
117 llvm::LLVMContext llvm_context;
118 std::unique_ptr<llvm::Module> module_with_serialized_proto =
119 absl::make_unique<llvm::Module>("embedded_data_module", llvm_context);
120
121 EmbeddedProtocolBuffers result;
122
123 for (const ProtobufToEmbed& protobuf_to_embed : protobufs_to_embed) {
124 string cpp_shim, cpp_variable_decl;
125 if (protobuf_to_embed.message) {
126 string protobuf_array_symbol_name;
127 int64 protobuf_array_size;
128
129 AddEmbeddedProtocolBufferToLlvmModule(
130 module_with_serialized_proto.get(), *protobuf_to_embed.message,
131 protobuf_to_embed.symbol_prefix, &protobuf_array_symbol_name,
132 &protobuf_array_size);
133 cpp_shim = CreateCPPShimExpression(
134 protobuf_to_embed.qualified_cpp_protobuf_name,
135 protobuf_array_symbol_name, protobuf_array_size);
136
137 cpp_variable_decl =
138 absl::StrCat("extern \"C\" char ", protobuf_array_symbol_name, "[];");
139 } else {
140 cpp_shim = "nullptr";
141 }
142 result.cpp_shims.push_back({cpp_shim, cpp_variable_decl});
143 }
144
145 TF_ASSIGN_OR_RETURN(result.object_file_data,
146 CodegenModule(target_machine.get(),
147 std::move(module_with_serialized_proto)));
148 return result;
149 }
150
151 } // namespace tfcompile
152 } // namespace tensorflow
153