1 // Copyright (C) 2020 The Android Open Source Project
2 //
3 // Licensed under the Apache License, Version 2.0 (the "License");
4 // you may not use this file except in compliance with the License.
5 // You may obtain a copy of the License at
6 //
7 // http://www.apache.org/licenses/LICENSE-2.0
8 //
9 // Unless required by applicable law or agreed to in writing, software
10 // distributed under the License is distributed on an "AS IS" BASIS,
11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 // See the License for the specific language governing permissions and
13 // limitations under the License.
14 
15 #include "host-common/MediaSnapshotHelper.h"
16 #include "host-common/H264NaluParser.h"
17 #include "host-common/VpxFrameParser.h"
18 
19 #define MEDIA_SNAPSHOT_DEBUG 0
20 
21 #if MEDIA_SNAPSHOT_DEBUG
22 #define SNAPSHOT_DPRINT(fmt, ...)                                       \
23     fprintf(stderr, "media-snapshot-helper: %s:%d " fmt "\n", __func__, \
24             __LINE__, ##__VA_ARGS__);
25 #else
26 #define SNAPSHOT_DPRINT(fmt, ...)
27 #endif
28 
29 namespace android {
30 namespace emulation {
31 
32 using PacketInfo = MediaSnapshotState::PacketInfo;
33 using ColorAspects = MediaSnapshotState::ColorAspects;
34 
savePacket(const uint8_t * frame,size_t szBytes,uint64_t inputPts)35 void MediaSnapshotHelper::savePacket(const uint8_t* frame,
36                                      size_t szBytes,
37                                      uint64_t inputPts) {
38     if (mType == CodecType::H264) {
39         return saveH264Packet(frame, szBytes, inputPts);
40     } else {
41         return saveVPXPacket(frame, szBytes, inputPts);
42     }
43 }
44 
saveVPXPacket(const uint8_t * data,size_t len,uint64_t user_priv)45 void MediaSnapshotHelper::saveVPXPacket(const uint8_t* data,
46                                         size_t len,
47                                         uint64_t user_priv) {
48     const bool enableSnapshot = true;
49     if (enableSnapshot) {
50         VpxFrameParser fparser(mType == CodecType::VP8 ? 8 : 9, data,
51                                (size_t)len);
52         SNAPSHOT_DPRINT("found frame type: %s frame",
53                         fparser.isKeyFrame() ? "KEY" : "NON-KEY");
54         std::vector<uint8_t> v;
55         v.assign(data, data + len);
56         bool isIFrame = fparser.isKeyFrame();
57         if (isIFrame) {
58             mSnapshotState.savedPackets.clear();
59         }
60         const bool saveOK = mSnapshotState.savePacket(v, user_priv);
61         if (saveOK) {
62             SNAPSHOT_DPRINT("saving packet of isze %d; total is %d",
63                             (int)(v.size()),
64                             (int)(mSnapshotState.savedPackets.size()));
65         } else {
66             SNAPSHOT_DPRINT("saving packet; has duplicate, skip; total is %d",
67                             (int)(mSnapshotState.savedPackets.size()));
68         }
69     }
70 }
71 
saveH264Packet(const uint8_t * frame,size_t szBytes,uint64_t inputPts)72 void MediaSnapshotHelper::saveH264Packet(const uint8_t* frame,
73                                          size_t szBytes,
74                                          uint64_t inputPts) {
75     const bool enableSnapshot = true;
76     if (enableSnapshot) {
77         std::vector<uint8_t> v;
78         v.assign(frame, frame + szBytes);
79         bool hasSps = H264NaluParser::checkSpsFrame(frame, szBytes);
80         if (hasSps) {
81             SNAPSHOT_DPRINT("create new snapshot state");
82             MediaSnapshotState newSnapshotState{};
83             // we need to keep the frames, the guest might not have retrieved
84             // them yet; otherwise, we might loose some frames
85             std::swap(newSnapshotState.savedFrames, mSnapshotState.savedFrames);
86             std::swap(newSnapshotState, mSnapshotState);
87             mSnapshotState.saveSps(v);
88         } else {
89             bool hasPps = H264NaluParser::checkPpsFrame(frame, szBytes);
90             if (hasPps) {
91                 mSnapshotState.savePps(v);
92                 mSnapshotState.savedPackets.clear();
93                 mSnapshotState.savedDecodedFrame.data.clear();
94             } else {
95                 bool isIFrame = H264NaluParser::checkIFrame(frame, szBytes);
96                 if (isIFrame) {
97                     mSnapshotState.savedPackets.clear();
98                 }
99                 mSnapshotState.savePacket(std::move(v), inputPts);
100                 SNAPSHOT_DPRINT("saving packet; total is %d",
101                                 (int)(mSnapshotState.savedPackets.size()));
102             }
103         }
104     }
105 }
save(base::Stream * stream) const106 void MediaSnapshotHelper::save(base::Stream* stream) const {
107     SNAPSHOT_DPRINT("saving packets now %d",
108                     (int)(mSnapshotState.savedPackets.size()));
109     if (mType == CodecType::H264) {
110         stream->putBe32(264);
111     } else if (mType == CodecType::VP8) {
112         stream->putBe32(8);
113     } else if (mType == CodecType::VP9) {
114         stream->putBe32(9);
115     }
116     mSnapshotState.save(stream);
117 }
118 
replay(std::function<void (uint8_t * data,size_t len,uint64_t pts)> oneShotDecode)119 void MediaSnapshotHelper::replay(
120         std::function<void(uint8_t* data, size_t len, uint64_t pts)>
121                 oneShotDecode) {
122     if (mSnapshotState.sps.size() > 0) {
123         oneShotDecode(mSnapshotState.sps.data(), mSnapshotState.sps.size(), 0);
124         if (mSnapshotState.pps.size() > 0) {
125             oneShotDecode(mSnapshotState.pps.data(), mSnapshotState.pps.size(),
126                           0);
127             for (int i = 0; i < mSnapshotState.savedPackets.size(); ++i) {
128                 MediaSnapshotState::PacketInfo& pkt =
129                         mSnapshotState.savedPackets[i];
130                 SNAPSHOT_DPRINT("reloading frame %d size %d", i,
131                                 (int)pkt.data.size());
132                 oneShotDecode(pkt.data.data(), pkt.data.size(), pkt.pts);
133             }
134         }
135     }
136 }
137 
load(base::Stream * stream,std::function<void (uint8_t * data,size_t len,uint64_t pts)> oneShotDecode)138 void MediaSnapshotHelper::load(
139         base::Stream* stream,
140         std::function<void(uint8_t* data, size_t len, uint64_t pts)>
141                 oneShotDecode) {
142     int type = stream->getBe32();
143     if (type == 264) {
144         mType = CodecType::H264;
145     } else if (type == 8) {
146         mType = CodecType::VP8;
147     } else if (type == 9) {
148         mType = CodecType::VP9;
149     }
150 
151     mSnapshotState.load(stream);
152 
153     SNAPSHOT_DPRINT("loaded packets %d, now restore decoder",
154                     (int)(mSnapshotState.savedPackets.size()));
155 
156     replay(oneShotDecode);
157 
158     SNAPSHOT_DPRINT("Done loading snapshots frames\n\n");
159 }
160 
161 }  // namespace emulation
162 }  // namespace android
163