1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 // This is an explorative prototype emitter for convolution using MLIR.
17 // This prototype is still under construction.
18 // TODO(timshen): Fix the documentation once it's implemented.
19 //
20 // Goals:
21 // * Autotune-able tiling.
22 // * Autotune-able memory accesses.
23 // * Autotune-able lowering logic (from a portable program to thread-oriented
24 //   CUDA program).
25 // * Use milr::AffineExpr to analyze all accesses. It aims to algorithmically
26 //   find memory access strategies for given input layouts and tiling configs.
27 
28 #include "tensorflow/compiler/mlir/xla/experimental/conv_emitter/conv_emitter.h"
29 
30 #include "absl/types/span.h"
31 #include "llvm/ADT/STLExtras.h"
32 #include "llvm/Support/raw_ostream.h"
33 #include "mlir/Dialect/Affine/IR/AffineOps.h"  // from @llvm-project
34 #include "mlir/Dialect/StandardOps/IR/Ops.h"  // from @llvm-project
35 #include "mlir/IR/AffineExpr.h"  // from @llvm-project
36 #include "mlir/IR/AffineMap.h"  // from @llvm-project
37 #include "mlir/IR/Builders.h"  // from @llvm-project
38 #include "mlir/IR/BuiltinTypes.h"  // from @llvm-project
39 #include "mlir/Transforms/LoopUtils.h"  // from @llvm-project
40 #include "mlir/Transforms/RegionUtils.h"  // from @llvm-project
41 #include "tensorflow/compiler/mlir/xla/experimental/conv_emitter/conv_emitter_transforms.h"
42 #include "tensorflow/compiler/xla/permutation_util.h"
43 #include "tensorflow/compiler/xla/service/llvm_ir/llvm_util.h"
44 #include "tensorflow/compiler/xla/window_util.h"
45 
46 namespace xla {
47 namespace experimental {
48 namespace {
49 
50 using mlir::OpBuilder;
51 
52 // Various extracted information for input shapes.
53 struct ShapeInfo {
54   // Buffer dimensions in the order of NCHW.
55   std::vector<int64_t> nchw_dimensions;
56 
57   // Buffer dimensions in the order of major to minor;
58   std::vector<int64_t> physical_dimensions;
59 
60   // The affine map that takes NCHW indices, and maps to the physical order.
61   mlir::AffineMap affine_map;
62 
63   mlir::Type element_type;
64 };
65 
GetShapeInfo(const Shape & shape,int64 n_dim,int64 c_dim,absl::Span<const tensorflow::protobuf_int64> spatial_dims,mlir::Builder builder)66 ShapeInfo GetShapeInfo(
67     const Shape& shape, int64 n_dim, int64 c_dim,
68     absl::Span<const tensorflow::protobuf_int64> spatial_dims,
69     mlir::Builder builder) {
70   ShapeInfo shape_info;
71 
72   std::vector<int64> physical_to_logical(
73       shape.layout().minor_to_major().rbegin(),
74       shape.layout().minor_to_major().rend());
75 
76   std::vector<int64> nchw_to_logical;
77 
78   nchw_to_logical.push_back(n_dim);
79   nchw_to_logical.push_back(c_dim);
80   for (int64 dim : spatial_dims) {
81     nchw_to_logical.push_back(dim);
82   }
83 
84   for (int64 dim : nchw_to_logical) {
85     shape_info.nchw_dimensions.push_back(shape.dimensions(dim));
86   }
87 
88   for (int64 dim : physical_to_logical) {
89     shape_info.physical_dimensions.push_back(shape.dimensions(dim));
90   }
91 
92   std::vector<mlir::AffineExpr> affine_exprs;
93   // We want physical to nchw order.
94   for (int64 dim : ComposePermutations(InversePermutation(nchw_to_logical),
95                                        physical_to_logical)) {
96     affine_exprs.push_back(builder.getAffineDimExpr(dim));
97   }
98 
99   shape_info.affine_map = mlir::AffineMap::get(
100       /*dimCount=*/2 + spatial_dims.size(), /*symbolCount=*/0, affine_exprs,
101       builder.getContext());
102 
103   shape_info.element_type = [&] {
104     switch (shape.element_type()) {
105       case xla::F16:
106         return builder.getF16Type();
107       case xla::F32:
108         return builder.getF32Type();
109       default:
110         break;
111     }
112     CHECK(false);
113   }();
114 
115   return shape_info;
116 }
117 
SetMemRef(mlir::Operation * op,mlir::Value memref)118 void SetMemRef(mlir::Operation* op, mlir::Value memref) {
119   if (auto load = mlir::dyn_cast<mlir::AffineLoadOp>(op)) {
120     load.setMemRef(memref);
121   } else if (auto store = mlir::dyn_cast<mlir::AffineStoreOp>(op)) {
122     store.setMemRef(memref);
123   } else {
124     CHECK(false);
125   }
126 }
127 
128 // Hoist operations out of `where`. [begin_op, end_op) must be the first
129 // operations of their parent loop, and `where` must be an ancestor of that
130 // parent loop.
131 //
132 // It always preserves the semantics of the program, therefore it may modify the
133 // hoisted operations or add extra loops at the hoisted place.
HoistAndFix(llvm::iplist<mlir::Operation>::iterator begin_op,llvm::iplist<mlir::Operation>::iterator end_op,mlir::AffineForOp where)134 mlir::Operation* HoistAndFix(llvm::iplist<mlir::Operation>::iterator begin_op,
135                              llvm::iplist<mlir::Operation>::iterator end_op,
136                              mlir::AffineForOp where) {
137   // All loops to hoist through.
138   llvm::SmallVector<mlir::AffineForOp, 4> ancestors;
139   getPerfectlyNestedLoops(ancestors, where);
140   {
141     int i;
142     for (i = 0; i < ancestors.size(); i++) {
143       if (&ancestors[i].getBody()->front() == &*begin_op) {
144         break;
145       }
146     }
147     CHECK(i < ancestors.size());
148     ancestors.resize(i + 1);
149   }
150 
151   std::vector<int64_t> ancestor_dimensions;
152   for (auto ancestor : ancestors) {
153     CHECK(IsSimpleLoop(ancestor));
154     ancestor_dimensions.push_back(
155         ancestor.getUpperBoundMap().getSingleConstantResult());
156   }
157 
158   if (auto alloc = mlir::dyn_cast<mlir::AllocOp>(begin_op)) {
159     CHECK(std::next(begin_op) == end_op)
160         << "alloc() needs to be hoisted by its own";
161 
162     OpBuilder builder(where);
163     mlir::MemRefType type = alloc.getType();
164     CHECK(type.getAffineMaps().empty());
165     ancestor_dimensions.insert(ancestor_dimensions.end(),
166                                type.getShape().begin(), type.getShape().end());
167     mlir::MemRefType new_type =
168         mlir::MemRefType::get(ancestor_dimensions, type.getElementType());
169     auto new_alloc =
170         builder.create<mlir::AllocOp>(builder.getUnknownLoc(), new_type);
171 
172     std::vector<mlir::Value> indvars;
173     for (auto ancestor : ancestors) {
174       indvars.push_back(ancestor.getInductionVar());
175     }
176     for (auto& use : llvm::make_early_inc_range(alloc.getResult().getUses())) {
177       mlir::Operation* owner = use.getOwner();
178       BoundAffineMap affine_map = GetBoundAffineMapFrom(owner);
179       affine_map.operands.insert(affine_map.operands.begin(), indvars.begin(),
180                                  indvars.end());
181       CHECK(affine_map.affine_map.isIdentity());
182       affine_map.affine_map = mlir::AffineMap::getMultiDimIdentityMap(
183           affine_map.operands.size(), builder.getContext());
184 
185       mlir::Operation* new_op =
186           CloneWithNewAffineMap(owner, affine_map, OpBuilder(owner));
187       SetMemRef(new_op, new_alloc);
188       owner->replaceAllUsesWith(new_op);
189       owner->erase();
190     }
191     alloc.erase();
192     return new_alloc;
193   }
194 
195   const bool any_op_is_loop_variant = [&] {
196     for (mlir::Operation& op : llvm::make_range(begin_op, end_op)) {
197       if (mlir::isa<mlir::AffineForOp, mlir::AffineStoreOp>(op)) {
198         return true;
199       }
200     }
201     return false;
202   }();
203 
204   if (any_op_is_loop_variant) {
205     auto builder = OpBuilder(where);
206     std::vector<mlir::AffineForOp> new_loops;
207     for (auto dim : ancestor_dimensions) {
208       auto where =
209           builder.create<mlir::AffineForOp>(builder.getUnknownLoc(), 0, dim);
210       new_loops.push_back(where);
211       builder = OpBuilder::atBlockTerminator(where.getBody());
212     }
213     for (mlir::Operation& op :
214          llvm::make_early_inc_range(llvm::make_range(begin_op, end_op))) {
215       op.moveBefore(&new_loops.back().getBody()->back());
216     }
217     CHECK_EQ(ancestors.size(), new_loops.size());
218     for (int i = 0; i < ancestors.size(); i++) {
219       replaceAllUsesInRegionWith(ancestors[i].getInductionVar(),
220                                  new_loops[i].getInductionVar(),
221                                  new_loops.back().region());
222     }
223     return new_loops.front();
224   }
225   CHECK(false);
226 }
227 
HoistAndFix(mlir::Operation * op,mlir::AffineForOp where)228 mlir::Operation* HoistAndFix(mlir::Operation* op, mlir::AffineForOp where) {
229   return HoistAndFix(op->getIterator(), std::next(op->getIterator()), where);
230 }
231 
232 struct InitialMlirConvAnchors {
233   std::vector<mlir::AffineForOp> cartesian_product_loops;
234   std::vector<mlir::AffineForOp> reduction_loops;
235   mlir::AllocOp output_acc;
236 };
237 
238 // Return the following IR with the anchors set to corresponding operations.
239 //   for (cartesian loops...) {
240 //     %output_acc = alloc() : memref(f32)
241 //     output_acc[] = 0
242 //     for (reduction loops...) {
243 //       output_acc[] += input[...] * filter[...]
244 //     }
245 //     output[...] = output_acc[]
246 //   }
CreateNaiveMlirConv(mlir::Value input,mlir::Value filter,mlir::Value output,const ShapeInfo & input_shape_info,const ShapeInfo & filter_shape_info,const ShapeInfo & output_shape_info,const Window & window,OpBuilder builder)247 StatusOr<InitialMlirConvAnchors> CreateNaiveMlirConv(
248     mlir::Value input, mlir::Value filter, mlir::Value output,
249     const ShapeInfo& input_shape_info, const ShapeInfo& filter_shape_info,
250     const ShapeInfo& output_shape_info, const Window& window,
251     OpBuilder builder) {
252   CHECK(input_shape_info.element_type == builder.getF16Type());
253   CHECK(filter_shape_info.element_type == builder.getF16Type());
254   CHECK(output_shape_info.element_type == builder.getF16Type());
255 
256   auto location = mlir::UnknownLoc::get(builder.getContext());
257 
258   std::vector<mlir::AffineForOp> cartesian_product_loops =
259       CreateNestedSimpleLoops(output_shape_info.nchw_dimensions, builder);
260 
261   builder =
262       OpBuilder::atBlockTerminator(cartesian_product_loops.back().getBody());
263 
264   mlir::AllocOp output_acc = builder.create<mlir::AllocOp>(
265       location, mlir::MemRefType::get({}, builder.getF32Type()));
266 
267   builder.create<mlir::AffineStoreOp>(
268       location,
269       builder.create<mlir::ConstantOp>(
270           location, mlir::FloatAttr::get(builder.getF32Type(), 0)),
271       output_acc, llvm::ArrayRef<mlir::Value>());
272 
273   std::vector<mlir::AffineForOp> reduction_loops;
274   reduction_loops = CreateNestedSimpleLoops(
275       absl::MakeSpan(filter_shape_info.nchw_dimensions).subspan(1), builder);
276 
277   mlir::AffineForOp loop_n = cartesian_product_loops[0];
278   mlir::AffineForOp loop_o = cartesian_product_loops[1];
279   mlir::AffineForOp loop_c = reduction_loops[0];
280 
281   std::vector<mlir::Value> output_spatial_indvars;
282   for (auto loop : absl::MakeSpan(cartesian_product_loops).subspan(2)) {
283     output_spatial_indvars.push_back(loop.getInductionVar());
284   }
285   std::vector<mlir::Value> filter_spatial_indvars;
286   for (auto loop : absl::MakeSpan(reduction_loops).subspan(1)) {
287     filter_spatial_indvars.push_back(loop.getInductionVar());
288   }
289   int num_spatial_dims = output_spatial_indvars.size();
290   CHECK_EQ(num_spatial_dims, filter_spatial_indvars.size());
291 
292   builder = OpBuilder::atBlockTerminator(reduction_loops.back().getBody());
293 
294   mlir::Value loaded_input = [&] {
295     std::vector<mlir::AffineExpr> input_indices;
296     input_indices.push_back(builder.getAffineDimExpr(0));
297     input_indices.push_back(builder.getAffineDimExpr(1));
298 
299     // For spatial dimensions, generate input_index * stride + filter_index -
300     // left_pad
301     //
302     // TODO(timshen): guard out-of-bound loads and stores brought by padding.
303     for (int i = 0; i < num_spatial_dims; i++) {
304       const WindowDimension& window_dim = window.dimensions(i);
305       input_indices.push_back(
306           builder.getAffineDimExpr(i + 2) * window_dim.stride() +
307           builder.getAffineDimExpr(2 + num_spatial_dims + i) -
308           window_dim.padding_low());
309     }
310     std::vector<mlir::Value> input_vars;
311     input_vars.push_back(loop_n.getInductionVar());
312     input_vars.push_back(loop_c.getInductionVar());
313     input_vars.insert(input_vars.end(), output_spatial_indvars.begin(),
314                       output_spatial_indvars.end());
315     input_vars.insert(input_vars.end(), filter_spatial_indvars.begin(),
316                       filter_spatial_indvars.end());
317 
318     return builder.create<mlir::FPExtOp>(
319         location,
320         builder.createOrFold<mlir::AffineLoadOp>(
321             location, input,
322             mlir::AffineMap(input_shape_info.affine_map)
323                 .compose(mlir::AffineMap::get(
324                     /*dimCount=*/2 + num_spatial_dims * 2,
325                     /*symbolCount=*/0, input_indices, builder.getContext())),
326             input_vars),
327         builder.getF32Type());
328   }();
329 
330   mlir::Value loaded_filter = [&] {
331     std::vector<mlir::Value> filter_vars;
332     filter_vars.push_back(loop_o.getInductionVar());
333     filter_vars.push_back(loop_c.getInductionVar());
334     filter_vars.insert(filter_vars.end(), filter_spatial_indvars.begin(),
335                        filter_spatial_indvars.end());
336 
337     return builder.create<mlir::FPExtOp>(
338         location,
339         builder.createOrFold<mlir::AffineLoadOp>(
340             location, filter, filter_shape_info.affine_map, filter_vars),
341         builder.getF32Type());
342   }();
343 
344   auto accum_load_op =
345       builder.createOrFold<mlir::AffineLoadOp>(location, output_acc);
346   builder.createOrFold<mlir::AffineStoreOp>(
347       location,
348       builder.create<mlir::AddFOp>(
349           location, accum_load_op,
350           builder.create<mlir::MulFOp>(location, loaded_input, loaded_filter)),
351       output_acc, llvm::ArrayRef<mlir::Value>());
352 
353   builder.setInsertionPointAfter(reduction_loops[0]);
354   {
355     std::vector<mlir::Value> output_vars;
356     output_vars.push_back(loop_n.getInductionVar());
357     output_vars.push_back(loop_o.getInductionVar());
358     output_vars.insert(output_vars.end(), output_spatial_indvars.begin(),
359                        output_spatial_indvars.end());
360     builder.createOrFold<mlir::AffineStoreOp>(
361         location,
362         builder.create<mlir::FPTruncOp>(
363             location,
364             builder.createOrFold<mlir::AffineLoadOp>(location, output_acc),
365             builder.getF16Type()),
366         output, output_shape_info.affine_map, output_vars);
367   }
368 
369   return InitialMlirConvAnchors{cartesian_product_loops, reduction_loops,
370                                 output_acc};
371 }
372 
373 // Contains the following pattern with anchors:
374 //   for (cartesian loops...) {
375 //     %output_acc = alloc() : memref(..., f32)
376 //     for (reduction loops...) {
377 //       for (tiled cartesian loops...) {
378 //         output_acc[...] = 0
379 //       }
380 //       for (tiled cartesian loops...) {
381 //         for (reduction loops...) {
382 //           output_acc[] += input[...] * filter[...]
383 //         }
384 //       }
385 //       for (tiled cartesian loops...) {
386 //         output[...] = output_acc[...]
387 //       }
388 //     }
389 //   }
390 struct TransformedMlirConvAnchors {
391   std::vector<mlir::AffineForOp> cartesian_product_loops;
392   std::vector<mlir::AffineForOp> reduction_loops;
393 };
394 
TransformMlirConv(InitialMlirConvAnchors anchors)395 StatusOr<TransformedMlirConvAnchors> TransformMlirConv(
396     InitialMlirConvAnchors anchors) {
397   std::vector<mlir::AffineForOp> cartesian_product_loops =
398       anchors.cartesian_product_loops;
399   std::vector<mlir::AffineForOp> reduction_loops = anchors.reduction_loops;
400   mlir::AllocOp output_acc = anchors.output_acc;
401 
402   // TODO(timshen): consider using pattern matchers for transformations
403   //
404   // Initial form:
405   //   for (cartesian loops...) {
406   //     %output_acc = alloc() : memref(f32)
407   //     output_acc[] = 0
408   //     for (reduction loops...) {
409   //       output_acc[] += input[...] * filter[...]
410   //     }
411   //     output[...] = output_acc[]
412   //   }
413 
414   // Tile cartesian loops to:
415   //   for (cartesian loops...) {
416   //     for (tiled cartesian loops...) {
417   //       %output_acc = alloc() : memref(f32)
418   //       output_acc[] = 0
419   //       for (reduction loops...) {
420   //         output_acc[] += input[...] * filter[...]
421   //       }
422   //       output[...] = output_acc[]
423   //     }
424   //   }
425   TileLoop(reduction_loops[0], 4, reduction_loops.back());
426 
427   std::vector<mlir::AffineForOp> tiled_cartesian_loops;
428   tiled_cartesian_loops.push_back(
429       TileLoop(cartesian_product_loops[1], 32, cartesian_product_loops.back()));
430 
431   tiled_cartesian_loops.push_back(TileLoop(cartesian_product_loops.back(), 16,
432                                            tiled_cartesian_loops.back()));
433 
434   // Two hoist operations to interleave the allocation, computation, and
435   // writebacks to output_acc:
436   // After first hoist:
437   //   for (cartesian loops...) {
438   //     %output_acc = alloc() : memref(..., f32)
439   //     for (tiled cartesian loops...) {
440   //       output_acc[...] = 0
441   //       for (reduction loops...) {
442   //         output_acc[...] += input[...] * filter[...]
443   //       }
444   //       output[...] = output_acc[...]
445   //     }
446   //   }
447   output_acc = llvm::cast<mlir::AllocOp>(
448       HoistAndFix(output_acc, tiled_cartesian_loops.front()));
449 
450   // Hoist everything before reduction loops (aka zero initializations of
451   // output_acc):
452   //   for (cartesian loops...) {
453   //     %output_acc = alloc() : memref(..., f32)
454   //     for (tiled cartesian loops...) {
455   //       output_acc[...] = 0
456   //     }
457   //     for (tiled cartesian loops...) {
458   //       for (reduction loops...) {
459   //         output_acc[...] += input[...] * filter[...]
460   //       }
461   //       output[...] = output_acc[...]
462   //     }
463   //   }
464   HoistAndFix(tiled_cartesian_loops.back().getBody()->begin(),
465               reduction_loops.front().getOperation()->getIterator(),
466               tiled_cartesian_loops.front());
467 
468   // Now hoist all reduction loops outside of tiled cartesian loops.
469   // Notice that HoistAndFix automatically add a new set of tiled cartesian
470   // loops for hoisted reduction loops to keep the semantics correct.
471   //
472   // After second hoist:
473   //   for (cartesian loops...) {
474   //     %output_acc = alloc() : memref(..., f32)
475   //     for (tiled cartesian loops...) {
476   //       output_acc[...] = 0
477   //     }
478   //     for (tiled cartesian loops...) {
479   //       for (reduction loops...) {
480   //         output_acc[] += input[...] * filter[...]
481   //       }
482   //     }  // compute loop
483   //     for (tiled cartesian loops...) {
484   //       output[...] = output_acc[...]
485   //     }
486   //   }
487   {
488     auto compute_loop = llvm::cast<mlir::AffineForOp>(
489         HoistAndFix(reduction_loops.front(), tiled_cartesian_loops[0]));
490 
491     // Fix tiled_cartesian_loops to make them point to the tiled compute loops,
492     // not the writeback loops to output buffer.
493     llvm::SmallVector<mlir::AffineForOp, 4> all_loops;
494     getPerfectlyNestedLoops(all_loops, compute_loop);
495     absl::c_copy_n(all_loops, tiled_cartesian_loops.size(),
496                    tiled_cartesian_loops.data());
497   }
498 
499   // After exchanging tiled cartesian compute loops with reduction loops:
500   //   for (cartesian loops...) {
501   //     %output_acc = alloc() : memref(..., f32)
502   //     for (tiled cartesian loops...) {
503   //       output_acc[...] = 0
504   //     }
505   //     for (reduction loops...) {
506   //       for (tiled cartesian loops...) {
507   //         output_acc[] += input[...] * filter[...]
508   //       }
509   //     }
510   //     for (tiled cartesian loops...) {
511   //       output[...] = output_acc[...]
512   //     }
513   //   }
514   //
515   // ...so that later tiled cartesian loops (with computations in it) can be
516   // replaced by CUDA MMA instructions.
517   {
518     std::vector<mlir::AffineForOp> loops;
519     loops.insert(loops.end(), tiled_cartesian_loops.begin(),
520                  tiled_cartesian_loops.end());
521     loops.insert(loops.end(), reduction_loops.begin(), reduction_loops.end());
522     SinkPerfectlyNestedLoops(loops, tiled_cartesian_loops.size());
523   }
524   return TransformedMlirConvAnchors{cartesian_product_loops, reduction_loops};
525 }
526 
527 }  // namespace
528 
EmitConvolutionForwardAsMlir(HloInstruction * conv,absl::string_view function_name,mlir::MLIRContext * context)529 StatusOr<mlir::FuncOp> EmitConvolutionForwardAsMlir(
530     HloInstruction* conv, absl::string_view function_name,
531     mlir::MLIRContext* context) {
532   OpBuilder builder(context);
533 
534   const auto& dim_nums = conv->convolution_dimension_numbers();
535   ShapeInfo input_shape_info =
536       GetShapeInfo(conv->operand(0)->shape(), dim_nums.input_batch_dimension(),
537                    dim_nums.input_feature_dimension(),
538                    dim_nums.input_spatial_dimensions(), builder);
539 
540   ShapeInfo filter_shape_info = GetShapeInfo(
541       conv->operand(1)->shape(), dim_nums.kernel_output_feature_dimension(),
542       dim_nums.kernel_input_feature_dimension(),
543       dim_nums.kernel_spatial_dimensions(), builder);
544 
545   ShapeInfo output_shape_info = GetShapeInfo(
546       conv->shape().tuple_shapes(0), dim_nums.output_batch_dimension(),
547       dim_nums.output_feature_dimension(), dim_nums.output_spatial_dimensions(),
548       builder);
549 
550   auto function = mlir::FuncOp::create(
551       mlir::UnknownLoc::get(builder.getContext()),
552       llvm_ir::AsStringRef(function_name),
553       builder.getFunctionType(
554           {mlir::MemRefType::get(output_shape_info.physical_dimensions,
555                                  output_shape_info.element_type, {}),
556            mlir::MemRefType::get(input_shape_info.physical_dimensions,
557                                  input_shape_info.element_type, {}),
558            mlir::MemRefType::get(filter_shape_info.physical_dimensions,
559                                  filter_shape_info.element_type, {})},
560           {}));
561 
562   auto* entry_block = function.addEntryBlock();
563   builder.setInsertionPointToStart(entry_block);
564   builder.create<mlir::ReturnOp>(builder.getUnknownLoc());
565   builder.setInsertionPointToStart(entry_block);
566 
567   mlir::Value input = entry_block->getArgument(1);
568   mlir::Value filter = entry_block->getArgument(2);
569   mlir::Value output = entry_block->getArgument(0);
570 
571   TF_RETURN_IF_ERROR(ConvIsImplemented(conv));
572 
573   TF_ASSIGN_OR_RETURN(
574       InitialMlirConvAnchors initial_anchors,
575       CreateNaiveMlirConv(input, filter, output, input_shape_info,
576                           filter_shape_info, output_shape_info, conv->window(),
577                           builder));
578 
579   TF_ASSIGN_OR_RETURN(TransformedMlirConvAnchors transformed_anchors,
580                       TransformMlirConv(initial_anchors));
581 
582   // TODO(timshen): Implement a transformation that collects loads to a given
583   // buffer, create a local alloc() for the accessed part, redirects all loads
584   // and stores to that local alloc(), and create code to initialize /
585   // writeback the local alloc() if needed.
586 
587   // TODO(timshen): Implement CUDA-specific lowering.
588 
589   return function;
590 }
591 
ConvIsImplemented(const HloInstruction * conv)592 Status ConvIsImplemented(const HloInstruction* conv) {
593   if (conv->feature_group_count() != 1 || conv->batch_group_count() != 1) {
594     return Unimplemented("group count is not implemented.");
595   }
596   if (window_util::HasWindowReversal(conv->window())) {
597     return Unimplemented("Window reversal is not implemented.");
598   }
599   if (window_util::HasDilation(conv->window())) {
600     return Unimplemented("Dilation is not implemented.");
601   }
602   return Status::OK();
603 }
604 
605 }  // namespace experimental
606 }  // namespace xla
607