InliningUtils.cpp 14.4 KB
//===- InliningUtils.cpp ---- Misc utilities for inlining -----------------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
//
// This file implements miscellaneous inlining utilities.
//
//===----------------------------------------------------------------------===//

#include "mlir/Transforms/InliningUtils.h"

#include "mlir/IR/BlockAndValueMapping.h"
#include "mlir/IR/Builders.h"
#include "mlir/IR/Function.h"
#include "mlir/IR/Operation.h"
#include "llvm/ADT/MapVector.h"
#include "llvm/Support/Debug.h"
#include "llvm/Support/raw_ostream.h"

#define DEBUG_TYPE "inlining"

using namespace mlir;

/// Remap locations from the inlined blocks with CallSiteLoc locations with the
/// provided caller location.
static void
remapInlinedLocations(iterator_range<Region::iterator> inlinedBlocks,
                      Location callerLoc) {
  DenseMap<Location, Location> mappedLocations;
  auto remapOpLoc = [&](Operation *op) {
    auto it = mappedLocations.find(op->getLoc());
    if (it == mappedLocations.end()) {
      auto newLoc = CallSiteLoc::get(op->getLoc(), callerLoc);
      it = mappedLocations.try_emplace(op->getLoc(), newLoc).first;
    }
    op->setLoc(it->second);
  };
  for (auto &block : inlinedBlocks)
    block.walk(remapOpLoc);
}

static void remapInlinedOperands(iterator_range<Region::iterator> inlinedBlocks,
                                 BlockAndValueMapping &mapper) {
  auto remapOperands = [&](Operation *op) {
    for (auto &operand : op->getOpOperands())
      if (auto mappedOp = mapper.lookupOrNull(operand.get()))
        operand.set(mappedOp);
  };
  for (auto &block : inlinedBlocks)
    block.walk(remapOperands);
}

//===----------------------------------------------------------------------===//
// InlinerInterface
//===----------------------------------------------------------------------===//

bool InlinerInterface::isLegalToInline(
    Region *dest, Region *src, BlockAndValueMapping &valueMapping) const {
  // Regions can always be inlined into functions.
  if (isa<FuncOp>(dest->getParentOp()))
    return true;

  auto *handler = getInterfaceFor(dest->getParentOp());
  return handler ? handler->isLegalToInline(dest, src, valueMapping) : false;
}

bool InlinerInterface::isLegalToInline(
    Operation *op, Region *dest, BlockAndValueMapping &valueMapping) const {
  auto *handler = getInterfaceFor(op);
  return handler ? handler->isLegalToInline(op, dest, valueMapping) : false;
}

bool InlinerInterface::shouldAnalyzeRecursively(Operation *op) const {
  auto *handler = getInterfaceFor(op);
  return handler ? handler->shouldAnalyzeRecursively(op) : true;
}

/// Handle the given inlined terminator by replacing it with a new operation
/// as necessary.
void InlinerInterface::handleTerminator(Operation *op, Block *newDest) const {
  auto *handler = getInterfaceFor(op);
  assert(handler && "expected valid dialect handler");
  handler->handleTerminator(op, newDest);
}

/// Handle the given inlined terminator by replacing it with a new operation
/// as necessary.
void InlinerInterface::handleTerminator(Operation *op,
                                        ArrayRef<Value> valuesToRepl) const {
  auto *handler = getInterfaceFor(op);
  assert(handler && "expected valid dialect handler");
  handler->handleTerminator(op, valuesToRepl);
}

/// Utility to check that all of the operations within 'src' can be inlined.
static bool isLegalToInline(InlinerInterface &interface, Region *src,
                            Region *insertRegion,
                            BlockAndValueMapping &valueMapping) {
  for (auto &block : *src) {
    for (auto &op : block) {
      // Check this operation.
      if (!interface.isLegalToInline(&op, insertRegion, valueMapping)) {
        LLVM_DEBUG({
          llvm::dbgs() << "* Illegal to inline because of op: ";
          op.dump();
        });
        return false;
      }
      // Check any nested regions.
      if (interface.shouldAnalyzeRecursively(&op) &&
          llvm::any_of(op.getRegions(), [&](Region &region) {
            return !isLegalToInline(interface, &region, insertRegion,
                                    valueMapping);
          }))
        return false;
    }
  }
  return true;
}

//===----------------------------------------------------------------------===//
// Inline Methods
//===----------------------------------------------------------------------===//

LogicalResult mlir::inlineRegion(InlinerInterface &interface, Region *src,
                                 Operation *inlinePoint,
                                 BlockAndValueMapping &mapper,
                                 ValueRange resultsToReplace,
                                 TypeRange regionResultTypes,
                                 Optional<Location> inlineLoc,
                                 bool shouldCloneInlinedRegion) {
  assert(resultsToReplace.size() == regionResultTypes.size());
  // We expect the region to have at least one block.
  if (src->empty())
    return failure();

  // Check that all of the region arguments have been mapped.
  auto *srcEntryBlock = &src->front();
  if (llvm::any_of(srcEntryBlock->getArguments(),
                   [&](BlockArgument arg) { return !mapper.contains(arg); }))
    return failure();

  // The insertion point must be within a block.
  Block *insertBlock = inlinePoint->getBlock();
  if (!insertBlock)
    return failure();
  Region *insertRegion = insertBlock->getParent();

  // Check that the operations within the source region are valid to inline.
  if (!interface.isLegalToInline(insertRegion, src, mapper) ||
      !isLegalToInline(interface, src, insertRegion, mapper))
    return failure();

  // Split the insertion block.
  Block *postInsertBlock =
      insertBlock->splitBlock(++inlinePoint->getIterator());

  // Check to see if the region is being cloned, or moved inline. In either
  // case, move the new blocks after the 'insertBlock' to improve IR
  // readability.
  if (shouldCloneInlinedRegion)
    src->cloneInto(insertRegion, postInsertBlock->getIterator(), mapper);
  else
    insertRegion->getBlocks().splice(postInsertBlock->getIterator(),
                                     src->getBlocks(), src->begin(),
                                     src->end());

  // Get the range of newly inserted blocks.
  auto newBlocks = llvm::make_range(std::next(insertBlock->getIterator()),
                                    postInsertBlock->getIterator());
  Block *firstNewBlock = &*newBlocks.begin();

  // Remap the locations of the inlined operations if a valid source location
  // was provided.
  if (inlineLoc && !inlineLoc->isa<UnknownLoc>())
    remapInlinedLocations(newBlocks, *inlineLoc);

  // If the blocks were moved in-place, make sure to remap any necessary
  // operands.
  if (!shouldCloneInlinedRegion)
    remapInlinedOperands(newBlocks, mapper);

  // Process the newly inlined blocks.
  interface.processInlinedBlocks(newBlocks);

  // Handle the case where only a single block was inlined.
  if (std::next(newBlocks.begin()) == newBlocks.end()) {
    // Have the interface handle the terminator of this block.
    auto *firstBlockTerminator = firstNewBlock->getTerminator();
    interface.handleTerminator(firstBlockTerminator,
                               llvm::to_vector<6>(resultsToReplace));
    firstBlockTerminator->erase();

    // Merge the post insert block into the cloned entry block.
    firstNewBlock->getOperations().splice(firstNewBlock->end(),
                                          postInsertBlock->getOperations());
    postInsertBlock->erase();
  } else {
    // Otherwise, there were multiple blocks inlined. Add arguments to the post
    // insertion block to represent the results to replace.
    for (auto resultToRepl : llvm::enumerate(resultsToReplace)) {
      resultToRepl.value().replaceAllUsesWith(postInsertBlock->addArgument(
          regionResultTypes[resultToRepl.index()]));
    }

    /// Handle the terminators for each of the new blocks.
    for (auto &newBlock : newBlocks)
      interface.handleTerminator(newBlock.getTerminator(), postInsertBlock);
  }

  // Splice the instructions of the inlined entry block into the insert block.
  insertBlock->getOperations().splice(insertBlock->end(),
                                      firstNewBlock->getOperations());
  firstNewBlock->erase();
  return success();
}

/// This function is an overload of the above 'inlineRegion' that allows for
/// providing the set of operands ('inlinedOperands') that should be used
/// in-favor of the region arguments when inlining.
LogicalResult mlir::inlineRegion(InlinerInterface &interface, Region *src,
                                 Operation *inlinePoint,
                                 ValueRange inlinedOperands,
                                 ValueRange resultsToReplace,
                                 Optional<Location> inlineLoc,
                                 bool shouldCloneInlinedRegion) {
  // We expect the region to have at least one block.
  if (src->empty())
    return failure();

  auto *entryBlock = &src->front();
  if (inlinedOperands.size() != entryBlock->getNumArguments())
    return failure();

  // Map the provided call operands to the arguments of the region.
  BlockAndValueMapping mapper;
  for (unsigned i = 0, e = inlinedOperands.size(); i != e; ++i) {
    // Verify that the types of the provided values match the function argument
    // types.
    BlockArgument regionArg = entryBlock->getArgument(i);
    if (inlinedOperands[i].getType() != regionArg.getType())
      return failure();
    mapper.map(regionArg, inlinedOperands[i]);
  }

  // Call into the main region inliner function.
  return inlineRegion(interface, src, inlinePoint, mapper, resultsToReplace,
                      resultsToReplace.getTypes(), inlineLoc,
                      shouldCloneInlinedRegion);
}

/// Utility function used to generate a cast operation from the given interface,
/// or return nullptr if a cast could not be generated.
static Value materializeConversion(const DialectInlinerInterface *interface,
                                   SmallVectorImpl<Operation *> &castOps,
                                   OpBuilder &castBuilder, Value arg, Type type,
                                   Location conversionLoc) {
  if (!interface)
    return nullptr;

  // Check to see if the interface for the call can materialize a conversion.
  Operation *castOp = interface->materializeCallConversion(castBuilder, arg,
                                                           type, conversionLoc);
  if (!castOp)
    return nullptr;
  castOps.push_back(castOp);

  // Ensure that the generated cast is correct.
  assert(castOp->getNumOperands() == 1 && castOp->getOperand(0) == arg &&
         castOp->getNumResults() == 1 && *castOp->result_type_begin() == type);
  return castOp->getResult(0);
}

/// This function inlines a given region, 'src', of a callable operation,
/// 'callable', into the location defined by the given call operation. This
/// function returns failure if inlining is not possible, success otherwise. On
/// failure, no changes are made to the module. 'shouldCloneInlinedRegion'
/// corresponds to whether the source region should be cloned into the 'call' or
/// spliced directly.
LogicalResult mlir::inlineCall(InlinerInterface &interface,
                               CallOpInterface call,
                               CallableOpInterface callable, Region *src,
                               bool shouldCloneInlinedRegion) {
  // We expect the region to have at least one block.
  if (src->empty())
    return failure();
  auto *entryBlock = &src->front();
  ArrayRef<Type> callableResultTypes = callable.getCallableResults();

  // Make sure that the number of arguments and results matchup between the call
  // and the region.
  SmallVector<Value, 8> callOperands(call.getArgOperands());
  SmallVector<Value, 8> callResults(call.getOperation()->getResults());
  if (callOperands.size() != entryBlock->getNumArguments() ||
      callResults.size() != callableResultTypes.size())
    return failure();

  // A set of cast operations generated to matchup the signature of the region
  // with the signature of the call.
  SmallVector<Operation *, 4> castOps;
  castOps.reserve(callOperands.size() + callResults.size());

  // Functor used to cleanup generated state on failure.
  auto cleanupState = [&] {
    for (auto *op : castOps) {
      op->getResult(0).replaceAllUsesWith(op->getOperand(0));
      op->erase();
    }
    return failure();
  };

  // Builder used for any conversion operations that need to be materialized.
  OpBuilder castBuilder(call);
  Location castLoc = call.getLoc();
  auto *callInterface = interface.getInterfaceFor(call.getDialect());

  // Map the provided call operands to the arguments of the region.
  BlockAndValueMapping mapper;
  for (unsigned i = 0, e = callOperands.size(); i != e; ++i) {
    BlockArgument regionArg = entryBlock->getArgument(i);
    Value operand = callOperands[i];

    // If the call operand doesn't match the expected region argument, try to
    // generate a cast.
    Type regionArgType = regionArg.getType();
    if (operand.getType() != regionArgType) {
      if (!(operand = materializeConversion(callInterface, castOps, castBuilder,
                                            operand, regionArgType, castLoc)))
        return cleanupState();
    }
    mapper.map(regionArg, operand);
  }

  // Ensure that the resultant values of the call match the callable.
  castBuilder.setInsertionPointAfter(call);
  for (unsigned i = 0, e = callResults.size(); i != e; ++i) {
    Value callResult = callResults[i];
    if (callResult.getType() == callableResultTypes[i])
      continue;

    // Generate a conversion that will produce the original type, so that the IR
    // is still valid after the original call gets replaced.
    Value castResult =
        materializeConversion(callInterface, castOps, castBuilder, callResult,
                              callResult.getType(), castLoc);
    if (!castResult)
      return cleanupState();
    callResult.replaceAllUsesWith(castResult);
    castResult.getDefiningOp()->replaceUsesOfWith(castResult, callResult);
  }

  // Attempt to inline the call.
  if (failed(inlineRegion(interface, src, call, mapper, callResults,
                          callableResultTypes, call.getLoc(),
                          shouldCloneInlinedRegion)))
    return cleanupState();
  return success();
}