1 /*
2 * Copyright (c) 2020 The WebRTC project authors. All Rights Reserved.
3 *
4 * Use of this source code is governed by a BSD-style license
5 * that can be found in the LICENSE file in the root of the source
6 * tree. An additional intellectual property rights grant can be found
7 * in the file PATENTS. All contributing project authors may
8 * be found in the AUTHORS file in the root of the source tree.
9 */
10
11 #include <stddef.h>
12 #include <stdint.h>
13
14 #include <memory>
15 #include <ostream>
16 #include <string>
17
18 #include "absl/types/optional.h"
19 #include "api/transport/rtp/dependency_descriptor.h"
20 #include "api/video/video_frame_type.h"
21 #include "modules/video_coding/chain_diff_calculator.h"
22 #include "modules/video_coding/codecs/av1/create_scalability_structure.h"
23 #include "modules/video_coding/codecs/av1/scalable_video_controller.h"
24 #include "modules/video_coding/frame_dependencies_calculator.h"
25 #include "test/gmock.h"
26 #include "test/gtest.h"
27
28 namespace webrtc {
29 namespace {
30
31 using ::testing::AllOf;
32 using ::testing::Contains;
33 using ::testing::Each;
34 using ::testing::Field;
35 using ::testing::Ge;
36 using ::testing::IsEmpty;
37 using ::testing::Le;
38 using ::testing::Lt;
39 using ::testing::Not;
40 using ::testing::SizeIs;
41 using ::testing::TestWithParam;
42 using ::testing::Values;
43
44 struct SvcTestParam {
operator <<(std::ostream & os,const SvcTestParam & param)45 friend std::ostream& operator<<(std::ostream& os, const SvcTestParam& param) {
46 return os << param.name;
47 }
48
49 std::string name;
50 int num_temporal_units;
51 };
52
53 class ScalabilityStructureTest : public TestWithParam<SvcTestParam> {
54 public:
GenerateAllFrames()55 std::vector<GenericFrameInfo> GenerateAllFrames() {
56 std::vector<GenericFrameInfo> frames;
57
58 FrameDependenciesCalculator frame_deps_calculator;
59 ChainDiffCalculator chain_diff_calculator;
60 std::unique_ptr<ScalableVideoController> structure_controller =
61 CreateScalabilityStructure(GetParam().name);
62 FrameDependencyStructure structure =
63 structure_controller->DependencyStructure();
64 for (int i = 0; i < GetParam().num_temporal_units; ++i) {
65 for (auto& layer_frame :
66 structure_controller->NextFrameConfig(/*reset=*/false)) {
67 int64_t frame_id = static_cast<int64_t>(frames.size());
68 bool is_keyframe = layer_frame.IsKeyframe();
69 absl::optional<GenericFrameInfo> frame_info =
70 structure_controller->OnEncodeDone(std::move(layer_frame));
71 EXPECT_TRUE(frame_info.has_value());
72 if (is_keyframe) {
73 chain_diff_calculator.Reset(frame_info->part_of_chain);
74 }
75 frame_info->chain_diffs =
76 chain_diff_calculator.From(frame_id, frame_info->part_of_chain);
77 for (int64_t base_frame_id : frame_deps_calculator.FromBuffersUsage(
78 is_keyframe ? VideoFrameType::kVideoFrameKey
79 : VideoFrameType::kVideoFrameDelta,
80 frame_id, frame_info->encoder_buffers)) {
81 EXPECT_LT(base_frame_id, frame_id);
82 EXPECT_GE(base_frame_id, 0);
83 frame_info->frame_diffs.push_back(frame_id - base_frame_id);
84 }
85
86 frames.push_back(*std::move(frame_info));
87 }
88 }
89 return frames;
90 }
91 };
92
TEST_P(ScalabilityStructureTest,NumberOfDecodeTargetsAndChainsAreInRangeAndConsistent)93 TEST_P(ScalabilityStructureTest,
94 NumberOfDecodeTargetsAndChainsAreInRangeAndConsistent) {
95 FrameDependencyStructure structure =
96 CreateScalabilityStructure(GetParam().name)->DependencyStructure();
97 EXPECT_GT(structure.num_decode_targets, 0);
98 EXPECT_LE(structure.num_decode_targets,
99 DependencyDescriptor::kMaxDecodeTargets);
100 EXPECT_GE(structure.num_chains, 0);
101 EXPECT_LE(structure.num_chains, structure.num_decode_targets);
102 if (structure.num_chains == 0) {
103 EXPECT_THAT(structure.decode_target_protected_by_chain, IsEmpty());
104 } else {
105 EXPECT_THAT(structure.decode_target_protected_by_chain,
106 AllOf(SizeIs(structure.num_decode_targets), Each(Ge(0)),
107 Each(Lt(structure.num_chains))));
108 }
109 EXPECT_THAT(structure.templates,
110 SizeIs(Lt(size_t{DependencyDescriptor::kMaxTemplates})));
111 }
112
TEST_P(ScalabilityStructureTest,TemplatesAreSortedByLayerId)113 TEST_P(ScalabilityStructureTest, TemplatesAreSortedByLayerId) {
114 FrameDependencyStructure structure =
115 CreateScalabilityStructure(GetParam().name)->DependencyStructure();
116 ASSERT_THAT(structure.templates, Not(IsEmpty()));
117 const auto& first_templates = structure.templates.front();
118 EXPECT_EQ(first_templates.spatial_id, 0);
119 EXPECT_EQ(first_templates.temporal_id, 0);
120 for (size_t i = 1; i < structure.templates.size(); ++i) {
121 const auto& prev_template = structure.templates[i - 1];
122 const auto& next_template = structure.templates[i];
123 if (next_template.spatial_id == prev_template.spatial_id &&
124 next_template.temporal_id == prev_template.temporal_id) {
125 // Same layer, next_layer_idc == 0
126 } else if (next_template.spatial_id == prev_template.spatial_id &&
127 next_template.temporal_id == prev_template.temporal_id + 1) {
128 // Next temporal layer, next_layer_idc == 1
129 } else if (next_template.spatial_id == prev_template.spatial_id + 1 &&
130 next_template.temporal_id == 0) {
131 // Next spatial layer, next_layer_idc == 2
132 } else {
133 // everything else is invalid.
134 ADD_FAILURE() << "Invalid templates order. Template #" << i
135 << " with layer (" << next_template.spatial_id << ","
136 << next_template.temporal_id
137 << ") follows template with layer ("
138 << prev_template.spatial_id << ","
139 << prev_template.temporal_id << ").";
140 }
141 }
142 }
143
TEST_P(ScalabilityStructureTest,TemplatesMatchNumberOfDecodeTargetsAndChains)144 TEST_P(ScalabilityStructureTest, TemplatesMatchNumberOfDecodeTargetsAndChains) {
145 FrameDependencyStructure structure =
146 CreateScalabilityStructure(GetParam().name)->DependencyStructure();
147 EXPECT_THAT(
148 structure.templates,
149 Each(AllOf(Field(&FrameDependencyTemplate::decode_target_indications,
150 SizeIs(structure.num_decode_targets)),
151 Field(&FrameDependencyTemplate::chain_diffs,
152 SizeIs(structure.num_chains)))));
153 }
154
TEST_P(ScalabilityStructureTest,FrameInfoMatchesFrameDependencyStructure)155 TEST_P(ScalabilityStructureTest, FrameInfoMatchesFrameDependencyStructure) {
156 FrameDependencyStructure structure =
157 CreateScalabilityStructure(GetParam().name)->DependencyStructure();
158 std::vector<GenericFrameInfo> frame_infos = GenerateAllFrames();
159 for (size_t frame_id = 0; frame_id < frame_infos.size(); ++frame_id) {
160 const auto& frame = frame_infos[frame_id];
161 EXPECT_GE(frame.spatial_id, 0) << " for frame " << frame_id;
162 EXPECT_GE(frame.temporal_id, 0) << " for frame " << frame_id;
163 EXPECT_THAT(frame.decode_target_indications,
164 SizeIs(structure.num_decode_targets))
165 << " for frame " << frame_id;
166 EXPECT_THAT(frame.part_of_chain, SizeIs(structure.num_chains))
167 << " for frame " << frame_id;
168 }
169 }
170
TEST_P(ScalabilityStructureTest,ThereIsAPerfectTemplateForEachFrame)171 TEST_P(ScalabilityStructureTest, ThereIsAPerfectTemplateForEachFrame) {
172 FrameDependencyStructure structure =
173 CreateScalabilityStructure(GetParam().name)->DependencyStructure();
174 std::vector<GenericFrameInfo> frame_infos = GenerateAllFrames();
175 for (size_t frame_id = 0; frame_id < frame_infos.size(); ++frame_id) {
176 EXPECT_THAT(structure.templates, Contains(frame_infos[frame_id]))
177 << " for frame " << frame_id;
178 }
179 }
180
TEST_P(ScalabilityStructureTest,FrameDependsOnSameOrLowerLayer)181 TEST_P(ScalabilityStructureTest, FrameDependsOnSameOrLowerLayer) {
182 std::vector<GenericFrameInfo> frame_infos = GenerateAllFrames();
183 int64_t num_frames = frame_infos.size();
184
185 for (int64_t frame_id = 0; frame_id < num_frames; ++frame_id) {
186 const auto& frame = frame_infos[frame_id];
187 for (int frame_diff : frame.frame_diffs) {
188 int64_t base_frame_id = frame_id - frame_diff;
189 const auto& base_frame = frame_infos[base_frame_id];
190 EXPECT_GE(frame.spatial_id, base_frame.spatial_id)
191 << "Frame " << frame_id << " depends on frame " << base_frame_id;
192 EXPECT_GE(frame.temporal_id, base_frame.temporal_id)
193 << "Frame " << frame_id << " depends on frame " << base_frame_id;
194 }
195 }
196 }
197
TEST_P(ScalabilityStructureTest,NoFrameDependsOnDiscardableOrNotPresent)198 TEST_P(ScalabilityStructureTest, NoFrameDependsOnDiscardableOrNotPresent) {
199 std::vector<GenericFrameInfo> frame_infos = GenerateAllFrames();
200 int64_t num_frames = frame_infos.size();
201 FrameDependencyStructure structure =
202 CreateScalabilityStructure(GetParam().name)->DependencyStructure();
203
204 for (int dt = 0; dt < structure.num_decode_targets; ++dt) {
205 for (int64_t frame_id = 0; frame_id < num_frames; ++frame_id) {
206 const auto& frame = frame_infos[frame_id];
207 if (frame.decode_target_indications[dt] ==
208 DecodeTargetIndication::kNotPresent) {
209 continue;
210 }
211 for (int frame_diff : frame.frame_diffs) {
212 int64_t base_frame_id = frame_id - frame_diff;
213 const auto& base_frame = frame_infos[base_frame_id];
214 EXPECT_NE(base_frame.decode_target_indications[dt],
215 DecodeTargetIndication::kNotPresent)
216 << "Frame " << frame_id << " depends on frame " << base_frame_id
217 << " that is not part of decode target#" << dt;
218 EXPECT_NE(base_frame.decode_target_indications[dt],
219 DecodeTargetIndication::kDiscardable)
220 << "Frame " << frame_id << " depends on frame " << base_frame_id
221 << " that is discardable for decode target#" << dt;
222 }
223 }
224 }
225 }
226
TEST_P(ScalabilityStructureTest,NoFrameDependsThroughSwitchIndication)227 TEST_P(ScalabilityStructureTest, NoFrameDependsThroughSwitchIndication) {
228 FrameDependencyStructure structure =
229 CreateScalabilityStructure(GetParam().name)->DependencyStructure();
230 std::vector<GenericFrameInfo> frame_infos = GenerateAllFrames();
231 int64_t num_frames = frame_infos.size();
232 std::vector<std::set<int64_t>> full_deps(num_frames);
233
234 // For each frame calculate set of all frames it depends on, both directly and
235 // indirectly.
236 for (int64_t frame_id = 0; frame_id < num_frames; ++frame_id) {
237 std::set<int64_t> all_base_frames;
238 for (int frame_diff : frame_infos[frame_id].frame_diffs) {
239 int64_t base_frame_id = frame_id - frame_diff;
240 all_base_frames.insert(base_frame_id);
241 const auto& indirect = full_deps[base_frame_id];
242 all_base_frames.insert(indirect.begin(), indirect.end());
243 }
244 full_deps[frame_id] = std::move(all_base_frames);
245 }
246
247 // Now check the switch indication: frames after the switch indication mustn't
248 // depend on any addition frames before the switch indications.
249 for (int dt = 0; dt < structure.num_decode_targets; ++dt) {
250 for (int64_t switch_frame_id = 0; switch_frame_id < num_frames;
251 ++switch_frame_id) {
252 if (frame_infos[switch_frame_id].decode_target_indications[dt] !=
253 DecodeTargetIndication::kSwitch) {
254 continue;
255 }
256 for (int64_t later_frame_id = switch_frame_id + 1;
257 later_frame_id < num_frames; ++later_frame_id) {
258 if (frame_infos[later_frame_id].decode_target_indications[dt] ==
259 DecodeTargetIndication::kNotPresent) {
260 continue;
261 }
262 for (int frame_diff : frame_infos[later_frame_id].frame_diffs) {
263 int64_t early_frame_id = later_frame_id - frame_diff;
264 if (early_frame_id < switch_frame_id) {
265 EXPECT_THAT(full_deps[switch_frame_id], Contains(early_frame_id))
266 << "For decode target #" << dt << " frame " << later_frame_id
267 << " depends on the frame " << early_frame_id
268 << " that switch indication frame " << switch_frame_id
269 << " doesn't directly on indirectly depend on.";
270 }
271 }
272 }
273 }
274 }
275 }
276
277 INSTANTIATE_TEST_SUITE_P(
278 Svc,
279 ScalabilityStructureTest,
280 Values(SvcTestParam{"L1T2", /*num_temporal_units=*/4},
281 SvcTestParam{"L1T3", /*num_temporal_units=*/8},
282 SvcTestParam{"L2T1", /*num_temporal_units=*/3},
283 SvcTestParam{"L2T1_KEY", /*num_temporal_units=*/3},
284 SvcTestParam{"L3T1", /*num_temporal_units=*/3},
285 SvcTestParam{"L3T3", /*num_temporal_units=*/8},
286 SvcTestParam{"S2T1", /*num_temporal_units=*/3},
287 SvcTestParam{"L2T2", /*num_temporal_units=*/4},
288 SvcTestParam{"L2T2_KEY", /*num_temporal_units=*/4},
289 SvcTestParam{"L2T2_KEY_SHIFT", /*num_temporal_units=*/4}),
__anonda24c3930202(const testing::TestParamInfo<SvcTestParam>& info) 290 [](const testing::TestParamInfo<SvcTestParam>& info) {
291 return info.param.name;
292 });
293
294 } // namespace
295 } // namespace webrtc
296