Lines Matching refs:Gradient

75 template <typename Gradient, typename BackwardFunction, typename TapeTensor>
81 virtual int64 NumElements(Gradient* tensor) const = 0;
85 virtual Gradient* AggregateGradients(
86 gtl::ArraySlice<Gradient*> gradient_tensors) const = 0;
89 virtual Gradient* Zeros(const TapeTensor& tensor) const = 0;
92 virtual Gradient* Ones(const TapeTensor& tensor) const = 0;
97 gtl::ArraySlice<Gradient*> output_gradients,
98 std::vector<Gradient*>* result) const = 0;
102 virtual void MarkAsResult(Gradient* gradient) const = 0;
105 virtual void DeleteGradient(Gradient* gradient) const = 0;
110 template <typename Gradient, typename BackwardFunction, typename TapeTensor>
143 const VSpace<Gradient, BackwardFunction, TapeTensor>& vspace,
147 gtl::ArraySlice<Gradient*> output_gradients,
148 std::vector<Gradient*>* result);
184 template <typename Gradient, typename BackwardFunction, typename TapeTensor>
185 bool GradientTape<Gradient, BackwardFunction, TapeTensor>::ShouldRecord( in ShouldRecord()
199 template <typename Gradient, typename BackwardFunction, typename TapeTensor>
200 void GradientTape<Gradient, BackwardFunction, TapeTensor>::Watch( in Watch()
205 template <typename Gradient, typename BackwardFunction, typename TapeTensor>
206 void GradientTape<Gradient, BackwardFunction, TapeTensor>::RecordOperation( in RecordOperation()
236 template <typename Gradient, typename BackwardFunction, typename TapeTensor>
237 void GradientTape<Gradient, BackwardFunction, TapeTensor>::DeleteTrace( in DeleteTrace()
396 template <typename Gradient, typename BackwardFunction, typename TapeTensor>
398 const VSpace<Gradient, BackwardFunction, TapeTensor>& vspace, in InitialGradients() argument
401 gtl::ArraySlice<Gradient*> output_gradients, const TensorTape& tensor_tape, in InitialGradients()
403 gtl::FlatMap<int64, std::vector<Gradient*>>* result) { in InitialGradients()
474 template <typename Gradient, typename BackwardFunction, typename TapeTensor>
475 Status GradientTape<Gradient, BackwardFunction, TapeTensor>::ComputeGradient( in ComputeGradient()
476 const VSpace<Gradient, BackwardFunction, TapeTensor>& vspace, in ComputeGradient() argument
480 gtl::ArraySlice<Gradient*> output_gradients, in ComputeGradient()
481 std::vector<Gradient*>* result) { in ComputeGradient()
488 gtl::FlatMap<int64, std::vector<Gradient*>> gradients; in ComputeGradient()
525 std::vector<Gradient*> out_gradients; in ComputeGradient()
542 Gradient* new_gradients = nullptr; in ComputeGradient()
558 std::vector<Gradient*> in_gradients; in ComputeGradient()
574 for (Gradient* grad : out_gradients) { in ComputeGradient()
597 Gradient* grad = vspace.AggregateGradients(unaggregated_grads); in ComputeGradient()
653 Gradient* grad = vspace.AggregateGradients(grad_it->second); in ComputeGradient()