1// RUN: mlir-opt -allow-unregistered-dialect %s -split-input-file -verify-diagnostics 2 3// ----- 4func @invalidStatisticsMismatchedLayerType(%arg0: tensor<8x4x3xf32>) -> 5 tensor<8x4x3xf32> { 6 // expected-error@+1 {{layerStats must have a floating point element type}} 7 %0 = "quant.stats"(%arg0) { 8 layerStats = dense<[-1, 1]> : tensor<2xi8> 9 } : (tensor<8x4x3xf32>) -> tensor<8x4x3xf32> 10 return %0 : tensor<8x4x3xf32> 11} 12 13// ----- 14func @invalidStatisticsMismatchedLayerRank(%arg0: tensor<8x4x3xf32>) -> 15 tensor<8x4x3xf32> { 16 // expected-error@+1 {{layerStats must have shape [2]}} 17 %0 = "quant.stats"(%arg0) { 18 layerStats = dense<[[-1.0, 1.0]]> : tensor<1x2xf32> 19 } : (tensor<8x4x3xf32>) -> tensor<8x4x3xf32> 20 return %0 : tensor<8x4x3xf32> 21} 22 23// ----- 24func @invalidStatisticsMismatchedLayerShape(%arg0: tensor<8x4x3xf32>) -> 25 tensor<8x4x3xf32> { 26 // expected-error@+1 {{layerStats must have shape [2]}} 27 %0 = "quant.stats"(%arg0) { 28 layerStats = dense<[-1.0, 1.0, 2.0]> : tensor<3xf32> 29 } : (tensor<8x4x3xf32>) -> tensor<8x4x3xf32> 30 return %0 : tensor<8x4x3xf32> 31} 32 33// ----- 34// CHECK-LABEL: validStatistics 35func @invalidStatisticsMismatchedAxisType(%arg0: tensor<8x4x3xf32>) -> tensor<8x4x3xf32> { 36 // expected-error@+1 {{axisStats must have a floating point element type}} 37 %0 = "quant.stats"(%0) { 38 layerStats = dense<[-1.0, 1.0]> : tensor<2xf32>, 39 axisStats = dense<[ 40 [-1, 1], 41 [-8, 8], 42 [-1, 0] 43 ]> : tensor<3x2xi8>, axis = 3 : i64 44 } : (tensor<8x4x3xf32>) -> tensor<8x4x3xf32> 45 return %0 : tensor<8x4x3xf32> 46} 47 48// ----- 49func @invalidStatisticsMismatchedAxisSize(%arg0: tensor<8x4x3xf32>) -> 50 tensor<8x4x3xf32> { 51 // expected-error@+1 {{axisStats must have shape [N,2] where N = the slice size defined by the axis dim}} 52 %0 = "quant.stats"(%arg0) { 53 layerStats = dense<[-1.0, 1.0]> : tensor<2xf32>, 54 axisStats = dense<[ 55 [-1.0, 1.0], 56 [-8.0, 8.0], 57 [-0.5, 0.5], 58 [-2.0, 3.5] 59 ]> : tensor<4x2xf32>, axis = 3 : i64 60 } : (tensor<8x4x3xf32>) -> tensor<8x4x3xf32> 61 return %0 : tensor<8x4x3xf32> 62} 63 64// ----- 65func @invalidStatisticsMismatchedAxisShape(%arg0: tensor<8x4x3xf32>) -> 66 tensor<8x4x3xf32> { 67 // expected-error@+1 {{axisStats must have shape [N,2] where N = the slice size defined by the axis dim}} 68 %0 = "quant.stats"(%arg0) { 69 layerStats = dense<[-1.0, 1.0]> : tensor<2xf32>, 70 axisStats = dense<[ 71 [-1.0, 1.0, 1.0], 72 [-8.0, 8.0, 1.0], 73 [-0.5, 0.5, 1.0] 74 ]> : tensor<3x3xf32>, axis = 3 : i64 75 } : (tensor<8x4x3xf32>) -> tensor<8x4x3xf32> 76 return %0 : tensor<8x4x3xf32> 77} 78 79// ----- 80func @axisIsRequiredForAxisStats(%arg0: tensor<8x4x3xf32>) -> tensor<8x4x3xf32> { 81 // expected-error@+1 {{axis must be specified for axisStats}} 82 %1 = "quant.stats"(%arg0) { 83 layerStats = dense<[-1.0, 1.0]> : tensor<2xf32>, 84 axisStats = dense<[ 85 [-1.0, 1.0], 86 [-8.0, 8.0], 87 [-0.5, 0.5] 88 ]> : tensor<3x2xf32> 89 } : (tensor<8x4x3xf32>) -> tensor<8x4x3xf32> 90 return %1 : tensor<8x4x3xf32> 91} 92 93// ----- 94