LoopNestAnalysis.cpp 10.6 KB
//===- LoopNestAnalysis.cpp - Loop Nest Analysis --------------------------==//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
///
/// \file
/// The implementation for the loop nest analysis.
///
//===----------------------------------------------------------------------===//

#include "llvm/Analysis/LoopNestAnalysis.h"
#include "llvm/ADT/BreadthFirstIterator.h"
#include "llvm/ADT/Statistic.h"
#include "llvm/Analysis/PostDominators.h"
#include "llvm/Analysis/ValueTracking.h"

using namespace llvm;

#define DEBUG_TYPE "loopnest"
#ifndef NDEBUG
static const char *VerboseDebug = DEBUG_TYPE "-verbose";
#endif

/// Determine whether the loops structure violates basic requirements for
/// perfect nesting:
///  - the inner loop should be the outer loop's only child
///  - the outer loop header should 'flow' into the inner loop preheader
///    or jump around the inner loop to the outer loop latch
///  - if the inner loop latch exits the inner loop, it should 'flow' into
///    the outer loop latch.
/// Returns true if the loop structure satisfies the basic requirements and
/// false otherwise.
static bool checkLoopsStructure(const Loop &OuterLoop, const Loop &InnerLoop,
                                ScalarEvolution &SE);

//===----------------------------------------------------------------------===//
// LoopNest implementation
//

LoopNest::LoopNest(Loop &Root, ScalarEvolution &SE)
    : MaxPerfectDepth(getMaxPerfectDepth(Root, SE)) {
  for (Loop *L : breadth_first(&Root))
    Loops.push_back(L);
}

std::unique_ptr<LoopNest> LoopNest::getLoopNest(Loop &Root,
                                                ScalarEvolution &SE) {
  return std::make_unique<LoopNest>(Root, SE);
}

bool LoopNest::arePerfectlyNested(const Loop &OuterLoop, const Loop &InnerLoop,
                                  ScalarEvolution &SE) {
  assert(!OuterLoop.getSubLoops().empty() && "Outer loop should have subloops");
  assert(InnerLoop.getParentLoop() && "Inner loop should have a parent");
  LLVM_DEBUG(dbgs() << "Checking whether loop '" << OuterLoop.getName()
                    << "' and '" << InnerLoop.getName()
                    << "' are perfectly nested.\n");

  // Determine whether the loops structure satisfies the following requirements:
  //  - the inner loop should be the outer loop's only child
  //  - the outer loop header should 'flow' into the inner loop preheader
  //    or jump around the inner loop to the outer loop latch
  //  - if the inner loop latch exits the inner loop, it should 'flow' into
  //    the outer loop latch.
  if (!checkLoopsStructure(OuterLoop, InnerLoop, SE)) {
    LLVM_DEBUG(dbgs() << "Not perfectly nested: invalid loop structure.\n");
    return false;
  }

  // Bail out if we cannot retrieve the outer loop bounds.
  auto OuterLoopLB = OuterLoop.getBounds(SE);
  if (OuterLoopLB == None) {
    LLVM_DEBUG(dbgs() << "Cannot compute loop bounds of OuterLoop: "
                      << OuterLoop << "\n";);
    return false;
  }

  // Identify the outer loop latch comparison instruction.
  const BasicBlock *Latch = OuterLoop.getLoopLatch();
  assert(Latch && "Expecting a valid loop latch");
  const BranchInst *BI = dyn_cast<BranchInst>(Latch->getTerminator());
  assert(BI && BI->isConditional() &&
         "Expecting loop latch terminator to be a branch instruction");

  const CmpInst *OuterLoopLatchCmp = dyn_cast<CmpInst>(BI->getCondition());
  DEBUG_WITH_TYPE(
      VerboseDebug, if (OuterLoopLatchCmp) {
        dbgs() << "Outer loop latch compare instruction: " << *OuterLoopLatchCmp
               << "\n";
      });

  // Identify the inner loop guard instruction.
  BranchInst *InnerGuard = InnerLoop.getLoopGuardBranch();
  const CmpInst *InnerLoopGuardCmp =
      (InnerGuard) ? dyn_cast<CmpInst>(InnerGuard->getCondition()) : nullptr;

  DEBUG_WITH_TYPE(
      VerboseDebug, if (InnerLoopGuardCmp) {
        dbgs() << "Inner loop guard compare instruction: " << *InnerLoopGuardCmp
               << "\n";
      });

  // Determine whether instructions in a basic block are one of:
  //  - the inner loop guard comparison
  //  - the outer loop latch comparison
  //  - the outer loop induction variable increment
  //  - a phi node, a cast or a branch
  auto containsOnlySafeInstructions = [&](const BasicBlock &BB) {
    return llvm::all_of(BB, [&](const Instruction &I) {
      bool isAllowed = isSafeToSpeculativelyExecute(&I) || isa<PHINode>(I) ||
                       isa<BranchInst>(I);
      if (!isAllowed) {
        DEBUG_WITH_TYPE(VerboseDebug, {
          dbgs() << "Instruction: " << I << "\nin basic block: " << BB
                 << " is considered unsafe.\n";
        });
        return false;
      }

      // The only binary instruction allowed is the outer loop step instruction,
      // the only comparison instructions allowed are the inner loop guard
      // compare instruction and the outer loop latch compare instruction.
      if ((isa<BinaryOperator>(I) && &I != &OuterLoopLB->getStepInst()) ||
          (isa<CmpInst>(I) && &I != OuterLoopLatchCmp &&
           &I != InnerLoopGuardCmp)) {
        DEBUG_WITH_TYPE(VerboseDebug, {
          dbgs() << "Instruction: " << I << "\nin basic block:" << BB
                 << "is unsafe.\n";
        });
        return false;
      }
      return true;
    });
  };

  // Check the code surrounding the inner loop for instructions that are deemed
  // unsafe.
  const BasicBlock *OuterLoopHeader = OuterLoop.getHeader();
  const BasicBlock *OuterLoopLatch = OuterLoop.getLoopLatch();
  const BasicBlock *InnerLoopPreHeader = InnerLoop.getLoopPreheader();

  if (!containsOnlySafeInstructions(*OuterLoopHeader) ||
      !containsOnlySafeInstructions(*OuterLoopLatch) ||
      (InnerLoopPreHeader != OuterLoopHeader &&
       !containsOnlySafeInstructions(*InnerLoopPreHeader)) ||
      !containsOnlySafeInstructions(*InnerLoop.getExitBlock())) {
    LLVM_DEBUG(dbgs() << "Not perfectly nested: code surrounding inner loop is "
                         "unsafe\n";);
    return false;
  }

  LLVM_DEBUG(dbgs() << "Loop '" << OuterLoop.getName() << "' and '"
                    << InnerLoop.getName() << "' are perfectly nested.\n");

  return true;
}

SmallVector<LoopVectorTy, 4>
LoopNest::getPerfectLoops(ScalarEvolution &SE) const {
  SmallVector<LoopVectorTy, 4> LV;
  LoopVectorTy PerfectNest;

  for (Loop *L : depth_first(const_cast<Loop *>(Loops.front()))) {
    if (PerfectNest.empty())
      PerfectNest.push_back(L);

    auto &SubLoops = L->getSubLoops();
    if (SubLoops.size() == 1 && arePerfectlyNested(*L, *SubLoops.front(), SE)) {
      PerfectNest.push_back(SubLoops.front());
    } else {
      LV.push_back(PerfectNest);
      PerfectNest.clear();
    }
  }

  return LV;
}

unsigned LoopNest::getMaxPerfectDepth(const Loop &Root, ScalarEvolution &SE) {
  LLVM_DEBUG(dbgs() << "Get maximum perfect depth of loop nest rooted by loop '"
                    << Root.getName() << "'\n");

  const Loop *CurrentLoop = &Root;
  const auto *SubLoops = &CurrentLoop->getSubLoops();
  unsigned CurrentDepth = 1;

  while (SubLoops->size() == 1) {
    const Loop *InnerLoop = SubLoops->front();
    if (!arePerfectlyNested(*CurrentLoop, *InnerLoop, SE)) {
      LLVM_DEBUG({
        dbgs() << "Not a perfect nest: loop '" << CurrentLoop->getName()
               << "' is not perfectly nested with loop '"
               << InnerLoop->getName() << "'\n";
      });
      break;
    }

    CurrentLoop = InnerLoop;
    SubLoops = &CurrentLoop->getSubLoops();
    ++CurrentDepth;
  }

  return CurrentDepth;
}

static bool checkLoopsStructure(const Loop &OuterLoop, const Loop &InnerLoop,
                                ScalarEvolution &SE) {
  // The inner loop must be the only outer loop's child.
  if ((OuterLoop.getSubLoops().size() != 1) ||
      (InnerLoop.getParentLoop() != &OuterLoop))
    return false;

  // We expect loops in normal form which have a preheader, header, latch...
  if (!OuterLoop.isLoopSimplifyForm() || !InnerLoop.isLoopSimplifyForm())
    return false;

  const BasicBlock *OuterLoopHeader = OuterLoop.getHeader();
  const BasicBlock *OuterLoopLatch = OuterLoop.getLoopLatch();
  const BasicBlock *InnerLoopPreHeader = InnerLoop.getLoopPreheader();
  const BasicBlock *InnerLoopLatch = InnerLoop.getLoopLatch();
  const BasicBlock *InnerLoopExit = InnerLoop.getExitBlock();

  // We expect rotated loops. The inner loop should have a single exit block.
  if (OuterLoop.getExitingBlock() != OuterLoopLatch ||
      InnerLoop.getExitingBlock() != InnerLoopLatch || !InnerLoopExit)
    return false;

  // Ensure the only branch that may exist between the loops is the inner loop
  // guard.
  if (OuterLoopHeader != InnerLoopPreHeader) {
    const BranchInst *BI =
        dyn_cast<BranchInst>(OuterLoopHeader->getTerminator());

    if (!BI || BI != InnerLoop.getLoopGuardBranch())
      return false;

    // The successors of the inner loop guard should be the inner loop
    // preheader and the outer loop latch.
    for (const BasicBlock *Succ : BI->successors()) {
      if (Succ == InnerLoopPreHeader)
        continue;
      if (Succ == OuterLoopLatch)
        continue;

      DEBUG_WITH_TYPE(VerboseDebug, {
        dbgs() << "Inner loop guard successor " << Succ->getName()
               << " doesn't lead to inner loop preheader or "
                  "outer loop latch.\n";
      });
      return false;
    }
  }

  // Ensure the inner loop exit block leads to the outer loop latch.
  if (InnerLoopExit->getSingleSuccessor() != OuterLoopLatch) {
    DEBUG_WITH_TYPE(
        VerboseDebug,
        dbgs() << "Inner loop exit block " << *InnerLoopExit
               << " does not directly lead to the outer loop latch.\n";);
    return false;
  }

  return true;
}

raw_ostream &llvm::operator<<(raw_ostream &OS, const LoopNest &LN) {
  OS << "IsPerfect=";
  if (LN.getMaxPerfectDepth() == LN.getNestDepth())
    OS << "true";
  else
    OS << "false";
  OS << ", Depth=" << LN.getNestDepth();
  OS << ", OutermostLoop: " << LN.getOutermostLoop().getName();
  OS << ", Loops: ( ";
  for (const Loop *L : LN.getLoops())
    OS << L->getName() << " ";
  OS << ")";

  return OS;
}

//===----------------------------------------------------------------------===//
// LoopNestPrinterPass implementation
//

PreservedAnalyses LoopNestPrinterPass::run(Loop &L, LoopAnalysisManager &AM,
                                           LoopStandardAnalysisResults &AR,
                                           LPMUpdater &U) {
  if (auto LN = LoopNest::getLoopNest(L, AR.SE))
    OS << *LN << "\n";

  return PreservedAnalyses::all();
}