1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #ifndef TENSORFLOW_C_KERNELS_H_
17 #define TENSORFLOW_C_KERNELS_H_
18 
19 #include "tensorflow/c/c_api.h"
20 
21 #ifdef __cplusplus
22 extern "C" {
23 #endif
24 
25 // --------------------------------------------------------------------------
26 // C API for TensorFlow Kernels.
27 //
28 // This API allows developers to register custom kernel implementations for
29 // TensorFlow.
30 //
31 // See c_api.h header comments for a discussion about API conventions.
32 //
33 // Users wishing to extend TensorFlow with new kernels will call
34 // `TF_NewKernelBuilder`. The resulting kernel builder can be registered with
35 // `TF_RegisterKernelBuilder`, which will allow TF to construct user-provided
36 // kernels when necessary.
37 
38 typedef struct TF_KernelBuilder TF_KernelBuilder;
39 typedef struct TF_OpKernelConstruction TF_OpKernelConstruction;
40 typedef struct TF_OpKernelContext TF_OpKernelContext;
41 
42 // Allocates a new kernel builder and returns a pointer to it.
43 //
44 // If non-null, TensorFlow will call create_func when it needs to instantiate
45 // the kernel. The pointer returned by create_func will be passed to
46 // compute_func and delete_func, thereby functioning as a "this" pointer for
47 // referring to kernel instances.
48 //
49 // The TF_OpKernelConstruction pointer passed to create_func is owned by
50 // TensorFlow and will be deleted once create_func returns. It must not be used
51 // after this.
52 //
53 // When TensorFlow needs to perform a computation with this kernel, it will
54 // call compute_func. This function will receive the pointer returned by
55 // create_func (or null if no create_func was provided), along with the inputs
56 // to the computation.
57 //
58 // The TF_OpKernelContext pointer received by compute_func is owned by
59 // TensorFlow and will be deleted once compute_func returns. It must not be used
60 // after this.
61 //
62 // Finally, when TensorFlow no longer needs the kernel, it will call
63 // delete_func if one is provided. This function will receive the pointer
64 // returned in `create_func` or nullptr if no `create_func` was provided.
65 //
66 // The caller should pass the result of this function to
67 // TF_RegisterKernelBuilder, which will take ownership of the pointer. If, for
68 // some reason, the kernel builder will not be registered, the caller should
69 // delete it with TF_DeleteKernelBuilder.
70 TF_CAPI_EXPORT extern TF_KernelBuilder* TF_NewKernelBuilder(
71     const char* op_name, const char* device_name,
72     void* (*create_func)(TF_OpKernelConstruction*),
73     void (*compute_func)(void*, TF_OpKernelContext*),
74     void (*delete_func)(void*));
75 
76 // Register the given kernel builder with the TensorFlow runtime. If
77 // registration fails, the given status will be populated.
78 //
79 // This call takes ownership of the `builder` pointer.
80 TF_CAPI_EXPORT extern void TF_RegisterKernelBuilder(const char* kernel_name,
81                                                     TF_KernelBuilder* builder,
82                                                     TF_Status* status);
83 
84 // Deletes the given TF_KernelBuilder. This should be called only if the kernel
85 // builder is not registered with TensorFlow via TF_RegisterKernelBuilder.
86 TF_CAPI_EXPORT extern void TF_DeleteKernelBuilder(TF_KernelBuilder* builder);
87 
88 // --------------------------------------------------------------------------
89 // OpKernelContext routines
90 
91 // TF_NumInputs returns the number of inputs available in ctx.
92 TF_CAPI_EXPORT extern int TF_NumInputs(TF_OpKernelContext* ctx);
93 
94 // TF_NumOutputs returns the number of outputs to be placed in *ctx by the
95 // kernel.
96 TF_CAPI_EXPORT extern int TF_NumOutputs(TF_OpKernelContext* ctx);
97 
98 // Retrieves the ith input from ctx. If TF_GetCode(status) is TF_OK, *tensor is
99 // populated and its ownership is passed to the caller. In any other case,
100 // *tensor is not modified.
101 //
102 // If i < 0 or i >= TF_NumInputs(ctx), *status is set to TF_OUT_OF_RANGE.
103 TF_CAPI_EXPORT extern void TF_GetInput(TF_OpKernelContext* ctx, int i,
104                                        TF_Tensor** tensor, TF_Status* status);
105 
106 // Sets the ith output of ctx to tensor. If TF_GetCode(status) is anything but
107 // TF_OK, ctx is left unmodified.
108 //
109 // If i < 0 or i >= TF_NumOutputs(ctx), *status is set to TF_OUT_OF_RANGE.
110 TF_CAPI_EXPORT extern void TF_SetOutput(TF_OpKernelContext* ctx, int i,
111                                         const TF_Tensor* tensor,
112                                         TF_Status* status);
113 
114 // Notifies the given OpKernelConstruction that kernel construction has failed.
115 TF_CAPI_EXPORT extern void TF_OpKernelConstruction_Failure(
116     TF_OpKernelConstruction* ctx, TF_Status* status);
117 
118 // Notifies the given OpKernelContext that the kernel's compute function has
119 // failed.
120 TF_CAPI_EXPORT extern void TF_OpKernelContext_Failure(TF_OpKernelContext* ctx,
121                                                       TF_Status* status);
122 
123 // Returns the expected output data type of the ith output. If i < 0 or
124 // i >= TF_NumOutputs(ctx), the program aborts.
125 TF_CAPI_EXPORT extern TF_DataType TF_ExpectedOutputDataType(
126     TF_OpKernelContext* ctx, int i);
127 
128 // Returns the step ID of the given context.
129 TF_CAPI_EXPORT extern int64_t TF_StepId(TF_OpKernelContext* ctx);
130 
131 // Interprets the named kernel construction attribute as a TF_DataType and
132 // places it into *val. *status is set to TF_OK.
133 //
134 // If the attribute could not be found or could not be interpreted as
135 // TF_DataType, *status is populated with an error.
136 TF_CAPI_EXPORT extern void TF_OpKernelConstruction_GetAttrType(
137     TF_OpKernelConstruction* ctx, const char* attr_name, TF_DataType* val,
138     TF_Status* status);
139 
140 #ifdef __cplusplus
141 } /* end extern "C" */
142 #endif
143 
144 #endif  // TENSORFLOW_C_KERNELS_H_
145