1syntax = "proto3"; 2 3package tensorflow.data; 4 5import "tensorflow/core/data/dataset.proto"; 6import "tensorflow/core/data/service/common.proto"; 7 8message ProcessTaskRequest { 9 TaskDef task = 1; 10} 11 12message ProcessTaskResponse {} 13 14message GetElementRequest { 15 // The task to fetch an element from. 16 int64 task_id = 1; 17 // Optional index to indentify the consumer. 18 oneof optional_consumer_index { 19 int64 consumer_index = 2; 20 } 21 // Optional round index, indicating which round of round-robin the consumer 22 // wants to read from. This is used to keep consumers in sync. 23 oneof optional_round_index { 24 int64 round_index = 3; 25 } 26 // Whether the previous round was skipped. This information is needed by the 27 // worker to recover after restarts. 28 bool skipped_previous_round = 4; 29 // Whether to skip the round if data isn't ready fast enough. 30 bool allow_skip = 5; 31} 32 33message GetElementResponse { 34 // The produced element. 35 CompressedElement compressed_element = 3; 36 // Boolean to indicate whether the iterator has been exhausted. 37 bool end_of_sequence = 2; 38 // Indicates whether the round was skipped. 39 bool skip_task = 4; 40} 41 42// Named GetWorkerTasks to avoid conflicting with GetTasks in dispatcher.proto 43message GetWorkerTasksRequest {} 44 45message GetWorkerTasksResponse { 46 repeated TaskInfo tasks = 1; 47} 48 49service WorkerService { 50 // Processes an task for a dataset, making elements available to clients. 51 rpc ProcessTask(ProcessTaskRequest) returns (ProcessTaskResponse); 52 53 // Gets the next dataset element. 54 rpc GetElement(GetElementRequest) returns (GetElementResponse); 55 56 // Gets the tasks currently being executed by the worker. 57 rpc GetWorkerTasks(GetWorkerTasksRequest) returns (GetWorkerTasksResponse); 58} 59