1 /* Copyright 2016 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 #ifndef TENSORFLOW_CORE_PLATFORM_CLOUD_HTTP_REQUEST_FAKE_H_ 17 #define TENSORFLOW_CORE_PLATFORM_CLOUD_HTTP_REQUEST_FAKE_H_ 18 19 #include <algorithm> 20 #include <fstream> 21 #include <string> 22 #include <vector> 23 #include <curl/curl.h> 24 #include "tensorflow/core/lib/core/errors.h" 25 #include "tensorflow/core/lib/core/status.h" 26 #include "tensorflow/core/lib/core/status_test_util.h" 27 #include "tensorflow/core/lib/core/stringpiece.h" 28 #include "tensorflow/core/platform/cloud/curl_http_request.h" 29 #include "tensorflow/core/platform/macros.h" 30 #include "tensorflow/core/platform/protobuf.h" 31 #include "tensorflow/core/platform/test.h" 32 #include "tensorflow/core/platform/types.h" 33 34 namespace tensorflow { 35 36 /// Fake HttpRequest for testing. 37 class FakeHttpRequest : public CurlHttpRequest { 38 public: 39 /// Return the response for the given request. FakeHttpRequest(const string & request,const string & response)40 FakeHttpRequest(const string& request, const string& response) 41 : FakeHttpRequest(request, response, Status::OK(), nullptr, {}, 200) {} 42 43 /// Return the response with headers for the given request. FakeHttpRequest(const string & request,const string & response,const std::map<string,string> & response_headers)44 FakeHttpRequest(const string& request, const string& response, 45 const std::map<string, string>& response_headers) 46 : FakeHttpRequest(request, response, Status::OK(), nullptr, 47 response_headers, 200) {} 48 49 /// \brief Return the response for the request and capture the POST body. 50 /// 51 /// Post body is not expected to be a part of the 'request' parameter. FakeHttpRequest(const string & request,const string & response,string * captured_post_body)52 FakeHttpRequest(const string& request, const string& response, 53 string* captured_post_body) 54 : FakeHttpRequest(request, response, Status::OK(), captured_post_body, {}, 55 200) {} 56 57 /// \brief Return the response and the status for the given request. FakeHttpRequest(const string & request,const string & response,Status response_status,uint64 response_code)58 FakeHttpRequest(const string& request, const string& response, 59 Status response_status, uint64 response_code) 60 : FakeHttpRequest(request, response, response_status, nullptr, {}, 61 response_code) {} 62 63 /// \brief Return the response and the status for the given request 64 /// and capture the POST body. 65 /// 66 /// Post body is not expected to be a part of the 'request' parameter. FakeHttpRequest(const string & request,const string & response,Status response_status,string * captured_post_body,const std::map<string,string> & response_headers,uint64 response_code)67 FakeHttpRequest(const string& request, const string& response, 68 Status response_status, string* captured_post_body, 69 const std::map<string, string>& response_headers, 70 uint64 response_code) 71 : expected_request_(request), 72 response_(response), 73 response_status_(response_status), 74 captured_post_body_(captured_post_body), 75 response_headers_(response_headers), 76 response_code_(response_code) {} 77 SetUri(const string & uri)78 void SetUri(const string& uri) override { 79 actual_uri_ += "Uri: " + uri + "\n"; 80 } SetRange(uint64 start,uint64 end)81 void SetRange(uint64 start, uint64 end) override { 82 actual_request_ += strings::StrCat("Range: ", start, "-", end, "\n"); 83 } AddHeader(const string & name,const string & value)84 void AddHeader(const string& name, const string& value) override { 85 actual_request_ += "Header " + name + ": " + value + "\n"; 86 } AddAuthBearerHeader(const string & auth_token)87 void AddAuthBearerHeader(const string& auth_token) override { 88 actual_request_ += "Auth Token: " + auth_token + "\n"; 89 } SetDeleteRequest()90 void SetDeleteRequest() override { actual_request_ += "Delete: yes\n"; } SetPutFromFile(const string & body_filepath,size_t offset)91 Status SetPutFromFile(const string& body_filepath, size_t offset) override { 92 std::ifstream stream(body_filepath); 93 const string& content = string(std::istreambuf_iterator<char>(stream), 94 std::istreambuf_iterator<char>()) 95 .substr(offset); 96 actual_request_ += "Put body: " + content + "\n"; 97 return Status::OK(); 98 } SetPostFromBuffer(const char * buffer,size_t size)99 void SetPostFromBuffer(const char* buffer, size_t size) override { 100 if (captured_post_body_) { 101 *captured_post_body_ = string(buffer, size); 102 } else { 103 actual_request_ += 104 strings::StrCat("Post body: ", StringPiece(buffer, size), "\n"); 105 } 106 } SetPutEmptyBody()107 void SetPutEmptyBody() override { actual_request_ += "Put: yes\n"; } SetPostEmptyBody()108 void SetPostEmptyBody() override { 109 if (captured_post_body_) { 110 *captured_post_body_ = "<empty>"; 111 } else { 112 actual_request_ += "Post: yes\n"; 113 } 114 } SetResultBuffer(std::vector<char> * buffer)115 void SetResultBuffer(std::vector<char>* buffer) override { 116 buffer->clear(); 117 buffer_ = buffer; 118 } SetResultBufferDirect(char * buffer,size_t size)119 void SetResultBufferDirect(char* buffer, size_t size) override { 120 direct_result_buffer_ = buffer; 121 direct_result_buffer_size_ = size; 122 } GetResultBufferDirectBytesTransferred()123 size_t GetResultBufferDirectBytesTransferred() override { 124 return direct_result_bytes_transferred_; 125 } Send()126 Status Send() override { 127 EXPECT_EQ(expected_request_, actual_request()) 128 << "Unexpected HTTP request."; 129 if (buffer_) { 130 buffer_->insert(buffer_->begin(), response_.data(), 131 response_.data() + response_.size()); 132 } else if (direct_result_buffer_ != nullptr) { 133 size_t bytes_to_copy = 134 std::min<size_t>(direct_result_buffer_size_, response_.size()); 135 memcpy(direct_result_buffer_, response_.data(), bytes_to_copy); 136 direct_result_bytes_transferred_ += bytes_to_copy; 137 } 138 return response_status_; 139 } 140 141 // This function just does a simple replacing of "/" with "%2F" instead of 142 // full url encoding. EscapeString(const string & str)143 string EscapeString(const string& str) override { 144 const string victim = "/"; 145 const string encoded = "%2F"; 146 147 string copy_str = str; 148 std::string::size_type n = 0; 149 while ((n = copy_str.find(victim, n)) != std::string::npos) { 150 copy_str.replace(n, victim.size(), encoded); 151 n += encoded.size(); 152 } 153 return copy_str; 154 } 155 GetResponseHeader(const string & name)156 string GetResponseHeader(const string& name) const override { 157 const auto header = response_headers_.find(name); 158 return header != response_headers_.end() ? header->second : ""; 159 } 160 GetResponseCode()161 virtual uint64 GetResponseCode() const override { return response_code_; } 162 SetTimeouts(uint32 connection,uint32 inactivity,uint32 total)163 void SetTimeouts(uint32 connection, uint32 inactivity, 164 uint32 total) override { 165 actual_request_ += strings::StrCat("Timeouts: ", connection, " ", 166 inactivity, " ", total, "\n"); 167 } 168 169 private: actual_request()170 string actual_request() const { 171 string s; 172 s.append(actual_uri_); 173 s.append(actual_request_); 174 return s; 175 } 176 177 std::vector<char>* buffer_ = nullptr; 178 char* direct_result_buffer_ = nullptr; 179 size_t direct_result_buffer_size_ = 0; 180 size_t direct_result_bytes_transferred_ = 0; 181 string expected_request_; 182 string actual_uri_; 183 string actual_request_; 184 string response_; 185 Status response_status_; 186 string* captured_post_body_ = nullptr; 187 std::map<string, string> response_headers_; 188 uint64 response_code_ = 0; 189 }; 190 191 /// Fake HttpRequest factory for testing. 192 class FakeHttpRequestFactory : public HttpRequest::Factory { 193 public: FakeHttpRequestFactory(const std::vector<HttpRequest * > * requests)194 FakeHttpRequestFactory(const std::vector<HttpRequest*>* requests) 195 : requests_(requests) {} 196 ~FakeHttpRequestFactory()197 ~FakeHttpRequestFactory() { 198 EXPECT_EQ(current_index_, requests_->size()) 199 << "Not all expected requests were made."; 200 } 201 Create()202 HttpRequest* Create() override { 203 EXPECT_LT(current_index_, requests_->size()) 204 << "Too many calls of HttpRequest factory."; 205 return (*requests_)[current_index_++]; 206 } 207 208 private: 209 const std::vector<HttpRequest*>* requests_; 210 int current_index_ = 0; 211 }; 212 213 } // namespace tensorflow 214 215 #endif // TENSORFLOW_CORE_PLATFORM_CLOUD_HTTP_REQUEST_FAKE_H_ 216