1 /*
2  * Copyright (C) 2017 The Android Open Source Project
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *      http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 package com.android.nn.benchmark.evaluators;
18 
19 import java.util.List;
20 
21 /**
22  * Inference evaluator for the TTS model.
23  *
24  * This validates that the Mel-cep distortion and log F0 error are within the limits.
25  */
26 public class MelCepLogF0 extends BaseSequenceEvaluator {
27 
28     static private final float MEL_CEP_DISTORTION_LIMIT = 4f;
29     static private final float LOG_F0_ERROR_LIMIT = 0.01f;
30 
31     // The TTS model predicts 4 frames per inference.
32     // For each frame, there are 40 amplitude values, 7 aperiodicity values,
33     // 1 log F0 value and 1 voicing value.
34     static private final int FRAMES_PER_INFERENCE = 4;
35     static private final int AMPLITUDE_DIMENSION = 40;
36     static private final int APERIODICITY_DIMENSION = 7;
37     static private final int LOG_F0_DIMENSION = 1;
38     static private final int VOICING_DIMENSION = 1;
39     static private final int FRAME_OUTPUT_DIMENSION = AMPLITUDE_DIMENSION + APERIODICITY_DIMENSION +
40             LOG_F0_DIMENSION + VOICING_DIMENSION;
41     // The threshold to classify if a frame is voiced (above threshold) or unvoiced (below).
42     static private final float VOICED_THRESHOLD = 0f;
43 
44     private float mMaxMelCepDistortion = 0f;
45     private float mMaxLogF0Error = 0f;
46 
47     @Override
EvaluateSequenceAccuracy(float[][] outputs, float[][] expectedOutputs, List<String> outValidationErrors)48     protected void EvaluateSequenceAccuracy(float[][] outputs, float[][] expectedOutputs,
49             List<String> outValidationErrors) {
50         float melCepDistortion = calculateMelCepDistortion(outputs, expectedOutputs);
51         if (melCepDistortion > MEL_CEP_DISTORTION_LIMIT) {
52             outValidationErrors.add("Mel-cep distortion exceeded the limit: " +
53                     melCepDistortion);
54         }
55         mMaxMelCepDistortion = Math.max(mMaxMelCepDistortion, melCepDistortion);
56 
57         float logF0Error = calculateLogF0Error(outputs, expectedOutputs);
58         if (logF0Error > LOG_F0_ERROR_LIMIT) {
59             outValidationErrors.add("Log F0 error exceeded the limit: " + logF0Error);
60         }
61         mMaxLogF0Error = Math.max(mMaxLogF0Error, logF0Error);
62     }
63 
64     @Override
AddValidationResult(List<String> keys, List<Float> values)65     protected void AddValidationResult(List<String> keys, List<Float> values) {
66         keys.add("max_mel_cep_distortion");
67         values.add(mMaxMelCepDistortion);
68         keys.add("max_log_f0_error");
69         values.add(mMaxLogF0Error);
70     }
71 
calculateMelCepDistortion(float[][] outputs, float[][] expectedOutputs)72     private static float calculateMelCepDistortion(float[][] outputs, float[][] expectedOutputs) {
73         int inferenceCount = outputs.length;
74         float squared_error = 0;
75         for (int inferenceIndex = 0; inferenceIndex < inferenceCount; ++inferenceIndex) {
76             for (int frameIndex = 0; frameIndex < FRAMES_PER_INFERENCE; ++frameIndex) {
77                 // Mel-Cep distortion skips the first amplitude element.
78                 for (int amplitudeIndex = 1; amplitudeIndex < AMPLITUDE_DIMENSION;
79                      ++amplitudeIndex) {
80                     int i = frameIndex * FRAME_OUTPUT_DIMENSION + amplitudeIndex;
81                     squared_error += Math.pow(
82                             outputs[inferenceIndex][i] - expectedOutputs[inferenceIndex][i], 2);
83                 }
84             }
85         }
86 
87         return (float)Math.sqrt(squared_error /
88                 (inferenceCount * FRAMES_PER_INFERENCE * (AMPLITUDE_DIMENSION - 1)));
89     }
90 
calculateLogF0Error(float[][] outputs, float[][] expectedOutputs)91     private static float calculateLogF0Error(float[][] outputs, float[][] expectedOutputs) {
92         int inferenceCount = outputs.length;
93         float squared_error = 0;
94         int count = 0;
95         for (int inferenceIndex = 0; inferenceIndex < inferenceCount; ++inferenceIndex) {
96             for (int frameIndex = 0; frameIndex < FRAMES_PER_INFERENCE; ++frameIndex) {
97                 int f0Index = frameIndex * FRAME_OUTPUT_DIMENSION + AMPLITUDE_DIMENSION +
98                         APERIODICITY_DIMENSION;
99                 int voicedIndex = f0Index + LOG_F0_DIMENSION;
100                 if (outputs[inferenceIndex][voicedIndex] > VOICED_THRESHOLD &&
101                         expectedOutputs[inferenceIndex][voicedIndex] > VOICED_THRESHOLD) {
102                     squared_error += Math.pow(outputs[inferenceIndex][f0Index] -
103                             expectedOutputs[inferenceIndex][f0Index], 2);
104                     ++count;
105                 }
106             }
107         }
108         float logF0Error = 0f;
109         if (count > 0) {
110             logF0Error = (float)Math.sqrt(squared_error / count);
111         }
112         return logF0Error;
113     }
114 }
115