Home
last modified time | relevance | path

Searched refs:RnnDescriptor (Results 1 – 10 of 10) sorted by relevance

/external/tensorflow/tensorflow/stream_executor/cuda/
Dcuda_dnn.h50 port::StatusOr<std::unique_ptr<dnn::RnnDescriptor>> createRnnDescriptor(
74 bool DoRnnForward(Stream* stream, const dnn::RnnDescriptor& rnn_desc,
92 bool DoRnnForward(Stream* stream, const dnn::RnnDescriptor& rnn_desc,
110 bool DoRnnForward(Stream* stream, const dnn::RnnDescriptor& rnn_desc,
128 bool DoRnnBackward(Stream* stream, const dnn::RnnDescriptor& rnn_desc,
153 bool DoRnnBackward(Stream* stream, const dnn::RnnDescriptor& rnn_desc,
178 bool DoRnnBackward(Stream* stream, const dnn::RnnDescriptor& rnn_desc,
Dcuda_dnn.cc443 using RnnDescriptor = std::unique_ptr<cudnnRNNStruct, RnnDescriptorDeleter>; typedef
492 RnnDescriptor CreateRnnDescriptor() { in CreateRnnDescriptor()
495 return RnnDescriptor(result); in CreateRnnDescriptor()
1025 typedef dnn::RnnDescriptor::ParamsRegions ParamsRegions;
1057 class CudnnRnnDescriptor : public dnn::RnnDescriptor {
1058 CudnnRnnDescriptor(const CudnnHandle& cudnn, gpu::RnnDescriptor rnn_desc, in CudnnRnnDescriptor()
1099 gpu::RnnDescriptor rnn_desc = CreateRnnDescriptor(); in Create()
1229 gpu::RnnDescriptor rnn_desc_;
1288 dnn::RnnDescriptor::ParamsRegions* weights) { in CheckAndFetchProjectionWeights()
1348 dnn::RnnDescriptor::ParamsRegion region = {reinterpret_cast<int64>(offset), in CheckAndFetchProjectionWeights()
[all …]
/external/tensorflow/tensorflow/core/kernels/
Dcudnn_rnn_ops.cc85 using se::dnn::RnnDescriptor;
559 std::unique_ptr<RnnDescriptor> rnn_desc;
786 Status DoForward(OpKernelContext* context, const RnnDescriptor& rnn_desc, in DoForward()
847 OpKernelContext* context, const RnnDescriptor& rnn_desc, in DoBackward()
929 const std::vector<RnnDescriptor::ParamsRegion>& params, in RestoreParams()
1010 std::unique_ptr<RnnDescriptor>* rnn_desc) { in ExtractCudnnRNNParamsInfo()
1059 std::unique_ptr<RnnDescriptor>* rnn_desc, in CreateRnnDescriptor()
1085 RnnStateCache* cache, RnnDescriptor** rnn_desc, in GetCachedRnnDescriptor()
1127 std::unique_ptr<RnnDescriptor> rnn_desc; in Compute()
1200 std::unique_ptr<RnnDescriptor> rnn_desc; in Compute()
[all …]
/external/tensorflow/tensorflow/stream_executor/rocm/
Drocm_dnn.h84 port::StatusOr<std::unique_ptr<dnn::RnnDescriptor>> createRnnDescriptor(
101 bool DoRnnForward(Stream* stream, const dnn::RnnDescriptor& rnn_desc,
119 bool DoRnnForward(Stream* stream, const dnn::RnnDescriptor& rnn_desc,
137 bool DoRnnForward(Stream* stream, const dnn::RnnDescriptor& rnn_desc,
155 bool DoRnnBackward(Stream* stream, const dnn::RnnDescriptor& rnn_desc,
180 bool DoRnnBackward(Stream* stream, const dnn::RnnDescriptor& rnn_desc,
205 bool DoRnnBackward(Stream* stream, const dnn::RnnDescriptor& rnn_desc,
Drocm_dnn.cc1727 typedef dnn::RnnDescriptor::ParamsRegion ParamsRegion;
1728 typedef dnn::RnnDescriptor::ParamsRegions ParamsRegions;
1760 class MIOpenRnnDescriptor : public MIOpenDescriptorCommon<dnn::RnnDescriptor> {
2514 port::StatusOr<std::unique_ptr<dnn::RnnDescriptor>>
2545 return port::StatusOr<std::unique_ptr<dnn::RnnDescriptor>>( in createRnnDescriptor()
2578 Stream* stream, const dnn::RnnDescriptor& rnn_desc, in DoRnnForward()
2621 Stream* stream, const dnn::RnnDescriptor& rnn_desc, in DoRnnForward()
2663 Stream* stream, const dnn::RnnDescriptor& rnn_desc, in DoRnnForward()
2685 Stream* stream, const dnn::RnnDescriptor& rnn_desc, in DoRnnBackward()
2737 Stream* stream, const dnn::RnnDescriptor& rnn_desc, in DoRnnBackward()
[all …]
/external/tensorflow/tensorflow/stream_executor/
Ddnn.h157 class RnnDescriptor {
164 virtual ~RnnDescriptor() {} in ~RnnDescriptor()
2109 virtual port::StatusOr<std::unique_ptr<dnn::RnnDescriptor>>
2185 virtual bool DoRnnForward(Stream* stream, const dnn::RnnDescriptor& rnn_desc, in DoRnnForward()
2206 virtual bool DoRnnForward(Stream* stream, const dnn::RnnDescriptor& rnn_desc, in DoRnnForward()
2227 virtual bool DoRnnForward(Stream* stream, const dnn::RnnDescriptor& rnn_desc, in DoRnnForward()
2289 Stream* stream, const dnn::RnnDescriptor& rnn_desc, in DoRnnBackward()
2317 Stream* stream, const dnn::RnnDescriptor& rnn_desc, in DoRnnBackward()
2345 Stream* stream, const dnn::RnnDescriptor& rnn_desc, in DoRnnBackward()
Dstream.h1779 Stream &ThenRnnForward(const dnn::RnnDescriptor &rnn_desc,
1798 Stream &ThenRnnForward(const dnn::RnnDescriptor &rnn_desc,
1816 Stream &ThenRnnForward(const dnn::RnnDescriptor &rnn_desc,
1837 const dnn::RnnDescriptor &rnn_desc,
1862 Stream &ThenRnnBackward(const dnn::RnnDescriptor &rnn_desc,
1887 Stream &ThenRnnBackward(const dnn::RnnDescriptor &rnn_desc,
Dstream_executor_pimpl.h415 port::StatusOr<std::unique_ptr<dnn::RnnDescriptor>> createRnnDescriptor(
Dstream_executor_pimpl.cc364 port::StatusOr<std::unique_ptr<dnn::RnnDescriptor>>
Dstream.cc4539 const dnn::RnnDescriptor &rnn_desc, in ThenRnnForward()
4573 const dnn::RnnDescriptor &rnn_desc, in ThenRnnForward()
4606 const dnn::RnnDescriptor &rnn_desc, in ThenRnnForward()
4640 const dnn::RnnDescriptor &rnn_desc, in ThenRnnBackward()
4685 const dnn::RnnDescriptor &rnn_desc, in ThenRnnBackward()
4729 const dnn::RnnDescriptor &rnn_desc, in ThenRnnBackward()