1 /*
2  *  Copyright (c) 2012 The WebRTC project authors. All Rights Reserved.
3  *
4  *  Use of this source code is governed by a BSD-style license
5  *  that can be found in the LICENSE file in the root of the source
6  *  tree. An additional intellectual property rights grant can be found
7  *  in the file PATENTS.  All contributing project authors may
8  *  be found in the AUTHORS file in the root of the source tree.
9  */
10 
11 #include "webrtc/test/channel_transport/udp_socket2_manager_win.h"
12 
13 #include <assert.h>
14 #include <stdio.h>
15 
16 #include "webrtc/system_wrappers/include/aligned_malloc.h"
17 #include "webrtc/test/channel_transport/udp_socket2_win.h"
18 
19 namespace webrtc {
20 namespace test {
21 
22 uint32_t UdpSocket2ManagerWindows::_numOfActiveManagers = 0;
23 bool UdpSocket2ManagerWindows::_wsaInit = false;
24 
UdpSocket2ManagerWindows()25 UdpSocket2ManagerWindows::UdpSocket2ManagerWindows()
26     : UdpSocketManager(),
27       _id(-1),
28       _stopped(false),
29       _init(false),
30       _pCrit(CriticalSectionWrapper::CreateCriticalSection()),
31       _ioCompletionHandle(NULL),
32       _numActiveSockets(0),
33       _event(EventWrapper::Create())
34 {
35     _managerNumber = _numOfActiveManagers++;
36 
37     if(_numOfActiveManagers == 1)
38     {
39         WORD wVersionRequested = MAKEWORD(2, 2);
40         WSADATA wsaData;
41         _wsaInit = WSAStartup(wVersionRequested, &wsaData) == 0;
42         // TODO (hellner): seems safer to use RAII for this. E.g. what happens
43         //                 if a UdpSocket2ManagerWindows() created and destroyed
44         //                 without being initialized.
45     }
46 }
47 
~UdpSocket2ManagerWindows()48 UdpSocket2ManagerWindows::~UdpSocket2ManagerWindows()
49 {
50     WEBRTC_TRACE(kTraceDebug, kTraceTransport, _id,
51                  "UdpSocket2ManagerWindows(%d)::~UdpSocket2ManagerWindows()",
52                  _managerNumber);
53 
54     if(_init)
55     {
56         _pCrit->Enter();
57         if(_numActiveSockets)
58         {
59             _pCrit->Leave();
60             _event->Wait(INFINITE);
61         }
62         else
63         {
64             _pCrit->Leave();
65         }
66         StopWorkerThreads();
67 
68         for (WorkerList::iterator iter = _workerThreadsList.begin();
69              iter != _workerThreadsList.end(); ++iter) {
70           delete *iter;
71         }
72         _workerThreadsList.clear();
73         _ioContextPool.Free();
74 
75         _numOfActiveManagers--;
76         if(_ioCompletionHandle)
77         {
78             CloseHandle(_ioCompletionHandle);
79         }
80         if (_numOfActiveManagers == 0)
81         {
82             if(_wsaInit)
83             {
84                 WSACleanup();
85             }
86         }
87     }
88     if(_pCrit)
89     {
90         delete _pCrit;
91     }
92     if(_event)
93     {
94         delete _event;
95     }
96 }
97 
Init(int32_t id,uint8_t & numOfWorkThreads)98 bool UdpSocket2ManagerWindows::Init(int32_t id,
99                                     uint8_t& numOfWorkThreads) {
100   CriticalSectionScoped cs(_pCrit);
101   if ((_id != -1) || (_numOfWorkThreads != 0)) {
102       assert(_id != -1);
103       assert(_numOfWorkThreads != 0);
104       return false;
105   }
106   _id = id;
107   _numOfWorkThreads = numOfWorkThreads;
108   return true;
109 }
110 
Start()111 bool UdpSocket2ManagerWindows::Start()
112 {
113     WEBRTC_TRACE(kTraceDebug, kTraceTransport, _id,
114                  "UdpSocket2ManagerWindows(%d)::Start()",_managerNumber);
115     if(!_init)
116     {
117         StartWorkerThreads();
118     }
119 
120     if(!_init)
121     {
122         return false;
123     }
124     _pCrit->Enter();
125     // Start worker threads.
126     _stopped = false;
127     int32_t error = 0;
128     for (WorkerList::iterator iter = _workerThreadsList.begin();
129          iter != _workerThreadsList.end() && !error; ++iter) {
130       if(!(*iter)->Start())
131         error = 1;
132     }
133     if(error)
134     {
135         WEBRTC_TRACE(
136             kTraceError,
137             kTraceTransport,
138             _id,
139             "UdpSocket2ManagerWindows(%d)::Start() error starting worker\
140  threads",
141             _managerNumber);
142         _pCrit->Leave();
143         return false;
144     }
145     _pCrit->Leave();
146     return true;
147 }
148 
StartWorkerThreads()149 bool UdpSocket2ManagerWindows::StartWorkerThreads()
150 {
151     if(!_init)
152     {
153         _pCrit->Enter();
154 
155         _ioCompletionHandle = CreateIoCompletionPort(INVALID_HANDLE_VALUE, NULL,
156                                                      0, 0);
157         if(_ioCompletionHandle == NULL)
158         {
159             int32_t error = GetLastError();
160             WEBRTC_TRACE(
161                 kTraceError,
162                 kTraceTransport,
163                 _id,
164                 "UdpSocket2ManagerWindows(%d)::StartWorkerThreads()"
165                 "_ioCompletioHandle == NULL: error:%d",
166                 _managerNumber,error);
167             _pCrit->Leave();
168             return false;
169         }
170 
171         // Create worker threads.
172         uint32_t i = 0;
173         bool error = false;
174         while(i < _numOfWorkThreads && !error)
175         {
176             UdpSocket2WorkerWindows* pWorker =
177                 new UdpSocket2WorkerWindows(_ioCompletionHandle);
178             if(pWorker->Init() != 0)
179             {
180                 error = true;
181                 delete pWorker;
182                 break;
183             }
184             _workerThreadsList.push_front(pWorker);
185             i++;
186         }
187         if(error)
188         {
189             WEBRTC_TRACE(
190                 kTraceError,
191                 kTraceTransport,
192                 _id,
193                 "UdpSocket2ManagerWindows(%d)::StartWorkerThreads() error "
194                 "creating work threads",
195                 _managerNumber);
196             // Delete worker threads.
197             for (WorkerList::iterator iter = _workerThreadsList.begin();
198                  iter != _workerThreadsList.end(); ++iter) {
199               delete *iter;
200             }
201             _workerThreadsList.clear();
202             _pCrit->Leave();
203             return false;
204         }
205         if(_ioContextPool.Init())
206         {
207             WEBRTC_TRACE(
208                 kTraceError,
209                 kTraceTransport,
210                 _id,
211                 "UdpSocket2ManagerWindows(%d)::StartWorkerThreads() error "
212                 "initiating _ioContextPool",
213                 _managerNumber);
214             _pCrit->Leave();
215             return false;
216         }
217         _init = true;
218         WEBRTC_TRACE(
219             kTraceDebug,
220             kTraceTransport,
221             _id,
222             "UdpSocket2ManagerWindows::StartWorkerThreads %d number of work "
223             "threads created and initialized",
224             _numOfWorkThreads);
225         _pCrit->Leave();
226     }
227     return true;
228 }
229 
Stop()230 bool UdpSocket2ManagerWindows::Stop()
231 {
232     WEBRTC_TRACE(kTraceDebug, kTraceTransport, _id,
233                  "UdpSocket2ManagerWindows(%d)::Stop()",_managerNumber);
234 
235     if(!_init)
236     {
237         return false;
238     }
239     _pCrit->Enter();
240     _stopped = true;
241     if(_numActiveSockets)
242     {
243         WEBRTC_TRACE(
244             kTraceError,
245             kTraceTransport,
246             _id,
247             "UdpSocket2ManagerWindows(%d)::Stop() there is still active\
248  sockets",
249             _managerNumber);
250         _pCrit->Leave();
251         return false;
252     }
253     // No active sockets. Stop all worker threads.
254     bool result = StopWorkerThreads();
255     _pCrit->Leave();
256     return result;
257 }
258 
StopWorkerThreads()259 bool UdpSocket2ManagerWindows::StopWorkerThreads()
260 {
261     int32_t error = 0;
262     WEBRTC_TRACE(
263         kTraceDebug,
264         kTraceTransport,
265         _id,
266         "UdpSocket2ManagerWindows(%d)::StopWorkerThreads() Worker\
267  threadsStoped, numActicve Sockets=%d",
268         _managerNumber,
269         _numActiveSockets);
270 
271     // Release all threads waiting for GetQueuedCompletionStatus(..).
272     if(_ioCompletionHandle)
273     {
274         uint32_t i = 0;
275         for(i = 0; i < _workerThreadsList.size(); i++)
276         {
277             PostQueuedCompletionStatus(_ioCompletionHandle, 0 ,0 , NULL);
278         }
279     }
280     for (WorkerList::iterator iter = _workerThreadsList.begin();
281          iter != _workerThreadsList.end(); ++iter) {
282         if((*iter)->Stop() == false)
283         {
284             error = -1;
285             WEBRTC_TRACE(kTraceWarning,  kTraceTransport, -1,
286                          "failed to stop worker thread");
287         }
288     }
289 
290     if(error)
291     {
292         WEBRTC_TRACE(
293             kTraceError,
294             kTraceTransport,
295             _id,
296             "UdpSocket2ManagerWindows(%d)::StopWorkerThreads() error stopping\
297  worker threads",
298             _managerNumber);
299         return false;
300     }
301     return true;
302 }
303 
AddSocketPrv(UdpSocket2Windows * s)304 bool UdpSocket2ManagerWindows::AddSocketPrv(UdpSocket2Windows* s)
305 {
306     WEBRTC_TRACE(kTraceDebug, kTraceTransport, _id,
307                  "UdpSocket2ManagerWindows(%d)::AddSocketPrv()",_managerNumber);
308     if(!_init)
309     {
310         WEBRTC_TRACE(
311             kTraceError,
312             kTraceTransport,
313             _id,
314             "UdpSocket2ManagerWindows(%d)::AddSocketPrv() manager not\
315  initialized",
316             _managerNumber);
317         return false;
318     }
319     _pCrit->Enter();
320     if(s == NULL)
321     {
322         WEBRTC_TRACE(
323             kTraceError,
324             kTraceTransport,
325             _id,
326             "UdpSocket2ManagerWindows(%d)::AddSocketPrv() socket == NULL",
327             _managerNumber);
328         _pCrit->Leave();
329         return false;
330     }
331     if(s->GetFd() == NULL || s->GetFd() == INVALID_SOCKET)
332     {
333         WEBRTC_TRACE(
334             kTraceError,
335             kTraceTransport,
336             _id,
337             "UdpSocket2ManagerWindows(%d)::AddSocketPrv() socket->GetFd() ==\
338  %d",
339             _managerNumber,
340             (int32_t)s->GetFd());
341         _pCrit->Leave();
342         return false;
343 
344     }
345     _ioCompletionHandle = CreateIoCompletionPort((HANDLE)s->GetFd(),
346                                                  _ioCompletionHandle,
347                                                  (ULONG_PTR)(s), 0);
348     if(_ioCompletionHandle == NULL)
349     {
350         int32_t error = GetLastError();
351         WEBRTC_TRACE(
352             kTraceError,
353             kTraceTransport,
354             _id,
355             "UdpSocket2ManagerWindows(%d)::AddSocketPrv() Error adding to IO\
356  completion: %d",
357             _managerNumber,
358             error);
359         _pCrit->Leave();
360         return false;
361     }
362     _numActiveSockets++;
363     _pCrit->Leave();
364     return true;
365 }
RemoveSocketPrv(UdpSocket2Windows * s)366 bool UdpSocket2ManagerWindows::RemoveSocketPrv(UdpSocket2Windows* s)
367 {
368     if(!_init)
369     {
370         return false;
371     }
372     _pCrit->Enter();
373     _numActiveSockets--;
374     if(_numActiveSockets == 0)
375     {
376         _event->Set();
377     }
378     _pCrit->Leave();
379     return true;
380 }
381 
PopIoContext()382 PerIoContext* UdpSocket2ManagerWindows::PopIoContext()
383 {
384     if(!_init)
385     {
386         return NULL;
387     }
388 
389     PerIoContext* pIoC = NULL;
390     if(!_stopped)
391     {
392         pIoC = _ioContextPool.PopIoContext();
393     }else
394     {
395         WEBRTC_TRACE(
396             kTraceError,
397             kTraceTransport,
398             _id,
399             "UdpSocket2ManagerWindows(%d)::PopIoContext() Manager Not started",
400             _managerNumber);
401     }
402     return pIoC;
403 }
404 
PushIoContext(PerIoContext * pIoContext)405 int32_t UdpSocket2ManagerWindows::PushIoContext(PerIoContext* pIoContext)
406 {
407     return _ioContextPool.PushIoContext(pIoContext);
408 }
409 
IoContextPool()410 IoContextPool::IoContextPool()
411     : _pListHead(NULL),
412       _init(false),
413       _size(0),
414       _inUse(0)
415 {
416 }
417 
~IoContextPool()418 IoContextPool::~IoContextPool()
419 {
420     Free();
421     assert(_size.Value() == 0);
422     AlignedFree(_pListHead);
423 }
424 
Init(uint32_t)425 int32_t IoContextPool::Init(uint32_t /*increaseSize*/)
426 {
427     if(_init)
428     {
429         return 0;
430     }
431 
432     _pListHead = (PSLIST_HEADER)AlignedMalloc(sizeof(SLIST_HEADER),
433                                               MEMORY_ALLOCATION_ALIGNMENT);
434     if(_pListHead == NULL)
435     {
436         return -1;
437     }
438     InitializeSListHead(_pListHead);
439     _init = true;
440     return 0;
441 }
442 
PopIoContext()443 PerIoContext* IoContextPool::PopIoContext()
444 {
445     if(!_init)
446     {
447         return NULL;
448     }
449 
450     PSLIST_ENTRY pListEntry = InterlockedPopEntrySList(_pListHead);
451     if(pListEntry == NULL)
452     {
453         IoContextPoolItem* item = (IoContextPoolItem*)
454             AlignedMalloc(
455                 sizeof(IoContextPoolItem),
456                 MEMORY_ALLOCATION_ALIGNMENT);
457         if(item == NULL)
458         {
459             return NULL;
460         }
461         memset(&item->payload.ioContext,0,sizeof(PerIoContext));
462         item->payload.base = item;
463         pListEntry = &(item->itemEntry);
464         ++_size;
465     }
466     ++_inUse;
467     return &((IoContextPoolItem*)pListEntry)->payload.ioContext;
468 }
469 
PushIoContext(PerIoContext * pIoContext)470 int32_t IoContextPool::PushIoContext(PerIoContext* pIoContext)
471 {
472     // TODO (hellner): Overlapped IO should be completed at this point. Perhaps
473     //                 add an assert?
474     const bool overlappedIOCompleted = HasOverlappedIoCompleted(
475         (LPOVERLAPPED)pIoContext);
476 
477     IoContextPoolItem* item = ((IoContextPoolItemPayload*)pIoContext)->base;
478 
479     const int32_t usedItems = --_inUse;
480     const int32_t totalItems = _size.Value();
481     const int32_t freeItems = totalItems - usedItems;
482     if(freeItems < 0)
483     {
484         assert(false);
485         AlignedFree(item);
486         return -1;
487     }
488     if((freeItems >= totalItems>>1) &&
489         overlappedIOCompleted)
490     {
491         AlignedFree(item);
492         --_size;
493         return 0;
494     }
495     InterlockedPushEntrySList(_pListHead, &(item->itemEntry));
496     return 0;
497 }
498 
Free()499 int32_t IoContextPool::Free()
500 {
501     if(!_init)
502     {
503         return 0;
504     }
505 
506     int32_t itemsFreed = 0;
507     PSLIST_ENTRY pListEntry = InterlockedPopEntrySList(_pListHead);
508     while(pListEntry != NULL)
509     {
510         IoContextPoolItem* item = ((IoContextPoolItem*)pListEntry);
511         AlignedFree(item);
512         --_size;
513         itemsFreed++;
514         pListEntry = InterlockedPopEntrySList(_pListHead);
515     }
516     return itemsFreed;
517 }
518 
519 int32_t UdpSocket2WorkerWindows::_numOfWorkers = 0;
520 
UdpSocket2WorkerWindows(HANDLE ioCompletionHandle)521 UdpSocket2WorkerWindows::UdpSocket2WorkerWindows(HANDLE ioCompletionHandle)
522     : _ioCompletionHandle(ioCompletionHandle),
523       _pThread(Run, this, "UdpSocket2ManagerWindows_thread"),
524       _init(false) {
525     _workerNumber = _numOfWorkers++;
526     WEBRTC_TRACE(kTraceMemory,  kTraceTransport, -1,
527                  "UdpSocket2WorkerWindows created");
528 }
529 
~UdpSocket2WorkerWindows()530 UdpSocket2WorkerWindows::~UdpSocket2WorkerWindows()
531 {
532     WEBRTC_TRACE(kTraceMemory,  kTraceTransport, -1,
533                  "UdpSocket2WorkerWindows deleted");
534 }
535 
Start()536 bool UdpSocket2WorkerWindows::Start()
537 {
538     WEBRTC_TRACE(kTraceStateInfo,  kTraceTransport, -1,
539                  "Start UdpSocket2WorkerWindows");
540     _pThread.Start();
541 
542     _pThread.SetPriority(rtc::kRealtimePriority);
543     return true;
544 }
545 
Stop()546 bool UdpSocket2WorkerWindows::Stop()
547 {
548     WEBRTC_TRACE(kTraceStateInfo,  kTraceTransport, -1,
549                  "Stop UdpSocket2WorkerWindows");
550     _pThread.Stop();
551     return true;
552 }
553 
Init()554 int32_t UdpSocket2WorkerWindows::Init()
555 {
556   _init = true;
557   return 0;
558 }
559 
Run(void * obj)560 bool UdpSocket2WorkerWindows::Run(void* obj)
561 {
562     UdpSocket2WorkerWindows* pWorker =
563         static_cast<UdpSocket2WorkerWindows*>(obj);
564     return pWorker->Process();
565 }
566 
567 // Process should always return true. Stopping the worker threads is done in
568 // the UdpSocket2ManagerWindows::StopWorkerThreads() function.
Process()569 bool UdpSocket2WorkerWindows::Process()
570 {
571     int32_t success = 0;
572     DWORD ioSize = 0;
573     UdpSocket2Windows* pSocket = NULL;
574     PerIoContext* pIOContext = 0;
575     OVERLAPPED* pOverlapped = 0;
576     success = GetQueuedCompletionStatus(_ioCompletionHandle,
577                                         &ioSize,
578                                        (ULONG_PTR*)&pSocket, &pOverlapped, 200);
579 
580     uint32_t error = 0;
581     if(!success)
582     {
583         error = GetLastError();
584         if(error == WAIT_TIMEOUT)
585         {
586             return true;
587         }
588         // This may happen if e.g. PostQueuedCompletionStatus() has been called.
589         // The IO context still needs to be reclaimed or re-used which is done
590         // in UdpSocket2Windows::IOCompleted(..).
591     }
592     if(pSocket == NULL)
593     {
594         WEBRTC_TRACE(
595             kTraceDebug,
596             kTraceTransport,
597             -1,
598             "UdpSocket2WorkerWindows(%d)::Process(), pSocket == 0, end thread",
599             _workerNumber);
600         return true;
601     }
602     pIOContext = (PerIoContext*)pOverlapped;
603     pSocket->IOCompleted(pIOContext,ioSize,error);
604     return true;
605 }
606 
607 }  // namespace test
608 }  // namespace webrtc
609