1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include <memory>
17
18 #include "tensorflow/c/c_api_internal.h"
19 #include "tensorflow/c/kernels.h"
20 #include "tensorflow/c/tf_status_helper.h"
21 #include "tensorflow/core/framework/kernel_def_builder.h"
22 #include "tensorflow/core/framework/op_kernel.h"
23
24 // This file forms the basis of a stable ABI for third-party kernel
25 // implementations. It is crucial that changes to this file are made cautiously
26 // and with a focus on maintaining both source and binary compatibility.
27
28 struct TF_KernelBuilder {
29 ::tensorflow::KernelDefBuilder* cc_builder;
30
31 void* (*create_function)(TF_OpKernelConstruction*);
32 void (*compute_function)(void*, TF_OpKernelContext*);
33 void (*delete_function)(void*);
34 };
35
TF_NewKernelBuilder(const char * op_name,const char * device_name,void * (* create_func)(TF_OpKernelConstruction *),void (* compute_func)(void *,TF_OpKernelContext *),void (* delete_func)(void *))36 TF_KernelBuilder* TF_NewKernelBuilder(
37 const char* op_name, const char* device_name,
38 void* (*create_func)(TF_OpKernelConstruction*),
39 void (*compute_func)(void*, TF_OpKernelContext*),
40 void (*delete_func)(void*)) {
41 TF_KernelBuilder* result = new TF_KernelBuilder;
42 result->cc_builder = new ::tensorflow::KernelDefBuilder(op_name);
43 result->cc_builder->Device(device_name);
44 result->create_function = create_func;
45 result->compute_function = compute_func;
46 result->delete_function = delete_func;
47 return result;
48 }
49
TF_DeleteKernelBuilder(TF_KernelBuilder * builder)50 void TF_DeleteKernelBuilder(TF_KernelBuilder* builder) {
51 if (builder != nullptr) {
52 delete builder->cc_builder;
53 delete builder;
54 }
55 }
56
57 namespace tensorflow {
58 namespace {
59
60 // An OpKernel whose methods delegate to C function pointers.
61 class COpKernel : public OpKernel {
62 public:
COpKernel(OpKernelConstruction * ctx,void * (* create_func)(TF_OpKernelConstruction *),void (* compute_func)(void *,TF_OpKernelContext *),void (* delete_func)(void *))63 explicit COpKernel(OpKernelConstruction* ctx,
64 void* (*create_func)(TF_OpKernelConstruction*),
65 void (*compute_func)(void*, TF_OpKernelContext*),
66 void (*delete_func)(void*))
67 : OpKernel(ctx), compute_func_(compute_func), delete_func_(delete_func) {
68 if (create_func != nullptr) {
69 c_kernel_ =
70 (*create_func)(reinterpret_cast<TF_OpKernelConstruction*>(ctx));
71 } else {
72 c_kernel_ = nullptr;
73 }
74 }
75
Compute(OpKernelContext * ctx)76 void Compute(OpKernelContext* ctx) override {
77 (*compute_func_)(c_kernel_, reinterpret_cast<TF_OpKernelContext*>(ctx));
78 }
79
~COpKernel()80 ~COpKernel() override {
81 if (delete_func_ != nullptr) {
82 (*delete_func_)(c_kernel_);
83 }
84 }
85
86 private:
87 void (*compute_func_)(void*, TF_OpKernelContext* context);
88 void (*delete_func_)(void*);
89 void* c_kernel_;
90 };
91
92 // A KernelFactory that returns COpKernel instances.
93 class KernelBuilderFactory
94 : public ::tensorflow::kernel_factory::OpKernelFactory {
95 public:
KernelBuilderFactory(TF_KernelBuilder * builder)96 explicit KernelBuilderFactory(TF_KernelBuilder* builder)
97 : builder_(builder) {}
Create(::tensorflow::OpKernelConstruction * context)98 ::tensorflow::OpKernel* Create(
99 ::tensorflow::OpKernelConstruction* context) override {
100 return new ::tensorflow::COpKernel(context, builder_->create_function,
101 builder_->compute_function,
102 builder_->delete_function);
103 }
~KernelBuilderFactory()104 ~KernelBuilderFactory() override { TF_DeleteKernelBuilder(builder_); }
105
106 private:
107 TF_KernelBuilder* builder_;
108 };
109 } // namespace
110 } // namespace tensorflow
111
TF_RegisterKernelBuilder(const char * name,TF_KernelBuilder * builder,TF_Status * status)112 void TF_RegisterKernelBuilder(const char* name, TF_KernelBuilder* builder,
113 TF_Status* status) {
114 using tensorflow::register_kernel::Name;
115
116 tensorflow::kernel_factory::OpKernelRegistrar(
117 builder->cc_builder->Build(), name,
118 absl::make_unique<tensorflow::KernelBuilderFactory>(builder));
119
120 TF_SetStatus(status, TF_OK, "");
121 }
122
TF_NumInputs(TF_OpKernelContext * ctx)123 int TF_NumInputs(TF_OpKernelContext* ctx) {
124 auto* cc_ctx = reinterpret_cast<::tensorflow::OpKernelContext*>(ctx);
125 return cc_ctx->num_inputs();
126 }
127
TF_NumOutputs(TF_OpKernelContext * ctx)128 int TF_NumOutputs(TF_OpKernelContext* ctx) {
129 auto* cc_ctx = reinterpret_cast<::tensorflow::OpKernelContext*>(ctx);
130 return cc_ctx->num_outputs();
131 }
132
TF_GetInput(TF_OpKernelContext * ctx,int i,TF_Tensor ** tensor,TF_Status * status)133 void TF_GetInput(TF_OpKernelContext* ctx, int i, TF_Tensor** tensor,
134 TF_Status* status) {
135 auto* cc_ctx = reinterpret_cast<::tensorflow::OpKernelContext*>(ctx);
136 if (i < 0 || i >= cc_ctx->num_inputs()) {
137 TF_SetStatus(status, TF_OUT_OF_RANGE, "input index out of range");
138 return;
139 }
140 const ::tensorflow::Tensor& cc_tensor(cc_ctx->input(i));
141 TF_Tensor* result = ::tensorflow::TF_TensorFromTensor(cc_tensor, status);
142 if (TF_GetCode(status) == TF_OK) {
143 *tensor = result;
144 }
145 }
146
TF_SetOutput(TF_OpKernelContext * ctx,int i,const TF_Tensor * tensor,TF_Status * status)147 void TF_SetOutput(TF_OpKernelContext* ctx, int i, const TF_Tensor* tensor,
148 TF_Status* status) {
149 auto* cc_ctx = reinterpret_cast<::tensorflow::OpKernelContext*>(ctx);
150 if (i < 0 || i >= cc_ctx->num_inputs()) {
151 TF_SetStatus(status, TF_OUT_OF_RANGE, "input index out of range");
152 return;
153 }
154 ::tensorflow::Tensor cc_tensor;
155 ::tensorflow::Status s = ::tensorflow::TF_TensorToTensor(tensor, &cc_tensor);
156 TF_SetStatus(status, TF_OK, "");
157 ::tensorflow::Set_TF_Status_from_Status(status, s);
158 if (s.ok()) {
159 cc_ctx->set_output(i, cc_tensor);
160 }
161 }
162
TF_OpKernelConstruction_Failure(TF_OpKernelConstruction * ctx,TF_Status * status)163 void TF_OpKernelConstruction_Failure(TF_OpKernelConstruction* ctx,
164 TF_Status* status) {
165 auto* cc_ctx = reinterpret_cast<::tensorflow::OpKernelConstruction*>(ctx);
166 ::tensorflow::Status s(::tensorflow::StatusFromTF_Status(status));
167 cc_ctx->CtxFailure(s);
168 }
169
TF_OpKernelContext_Failure(TF_OpKernelContext * ctx,TF_Status * status)170 void TF_OpKernelContext_Failure(TF_OpKernelContext* ctx, TF_Status* status) {
171 auto* cc_ctx = reinterpret_cast<::tensorflow::OpKernelContext*>(ctx);
172 ::tensorflow::Status s(::tensorflow::StatusFromTF_Status(status));
173 cc_ctx->CtxFailure(s);
174 }
175
176 #define DEFINE_TF_GETATTR(func, c_type, cc_type) \
177 void TF_OpKernelConstruction_GetAttr##func(TF_OpKernelConstruction* ctx, \
178 const char* attr_name, \
179 c_type* val, TF_Status* status) { \
180 TF_SetStatus(status, TF_OK, ""); \
181 cc_type v; \
182 auto* cc_ctx = reinterpret_cast<::tensorflow::OpKernelConstruction*>(ctx); \
183 ::tensorflow::Status s = cc_ctx->GetAttr(attr_name, &v); \
184 ::tensorflow::Set_TF_Status_from_Status(status, s); \
185 if (s.ok()) { \
186 *val = static_cast<c_type>(v); \
187 } \
188 }
189
DEFINE_TF_GETATTR(Type,TF_DataType,tensorflow::DataType)190 DEFINE_TF_GETATTR(Type, TF_DataType, tensorflow::DataType)
191
192 TF_DataType TF_ExpectedOutputDataType(TF_OpKernelContext* ctx, int i) {
193 auto* cc_ctx = reinterpret_cast<::tensorflow::OpKernelContext*>(ctx);
194 return static_cast<TF_DataType>(cc_ctx->expected_output_dtype(i));
195 }
196
TF_StepId(TF_OpKernelContext * ctx)197 int64_t TF_StepId(TF_OpKernelContext* ctx) {
198 return reinterpret_cast<::tensorflow::OpKernelContext*>(ctx)->step_id();
199 }
200