1 /* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #ifndef TENSORFLOW_CORE_PLATFORM_CLOUD_HTTP_REQUEST_FAKE_H_
17 #define TENSORFLOW_CORE_PLATFORM_CLOUD_HTTP_REQUEST_FAKE_H_
18 
19 #include <algorithm>
20 #include <fstream>
21 #include <string>
22 #include <vector>
23 #include <curl/curl.h>
24 #include "tensorflow/core/lib/core/errors.h"
25 #include "tensorflow/core/lib/core/status.h"
26 #include "tensorflow/core/lib/core/status_test_util.h"
27 #include "tensorflow/core/lib/core/stringpiece.h"
28 #include "tensorflow/core/platform/cloud/curl_http_request.h"
29 #include "tensorflow/core/platform/macros.h"
30 #include "tensorflow/core/platform/protobuf.h"
31 #include "tensorflow/core/platform/test.h"
32 #include "tensorflow/core/platform/types.h"
33 
34 namespace tensorflow {
35 
36 /// Fake HttpRequest for testing.
37 class FakeHttpRequest : public CurlHttpRequest {
38  public:
39   /// Return the response for the given request.
FakeHttpRequest(const string & request,const string & response)40   FakeHttpRequest(const string& request, const string& response)
41       : FakeHttpRequest(request, response, Status::OK(), nullptr, {}, 200) {}
42 
43   /// Return the response with headers for the given request.
FakeHttpRequest(const string & request,const string & response,const std::map<string,string> & response_headers)44   FakeHttpRequest(const string& request, const string& response,
45                   const std::map<string, string>& response_headers)
46       : FakeHttpRequest(request, response, Status::OK(), nullptr,
47                         response_headers, 200) {}
48 
49   /// \brief Return the response for the request and capture the POST body.
50   ///
51   /// Post body is not expected to be a part of the 'request' parameter.
FakeHttpRequest(const string & request,const string & response,string * captured_post_body)52   FakeHttpRequest(const string& request, const string& response,
53                   string* captured_post_body)
54       : FakeHttpRequest(request, response, Status::OK(), captured_post_body, {},
55                         200) {}
56 
57   /// \brief Return the response and the status for the given request.
FakeHttpRequest(const string & request,const string & response,Status response_status,uint64 response_code)58   FakeHttpRequest(const string& request, const string& response,
59                   Status response_status, uint64 response_code)
60       : FakeHttpRequest(request, response, response_status, nullptr, {},
61                         response_code) {}
62 
63   /// \brief Return the response and the status for the given request
64   ///  and capture the POST body.
65   ///
66   /// Post body is not expected to be a part of the 'request' parameter.
FakeHttpRequest(const string & request,const string & response,Status response_status,string * captured_post_body,const std::map<string,string> & response_headers,uint64 response_code)67   FakeHttpRequest(const string& request, const string& response,
68                   Status response_status, string* captured_post_body,
69                   const std::map<string, string>& response_headers,
70                   uint64 response_code)
71       : expected_request_(request),
72         response_(response),
73         response_status_(response_status),
74         captured_post_body_(captured_post_body),
75         response_headers_(response_headers),
76         response_code_(response_code) {}
77 
SetUri(const string & uri)78   void SetUri(const string& uri) override {
79     actual_uri_ += "Uri: " + uri + "\n";
80   }
SetRange(uint64 start,uint64 end)81   void SetRange(uint64 start, uint64 end) override {
82     actual_request_ += strings::StrCat("Range: ", start, "-", end, "\n");
83   }
AddHeader(const string & name,const string & value)84   void AddHeader(const string& name, const string& value) override {
85     actual_request_ += "Header " + name + ": " + value + "\n";
86   }
AddAuthBearerHeader(const string & auth_token)87   void AddAuthBearerHeader(const string& auth_token) override {
88     actual_request_ += "Auth Token: " + auth_token + "\n";
89   }
SetDeleteRequest()90   void SetDeleteRequest() override { actual_request_ += "Delete: yes\n"; }
SetPutFromFile(const string & body_filepath,size_t offset)91   Status SetPutFromFile(const string& body_filepath, size_t offset) override {
92     std::ifstream stream(body_filepath);
93     const string& content = string(std::istreambuf_iterator<char>(stream),
94                                    std::istreambuf_iterator<char>())
95                                 .substr(offset);
96     actual_request_ += "Put body: " + content + "\n";
97     return Status::OK();
98   }
SetPostFromBuffer(const char * buffer,size_t size)99   void SetPostFromBuffer(const char* buffer, size_t size) override {
100     if (captured_post_body_) {
101       *captured_post_body_ = string(buffer, size);
102     } else {
103       actual_request_ +=
104           strings::StrCat("Post body: ", StringPiece(buffer, size), "\n");
105     }
106   }
SetPutEmptyBody()107   void SetPutEmptyBody() override { actual_request_ += "Put: yes\n"; }
SetPostEmptyBody()108   void SetPostEmptyBody() override {
109     if (captured_post_body_) {
110       *captured_post_body_ = "<empty>";
111     } else {
112       actual_request_ += "Post: yes\n";
113     }
114   }
SetResultBuffer(std::vector<char> * buffer)115   void SetResultBuffer(std::vector<char>* buffer) override {
116     buffer->clear();
117     buffer_ = buffer;
118   }
SetResultBufferDirect(char * buffer,size_t size)119   void SetResultBufferDirect(char* buffer, size_t size) override {
120     direct_result_buffer_ = buffer;
121     direct_result_buffer_size_ = size;
122   }
GetResultBufferDirectBytesTransferred()123   size_t GetResultBufferDirectBytesTransferred() override {
124     return direct_result_bytes_transferred_;
125   }
Send()126   Status Send() override {
127     EXPECT_EQ(expected_request_, actual_request())
128         << "Unexpected HTTP request.";
129     if (buffer_) {
130       buffer_->insert(buffer_->begin(), response_.data(),
131                       response_.data() + response_.size());
132     } else if (direct_result_buffer_ != nullptr) {
133       size_t bytes_to_copy =
134           std::min<size_t>(direct_result_buffer_size_, response_.size());
135       memcpy(direct_result_buffer_, response_.data(), bytes_to_copy);
136       direct_result_bytes_transferred_ += bytes_to_copy;
137     }
138     return response_status_;
139   }
140 
141   // This function just does a simple replacing of "/" with "%2F" instead of
142   // full url encoding.
EscapeString(const string & str)143   string EscapeString(const string& str) override {
144     const string victim = "/";
145     const string encoded = "%2F";
146 
147     string copy_str = str;
148     std::string::size_type n = 0;
149     while ((n = copy_str.find(victim, n)) != std::string::npos) {
150       copy_str.replace(n, victim.size(), encoded);
151       n += encoded.size();
152     }
153     return copy_str;
154   }
155 
GetResponseHeader(const string & name)156   string GetResponseHeader(const string& name) const override {
157     const auto header = response_headers_.find(name);
158     return header != response_headers_.end() ? header->second : "";
159   }
160 
GetResponseCode()161   virtual uint64 GetResponseCode() const override { return response_code_; }
162 
SetTimeouts(uint32 connection,uint32 inactivity,uint32 total)163   void SetTimeouts(uint32 connection, uint32 inactivity,
164                    uint32 total) override {
165     actual_request_ += strings::StrCat("Timeouts: ", connection, " ",
166                                        inactivity, " ", total, "\n");
167   }
168 
169  private:
actual_request()170   string actual_request() const {
171     string s;
172     s.append(actual_uri_);
173     s.append(actual_request_);
174     return s;
175   }
176 
177   std::vector<char>* buffer_ = nullptr;
178   char* direct_result_buffer_ = nullptr;
179   size_t direct_result_buffer_size_ = 0;
180   size_t direct_result_bytes_transferred_ = 0;
181   string expected_request_;
182   string actual_uri_;
183   string actual_request_;
184   string response_;
185   Status response_status_;
186   string* captured_post_body_ = nullptr;
187   std::map<string, string> response_headers_;
188   uint64 response_code_ = 0;
189 };
190 
191 /// Fake HttpRequest factory for testing.
192 class FakeHttpRequestFactory : public HttpRequest::Factory {
193  public:
FakeHttpRequestFactory(const std::vector<HttpRequest * > * requests)194   FakeHttpRequestFactory(const std::vector<HttpRequest*>* requests)
195       : requests_(requests) {}
196 
~FakeHttpRequestFactory()197   ~FakeHttpRequestFactory() {
198     EXPECT_EQ(current_index_, requests_->size())
199         << "Not all expected requests were made.";
200   }
201 
Create()202   HttpRequest* Create() override {
203     EXPECT_LT(current_index_, requests_->size())
204         << "Too many calls of HttpRequest factory.";
205     return (*requests_)[current_index_++];
206   }
207 
208  private:
209   const std::vector<HttpRequest*>* requests_;
210   int current_index_ = 0;
211 };
212 
213 }  // namespace tensorflow
214 
215 #endif  // TENSORFLOW_CORE_PLATFORM_CLOUD_HTTP_REQUEST_FAKE_H_
216