SecureIO.cpp 4.76 KB
#include "FuncHelper.h"
#include "llvm/IR/CFG.h"
#include "llvm/IR/Function.h"
#include "llvm/IR/GlobalValue.h"
#include "llvm/IR/Instruction.h"
#include "llvm/IR/Instructions.h"
#include "llvm/IR/Module.h"
#include "llvm/IR/Operator.h"
#include "llvm/Pass.h"
#include "llvm/Support/Alignment.h"
#include "llvm/Support/raw_ostream.h"
#include "llvm/Transforms/Utils/Cloning.h"
#include <fstream>
#include <iostream>
using namespace llvm;

namespace {
struct SecureIO : public ModulePass {
  static char ID;
  std::vector<Instruction *> inst_lst;
  SecureIO() : ModulePass(ID) {}

  void insertFunc(Instruction &I, Module *m) {
    for (int i = 0; i < I.getNumOperands(); i++) {
      Value *v = I.getOperand(i);
      if (const ConstantExpr *CE = dyn_cast<ConstantExpr>(v)) {
        Value *first = CE->getOperand(0);
        if (const ConstantInt *CI = dyn_cast<ConstantInt>(first)) {
          const uint64_t *a = CI->getValue().getRawData();

          if (*a >= (uint64_t)0x40000000UL &&
              *a <= (uint64_t)0x500E0000) { // I/O Address에 접근할 경우


            std::vector<Type *> newf_args;
            if (I.getOpcode() == Instruction::Store)
              newf_args.push_back(I.getOperand(0)->getType());
            FunctionType *newf_type = FunctionType::get(
                IntegerType::get(m->getContext(), 32), newf_args, false);
            Function *newF = Function::Create(
                newf_type, GlobalValue::ExternalLinkage, "secure_io", m);
            Attribute attr =
                Attribute::get(m->getContext(), "cmse_nonsecure_entry");
            AttributeList newf_att_list;
            SmallVector<AttributeList, 4> Attrs;
            AttributeList PAS;
            AttrBuilder B;
            B.addAttribute(attr);
            PAS = AttributeList::get(m->getContext(), ~0U, B);
            Attrs.push_back(PAS);
            newf_att_list = AttributeList::get(m->getContext(), Attrs);
            newF->setAttributes(newf_att_list);
            FunctionType *newf_type_extern = FunctionType::get(
                IntegerType::get(I.getModule()->getContext(), 32), newf_args,
                false);
            Function *newF_Extern =
                Function::Create(newf_type_extern, GlobalValue::ExternalLinkage,
                                 newF->getName(), I.getModule());
            Instruction *newinst = I.clone();
            Value *opv = newinst->getOperand(0);
            auto *ai = newF->arg_begin();
            if (I.getOpcode() == Instruction::Store)
              newinst->setOperand(0, ai);
            newinst->setDebugLoc(nullptr);

            BasicBlock *entry =
                BasicBlock::Create(m->getContext(), "entry", newF);
            std::vector<Type *> func_enterSecure_ty_args;
            FunctionType *func_enterSecure_ty =
                FunctionType::get(IntegerType::get(m->getContext(), 32),
                                  func_enterSecure_ty_args, false);
            FunctionCallee func_enterSecure =
                m->getOrInsertFunction("enterSecure", func_enterSecure_ty);
            CallInst *call_func_enterSecure =
                CallInst::Create(func_enterSecure, "", entry);
            call_func_enterSecure->setCallingConv(CallingConv::C);
            call_func_enterSecure->setTailCall(false);
            newinst->insertBefore(call_func_enterSecure);

            ReturnInst *ret = ReturnInst::Create(
                m->getContext(),
                ConstantInt::get(m->getContext(),
                                 APInt(32, StringRef(std::to_string(0)), 10)),
                entry);

            std::vector<Value *> newfu_c_arg = {opv};
            CallInst *newfu_c =
                CallInst::Create(newF_Extern, newfu_c_arg, "", &I);
            newfu_c->setCallingConv(CallingConv::C);
            newfu_c->setTailCall(false);
            AttributeList Func_Read_OTP_Call_4_PAL;
            newfu_c->setAttributes(Func_Read_OTP_Call_4_PAL);
            inst_lst.push_back(&I);
          }
        }
      }
    }
  }

  bool runOnModule(Module &MOD) override {
    LLVMContext context;
    Module *m = new llvm::Module("test", MOD.getContext());
    std::vector<Function *> f;
    for (auto &F : MOD) {
      if (F.getName().equals("GPIO_SetMode"))
        continue;
      f.push_back(&F);
    }
    for (auto &F : f) {
      for (auto &BB : *F) {
        for (auto &I : BB) {
          if (I.getOpcode() == Instruction::Store)
            insertFunc(I, m);
        }
      }
    }
    m->dump();
    for (auto I : inst_lst) {
      I->eraseFromParent();
    }

    return false;
  }

}; // end of struct Hello
} // end of anonymous namespace

char SecureIO::ID = 0;

static RegisterPass<SecureIO> X("lr", "Hello World Pass",
                                  false /* Only looks at CFG */,
                                  false /* Analysis Pass */);