all-reduce-max.mlir 11.3 KB
// RUN: mlir-opt -test-all-reduce-lowering %s | FileCheck %s

// NOTE: Assertions have been autogenerated by utils/generate-test-checks.py
// CHECK: gpu.module @kernels {
gpu.module @kernels {

  // CHECK-LABEL: gpu.func @kernel(
  // CHECK-SAME: [[VAL_0:%.*]]: f32) workgroup([[VAL_1:%.*]] : memref<32xf32, 3>) kernel {
  gpu.func @kernel(%arg0 : f32) kernel {
    // CHECK:   [[VAL_2:%.*]] = constant 31 : i32
    // CHECK:   [[VAL_3:%.*]] = constant 0 : i32
    // CHECK:   [[VAL_4:%.*]] = constant 0 : index
    // CHECK:   [[VAL_5:%.*]] = constant 32 : i32
    // CHECK:   [[VAL_6:%.*]] = constant 1 : i32
    // CHECK:   [[VAL_7:%.*]] = constant 2 : i32
    // CHECK:   [[VAL_8:%.*]] = constant 4 : i32
    // CHECK:   [[VAL_9:%.*]] = constant 8 : i32
    // CHECK:   [[VAL_10:%.*]] = constant 16 : i32
    // CHECK:   [[VAL_11:%.*]] = "gpu.block_dim"() {dimension = "x"} : () -> index
    // CHECK:   [[VAL_12:%.*]] = index_cast [[VAL_11]] : index to i32
    // CHECK:   [[VAL_13:%.*]] = "gpu.block_dim"() {dimension = "y"} : () -> index
    // CHECK:   [[VAL_14:%.*]] = index_cast [[VAL_13]] : index to i32
    // CHECK:   [[VAL_15:%.*]] = "gpu.block_dim"() {dimension = "z"} : () -> index
    // CHECK:   [[VAL_16:%.*]] = index_cast [[VAL_15]] : index to i32
    // CHECK:   [[VAL_17:%.*]] = "gpu.thread_id"() {dimension = "x"} : () -> index
    // CHECK:   [[VAL_18:%.*]] = index_cast [[VAL_17]] : index to i32
    // CHECK:   [[VAL_19:%.*]] = "gpu.thread_id"() {dimension = "y"} : () -> index
    // CHECK:   [[VAL_20:%.*]] = index_cast [[VAL_19]] : index to i32
    // CHECK:   [[VAL_21:%.*]] = "gpu.thread_id"() {dimension = "z"} : () -> index
    // CHECK:   [[VAL_22:%.*]] = index_cast [[VAL_21]] : index to i32
    // CHECK:   [[VAL_23:%.*]] = muli [[VAL_22]], [[VAL_14]] : i32
    // CHECK:   [[VAL_24:%.*]] = addi [[VAL_23]], [[VAL_20]] : i32
    // CHECK:   [[VAL_25:%.*]] = muli [[VAL_24]], [[VAL_12]] : i32
    // CHECK:   [[VAL_26:%.*]] = muli [[VAL_12]], [[VAL_14]] : i32
    // CHECK:   [[VAL_27:%.*]] = addi [[VAL_25]], [[VAL_18]] : i32
    // CHECK:   [[VAL_28:%.*]] = muli [[VAL_26]], [[VAL_16]] : i32
    // CHECK:   [[VAL_29:%.*]] = and [[VAL_27]], [[VAL_2]] : i32
    // CHECK:   [[VAL_30:%.*]] = cmpi "eq", [[VAL_29]], [[VAL_3]] : i32
    // CHECK:   [[VAL_31:%.*]] = subi [[VAL_27]], [[VAL_29]] : i32
    // CHECK:   [[VAL_32:%.*]] = subi [[VAL_28]], [[VAL_31]] : i32
    // CHECK:   [[VAL_33:%.*]] = cmpi "slt", [[VAL_32]], [[VAL_5]] : i32
    // CHECK:   cond_br [[VAL_33]], ^bb1, ^bb17
    // CHECK: ^bb1:
    // CHECK:   [[VAL_34:%.*]], [[VAL_35:%.*]] = gpu.shuffle [[VAL_0]], [[VAL_6]], [[VAL_32]] xor : f32
    // CHECK:   cond_br [[VAL_35]], ^bb2, ^bb3
    // CHECK: ^bb2:
    // CHECK:   [[VAL_36:%.*]] = cmpf "ugt", [[VAL_0]], [[VAL_34]] : f32
    // CHECK:   [[VAL_37:%.*]] = select [[VAL_36]], [[VAL_0]], [[VAL_34]] : f32
    // CHECK:   br ^bb4([[VAL_37]] : f32)
    // CHECK: ^bb3:
    // CHECK:   br ^bb4([[VAL_0]] : f32)
    // CHECK: ^bb4([[VAL_38:%.*]]: f32):
    // CHECK:   [[VAL_39:%.*]], [[VAL_40:%.*]] = gpu.shuffle [[VAL_38]], [[VAL_7]], [[VAL_32]] xor : f32
    // CHECK:   cond_br [[VAL_40]], ^bb5, ^bb6
    // CHECK: ^bb5:
    // CHECK:   [[VAL_41:%.*]] = cmpf "ugt", [[VAL_38]], [[VAL_39]] : f32
    // CHECK:   [[VAL_42:%.*]] = select [[VAL_41]], [[VAL_38]], [[VAL_39]] : f32
    // CHECK:   br ^bb7([[VAL_42]] : f32)
    // CHECK: ^bb6:
    // CHECK:   br ^bb7([[VAL_38]] : f32)
    // CHECK: ^bb7([[VAL_43:%.*]]: f32):
    // CHECK:   [[VAL_44:%.*]], [[VAL_45:%.*]] = gpu.shuffle [[VAL_43]], [[VAL_8]], [[VAL_32]] xor : f32
    // CHECK:   cond_br [[VAL_45]], ^bb8, ^bb9
    // CHECK: ^bb8:
    // CHECK:   [[VAL_46:%.*]] = cmpf "ugt", [[VAL_43]], [[VAL_44]] : f32
    // CHECK:   [[VAL_47:%.*]] = select [[VAL_46]], [[VAL_43]], [[VAL_44]] : f32
    // CHECK:   br ^bb10([[VAL_47]] : f32)
    // CHECK: ^bb9:
    // CHECK:   br ^bb10([[VAL_43]] : f32)
    // CHECK: ^bb10([[VAL_48:%.*]]: f32):
    // CHECK:   [[VAL_49:%.*]], [[VAL_50:%.*]] = gpu.shuffle [[VAL_48]], [[VAL_9]], [[VAL_32]] xor : f32
    // CHECK:   cond_br [[VAL_50]], ^bb11, ^bb12
    // CHECK: ^bb11:
    // CHECK:   [[VAL_51:%.*]] = cmpf "ugt", [[VAL_48]], [[VAL_49]] : f32
    // CHECK:   [[VAL_52:%.*]] = select [[VAL_51]], [[VAL_48]], [[VAL_49]] : f32
    // CHECK:   br ^bb13([[VAL_52]] : f32)
    // CHECK: ^bb12:
    // CHECK:   br ^bb13([[VAL_48]] : f32)
    // CHECK: ^bb13([[VAL_53:%.*]]: f32):
    // CHECK:   [[VAL_54:%.*]], [[VAL_55:%.*]] = gpu.shuffle [[VAL_53]], [[VAL_10]], [[VAL_32]] xor : f32
    // CHECK:   cond_br [[VAL_55]], ^bb14, ^bb15
    // CHECK: ^bb14:
    // CHECK:   [[VAL_56:%.*]] = cmpf "ugt", [[VAL_53]], [[VAL_54]] : f32
    // CHECK:   [[VAL_57:%.*]] = select [[VAL_56]], [[VAL_53]], [[VAL_54]] : f32
    // CHECK:   br ^bb16([[VAL_57]] : f32)
    // CHECK: ^bb15:
    // CHECK:   br ^bb16([[VAL_53]] : f32)
    // CHECK: ^bb16([[VAL_58:%.*]]: f32):
    // CHECK:   br ^bb18([[VAL_58]] : f32)
    // CHECK: ^bb17:
    // CHECK:   [[VAL_59:%.*]], [[VAL_60:%.*]] = gpu.shuffle [[VAL_0]], [[VAL_6]], [[VAL_5]] xor : f32
    // CHECK:   [[VAL_61:%.*]] = cmpf "ugt", [[VAL_0]], [[VAL_59]] : f32
    // CHECK:   [[VAL_62:%.*]] = select [[VAL_61]], [[VAL_0]], [[VAL_59]] : f32
    // CHECK:   [[VAL_63:%.*]], [[VAL_64:%.*]] = gpu.shuffle [[VAL_62]], [[VAL_7]], [[VAL_5]] xor : f32
    // CHECK:   [[VAL_65:%.*]] = cmpf "ugt", [[VAL_62]], [[VAL_63]] : f32
    // CHECK:   [[VAL_66:%.*]] = select [[VAL_65]], [[VAL_62]], [[VAL_63]] : f32
    // CHECK:   [[VAL_67:%.*]], [[VAL_68:%.*]] = gpu.shuffle [[VAL_66]], [[VAL_8]], [[VAL_5]] xor : f32
    // CHECK:   [[VAL_69:%.*]] = cmpf "ugt", [[VAL_66]], [[VAL_67]] : f32
    // CHECK:   [[VAL_70:%.*]] = select [[VAL_69]], [[VAL_66]], [[VAL_67]] : f32
    // CHECK:   [[VAL_71:%.*]], [[VAL_72:%.*]] = gpu.shuffle [[VAL_70]], [[VAL_9]], [[VAL_5]] xor : f32
    // CHECK:   [[VAL_73:%.*]] = cmpf "ugt", [[VAL_70]], [[VAL_71]] : f32
    // CHECK:   [[VAL_74:%.*]] = select [[VAL_73]], [[VAL_70]], [[VAL_71]] : f32
    // CHECK:   [[VAL_75:%.*]], [[VAL_76:%.*]] = gpu.shuffle [[VAL_74]], [[VAL_10]], [[VAL_5]] xor : f32
    // CHECK:   [[VAL_77:%.*]] = cmpf "ugt", [[VAL_74]], [[VAL_75]] : f32
    // CHECK:   [[VAL_78:%.*]] = select [[VAL_77]], [[VAL_74]], [[VAL_75]] : f32
    // CHECK:   br ^bb18([[VAL_78]] : f32)
    // CHECK: ^bb18([[VAL_79:%.*]]: f32):
    // CHECK:   cond_br [[VAL_30]], ^bb19, ^bb20
    // CHECK: ^bb19:
    // CHECK:   [[VAL_80:%.*]] = divi_signed [[VAL_27]], [[VAL_5]] : i32
    // CHECK:   [[VAL_81:%.*]] = index_cast [[VAL_80]] : i32 to index
    // CHECK:   store [[VAL_79]], [[VAL_1]]{{\[}}[[VAL_81]]] : memref<32xf32, 3>
    // CHECK:   br ^bb21
    // CHECK: ^bb20:
    // CHECK:   br ^bb21
    // CHECK: ^bb21:
    // CHECK:   gpu.barrier
    // CHECK:   [[VAL_82:%.*]] = addi [[VAL_28]], [[VAL_2]] : i32
    // CHECK:   [[VAL_83:%.*]] = divi_signed [[VAL_82]], [[VAL_5]] : i32
    // CHECK:   [[VAL_84:%.*]] = cmpi "slt", [[VAL_27]], [[VAL_83]] : i32
    // CHECK:   cond_br [[VAL_84]], ^bb22, ^bb41
    // CHECK: ^bb22:
    // CHECK:   [[VAL_85:%.*]] = index_cast [[VAL_27]] : i32 to index
    // CHECK:   [[VAL_86:%.*]] = load [[VAL_1]]{{\[}}[[VAL_85]]] : memref<32xf32, 3>
    // CHECK:   [[VAL_87:%.*]] = cmpi "slt", [[VAL_83]], [[VAL_5]] : i32
    // CHECK:   cond_br [[VAL_87]], ^bb23, ^bb39
    // CHECK: ^bb23:
    // CHECK:   [[VAL_88:%.*]], [[VAL_89:%.*]] = gpu.shuffle [[VAL_86]], [[VAL_6]], [[VAL_83]] xor : f32
    // CHECK:   cond_br [[VAL_89]], ^bb24, ^bb25
    // CHECK: ^bb24:
    // CHECK:   [[VAL_90:%.*]] = cmpf "ugt", [[VAL_86]], [[VAL_88]] : f32
    // CHECK:   [[VAL_91:%.*]] = select [[VAL_90]], [[VAL_86]], [[VAL_88]] : f32
    // CHECK:   br ^bb26([[VAL_91]] : f32)
    // CHECK: ^bb25:
    // CHECK:   br ^bb26([[VAL_86]] : f32)
    // CHECK: ^bb26([[VAL_92:%.*]]: f32):
    // CHECK:   [[VAL_93:%.*]], [[VAL_94:%.*]] = gpu.shuffle [[VAL_92]], [[VAL_7]], [[VAL_83]] xor : f32
    // CHECK:   cond_br [[VAL_94]], ^bb27, ^bb28
    // CHECK: ^bb27:
    // CHECK:   [[VAL_95:%.*]] = cmpf "ugt", [[VAL_92]], [[VAL_93]] : f32
    // CHECK:   [[VAL_96:%.*]] = select [[VAL_95]], [[VAL_92]], [[VAL_93]] : f32
    // CHECK:   br ^bb29([[VAL_96]] : f32)
    // CHECK: ^bb28:
    // CHECK:   br ^bb29([[VAL_92]] : f32)
    // CHECK: ^bb29([[VAL_97:%.*]]: f32):
    // CHECK:   [[VAL_98:%.*]], [[VAL_99:%.*]] = gpu.shuffle [[VAL_97]], [[VAL_8]], [[VAL_83]] xor : f32
    // CHECK:   cond_br [[VAL_99]], ^bb30, ^bb31
    // CHECK: ^bb30:
    // CHECK:   [[VAL_100:%.*]] = cmpf "ugt", [[VAL_97]], [[VAL_98]] : f32
    // CHECK:   [[VAL_101:%.*]] = select [[VAL_100]], [[VAL_97]], [[VAL_98]] : f32
    // CHECK:   br ^bb32([[VAL_101]] : f32)
    // CHECK: ^bb31:
    // CHECK:   br ^bb32([[VAL_97]] : f32)
    // CHECK: ^bb32([[VAL_102:%.*]]: f32):
    // CHECK:   [[VAL_103:%.*]], [[VAL_104:%.*]] = gpu.shuffle [[VAL_102]], [[VAL_9]], [[VAL_83]] xor : f32
    // CHECK:   cond_br [[VAL_104]], ^bb33, ^bb34
    // CHECK: ^bb33:
    // CHECK:   [[VAL_105:%.*]] = cmpf "ugt", [[VAL_102]], [[VAL_103]] : f32
    // CHECK:   [[VAL_106:%.*]] = select [[VAL_105]], [[VAL_102]], [[VAL_103]] : f32
    // CHECK:   br ^bb35([[VAL_106]] : f32)
    // CHECK: ^bb34:
    // CHECK:   br ^bb35([[VAL_102]] : f32)
    // CHECK: ^bb35([[VAL_107:%.*]]: f32):
    // CHECK:   [[VAL_108:%.*]], [[VAL_109:%.*]] = gpu.shuffle [[VAL_107]], [[VAL_10]], [[VAL_83]] xor : f32
    // CHECK:   cond_br [[VAL_109]], ^bb36, ^bb37
    // CHECK: ^bb36:
    // CHECK:   [[VAL_110:%.*]] = cmpf "ugt", [[VAL_107]], [[VAL_108]] : f32
    // CHECK:   [[VAL_111:%.*]] = select [[VAL_110]], [[VAL_107]], [[VAL_108]] : f32
    // CHECK:   br ^bb38([[VAL_111]] : f32)
    // CHECK: ^bb37:
    // CHECK:   br ^bb38([[VAL_107]] : f32)
    // CHECK: ^bb38([[VAL_112:%.*]]: f32):
    // CHECK:   br ^bb40([[VAL_112]] : f32)
    // CHECK: ^bb39:
    // CHECK:   [[VAL_113:%.*]], [[VAL_114:%.*]] = gpu.shuffle [[VAL_86]], [[VAL_6]], [[VAL_5]] xor : f32
    // CHECK:   [[VAL_115:%.*]] = cmpf "ugt", [[VAL_86]], [[VAL_113]] : f32
    // CHECK:   [[VAL_116:%.*]] = select [[VAL_115]], [[VAL_86]], [[VAL_113]] : f32
    // CHECK:   [[VAL_117:%.*]], [[VAL_118:%.*]] = gpu.shuffle [[VAL_116]], [[VAL_7]], [[VAL_5]] xor : f32
    // CHECK:   [[VAL_119:%.*]] = cmpf "ugt", [[VAL_116]], [[VAL_117]] : f32
    // CHECK:   [[VAL_120:%.*]] = select [[VAL_119]], [[VAL_116]], [[VAL_117]] : f32
    // CHECK:   [[VAL_121:%.*]], [[VAL_122:%.*]] = gpu.shuffle [[VAL_120]], [[VAL_8]], [[VAL_5]] xor : f32
    // CHECK:   [[VAL_123:%.*]] = cmpf "ugt", [[VAL_120]], [[VAL_121]] : f32
    // CHECK:   [[VAL_124:%.*]] = select [[VAL_123]], [[VAL_120]], [[VAL_121]] : f32
    // CHECK:   [[VAL_125:%.*]], [[VAL_126:%.*]] = gpu.shuffle [[VAL_124]], [[VAL_9]], [[VAL_5]] xor : f32
    // CHECK:   [[VAL_127:%.*]] = cmpf "ugt", [[VAL_124]], [[VAL_125]] : f32
    // CHECK:   [[VAL_128:%.*]] = select [[VAL_127]], [[VAL_124]], [[VAL_125]] : f32
    // CHECK:   [[VAL_129:%.*]], [[VAL_130:%.*]] = gpu.shuffle [[VAL_128]], [[VAL_10]], [[VAL_5]] xor : f32
    // CHECK:   [[VAL_131:%.*]] = cmpf "ugt", [[VAL_128]], [[VAL_129]] : f32
    // CHECK:   [[VAL_132:%.*]] = select [[VAL_131]], [[VAL_128]], [[VAL_129]] : f32
    // CHECK:   br ^bb40([[VAL_132]] : f32)
    // CHECK: ^bb40([[VAL_133:%.*]]: f32):
    // CHECK:   store [[VAL_133]], [[VAL_1]]{{\[}}[[VAL_4]]] : memref<32xf32, 3>
    // CHECK:   br ^bb42
    // CHECK: ^bb41:
    // CHECK:   br ^bb42
    // CHECK: ^bb42:
    // CHECK:   gpu.barrier
    %sum = "gpu.all_reduce"(%arg0) ({}) {op = "max"} : (f32) -> (f32)
    gpu.return
  }

}