1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/core/common_runtime/placer.h"
17 
18 #include <memory>
19 #include <string>
20 #include <unordered_set>
21 #include <utility>
22 #include <vector>
23 
24 #include "tensorflow/core/common_runtime/device.h"
25 #include "tensorflow/core/common_runtime/device_factory.h"
26 #include "tensorflow/core/common_runtime/device_set.h"
27 #include "tensorflow/core/framework/device_attributes.pb.h"
28 #include "tensorflow/core/framework/function.h"
29 #include "tensorflow/core/framework/function_testlib.h"
30 #include "tensorflow/core/framework/kernel_def_builder.h"
31 #include "tensorflow/core/framework/op.h"
32 #include "tensorflow/core/framework/op_def_builder.h"
33 #include "tensorflow/core/framework/op_kernel.h"
34 #include "tensorflow/core/framework/types.pb.h"
35 #include "tensorflow/core/graph/graph.h"
36 #include "tensorflow/core/graph/graph_constructor.h"
37 #include "tensorflow/core/graph/graph_def_builder.h"
38 #include "tensorflow/core/graph/graph_def_builder_util.h"
39 #include "tensorflow/core/lib/core/error_codes.pb.h"
40 #include "tensorflow/core/lib/core/errors.h"
41 #include "tensorflow/core/lib/core/status_test_util.h"
42 #include "tensorflow/core/lib/strings/str_util.h"
43 #include "tensorflow/core/lib/strings/strcat.h"
44 #include "tensorflow/core/platform/test.h"
45 
46 namespace tensorflow {
47 
48 using ::tensorflow::test::function::GDef;
49 using ::tensorflow::test::function::NDef;
50 using FDH = ::tensorflow::FunctionDefHelper;
51 
52 constexpr char kCPU[] = "/device:fakecpu:0";
53 constexpr char kGPU[] = "/device:fakegpu:0";
54 
55 constexpr char kFullCPU[] = "/job:a/replica:0/task:0/device:fakecpu:0";
56 constexpr char kFullGPU[] = "/job:a/replica:0/task:0/device:fakegpu:0";
57 
58 namespace {
59 
60 ////////////////////////////////////////////////////////////////////////////////
61 //
62 // Op, kernel, and device registrations to set up the environment.
63 //
64 // The Placer uses information about the op (input types),
65 // kernel (device constraints), and available devices to make
66 // placement decisions. To avoid depending on the full runtime, we
67 // define dummy implementations of these, and register them with the
68 // runtime.
69 //
70 ////////////////////////////////////////////////////////////////////////////////
71 
72 // A dummy OpKernel that is used to register ops on different devices.
73 class DummyOp : public OpKernel {
74  public:
DummyOp(OpKernelConstruction * context)75   explicit DummyOp(OpKernelConstruction* context) : OpKernel(context) {}
Compute(OpKernelContext * context)76   void Compute(OpKernelContext* context) override {}
77 };
78 
79 // A fake device that has specific device attributes, used to simulate
80 // the presence of a CPU or a GPU (without depending on that part of
81 // the runtime.
82 class FakeDevice : public Device {
83  private:
FakeDevice(const DeviceAttributes & device_attributes)84   explicit FakeDevice(const DeviceAttributes& device_attributes)
85       : Device(nullptr, device_attributes) {}
86 
87  public:
Sync()88   Status Sync() override { return errors::Unimplemented("FakeDevice::Sync()"); }
89 
GetAllocator(AllocatorAttributes attr)90   Allocator* GetAllocator(AllocatorAttributes attr) override { return nullptr; }
91 
MakeCPU(const string & name)92   static std::unique_ptr<Device> MakeCPU(const string& name) {
93     DeviceAttributes device_attributes;
94     device_attributes.set_name(name);
95     device_attributes.set_device_type(DeviceType("FakeCPU").type());
96     return std::unique_ptr<Device>(new FakeDevice(device_attributes));
97   }
98 
MakeGPU(const string & name)99   static std::unique_ptr<Device> MakeGPU(const string& name) {
100     DeviceAttributes device_attributes;
101     device_attributes.set_name(name);
102     device_attributes.set_device_type(DeviceType("FakeGPU").type());
103     return std::unique_ptr<Device>(new FakeDevice(device_attributes));
104   }
105 };
106 
107 class DummyFactory : public DeviceFactory {
108  public:
CreateDevices(const SessionOptions & options,const string & name_prefix,std::vector<std::unique_ptr<Device>> * devices)109   Status CreateDevices(const SessionOptions& options, const string& name_prefix,
110                        std::vector<std::unique_ptr<Device>>* devices) override {
111     return Status::OK();
112   }
113 };
114 
115 // Device order now depends on the registration of devices, not a fixed
116 // value in device_set.cc.  To avoid the need to link in the real CPU and GPU
117 // devices into this test, we create fake devices and registrations that
118 // can stand-in for the real devices for the purposes of testing placement
119 // and ordering.
120 REGISTER_LOCAL_DEVICE_FACTORY("FakeCPU", DummyFactory);
121 REGISTER_LOCAL_DEVICE_FACTORY("FakeGPU", DummyFactory, 51);
122 
123 // Register the following ops so they can be added to a Graph, and
124 // kernels so that they can be placed on particular device types.
125 REGISTER_OP("TestVariable").Output("o: Ref(float)");
126 REGISTER_KERNEL_BUILDER(Name("TestVariable").Device("FakeCPU"), DummyOp);
127 REGISTER_KERNEL_BUILDER(Name("TestVariable").Device("FakeGPU"), DummyOp);
128 
129 REGISTER_OP("VariableCPU").Output("o: Ref(float)");
130 REGISTER_KERNEL_BUILDER(Name("VariableCPU").Device("FakeCPU"), DummyOp);
131 
132 REGISTER_OP("VariableGPU").Output("o: Ref(float)");
133 REGISTER_KERNEL_BUILDER(Name("VariableGPU").Device("FakeGPU"), DummyOp);
134 
135 REGISTER_OP("VariableNoKernels").Output("o: Ref(float)");
136 
137 REGISTER_OP("TestAdd").Input("a: float").Input("b: float").Output("o: float");
138 REGISTER_KERNEL_BUILDER(Name("TestAdd").Device("FakeCPU"), DummyOp);
139 REGISTER_KERNEL_BUILDER(Name("TestAdd").Device("FakeGPU"), DummyOp);
140 
141 REGISTER_OP("TestRelu").Input("i: float").Output("o: float");
142 REGISTER_KERNEL_BUILDER(Name("TestRelu").Device("FakeCPU"), DummyOp);
143 REGISTER_KERNEL_BUILDER(Name("TestRelu").Device("FakeGPU"), DummyOp);
144 
145 REGISTER_OP("ReluCPU").Input("i: float").Output("o: float");
146 REGISTER_KERNEL_BUILDER(Name("ReluCPU").Device("FakeCPU"), DummyOp);
147 
148 REGISTER_OP("ReluGPU").Input("i: float").Output("o: float");
149 REGISTER_KERNEL_BUILDER(Name("ReluGPU").Device("FakeGPU"), DummyOp);
150 
151 REGISTER_OP("TestAssign").Input("i: Ref(float)").Input("v: float");
152 REGISTER_KERNEL_BUILDER(Name("TestAssign").Device("FakeCPU"), DummyOp);
153 REGISTER_KERNEL_BUILDER(Name("TestAssign").Device("FakeGPU"), DummyOp);
154 
155 REGISTER_OP("AssignCPU").Input("i: Ref(float)").Input("v: float");
156 REGISTER_KERNEL_BUILDER(Name("AssignCPU").Device("FakeCPU"), DummyOp);
157 
158 REGISTER_OP("AssignGPU").Input("i: Ref(float)").Input("v: float");
159 REGISTER_KERNEL_BUILDER(Name("AssignGPU").Device("FakeGPU"), DummyOp);
160 
161 REGISTER_OP("TestInput").Output("a: float").Output("b: float");
162 REGISTER_KERNEL_BUILDER(Name("TestInput").Device("FakeCPU"), DummyOp);
163 
164 // Op producing an output that can be placed on CPU or GPU.
165 REGISTER_OP("TestCPUGPUOutput").Output("a: float");
166 REGISTER_KERNEL_BUILDER(Name("TestCPUGPUOutput").Device("FakeCPU"), DummyOp);
167 REGISTER_KERNEL_BUILDER(Name("TestCPUGPUOutput").Device("FakeGPU"), DummyOp);
168 
169 REGISTER_OP("TestGPUOutput").Output("a: float");
170 REGISTER_KERNEL_BUILDER(Name("TestGPUOutput").Device("FakeGPU"), DummyOp);
171 
172 REGISTER_OP("TestDevice").Output("a: float").Output("b: float");
173 REGISTER_KERNEL_BUILDER(Name("TestDevice").Device("FakeGPU"), DummyOp);
174 
175 REGISTER_OP("TestDeviceEnforce").Input("a: Ref(float)").Output("b: float");
176 REGISTER_KERNEL_BUILDER(Name("TestDeviceEnforce").Device("FakeCPU"), DummyOp);
177 REGISTER_KERNEL_BUILDER(Name("TestDeviceEnforce").Device("FakeGPU"), DummyOp);
178 
179 REGISTER_KERNEL_BUILDER(Name("Shape").Device("FakeCPU"), DummyOp);
180 REGISTER_KERNEL_BUILDER(Name("Shape").Device("FakeGPU"), DummyOp);
181 
182 // Op that has kernels with device priorities specified.
183 REGISTER_OP("TestDatasetOp").Input("a: float").Output("b: float");
184 REGISTER_KERNEL_BUILDER(Name("TestDatasetOp").Device("FakeCPU").Priority(2),
185                         DummyOp);
186 REGISTER_KERNEL_BUILDER(Name("TestDatasetOp").Device("FakeGPU").Priority(1),
187                         DummyOp);
188 
189 ////////////////////////////////////////////////////////////////////////////////
190 //
191 // A PlacerTest method has three phases:
192 //
193 // 1. Build a TensorFlow graph, with no (or partial) device assignments.
194 // 2. Attempt to compute a placement using the Placer.
195 // 3. EITHER: test that the constraints implied by the graph are respected;
196 //    or that an appropriate error was reported.
197 //
198 ////////////////////////////////////////////////////////////////////////////////
199 class PlacerTest : public ::testing::Test {
200  protected:
PlacerTest()201   PlacerTest() {
202     // Build a set of 10 GPU and 10 CPU devices.
203     // NOTE: this->local_devices_ owns the device objects;
204     // this->devices_ contains borrowed pointers to the device
205     // objects.
206     for (int i = 0; i < 10; ++i) {
207       local_devices_.emplace_back(FakeDevice::MakeCPU(
208           strings::StrCat("/job:a/replica:0/task:0/device:fakecpu:", i)));
209       devices_.AddDevice(local_devices_.back().get());
210       // Insert the GPUs in reverse order.
211       local_devices_.emplace_back(FakeDevice::MakeGPU(
212           strings::StrCat("/job:a/replica:0/task:0/device:fakegpu:", 9 - i)));
213       devices_.AddDevice(local_devices_.back().get());
214     }
215   }
216 
217   // Builds the given graph, and (if successful) indexes the node
218   // names for use in placement, and later lookup.
BuildGraph(const GraphDefBuilder & builder,Graph * out_graph)219   Status BuildGraph(const GraphDefBuilder& builder, Graph* out_graph) {
220     TF_RETURN_IF_ERROR(GraphDefBuilderToGraph(builder, out_graph));
221     nodes_by_name_.clear();
222     for (Node* node : out_graph->nodes()) {
223       nodes_by_name_[node->name()] = node->id();
224     }
225     return Status::OK();
226   }
227 
BuildGraph(const GraphDef & graph_def,Graph * out_graph)228   Status BuildGraph(const GraphDef& graph_def, Graph* out_graph) {
229     GraphConstructorOptions opts;
230     TF_RETURN_IF_ERROR(ConvertGraphDefToGraph(opts, graph_def, out_graph));
231     nodes_by_name_.clear();
232     for (Node* node : out_graph->nodes()) {
233       nodes_by_name_[node->name()] = node->id();
234     }
235     return Status::OK();
236   }
237 
238   // Invokes the Placer on "graph". If no DeviceSet is specified, the
239   // placement will use the default DeviceSet (of 10 CPU and 10 GPU devices).
240   //
241   // REQUIRES: "*graph" was produced by the most recent call to BuildGraph.
Place(Graph * graph,DeviceSet * devices,bool allow_soft_placement,bool log_device_placement)242   Status Place(Graph* graph, DeviceSet* devices, bool allow_soft_placement,
243                bool log_device_placement) {
244     Placer placer(graph, devices, nullptr, allow_soft_placement,
245                   log_device_placement);
246     return placer.Run();
247   }
248 
Place(Graph * graph,DeviceSet * devices)249   Status Place(Graph* graph, DeviceSet* devices) {
250     return Place(graph, devices, true, false);
251   }
252 
Place(Graph * graph,bool allow_soft_placement,bool log_device_placement)253   Status Place(Graph* graph, bool allow_soft_placement,
254                bool log_device_placement) {
255     return Place(graph, &devices_, allow_soft_placement, log_device_placement);
256   }
257 
Place(Graph * graph)258   Status Place(Graph* graph) { return Place(graph, &devices_, true, false); }
259 
260   // Returns the node in "graph" with the given name.
261   //
262   // REQUIRES: "graph" was produced by the most recent call to BuildGraph.
GetNodeByName(const Graph & graph,const string & name)263   Node* GetNodeByName(const Graph& graph, const string& name) {
264     const auto search = nodes_by_name_.find(name);
265     CHECK(search != nodes_by_name_.end()) << "Unknown node name: " << name;
266     return graph.FindNodeId(search->second);
267   }
268 
269  protected:
270   std::vector<std::unique_ptr<Device>> local_devices_;
271   DeviceSet devices_;
272   Placer::NodeNameToIdMap nodes_by_name_;
273 
274   Status ReferenceTestHelper(const string& variable_op_type,
275                              const string& assign_op_type,
276                              const DeviceType& expected_device_type);
277 };
278 
279 // Fixture that add a parameter for allow_soft_placement.
280 // Test cases that want to test behavior with and without soft placement
281 // can use this fixture instead of PlacerTest.
282 class SoftPlacementPlacerTest : public PlacerTest,
283                                 public ::testing::WithParamInterface<bool> {};
284 
285 INSTANTIATE_TEST_SUITE_P(, SoftPlacementPlacerTest,
286                          ::testing::Values(false, true),
287                          ::testing::PrintToStringParamName());
288 
289 #define EXPECT_COLOCATED(g, name_a, name_b)                         \
290   do {                                                              \
291     Graph& g_ = (g);                                                \
292     EXPECT_EQ(GetNodeByName(g_, (name_a))->assigned_device_name(),  \
293               GetNodeByName(g_, (name_b))->assigned_device_name()); \
294   } while (0)
295 
296 #define EXPECT_NOT_COLOCATED(g, name_a, name_b)                     \
297   do {                                                              \
298     Graph& g_ = (g);                                                \
299     EXPECT_NE(GetNodeByName(g_, (name_a))->assigned_device_name(),  \
300               GetNodeByName(g_, (name_b))->assigned_device_name()); \
301   } while (0)
302 
303 #define EXPECT_DEVICE_TYPE(g, name, expected_device_type)               \
304   EXPECT_EQ(DeviceType(expected_device_type).type(),                    \
305             devices_                                                    \
306                 .FindDeviceByName(                                      \
307                     GetNodeByName((g), (name))->assigned_device_name()) \
308                 ->attributes()                                          \
309                 .device_type())
310 
311 #define EXPECT_DEVICE_CONTAINS(g, name, device_substr) \
312   EXPECT_TRUE(::tensorflow::str_util::StrContains(     \
313       GetNodeByName((g), (name))->assigned_device_name(), device_substr))
314 
315 // Test that a graph with no constraints will successfully assign nodes to the
316 // "best available" device (i.e. prefer GPU over CPU).
TEST_F(PlacerTest,TestNoConstraints)317 TEST_F(PlacerTest, TestNoConstraints) {
318   Graph g(OpRegistry::Global());
319   {  // Scope for temporary variables used to construct g.
320     GraphDefBuilder b(GraphDefBuilder::kFailImmediately);
321     Node* input = ops::SourceOp("TestInput", b.opts().WithName("in"));
322     ops::UnaryOp("TestRelu", ops::NodeOut(input, 0), b.opts().WithName("n1"));
323     ops::UnaryOp("TestRelu", ops::NodeOut(input, 1), b.opts().WithName("n2"));
324     TF_EXPECT_OK(BuildGraph(b, &g));
325   }
326 
327   TF_EXPECT_OK(Place(&g));
328   EXPECT_DEVICE_TYPE(g, "in", "FakeCPU");
329   EXPECT_DEVICE_TYPE(g, "n1", "FakeGPU");
330   EXPECT_DEVICE_TYPE(g, "n2", "FakeGPU");
331 }
332 
333 // Test that a graph with no constraints but using kernels that have a specified
334 // device priority will successfully assign nodes to the device with higher
335 // priority
TEST_F(PlacerTest,TestNoConstraintsWithPrioritizedKernels)336 TEST_F(PlacerTest, TestNoConstraintsWithPrioritizedKernels) {
337   Graph g(OpRegistry::Global());
338   {  // Scope for temporary variables used to construct g.
339     GraphDefBuilder b(GraphDefBuilder::kFailImmediately);
340     Node* input = ops::SourceOp("TestInput", b.opts().WithName("in"));
341     ops::UnaryOp("TestDatasetOp", ops::NodeOut(input, 0),
342                  b.opts().WithName("n1"));
343     ops::UnaryOp("TestDatasetOp", ops::NodeOut(input, 1),
344                  b.opts().WithName("n2"));
345     TF_EXPECT_OK(BuildGraph(b, &g));
346   }
347 
348   TF_EXPECT_OK(Place(&g));
349   EXPECT_DEVICE_TYPE(g, "in", "FakeCPU");
350   EXPECT_DEVICE_TYPE(g, "n1", "FakeCPU");
351   EXPECT_DEVICE_TYPE(g, "n2", "FakeCPU");
352 }
353 
TEST_F(PlacerTest,TestGPUInputIntoPrioritizedKernel)354 TEST_F(PlacerTest, TestGPUInputIntoPrioritizedKernel) {
355   Graph g(OpRegistry::Global());
356   {
357     // Scope for temp variables used to construct g.
358     GraphDefBuilder b(GraphDefBuilder::kFailImmediately);
359     Node* input = ops::SourceOp("TestGPUOutput", b.opts().WithName("in"));
360     ops::UnaryOp("TestDatasetOp", ops::NodeOut(input, 0),
361                  b.opts().WithName("n1"));
362     TF_EXPECT_OK(BuildGraph(b, &g));
363   }
364 
365   TF_EXPECT_OK(Place(&g));
366   EXPECT_DEVICE_TYPE(g, "in", "FakeGPU");
367   EXPECT_DEVICE_TYPE(g, "n1", "FakeCPU");
368 }
369 
370 // Tests that a GPU kernel colocated with prioritized kernel respects it.
TEST_F(PlacerTest,TestGPUInputColocatedWithPrioritizedKernel)371 TEST_F(PlacerTest, TestGPUInputColocatedWithPrioritizedKernel) {
372   Graph g(OpRegistry::Global());
373   {
374     // Scope for temp variables used to construct g.
375     GraphDefBuilder b(GraphDefBuilder::kFailImmediately);
376     Node* input = ops::SourceOp("TestGPUOutput", b.opts().WithName("in"));
377     // We colocate n1 with in.
378     ops::UnaryOp("TestDatasetOp", ops::NodeOut(input, 0),
379                  b.opts().WithName("n1").WithAttr("_class", {"loc:@in"}));
380     // We don't colocate n2 with in.
381     ops::UnaryOp("TestDatasetOp", ops::NodeOut(input, 0),
382                  b.opts().WithName("n2"));
383     TF_EXPECT_OK(BuildGraph(b, &g));
384   }
385 
386   TF_EXPECT_OK(Place(&g));
387   EXPECT_DEVICE_TYPE(g, "in", "FakeGPU");
388   EXPECT_DEVICE_TYPE(g, "n1", "FakeGPU");
389   EXPECT_DEVICE_TYPE(g, "n2", "FakeCPU");
390 }
391 
392 REGISTER_OP("CreateDatasetCPU").Output("o: resource");
393 REGISTER_KERNEL_BUILDER(Name("CreateDatasetCPU").Device("FakeCPU"), DummyOp);
394 
395 REGISTER_OP("CreateDatasetSP").Output("o: resource");
396 REGISTER_KERNEL_BUILDER(Name("CreateDatasetSP").Device("FakeCPU").Priority(2),
397                         DummyOp);
398 REGISTER_KERNEL_BUILDER(Name("CreateDatasetSP").Device("FakeGPU").Priority(1),
399                         DummyOp);
400 
401 REGISTER_OP("CreateDatasetRP").Output("o: resource");
402 REGISTER_KERNEL_BUILDER(Name("CreateDatasetRP").Device("FakeCPU").Priority(1),
403                         DummyOp);
404 REGISTER_KERNEL_BUILDER(Name("CreateDatasetRP").Device("FakeGPU").Priority(2),
405                         DummyOp);
406 
407 REGISTER_OP("CreateDatasetNP").Output("o: resource");
408 REGISTER_KERNEL_BUILDER(Name("CreateDatasetNP").Device("FakeCPU"), DummyOp);
409 REGISTER_KERNEL_BUILDER(Name("CreateDatasetNP").Device("FakeGPU"), DummyOp);
410 
411 REGISTER_OP("IteratorNP").Input("i: resource").Output("o: float");
412 REGISTER_KERNEL_BUILDER(Name("IteratorNP").Device("FakeCPU"), DummyOp);
413 REGISTER_KERNEL_BUILDER(Name("IteratorNP").Device("FakeGPU"), DummyOp);
414 
415 REGISTER_OP("IteratorSP").Input("i: resource").Output("o: float");
416 REGISTER_KERNEL_BUILDER(Name("IteratorSP").Device("FakeCPU").Priority(2),
417                         DummyOp);
418 REGISTER_KERNEL_BUILDER(Name("IteratorSP").Device("FakeGPU").Priority(1),
419                         DummyOp);
420 
421 REGISTER_OP("IteratorRP").Input("i: resource").Output("o: float");
422 REGISTER_KERNEL_BUILDER(Name("IteratorRP").Device("FakeCPU").Priority(1),
423                         DummyOp);
424 REGISTER_KERNEL_BUILDER(Name("IteratorRP").Device("FakeGPU").Priority(2),
425                         DummyOp);
426 
427 REGISTER_OP("IteratorGPU").Input("i: resource").Output("o: float");
428 REGISTER_KERNEL_BUILDER(Name("IteratorGPU").Device("FakeGPU"), DummyOp);
429 
430 // Test reference edges with one node having prioritized kernels and the other
431 // has no preference. We should respect priority here.
TEST_F(PlacerTest,TestDSWithPriority)432 TEST_F(PlacerTest, TestDSWithPriority) {
433   Graph g(OpRegistry::Global());
434   {
435     GraphDefBuilder b(GraphDefBuilder::kFailImmediately);
436     Node* ds = ops::SourceOp("CreateDatasetSP", b.opts().WithName("ds"));
437     ops::UnaryOp("IteratorNP", ops::NodeOut(ds, 0), b.opts().WithName("it"));
438     TF_EXPECT_OK(BuildGraph(b, &g));
439   }
440   TF_EXPECT_OK(Place(&g));
441   EXPECT_DEVICE_TYPE(g, "ds", "FakeCPU");
442   EXPECT_DEVICE_TYPE(g, "it", "FakeCPU");
443 }
444 
445 // Test reference edges with one node having kernels with regular priority and
446 // the other has no preference. We should respect priority here.
TEST_F(PlacerTest,TestDSWithGPUPriority)447 TEST_F(PlacerTest, TestDSWithGPUPriority) {
448   Graph g(OpRegistry::Global());
449   {
450     GraphDefBuilder b(GraphDefBuilder::kFailImmediately);
451     Node* ds = ops::SourceOp("CreateDatasetRP", b.opts().WithName("ds"));
452     ops::UnaryOp("IteratorNP", ops::NodeOut(ds, 0), b.opts().WithName("it"));
453     TF_EXPECT_OK(BuildGraph(b, &g));
454   }
455   TF_EXPECT_OK(Place(&g));
456   EXPECT_DEVICE_TYPE(g, "ds", "FakeGPU");
457   EXPECT_DEVICE_TYPE(g, "it", "FakeGPU");
458 }
459 
460 // Test reference edges with one node having prioritized kernels and the other
461 // has no preference. We should respect priority here.
TEST_F(PlacerTest,TestITWithPriority)462 TEST_F(PlacerTest, TestITWithPriority) {
463   Graph g(OpRegistry::Global());
464   {
465     GraphDefBuilder b(GraphDefBuilder::kFailImmediately);
466     Node* ds = ops::SourceOp("CreateDatasetNP", b.opts().WithName("ds"));
467     ops::UnaryOp("IteratorSP", ops::NodeOut(ds, 0), b.opts().WithName("it"));
468     TF_EXPECT_OK(BuildGraph(b, &g));
469   }
470   TF_EXPECT_OK(Place(&g));
471   EXPECT_DEVICE_TYPE(g, "ds", "FakeCPU");
472   EXPECT_DEVICE_TYPE(g, "it", "FakeCPU");
473 }
474 
475 // Test reference edges with one node having kernels with regular priority and
476 // the other has no preference. We should respect priority here.
TEST_F(PlacerTest,TestITWithGPUPriority)477 TEST_F(PlacerTest, TestITWithGPUPriority) {
478   Graph g(OpRegistry::Global());
479   {
480     GraphDefBuilder b(GraphDefBuilder::kFailImmediately);
481     Node* ds = ops::SourceOp("CreateDatasetNP", b.opts().WithName("ds"));
482     ops::UnaryOp("IteratorRP", ops::NodeOut(ds, 0), b.opts().WithName("it"));
483     TF_EXPECT_OK(BuildGraph(b, &g));
484   }
485   TF_EXPECT_OK(Place(&g));
486   EXPECT_DEVICE_TYPE(g, "ds", "FakeGPU");
487   EXPECT_DEVICE_TYPE(g, "it", "FakeGPU");
488 }
489 
490 // Test reference edges with one node having prioritized kernels and other node
491 // can only be placed on GPU. We should respect the constraint then.
TEST_F(PlacerTest,TestITGPU)492 TEST_F(PlacerTest, TestITGPU) {
493   Graph g(OpRegistry::Global());
494   {
495     GraphDefBuilder b(GraphDefBuilder::kFailImmediately);
496     Node* ds = ops::SourceOp("CreateDatasetSP", b.opts().WithName("ds"));
497     ops::UnaryOp("IteratorGPU", ops::NodeOut(ds, 0), b.opts().WithName("it"));
498     TF_EXPECT_OK(BuildGraph(b, &g));
499   }
500   TF_EXPECT_OK(Place(&g));
501   EXPECT_DEVICE_TYPE(g, "ds", "FakeGPU");
502   EXPECT_DEVICE_TYPE(g, "it", "FakeGPU");
503 }
504 
505 // Test reference edges with one node having prioritized kernels and other node
506 // can only be placed on CPU. We should respect the constraint then.
TEST_F(PlacerTest,TestSimpleIteratorOnlyGPU)507 TEST_F(PlacerTest, TestSimpleIteratorOnlyGPU) {
508   Graph g(OpRegistry::Global());
509   {
510     GraphDefBuilder b(GraphDefBuilder::kFailImmediately);
511     Node* ds = ops::SourceOp("CreateDatasetCPU", b.opts().WithName("ds"));
512     ops::UnaryOp("IteratorRP", ops::NodeOut(ds, 0), b.opts().WithName("it"));
513     TF_EXPECT_OK(BuildGraph(b, &g));
514   }
515   TF_EXPECT_OK(Place(&g));
516   EXPECT_DEVICE_TYPE(g, "ds", "FakeCPU");
517   EXPECT_DEVICE_TYPE(g, "it", "FakeCPU");
518 }
519 
520 // Test constraints with agreeing priorities.
TEST_F(PlacerTest,TestAgreeingPriorities)521 TEST_F(PlacerTest, TestAgreeingPriorities) {
522   Graph g(OpRegistry::Global());
523   {
524     GraphDefBuilder b(GraphDefBuilder::kFailImmediately);
525     Node* ds = ops::SourceOp("CreateDatasetSP", b.opts().WithName("ds"));
526     ops::UnaryOp("IteratorSP", ops::NodeOut(ds, 0), b.opts().WithName("it"));
527     TF_EXPECT_OK(BuildGraph(b, &g));
528   }
529   TF_EXPECT_OK(Place(&g));
530   EXPECT_DEVICE_TYPE(g, "ds", "FakeCPU");
531   EXPECT_DEVICE_TYPE(g, "it", "FakeCPU");
532 }
533 
534 // Test constraints with agreeing regular priorities.
TEST_F(PlacerTest,TestAgreeingRegularPriorities)535 TEST_F(PlacerTest, TestAgreeingRegularPriorities) {
536   Graph g(OpRegistry::Global());
537   {
538     GraphDefBuilder b(GraphDefBuilder::kFailImmediately);
539     Node* ds = ops::SourceOp("CreateDatasetRP", b.opts().WithName("ds"));
540     ops::UnaryOp("IteratorRP", ops::NodeOut(ds, 0), b.opts().WithName("it"));
541     TF_EXPECT_OK(BuildGraph(b, &g));
542   }
543   TF_EXPECT_OK(Place(&g));
544   EXPECT_DEVICE_TYPE(g, "ds", "FakeGPU");
545   EXPECT_DEVICE_TYPE(g, "it", "FakeGPU");
546 }
547 
548 // Test constraints with different priorities. In this case, we should bail
549 // and just revert to default.
TEST_F(PlacerTest,TestConflictingPriorities)550 TEST_F(PlacerTest, TestConflictingPriorities) {
551   Graph g(OpRegistry::Global());
552   {
553     GraphDefBuilder b(GraphDefBuilder::kFailImmediately);
554     Node* ds = ops::SourceOp("CreateDatasetSP", b.opts().WithName("ds"));
555     ops::UnaryOp("IteratorRP", ops::NodeOut(ds, 0), b.opts().WithName("it"));
556     TF_EXPECT_OK(BuildGraph(b, &g));
557   }
558   TF_EXPECT_OK(Place(&g));
559   EXPECT_DEVICE_TYPE(g, "ds", "FakeGPU");
560   EXPECT_DEVICE_TYPE(g, "it", "FakeGPU");
561 }
562 
563 // Test constraints with different priorities. In this case, we should bail
564 // and just revert to default.
TEST_F(PlacerTest,TestConflictingPrioritiesReversed)565 TEST_F(PlacerTest, TestConflictingPrioritiesReversed) {
566   Graph g(OpRegistry::Global());
567   {
568     GraphDefBuilder b(GraphDefBuilder::kFailImmediately);
569     Node* ds = ops::SourceOp("CreateDatasetRP", b.opts().WithName("ds"));
570     ops::UnaryOp("IteratorSP", ops::NodeOut(ds, 0), b.opts().WithName("it"));
571     TF_EXPECT_OK(BuildGraph(b, &g));
572   }
573   TF_EXPECT_OK(Place(&g));
574   EXPECT_DEVICE_TYPE(g, "ds", "FakeGPU");
575   EXPECT_DEVICE_TYPE(g, "it", "FakeGPU");
576 }
577 
578 // Test that a graph with device type and reference constraints on
579 // some of the ops will successfully assign nodes to the constrained
580 // device, and colocate nodes with reference connections.
TEST_F(PlacerTest,TestDeviceTypeConstraints)581 TEST_F(PlacerTest, TestDeviceTypeConstraints) {
582   Graph g(OpRegistry::Global());
583   {  // Scope for temporary variables used to construct g.
584     GraphDefBuilder b(GraphDefBuilder::kFailImmediately);
585     Node* input = ops::SourceOp("TestInput", b.opts().WithName("in"));
586     Node* var_cpu = ops::SourceOp("VariableCPU", b.opts().WithName("var_cpu"));
587     ops::BinaryOp("AssignCPU", var_cpu, input, b.opts().WithName("assign_cpu"));
588     Node* var_gpu = ops::SourceOp("VariableGPU", b.opts().WithName("var_gpu"));
589     ops::BinaryOp("AssignGPU", var_gpu, input, b.opts().WithName("assign_gpu"));
590     TF_EXPECT_OK(BuildGraph(b, &g));
591   }
592 
593   TF_EXPECT_OK(Place(&g));
594   EXPECT_DEVICE_TYPE(g, "in", "FakeCPU");
595   EXPECT_DEVICE_TYPE(g, "var_cpu", "FakeCPU");
596   EXPECT_DEVICE_TYPE(g, "assign_cpu", "FakeCPU");
597   EXPECT_COLOCATED(g, "var_cpu", "assign_cpu");
598   EXPECT_DEVICE_TYPE(g, "var_gpu", "FakeGPU");
599   EXPECT_DEVICE_TYPE(g, "assign_gpu", "FakeGPU");
600   EXPECT_COLOCATED(g, "var_gpu", "assign_gpu");
601 }
602 
TEST_F(PlacerTest,TestMetadataColocatedWithInput)603 TEST_F(PlacerTest, TestMetadataColocatedWithInput) {
604   Graph g(OpRegistry::Global());
605   {  // Scope for temporary variables used to construct g.
606     GraphDefBuilder b(GraphDefBuilder::kFailImmediately);
607     Node* var_cpu = ops::SourceOp("VariableCPU", b.opts().WithName("var_cpu"));
608 
609     // Normally, shape has a GPU implementation and would be placed
610     // on GPU.  However, because it is a metadata operation, it is
611     // placed on CPU to avoid transferring the data from CPU to GPU.
612     ops::UnaryOp("Shape", var_cpu, b.opts().WithName("shape_op"));
613     TF_EXPECT_OK(BuildGraph(b, &g));
614   }
615 
616   TF_EXPECT_OK(Place(&g));
617   EXPECT_DEVICE_TYPE(g, "var_cpu", "FakeCPU");
618   EXPECT_DEVICE_TYPE(g, "shape_op", "FakeCPU");
619   EXPECT_COLOCATED(g, "var_cpu", "shape_op");
620 }
621 
622 // Heuristic A implements "Island fusing": if a node only generates
623 // an output and it has only one consumer, we place the node
624 // with its consumer.
TEST_F(PlacerTest,TestHeuristicGeneratorFollowsSingleConsumer)625 TEST_F(PlacerTest, TestHeuristicGeneratorFollowsSingleConsumer) {
626   Graph g(OpRegistry::Global());
627   {  // Scope for temporary variables used to construct g.
628     GraphDefBuilder b(GraphDefBuilder::kFailImmediately);
629 
630     // A variable is only on CPU
631     Node* var_cpu = ops::SourceOp("VariableCPU", b.opts().WithName("var_cpu"));
632 
633     // The constant to be assigned can be on both GPU or CPU.
634     //
635     // Because of the heuristic, it gets placed on CPU to avoid a
636     // copy.
637     Node* input = ops::SourceOp("TestCPUGPUOutput", b.opts().WithName("in"));
638 
639     // The assign is bound to CPU by the reference edge.
640     ops::BinaryOp("TestAssign", var_cpu, input, b.opts().WithName("assign"));
641 
642     TF_EXPECT_OK(BuildGraph(b, &g));
643   }
644 
645   TF_EXPECT_OK(Place(&g));
646   EXPECT_COLOCATED(g, "var_cpu", "in");
647   EXPECT_COLOCATED(g, "assign", "in");
648 }
649 
TEST_F(PlacerTest,TestIgnoreGeneratorHeuristicIfWrongDevice)650 TEST_F(PlacerTest, TestIgnoreGeneratorHeuristicIfWrongDevice) {
651   Graph g(OpRegistry::Global());
652   {  // Scope for temporary variables used to construct g.
653     GraphDefBuilder b(GraphDefBuilder::kFailImmediately);
654 
655     // A variable is only on CPU
656     Node* var_cpu = ops::SourceOp("VariableCPU", b.opts().WithName("var_cpu"));
657 
658     // The constant to be assigned can only be on GPU.
659     //
660     // The heuristic to place the generator with its consumer does
661     // not apply since the consumer's device is not in the list
662     // of valid devices for the generator.
663     Node* input = ops::SourceOp("TestGPUOutput", b.opts().WithName("in"));
664 
665     // The assign is bound to CPU by the reference edge.
666     ops::BinaryOp("TestAssign", var_cpu, input, b.opts().WithName("assign"));
667 
668     TF_EXPECT_OK(BuildGraph(b, &g));
669   }
670 
671   TF_EXPECT_OK(Place(&g));
672   EXPECT_DEVICE_TYPE(g, "in", "FakeGPU");
673   EXPECT_DEVICE_TYPE(g, "var_cpu", "FakeCPU");
674   EXPECT_COLOCATED(g, "var_cpu", "assign");
675 }
676 
TEST_F(PlacerTest,TestIgnoreGeneratorHeuristicIfWrongPartialDevice)677 TEST_F(PlacerTest, TestIgnoreGeneratorHeuristicIfWrongPartialDevice) {
678   Graph g(OpRegistry::Global());
679   {  // Scope for temporary variables used to construct g.
680     GraphDefBuilder b(GraphDefBuilder::kFailImmediately);
681 
682     // A variable is only on CPU
683     Node* var_cpu = ops::SourceOp("VariableCPU", b.opts().WithName("var_cpu"));
684 
685     // The constant to be assigned can be on CPU or GPU, but is explicitly
686     // placed on CPU:1.
687     //
688     // The heuristic to place the generator with its consumer does
689     // not apply since the consumer's device is not in the list
690     // of valid devices for the generator.
691     Node* input =
692         ops::SourceOp("TestCPUGPUOutput",
693                       b.opts().WithName("in").WithDevice("/device:fakecpu:1"));
694 
695     // The assign is bound to CPU by the reference edge.
696     ops::BinaryOp("TestAssign", var_cpu, input, b.opts().WithName("assign"));
697 
698     TF_EXPECT_OK(BuildGraph(b, &g));
699   }
700 
701   TF_EXPECT_OK(Place(&g));
702   EXPECT_DEVICE_TYPE(g, "in", "FakeCPU");
703   EXPECT_DEVICE_CONTAINS(g, "in", "/device:fakecpu:1");
704   EXPECT_DEVICE_TYPE(g, "var_cpu", "FakeCPU");
705   EXPECT_COLOCATED(g, "var_cpu", "assign");
706   EXPECT_DEVICE_CONTAINS(g, "var_cpu", "/device:fakecpu:0");
707 }
708 
709 // Test that a graph with partial device specifications on the ops
710 // will successfully
TEST_F(PlacerTest,TestPartialSpec)711 TEST_F(PlacerTest, TestPartialSpec) {
712   Graph g(OpRegistry::Global());
713   {  // Scope for temporary variables used to construct g.
714     GraphDefBuilder b(GraphDefBuilder::kFailImmediately);
715     ops::SourceOp("TestInput", b.opts().WithName("in").WithDevice("/job:a"));
716     ops::SourceOp("TestVariable",
717                   b.opts().WithName("var").WithDevice("/job:a"));
718     TF_EXPECT_OK(BuildGraph(b, &g));
719   }
720 
721   TF_EXPECT_OK(Place(&g));
722   EXPECT_DEVICE_TYPE(g, "in", "FakeCPU");
723   EXPECT_DEVICE_CONTAINS(g, "in", "/job:a");
724   EXPECT_DEVICE_TYPE(g, "var", "FakeGPU");
725   EXPECT_DEVICE_CONTAINS(g, "var", "/job:a");
726 }
727 
728 // Test that a node with a pre-assigned device is not relocated.
TEST_F(PlacerTest,TestAssignedDevicePreserved)729 TEST_F(PlacerTest, TestAssignedDevicePreserved) {
730   Graph g(OpRegistry::Global());
731   {  // Scope for temporary variables used to construct g.
732     GraphDefBuilder b(GraphDefBuilder::kFailImmediately);
733     ops::SourceOp("TestInput", b.opts().WithName("in"));
734     TF_EXPECT_OK(BuildGraph(b, &g));
735   }
736 
737   GetNodeByName(g, "in")->set_assigned_device_name(
738       "/job:a/replica:0/task:0/device:fakecpu:7");
739 
740   TF_EXPECT_OK(Place(&g));
741   EXPECT_EQ("/job:a/replica:0/task:0/device:fakecpu:7",
742             GetNodeByName(g, "in")->assigned_device_name());
743 }
744 
745 // Test that a graph with partial device specifications for CPU-only ops
746 // will be relocated to CPU.
TEST_F(PlacerTest,TestPartialSpecGpuToCpu)747 TEST_F(PlacerTest, TestPartialSpecGpuToCpu) {
748   Graph g(OpRegistry::Global());
749   {  // Scope for temporary variables used to construct g.
750     GraphDefBuilder b(GraphDefBuilder::kFailImmediately);
751     ops::SourceOp("TestInput",
752                   b.opts().WithName("in").WithDevice("/device:fakegpu:0"));
753     ops::SourceOp("TestVariable",
754                   b.opts().WithName("var").WithDevice("/device:fakegpu:0"));
755     TF_EXPECT_OK(BuildGraph(b, &g));
756   }
757 
758   TF_EXPECT_OK(Place(&g, true, false));
759   EXPECT_DEVICE_TYPE(g, "in", "FakeCPU");
760   EXPECT_DEVICE_CONTAINS(g, "in", "/device:fakecpu");
761   EXPECT_DEVICE_TYPE(g, "var", "FakeGPU");
762   EXPECT_DEVICE_CONTAINS(g, "var", "/device:fakegpu:0");
763 }
764 
765 // Test that a node with an assigned GPU device but has not registered
766 // OpKernel will fail.
TEST_F(PlacerTest,TestAssignedGpuDeviceToCpuDevice)767 TEST_F(PlacerTest, TestAssignedGpuDeviceToCpuDevice) {
768   Graph g(OpRegistry::Global());
769   {  // Scope for temporary variables used to construct g.
770     GraphDefBuilder b(GraphDefBuilder::kFailImmediately);
771     ops::SourceOp("TestInput", b.opts().WithName("in"));
772     TF_EXPECT_OK(BuildGraph(b, &g));
773   }
774 
775   GetNodeByName(g, "in")->set_assigned_device_name(
776       "/job:a/replica:0/task:0/device:fakegpu:0");
777 
778   Status s = Place(&g);
779   EXPECT_EQ(error::INTERNAL, s.code());
780   EXPECT_TRUE(str_util::StrContains(
781       s.error_message(),
782       "Assigned device '/job:a/replica:0/task:0/device:fakegpu:0' "
783       "does not have registered OpKernel support for TestInput"));
784 }
785 
786 // Test that graphs with reference connections are correctly placed.
787 
788 // Build a graph containing a Variable op of "variable_op_type" and an
789 // Assign op of "assign_op_type", and expect all of the ops to be
790 // placed on a device of type "expected_device_type".
ReferenceTestHelper(const string & variable_op_type,const string & assign_op_type,const DeviceType & expected_device_type)791 Status PlacerTest::ReferenceTestHelper(const string& variable_op_type,
792                                        const string& assign_op_type,
793                                        const DeviceType& expected_device_type) {
794   Graph g(OpRegistry::Global());
795   {  // Scope for temporary variables used to construct g.
796     GraphDefBuilder b(GraphDefBuilder::kFailImmediately);
797     Node* input = ops::SourceOp("TestInput", b.opts().WithName("in"));
798     // Build ten variable-and-assignment pairs.
799     for (int i = 0; i < 10; ++i) {
800       Node* var = ops::SourceOp(variable_op_type,
801                                 b.opts().WithName(strings::StrCat("var_", i)));
802       ops::BinaryOp(assign_op_type, var, input,
803                     b.opts().WithName(strings::StrCat("assign_", i)));
804     }
805     TF_EXPECT_OK(BuildGraph(b, &g));
806   }
807 
808   TF_RETURN_IF_ERROR(Place(&g));
809 
810   for (int i = 0; i < 10; ++i) {
811     EXPECT_COLOCATED(g, strings::StrCat("var_", i),
812                      strings::StrCat("assign_", i));
813     EXPECT_DEVICE_TYPE(g, strings::StrCat("var_", i), expected_device_type);
814     EXPECT_DEVICE_TYPE(g, strings::StrCat("assign_", i), expected_device_type);
815   }
816 
817   return Status::OK();
818 }
819 
820 // Test all 2^3 combinations of Variable and Assignment op types
821 // (unconstrained, CPU-only, and GPU-only).
TEST_F(PlacerTest,TestReferenceConnection)822 TEST_F(PlacerTest, TestReferenceConnection) {
823   Status s;
824   TF_EXPECT_OK(ReferenceTestHelper("TestVariable", "TestAssign", "FakeGPU"));
825   TF_EXPECT_OK(ReferenceTestHelper("TestVariable", "AssignCPU", "FakeCPU"));
826   TF_EXPECT_OK(ReferenceTestHelper("TestVariable", "AssignGPU", "FakeGPU"));
827   TF_EXPECT_OK(ReferenceTestHelper("VariableCPU", "TestAssign", "FakeCPU"));
828   TF_EXPECT_OK(ReferenceTestHelper("VariableCPU", "AssignCPU", "FakeCPU"));
829   {
830     Status s = ReferenceTestHelper("VariableCPU", "AssignGPU", "FakeCPU");
831     EXPECT_EQ(error::INVALID_ARGUMENT, s.code());
832     EXPECT_TRUE(str_util::StrContains(
833         s.error_message(), "no device type supports both of those nodes"));
834   }
835   TF_EXPECT_OK(ReferenceTestHelper("VariableGPU", "TestAssign", "FakeGPU"));
836   {
837     Status s = ReferenceTestHelper("VariableGPU", "AssignCPU", "FakeCPU");
838     EXPECT_EQ(error::INVALID_ARGUMENT, s.code());
839     EXPECT_TRUE(str_util::StrContains(
840         s.error_message(), "no device type supports both of those nodes"));
841   }
842   TF_EXPECT_OK(ReferenceTestHelper("VariableGPU", "AssignGPU", "FakeGPU"));
843 }
844 
845 // Handle-using dummy variable ops.
846 REGISTER_OP("TestHandleVariable").Output("o: resource");
847 REGISTER_KERNEL_BUILDER(Name("TestHandleVariable").Device("FakeCPU"), DummyOp);
848 REGISTER_KERNEL_BUILDER(Name("TestHandleVariable").Device("FakeGPU"), DummyOp);
849 
850 REGISTER_OP("HandleVariableCPU").Output("o: resource");
851 REGISTER_KERNEL_BUILDER(Name("HandleVariableCPU").Device("FakeCPU"), DummyOp);
852 
853 REGISTER_OP("HandleVariableGPU").Output("o: resource");
854 REGISTER_KERNEL_BUILDER(Name("HandleVariableGPU").Device("FakeGPU"), DummyOp);
855 
856 REGISTER_OP("TestHandleAssign").Input("i: resource").Input("v: float");
857 REGISTER_KERNEL_BUILDER(Name("TestHandleAssign").Device("FakeCPU"), DummyOp);
858 REGISTER_KERNEL_BUILDER(Name("TestHandleAssign").Device("FakeGPU"), DummyOp);
859 
860 REGISTER_OP("HandleAssignCPU").Input("i: resource").Input("v: float");
861 REGISTER_KERNEL_BUILDER(Name("HandleAssignCPU").Device("FakeCPU"), DummyOp);
862 
863 REGISTER_OP("HandleAssignGPU").Input("i: resource").Input("v: float");
864 REGISTER_KERNEL_BUILDER(Name("HandleAssignGPU").Device("FakeGPU"), DummyOp);
865 
866 REGISTER_OP("TestTwoHandlesIn").Input("i: resource").Input("j: resource");
867 REGISTER_KERNEL_BUILDER(Name("TestTwoHandlesIn").Device("FakeCPU"), DummyOp);
868 REGISTER_KERNEL_BUILDER(Name("TestTwoHandlesIn").Device("FakeGPU"), DummyOp);
869 
870 // Tests all combinations of resource handles and ops using them.
TEST_F(PlacerTest,TestResourceHandle)871 TEST_F(PlacerTest, TestResourceHandle) {
872   auto handle_test = [this](const string& var_op_name,
873                             const string& use_op_name, DeviceType device) {
874     Graph g(OpRegistry::Global());
875     {  // Scope for temporary variables used to construct g.
876       GraphDefBuilder b(GraphDefBuilder::kFailImmediately);
877       Node* input = ops::SourceOp("TestInput", b.opts().WithName("in"));
878       Node* var = ops::SourceOp(var_op_name, b.opts().WithName("var"));
879       ops::BinaryOp(use_op_name, var, input, b.opts().WithName("assign"));
880       TF_EXPECT_OK(BuildGraph(b, &g));
881     }
882 
883     TF_RETURN_IF_ERROR(Place(&g));
884 
885     EXPECT_COLOCATED(g, "var", "assign");
886     EXPECT_DEVICE_TYPE(g, "var", device);
887     EXPECT_DEVICE_TYPE(g, "assign", device);
888     return Status::OK();
889   };
890   TF_EXPECT_OK(
891       handle_test("TestHandleVariable", "TestHandleAssign", "FakeGPU"));
892   TF_EXPECT_OK(handle_test("TestHandleVariable", "HandleAssignCPU", "FakeCPU"));
893   TF_EXPECT_OK(handle_test("TestHandleVariable", "HandleAssignGPU", "FakeGPU"));
894   TF_EXPECT_OK(handle_test("HandleVariableCPU", "TestHandleAssign", "FakeCPU"));
895   TF_EXPECT_OK(handle_test("HandleVariableCPU", "HandleAssignCPU", "FakeCPU"));
896   TF_EXPECT_OK(handle_test("HandleVariableGPU", "HandleAssignGPU", "FakeGPU"));
897   TF_EXPECT_OK(handle_test("HandleVariableGPU", "TestHandleAssign", "FakeGPU"));
898   EXPECT_FALSE(
899       handle_test("HandleVariableGPU", "HandleAssignCPU", "FakeCPU").ok());
900   EXPECT_FALSE(
901       handle_test("HandleVariableCPU", "HandleAssignGPU", "FakeCPU").ok());
902 }
903 
TEST_F(PlacerTest,TestResourceHandlesOnDifferentDevicesFails)904 TEST_F(PlacerTest, TestResourceHandlesOnDifferentDevicesFails) {
905   auto handle_test = [this](bool allow_soft_placement, bool set_assigned) {
906     Graph g(OpRegistry::Global());
907     {  // Scope for temporary variables used to construct g.
908       GraphDefBuilder b(GraphDefBuilder::kFailImmediately);
909       Node* var_cpu =
910           ops::SourceOp("TestHandleVariable", b.opts().WithName("var_cpu"));
911       Node* var_gpu =
912           ops::SourceOp("TestHandleVariable", b.opts().WithName("var_gpu"));
913       ops::BinaryOp("TestTwoHandlesIn", var_cpu, var_gpu,
914                     b.opts().WithName("two_handles_in"));
915       TF_EXPECT_OK(BuildGraph(b, &g));
916 
917       if (set_assigned) {
918         GetNodeByName(g, "var_cpu")
919             ->set_assigned_device_name(
920                 "/job:a/replica:0/task:0/device:fakecpu:0");
921         GetNodeByName(g, "var_gpu")
922             ->set_assigned_device_name(
923                 "/job:a/replica:0/task:0/device:fakegpu:0");
924       } else {
925         GetNodeByName(g, "var_cpu")
926             ->set_requested_device("/job:a/replica:0/task:0/device:fakecpu:0");
927         GetNodeByName(g, "var_gpu")
928             ->set_requested_device("/job:a/replica:0/task:0/device:fakegpu:0");
929       }
930     }
931 
932     Status s = Place(&g, allow_soft_placement, true);
933     EXPECT_EQ(error::INVALID_ARGUMENT, s.code()) << s.ToString();
934     EXPECT_TRUE(str_util::StrContains(
935         s.error_message(),
936         "Cannot place the graph because a reference or resource edge "
937         "connects "
938         "colocation groups with incompatible assigned devices: "
939         "/job:a/replica:0/task:0/device:fakegpu:0 vs "
940         "/job:a/replica:0/task:0/device:fakecpu:0"));
941 
942     return Status::OK();
943   };
944 
945   TF_EXPECT_OK(handle_test(false, false));
946   TF_EXPECT_OK(handle_test(false, true));
947   TF_EXPECT_OK(handle_test(true, false));
948   TF_EXPECT_OK(handle_test(true, true));
949 }
950 
951 // Test that an assignment of an operator to the wrong device
952 // is ignored when it could never be satisfied (due to reference
953 // edges, for example).
TEST_F(PlacerTest,TestReferenceConnectionIgnoreInfeasible)954 TEST_F(PlacerTest, TestReferenceConnectionIgnoreInfeasible) {
955   Status s;
956   Graph g(OpRegistry::Global());
957   {
958     GraphDefBuilder b(GraphDefBuilder::kFailImmediately);
959     Node* input = ops::SourceOp(
960         "TestDevice",
961         b.opts().WithName("in").WithDevice("/job:a/task:0/device:fakegpu:0"));
962     Node* var =
963         ops::SourceOp("TestVariable", b.opts().WithName("var_0").WithDevice(
964                                           "/job:a/task:0/device:fakegpu:0"));
965 
966     // This op is specified on CPU, but in practice will be ignored,
967     // because the reference edges forces it on GPU.
968     ops::BinaryOp("TestAssign", var, input,
969                   b.opts().WithName("assign").WithDevice(
970                       "/job:a/task:0/device:fakecpu:0"));
971     TF_EXPECT_OK(BuildGraph(b, &g));
972   }
973 
974   s = Place(&g, false, false);
975   TF_EXPECT_OK(s);
976   EXPECT_DEVICE_TYPE(g, "var_0", "FakeGPU");
977   EXPECT_DEVICE_TYPE(g, "assign", "FakeGPU");
978 }
979 
980 // Test that an assignment of an operator to the a more specified device
981 // causes the device to maintain its more specific placement.
TEST_F(PlacerTest,TestReferenceConnectionMoreSpecificDestinationSourceWins)982 TEST_F(PlacerTest, TestReferenceConnectionMoreSpecificDestinationSourceWins) {
983   Status s;
984   Graph g(OpRegistry::Global());
985   {
986     GraphDefBuilder b(GraphDefBuilder::kFailImmediately);
987     // Input can be on either device
988     Node* input =
989         ops::SourceOp("TestCPUGPUOutput",
990                       b.opts().WithName("in").WithDevice("/job:a/task:0"));
991 
992     // Variable can be on either device
993     Node* var = ops::SourceOp(
994         "TestVariable", b.opts().WithName("var_0").WithDevice("/job:a/task:0"));
995 
996     // This op is specified on CPU and is more specific than the variable.
997     // Because the variable is less specified, the variable will be
998     // assigned to CPU.
999     ops::BinaryOp("TestAssign", var, input,
1000                   b.opts().WithName("assign").WithDevice(
1001                       "/job:a/task:0/device:fakecpu:0"));
1002     TF_EXPECT_OK(BuildGraph(b, &g));
1003   }
1004 
1005   s = Place(&g, false, false);
1006   TF_EXPECT_OK(s);
1007   EXPECT_DEVICE_TYPE(g, "var_0", "FakeCPU");
1008   EXPECT_DEVICE_TYPE(g, "assign", "FakeCPU");
1009 }
1010 
1011 // A reference connection exists between a variable and an assign,
1012 // where the assign has a device but the variable does not.  In this
1013 // case, the variable gets placed on the location of the assign
1014 // operation.
TEST_F(PlacerTest,TestReferenceConnectionNoSourceDevice)1015 TEST_F(PlacerTest, TestReferenceConnectionNoSourceDevice) {
1016   Status s;
1017   Graph g(OpRegistry::Global());
1018   {
1019     GraphDefBuilder b(GraphDefBuilder::kFailImmediately);
1020     Node* input = ops::SourceOp(
1021         "TestDevice",
1022         b.opts().WithName("in").WithDevice("/job:a/task:0/device:fakegpu:0"));
1023     Node* var = ops::SourceOp("TestVariable", b.opts().WithName("var_0"));
1024     ops::BinaryOp("TestAssign", var, input,
1025                   b.opts().WithName("assign").WithDevice(
1026                       "/job:a/task:0/device:fakecpu:0"));
1027     TF_EXPECT_OK(BuildGraph(b, &g));
1028   }
1029 
1030   s = Place(&g, false, false);
1031   TF_EXPECT_OK(s);
1032   EXPECT_DEVICE_TYPE(g, "var_0", "FakeCPU");
1033   EXPECT_DEVICE_TYPE(g, "assign", "FakeCPU");
1034 }
1035 
TEST_F(PlacerTest,TestColocationGroup)1036 TEST_F(PlacerTest, TestColocationGroup) {
1037   Graph g(OpRegistry::Global());
1038   {  // Scope for temporary variables used to construct g.
1039     GraphDefBuilder b(GraphDefBuilder::kFailImmediately);
1040     Node* input = ops::SourceOp("TestInput", b.opts().WithName("in"));
1041     Node* colocated_with_input = ops::UnaryOp(
1042         "TestRelu", input,
1043         b.opts().WithName("colocated_1").WithAttr("_class", {"loc:@in"}));
1044 
1045     // This will not be colocated with the input because TestInput is
1046     // only available on CPU and TestRelu will default to GPU.
1047     Node* not_colocated_with_input =
1048         ops::UnaryOp("TestRelu", input, b.opts().WithName("foo"));
1049     CHECK(colocated_with_input);
1050     CHECK(not_colocated_with_input);
1051     TF_EXPECT_OK(BuildGraph(b, &g));
1052   }
1053 
1054   TF_EXPECT_OK(Place(&g));
1055   EXPECT_COLOCATED(g, "in", "colocated_1");
1056   EXPECT_NOT_COLOCATED(g, "in", "foo");
1057 }
1058 
TEST_F(PlacerTest,TestMultipleColocationGroups)1059 TEST_F(PlacerTest, TestMultipleColocationGroups) {
1060   Graph g(OpRegistry::Global());
1061   {  // Scope for temporary variables used to construct g.
1062     GraphDefBuilder b(GraphDefBuilder::kFailImmediately);
1063     Node* input = ops::SourceOp("TestInput", b.opts().WithName("in"));
1064     Node* colocated_with_input = ops::UnaryOp(
1065         "TestRelu", input,
1066         b.opts().WithName("colocated_1").WithAttr("_class", {"loc:@in"}));
1067     Node* colocated_with_input_and_other =
1068         ops::UnaryOp("TestRelu", input,
1069                      b.opts().WithName("foo").WithAttr(
1070                          "_class", {"loc:@in", "loc:@colocated_1"}));
1071     CHECK(colocated_with_input);
1072     CHECK(colocated_with_input_and_other);
1073     TF_EXPECT_OK(BuildGraph(b, &g));
1074   }
1075 
1076   TF_EXPECT_OK(Place(&g));
1077   EXPECT_COLOCATED(g, "in", "colocated_1");
1078   EXPECT_COLOCATED(g, "in", "foo");
1079 }
1080 
TEST_P(SoftPlacementPlacerTest,TestInvalidMultipleColocationGroups)1081 TEST_P(SoftPlacementPlacerTest, TestInvalidMultipleColocationGroups) {
1082   Graph g(OpRegistry::Global());
1083   {  // Scope for temporary variables used to construct g.
1084     GraphDefBuilder b(GraphDefBuilder::kFailImmediately);
1085     Node* input = ops::SourceOp("TestInput", b.opts().WithName("in"));
1086     Node* colocated_with_input = ops::UnaryOp(
1087         "ReluCPU", input,
1088         b.opts().WithName("colocated_1").WithAttr("_class", {"loc:@in"}));
1089     Node* colocated_with_input_and_other =
1090         ops::UnaryOp("ReluGPU", input,
1091                      b.opts().WithName("foo").WithAttr(
1092                          "_class", {"loc:@in", "loc:@colocated_1"}));
1093     CHECK(colocated_with_input);
1094     CHECK(colocated_with_input_and_other);
1095     TF_EXPECT_OK(BuildGraph(b, &g));
1096   }
1097 
1098   bool allow_soft_placement = GetParam();
1099   Status s = Place(&g, allow_soft_placement, true);
1100   if (allow_soft_placement) {
1101     EXPECT_EQ(error::OK, s.code()) << s.ToString();
1102     EXPECT_DEVICE_TYPE(g, "in", "FakeCPU");
1103     EXPECT_DEVICE_TYPE(g, "colocated_1", "FakeCPU");
1104     EXPECT_DEVICE_TYPE(g, "foo", "FakeGPU");
1105   } else {
1106     EXPECT_TRUE(str_util::StrContains(
1107         s.error_message(),
1108         "Cannot colocate nodes {{colocation_node foo}} and "
1109         "{{colocation_node in}} because no device type supports both of those "
1110         "nodes and the other nodes colocated with them"))
1111         << s.ToString();
1112   }
1113 }
1114 
TEST_F(PlacerTest,TestColocationGroupWithReferenceConnections)1115 TEST_F(PlacerTest, TestColocationGroupWithReferenceConnections) {
1116   Graph g(OpRegistry::Global());
1117   {  // Scope for temporary variables used to construct g.
1118     GraphDefBuilder b(GraphDefBuilder::kFailImmediately);
1119     Node* input = ops::SourceOp("TestInput", b.opts().WithName("in"));
1120     Node* var1 = ops::SourceOp("VariableCPU", b.opts().WithName("var1"));
1121     Node* var2 = ops::SourceOp("VariableCPU", b.opts().WithName("var2"));
1122 
1123     // Two assigns (reference connections) with two different
1124     // colocation groups. Because their colocation groups all map to the
1125     // same device, this is a valid assignment.
1126     ops::BinaryOp(
1127         "TestAssign", var1, input,
1128         b.opts().WithName("assign1").WithAttr("_class", {"loc:@var1"}));
1129     ops::BinaryOp(
1130         "TestAssign", var2, input,
1131         b.opts().WithName("assign2").WithAttr("_class", {"loc:@var2"}));
1132     TF_EXPECT_OK(BuildGraph(b, &g));
1133   }
1134 
1135   TF_EXPECT_OK(Place(&g));
1136   EXPECT_COLOCATED(g, "in", "var1");
1137   EXPECT_COLOCATED(g, "in", "var2");
1138   EXPECT_COLOCATED(g, "var1", "assign2");
1139   EXPECT_COLOCATED(g, "var2", "assign1");
1140 }
1141 
TEST_P(SoftPlacementPlacerTest,TestColocationGroupWithUnsatisfiableReferenceConnections)1142 TEST_P(SoftPlacementPlacerTest,
1143        TestColocationGroupWithUnsatisfiableReferenceConnections) {
1144   Graph g(OpRegistry::Global());
1145   {  // Scope for temporary variables used to construct g.
1146     GraphDefBuilder b(GraphDefBuilder::kFailImmediately);
1147     Node* input = ops::SourceOp("TestInput", b.opts().WithName("in"));
1148 
1149     Node* var1 = ops::SourceOp("VariableCPU", b.opts().WithName("var1"));
1150     Node* var2 = ops::SourceOp("VariableCPU", b.opts().WithName("var2"));
1151     // Var 3 is on GPU
1152     Node* var3 = ops::SourceOp("VariableGPU", b.opts().WithName("var3"));
1153 
1154     // Two assigns (reference connections) with two different
1155     // colocation groups. Because their colocation groups all map to the
1156     // same device, this is a valid assignment.
1157     ops::BinaryOp(
1158         "TestAssign", var1, input,
1159         b.opts().WithName("assign1").WithAttr("_class", {"loc:@var1"}));
1160     ops::BinaryOp(
1161         "TestAssign", var2, input,
1162         b.opts().WithName("assign2").WithAttr("_class", {"loc:@var2"}));
1163     // Assign to var3, but try to use a colocation group that matches
1164     // the assign of var2.  This should fail because assign2 must be on CPU
1165     // (it has a reference edge on var2), and assign3 must be on GPU,
1166     // hence the conflict.
1167     ops::BinaryOp(
1168         "TestAssign", var3, input,
1169         b.opts().WithName("assign3").WithAttr("_class", {"loc:@var2"}));
1170     TF_EXPECT_OK(BuildGraph(b, &g));
1171   }
1172 
1173   bool allow_soft_placement = GetParam();
1174   Status s = Place(&g, allow_soft_placement, true);
1175   if (allow_soft_placement) {
1176     EXPECT_EQ(error::OK, s.code()) << s.ToString();
1177   } else {
1178     EXPECT_EQ(error::INVALID_ARGUMENT, s.code()) << s.ToString();
1179     EXPECT_TRUE(str_util::StrContains(
1180         s.error_message(),
1181         "Cannot colocate nodes {{colocation_node assign3}} and "
1182         "{{colocation_node var2}} because no device type supports both of "
1183         "those nodes and the other nodes colocated with them."))
1184         << s.ToString();
1185   }
1186 }
1187 
TEST_F(PlacerTest,TestColocationAndReferenceConnections)1188 TEST_F(PlacerTest, TestColocationAndReferenceConnections) {
1189   Graph g(OpRegistry::Global());
1190   {  // Scope for temporary variables used to construct g.
1191     GraphDefBuilder b(GraphDefBuilder::kFailImmediately);
1192     Node* input = ops::SourceOp("TestInput", b.opts().WithName("in"));
1193     for (int i = 0; i < 10; ++i) {
1194       // Declare ten variable and assignment pairs.
1195       Node* var = ops::SourceOp("TestVariable",
1196                                 b.opts().WithName(strings::StrCat("var_", i)));
1197       ops::BinaryOp("TestAssign", var, input,
1198                     b.opts().WithName(strings::StrCat("assign_", i)));
1199     }
1200     for (int i = 10; i < 100; ++i) {
1201       // Create a variable colocated with some existing variable, and
1202       // an assignment colocated with a possibly-different variable.
1203       Node* var = ops::SourceOp(
1204           "TestVariable",
1205           b.opts()
1206               .WithName(strings::StrCat("var_", i))
1207               .WithAttr("_class", {strings::StrCat("loc:@var_", i % 6)}));
1208       ops::BinaryOp(
1209           "TestAssign", var, input,
1210           b.opts()
1211               .WithName(strings::StrCat("assign_", i))
1212               .WithAttr("_class", {strings::StrCat("loc:@assign_", i % 3)}));
1213     }
1214     TF_EXPECT_OK(BuildGraph(b, &g));
1215   }
1216 
1217   TF_EXPECT_OK(Place(&g));
1218   for (int i = 0; i < 10; ++i) {
1219     EXPECT_COLOCATED(g, strings::StrCat("var_", i),
1220                      strings::StrCat("assign_", i));
1221   }
1222   for (int i = 10; i < 100; ++i) {
1223     EXPECT_COLOCATED(g, strings::StrCat("var_", i),
1224                      strings::StrCat("assign_", i));
1225     EXPECT_COLOCATED(g, strings::StrCat("var_", i),
1226                      strings::StrCat("var_", i % 6));
1227     EXPECT_COLOCATED(g, strings::StrCat("assign_", i),
1228                      strings::StrCat("assign_", i % 3));
1229   }
1230 }
1231 
1232 // Test that placement fails when no devices are registered.
TEST_F(PlacerTest,TestEmptyDeviceSet)1233 TEST_F(PlacerTest, TestEmptyDeviceSet) {
1234   Graph g(OpRegistry::Global());
1235   {  // Scope for temporary variables used to construct g.
1236     GraphDefBuilder b(GraphDefBuilder::kFailImmediately);
1237     ops::SourceOp("TestInput", b.opts().WithName("in"));
1238     TF_EXPECT_OK(BuildGraph(b, &g));
1239   }
1240 
1241   DeviceSet empty;
1242 
1243   Status s = Place(&g, &empty);
1244   EXPECT_TRUE(
1245       str_util::StrContains(s.error_message(), "No devices are registered"));
1246 }
1247 
1248 // Test that placement fails when the requested device forces an
1249 // indirect constraint to be violated.
TEST_F(PlacerTest,TestHeterogeneousDeviceSetFailure)1250 TEST_F(PlacerTest, TestHeterogeneousDeviceSetFailure) {
1251   Graph g(OpRegistry::Global());
1252   {  // Scope for temporary variables used to construct g.
1253     GraphDefBuilder b(GraphDefBuilder::kFailImmediately);
1254     Node* in = ops::SourceOp("TestInput", b.opts().WithName("in"));
1255     Node* var = ops::SourceOp("VariableGPU", b.opts().WithName("var"));
1256     ops::BinaryOp("TestAssign", var, in,
1257                   b.opts().WithName("assign").WithDevice("/job:b/task:1"));
1258     TF_EXPECT_OK(BuildGraph(b, &g));
1259   }
1260 
1261   DeviceSet heterogeneous;
1262   std::unique_ptr<Device> gpu(
1263       FakeDevice::MakeGPU("/job:b/replica:0/task:0/device:fakegpu:0"));
1264   heterogeneous.AddDevice(gpu.get());
1265   std::unique_ptr<Device> cpu(
1266       FakeDevice::MakeCPU("/job:b/replica:0/task:1/device:fakecpu:0"));
1267   heterogeneous.AddDevice(cpu.get());
1268   Status s = Place(&g, &heterogeneous);
1269   EXPECT_EQ(error::INVALID_ARGUMENT, s.code());
1270   EXPECT_TRUE(
1271       str_util::StrContains(s.error_message(),
1272                             "colocated with a group of nodes that required "
1273                             "incompatible device"));
1274 
1275   // The error message should contain information that indicates which
1276   // op types have which registered device types.
1277   EXPECT_TRUE(str_util::StrContains(s.error_message(), "VariableGPU: FakeGPU"))
1278       << s;
1279   EXPECT_TRUE(
1280       str_util::StrContains(s.error_message(), "TestAssign: FakeGPU FakeCPU"))
1281       << s;
1282 }
1283 
1284 // Test that placement fails when an unknown device is requested.
TEST_F(PlacerTest,TestUnknownDevice)1285 TEST_F(PlacerTest, TestUnknownDevice) {
1286   Graph g(OpRegistry::Global());
1287   {  // Scope for temporary variables used to construct g.
1288     GraphDefBuilder b(GraphDefBuilder::kFailImmediately);
1289     ops::SourceOp("TestInput", b.opts().WithName("in").WithDevice("/job:foo"));
1290     TF_EXPECT_OK(BuildGraph(b, &g));
1291   }
1292 
1293   Status s = Place(&g);
1294   EXPECT_EQ(error::INVALID_ARGUMENT, s.code());
1295   EXPECT_TRUE(str_util::StrContains(s.error_message(), "/job:foo"));
1296 }
1297 
1298 // Test that placement fails when the combination of partial
1299 // constraints leads to an unknown device.
TEST_F(PlacerTest,TestUnknownMergedDevice)1300 TEST_F(PlacerTest, TestUnknownMergedDevice) {
1301   Graph g(OpRegistry::Global());
1302   {  // Scope for temporary variables used to construct g.
1303     GraphDefBuilder b(GraphDefBuilder::kFailImmediately);
1304     ops::SourceOp("TestInput", b.opts().WithName("in").WithDevice("/job:foo"));
1305     TF_EXPECT_OK(BuildGraph(b, &g));
1306   }
1307 
1308   Status s = Place(&g);
1309   EXPECT_EQ(error::INVALID_ARGUMENT, s.code());
1310   EXPECT_TRUE(str_util::StrContains(s.error_message(), "/job:foo"));
1311 }
1312 
1313 // Test that placement fails when the previously-assigned device for a
1314 // node is unknown.
TEST_F(PlacerTest,TestUnknownAssignedDevice)1315 TEST_F(PlacerTest, TestUnknownAssignedDevice) {
1316   Graph g(OpRegistry::Global());
1317   {  // Scope for temporary variables used to construct g.
1318     GraphDefBuilder b(GraphDefBuilder::kFailImmediately);
1319     ops::SourceOp("TestInput", b.opts().WithName("in"));
1320     TF_EXPECT_OK(BuildGraph(b, &g));
1321   }
1322 
1323   GetNodeByName(g, "in")->set_assigned_device_name("/job:foo");
1324 
1325   Status s = Place(&g);
1326   EXPECT_EQ(error::INTERNAL, s.code());
1327   EXPECT_TRUE(str_util::StrContains(
1328       s.error_message(),
1329       "Assigned device '/job:foo' does not match any device"));
1330 }
1331 
1332 // Test that placement fails when an op with no registered kernels is
1333 // requested.
TEST_F(PlacerTest,TestNoKernelsRegistered)1334 TEST_F(PlacerTest, TestNoKernelsRegistered) {
1335   Graph g(OpRegistry::Global());
1336   {  // Scope for temporary variables used to construct g.
1337     GraphDefBuilder b(GraphDefBuilder::kFailImmediately);
1338     ops::SourceOp("VariableNoKernels", b.opts().WithName("var"));
1339     TF_EXPECT_OK(BuildGraph(b, &g));
1340   }
1341 
1342   Status s = Place(&g);
1343   EXPECT_EQ(error::INVALID_ARGUMENT, s.code());
1344   EXPECT_TRUE(
1345       str_util::StrContains(s.error_message(),
1346                             "No OpKernel was registered to support Op "
1347                             "'VariableNoKernels' used by {{node var}}"));
1348   EXPECT_TRUE(
1349       str_util::StrContains(s.error_message(), "<no registered kernels>"));
1350 }
1351 
1352 // Test that placement fails when a kernel is registered but no known
1353 // device supports it.
TEST_F(PlacerTest,TestNoDevicesRegistered)1354 TEST_F(PlacerTest, TestNoDevicesRegistered) {
1355   Graph g(OpRegistry::Global());
1356   {  // Scope for temporary variables used to construct g.
1357     GraphDefBuilder b(GraphDefBuilder::kFailImmediately);
1358     ops::SourceOp("VariableGPU", b.opts().WithName("var"));
1359     TF_EXPECT_OK(BuildGraph(b, &g));
1360   }
1361 
1362   DeviceSet cpu_only;
1363   std::unique_ptr<Device> cpu(
1364       FakeDevice::MakeCPU("/job:a/replica:0/task:0/device:fakecpu:0"));
1365   cpu_only.AddDevice(cpu.get());
1366 
1367   Status s = Place(&g, &cpu_only);
1368   EXPECT_EQ(error::INVALID_ARGUMENT, s.code());
1369   EXPECT_TRUE(str_util::StrContains(s.error_message(),
1370                                     "No OpKernel was registered to support Op "
1371                                     "'VariableGPU' used by {{node var}}"));
1372   EXPECT_TRUE(str_util::StrContains(s.error_message(), "device='FakeGPU'"));
1373 }
1374 
1375 // Test that placement fails when a requested device is malformed.
TEST_F(PlacerTest,TestMalformedDeviceSpecification)1376 TEST_F(PlacerTest, TestMalformedDeviceSpecification) {
1377   Graph g(OpRegistry::Global());
1378   {  // Scope for temporary variables used to construct g.
1379     GraphDefBuilder b(GraphDefBuilder::kFailImmediately);
1380     ops::SourceOp("TestInput", b.opts().WithName("in").WithDevice("/foo:bar"));
1381     TF_EXPECT_OK(BuildGraph(b, &g));
1382   }
1383 
1384   Status s = Place(&g);
1385   EXPECT_EQ(error::INVALID_ARGUMENT, s.code());
1386   EXPECT_TRUE(str_util::StrContains(
1387       s.error_message(), "Malformed device specification '/foo:bar'"));
1388 }
1389 
1390 // Test that placement fails when a previously-assigned device is malformed.
TEST_F(PlacerTest,TestMalformedAssignedDevice)1391 TEST_F(PlacerTest, TestMalformedAssignedDevice) {
1392   Graph g(OpRegistry::Global());
1393   {  // Scope for temporary variables used to construct g.
1394     GraphDefBuilder b(GraphDefBuilder::kFailImmediately);
1395     ops::SourceOp("TestInput", b.opts().WithName("in"));
1396     TF_EXPECT_OK(BuildGraph(b, &g));
1397   }
1398 
1399   GetNodeByName(g, "in")->set_assigned_device_name("/foo:bar");
1400 
1401   Status s = Place(&g);
1402   EXPECT_EQ(error::INTERNAL, s.code());
1403   EXPECT_TRUE(str_util::StrContains(s.error_message(),
1404                                     "Malformed assigned device '/foo:bar'"));
1405 }
1406 
1407 // Test that placement fails when a device was previously assigned to
1408 // a node, but it does not uniquely identify a particular device.
TEST_F(PlacerTest,TestNonUniqueAssignedDevice)1409 TEST_F(PlacerTest, TestNonUniqueAssignedDevice) {
1410   Graph g(OpRegistry::Global());
1411   {  // Scope for temporary variables used to construct g.
1412     GraphDefBuilder b(GraphDefBuilder::kFailImmediately);
1413     ops::SourceOp("TestInput", b.opts().WithName("in"));
1414     TF_EXPECT_OK(BuildGraph(b, &g));
1415   }
1416 
1417   GetNodeByName(g, "in")->set_assigned_device_name("/job:a");
1418 
1419   Status s = Place(&g);
1420   EXPECT_EQ(error::INTERNAL, s.code());
1421   EXPECT_TRUE(str_util::StrContains(
1422       s.error_message(), "Assigned device '/job:a' does not match any device"));
1423 }
1424 
1425 // Test that ops request to be placed on non-existent devices will be relocated
1426 // to existing device of the same type if allow_soft_placement is set.
TEST_F(PlacerTest,TestNonexistentGpuAllowSoftPlacement)1427 TEST_F(PlacerTest, TestNonexistentGpuAllowSoftPlacement) {
1428   Graph g(OpRegistry::Global());
1429   {  // Scope for temporary variables used to construct g.
1430     GraphDefBuilder b(GraphDefBuilder::kFailImmediately);
1431     ops::SourceOp("TestDevice",
1432                   b.opts().WithName("in").WithDevice("/device:fakegpu:11"));
1433     TF_EXPECT_OK(BuildGraph(b, &g));
1434   }
1435 
1436   TF_EXPECT_OK(Place(&g, true, false));
1437   EXPECT_DEVICE_CONTAINS(g, "in", "/device:fakegpu:0");
1438 }
1439 
1440 // Test that ops request to be placed on non-existent devices will fail if
1441 // allow_soft_placement is not set.
TEST_F(PlacerTest,TestNonexistentGpuNoAllowSoftPlacement)1442 TEST_F(PlacerTest, TestNonexistentGpuNoAllowSoftPlacement) {
1443   Graph g(OpRegistry::Global());
1444   {  // Scope for temporary variables used to construct g.
1445     GraphDefBuilder b(GraphDefBuilder::kFailImmediately);
1446     ops::SourceOp("TestDevice",
1447                   b.opts().WithName("in").WithDevice("/device:fakegpu:11"));
1448     TF_EXPECT_OK(BuildGraph(b, &g));
1449   }
1450 
1451   Status s = Place(&g, false, false);
1452   EXPECT_EQ(error::INVALID_ARGUMENT, s.code());
1453   EXPECT_TRUE(str_util::StrContains(s.error_message(), "/device:fakegpu:11"));
1454 }
1455 
1456 // Test that the "Cannot assign a device" error message contains a format tag
1457 // when requested.
TEST_F(PlacerTest,TestNonexistentGpuNoAllowSoftPlacementFormatTag)1458 TEST_F(PlacerTest, TestNonexistentGpuNoAllowSoftPlacementFormatTag) {
1459   Graph g(OpRegistry::Global());
1460   {  // Scope for temporary variables used to construct g.
1461     GraphDefBuilder b(GraphDefBuilder::kFailImmediately);
1462     ops::SourceOp("TestDevice",
1463                   b.opts().WithName("in").WithDevice("/device:fakegpu:11"));
1464     TF_EXPECT_OK(BuildGraph(b, &g));
1465   }
1466 
1467   Status s = Place(&g, false, false);
1468   EXPECT_EQ(error::INVALID_ARGUMENT, s.code());
1469   LOG(WARNING) << s.error_message();
1470   EXPECT_TRUE(str_util::StrContains(s.error_message(),
1471                                     "Cannot assign a device for operation in"));
1472   EXPECT_TRUE(str_util::StrContains(s.error_message(), "{{node in}}"));
1473 }
1474 
1475 // Test that placement fails when a node requests an explicit device that is not
1476 // supported by the registered kernels if allow_soft_placement is no set.
TEST_F(PlacerTest,TestUnsupportedDeviceNoAllowSoftPlacement)1477 TEST_F(PlacerTest, TestUnsupportedDeviceNoAllowSoftPlacement) {
1478   Graph g(OpRegistry::Global());
1479   {  // Scope for temporary variables used to construct g.
1480     GraphDefBuilder b(GraphDefBuilder::kFailImmediately);
1481     ops::SourceOp("VariableGPU",
1482                   b.opts().WithName("var").WithDevice("/device:fakecpu:0"));
1483     TF_EXPECT_OK(BuildGraph(b, &g));
1484   }
1485 
1486   Status s = Place(&g, false, false);
1487   EXPECT_EQ(error::INVALID_ARGUMENT, s.code());
1488   EXPECT_TRUE(str_util::StrContains(s.error_message(), "/device:fakecpu:0"));
1489   EXPECT_TRUE(str_util::StrContains(
1490       s.error_message(),
1491       "no supported kernel for fakecpu devices is available"));
1492 }
1493 
1494 // Test that placement fails when a node requests an explicit device that is not
1495 // supported by the registered kernels if allow_soft_placement is no set.
TEST_F(PlacerTest,TestNonExistentDevice)1496 TEST_F(PlacerTest, TestNonExistentDevice) {
1497   Graph g(OpRegistry::Global());
1498   {  // Scope for temporary variables used to construct g.
1499     GraphDefBuilder b(GraphDefBuilder::kFailImmediately);
1500     ops::SourceOp("VariableGPU",
1501                   b.opts().WithName("var").WithDevice("/job:foo/replica:17"));
1502     TF_EXPECT_OK(BuildGraph(b, &g));
1503   }
1504 
1505   Status s = Place(&g, false, false);
1506   EXPECT_EQ(error::INVALID_ARGUMENT, s.code());
1507   LOG(WARNING) << s.error_message();
1508   EXPECT_TRUE(str_util::StrContains(
1509       s.error_message(), "was explicitly assigned to /job:foo/replica:17"));
1510   EXPECT_TRUE(
1511       str_util::StrContains(s.error_message(), "but available devices"));
1512 }
1513 
1514 #if !GOOGLE_CUDA
1515 // Test that we inform the user if they appear to be explicitly placing nodes
1516 // on a GPU when CUDA is not available
TEST_F(PlacerTest,TestUseGpuWithNoCuda)1517 TEST_F(PlacerTest, TestUseGpuWithNoCuda) {
1518   Graph g(OpRegistry::Global());
1519   {  // Scope for temporary variables used to construct g.
1520     GraphDefBuilder b(GraphDefBuilder::kFailImmediately);
1521     ops::SourceOp("VariableGPU",
1522                   b.opts().WithName("var").WithDevice("/device:gpu:0"));
1523     TF_EXPECT_OK(BuildGraph(b, &g));
1524   }
1525 
1526   Status s = Place(&g, false, false);
1527   EXPECT_EQ(error::INVALID_ARGUMENT, s.code());
1528   LOG(WARNING) << s.error_message();
1529   EXPECT_TRUE(str_util::StrContains(
1530       s.error_message(),
1531       "The requested device appears to be a GPU, but CUDA is not enabled."));
1532 }
1533 #endif
1534 
TEST_F(PlacerTest,TestUnsupportedDeviceAllowSoftPlacement)1535 TEST_F(PlacerTest, TestUnsupportedDeviceAllowSoftPlacement) {
1536   Graph g(OpRegistry::Global());
1537   {  // Scope for temporary variables used to construct g.
1538     GraphDefBuilder b(GraphDefBuilder::kFailImmediately);
1539     ops::SourceOp("VariableGPU",
1540                   b.opts().WithName("var").WithDevice("/device:fakecpu:0"));
1541     TF_EXPECT_OK(BuildGraph(b, &g));
1542   }
1543 
1544   TF_EXPECT_OK(Place(&g, true, false));
1545 }
1546 
1547 // Test that a graph with device type and reference constraints on
1548 // some of the ops will successfully assign nodes to the constrained
1549 // device, and colocate nodes with reference connections.
TEST_F(PlacerTest,TestDeviceTypeConstraintsAllowSoftPlacement)1550 TEST_F(PlacerTest, TestDeviceTypeConstraintsAllowSoftPlacement) {
1551   Graph g(OpRegistry::Global());
1552   {  // Scope for temporary variables used to construct g.
1553     GraphDefBuilder b(GraphDefBuilder::kFailImmediately);
1554     // var_gpu has ref output and runs on GPU.
1555     // force_gpu takes var_gpu and requested CPU.
1556     // Verify that both are placed on GPU.
1557     Node* var_gpu = ops::SourceOp("VariableGPU", b.opts().WithName("var_gpu"));
1558     ops::UnaryOp(
1559         "TestDeviceEnforce", var_gpu,
1560         b.opts().WithName("force_gpu").WithDevice("/device:fakecpu:0"));
1561     // var_cpu has ref output and runs on CPU.
1562     // force_cpu takes var_cpu and requested GPU.
1563     // Verify that both are placed on CPU.
1564     Node* var_cpu = ops::SourceOp("VariableCPU", b.opts().WithName("var_cpu"));
1565     ops::UnaryOp(
1566         "TestDeviceEnforce", var_cpu,
1567         b.opts().WithName("force_cpu").WithDevice("/device:fakegpu:0"));
1568     TF_EXPECT_OK(BuildGraph(b, &g));
1569   }
1570 
1571   TF_EXPECT_OK(Place(&g, true, false));
1572   EXPECT_DEVICE_TYPE(g, "var_gpu", "FakeGPU");
1573   EXPECT_DEVICE_TYPE(g, "force_gpu", "FakeGPU");
1574   EXPECT_COLOCATED(g, "var_gpu", "force_gpu");
1575   EXPECT_DEVICE_TYPE(g, "var_cpu", "FakeCPU");
1576   EXPECT_DEVICE_TYPE(g, "force_cpu", "FakeCPU");
1577   EXPECT_COLOCATED(g, "var_cpu", "force_cpu");
1578 }
1579 
1580 // Test that placement fails when two nodes have a reference connection
1581 // constraint, and each node requires a mutually incompatible device.
TEST_F(PlacerTest,TestUnsatisfiableConstraintWithReferenceConnections)1582 TEST_F(PlacerTest, TestUnsatisfiableConstraintWithReferenceConnections) {
1583   Graph g(OpRegistry::Global());
1584   {  // Scope for temporary variables used to construct g.
1585     GraphDefBuilder b(GraphDefBuilder::kFailImmediately);
1586     Node* var = ops::SourceOp("VariableGPU", b.opts().WithName("var"));
1587     Node* input = ops::SourceOp("TestInput", b.opts().WithName("in"));
1588     ops::BinaryOp("AssignCPU", var, input, b.opts().WithName("assign"));
1589     TF_EXPECT_OK(BuildGraph(b, &g));
1590   }
1591 
1592   Status s = Place(&g);
1593   EXPECT_EQ(error::INVALID_ARGUMENT, s.code());
1594   EXPECT_TRUE(str_util::StrContains(s.error_message(),
1595                                     "Cannot colocate nodes {{colocation_node "
1596                                     "var}} and {{colocation_node assign}}"));
1597 }
1598 
1599 // Test that a generator node follows its consumers (where there are several
1600 // consumer nodes on the same devices).
TEST_F(PlacerTest,TestGeneratorNodeFollowsConsumerNode)1601 TEST_F(PlacerTest, TestGeneratorNodeFollowsConsumerNode) {
1602   Graph g(OpRegistry::Global());
1603   {  // Scope for temporary variables used to construct g.
1604     GraphDefBuilder b(GraphDefBuilder::kFailImmediately);
1605 
1606     // A variable is only on CPU
1607     Node* var1_cpu =
1608         ops::SourceOp("VariableCPU", b.opts().WithName("var1_cpu"));
1609     Node* var2_cpu =
1610         ops::SourceOp("VariableCPU", b.opts().WithName("var2_cpu"));
1611 
1612     // The constant to be assigned can be on both GPU or CPU.
1613     //
1614     // Because of the heuristic, it gets placed on CPU to avoid a
1615     // copy.
1616     Node* input = ops::SourceOp("TestCPUGPUOutput", b.opts().WithName("in"));
1617 
1618     // The assigns are bound to CPU by the reference edge.
1619     ops::BinaryOp("TestAssign", var1_cpu, input, b.opts().WithName("assign1"));
1620     ops::BinaryOp("TestAssign", var2_cpu, input, b.opts().WithName("assign2"));
1621 
1622     TF_EXPECT_OK(BuildGraph(b, &g));
1623   }
1624 
1625   TF_EXPECT_OK(Place(&g));
1626   EXPECT_COLOCATED(g, "var1_cpu", "in");
1627   EXPECT_COLOCATED(g, "assign1", "in");
1628   EXPECT_COLOCATED(g, "var2_cpu", "in");
1629   EXPECT_COLOCATED(g, "assign2", "in");
1630 }
1631 
1632 // Test that a generator node does not follow its consumers (where there are
1633 // several consumers on different devices).
TEST_F(PlacerTest,TestGeneratorNodeDoesntFollowNonColocatedConsumers)1634 TEST_F(PlacerTest, TestGeneratorNodeDoesntFollowNonColocatedConsumers) {
1635   Graph g(OpRegistry::Global());
1636   {  // Scope for temporary variables used to construct g.
1637     GraphDefBuilder b(GraphDefBuilder::kFailImmediately);
1638 
1639     // A variable is only on CPU
1640     Node* var1_cpu =
1641         ops::SourceOp("VariableCPU", b.opts().WithName("var1_cpu"));
1642     Node* var2_cpu =
1643         ops::SourceOp("VariableCPU", b.opts().WithName("var2_cpu"));
1644 
1645     // The constant to be assigned can be on both GPU or CPU.
1646     //
1647     // Because of the heuristic, it ought to be on the GPU (cannot be
1648     // co-located with both consumers, so goes to the 'standard' place)
1649     Node* input = ops::SourceOp("TestCPUGPUOutput", b.opts().WithName("in"));
1650 
1651     // The assigns are bound to CPU by the reference edge.
1652     ops::BinaryOp("TestAssign", var1_cpu, input, b.opts().WithName("assign1"));
1653     ops::BinaryOp("TestAssign", var2_cpu, input, b.opts().WithName("assign2"));
1654 
1655     TF_EXPECT_OK(BuildGraph(b, &g));
1656 
1657     GetNodeByName(g, "var1_cpu")
1658         ->set_assigned_device_name("/job:a/replica:0/task:0/device:fakecpu:1");
1659 
1660     GetNodeByName(g, "var2_cpu")
1661         ->set_assigned_device_name("/job:a/replica:0/task:0/device:fakecpu:2");
1662   }
1663 
1664   TF_EXPECT_OK(Place(&g));
1665   EXPECT_COLOCATED(g, "assign1", "var1_cpu");
1666   EXPECT_COLOCATED(g, "assign2", "var2_cpu");
1667   EXPECT_DEVICE_TYPE(g, "in", "FakeGPU");
1668 }
1669 
1670 REGISTER_KERNEL_BUILDER(Name("_Arg").Device("FakeCPU"), DummyOp);
1671 REGISTER_KERNEL_BUILDER(Name("_Arg").Device("FakeGPU"), DummyOp);
1672 REGISTER_KERNEL_BUILDER(Name("_Retval").Device("FakeCPU"), DummyOp);
1673 REGISTER_KERNEL_BUILDER(Name("_Retval").Device("FakeGPU"), DummyOp);
1674 REGISTER_KERNEL_BUILDER(Name("Identity").Device("FakeCPU"), DummyOp);
1675 REGISTER_KERNEL_BUILDER(Name("Identity").Device("FakeGPU"), DummyOp);
1676 REGISTER_KERNEL_BUILDER(Name("Const").Device("FakeCPU"), DummyOp);
1677 REGISTER_KERNEL_BUILDER(Name("Const").Device("FakeGPU"), DummyOp);
1678 REGISTER_KERNEL_BUILDER(Name("Mul").Device("FakeCPU"), DummyOp);
1679 REGISTER_KERNEL_BUILDER(Name("Mul").Device("FakeGPU"), DummyOp);
1680 REGISTER_KERNEL_BUILDER(Name("Add").Device("FakeCPU"), DummyOp);
1681 REGISTER_KERNEL_BUILDER(Name("Add").Device("FakeGPU"), DummyOp);
1682 
TEST_P(SoftPlacementPlacerTest,RequestedDeviceOnResourceGeneratorIsTreatedAsAssigned)1683 TEST_P(SoftPlacementPlacerTest,
1684        RequestedDeviceOnResourceGeneratorIsTreatedAsAssigned) {
1685   /*
1686    *    a:RES:GPU  b:RES:CPU
1687    *       |         |
1688    *       |         |
1689    *       v         v
1690    *      id1       id2
1691    *     @loc:id2
1692    */
1693   FunctionDef func = test::function::ResourceOutput();
1694   GraphDef graph = GDef(
1695       {
1696           NDef("a", "_Arg", {}, {{"T", DT_RESOURCE}}, kGPU),
1697           NDef("b", "_Arg", {}, {{"T", DT_RESOURCE}}, kCPU),
1698           NDef("id1", "Identity", {"a"},
1699                {{"T", DT_RESOURCE},
1700                 {"_class", gtl::ArraySlice<string>({"loc:@id2"})}}),
1701           NDef("id2", "Identity", {"b"}, {{"T", DT_RESOURCE}}),
1702       },
1703       // FunctionLib
1704       {func});
1705 
1706   Graph g(OpRegistry::Global());
1707   TF_ASSERT_OK(BuildGraph(graph, &g));
1708 
1709   bool allow_soft_placement = GetParam();
1710   Status s = Place(&g, allow_soft_placement, true);
1711   if (allow_soft_placement) {
1712     EXPECT_EQ(error::OK, s.code()) << s.ToString();
1713     EXPECT_DEVICE_TYPE(g, "a", "FakeGPU");
1714     EXPECT_DEVICE_TYPE(g, "id1", "FakeGPU");
1715     EXPECT_DEVICE_TYPE(g, "b", "FakeCPU");
1716     EXPECT_DEVICE_TYPE(g, "id2", "FakeCPU");
1717   } else {
1718     EXPECT_EQ(error::INVALID_ARGUMENT, s.code());
1719     EXPECT_TRUE(str_util::StrContains(
1720         s.error_message(),
1721         "Cannot colocate nodes {{colocation_node id2}} and {{colocation_node "
1722         "id1}}: Cannot merge devices with incompatible types: "
1723         "'/device:fakecpu:0' and '/device:fakegpu:0'"))
1724         << s.ToString();
1725   }
1726 }
1727 
TEST_F(PlacerTest,RequestedDeviceCanBeOverridden)1728 TEST_F(PlacerTest, RequestedDeviceCanBeOverridden) {
1729   /*
1730    *     a:RES      b:RES
1731    *       |         |
1732    *     id_a:GPU   id_b:CPU
1733    *       |         |
1734    *       v         v
1735    *      id1       id2
1736    *     @loc:id2
1737    */
1738   FunctionDef func = test::function::ResourceOutput();
1739   GraphDef graph = GDef(
1740       {
1741           NDef("a", "_Arg", {}, {{"T", DT_RESOURCE}}),
1742           NDef("b", "_Arg", {}, {{"T", DT_RESOURCE}}),
1743           NDef("id_a", "Identity", {"a"}, {{"T", DT_RESOURCE}}, kGPU),
1744           NDef("id_b", "Identity", {"b"}, {{"T", DT_RESOURCE}}, kCPU),
1745           NDef("id1", "Identity", {"id_a"},
1746                {{"T", DT_RESOURCE},
1747                 {"_class", gtl::ArraySlice<string>({"loc:@id2"})}}),
1748           NDef("id2", "Identity", {"id_b"}, {{"T", DT_RESOURCE}}),
1749       },
1750       // FunctionLib
1751       {func});
1752 
1753   Graph g(OpRegistry::Global());
1754   TF_ASSERT_OK(BuildGraph(graph, &g));
1755   TF_ASSERT_OK(Place(&g));
1756 
1757   // All should be colocated
1758   EXPECT_COLOCATED(g, "a", "b");
1759   EXPECT_COLOCATED(g, "id_a", "id_b");
1760   EXPECT_COLOCATED(g, "id1", "id2");
1761   EXPECT_COLOCATED(g, "a", "id_a");
1762   EXPECT_COLOCATED(g, "a", "id1");
1763 }
1764 
TEST_P(SoftPlacementPlacerTest,AssignedDevicesAreNotOverriddenDueToResourcesAndColocation)1765 TEST_P(SoftPlacementPlacerTest,
1766        AssignedDevicesAreNotOverriddenDueToResourcesAndColocation) {
1767   /*
1768    *     a:RES      b:RES
1769    *       |         |
1770    *     id_a:GPU   id_b:CPU
1771    *       |         |
1772    *       v         v
1773    *      id1       id2
1774    *     @loc:id2
1775    */
1776   FunctionDef func = test::function::ResourceOutput();
1777   GraphDef graph = GDef(
1778       {
1779           NDef("a", "_Arg", {}, {{"T", DT_RESOURCE}}),
1780           NDef("b", "_Arg", {}, {{"T", DT_RESOURCE}}),
1781           NDef("id_a", "Identity", {"a"}, {{"T", DT_RESOURCE}}),
1782           NDef("id_b", "Identity", {"b"}, {{"T", DT_RESOURCE}}),
1783           NDef("id1", "Identity", {"id_a"},
1784                {{"T", DT_RESOURCE},
1785                 {"_class", gtl::ArraySlice<string>({"loc:@id2"})}}),
1786           NDef("id2", "Identity", {"id_b"}, {{"T", DT_RESOURCE}}),
1787       },
1788       // FunctionLib
1789       {func});
1790 
1791   Graph g(OpRegistry::Global());
1792   TF_ASSERT_OK(BuildGraph(graph, &g));
1793   std::unordered_map<string, Node*> nodes = g.BuildNodeNameIndex();
1794   GetNodeByName(g, "id_a")->set_assigned_device_name(kFullGPU);
1795   GetNodeByName(g, "id_b")->set_assigned_device_name(kFullCPU);
1796 
1797   bool allow_soft_placement = GetParam();
1798 
1799   Status s = Place(&g, allow_soft_placement, false);
1800   if (allow_soft_placement) {
1801     EXPECT_EQ(error::OK, s.code()) << s.ToString();
1802     EXPECT_DEVICE_TYPE(g, "a", "FakeGPU");
1803     EXPECT_DEVICE_TYPE(g, "id_a", "FakeGPU");
1804     EXPECT_DEVICE_TYPE(g, "id1", "FakeGPU");
1805     EXPECT_DEVICE_TYPE(g, "b", "FakeCPU");
1806     EXPECT_DEVICE_TYPE(g, "id_b", "FakeCPU");
1807     EXPECT_DEVICE_TYPE(g, "id2", "FakeCPU");
1808   } else {
1809     EXPECT_EQ(error::INVALID_ARGUMENT, s.code());
1810     EXPECT_TRUE(str_util::StrContains(
1811         s.error_message(),
1812         "Cannot colocate nodes {{colocation_node id2}} and {{colocation_node "
1813         "id1}}: Cannot merge devices with incompatible types: "
1814         "'/job:a/replica:0/task:0/device:fakecpu:0' and "
1815         "'/job:a/replica:0/task:0/device:fakegpu:0'"))
1816         << s.ToString();
1817   }
1818 }
1819 
1820 }  // namespace
1821 }  // namespace tensorflow
1822