Searched refs:bwdParams (Results 1 – 6 of 6) sorted by relevance
/external/tensorflow/tensorflow/core/kernels/mkl/ |
D | mkl_pooling_ops_common.cc | 129 void MklPoolingBwdPrimitive<T>::Setup(const MklPoolingParams& bwdParams) { in Setup() argument 130 DCHECK(bwdParams.alg_kind == mkldnn::algorithm::pooling_max || in Setup() 131 bwdParams.alg_kind == mkldnn::algorithm::pooling_avg || in Setup() 132 bwdParams.alg_kind == mkldnn::algorithm::pooling_avg_include_padding || in Setup() 133 bwdParams.alg_kind == mkldnn::algorithm::pooling_avg_exclude_padding) in Setup() 135 context_.alg_kind = bwdParams.alg_kind; in Setup() 138 context_.src_md.reset(new memory::desc({bwdParams.src_dims}, MklDnnType<T>(), in Setup() 140 context_.src_md.reset(new memory::desc(bwdParams.src_md.data)); in Setup() 141 context_.dst_md.reset(new memory::desc({bwdParams.dst_dims}, MklDnnType<T>(), in Setup() 142 bwdParams.native_format in Setup() [all …]
|
D | mkl_pooling_ops_common.h | 214 explicit MklPoolingBwdPrimitive(const MklPoolingParams& bwdParams) in MklPoolingBwdPrimitive() argument 216 if (context_.bwd == nullptr) Setup(bwdParams); in MklPoolingBwdPrimitive() 241 void Setup(const MklPoolingParams& bwdParams); 301 static MklPoolingBwdPrimitive<T>* Get(const MklPoolingParams& bwdParams) { in Get() argument 308 bwdParams)); in Get() 310 pooling_backward = new MklPoolingBwdPrimitive<T>(bwdParams); in Get() 312 bwdParams, pooling_backward); in Get() 330 static string CreateKey(const MklPoolingParams& bwdParams) { in CreateKey() argument 334 key_creator.AddAsKey(bwdParams.src_dims); in CreateKey() 335 key_creator.AddAsKey(bwdParams.dst_dims); in CreateKey() [all …]
|
D | mkl_relu_op.cc | 242 explicit MklEltwiseBwdPrimitive(const MklEltwiseBwdParams<T>& bwdParams) in MklEltwiseBwdPrimitive() argument 246 Setup(bwdParams); in MklEltwiseBwdPrimitive() 335 void Setup(const MklEltwiseBwdParams<T>& bwdParams) { in Setup() argument 337 context_.src_md.reset(new memory::desc(bwdParams.common_md.data)); in Setup() 338 context_.diff_dst_md.reset(new memory::desc(bwdParams.common_md.data)); in Setup() 344 prop_kind::forward_training, bwdParams.alg_kind, *context_.src_md, in Setup() 345 bwdParams.alpha, bwdParams.beta)); in Setup() 348 bwdParams.alg_kind, *context_.diff_dst_md, *context_.src_md, in Setup() 349 bwdParams.alpha, bwdParams.beta)); in Setup() 365 {{bwdParams.forward_input_type, *context_.src_mem}, in Setup() [all …]
|
D | mkl_fused_batch_norm_op.cc | 402 explicit MklFusedBatchNormBwdPrimitive(const MklBatchNormBwdParams& bwdParams) in MklFusedBatchNormBwdPrimitive() argument 404 if (context_.bn_bwd == nullptr) Setup(bwdParams); in MklFusedBatchNormBwdPrimitive() 519 void Setup(const MklBatchNormBwdParams& bwdParams) { in Setup() argument 521 bwdParams.training in Setup() 526 auto src_md = bwdParams.src_md; in Setup() 527 auto diff_dst_md = bwdParams.diff_dst_md; in Setup() 528 auto variance_desc = memory::desc({1, bwdParams.depth}, MklDnnType<U>(), in Setup() 530 auto mean_desc = memory::desc({1, bwdParams.depth}, MklDnnType<U>(), in Setup() 532 auto weights_desc = memory::desc({2, bwdParams.depth}, MklDnnType<U>(), in Setup() 538 auto bn_flags = bwdParams.training in Setup() [all …]
|
D | mkl_avgpooling_op.cc | 242 MklPoolingParams bwdParams( in Compute() local 250 MklPoolingBwdPrimitiveFactory<T>::Get(bwdParams); in Compute()
|
D | mkl_maxpooling_op.cc | 297 MklPoolingParams bwdParams( in Compute() local 304 MklPoolingBwdPrimitiveFactory<T>::Get(bwdParams); in Compute()
|