Home
last modified time | relevance | path

Searched refs:bwdParams (Results 1 – 6 of 6) sorted by relevance

/external/tensorflow/tensorflow/core/kernels/mkl/
Dmkl_pooling_ops_common.cc129 void MklPoolingBwdPrimitive<T>::Setup(const MklPoolingParams& bwdParams) { in Setup() argument
130 DCHECK(bwdParams.alg_kind == mkldnn::algorithm::pooling_max || in Setup()
131 bwdParams.alg_kind == mkldnn::algorithm::pooling_avg || in Setup()
132 bwdParams.alg_kind == mkldnn::algorithm::pooling_avg_include_padding || in Setup()
133 bwdParams.alg_kind == mkldnn::algorithm::pooling_avg_exclude_padding) in Setup()
135 context_.alg_kind = bwdParams.alg_kind; in Setup()
138 context_.src_md.reset(new memory::desc({bwdParams.src_dims}, MklDnnType<T>(), in Setup()
140 context_.src_md.reset(new memory::desc(bwdParams.src_md.data)); in Setup()
141 context_.dst_md.reset(new memory::desc({bwdParams.dst_dims}, MklDnnType<T>(), in Setup()
142 bwdParams.native_format in Setup()
[all …]
Dmkl_pooling_ops_common.h214 explicit MklPoolingBwdPrimitive(const MklPoolingParams& bwdParams) in MklPoolingBwdPrimitive() argument
216 if (context_.bwd == nullptr) Setup(bwdParams); in MklPoolingBwdPrimitive()
241 void Setup(const MklPoolingParams& bwdParams);
301 static MklPoolingBwdPrimitive<T>* Get(const MklPoolingParams& bwdParams) { in Get() argument
308 bwdParams)); in Get()
310 pooling_backward = new MklPoolingBwdPrimitive<T>(bwdParams); in Get()
312 bwdParams, pooling_backward); in Get()
330 static string CreateKey(const MklPoolingParams& bwdParams) { in CreateKey() argument
334 key_creator.AddAsKey(bwdParams.src_dims); in CreateKey()
335 key_creator.AddAsKey(bwdParams.dst_dims); in CreateKey()
[all …]
Dmkl_relu_op.cc242 explicit MklEltwiseBwdPrimitive(const MklEltwiseBwdParams<T>& bwdParams) in MklEltwiseBwdPrimitive() argument
246 Setup(bwdParams); in MklEltwiseBwdPrimitive()
335 void Setup(const MklEltwiseBwdParams<T>& bwdParams) { in Setup() argument
337 context_.src_md.reset(new memory::desc(bwdParams.common_md.data)); in Setup()
338 context_.diff_dst_md.reset(new memory::desc(bwdParams.common_md.data)); in Setup()
344 prop_kind::forward_training, bwdParams.alg_kind, *context_.src_md, in Setup()
345 bwdParams.alpha, bwdParams.beta)); in Setup()
348 bwdParams.alg_kind, *context_.diff_dst_md, *context_.src_md, in Setup()
349 bwdParams.alpha, bwdParams.beta)); in Setup()
365 {{bwdParams.forward_input_type, *context_.src_mem}, in Setup()
[all …]
Dmkl_fused_batch_norm_op.cc402 explicit MklFusedBatchNormBwdPrimitive(const MklBatchNormBwdParams& bwdParams) in MklFusedBatchNormBwdPrimitive() argument
404 if (context_.bn_bwd == nullptr) Setup(bwdParams); in MklFusedBatchNormBwdPrimitive()
519 void Setup(const MklBatchNormBwdParams& bwdParams) { in Setup() argument
521 bwdParams.training in Setup()
526 auto src_md = bwdParams.src_md; in Setup()
527 auto diff_dst_md = bwdParams.diff_dst_md; in Setup()
528 auto variance_desc = memory::desc({1, bwdParams.depth}, MklDnnType<U>(), in Setup()
530 auto mean_desc = memory::desc({1, bwdParams.depth}, MklDnnType<U>(), in Setup()
532 auto weights_desc = memory::desc({2, bwdParams.depth}, MklDnnType<U>(), in Setup()
538 auto bn_flags = bwdParams.training in Setup()
[all …]
Dmkl_avgpooling_op.cc242 MklPoolingParams bwdParams( in Compute() local
250 MklPoolingBwdPrimitiveFactory<T>::Get(bwdParams); in Compute()
Dmkl_maxpooling_op.cc297 MklPoolingParams bwdParams( in Compute() local
304 MklPoolingBwdPrimitiveFactory<T>::Get(bwdParams); in Compute()