LLJITWithOptimizingIRTransform.cpp 3.6 KB
//===-- LLJITWithOptimizingIRTransform.cpp -- LLJIT with IR optimization --===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
//
// In this example we will use an IR transform to optimize a module as it
// passes through LLJIT's IRTransformLayer.
//
//===----------------------------------------------------------------------===//

#include "llvm/ExecutionEngine/Orc/LLJIT.h"
#include "llvm/IR/LegacyPassManager.h"
#include "llvm/Support/InitLLVM.h"
#include "llvm/Support/TargetSelect.h"
#include "llvm/Support/raw_ostream.h"
#include "llvm/Transforms/IPO.h"
#include "llvm/Transforms/Scalar.h"

#include "../ExampleModules.h"

using namespace llvm;
using namespace llvm::orc;

ExitOnError ExitOnErr;

// Example IR module.
//
// This IR contains a recursive definition of the factorial function:
//
// fac(n) | n == 0    = 1
//        | otherwise = n * fac(n - 1)
//
// It also contains an entry function which calls the factorial function with
// an input value of 5.
//
// We expect the IR optimization transform that we build below to transform
// this into a non-recursive factorial function and an entry function that
// returns a constant value of 5!, or 120.

const llvm::StringRef MainMod =
    R"(

  define i32 @fac(i32 %n) {
  entry:
    %tobool = icmp eq i32 %n, 0
    br i1 %tobool, label %return, label %if.then

  if.then:                                          ; preds = %entry
    %arg = add nsw i32 %n, -1
    %call_result = call i32 @fac(i32 %arg)
    %result = mul nsw i32 %n, %call_result
    br label %return

  return:                                           ; preds = %entry, %if.then
    %final_result = phi i32 [ %result, %if.then ], [ 1, %entry ]
    ret i32 %final_result
  }

  define i32 @entry() {
  entry:
    %result = call i32 @fac(i32 5)
    ret i32 %result
  }

)";

// A function object that creates a simple pass pipeline to apply to each
// module as it passes through the IRTransformLayer.
class MyOptimizationTransform {
public:
  MyOptimizationTransform() : PM(std::make_unique<legacy::PassManager>()) {
    PM->add(createTailCallEliminationPass());
    PM->add(createFunctionInliningPass());
    PM->add(createIndVarSimplifyPass());
    PM->add(createCFGSimplificationPass());
  }

  Expected<ThreadSafeModule> operator()(ThreadSafeModule TSM,
                                        MaterializationResponsibility &R) {
    TSM.withModuleDo([this](Module &M) {
      dbgs() << "--- BEFORE OPTIMIZATION ---\n" << M << "\n";
      PM->run(M);
      dbgs() << "--- AFTER OPTIMIZATION ---\n" << M << "\n";
    });
    return std::move(TSM);
  }

private:
  std::unique_ptr<legacy::PassManager> PM;
};

int main(int argc, char *argv[]) {
  // Initialize LLVM.
  InitLLVM X(argc, argv);

  InitializeNativeTarget();
  InitializeNativeTargetAsmPrinter();

  ExitOnErr.setBanner(std::string(argv[0]) + ": ");

  // (1) Create LLJIT instance.
  auto J = ExitOnErr(LLJITBuilder().create());

  // (2) Install transform to optimize modules when they're materialized.
  J->getIRTransformLayer().setTransform(MyOptimizationTransform());

  // (3) Add modules.
  ExitOnErr(J->addIRModule(ExitOnErr(parseExampleModule(MainMod, "MainMod"))));

  // (4) Look up the JIT'd function and call it.
  auto EntrySym = ExitOnErr(J->lookup("entry"));
  auto *Entry = (int (*)())EntrySym.getAddress();

  int Result = Entry();
  outs() << "--- Result ---\n"
         << "entry() = " << Result << "\n";

  return 0;
}