1 /*
2  *  Copyright 2012 The WebRTC project authors. All Rights Reserved.
3  *
4  *  Use of this source code is governed by a BSD-style license
5  *  that can be found in the LICENSE file in the root of the source
6  *  tree. An additional intellectual property rights grant can be found
7  *  in the file PATENTS.  All contributing project authors may
8  *  be found in the AUTHORS file in the root of the source tree.
9  */
10 
11 // This file contains mock implementations of observers used in PeerConnection.
12 // TODO(steveanton): These aren't really mocks and should be renamed.
13 
14 #ifndef PC_TEST_MOCK_PEER_CONNECTION_OBSERVERS_H_
15 #define PC_TEST_MOCK_PEER_CONNECTION_OBSERVERS_H_
16 
17 #include <map>
18 #include <memory>
19 #include <string>
20 #include <utility>
21 #include <vector>
22 
23 #include "api/data_channel_interface.h"
24 #include "api/jsep_ice_candidate.h"
25 #include "pc/stream_collection.h"
26 #include "rtc_base/checks.h"
27 
28 namespace webrtc {
29 
30 class MockPeerConnectionObserver : public PeerConnectionObserver {
31  public:
32   struct AddTrackEvent {
AddTrackEventAddTrackEvent33     explicit AddTrackEvent(
34         rtc::scoped_refptr<RtpReceiverInterface> event_receiver,
35         std::vector<rtc::scoped_refptr<MediaStreamInterface>> event_streams)
36         : receiver(std::move(event_receiver)),
37           streams(std::move(event_streams)) {
38       for (auto stream : streams) {
39         std::vector<rtc::scoped_refptr<MediaStreamTrackInterface>> tracks;
40         for (auto audio_track : stream->GetAudioTracks()) {
41           tracks.push_back(audio_track);
42         }
43         for (auto video_track : stream->GetVideoTracks()) {
44           tracks.push_back(video_track);
45         }
46         snapshotted_stream_tracks[stream] = tracks;
47       }
48     }
49 
50     rtc::scoped_refptr<RtpReceiverInterface> receiver;
51     std::vector<rtc::scoped_refptr<MediaStreamInterface>> streams;
52     // This map records the tracks present in each stream at the time the
53     // OnAddTrack callback was issued.
54     std::map<rtc::scoped_refptr<MediaStreamInterface>,
55              std::vector<rtc::scoped_refptr<MediaStreamTrackInterface>>>
56         snapshotted_stream_tracks;
57   };
58 
MockPeerConnectionObserver()59   MockPeerConnectionObserver() : remote_streams_(StreamCollection::Create()) {}
~MockPeerConnectionObserver()60   virtual ~MockPeerConnectionObserver() {}
SetPeerConnectionInterface(PeerConnectionInterface * pc)61   void SetPeerConnectionInterface(PeerConnectionInterface* pc) {
62     pc_ = pc;
63     if (pc) {
64       state_ = pc_->signaling_state();
65     }
66   }
OnSignalingChange(PeerConnectionInterface::SignalingState new_state)67   void OnSignalingChange(
68       PeerConnectionInterface::SignalingState new_state) override {
69     RTC_DCHECK(pc_);
70     RTC_DCHECK(pc_->signaling_state() == new_state);
71     state_ = new_state;
72   }
73 
RemoteStream(const std::string & label)74   MediaStreamInterface* RemoteStream(const std::string& label) {
75     return remote_streams_->find(label);
76   }
remote_streams()77   StreamCollectionInterface* remote_streams() const { return remote_streams_; }
OnAddStream(rtc::scoped_refptr<MediaStreamInterface> stream)78   void OnAddStream(rtc::scoped_refptr<MediaStreamInterface> stream) override {
79     last_added_stream_ = stream;
80     remote_streams_->AddStream(stream);
81   }
OnRemoveStream(rtc::scoped_refptr<MediaStreamInterface> stream)82   void OnRemoveStream(
83       rtc::scoped_refptr<MediaStreamInterface> stream) override {
84     last_removed_stream_ = stream;
85     remote_streams_->RemoveStream(stream);
86   }
OnRenegotiationNeeded()87   void OnRenegotiationNeeded() override { renegotiation_needed_ = true; }
OnDataChannel(rtc::scoped_refptr<DataChannelInterface> data_channel)88   void OnDataChannel(
89       rtc::scoped_refptr<DataChannelInterface> data_channel) override {
90     last_datachannel_ = data_channel;
91   }
92 
OnIceConnectionChange(PeerConnectionInterface::IceConnectionState new_state)93   void OnIceConnectionChange(
94       PeerConnectionInterface::IceConnectionState new_state) override {
95     RTC_DCHECK(pc_);
96     RTC_DCHECK(pc_->ice_connection_state() == new_state);
97     // When ICE is finished, the caller will get to a kIceConnectionCompleted
98     // state, because it has the ICE controlling role, while the callee
99     // will get to a kIceConnectionConnected state. This means that both ICE
100     // and DTLS are connected.
101     ice_connected_ =
102         (new_state == PeerConnectionInterface::kIceConnectionConnected) ||
103         (new_state == PeerConnectionInterface::kIceConnectionCompleted);
104     callback_triggered_ = true;
105   }
OnIceGatheringChange(PeerConnectionInterface::IceGatheringState new_state)106   void OnIceGatheringChange(
107       PeerConnectionInterface::IceGatheringState new_state) override {
108     RTC_DCHECK(pc_);
109     RTC_DCHECK(pc_->ice_gathering_state() == new_state);
110     ice_gathering_complete_ =
111         new_state == PeerConnectionInterface::kIceGatheringComplete;
112     callback_triggered_ = true;
113   }
OnIceCandidate(const IceCandidateInterface * candidate)114   void OnIceCandidate(const IceCandidateInterface* candidate) override {
115     RTC_DCHECK(pc_);
116     RTC_DCHECK(PeerConnectionInterface::kIceGatheringNew !=
117                pc_->ice_gathering_state());
118     candidates_.push_back(std::make_unique<JsepIceCandidate>(
119         candidate->sdp_mid(), candidate->sdp_mline_index(),
120         candidate->candidate()));
121     callback_triggered_ = true;
122   }
123 
OnIceCandidatesRemoved(const std::vector<cricket::Candidate> & candidates)124   void OnIceCandidatesRemoved(
125       const std::vector<cricket::Candidate>& candidates) override {
126     num_candidates_removed_++;
127     callback_triggered_ = true;
128   }
129 
OnIceConnectionReceivingChange(bool receiving)130   void OnIceConnectionReceivingChange(bool receiving) override {
131     callback_triggered_ = true;
132   }
133 
OnAddTrack(rtc::scoped_refptr<RtpReceiverInterface> receiver,const std::vector<rtc::scoped_refptr<MediaStreamInterface>> & streams)134   void OnAddTrack(rtc::scoped_refptr<RtpReceiverInterface> receiver,
135                   const std::vector<rtc::scoped_refptr<MediaStreamInterface>>&
136                       streams) override {
137     RTC_DCHECK(receiver);
138     num_added_tracks_++;
139     last_added_track_label_ = receiver->id();
140     add_track_events_.push_back(AddTrackEvent(receiver, streams));
141   }
142 
OnTrack(rtc::scoped_refptr<RtpTransceiverInterface> transceiver)143   void OnTrack(
144       rtc::scoped_refptr<RtpTransceiverInterface> transceiver) override {
145     on_track_transceivers_.push_back(transceiver);
146   }
147 
OnRemoveTrack(rtc::scoped_refptr<RtpReceiverInterface> receiver)148   void OnRemoveTrack(
149       rtc::scoped_refptr<RtpReceiverInterface> receiver) override {
150     remove_track_events_.push_back(receiver);
151   }
152 
GetAddTrackReceivers()153   std::vector<rtc::scoped_refptr<RtpReceiverInterface>> GetAddTrackReceivers() {
154     std::vector<rtc::scoped_refptr<RtpReceiverInterface>> receivers;
155     for (const AddTrackEvent& event : add_track_events_) {
156       receivers.push_back(event.receiver);
157     }
158     return receivers;
159   }
160 
CountAddTrackEventsForStream(const std::string & stream_id)161   int CountAddTrackEventsForStream(const std::string& stream_id) {
162     int found_tracks = 0;
163     for (const AddTrackEvent& event : add_track_events_) {
164       bool has_stream_id = false;
165       for (auto stream : event.streams) {
166         if (stream->id() == stream_id) {
167           has_stream_id = true;
168           break;
169         }
170       }
171       if (has_stream_id) {
172         ++found_tracks;
173       }
174     }
175     return found_tracks;
176   }
177 
178   // Returns the id of the last added stream.
179   // Empty string if no stream have been added.
GetLastAddedStreamId()180   std::string GetLastAddedStreamId() {
181     if (last_added_stream_.get())
182       return last_added_stream_->id();
183     return "";
184   }
GetLastRemovedStreamId()185   std::string GetLastRemovedStreamId() {
186     if (last_removed_stream_.get())
187       return last_removed_stream_->id();
188     return "";
189   }
190 
last_candidate()191   IceCandidateInterface* last_candidate() {
192     if (candidates_.empty()) {
193       return nullptr;
194     } else {
195       return candidates_.back().get();
196     }
197   }
198 
GetAllCandidates()199   std::vector<const IceCandidateInterface*> GetAllCandidates() {
200     std::vector<const IceCandidateInterface*> candidates;
201     for (const auto& candidate : candidates_) {
202       candidates.push_back(candidate.get());
203     }
204     return candidates;
205   }
206 
GetCandidatesByMline(int mline_index)207   std::vector<IceCandidateInterface*> GetCandidatesByMline(int mline_index) {
208     std::vector<IceCandidateInterface*> candidates;
209     for (const auto& candidate : candidates_) {
210       if (candidate->sdp_mline_index() == mline_index) {
211         candidates.push_back(candidate.get());
212       }
213     }
214     return candidates;
215   }
216 
negotiation_needed()217   bool negotiation_needed() const { return renegotiation_needed_; }
clear_negotiation_needed()218   void clear_negotiation_needed() { renegotiation_needed_ = false; }
219 
220   rtc::scoped_refptr<PeerConnectionInterface> pc_;
221   PeerConnectionInterface::SignalingState state_;
222   std::vector<std::unique_ptr<IceCandidateInterface>> candidates_;
223   rtc::scoped_refptr<DataChannelInterface> last_datachannel_;
224   rtc::scoped_refptr<StreamCollection> remote_streams_;
225   bool renegotiation_needed_ = false;
226   bool ice_gathering_complete_ = false;
227   bool ice_connected_ = false;
228   bool callback_triggered_ = false;
229   int num_added_tracks_ = 0;
230   std::string last_added_track_label_;
231   std::vector<AddTrackEvent> add_track_events_;
232   std::vector<rtc::scoped_refptr<RtpReceiverInterface>> remove_track_events_;
233   std::vector<rtc::scoped_refptr<RtpTransceiverInterface>>
234       on_track_transceivers_;
235   int num_candidates_removed_ = 0;
236 
237  private:
238   rtc::scoped_refptr<MediaStreamInterface> last_added_stream_;
239   rtc::scoped_refptr<MediaStreamInterface> last_removed_stream_;
240 };
241 
242 class MockCreateSessionDescriptionObserver
243     : public webrtc::CreateSessionDescriptionObserver {
244  public:
MockCreateSessionDescriptionObserver()245   MockCreateSessionDescriptionObserver()
246       : called_(false),
247         error_("MockCreateSessionDescriptionObserver not called") {}
~MockCreateSessionDescriptionObserver()248   virtual ~MockCreateSessionDescriptionObserver() {}
OnSuccess(SessionDescriptionInterface * desc)249   void OnSuccess(SessionDescriptionInterface* desc) override {
250     called_ = true;
251     error_ = "";
252     desc_.reset(desc);
253   }
OnFailure(webrtc::RTCError error)254   void OnFailure(webrtc::RTCError error) override {
255     called_ = true;
256     error_ = error.message();
257   }
called()258   bool called() const { return called_; }
result()259   bool result() const { return error_.empty(); }
error()260   const std::string& error() const { return error_; }
MoveDescription()261   std::unique_ptr<SessionDescriptionInterface> MoveDescription() {
262     return std::move(desc_);
263   }
264 
265  private:
266   bool called_;
267   std::string error_;
268   std::unique_ptr<SessionDescriptionInterface> desc_;
269 };
270 
271 class MockSetSessionDescriptionObserver
272     : public webrtc::SetSessionDescriptionObserver {
273  public:
Create()274   static rtc::scoped_refptr<MockSetSessionDescriptionObserver> Create() {
275     return new rtc::RefCountedObject<MockSetSessionDescriptionObserver>();
276   }
277 
MockSetSessionDescriptionObserver()278   MockSetSessionDescriptionObserver()
279       : called_(false),
280         error_("MockSetSessionDescriptionObserver not called") {}
~MockSetSessionDescriptionObserver()281   ~MockSetSessionDescriptionObserver() override {}
OnSuccess()282   void OnSuccess() override {
283     called_ = true;
284     error_ = "";
285   }
OnFailure(webrtc::RTCError error)286   void OnFailure(webrtc::RTCError error) override {
287     called_ = true;
288     error_ = error.message();
289   }
290 
called()291   bool called() const { return called_; }
result()292   bool result() const { return error_.empty(); }
error()293   const std::string& error() const { return error_; }
294 
295  private:
296   bool called_;
297   std::string error_;
298 };
299 
300 class MockSetRemoteDescriptionObserver
301     : public rtc::RefCountedObject<SetRemoteDescriptionObserverInterface> {
302  public:
called()303   bool called() const { return error_.has_value(); }
error()304   RTCError& error() {
305     RTC_DCHECK(error_.has_value());
306     return *error_;
307   }
308 
309   // SetRemoteDescriptionObserverInterface implementation.
OnSetRemoteDescriptionComplete(RTCError error)310   void OnSetRemoteDescriptionComplete(RTCError error) override {
311     error_ = std::move(error);
312   }
313 
314  private:
315   // Set on complete, on success this is set to an RTCError::OK() error.
316   absl::optional<RTCError> error_;
317 };
318 
319 class MockDataChannelObserver : public webrtc::DataChannelObserver {
320  public:
MockDataChannelObserver(webrtc::DataChannelInterface * channel)321   explicit MockDataChannelObserver(webrtc::DataChannelInterface* channel)
322       : channel_(channel) {
323     channel_->RegisterObserver(this);
324     state_ = channel_->state();
325   }
~MockDataChannelObserver()326   virtual ~MockDataChannelObserver() { channel_->UnregisterObserver(); }
327 
OnBufferedAmountChange(uint64_t previous_amount)328   void OnBufferedAmountChange(uint64_t previous_amount) override {}
329 
OnStateChange()330   void OnStateChange() override { state_ = channel_->state(); }
OnMessage(const DataBuffer & buffer)331   void OnMessage(const DataBuffer& buffer) override {
332     messages_.push_back(
333         std::string(buffer.data.data<char>(), buffer.data.size()));
334   }
335 
IsOpen()336   bool IsOpen() const { return state_ == DataChannelInterface::kOpen; }
messages()337   std::vector<std::string> messages() const { return messages_; }
last_message()338   std::string last_message() const {
339     return messages_.empty() ? std::string() : messages_.back();
340   }
received_message_count()341   size_t received_message_count() const { return messages_.size(); }
342 
343  private:
344   rtc::scoped_refptr<webrtc::DataChannelInterface> channel_;
345   DataChannelInterface::DataState state_;
346   std::vector<std::string> messages_;
347 };
348 
349 class MockStatsObserver : public webrtc::StatsObserver {
350  public:
MockStatsObserver()351   MockStatsObserver() : called_(false), stats_() {}
~MockStatsObserver()352   virtual ~MockStatsObserver() {}
353 
OnComplete(const StatsReports & reports)354   virtual void OnComplete(const StatsReports& reports) {
355     RTC_CHECK(!called_);
356     called_ = true;
357     stats_.Clear();
358     stats_.number_of_reports = reports.size();
359     for (const auto* r : reports) {
360       if (r->type() == StatsReport::kStatsReportTypeSsrc) {
361         stats_.timestamp = r->timestamp();
362         GetIntValue(r, StatsReport::kStatsValueNameAudioOutputLevel,
363                     &stats_.audio_output_level);
364         GetIntValue(r, StatsReport::kStatsValueNameAudioInputLevel,
365                     &stats_.audio_input_level);
366         GetIntValue(r, StatsReport::kStatsValueNameBytesReceived,
367                     &stats_.bytes_received);
368         GetIntValue(r, StatsReport::kStatsValueNameBytesSent,
369                     &stats_.bytes_sent);
370         GetInt64Value(r, StatsReport::kStatsValueNameCaptureStartNtpTimeMs,
371                       &stats_.capture_start_ntp_time);
372         stats_.track_ids.emplace_back();
373         GetStringValue(r, StatsReport::kStatsValueNameTrackId,
374                        &stats_.track_ids.back());
375       } else if (r->type() == StatsReport::kStatsReportTypeBwe) {
376         stats_.timestamp = r->timestamp();
377         GetIntValue(r, StatsReport::kStatsValueNameAvailableReceiveBandwidth,
378                     &stats_.available_receive_bandwidth);
379       } else if (r->type() == StatsReport::kStatsReportTypeComponent) {
380         stats_.timestamp = r->timestamp();
381         GetStringValue(r, StatsReport::kStatsValueNameDtlsCipher,
382                        &stats_.dtls_cipher);
383         GetStringValue(r, StatsReport::kStatsValueNameSrtpCipher,
384                        &stats_.srtp_cipher);
385       }
386     }
387   }
388 
called()389   bool called() const { return called_; }
number_of_reports()390   size_t number_of_reports() const { return stats_.number_of_reports; }
timestamp()391   double timestamp() const { return stats_.timestamp; }
392 
AudioOutputLevel()393   int AudioOutputLevel() const {
394     RTC_CHECK(called_);
395     return stats_.audio_output_level;
396   }
397 
AudioInputLevel()398   int AudioInputLevel() const {
399     RTC_CHECK(called_);
400     return stats_.audio_input_level;
401   }
402 
BytesReceived()403   int BytesReceived() const {
404     RTC_CHECK(called_);
405     return stats_.bytes_received;
406   }
407 
BytesSent()408   int BytesSent() const {
409     RTC_CHECK(called_);
410     return stats_.bytes_sent;
411   }
412 
CaptureStartNtpTime()413   int64_t CaptureStartNtpTime() const {
414     RTC_CHECK(called_);
415     return stats_.capture_start_ntp_time;
416   }
417 
AvailableReceiveBandwidth()418   int AvailableReceiveBandwidth() const {
419     RTC_CHECK(called_);
420     return stats_.available_receive_bandwidth;
421   }
422 
DtlsCipher()423   std::string DtlsCipher() const {
424     RTC_CHECK(called_);
425     return stats_.dtls_cipher;
426   }
427 
SrtpCipher()428   std::string SrtpCipher() const {
429     RTC_CHECK(called_);
430     return stats_.srtp_cipher;
431   }
432 
TrackIds()433   std::vector<std::string> TrackIds() const {
434     RTC_CHECK(called_);
435     return stats_.track_ids;
436   }
437 
438  private:
GetIntValue(const StatsReport * report,StatsReport::StatsValueName name,int * value)439   bool GetIntValue(const StatsReport* report,
440                    StatsReport::StatsValueName name,
441                    int* value) {
442     const StatsReport::Value* v = report->FindValue(name);
443     if (v) {
444       // TODO(tommi): We should really just be using an int here :-/
445       *value = rtc::FromString<int>(v->ToString());
446     }
447     return v != nullptr;
448   }
449 
GetInt64Value(const StatsReport * report,StatsReport::StatsValueName name,int64_t * value)450   bool GetInt64Value(const StatsReport* report,
451                      StatsReport::StatsValueName name,
452                      int64_t* value) {
453     const StatsReport::Value* v = report->FindValue(name);
454     if (v) {
455       // TODO(tommi): We should really just be using an int here :-/
456       *value = rtc::FromString<int64_t>(v->ToString());
457     }
458     return v != nullptr;
459   }
460 
GetStringValue(const StatsReport * report,StatsReport::StatsValueName name,std::string * value)461   bool GetStringValue(const StatsReport* report,
462                       StatsReport::StatsValueName name,
463                       std::string* value) {
464     const StatsReport::Value* v = report->FindValue(name);
465     if (v)
466       *value = v->ToString();
467     return v != nullptr;
468   }
469 
470   bool called_;
471   struct {
Clear__anon763bc74c0108472     void Clear() {
473       number_of_reports = 0;
474       timestamp = 0;
475       audio_output_level = 0;
476       audio_input_level = 0;
477       bytes_received = 0;
478       bytes_sent = 0;
479       capture_start_ntp_time = 0;
480       available_receive_bandwidth = 0;
481       dtls_cipher.clear();
482       srtp_cipher.clear();
483       track_ids.clear();
484     }
485 
486     size_t number_of_reports;
487     double timestamp;
488     int audio_output_level;
489     int audio_input_level;
490     int bytes_received;
491     int bytes_sent;
492     int64_t capture_start_ntp_time;
493     int available_receive_bandwidth;
494     std::string dtls_cipher;
495     std::string srtp_cipher;
496     std::vector<std::string> track_ids;
497   } stats_;
498 };
499 
500 // Helper class that just stores the report from the callback.
501 class MockRTCStatsCollectorCallback : public webrtc::RTCStatsCollectorCallback {
502  public:
report()503   rtc::scoped_refptr<const RTCStatsReport> report() { return report_; }
504 
called()505   bool called() const { return called_; }
506 
507  protected:
OnStatsDelivered(const rtc::scoped_refptr<const RTCStatsReport> & report)508   void OnStatsDelivered(
509       const rtc::scoped_refptr<const RTCStatsReport>& report) override {
510     report_ = report;
511     called_ = true;
512   }
513 
514  private:
515   bool called_ = false;
516   rtc::scoped_refptr<const RTCStatsReport> report_;
517 };
518 
519 }  // namespace webrtc
520 
521 #endif  // PC_TEST_MOCK_PEER_CONNECTION_OBSERVERS_H_
522