1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #ifndef TENSORFLOW_COMPILER_XLA_PYTHON_OUTFEED_RECEIVER_H_
17 #define TENSORFLOW_COMPILER_XLA_PYTHON_OUTFEED_RECEIVER_H_
18 
19 #include <memory>
20 
21 #include "tensorflow/compiler/xla/client/xla_builder.h"
22 #include "tensorflow/compiler/xla/literal.h"
23 #include "tensorflow/compiler/xla/pjrt/pjrt_client.h"
24 #include "tensorflow/compiler/xla/shape.h"
25 #include "tensorflow/compiler/xla/statusor.h"
26 
27 namespace xla {
28 
29 class OutfeedReceiverImpl;
30 
31 // Implements a multithreaded receiver of outfeeds from devices.
32 class OutfeedReceiver {
33  public:
34   // A callback takes: device, consumer id, received.
35   using Callback =
36       std::function<void(PjRtDevice*, uint32_t, std::shared_ptr<Literal>)>;
37 
38   // Constructs the receiver for the given clients and callback function.
39   //
40   // Args:
41   //   callback: a function to be called when an outfeed is ready for
42   //     processing.
43   //   clients: the clients for whose devices to listen.
44   //   max_callback_queue_size_bytes: the maximum number of bytes for all
45   //     received outfeeds queued to be processed. When this limit is reached
46   //     we pause receiving outfeeds from devices.
47   OutfeedReceiver(Callback callback, absl::Span<PjRtClient* const> clients,
48                   ssize_t max_callback_queue_size_bytes);
49 
50   OutfeedReceiver(const OutfeedReceiver&) = delete;
51   OutfeedReceiver& operator=(const OutfeedReceiver&) = delete;
52 
53   // Blocks until all data has been received from devices and all data
54   // in the queue has been passed to Python.
55   ~OutfeedReceiver();
56 
57   // Starts the listener threads and the callback thread.
58   void Start();
59 
60   // Adds to the computation builder the outfeed of the arrays.
61   // Has the side-effect of registering the sent shape for the consumer_id.
62   // Returns error status if the outfeed shape is different than the
63   // previously used shape for the same consumer_id or the consumer id is
64   // invalid.
65   StatusOr<XlaOp> AddOutfeedToBuilder(XlaBuilder* builder, XlaOp token,
66                                       uint32_t consumer_id,
67                                       std::vector<XlaOp> arrays);
68 
69  private:
70   std::unique_ptr<OutfeedReceiverImpl> p_impl_;
71 };
72 
73 }  // namespace xla
74 
75 #endif  // TENSORFLOW_COMPILER_XLA_PYTHON_OUTFEED_RECEIVER_H_
76