1 /* 2 * Copyright (C) 2017 The Android Open Source Project 3 * 4 * Licensed under the Apache License, Version 2.0 (the "License"); 5 * you may not use this file except in compliance with the License. 6 * You may obtain a copy of the License at 7 * 8 * http://www.apache.org/licenses/LICENSE-2.0 9 * 10 * Unless required by applicable law or agreed to in writing, software 11 * distributed under the License is distributed on an "AS IS" BASIS, 12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 * See the License for the specific language governing permissions and 14 * limitations under the License. 15 */ 16 17 package com.android.nn.benchmark.evaluators; 18 19 import java.util.List; 20 21 /** 22 * Inference evaluator for the TTS model. 23 * 24 * This validates that the Mel-cep distortion and log F0 error are within the limits. 25 */ 26 public class MelCepLogF0 extends BaseSequenceEvaluator { 27 28 static private final float MEL_CEP_DISTORTION_LIMIT = 4f; 29 static private final float LOG_F0_ERROR_LIMIT = 0.01f; 30 31 // The TTS model predicts 4 frames per inference. 32 // For each frame, there are 40 amplitude values, 7 aperiodicity values, 33 // 1 log F0 value and 1 voicing value. 34 static private final int FRAMES_PER_INFERENCE = 4; 35 static private final int AMPLITUDE_DIMENSION = 40; 36 static private final int APERIODICITY_DIMENSION = 7; 37 static private final int LOG_F0_DIMENSION = 1; 38 static private final int VOICING_DIMENSION = 1; 39 static private final int FRAME_OUTPUT_DIMENSION = AMPLITUDE_DIMENSION + APERIODICITY_DIMENSION + 40 LOG_F0_DIMENSION + VOICING_DIMENSION; 41 // The threshold to classify if a frame is voiced (above threshold) or unvoiced (below). 42 static private final float VOICED_THRESHOLD = 0f; 43 44 private float mMaxMelCepDistortion = 0f; 45 private float mMaxLogF0Error = 0f; 46 47 @Override EvaluateSequenceAccuracy(float[][] outputs, float[][] expectedOutputs, List<String> outValidationErrors)48 protected void EvaluateSequenceAccuracy(float[][] outputs, float[][] expectedOutputs, 49 List<String> outValidationErrors) { 50 float melCepDistortion = calculateMelCepDistortion(outputs, expectedOutputs); 51 if (melCepDistortion > MEL_CEP_DISTORTION_LIMIT) { 52 outValidationErrors.add("Mel-cep distortion exceeded the limit: " + 53 melCepDistortion); 54 } 55 mMaxMelCepDistortion = Math.max(mMaxMelCepDistortion, melCepDistortion); 56 57 float logF0Error = calculateLogF0Error(outputs, expectedOutputs); 58 if (logF0Error > LOG_F0_ERROR_LIMIT) { 59 outValidationErrors.add("Log F0 error exceeded the limit: " + logF0Error); 60 } 61 mMaxLogF0Error = Math.max(mMaxLogF0Error, logF0Error); 62 } 63 64 @Override AddValidationResult(List<String> keys, List<Float> values)65 protected void AddValidationResult(List<String> keys, List<Float> values) { 66 keys.add("max_mel_cep_distortion"); 67 values.add(mMaxMelCepDistortion); 68 keys.add("max_log_f0_error"); 69 values.add(mMaxLogF0Error); 70 } 71 calculateMelCepDistortion(float[][] outputs, float[][] expectedOutputs)72 private static float calculateMelCepDistortion(float[][] outputs, float[][] expectedOutputs) { 73 int inferenceCount = outputs.length; 74 float squared_error = 0; 75 for (int inferenceIndex = 0; inferenceIndex < inferenceCount; ++inferenceIndex) { 76 for (int frameIndex = 0; frameIndex < FRAMES_PER_INFERENCE; ++frameIndex) { 77 // Mel-Cep distortion skips the first amplitude element. 78 for (int amplitudeIndex = 1; amplitudeIndex < AMPLITUDE_DIMENSION; 79 ++amplitudeIndex) { 80 int i = frameIndex * FRAME_OUTPUT_DIMENSION + amplitudeIndex; 81 squared_error += Math.pow( 82 outputs[inferenceIndex][i] - expectedOutputs[inferenceIndex][i], 2); 83 } 84 } 85 } 86 87 return (float)Math.sqrt(squared_error / 88 (inferenceCount * FRAMES_PER_INFERENCE * (AMPLITUDE_DIMENSION - 1))); 89 } 90 calculateLogF0Error(float[][] outputs, float[][] expectedOutputs)91 private static float calculateLogF0Error(float[][] outputs, float[][] expectedOutputs) { 92 int inferenceCount = outputs.length; 93 float squared_error = 0; 94 int count = 0; 95 for (int inferenceIndex = 0; inferenceIndex < inferenceCount; ++inferenceIndex) { 96 for (int frameIndex = 0; frameIndex < FRAMES_PER_INFERENCE; ++frameIndex) { 97 int f0Index = frameIndex * FRAME_OUTPUT_DIMENSION + AMPLITUDE_DIMENSION + 98 APERIODICITY_DIMENSION; 99 int voicedIndex = f0Index + LOG_F0_DIMENSION; 100 if (outputs[inferenceIndex][voicedIndex] > VOICED_THRESHOLD && 101 expectedOutputs[inferenceIndex][voicedIndex] > VOICED_THRESHOLD) { 102 squared_error += Math.pow(outputs[inferenceIndex][f0Index] - 103 expectedOutputs[inferenceIndex][f0Index], 2); 104 ++count; 105 } 106 } 107 } 108 float logF0Error = 0f; 109 if (count > 0) { 110 logF0Error = (float)Math.sqrt(squared_error / count); 111 } 112 return logF0Error; 113 } 114 } 115