1 /*
2  * Copyright (C) 2010 The Android Open Source Project
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *      http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #include "include/DRMExtractor.h"
18 
19 #include <arpa/inet.h>
20 #include <utils/String8.h>
21 #include <media/stagefright/foundation/ADebug.h>
22 #include <media/stagefright/Utils.h>
23 #include <media/stagefright/DataSource.h>
24 #include <media/stagefright/MediaSource.h>
25 #include <media/stagefright/MediaDefs.h>
26 #include <media/stagefright/MetaData.h>
27 #include <media/stagefright/MediaErrors.h>
28 #include <media/stagefright/MediaBuffer.h>
29 
30 #include <drm/drm_framework_common.h>
31 #include <utils/Errors.h>
32 
33 
34 namespace android {
35 
36 class DRMSource : public MediaSource {
37 public:
38     DRMSource(const sp<IMediaSource> &mediaSource,
39             const sp<DecryptHandle> &decryptHandle,
40             DrmManagerClient *managerClient,
41             int32_t trackId, DrmBuffer *ipmpBox);
42 
43     virtual status_t start(MetaData *params = NULL);
44     virtual status_t stop();
45     virtual sp<MetaData> getFormat();
46     virtual status_t read(
47             MediaBuffer **buffer, const ReadOptions *options = NULL);
48 
49 protected:
50     virtual ~DRMSource();
51 
52 private:
53     sp<IMediaSource> mOriginalMediaSource;
54     sp<DecryptHandle> mDecryptHandle;
55     DrmManagerClient* mDrmManagerClient;
56     size_t mTrackId;
57     mutable Mutex mDRMLock;
58     size_t mNALLengthSize;
59     bool mWantsNALFragments;
60 
61     DRMSource(const DRMSource &);
62     DRMSource &operator=(const DRMSource &);
63 };
64 
65 ////////////////////////////////////////////////////////////////////////////////
66 
DRMSource(const sp<IMediaSource> & mediaSource,const sp<DecryptHandle> & decryptHandle,DrmManagerClient * managerClient,int32_t trackId,DrmBuffer * ipmpBox)67 DRMSource::DRMSource(const sp<IMediaSource> &mediaSource,
68         const sp<DecryptHandle> &decryptHandle,
69         DrmManagerClient *managerClient,
70         int32_t trackId, DrmBuffer *ipmpBox)
71     : mOriginalMediaSource(mediaSource),
72       mDecryptHandle(decryptHandle),
73       mDrmManagerClient(managerClient),
74       mTrackId(trackId),
75       mNALLengthSize(0),
76       mWantsNALFragments(false) {
77     CHECK(mDrmManagerClient);
78     mDrmManagerClient->initializeDecryptUnit(
79             mDecryptHandle, trackId, ipmpBox);
80 
81     const char *mime;
82     bool success = getFormat()->findCString(kKeyMIMEType, &mime);
83     CHECK(success);
84 
85     if (!strcasecmp(mime, MEDIA_MIMETYPE_VIDEO_AVC)) {
86         uint32_t type;
87         const void *data;
88         size_t size;
89         CHECK(getFormat()->findData(kKeyAVCC, &type, &data, &size));
90 
91         const uint8_t *ptr = (const uint8_t *)data;
92 
93         CHECK(size >= 7);
94         CHECK_EQ(ptr[0], 1);  // configurationVersion == 1
95 
96         // The number of bytes used to encode the length of a NAL unit.
97         mNALLengthSize = 1 + (ptr[4] & 3);
98     }
99 }
100 
~DRMSource()101 DRMSource::~DRMSource() {
102     Mutex::Autolock autoLock(mDRMLock);
103     mDrmManagerClient->finalizeDecryptUnit(mDecryptHandle, mTrackId);
104 }
105 
start(MetaData * params)106 status_t DRMSource::start(MetaData *params) {
107     int32_t val;
108     if (params && params->findInt32(kKeyWantsNALFragments, &val)
109         && val != 0) {
110         mWantsNALFragments = true;
111     } else {
112         mWantsNALFragments = false;
113     }
114 
115    return mOriginalMediaSource->start(params);
116 }
117 
stop()118 status_t DRMSource::stop() {
119     return mOriginalMediaSource->stop();
120 }
121 
getFormat()122 sp<MetaData> DRMSource::getFormat() {
123     return mOriginalMediaSource->getFormat();
124 }
125 
read(MediaBuffer ** buffer,const ReadOptions * options)126 status_t DRMSource::read(MediaBuffer **buffer, const ReadOptions *options) {
127     Mutex::Autolock autoLock(mDRMLock);
128     status_t err;
129     if ((err = mOriginalMediaSource->read(buffer, options)) != OK) {
130         return err;
131     }
132 
133     size_t len = (*buffer)->range_length();
134 
135     char *src = (char *)(*buffer)->data() + (*buffer)->range_offset();
136 
137     DrmBuffer encryptedDrmBuffer(src, len);
138     DrmBuffer decryptedDrmBuffer;
139     decryptedDrmBuffer.length = len;
140     decryptedDrmBuffer.data = new char[len];
141     DrmBuffer *pDecryptedDrmBuffer = &decryptedDrmBuffer;
142 
143     if ((err = mDrmManagerClient->decrypt(mDecryptHandle, mTrackId,
144             &encryptedDrmBuffer, &pDecryptedDrmBuffer)) != NO_ERROR) {
145 
146         if (decryptedDrmBuffer.data) {
147             delete [] decryptedDrmBuffer.data;
148             decryptedDrmBuffer.data = NULL;
149         }
150 
151         return err;
152     }
153     CHECK(pDecryptedDrmBuffer == &decryptedDrmBuffer);
154 
155     const char *mime;
156     CHECK(getFormat()->findCString(kKeyMIMEType, &mime));
157 
158     if (!strcasecmp(mime, MEDIA_MIMETYPE_VIDEO_AVC) && !mWantsNALFragments) {
159         uint8_t *dstData = (uint8_t*)src;
160         size_t srcOffset = 0;
161         size_t dstOffset = 0;
162 
163         len = decryptedDrmBuffer.length;
164         while (srcOffset < len) {
165             CHECK(srcOffset + mNALLengthSize <= len);
166             size_t nalLength = 0;
167             const uint8_t* data = (const uint8_t*)(&decryptedDrmBuffer.data[srcOffset]);
168 
169             switch (mNALLengthSize) {
170                 case 1:
171                     nalLength = *data;
172                     break;
173                 case 2:
174                     nalLength = U16_AT(data);
175                     break;
176                 case 3:
177                     nalLength = ((size_t)data[0] << 16) | U16_AT(&data[1]);
178                     break;
179                 case 4:
180                     nalLength = U32_AT(data);
181                     break;
182                 default:
183                     CHECK(!"Should not be here.");
184                     break;
185             }
186 
187             srcOffset += mNALLengthSize;
188 
189             size_t end = srcOffset + nalLength;
190             if (end > len || end < srcOffset) {
191                 if (decryptedDrmBuffer.data) {
192                     delete [] decryptedDrmBuffer.data;
193                     decryptedDrmBuffer.data = NULL;
194                 }
195 
196                 return ERROR_MALFORMED;
197             }
198 
199             if (nalLength == 0) {
200                 continue;
201             }
202 
203             if (dstOffset > SIZE_MAX - 4 ||
204                 dstOffset + 4 > SIZE_MAX - nalLength ||
205                 dstOffset + 4 + nalLength > (*buffer)->size()) {
206                 (*buffer)->release();
207                 (*buffer) = NULL;
208                 if (decryptedDrmBuffer.data) {
209                     delete [] decryptedDrmBuffer.data;
210                     decryptedDrmBuffer.data = NULL;
211                 }
212                 return ERROR_MALFORMED;
213             }
214 
215             dstData[dstOffset++] = 0;
216             dstData[dstOffset++] = 0;
217             dstData[dstOffset++] = 0;
218             dstData[dstOffset++] = 1;
219             memcpy(&dstData[dstOffset], &decryptedDrmBuffer.data[srcOffset], nalLength);
220             srcOffset += nalLength;
221             dstOffset += nalLength;
222         }
223 
224         CHECK_EQ(srcOffset, len);
225         (*buffer)->set_range((*buffer)->range_offset(), dstOffset);
226 
227     } else {
228         memcpy(src, decryptedDrmBuffer.data, decryptedDrmBuffer.length);
229         (*buffer)->set_range((*buffer)->range_offset(), decryptedDrmBuffer.length);
230     }
231 
232     if (decryptedDrmBuffer.data) {
233         delete [] decryptedDrmBuffer.data;
234         decryptedDrmBuffer.data = NULL;
235     }
236 
237     return OK;
238 }
239 
240 ////////////////////////////////////////////////////////////////////////////////
241 
DRMExtractor(const sp<DataSource> & source,const char * mime)242 DRMExtractor::DRMExtractor(const sp<DataSource> &source, const char* mime)
243     : mDataSource(source),
244       mDecryptHandle(NULL),
245       mDrmManagerClient(NULL) {
246     mOriginalExtractor = MediaExtractor::Create(source, mime);
247     mOriginalExtractor->setDrmFlag(true);
248     mOriginalExtractor->getMetaData()->setInt32(kKeyIsDRM, 1);
249 
250     source->getDrmInfo(mDecryptHandle, &mDrmManagerClient);
251 }
252 
~DRMExtractor()253 DRMExtractor::~DRMExtractor() {
254 }
255 
countTracks()256 size_t DRMExtractor::countTracks() {
257     return mOriginalExtractor->countTracks();
258 }
259 
getTrack(size_t index)260 sp<IMediaSource> DRMExtractor::getTrack(size_t index) {
261     sp<IMediaSource> originalMediaSource = mOriginalExtractor->getTrack(index);
262     originalMediaSource->getFormat()->setInt32(kKeyIsDRM, 1);
263 
264     int32_t trackID;
265     CHECK(getTrackMetaData(index, 0)->findInt32(kKeyTrackID, &trackID));
266 
267     DrmBuffer ipmpBox;
268     ipmpBox.data = mOriginalExtractor->getDrmTrackInfo(trackID, &(ipmpBox.length));
269     CHECK(ipmpBox.length > 0);
270 
271     return interface_cast<IMediaSource>(
272             new DRMSource(originalMediaSource, mDecryptHandle, mDrmManagerClient,
273             trackID, &ipmpBox));
274 }
275 
getTrackMetaData(size_t index,uint32_t flags)276 sp<MetaData> DRMExtractor::getTrackMetaData(size_t index, uint32_t flags) {
277     return mOriginalExtractor->getTrackMetaData(index, flags);
278 }
279 
getMetaData()280 sp<MetaData> DRMExtractor::getMetaData() {
281     return mOriginalExtractor->getMetaData();
282 }
283 
SniffDRM(const sp<DataSource> & source,String8 * mimeType,float * confidence,sp<AMessage> *)284 bool SniffDRM(
285     const sp<DataSource> &source, String8 *mimeType, float *confidence,
286         sp<AMessage> *) {
287     sp<DecryptHandle> decryptHandle = source->DrmInitialization();
288 
289     if (decryptHandle != NULL) {
290         if (decryptHandle->decryptApiType == DecryptApiType::CONTAINER_BASED) {
291             *mimeType = String8("drm+container_based+") + decryptHandle->mimeType;
292             *confidence = 10.0f;
293         } else if (decryptHandle->decryptApiType == DecryptApiType::ELEMENTARY_STREAM_BASED) {
294             *mimeType = String8("drm+es_based+") + decryptHandle->mimeType;
295             *confidence = 10.0f;
296         } else {
297             return false;
298         }
299 
300         return true;
301     }
302 
303     return false;
304 }
305 } //namespace android
306 
307