1 /* 2 * Copyright 2012 The WebRTC project authors. All Rights Reserved. 3 * 4 * Use of this source code is governed by a BSD-style license 5 * that can be found in the LICENSE file in the root of the source 6 * tree. An additional intellectual property rights grant can be found 7 * in the file PATENTS. All contributing project authors may 8 * be found in the AUTHORS file in the root of the source tree. 9 */ 10 11 // This file contains mock implementations of observers used in PeerConnection. 12 // TODO(steveanton): These aren't really mocks and should be renamed. 13 14 #ifndef PC_TEST_MOCK_PEER_CONNECTION_OBSERVERS_H_ 15 #define PC_TEST_MOCK_PEER_CONNECTION_OBSERVERS_H_ 16 17 #include <map> 18 #include <memory> 19 #include <string> 20 #include <utility> 21 #include <vector> 22 23 #include "api/data_channel_interface.h" 24 #include "api/jsep_ice_candidate.h" 25 #include "pc/stream_collection.h" 26 #include "rtc_base/checks.h" 27 28 namespace webrtc { 29 30 class MockPeerConnectionObserver : public PeerConnectionObserver { 31 public: 32 struct AddTrackEvent { AddTrackEventAddTrackEvent33 explicit AddTrackEvent( 34 rtc::scoped_refptr<RtpReceiverInterface> event_receiver, 35 std::vector<rtc::scoped_refptr<MediaStreamInterface>> event_streams) 36 : receiver(std::move(event_receiver)), 37 streams(std::move(event_streams)) { 38 for (auto stream : streams) { 39 std::vector<rtc::scoped_refptr<MediaStreamTrackInterface>> tracks; 40 for (auto audio_track : stream->GetAudioTracks()) { 41 tracks.push_back(audio_track); 42 } 43 for (auto video_track : stream->GetVideoTracks()) { 44 tracks.push_back(video_track); 45 } 46 snapshotted_stream_tracks[stream] = tracks; 47 } 48 } 49 50 rtc::scoped_refptr<RtpReceiverInterface> receiver; 51 std::vector<rtc::scoped_refptr<MediaStreamInterface>> streams; 52 // This map records the tracks present in each stream at the time the 53 // OnAddTrack callback was issued. 54 std::map<rtc::scoped_refptr<MediaStreamInterface>, 55 std::vector<rtc::scoped_refptr<MediaStreamTrackInterface>>> 56 snapshotted_stream_tracks; 57 }; 58 MockPeerConnectionObserver()59 MockPeerConnectionObserver() : remote_streams_(StreamCollection::Create()) {} ~MockPeerConnectionObserver()60 virtual ~MockPeerConnectionObserver() {} SetPeerConnectionInterface(PeerConnectionInterface * pc)61 void SetPeerConnectionInterface(PeerConnectionInterface* pc) { 62 pc_ = pc; 63 if (pc) { 64 state_ = pc_->signaling_state(); 65 } 66 } OnSignalingChange(PeerConnectionInterface::SignalingState new_state)67 void OnSignalingChange( 68 PeerConnectionInterface::SignalingState new_state) override { 69 RTC_DCHECK(pc_); 70 RTC_DCHECK(pc_->signaling_state() == new_state); 71 state_ = new_state; 72 } 73 RemoteStream(const std::string & label)74 MediaStreamInterface* RemoteStream(const std::string& label) { 75 return remote_streams_->find(label); 76 } remote_streams()77 StreamCollectionInterface* remote_streams() const { return remote_streams_; } OnAddStream(rtc::scoped_refptr<MediaStreamInterface> stream)78 void OnAddStream(rtc::scoped_refptr<MediaStreamInterface> stream) override { 79 last_added_stream_ = stream; 80 remote_streams_->AddStream(stream); 81 } OnRemoveStream(rtc::scoped_refptr<MediaStreamInterface> stream)82 void OnRemoveStream( 83 rtc::scoped_refptr<MediaStreamInterface> stream) override { 84 last_removed_stream_ = stream; 85 remote_streams_->RemoveStream(stream); 86 } OnRenegotiationNeeded()87 void OnRenegotiationNeeded() override { renegotiation_needed_ = true; } OnDataChannel(rtc::scoped_refptr<DataChannelInterface> data_channel)88 void OnDataChannel( 89 rtc::scoped_refptr<DataChannelInterface> data_channel) override { 90 last_datachannel_ = data_channel; 91 } 92 OnIceConnectionChange(PeerConnectionInterface::IceConnectionState new_state)93 void OnIceConnectionChange( 94 PeerConnectionInterface::IceConnectionState new_state) override { 95 RTC_DCHECK(pc_); 96 RTC_DCHECK(pc_->ice_connection_state() == new_state); 97 // When ICE is finished, the caller will get to a kIceConnectionCompleted 98 // state, because it has the ICE controlling role, while the callee 99 // will get to a kIceConnectionConnected state. This means that both ICE 100 // and DTLS are connected. 101 ice_connected_ = 102 (new_state == PeerConnectionInterface::kIceConnectionConnected) || 103 (new_state == PeerConnectionInterface::kIceConnectionCompleted); 104 callback_triggered_ = true; 105 } OnIceGatheringChange(PeerConnectionInterface::IceGatheringState new_state)106 void OnIceGatheringChange( 107 PeerConnectionInterface::IceGatheringState new_state) override { 108 RTC_DCHECK(pc_); 109 RTC_DCHECK(pc_->ice_gathering_state() == new_state); 110 ice_gathering_complete_ = 111 new_state == PeerConnectionInterface::kIceGatheringComplete; 112 callback_triggered_ = true; 113 } OnIceCandidate(const IceCandidateInterface * candidate)114 void OnIceCandidate(const IceCandidateInterface* candidate) override { 115 RTC_DCHECK(pc_); 116 RTC_DCHECK(PeerConnectionInterface::kIceGatheringNew != 117 pc_->ice_gathering_state()); 118 candidates_.push_back(std::make_unique<JsepIceCandidate>( 119 candidate->sdp_mid(), candidate->sdp_mline_index(), 120 candidate->candidate())); 121 callback_triggered_ = true; 122 } 123 OnIceCandidatesRemoved(const std::vector<cricket::Candidate> & candidates)124 void OnIceCandidatesRemoved( 125 const std::vector<cricket::Candidate>& candidates) override { 126 num_candidates_removed_++; 127 callback_triggered_ = true; 128 } 129 OnIceConnectionReceivingChange(bool receiving)130 void OnIceConnectionReceivingChange(bool receiving) override { 131 callback_triggered_ = true; 132 } 133 OnAddTrack(rtc::scoped_refptr<RtpReceiverInterface> receiver,const std::vector<rtc::scoped_refptr<MediaStreamInterface>> & streams)134 void OnAddTrack(rtc::scoped_refptr<RtpReceiverInterface> receiver, 135 const std::vector<rtc::scoped_refptr<MediaStreamInterface>>& 136 streams) override { 137 RTC_DCHECK(receiver); 138 num_added_tracks_++; 139 last_added_track_label_ = receiver->id(); 140 add_track_events_.push_back(AddTrackEvent(receiver, streams)); 141 } 142 OnTrack(rtc::scoped_refptr<RtpTransceiverInterface> transceiver)143 void OnTrack( 144 rtc::scoped_refptr<RtpTransceiverInterface> transceiver) override { 145 on_track_transceivers_.push_back(transceiver); 146 } 147 OnRemoveTrack(rtc::scoped_refptr<RtpReceiverInterface> receiver)148 void OnRemoveTrack( 149 rtc::scoped_refptr<RtpReceiverInterface> receiver) override { 150 remove_track_events_.push_back(receiver); 151 } 152 GetAddTrackReceivers()153 std::vector<rtc::scoped_refptr<RtpReceiverInterface>> GetAddTrackReceivers() { 154 std::vector<rtc::scoped_refptr<RtpReceiverInterface>> receivers; 155 for (const AddTrackEvent& event : add_track_events_) { 156 receivers.push_back(event.receiver); 157 } 158 return receivers; 159 } 160 CountAddTrackEventsForStream(const std::string & stream_id)161 int CountAddTrackEventsForStream(const std::string& stream_id) { 162 int found_tracks = 0; 163 for (const AddTrackEvent& event : add_track_events_) { 164 bool has_stream_id = false; 165 for (auto stream : event.streams) { 166 if (stream->id() == stream_id) { 167 has_stream_id = true; 168 break; 169 } 170 } 171 if (has_stream_id) { 172 ++found_tracks; 173 } 174 } 175 return found_tracks; 176 } 177 178 // Returns the id of the last added stream. 179 // Empty string if no stream have been added. GetLastAddedStreamId()180 std::string GetLastAddedStreamId() { 181 if (last_added_stream_.get()) 182 return last_added_stream_->id(); 183 return ""; 184 } GetLastRemovedStreamId()185 std::string GetLastRemovedStreamId() { 186 if (last_removed_stream_.get()) 187 return last_removed_stream_->id(); 188 return ""; 189 } 190 last_candidate()191 IceCandidateInterface* last_candidate() { 192 if (candidates_.empty()) { 193 return nullptr; 194 } else { 195 return candidates_.back().get(); 196 } 197 } 198 GetAllCandidates()199 std::vector<const IceCandidateInterface*> GetAllCandidates() { 200 std::vector<const IceCandidateInterface*> candidates; 201 for (const auto& candidate : candidates_) { 202 candidates.push_back(candidate.get()); 203 } 204 return candidates; 205 } 206 GetCandidatesByMline(int mline_index)207 std::vector<IceCandidateInterface*> GetCandidatesByMline(int mline_index) { 208 std::vector<IceCandidateInterface*> candidates; 209 for (const auto& candidate : candidates_) { 210 if (candidate->sdp_mline_index() == mline_index) { 211 candidates.push_back(candidate.get()); 212 } 213 } 214 return candidates; 215 } 216 negotiation_needed()217 bool negotiation_needed() const { return renegotiation_needed_; } clear_negotiation_needed()218 void clear_negotiation_needed() { renegotiation_needed_ = false; } 219 220 rtc::scoped_refptr<PeerConnectionInterface> pc_; 221 PeerConnectionInterface::SignalingState state_; 222 std::vector<std::unique_ptr<IceCandidateInterface>> candidates_; 223 rtc::scoped_refptr<DataChannelInterface> last_datachannel_; 224 rtc::scoped_refptr<StreamCollection> remote_streams_; 225 bool renegotiation_needed_ = false; 226 bool ice_gathering_complete_ = false; 227 bool ice_connected_ = false; 228 bool callback_triggered_ = false; 229 int num_added_tracks_ = 0; 230 std::string last_added_track_label_; 231 std::vector<AddTrackEvent> add_track_events_; 232 std::vector<rtc::scoped_refptr<RtpReceiverInterface>> remove_track_events_; 233 std::vector<rtc::scoped_refptr<RtpTransceiverInterface>> 234 on_track_transceivers_; 235 int num_candidates_removed_ = 0; 236 237 private: 238 rtc::scoped_refptr<MediaStreamInterface> last_added_stream_; 239 rtc::scoped_refptr<MediaStreamInterface> last_removed_stream_; 240 }; 241 242 class MockCreateSessionDescriptionObserver 243 : public webrtc::CreateSessionDescriptionObserver { 244 public: MockCreateSessionDescriptionObserver()245 MockCreateSessionDescriptionObserver() 246 : called_(false), 247 error_("MockCreateSessionDescriptionObserver not called") {} ~MockCreateSessionDescriptionObserver()248 virtual ~MockCreateSessionDescriptionObserver() {} OnSuccess(SessionDescriptionInterface * desc)249 void OnSuccess(SessionDescriptionInterface* desc) override { 250 called_ = true; 251 error_ = ""; 252 desc_.reset(desc); 253 } OnFailure(webrtc::RTCError error)254 void OnFailure(webrtc::RTCError error) override { 255 called_ = true; 256 error_ = error.message(); 257 } called()258 bool called() const { return called_; } result()259 bool result() const { return error_.empty(); } error()260 const std::string& error() const { return error_; } MoveDescription()261 std::unique_ptr<SessionDescriptionInterface> MoveDescription() { 262 return std::move(desc_); 263 } 264 265 private: 266 bool called_; 267 std::string error_; 268 std::unique_ptr<SessionDescriptionInterface> desc_; 269 }; 270 271 class MockSetSessionDescriptionObserver 272 : public webrtc::SetSessionDescriptionObserver { 273 public: Create()274 static rtc::scoped_refptr<MockSetSessionDescriptionObserver> Create() { 275 return new rtc::RefCountedObject<MockSetSessionDescriptionObserver>(); 276 } 277 MockSetSessionDescriptionObserver()278 MockSetSessionDescriptionObserver() 279 : called_(false), 280 error_("MockSetSessionDescriptionObserver not called") {} ~MockSetSessionDescriptionObserver()281 ~MockSetSessionDescriptionObserver() override {} OnSuccess()282 void OnSuccess() override { 283 called_ = true; 284 error_ = ""; 285 } OnFailure(webrtc::RTCError error)286 void OnFailure(webrtc::RTCError error) override { 287 called_ = true; 288 error_ = error.message(); 289 } 290 called()291 bool called() const { return called_; } result()292 bool result() const { return error_.empty(); } error()293 const std::string& error() const { return error_; } 294 295 private: 296 bool called_; 297 std::string error_; 298 }; 299 300 class MockSetRemoteDescriptionObserver 301 : public rtc::RefCountedObject<SetRemoteDescriptionObserverInterface> { 302 public: called()303 bool called() const { return error_.has_value(); } error()304 RTCError& error() { 305 RTC_DCHECK(error_.has_value()); 306 return *error_; 307 } 308 309 // SetRemoteDescriptionObserverInterface implementation. OnSetRemoteDescriptionComplete(RTCError error)310 void OnSetRemoteDescriptionComplete(RTCError error) override { 311 error_ = std::move(error); 312 } 313 314 private: 315 // Set on complete, on success this is set to an RTCError::OK() error. 316 absl::optional<RTCError> error_; 317 }; 318 319 class MockDataChannelObserver : public webrtc::DataChannelObserver { 320 public: MockDataChannelObserver(webrtc::DataChannelInterface * channel)321 explicit MockDataChannelObserver(webrtc::DataChannelInterface* channel) 322 : channel_(channel) { 323 channel_->RegisterObserver(this); 324 state_ = channel_->state(); 325 } ~MockDataChannelObserver()326 virtual ~MockDataChannelObserver() { channel_->UnregisterObserver(); } 327 OnBufferedAmountChange(uint64_t previous_amount)328 void OnBufferedAmountChange(uint64_t previous_amount) override {} 329 OnStateChange()330 void OnStateChange() override { state_ = channel_->state(); } OnMessage(const DataBuffer & buffer)331 void OnMessage(const DataBuffer& buffer) override { 332 messages_.push_back( 333 std::string(buffer.data.data<char>(), buffer.data.size())); 334 } 335 IsOpen()336 bool IsOpen() const { return state_ == DataChannelInterface::kOpen; } messages()337 std::vector<std::string> messages() const { return messages_; } last_message()338 std::string last_message() const { 339 return messages_.empty() ? std::string() : messages_.back(); 340 } received_message_count()341 size_t received_message_count() const { return messages_.size(); } 342 343 private: 344 rtc::scoped_refptr<webrtc::DataChannelInterface> channel_; 345 DataChannelInterface::DataState state_; 346 std::vector<std::string> messages_; 347 }; 348 349 class MockStatsObserver : public webrtc::StatsObserver { 350 public: MockStatsObserver()351 MockStatsObserver() : called_(false), stats_() {} ~MockStatsObserver()352 virtual ~MockStatsObserver() {} 353 OnComplete(const StatsReports & reports)354 virtual void OnComplete(const StatsReports& reports) { 355 RTC_CHECK(!called_); 356 called_ = true; 357 stats_.Clear(); 358 stats_.number_of_reports = reports.size(); 359 for (const auto* r : reports) { 360 if (r->type() == StatsReport::kStatsReportTypeSsrc) { 361 stats_.timestamp = r->timestamp(); 362 GetIntValue(r, StatsReport::kStatsValueNameAudioOutputLevel, 363 &stats_.audio_output_level); 364 GetIntValue(r, StatsReport::kStatsValueNameAudioInputLevel, 365 &stats_.audio_input_level); 366 GetIntValue(r, StatsReport::kStatsValueNameBytesReceived, 367 &stats_.bytes_received); 368 GetIntValue(r, StatsReport::kStatsValueNameBytesSent, 369 &stats_.bytes_sent); 370 GetInt64Value(r, StatsReport::kStatsValueNameCaptureStartNtpTimeMs, 371 &stats_.capture_start_ntp_time); 372 stats_.track_ids.emplace_back(); 373 GetStringValue(r, StatsReport::kStatsValueNameTrackId, 374 &stats_.track_ids.back()); 375 } else if (r->type() == StatsReport::kStatsReportTypeBwe) { 376 stats_.timestamp = r->timestamp(); 377 GetIntValue(r, StatsReport::kStatsValueNameAvailableReceiveBandwidth, 378 &stats_.available_receive_bandwidth); 379 } else if (r->type() == StatsReport::kStatsReportTypeComponent) { 380 stats_.timestamp = r->timestamp(); 381 GetStringValue(r, StatsReport::kStatsValueNameDtlsCipher, 382 &stats_.dtls_cipher); 383 GetStringValue(r, StatsReport::kStatsValueNameSrtpCipher, 384 &stats_.srtp_cipher); 385 } 386 } 387 } 388 called()389 bool called() const { return called_; } number_of_reports()390 size_t number_of_reports() const { return stats_.number_of_reports; } timestamp()391 double timestamp() const { return stats_.timestamp; } 392 AudioOutputLevel()393 int AudioOutputLevel() const { 394 RTC_CHECK(called_); 395 return stats_.audio_output_level; 396 } 397 AudioInputLevel()398 int AudioInputLevel() const { 399 RTC_CHECK(called_); 400 return stats_.audio_input_level; 401 } 402 BytesReceived()403 int BytesReceived() const { 404 RTC_CHECK(called_); 405 return stats_.bytes_received; 406 } 407 BytesSent()408 int BytesSent() const { 409 RTC_CHECK(called_); 410 return stats_.bytes_sent; 411 } 412 CaptureStartNtpTime()413 int64_t CaptureStartNtpTime() const { 414 RTC_CHECK(called_); 415 return stats_.capture_start_ntp_time; 416 } 417 AvailableReceiveBandwidth()418 int AvailableReceiveBandwidth() const { 419 RTC_CHECK(called_); 420 return stats_.available_receive_bandwidth; 421 } 422 DtlsCipher()423 std::string DtlsCipher() const { 424 RTC_CHECK(called_); 425 return stats_.dtls_cipher; 426 } 427 SrtpCipher()428 std::string SrtpCipher() const { 429 RTC_CHECK(called_); 430 return stats_.srtp_cipher; 431 } 432 TrackIds()433 std::vector<std::string> TrackIds() const { 434 RTC_CHECK(called_); 435 return stats_.track_ids; 436 } 437 438 private: GetIntValue(const StatsReport * report,StatsReport::StatsValueName name,int * value)439 bool GetIntValue(const StatsReport* report, 440 StatsReport::StatsValueName name, 441 int* value) { 442 const StatsReport::Value* v = report->FindValue(name); 443 if (v) { 444 // TODO(tommi): We should really just be using an int here :-/ 445 *value = rtc::FromString<int>(v->ToString()); 446 } 447 return v != nullptr; 448 } 449 GetInt64Value(const StatsReport * report,StatsReport::StatsValueName name,int64_t * value)450 bool GetInt64Value(const StatsReport* report, 451 StatsReport::StatsValueName name, 452 int64_t* value) { 453 const StatsReport::Value* v = report->FindValue(name); 454 if (v) { 455 // TODO(tommi): We should really just be using an int here :-/ 456 *value = rtc::FromString<int64_t>(v->ToString()); 457 } 458 return v != nullptr; 459 } 460 GetStringValue(const StatsReport * report,StatsReport::StatsValueName name,std::string * value)461 bool GetStringValue(const StatsReport* report, 462 StatsReport::StatsValueName name, 463 std::string* value) { 464 const StatsReport::Value* v = report->FindValue(name); 465 if (v) 466 *value = v->ToString(); 467 return v != nullptr; 468 } 469 470 bool called_; 471 struct { Clear__anon763bc74c0108472 void Clear() { 473 number_of_reports = 0; 474 timestamp = 0; 475 audio_output_level = 0; 476 audio_input_level = 0; 477 bytes_received = 0; 478 bytes_sent = 0; 479 capture_start_ntp_time = 0; 480 available_receive_bandwidth = 0; 481 dtls_cipher.clear(); 482 srtp_cipher.clear(); 483 track_ids.clear(); 484 } 485 486 size_t number_of_reports; 487 double timestamp; 488 int audio_output_level; 489 int audio_input_level; 490 int bytes_received; 491 int bytes_sent; 492 int64_t capture_start_ntp_time; 493 int available_receive_bandwidth; 494 std::string dtls_cipher; 495 std::string srtp_cipher; 496 std::vector<std::string> track_ids; 497 } stats_; 498 }; 499 500 // Helper class that just stores the report from the callback. 501 class MockRTCStatsCollectorCallback : public webrtc::RTCStatsCollectorCallback { 502 public: report()503 rtc::scoped_refptr<const RTCStatsReport> report() { return report_; } 504 called()505 bool called() const { return called_; } 506 507 protected: OnStatsDelivered(const rtc::scoped_refptr<const RTCStatsReport> & report)508 void OnStatsDelivered( 509 const rtc::scoped_refptr<const RTCStatsReport>& report) override { 510 report_ = report; 511 called_ = true; 512 } 513 514 private: 515 bool called_ = false; 516 rtc::scoped_refptr<const RTCStatsReport> report_; 517 }; 518 519 } // namespace webrtc 520 521 #endif // PC_TEST_MOCK_PEER_CONNECTION_OBSERVERS_H_ 522