1 /*
2  * Copyright (C) 2023 The Android Open Source Project
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *      http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 package android.videocodec.cts;
18 
19 import static android.videocodec.cts.VideoEncoderInput.RES_YUV_MAP;
20 import static android.videocodec.cts.VideoEncoderInput.getRawResource;
21 
22 import static org.junit.Assert.assertEquals;
23 import static org.junit.Assert.assertTrue;
24 import static org.junit.Assert.fail;
25 import static org.junit.Assume.assumeNotNull;
26 import static org.junit.Assume.assumeTrue;
27 
28 import android.graphics.Rect;
29 import android.mediav2.common.cts.CompareStreams;
30 import android.mediav2.common.cts.EncoderConfigParams;
31 import android.mediav2.common.cts.RawResource;
32 import android.util.Log;
33 import android.util.Pair;
34 
35 import org.junit.After;
36 import org.junit.AfterClass;
37 import org.junit.Before;
38 import org.junit.BeforeClass;
39 
40 import java.io.File;
41 import java.io.IOException;
42 import java.util.ArrayList;
43 import java.util.List;
44 import java.util.Locale;
45 import java.util.Map;
46 import java.util.function.Predicate;
47 
48 /**
49  * Wrapper class for testing quality regression.
50  */
51 public class VideoEncoderQualityRegressionTestBase {
52     private static final String TAG = VideoEncoderQualityRegressionTestBase.class.getSimpleName();
53     private static final ArrayList<String> mTmpFiles = new ArrayList<>();
54     protected static ArrayList<VideoEncoderInput.CompressedResource> RESOURCES = new ArrayList<>();
55 
56     protected final String mCodecName;
57     protected final String mMediaType;
58     protected final VideoEncoderInput.CompressedResource mCRes;
59     protected final String mAllTestParams;
60 
61     static {
62         System.loadLibrary("ctsvideoqualityutils_jni");
63     }
64 
VideoEncoderQualityRegressionTestBase(String encoder, String mediaType, VideoEncoderInput.CompressedResource cRes, String allTestParams)65     VideoEncoderQualityRegressionTestBase(String encoder, String mediaType,
66             VideoEncoderInput.CompressedResource cRes, String allTestParams) {
67         mCodecName = encoder;
68         mMediaType = mediaType;
69         mCRes = cRes;
70         mAllTestParams = allTestParams;
71     }
72 
73     /**
74      * Decodes a compressed resource to get YUV file and logs the list of files currently residing
75      * in the cache.
76      */
77     @BeforeClass
decodeResourcesToYuv()78     public static void decodeResourcesToYuv() {
79         VideoEncoderValidationTestBase.decodeStreamsToYuv(RESOURCES, RES_YUV_MAP, TAG);
80     }
81 
82     /**
83      * Clean up the raw resource.
84      */
85     @AfterClass
cleanUpResources()86     public static void cleanUpResources() {
87         VideoEncoderValidationTestBase.cleanUpResources();
88     }
89 
90     @Before
setUp()91     public void setUp() {
92         assumeNotNull("no raw resource found for testing : "
93                 + VideoEncoderValidationTestBase.DIAGNOSTICS, getRawResource(mCRes));
94     }
95 
96     @After
tearDown()97     public void tearDown() {
98         for (String tmpFile : mTmpFiles) {
99             File tmp = new File(tmpFile);
100             if (tmp.exists()) assertTrue("unable to delete file " + tmpFile, tmp.delete());
101         }
102         mTmpFiles.clear();
103     }
104 
getVideoEncoderCfgParams(String mediaType, int width, int height, int bitRate, int bitRateMode, int keyFrameInterval, int frameRate, int maxBFrames, Pair<String, Boolean> feature)105     protected static EncoderConfigParams getVideoEncoderCfgParams(String mediaType, int width,
106             int height, int bitRate, int bitRateMode, int keyFrameInterval, int frameRate,
107             int maxBFrames, Pair<String, Boolean> feature) {
108         EncoderConfigParams.Builder foreman = new EncoderConfigParams.Builder(mediaType)
109                 .setWidth(width)
110                 .setHeight(height)
111                 .setBitRate(bitRate)
112                 .setBitRateMode(bitRateMode)
113                 .setKeyFrameInterval(keyFrameInterval)
114                 .setFrameRate(frameRate)
115                 .setMaxBFrames(maxBFrames);
116         if (feature != null) {
117             foreman.setFeature(feature.first, feature.second);
118         }
119         return foreman.build();
120     }
121 
nativeGetBDRate(double[] qualitiesA, double[] ratesA, double[] qualitiesB, double[] ratesB, boolean selBdSnr, StringBuilder retMsg)122     private native double nativeGetBDRate(double[] qualitiesA, double[] ratesA, double[] qualitiesB,
123             double[] ratesB, boolean selBdSnr, StringBuilder retMsg);
124 
getQualityRegressionForCfgs(List<EncoderConfigParams[]> cfgsUnion, VideoEncoderValidationTestBase[] testInstances, String[] encoderNames, RawResource res, int frameLimit, int frameRate, Map<Long, List<Rect>> frameCropRects, boolean setLoopBack, Predicate<Double> predicate)125     protected void getQualityRegressionForCfgs(List<EncoderConfigParams[]> cfgsUnion,
126             VideoEncoderValidationTestBase[] testInstances, String[] encoderNames, RawResource res,
127             int frameLimit, int frameRate, Map<Long, List<Rect>> frameCropRects,
128             boolean setLoopBack, Predicate<Double> predicate)
129             throws IOException, InterruptedException {
130         assertEquals("Quality comparison is done between two sets", 2, cfgsUnion.size());
131         assertTrue("Minimum of 4 points are required for polynomial curve fitting",
132                 cfgsUnion.get(0).length >= 4);
133         double[][] psnrs = new double[cfgsUnion.size()][cfgsUnion.get(0).length];
134         double[][] rates = new double[cfgsUnion.size()][cfgsUnion.get(0).length];
135         for (int i = 0; i < cfgsUnion.size(); i++) {
136             EncoderConfigParams[] cfgs = cfgsUnion.get(i);
137             String mediaType = cfgs[0].mMediaType;
138             testInstances[i].setLoopBack(setLoopBack);
139             for (int j = 0; j < cfgs.length; j++) {
140                 testInstances[i].encodeToMemory(encoderNames[i], cfgs[j], res, frameLimit, true,
141                         true);
142                 mTmpFiles.add(testInstances[i].getMuxedOutputFilePath());
143                 assertEquals("encoder did not encode the requested number of frames \n", frameLimit,
144                         testInstances[i].getOutputCount());
145                 int outSize = testInstances[i].getOutputManager().getOutStreamSize();
146                 double achievedBitRate = ((double) outSize * 8 * frameRate) / (1000 * frameLimit);
147                 CompareStreams cs = null;
148                 try {
149                     cs = new CompareStreams(res, mediaType,
150                             testInstances[i].getMuxedOutputFilePath(), frameCropRects, true, true);
151                     final double[] globalPSNR = cs.getGlobalPSNR();
152                     double weightedPSNR = (6 * globalPSNR[0] + globalPSNR[1] + globalPSNR[2]) / 8;
153                     psnrs[i][j] = weightedPSNR;
154                     rates[i][j] = achievedBitRate;
155                 } finally {
156                     if (cs != null) cs.cleanUp();
157                 }
158                 testInstances[i].deleteMuxedFile();
159             }
160         }
161         StringBuilder retMsg = new StringBuilder();
162         double bdRate = nativeGetBDRate(psnrs[0], rates[0], psnrs[1], rates[1], false, retMsg);
163         if (retMsg.length() != 0) fail(retMsg.toString());
164         for (int i = 0; i < psnrs.length; i++) {
165             retMsg.append(String.format("\nBitrate GlbPsnr Set %d\n", i));
166             for (int j = 0; j < psnrs[i].length; j++) {
167                 retMsg.append(String.format("{%f, %f},\n", rates[i][j], psnrs[i][j]));
168             }
169         }
170         retMsg.append(String.format(Locale.getDefault(), "bd rate: %f", bdRate));
171         Log.d(TAG, retMsg.toString());
172         assumeTrue(retMsg.toString(), predicate.test(bdRate));
173     }
174 }
175