1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 #ifndef TENSORFLOW_C_KERNELS_H_ 17 #define TENSORFLOW_C_KERNELS_H_ 18 19 #include "tensorflow/c/c_api.h" 20 21 #ifdef __cplusplus 22 extern "C" { 23 #endif 24 25 // -------------------------------------------------------------------------- 26 // C API for TensorFlow Kernels. 27 // 28 // This API allows developers to register custom kernel implementations for 29 // TensorFlow. 30 // 31 // See c_api.h header comments for a discussion about API conventions. 32 // 33 // Users wishing to extend TensorFlow with new kernels will call 34 // `TF_NewKernelBuilder`. The resulting kernel builder can be registered with 35 // `TF_RegisterKernelBuilder`, which will allow TF to construct user-provided 36 // kernels when necessary. 37 38 typedef struct TF_KernelBuilder TF_KernelBuilder; 39 typedef struct TF_OpKernelConstruction TF_OpKernelConstruction; 40 typedef struct TF_OpKernelContext TF_OpKernelContext; 41 42 // Allocates a new kernel builder and returns a pointer to it. 43 // 44 // If non-null, TensorFlow will call create_func when it needs to instantiate 45 // the kernel. The pointer returned by create_func will be passed to 46 // compute_func and delete_func, thereby functioning as a "this" pointer for 47 // referring to kernel instances. 48 // 49 // The TF_OpKernelConstruction pointer passed to create_func is owned by 50 // TensorFlow and will be deleted once create_func returns. It must not be used 51 // after this. 52 // 53 // When TensorFlow needs to perform a computation with this kernel, it will 54 // call compute_func. This function will receive the pointer returned by 55 // create_func (or null if no create_func was provided), along with the inputs 56 // to the computation. 57 // 58 // The TF_OpKernelContext pointer received by compute_func is owned by 59 // TensorFlow and will be deleted once compute_func returns. It must not be used 60 // after this. 61 // 62 // Finally, when TensorFlow no longer needs the kernel, it will call 63 // delete_func if one is provided. This function will receive the pointer 64 // returned in `create_func` or nullptr if no `create_func` was provided. 65 // 66 // The caller should pass the result of this function to 67 // TF_RegisterKernelBuilder, which will take ownership of the pointer. If, for 68 // some reason, the kernel builder will not be registered, the caller should 69 // delete it with TF_DeleteKernelBuilder. 70 TF_CAPI_EXPORT extern TF_KernelBuilder* TF_NewKernelBuilder( 71 const char* op_name, const char* device_name, 72 void* (*create_func)(TF_OpKernelConstruction*), 73 void (*compute_func)(void*, TF_OpKernelContext*), 74 void (*delete_func)(void*)); 75 76 // Register the given kernel builder with the TensorFlow runtime. If 77 // registration fails, the given status will be populated. 78 // 79 // This call takes ownership of the `builder` pointer. 80 TF_CAPI_EXPORT extern void TF_RegisterKernelBuilder(const char* kernel_name, 81 TF_KernelBuilder* builder, 82 TF_Status* status); 83 84 // Deletes the given TF_KernelBuilder. This should be called only if the kernel 85 // builder is not registered with TensorFlow via TF_RegisterKernelBuilder. 86 TF_CAPI_EXPORT extern void TF_DeleteKernelBuilder(TF_KernelBuilder* builder); 87 88 // -------------------------------------------------------------------------- 89 // OpKernelContext routines 90 91 // TF_NumInputs returns the number of inputs available in ctx. 92 TF_CAPI_EXPORT extern int TF_NumInputs(TF_OpKernelContext* ctx); 93 94 // TF_NumOutputs returns the number of outputs to be placed in *ctx by the 95 // kernel. 96 TF_CAPI_EXPORT extern int TF_NumOutputs(TF_OpKernelContext* ctx); 97 98 // Retrieves the ith input from ctx. If TF_GetCode(status) is TF_OK, *tensor is 99 // populated and its ownership is passed to the caller. In any other case, 100 // *tensor is not modified. 101 // 102 // If i < 0 or i >= TF_NumInputs(ctx), *status is set to TF_OUT_OF_RANGE. 103 TF_CAPI_EXPORT extern void TF_GetInput(TF_OpKernelContext* ctx, int i, 104 TF_Tensor** tensor, TF_Status* status); 105 106 // Sets the ith output of ctx to tensor. If TF_GetCode(status) is anything but 107 // TF_OK, ctx is left unmodified. 108 // 109 // If i < 0 or i >= TF_NumOutputs(ctx), *status is set to TF_OUT_OF_RANGE. 110 TF_CAPI_EXPORT extern void TF_SetOutput(TF_OpKernelContext* ctx, int i, 111 const TF_Tensor* tensor, 112 TF_Status* status); 113 114 // Notifies the given OpKernelConstruction that kernel construction has failed. 115 TF_CAPI_EXPORT extern void TF_OpKernelConstruction_Failure( 116 TF_OpKernelConstruction* ctx, TF_Status* status); 117 118 // Notifies the given OpKernelContext that the kernel's compute function has 119 // failed. 120 TF_CAPI_EXPORT extern void TF_OpKernelContext_Failure(TF_OpKernelContext* ctx, 121 TF_Status* status); 122 123 // Returns the expected output data type of the ith output. If i < 0 or 124 // i >= TF_NumOutputs(ctx), the program aborts. 125 TF_CAPI_EXPORT extern TF_DataType TF_ExpectedOutputDataType( 126 TF_OpKernelContext* ctx, int i); 127 128 // Returns the step ID of the given context. 129 TF_CAPI_EXPORT extern int64_t TF_StepId(TF_OpKernelContext* ctx); 130 131 // Interprets the named kernel construction attribute as a TF_DataType and 132 // places it into *val. *status is set to TF_OK. 133 // 134 // If the attribute could not be found or could not be interpreted as 135 // TF_DataType, *status is populated with an error. 136 TF_CAPI_EXPORT extern void TF_OpKernelConstruction_GetAttrType( 137 TF_OpKernelConstruction* ctx, const char* attr_name, TF_DataType* val, 138 TF_Status* status); 139 140 #ifdef __cplusplus 141 } /* end extern "C" */ 142 #endif 143 144 #endif // TENSORFLOW_C_KERNELS_H_ 145