1syntax = "proto3";
2
3package tensorflow.data;
4
5import "tensorflow/core/data/dataset.proto";
6import "tensorflow/core/data/service/common.proto";
7
8message ProcessTaskRequest {
9  TaskDef task = 1;
10}
11
12message ProcessTaskResponse {}
13
14message GetElementRequest {
15  // The task to fetch an element from.
16  int64 task_id = 1;
17  // Optional index to indentify the consumer.
18  oneof optional_consumer_index {
19    int64 consumer_index = 2;
20  }
21  // Optional round index, indicating which round of round-robin the consumer
22  // wants to read from. This is used to keep consumers in sync.
23  oneof optional_round_index {
24    int64 round_index = 3;
25  }
26  // Whether the previous round was skipped. This information is needed by the
27  // worker to recover after restarts.
28  bool skipped_previous_round = 4;
29  // Whether to skip the round if data isn't ready fast enough.
30  bool allow_skip = 5;
31}
32
33message GetElementResponse {
34  // The produced element.
35  CompressedElement compressed_element = 3;
36  // Boolean to indicate whether the iterator has been exhausted.
37  bool end_of_sequence = 2;
38  // Indicates whether the round was skipped.
39  bool skip_task = 4;
40}
41
42// Named GetWorkerTasks to avoid conflicting with GetTasks in dispatcher.proto
43message GetWorkerTasksRequest {}
44
45message GetWorkerTasksResponse {
46  repeated TaskInfo tasks = 1;
47}
48
49service WorkerService {
50  // Processes an task for a dataset, making elements available to clients.
51  rpc ProcessTask(ProcessTaskRequest) returns (ProcessTaskResponse);
52
53  // Gets the next dataset element.
54  rpc GetElement(GetElementRequest) returns (GetElementResponse);
55
56  // Gets the tasks currently being executed by the worker.
57  rpc GetWorkerTasks(GetWorkerTasksRequest) returns (GetWorkerTasksResponse);
58}
59