Home
last modified time | relevance | path

Searched refs:max_logits (Results 1 – 2 of 2) sorted by relevance

/external/tensorflow/tensorflow/core/kernels/
Dsoftmax_op_gpu.cu.cc73 const T* max_logits, T* output, in GenerateNormalizedProb() argument
88 max_val[i] = strict_cast<U>(ldg(max_logits + row)); in GenerateNormalizedProb()
112 const Eigen::half* max_logits, Eigen::half* output, const int num_rows, in GenerateNormalizedProb() argument
134 max_val[i] = strict_cast<float>(ldg(max_logits + row[i])); in GenerateNormalizedProb()
150 max_val[i] = strict_cast<float>(ldg(max_logits + row[i])); in GenerateNormalizedProb()
165 const T* __restrict__ max_logits, in SubtractAndExpFunctor()
167 : logits_(logits), max_logits_(max_logits), num_cols_(num_cols) {} in SubtractAndExpFunctor()
215 Tensor max_logits; in Compute() local
219 softmax_out->shape(), &max_logits)); in Compute()
227 context, const_cast<T*>(max_logits.flat<T>().data()), in Compute()
[all …]
/external/tensorflow/tensorflow/compiler/mlir/xla/transforms/
Dlegalize_tf.cc2724 auto max_logits = in matchAndRewrite() local
2728 CommonPrefixBroadcast(loc, logits, max_logits, rewriter); in matchAndRewrite()