LinalgToSPIRV.cpp 8.5 KB
//===- LinalgToSPIRV.cpp - Linalg to SPIR-V dialect conversion ------------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//

#include "mlir/Conversion/LinalgToSPIRV/LinalgToSPIRV.h"
#include "mlir/Dialect/Linalg/IR/LinalgOps.h"
#include "mlir/Dialect/Linalg/Utils/Utils.h"
#include "mlir/Dialect/SPIRV/SPIRVDialect.h"
#include "mlir/Dialect/SPIRV/SPIRVLowering.h"
#include "mlir/Dialect/SPIRV/SPIRVOps.h"
#include "mlir/Dialect/StandardOps/IR/Ops.h"
#include "mlir/Dialect/Utils/StructuredOpsUtils.h"
#include "mlir/IR/AffineExpr.h"

using namespace mlir;

//===----------------------------------------------------------------------===//
// Utilities
//===----------------------------------------------------------------------===//

/// Returns a `Value` containing the `dim`-th dimension's size of SPIR-V
/// location invocation ID. This function will create necessary operations with
/// `builder` at the proper region containing `op`.
static Value getLocalInvocationDimSize(Operation *op, int dim, Location loc,
                                       OpBuilder *builder) {
  assert(dim >= 0 && dim < 3 && "local invocation only has three dimensions");
  Value invocation = spirv::getBuiltinVariableValue(
      op, spirv::BuiltIn::LocalInvocationId, *builder);
  Type xType = invocation.getType().cast<ShapedType>().getElementType();
  return builder->create<spirv::CompositeExtractOp>(
      loc, xType, invocation, builder->getI32ArrayAttr({dim}));
}

//===----------------------------------------------------------------------===//
// Reduction (single workgroup)
//===----------------------------------------------------------------------===//

namespace {

/// A pattern to convert a linalg.generic op to SPIR-V ops under the condition
/// that the linalg.generic op is performing reduction with a workload size that
/// can fit in one workgroup.
class SingleWorkgroupReduction final
    : public SPIRVOpLowering<linalg::GenericOp> {
public:
  using SPIRVOpLowering<linalg::GenericOp>::SPIRVOpLowering;

  /// Matches the given linalg.generic op as performing reduction and returns
  /// the binary op kind if successful.
  static Optional<linalg::RegionMatcher::BinaryOpKind>
  matchAsPerformingReduction(linalg::GenericOp genericOp);

  LogicalResult
  matchAndRewrite(linalg::GenericOp genericOp, ArrayRef<Value> operands,
                  ConversionPatternRewriter &rewriter) const override;
};

} // namespace

Optional<linalg::RegionMatcher::BinaryOpKind>
SingleWorkgroupReduction::matchAsPerformingReduction(
    linalg::GenericOp genericOp) {
  Operation *op = genericOp.getOperation();

  // Make sure the linalg.generic is working on memrefs.
  if (!genericOp.hasBufferSemantics())
    return llvm::None;

  // Make sure this is reduction with one input and one output.
  if (genericOp.getNumInputs() != 1 || genericOp.getNumOutputs() != 1)
    return llvm::None;

  auto originalInputType = op->getOperand(0).getType().cast<MemRefType>();
  auto originalOutputType = op->getOperand(1).getType().cast<MemRefType>();

  // Make sure the original input has one dimension.
  if (!originalInputType.hasStaticShape() || originalInputType.getRank() != 1)
    return llvm::None;
  // Make sure the original output has one element.
  if (!originalOutputType.hasStaticShape() ||
      originalOutputType.getNumElements() != 1)
    return llvm::None;

  if (!genericOp.hasSingleReductionLoop())
    return llvm::None;

  if (genericOp.indexing_maps().getValue().size() != 2)
    return llvm::None;

  // TODO: create utility functions for these checks in Linalg
  // and use them.
  auto inputMap = genericOp.indexing_maps().getValue()[0].cast<AffineMapAttr>();
  auto outputMap =
      genericOp.indexing_maps().getValue()[1].cast<AffineMapAttr>();
  // The indexing map for the input should be `(i) -> (i)`.
  if (inputMap.getValue() !=
      AffineMap::get(1, 0, getAffineDimExpr(0, op->getContext())))
    return llvm::None;
  // The indexing map for the input should be `(i) -> (0)`.
  if (outputMap.getValue() !=
      AffineMap::get(1, 0, getAffineConstantExpr(0, op->getContext())))
    return llvm::None;

  return linalg::RegionMatcher::matchAsScalarBinaryOp(genericOp);
}

LogicalResult SingleWorkgroupReduction::matchAndRewrite(
    linalg::GenericOp genericOp, ArrayRef<Value> operands,
    ConversionPatternRewriter &rewriter) const {
  Operation *op = genericOp.getOperation();
  auto originalInputType = op->getOperand(0).getType().cast<MemRefType>();
  auto originalOutputType = op->getOperand(1).getType().cast<MemRefType>();

  auto binaryOpKind = matchAsPerformingReduction(genericOp);
  if (!binaryOpKind)
    return failure();

  // Query the shader interface for local workgroup size to make sure the
  // invocation configuration fits with the input memref's shape.
  DenseIntElementsAttr localSize = spirv::lookupLocalWorkGroupSize(genericOp);
  if (!localSize)
    return failure();

  if ((*localSize.begin()).getSExtValue() != originalInputType.getDimSize(0))
    return failure();
  if (llvm::any_of(llvm::drop_begin(localSize.getIntValues(), 1),
                   [](const APInt &size) { return !size.isOneValue(); }))
    return failure();

  // TODO: Query the target environment to make sure the current
  // workload fits in a local workgroup.

  Value convertedInput = operands[0], convertedOutput = operands[1];
  Location loc = genericOp.getLoc();

  // Get the invocation ID.
  Value x = getLocalInvocationDimSize(genericOp, /*dim=*/0, loc, &rewriter);

  // TODO: Load to Workgroup storage class first.

  // Get the input element accessed by this invocation.
  Value inputElementPtr = spirv::getElementPtr(
      typeConverter, originalInputType, convertedInput, {x}, loc, rewriter);
  Value inputElement = rewriter.create<spirv::LoadOp>(loc, inputElementPtr);

  // Perform the group reduction operation.
  Value groupOperation;
#define CREATE_GROUP_NON_UNIFORM_BIN_OP(opKind, spvOp)                         \
  case linalg::RegionMatcher::BinaryOpKind::opKind: {                          \
    groupOperation = rewriter.create<spirv::spvOp>(                            \
        loc, originalInputType.getElementType(), spirv::Scope::Subgroup,       \
        spirv::GroupOperation::Reduce, inputElement,                           \
        /*cluster_size=*/nullptr);                                             \
  } break
  switch (*binaryOpKind) {
    CREATE_GROUP_NON_UNIFORM_BIN_OP(IAdd, GroupNonUniformIAddOp);
  }
#undef CREATE_GROUP_NON_UNIFORM_BIN_OP

  // Get the output element accessed by this reduction.
  Value zero = spirv::ConstantOp::getZero(
      typeConverter.getIndexType(rewriter.getContext()), loc, rewriter);
  SmallVector<Value, 1> zeroIndices(originalOutputType.getRank(), zero);
  Value outputElementPtr =
      spirv::getElementPtr(typeConverter, originalOutputType, convertedOutput,
                           zeroIndices, loc, rewriter);

  // Write out the final reduction result. This should be only conducted by one
  // invocation. We use spv.GroupNonUniformElect to find the invocation with the
  // lowest ID.
  //
  // ```
  // if (spv.GroupNonUniformElect) { output = ... }
  // ```

  Value condition = rewriter.create<spirv::GroupNonUniformElectOp>(
      loc, spirv::Scope::Subgroup);

  auto createAtomicOp = [&](OpBuilder &builder) {
#define CREATE_ATOMIC_BIN_OP(opKind, spvOp)                                    \
  case linalg::RegionMatcher::BinaryOpKind::opKind: {                          \
    builder.create<spirv::spvOp>(loc, outputElementPtr, spirv::Scope::Device,  \
                                 spirv::MemorySemantics::AcquireRelease,       \
                                 groupOperation);                              \
  } break
    switch (*binaryOpKind) { CREATE_ATOMIC_BIN_OP(IAdd, AtomicIAddOp); }
#undef CREATE_ATOMIC_BIN_OP
  };

  spirv::SelectionOp::createIfThen(loc, condition, createAtomicOp, rewriter);

  rewriter.eraseOp(genericOp);
  return success();
}

//===----------------------------------------------------------------------===//
// Pattern population
//===----------------------------------------------------------------------===//

void mlir::populateLinalgToSPIRVPatterns(MLIRContext *context,
                                         SPIRVTypeConverter &typeConverter,
                                         OwningRewritePatternList &patterns) {
  patterns.insert<SingleWorkgroupReduction>(context, typeConverter);
}