1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 // This is an explorative prototype emitter for convolution using MLIR.
17 // This prototype is still under construction.
18 // TODO(timshen): Fix the documentation once it's implemented.
19 //
20 // Goals:
21 // * Autotune-able tiling.
22 // * Autotune-able memory accesses.
23 // * Autotune-able lowering logic (from a portable program to thread-oriented
24 // CUDA program).
25 // * Use milr::AffineExpr to analyze all accesses. It aims to algorithmically
26 // find memory access strategies for given input layouts and tiling configs.
27
28 #include "tensorflow/compiler/mlir/xla/experimental/conv_emitter/conv_emitter.h"
29
30 #include "absl/types/span.h"
31 #include "llvm/ADT/STLExtras.h"
32 #include "llvm/Support/raw_ostream.h"
33 #include "mlir/Dialect/Affine/IR/AffineOps.h" // from @llvm-project
34 #include "mlir/Dialect/StandardOps/IR/Ops.h" // from @llvm-project
35 #include "mlir/IR/AffineExpr.h" // from @llvm-project
36 #include "mlir/IR/AffineMap.h" // from @llvm-project
37 #include "mlir/IR/Builders.h" // from @llvm-project
38 #include "mlir/IR/BuiltinTypes.h" // from @llvm-project
39 #include "mlir/Transforms/LoopUtils.h" // from @llvm-project
40 #include "mlir/Transforms/RegionUtils.h" // from @llvm-project
41 #include "tensorflow/compiler/mlir/xla/experimental/conv_emitter/conv_emitter_transforms.h"
42 #include "tensorflow/compiler/xla/permutation_util.h"
43 #include "tensorflow/compiler/xla/service/llvm_ir/llvm_util.h"
44 #include "tensorflow/compiler/xla/window_util.h"
45
46 namespace xla {
47 namespace experimental {
48 namespace {
49
50 using mlir::OpBuilder;
51
52 // Various extracted information for input shapes.
53 struct ShapeInfo {
54 // Buffer dimensions in the order of NCHW.
55 std::vector<int64_t> nchw_dimensions;
56
57 // Buffer dimensions in the order of major to minor;
58 std::vector<int64_t> physical_dimensions;
59
60 // The affine map that takes NCHW indices, and maps to the physical order.
61 mlir::AffineMap affine_map;
62
63 mlir::Type element_type;
64 };
65
GetShapeInfo(const Shape & shape,int64 n_dim,int64 c_dim,absl::Span<const tensorflow::protobuf_int64> spatial_dims,mlir::Builder builder)66 ShapeInfo GetShapeInfo(
67 const Shape& shape, int64 n_dim, int64 c_dim,
68 absl::Span<const tensorflow::protobuf_int64> spatial_dims,
69 mlir::Builder builder) {
70 ShapeInfo shape_info;
71
72 std::vector<int64> physical_to_logical(
73 shape.layout().minor_to_major().rbegin(),
74 shape.layout().minor_to_major().rend());
75
76 std::vector<int64> nchw_to_logical;
77
78 nchw_to_logical.push_back(n_dim);
79 nchw_to_logical.push_back(c_dim);
80 for (int64 dim : spatial_dims) {
81 nchw_to_logical.push_back(dim);
82 }
83
84 for (int64 dim : nchw_to_logical) {
85 shape_info.nchw_dimensions.push_back(shape.dimensions(dim));
86 }
87
88 for (int64 dim : physical_to_logical) {
89 shape_info.physical_dimensions.push_back(shape.dimensions(dim));
90 }
91
92 std::vector<mlir::AffineExpr> affine_exprs;
93 // We want physical to nchw order.
94 for (int64 dim : ComposePermutations(InversePermutation(nchw_to_logical),
95 physical_to_logical)) {
96 affine_exprs.push_back(builder.getAffineDimExpr(dim));
97 }
98
99 shape_info.affine_map = mlir::AffineMap::get(
100 /*dimCount=*/2 + spatial_dims.size(), /*symbolCount=*/0, affine_exprs,
101 builder.getContext());
102
103 shape_info.element_type = [&] {
104 switch (shape.element_type()) {
105 case xla::F16:
106 return builder.getF16Type();
107 case xla::F32:
108 return builder.getF32Type();
109 default:
110 break;
111 }
112 CHECK(false);
113 }();
114
115 return shape_info;
116 }
117
SetMemRef(mlir::Operation * op,mlir::Value memref)118 void SetMemRef(mlir::Operation* op, mlir::Value memref) {
119 if (auto load = mlir::dyn_cast<mlir::AffineLoadOp>(op)) {
120 load.setMemRef(memref);
121 } else if (auto store = mlir::dyn_cast<mlir::AffineStoreOp>(op)) {
122 store.setMemRef(memref);
123 } else {
124 CHECK(false);
125 }
126 }
127
128 // Hoist operations out of `where`. [begin_op, end_op) must be the first
129 // operations of their parent loop, and `where` must be an ancestor of that
130 // parent loop.
131 //
132 // It always preserves the semantics of the program, therefore it may modify the
133 // hoisted operations or add extra loops at the hoisted place.
HoistAndFix(llvm::iplist<mlir::Operation>::iterator begin_op,llvm::iplist<mlir::Operation>::iterator end_op,mlir::AffineForOp where)134 mlir::Operation* HoistAndFix(llvm::iplist<mlir::Operation>::iterator begin_op,
135 llvm::iplist<mlir::Operation>::iterator end_op,
136 mlir::AffineForOp where) {
137 // All loops to hoist through.
138 llvm::SmallVector<mlir::AffineForOp, 4> ancestors;
139 getPerfectlyNestedLoops(ancestors, where);
140 {
141 int i;
142 for (i = 0; i < ancestors.size(); i++) {
143 if (&ancestors[i].getBody()->front() == &*begin_op) {
144 break;
145 }
146 }
147 CHECK(i < ancestors.size());
148 ancestors.resize(i + 1);
149 }
150
151 std::vector<int64_t> ancestor_dimensions;
152 for (auto ancestor : ancestors) {
153 CHECK(IsSimpleLoop(ancestor));
154 ancestor_dimensions.push_back(
155 ancestor.getUpperBoundMap().getSingleConstantResult());
156 }
157
158 if (auto alloc = mlir::dyn_cast<mlir::AllocOp>(begin_op)) {
159 CHECK(std::next(begin_op) == end_op)
160 << "alloc() needs to be hoisted by its own";
161
162 OpBuilder builder(where);
163 mlir::MemRefType type = alloc.getType();
164 CHECK(type.getAffineMaps().empty());
165 ancestor_dimensions.insert(ancestor_dimensions.end(),
166 type.getShape().begin(), type.getShape().end());
167 mlir::MemRefType new_type =
168 mlir::MemRefType::get(ancestor_dimensions, type.getElementType());
169 auto new_alloc =
170 builder.create<mlir::AllocOp>(builder.getUnknownLoc(), new_type);
171
172 std::vector<mlir::Value> indvars;
173 for (auto ancestor : ancestors) {
174 indvars.push_back(ancestor.getInductionVar());
175 }
176 for (auto& use : llvm::make_early_inc_range(alloc.getResult().getUses())) {
177 mlir::Operation* owner = use.getOwner();
178 BoundAffineMap affine_map = GetBoundAffineMapFrom(owner);
179 affine_map.operands.insert(affine_map.operands.begin(), indvars.begin(),
180 indvars.end());
181 CHECK(affine_map.affine_map.isIdentity());
182 affine_map.affine_map = mlir::AffineMap::getMultiDimIdentityMap(
183 affine_map.operands.size(), builder.getContext());
184
185 mlir::Operation* new_op =
186 CloneWithNewAffineMap(owner, affine_map, OpBuilder(owner));
187 SetMemRef(new_op, new_alloc);
188 owner->replaceAllUsesWith(new_op);
189 owner->erase();
190 }
191 alloc.erase();
192 return new_alloc;
193 }
194
195 const bool any_op_is_loop_variant = [&] {
196 for (mlir::Operation& op : llvm::make_range(begin_op, end_op)) {
197 if (mlir::isa<mlir::AffineForOp, mlir::AffineStoreOp>(op)) {
198 return true;
199 }
200 }
201 return false;
202 }();
203
204 if (any_op_is_loop_variant) {
205 auto builder = OpBuilder(where);
206 std::vector<mlir::AffineForOp> new_loops;
207 for (auto dim : ancestor_dimensions) {
208 auto where =
209 builder.create<mlir::AffineForOp>(builder.getUnknownLoc(), 0, dim);
210 new_loops.push_back(where);
211 builder = OpBuilder::atBlockTerminator(where.getBody());
212 }
213 for (mlir::Operation& op :
214 llvm::make_early_inc_range(llvm::make_range(begin_op, end_op))) {
215 op.moveBefore(&new_loops.back().getBody()->back());
216 }
217 CHECK_EQ(ancestors.size(), new_loops.size());
218 for (int i = 0; i < ancestors.size(); i++) {
219 replaceAllUsesInRegionWith(ancestors[i].getInductionVar(),
220 new_loops[i].getInductionVar(),
221 new_loops.back().region());
222 }
223 return new_loops.front();
224 }
225 CHECK(false);
226 }
227
HoistAndFix(mlir::Operation * op,mlir::AffineForOp where)228 mlir::Operation* HoistAndFix(mlir::Operation* op, mlir::AffineForOp where) {
229 return HoistAndFix(op->getIterator(), std::next(op->getIterator()), where);
230 }
231
232 struct InitialMlirConvAnchors {
233 std::vector<mlir::AffineForOp> cartesian_product_loops;
234 std::vector<mlir::AffineForOp> reduction_loops;
235 mlir::AllocOp output_acc;
236 };
237
238 // Return the following IR with the anchors set to corresponding operations.
239 // for (cartesian loops...) {
240 // %output_acc = alloc() : memref(f32)
241 // output_acc[] = 0
242 // for (reduction loops...) {
243 // output_acc[] += input[...] * filter[...]
244 // }
245 // output[...] = output_acc[]
246 // }
CreateNaiveMlirConv(mlir::Value input,mlir::Value filter,mlir::Value output,const ShapeInfo & input_shape_info,const ShapeInfo & filter_shape_info,const ShapeInfo & output_shape_info,const Window & window,OpBuilder builder)247 StatusOr<InitialMlirConvAnchors> CreateNaiveMlirConv(
248 mlir::Value input, mlir::Value filter, mlir::Value output,
249 const ShapeInfo& input_shape_info, const ShapeInfo& filter_shape_info,
250 const ShapeInfo& output_shape_info, const Window& window,
251 OpBuilder builder) {
252 CHECK(input_shape_info.element_type == builder.getF16Type());
253 CHECK(filter_shape_info.element_type == builder.getF16Type());
254 CHECK(output_shape_info.element_type == builder.getF16Type());
255
256 auto location = mlir::UnknownLoc::get(builder.getContext());
257
258 std::vector<mlir::AffineForOp> cartesian_product_loops =
259 CreateNestedSimpleLoops(output_shape_info.nchw_dimensions, builder);
260
261 builder =
262 OpBuilder::atBlockTerminator(cartesian_product_loops.back().getBody());
263
264 mlir::AllocOp output_acc = builder.create<mlir::AllocOp>(
265 location, mlir::MemRefType::get({}, builder.getF32Type()));
266
267 builder.create<mlir::AffineStoreOp>(
268 location,
269 builder.create<mlir::ConstantOp>(
270 location, mlir::FloatAttr::get(builder.getF32Type(), 0)),
271 output_acc, llvm::ArrayRef<mlir::Value>());
272
273 std::vector<mlir::AffineForOp> reduction_loops;
274 reduction_loops = CreateNestedSimpleLoops(
275 absl::MakeSpan(filter_shape_info.nchw_dimensions).subspan(1), builder);
276
277 mlir::AffineForOp loop_n = cartesian_product_loops[0];
278 mlir::AffineForOp loop_o = cartesian_product_loops[1];
279 mlir::AffineForOp loop_c = reduction_loops[0];
280
281 std::vector<mlir::Value> output_spatial_indvars;
282 for (auto loop : absl::MakeSpan(cartesian_product_loops).subspan(2)) {
283 output_spatial_indvars.push_back(loop.getInductionVar());
284 }
285 std::vector<mlir::Value> filter_spatial_indvars;
286 for (auto loop : absl::MakeSpan(reduction_loops).subspan(1)) {
287 filter_spatial_indvars.push_back(loop.getInductionVar());
288 }
289 int num_spatial_dims = output_spatial_indvars.size();
290 CHECK_EQ(num_spatial_dims, filter_spatial_indvars.size());
291
292 builder = OpBuilder::atBlockTerminator(reduction_loops.back().getBody());
293
294 mlir::Value loaded_input = [&] {
295 std::vector<mlir::AffineExpr> input_indices;
296 input_indices.push_back(builder.getAffineDimExpr(0));
297 input_indices.push_back(builder.getAffineDimExpr(1));
298
299 // For spatial dimensions, generate input_index * stride + filter_index -
300 // left_pad
301 //
302 // TODO(timshen): guard out-of-bound loads and stores brought by padding.
303 for (int i = 0; i < num_spatial_dims; i++) {
304 const WindowDimension& window_dim = window.dimensions(i);
305 input_indices.push_back(
306 builder.getAffineDimExpr(i + 2) * window_dim.stride() +
307 builder.getAffineDimExpr(2 + num_spatial_dims + i) -
308 window_dim.padding_low());
309 }
310 std::vector<mlir::Value> input_vars;
311 input_vars.push_back(loop_n.getInductionVar());
312 input_vars.push_back(loop_c.getInductionVar());
313 input_vars.insert(input_vars.end(), output_spatial_indvars.begin(),
314 output_spatial_indvars.end());
315 input_vars.insert(input_vars.end(), filter_spatial_indvars.begin(),
316 filter_spatial_indvars.end());
317
318 return builder.create<mlir::FPExtOp>(
319 location,
320 builder.createOrFold<mlir::AffineLoadOp>(
321 location, input,
322 mlir::AffineMap(input_shape_info.affine_map)
323 .compose(mlir::AffineMap::get(
324 /*dimCount=*/2 + num_spatial_dims * 2,
325 /*symbolCount=*/0, input_indices, builder.getContext())),
326 input_vars),
327 builder.getF32Type());
328 }();
329
330 mlir::Value loaded_filter = [&] {
331 std::vector<mlir::Value> filter_vars;
332 filter_vars.push_back(loop_o.getInductionVar());
333 filter_vars.push_back(loop_c.getInductionVar());
334 filter_vars.insert(filter_vars.end(), filter_spatial_indvars.begin(),
335 filter_spatial_indvars.end());
336
337 return builder.create<mlir::FPExtOp>(
338 location,
339 builder.createOrFold<mlir::AffineLoadOp>(
340 location, filter, filter_shape_info.affine_map, filter_vars),
341 builder.getF32Type());
342 }();
343
344 auto accum_load_op =
345 builder.createOrFold<mlir::AffineLoadOp>(location, output_acc);
346 builder.createOrFold<mlir::AffineStoreOp>(
347 location,
348 builder.create<mlir::AddFOp>(
349 location, accum_load_op,
350 builder.create<mlir::MulFOp>(location, loaded_input, loaded_filter)),
351 output_acc, llvm::ArrayRef<mlir::Value>());
352
353 builder.setInsertionPointAfter(reduction_loops[0]);
354 {
355 std::vector<mlir::Value> output_vars;
356 output_vars.push_back(loop_n.getInductionVar());
357 output_vars.push_back(loop_o.getInductionVar());
358 output_vars.insert(output_vars.end(), output_spatial_indvars.begin(),
359 output_spatial_indvars.end());
360 builder.createOrFold<mlir::AffineStoreOp>(
361 location,
362 builder.create<mlir::FPTruncOp>(
363 location,
364 builder.createOrFold<mlir::AffineLoadOp>(location, output_acc),
365 builder.getF16Type()),
366 output, output_shape_info.affine_map, output_vars);
367 }
368
369 return InitialMlirConvAnchors{cartesian_product_loops, reduction_loops,
370 output_acc};
371 }
372
373 // Contains the following pattern with anchors:
374 // for (cartesian loops...) {
375 // %output_acc = alloc() : memref(..., f32)
376 // for (reduction loops...) {
377 // for (tiled cartesian loops...) {
378 // output_acc[...] = 0
379 // }
380 // for (tiled cartesian loops...) {
381 // for (reduction loops...) {
382 // output_acc[] += input[...] * filter[...]
383 // }
384 // }
385 // for (tiled cartesian loops...) {
386 // output[...] = output_acc[...]
387 // }
388 // }
389 // }
390 struct TransformedMlirConvAnchors {
391 std::vector<mlir::AffineForOp> cartesian_product_loops;
392 std::vector<mlir::AffineForOp> reduction_loops;
393 };
394
TransformMlirConv(InitialMlirConvAnchors anchors)395 StatusOr<TransformedMlirConvAnchors> TransformMlirConv(
396 InitialMlirConvAnchors anchors) {
397 std::vector<mlir::AffineForOp> cartesian_product_loops =
398 anchors.cartesian_product_loops;
399 std::vector<mlir::AffineForOp> reduction_loops = anchors.reduction_loops;
400 mlir::AllocOp output_acc = anchors.output_acc;
401
402 // TODO(timshen): consider using pattern matchers for transformations
403 //
404 // Initial form:
405 // for (cartesian loops...) {
406 // %output_acc = alloc() : memref(f32)
407 // output_acc[] = 0
408 // for (reduction loops...) {
409 // output_acc[] += input[...] * filter[...]
410 // }
411 // output[...] = output_acc[]
412 // }
413
414 // Tile cartesian loops to:
415 // for (cartesian loops...) {
416 // for (tiled cartesian loops...) {
417 // %output_acc = alloc() : memref(f32)
418 // output_acc[] = 0
419 // for (reduction loops...) {
420 // output_acc[] += input[...] * filter[...]
421 // }
422 // output[...] = output_acc[]
423 // }
424 // }
425 TileLoop(reduction_loops[0], 4, reduction_loops.back());
426
427 std::vector<mlir::AffineForOp> tiled_cartesian_loops;
428 tiled_cartesian_loops.push_back(
429 TileLoop(cartesian_product_loops[1], 32, cartesian_product_loops.back()));
430
431 tiled_cartesian_loops.push_back(TileLoop(cartesian_product_loops.back(), 16,
432 tiled_cartesian_loops.back()));
433
434 // Two hoist operations to interleave the allocation, computation, and
435 // writebacks to output_acc:
436 // After first hoist:
437 // for (cartesian loops...) {
438 // %output_acc = alloc() : memref(..., f32)
439 // for (tiled cartesian loops...) {
440 // output_acc[...] = 0
441 // for (reduction loops...) {
442 // output_acc[...] += input[...] * filter[...]
443 // }
444 // output[...] = output_acc[...]
445 // }
446 // }
447 output_acc = llvm::cast<mlir::AllocOp>(
448 HoistAndFix(output_acc, tiled_cartesian_loops.front()));
449
450 // Hoist everything before reduction loops (aka zero initializations of
451 // output_acc):
452 // for (cartesian loops...) {
453 // %output_acc = alloc() : memref(..., f32)
454 // for (tiled cartesian loops...) {
455 // output_acc[...] = 0
456 // }
457 // for (tiled cartesian loops...) {
458 // for (reduction loops...) {
459 // output_acc[...] += input[...] * filter[...]
460 // }
461 // output[...] = output_acc[...]
462 // }
463 // }
464 HoistAndFix(tiled_cartesian_loops.back().getBody()->begin(),
465 reduction_loops.front().getOperation()->getIterator(),
466 tiled_cartesian_loops.front());
467
468 // Now hoist all reduction loops outside of tiled cartesian loops.
469 // Notice that HoistAndFix automatically add a new set of tiled cartesian
470 // loops for hoisted reduction loops to keep the semantics correct.
471 //
472 // After second hoist:
473 // for (cartesian loops...) {
474 // %output_acc = alloc() : memref(..., f32)
475 // for (tiled cartesian loops...) {
476 // output_acc[...] = 0
477 // }
478 // for (tiled cartesian loops...) {
479 // for (reduction loops...) {
480 // output_acc[] += input[...] * filter[...]
481 // }
482 // } // compute loop
483 // for (tiled cartesian loops...) {
484 // output[...] = output_acc[...]
485 // }
486 // }
487 {
488 auto compute_loop = llvm::cast<mlir::AffineForOp>(
489 HoistAndFix(reduction_loops.front(), tiled_cartesian_loops[0]));
490
491 // Fix tiled_cartesian_loops to make them point to the tiled compute loops,
492 // not the writeback loops to output buffer.
493 llvm::SmallVector<mlir::AffineForOp, 4> all_loops;
494 getPerfectlyNestedLoops(all_loops, compute_loop);
495 absl::c_copy_n(all_loops, tiled_cartesian_loops.size(),
496 tiled_cartesian_loops.data());
497 }
498
499 // After exchanging tiled cartesian compute loops with reduction loops:
500 // for (cartesian loops...) {
501 // %output_acc = alloc() : memref(..., f32)
502 // for (tiled cartesian loops...) {
503 // output_acc[...] = 0
504 // }
505 // for (reduction loops...) {
506 // for (tiled cartesian loops...) {
507 // output_acc[] += input[...] * filter[...]
508 // }
509 // }
510 // for (tiled cartesian loops...) {
511 // output[...] = output_acc[...]
512 // }
513 // }
514 //
515 // ...so that later tiled cartesian loops (with computations in it) can be
516 // replaced by CUDA MMA instructions.
517 {
518 std::vector<mlir::AffineForOp> loops;
519 loops.insert(loops.end(), tiled_cartesian_loops.begin(),
520 tiled_cartesian_loops.end());
521 loops.insert(loops.end(), reduction_loops.begin(), reduction_loops.end());
522 SinkPerfectlyNestedLoops(loops, tiled_cartesian_loops.size());
523 }
524 return TransformedMlirConvAnchors{cartesian_product_loops, reduction_loops};
525 }
526
527 } // namespace
528
EmitConvolutionForwardAsMlir(HloInstruction * conv,absl::string_view function_name,mlir::MLIRContext * context)529 StatusOr<mlir::FuncOp> EmitConvolutionForwardAsMlir(
530 HloInstruction* conv, absl::string_view function_name,
531 mlir::MLIRContext* context) {
532 OpBuilder builder(context);
533
534 const auto& dim_nums = conv->convolution_dimension_numbers();
535 ShapeInfo input_shape_info =
536 GetShapeInfo(conv->operand(0)->shape(), dim_nums.input_batch_dimension(),
537 dim_nums.input_feature_dimension(),
538 dim_nums.input_spatial_dimensions(), builder);
539
540 ShapeInfo filter_shape_info = GetShapeInfo(
541 conv->operand(1)->shape(), dim_nums.kernel_output_feature_dimension(),
542 dim_nums.kernel_input_feature_dimension(),
543 dim_nums.kernel_spatial_dimensions(), builder);
544
545 ShapeInfo output_shape_info = GetShapeInfo(
546 conv->shape().tuple_shapes(0), dim_nums.output_batch_dimension(),
547 dim_nums.output_feature_dimension(), dim_nums.output_spatial_dimensions(),
548 builder);
549
550 auto function = mlir::FuncOp::create(
551 mlir::UnknownLoc::get(builder.getContext()),
552 llvm_ir::AsStringRef(function_name),
553 builder.getFunctionType(
554 {mlir::MemRefType::get(output_shape_info.physical_dimensions,
555 output_shape_info.element_type, {}),
556 mlir::MemRefType::get(input_shape_info.physical_dimensions,
557 input_shape_info.element_type, {}),
558 mlir::MemRefType::get(filter_shape_info.physical_dimensions,
559 filter_shape_info.element_type, {})},
560 {}));
561
562 auto* entry_block = function.addEntryBlock();
563 builder.setInsertionPointToStart(entry_block);
564 builder.create<mlir::ReturnOp>(builder.getUnknownLoc());
565 builder.setInsertionPointToStart(entry_block);
566
567 mlir::Value input = entry_block->getArgument(1);
568 mlir::Value filter = entry_block->getArgument(2);
569 mlir::Value output = entry_block->getArgument(0);
570
571 TF_RETURN_IF_ERROR(ConvIsImplemented(conv));
572
573 TF_ASSIGN_OR_RETURN(
574 InitialMlirConvAnchors initial_anchors,
575 CreateNaiveMlirConv(input, filter, output, input_shape_info,
576 filter_shape_info, output_shape_info, conv->window(),
577 builder));
578
579 TF_ASSIGN_OR_RETURN(TransformedMlirConvAnchors transformed_anchors,
580 TransformMlirConv(initial_anchors));
581
582 // TODO(timshen): Implement a transformation that collects loads to a given
583 // buffer, create a local alloc() for the accessed part, redirects all loads
584 // and stores to that local alloc(), and create code to initialize /
585 // writeback the local alloc() if needed.
586
587 // TODO(timshen): Implement CUDA-specific lowering.
588
589 return function;
590 }
591
ConvIsImplemented(const HloInstruction * conv)592 Status ConvIsImplemented(const HloInstruction* conv) {
593 if (conv->feature_group_count() != 1 || conv->batch_group_count() != 1) {
594 return Unimplemented("group count is not implemented.");
595 }
596 if (window_util::HasWindowReversal(conv->window())) {
597 return Unimplemented("Window reversal is not implemented.");
598 }
599 if (window_util::HasDilation(conv->window())) {
600 return Unimplemented("Dilation is not implemented.");
601 }
602 return Status::OK();
603 }
604
605 } // namespace experimental
606 } // namespace xla
607