Searched refs:max_logits (Results 1 – 2 of 2) sorted by relevance
/external/tensorflow/tensorflow/core/kernels/ |
D | softmax_op_gpu.cu.cc | 73 const T* max_logits, T* output, in GenerateNormalizedProb() argument 88 max_val[i] = strict_cast<U>(ldg(max_logits + row)); in GenerateNormalizedProb() 112 const Eigen::half* max_logits, Eigen::half* output, const int num_rows, in GenerateNormalizedProb() argument 134 max_val[i] = strict_cast<float>(ldg(max_logits + row[i])); in GenerateNormalizedProb() 150 max_val[i] = strict_cast<float>(ldg(max_logits + row[i])); in GenerateNormalizedProb() 165 const T* __restrict__ max_logits, in SubtractAndExpFunctor() 167 : logits_(logits), max_logits_(max_logits), num_cols_(num_cols) {} in SubtractAndExpFunctor() 215 Tensor max_logits; in Compute() local 219 softmax_out->shape(), &max_logits)); in Compute() 227 context, const_cast<T*>(max_logits.flat<T>().data()), in Compute() [all …]
|
/external/tensorflow/tensorflow/compiler/mlir/xla/transforms/ |
D | legalize_tf.cc | 2724 auto max_logits = in matchAndRewrite() local 2728 CommonPrefixBroadcast(loc, logits, max_logits, rewriter); in matchAndRewrite()
|