1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include <memory>
17 
18 #include "tensorflow/c/c_api_internal.h"
19 #include "tensorflow/c/kernels.h"
20 #include "tensorflow/c/tf_status_helper.h"
21 #include "tensorflow/core/framework/kernel_def_builder.h"
22 #include "tensorflow/core/framework/op_kernel.h"
23 
24 // This file forms the basis of a stable ABI for third-party kernel
25 // implementations. It is crucial that changes to this file are made cautiously
26 // and with a focus on maintaining both source and binary compatibility.
27 
28 struct TF_KernelBuilder {
29   ::tensorflow::KernelDefBuilder* cc_builder;
30 
31   void* (*create_function)(TF_OpKernelConstruction*);
32   void (*compute_function)(void*, TF_OpKernelContext*);
33   void (*delete_function)(void*);
34 };
35 
TF_NewKernelBuilder(const char * op_name,const char * device_name,void * (* create_func)(TF_OpKernelConstruction *),void (* compute_func)(void *,TF_OpKernelContext *),void (* delete_func)(void *))36 TF_KernelBuilder* TF_NewKernelBuilder(
37     const char* op_name, const char* device_name,
38     void* (*create_func)(TF_OpKernelConstruction*),
39     void (*compute_func)(void*, TF_OpKernelContext*),
40     void (*delete_func)(void*)) {
41   TF_KernelBuilder* result = new TF_KernelBuilder;
42   result->cc_builder = new ::tensorflow::KernelDefBuilder(op_name);
43   result->cc_builder->Device(device_name);
44   result->create_function = create_func;
45   result->compute_function = compute_func;
46   result->delete_function = delete_func;
47   return result;
48 }
49 
TF_DeleteKernelBuilder(TF_KernelBuilder * builder)50 void TF_DeleteKernelBuilder(TF_KernelBuilder* builder) {
51   if (builder != nullptr) {
52     delete builder->cc_builder;
53     delete builder;
54   }
55 }
56 
57 namespace tensorflow {
58 namespace {
59 
60 // An OpKernel whose methods delegate to C function pointers.
61 class COpKernel : public OpKernel {
62  public:
COpKernel(OpKernelConstruction * ctx,void * (* create_func)(TF_OpKernelConstruction *),void (* compute_func)(void *,TF_OpKernelContext *),void (* delete_func)(void *))63   explicit COpKernel(OpKernelConstruction* ctx,
64                      void* (*create_func)(TF_OpKernelConstruction*),
65                      void (*compute_func)(void*, TF_OpKernelContext*),
66                      void (*delete_func)(void*))
67       : OpKernel(ctx), compute_func_(compute_func), delete_func_(delete_func) {
68     if (create_func != nullptr) {
69       c_kernel_ =
70           (*create_func)(reinterpret_cast<TF_OpKernelConstruction*>(ctx));
71     } else {
72       c_kernel_ = nullptr;
73     }
74   }
75 
Compute(OpKernelContext * ctx)76   void Compute(OpKernelContext* ctx) override {
77     (*compute_func_)(c_kernel_, reinterpret_cast<TF_OpKernelContext*>(ctx));
78   }
79 
~COpKernel()80   ~COpKernel() override {
81     if (delete_func_ != nullptr) {
82       (*delete_func_)(c_kernel_);
83     }
84   }
85 
86  private:
87   void (*compute_func_)(void*, TF_OpKernelContext* context);
88   void (*delete_func_)(void*);
89   void* c_kernel_;
90 };
91 
92 // A KernelFactory that returns COpKernel instances.
93 class KernelBuilderFactory
94     : public ::tensorflow::kernel_factory::OpKernelFactory {
95  public:
KernelBuilderFactory(TF_KernelBuilder * builder)96   explicit KernelBuilderFactory(TF_KernelBuilder* builder)
97       : builder_(builder) {}
Create(::tensorflow::OpKernelConstruction * context)98   ::tensorflow::OpKernel* Create(
99       ::tensorflow::OpKernelConstruction* context) override {
100     return new ::tensorflow::COpKernel(context, builder_->create_function,
101                                        builder_->compute_function,
102                                        builder_->delete_function);
103   }
~KernelBuilderFactory()104   ~KernelBuilderFactory() override { TF_DeleteKernelBuilder(builder_); }
105 
106  private:
107   TF_KernelBuilder* builder_;
108 };
109 }  // namespace
110 }  // namespace tensorflow
111 
TF_RegisterKernelBuilder(const char * name,TF_KernelBuilder * builder,TF_Status * status)112 void TF_RegisterKernelBuilder(const char* name, TF_KernelBuilder* builder,
113                               TF_Status* status) {
114   using tensorflow::register_kernel::Name;
115 
116   tensorflow::kernel_factory::OpKernelRegistrar(
117       builder->cc_builder->Build(), name,
118       absl::make_unique<tensorflow::KernelBuilderFactory>(builder));
119 
120   TF_SetStatus(status, TF_OK, "");
121 }
122 
TF_NumInputs(TF_OpKernelContext * ctx)123 int TF_NumInputs(TF_OpKernelContext* ctx) {
124   auto* cc_ctx = reinterpret_cast<::tensorflow::OpKernelContext*>(ctx);
125   return cc_ctx->num_inputs();
126 }
127 
TF_NumOutputs(TF_OpKernelContext * ctx)128 int TF_NumOutputs(TF_OpKernelContext* ctx) {
129   auto* cc_ctx = reinterpret_cast<::tensorflow::OpKernelContext*>(ctx);
130   return cc_ctx->num_outputs();
131 }
132 
TF_GetInput(TF_OpKernelContext * ctx,int i,TF_Tensor ** tensor,TF_Status * status)133 void TF_GetInput(TF_OpKernelContext* ctx, int i, TF_Tensor** tensor,
134                  TF_Status* status) {
135   auto* cc_ctx = reinterpret_cast<::tensorflow::OpKernelContext*>(ctx);
136   if (i < 0 || i >= cc_ctx->num_inputs()) {
137     TF_SetStatus(status, TF_OUT_OF_RANGE, "input index out of range");
138     return;
139   }
140   const ::tensorflow::Tensor& cc_tensor(cc_ctx->input(i));
141   TF_Tensor* result = ::tensorflow::TF_TensorFromTensor(cc_tensor, status);
142   if (TF_GetCode(status) == TF_OK) {
143     *tensor = result;
144   }
145 }
146 
TF_SetOutput(TF_OpKernelContext * ctx,int i,const TF_Tensor * tensor,TF_Status * status)147 void TF_SetOutput(TF_OpKernelContext* ctx, int i, const TF_Tensor* tensor,
148                   TF_Status* status) {
149   auto* cc_ctx = reinterpret_cast<::tensorflow::OpKernelContext*>(ctx);
150   if (i < 0 || i >= cc_ctx->num_inputs()) {
151     TF_SetStatus(status, TF_OUT_OF_RANGE, "input index out of range");
152     return;
153   }
154   ::tensorflow::Tensor cc_tensor;
155   ::tensorflow::Status s = ::tensorflow::TF_TensorToTensor(tensor, &cc_tensor);
156   TF_SetStatus(status, TF_OK, "");
157   ::tensorflow::Set_TF_Status_from_Status(status, s);
158   if (s.ok()) {
159     cc_ctx->set_output(i, cc_tensor);
160   }
161 }
162 
TF_OpKernelConstruction_Failure(TF_OpKernelConstruction * ctx,TF_Status * status)163 void TF_OpKernelConstruction_Failure(TF_OpKernelConstruction* ctx,
164                                      TF_Status* status) {
165   auto* cc_ctx = reinterpret_cast<::tensorflow::OpKernelConstruction*>(ctx);
166   ::tensorflow::Status s(::tensorflow::StatusFromTF_Status(status));
167   cc_ctx->CtxFailure(s);
168 }
169 
TF_OpKernelContext_Failure(TF_OpKernelContext * ctx,TF_Status * status)170 void TF_OpKernelContext_Failure(TF_OpKernelContext* ctx, TF_Status* status) {
171   auto* cc_ctx = reinterpret_cast<::tensorflow::OpKernelContext*>(ctx);
172   ::tensorflow::Status s(::tensorflow::StatusFromTF_Status(status));
173   cc_ctx->CtxFailure(s);
174 }
175 
176 #define DEFINE_TF_GETATTR(func, c_type, cc_type)                               \
177   void TF_OpKernelConstruction_GetAttr##func(TF_OpKernelConstruction* ctx,     \
178                                              const char* attr_name,            \
179                                              c_type* val, TF_Status* status) { \
180     TF_SetStatus(status, TF_OK, "");                                           \
181     cc_type v;                                                                 \
182     auto* cc_ctx = reinterpret_cast<::tensorflow::OpKernelConstruction*>(ctx); \
183     ::tensorflow::Status s = cc_ctx->GetAttr(attr_name, &v);                   \
184     ::tensorflow::Set_TF_Status_from_Status(status, s);                        \
185     if (s.ok()) {                                                              \
186       *val = static_cast<c_type>(v);                                           \
187     }                                                                          \
188   }
189 
DEFINE_TF_GETATTR(Type,TF_DataType,tensorflow::DataType)190 DEFINE_TF_GETATTR(Type, TF_DataType, tensorflow::DataType)
191 
192 TF_DataType TF_ExpectedOutputDataType(TF_OpKernelContext* ctx, int i) {
193   auto* cc_ctx = reinterpret_cast<::tensorflow::OpKernelContext*>(ctx);
194   return static_cast<TF_DataType>(cc_ctx->expected_output_dtype(i));
195 }
196 
TF_StepId(TF_OpKernelContext * ctx)197 int64_t TF_StepId(TF_OpKernelContext* ctx) {
198   return reinterpret_cast<::tensorflow::OpKernelContext*>(ctx)->step_id();
199 }
200