AFLplusplus cmplog插桩解析

AFL++’s Cmplog插桩

1. 源码解析

afl-cc.c

  • 代码片段:
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
	if (cmplog_mode) {

cc_params[cc_par_cnt++] = "-fno-inline";

#if LLVM_MAJOR >= 11 /* use new pass manager */
cc_params[cc_par_cnt++] = "-fexperimental-new-pass-manager";
cc_params[cc_par_cnt++] =
alloc_printf("-fpass-plugin=%s/cmplog-switches-pass.so", obj_path);
cc_params[cc_par_cnt++] = "-fexperimental-new-pass-manager";
cc_params[cc_par_cnt++] =
alloc_printf("-fpass-plugin=%s/split-switches-pass.so", obj_path);
#else
cc_params[cc_par_cnt++] = "-Xclang";
cc_params[cc_par_cnt++] = "-load";
cc_params[cc_par_cnt++] = "-Xclang";
cc_params[cc_par_cnt++] =
alloc_printf("%s/cmplog-switches-pass.so", obj_path);

// reuse split switches from laf
cc_params[cc_par_cnt++] = "-Xclang";
cc_params[cc_par_cnt++] = "-load";
cc_params[cc_par_cnt++] = "-Xclang";
cc_params[cc_par_cnt++] =
alloc_printf("%s/split-switches-pass.so", obj_path);
#endif

}
  • cmplog-switches-pass.sosplit-switches-pass.so是两个关键的动态共享库

cmplog-switches-pass.cc

  • CmplogSwitches类:

    • 构造函数:initInstrumentList() ==> 从环境变量 AFL_LLVM_ALLOWLIST/AFL_LLVM_DENYLIST 解析白名单/黑名单

    • runOnModule() / run() (LLVM >= 11)

    • getPassName() ==> 返回字符串 “cmplog switch split”,表示该Pass名字

    • 私有方法:hookInstrs() ?

  • hookInstrs()

    • 代码片段:

      1
      2
      3
      4
      5
      6
      7
      8
      9
      10
      11
      12
      13
      #if LLVM_VERSION_MAJOR >= 9
      FunctionCallee
      #else
      Constant *
      #endif
      c1 = M.getOrInsertFunction("__cmplog_ins_hook1", VoidTy, Int8Ty, Int8Ty,
      Int8Ty
      #if LLVM_VERSION_MAJOR < 5
      ,
      NULL
      #endif
      );
      ...
    • c1、c2、c4、c8 ==> __cmplog_ins_hook1()__cmplog_ins_hook2()__cmplog_ins_hook4()__cmplog_ins_hook8()

      • hook函数在afl-compiler-rt.o.c中被定义
      • 这里创建相应的函数函数
      • 插桩函数分析
    • __afl_cmp_map

      • cmp引入的位图
      • 作用?
    • 遍历IR中所有SwitchInst指令,将NumCases>1的指令保存到switches向量中,调用vector的erase()和std::remove()完成SwitchInst指令的去重

    • 然后遍历每一个SwitchInst指令(SI):

      • 跳过SI当:1. 整数比特宽度<16的条件分支(太简单了,很容易变异到) 或 2. Case数为0?
      • 当比特宽度模8不为0时,比特宽度取值向上取整为8的整数倍,同时需要做强制类型转换(cast)
      • 进行插桩,主要就是将条件值case比较值一个常数1传递给hook函数,伪代码大致如下:
      1
      2
      3
      4
      5
      6
      7
      if(__afl_cmp_map){
      // 是否需要强制类型转换?
      // e.g. for case 1:
      __cmplog_ins_hook1((cast_size) condition, (cast_size) caseValue1, 1);
      __cmplog_ins_hook2((cast_size) condition, (cast_size) caseValue2, 1);
      ...
      }

  • runOnModule()
    • 调用hookInstrs()函数

split-switches-pass.so.cc

  • SplitSwitchesTransform类:

    • 构造函数:initInstrumentList() ==> 从环境变量 AFL_LLVM_ALLOWLIST/AFL_LLVM_DENYLIST 解析白名单/黑名单
    • runOnModule() / run() (LLVM >= 11)
    • getPassName() ==> 返回字符串 “splits switch constructs”,表示该Pass名字
    • CaseExpr结构体: 建立Case值到基本块的映射,CaseVector用来存储CaseExpr信息
    • 私有方法:1、splitSwitches (bool); 2、transformCmps (bool);3、switchConvert
  • runOnModule() / run()

    • 调用splitSwitches(&M);
  • splitSwitches()

    • 遍历模块中的所有基本块,找到所有的switch指令,并保存到switches数组中,跳过NumCase<1的分支
    • 遍历每一个switch语句,具体来说:
    1
    2
    3
    4
    5
    6
    7
    8
    9
    10
    11
    12
    13
    14
    15
    16
    17
    18
    19
    20
    21
    22
    23
    24
    25
    26
    27
    28
    29
    30
    31
    32
    33
    34
    35
    36
    37
    38
    39
    40
    41
    42
    43
    44
    45
    46
    47
    48
    49
    50
    51
    52
    53
    54
    55
    56
    57
    58
    59
    60
    61
    62
    63
    64
    65
    66
    67
    68
    69
    70
    71
    72
    73
    74
    75
    76
    77
    for (auto &SI : switches) {

    BasicBlock *CurBlock = SI->getParent();
    BasicBlock *OrigBlock = CurBlock;
    Function * F = CurBlock->getParent();
    /* this is the value we are switching on */
    Value * Val = SI->getCondition();
    BasicBlock *Default = SI->getDefaultDest();
    unsigned bitw = Val->getType()->getIntegerBitWidth();

    /*
    if (!be_quiet)
    errs() << "switch: " << SI->getNumCases() << " cases " << bitw
    << " bit\n";
    */

    /* If there is only the default destination or the condition checks 8 bit or
    * less, don't bother with the code below. */
    if (SI->getNumCases() < 2 || bitw % 8 || bitw > 64) {

    // if (!be_quiet) errs() << "skip switch..\n";
    continue;

    }

    /* Create a new, empty default block so that the new hierarchy of
    * if-then statements go to this and the PHI nodes are happy.
    * if the default block is set as an unreachable we avoid creating one
    * because will never be a valid target.*/
    BasicBlock *NewDefault = nullptr;
    NewDefault = BasicBlock::Create(SI->getContext(), "NewDefault", F, Default); // 在基本块Default前面插入一个名为NewDefault的基本块。创建新基本块的目的是啥?
    BranchInst::Create(Default, NewDefault); // 创建 NewDefault ==> Default 的控制流

    /* Prepare cases vector. */
    CaseVector Cases;
    for (SwitchInst::CaseIt i = SI->case_begin(), e = SI->case_end(); i != e;
    ++i)
    #if LLVM_VERSION_MAJOR >= 5
    Cases.push_back(CaseExpr(i->getCaseValue(), i->getCaseSuccessor())); // getCaseSuccessor() 获取与case相关联的基本块指针
    #else
    Cases.push_back(CaseExpr(i.getCaseValue(), i.getCaseSuccessor()));
    #endif
    /* bugfix thanks to pbst
    * round up bytesChecked (in case getBitWidth() % 8 != 0) */
    std::vector<bool> bytesChecked((7 + Cases[0].Val->getBitWidth()) / 8,
    false);
    BasicBlock * SwitchBlock =
    switchConvert(Cases, bytesChecked, OrigBlock, NewDefault, Val, 0); // 调用switchConvert()

    /* Branch to our shiny new if-then stuff... */
    BranchInst::Create(SwitchBlock, OrigBlock); // 创建新的控制流,新创建的SwitchBlock块 --> OrigBlock

    /* We are now done with the switch instruction, delete it. */
    CurBlock->getInstList().erase(SI); // 删除原switch语句? 也就是将原switch替换为SwitchBlock

    /* we have to update the phi nodes! */ // 更新phi节点?
    for (BasicBlock::iterator I = Default->begin(); I != Default->end(); ++I) {

    if (!isa<PHINode>(&*I)) { continue; }
    PHINode *PN = cast<PHINode>(I);

    /* Only update the first occurrence. */
    unsigned Idx = 0, E = PN->getNumIncomingValues();
    for (; Idx != E; ++Idx) {

    if (PN->getIncomingBlock(Idx) == OrigBlock) {

    PN->setIncomingBlock(Idx, NewDefault);
    break;

    }

    }

    }

    }
  • switchConvert()

    1
    2
    3
    4
    5
    6
    7
    8
    9
    10
    11
    12
    13
    14
    15
    16
    17
    18
    19
    20
    21
    22
    23
    24
    25
    26
    27
    28
    29
    30
    31
    32
    33
    34
    35
    36
    37
    38
    39
    40
    41
    42
    43
    44
    45
    46
    47
    48
    49
    50
    51
    52
    53
    54
    55
    56
    57
    58
    59
    60
    61
    62
    63
    64
    65
    66
    67
    68
    69
    70
    71
    72
    73
    74
    75
    76
    77
    78
    79
    80
    81
    82
    83
    84
    85
    86
    87
    88
    89
    90
    91
    92
    93
    94
    95
    96
    97
    98
    99
    100
    101
    102
    103
    104
    105
    106
    107
    108
    109
    110
    111
    112
    113
    114
    115
    116
    117
    118
    119
    120
    121
    122
    123
    124
    125
    126
    127
    128
    129
    130
    131
    132
    133
    134
    135
    136
    137
    138
    139
    140
    141
    142
    143
    144
    145
    146
    147
    148
    149
    150
    151
    152
    153
    154
    155
    156
    157
    158
    159
    160
    161
    162
    163
    164
    165
    166
    167
    168
    169
    170
    171
    172
    173
    174
    175
    176
    177
    178
    179
    180
    181
    182
    183
    184
    185
    186
    187
    188
    189
    BasicBlock *SplitSwitchesTransform::switchConvert(
    CaseVector Cases, std::vector<bool> bytesChecked, BasicBlock *OrigBlock,
    BasicBlock *NewDefault, Value *Val, unsigned level) {

    unsigned ValTypeBitWidth = Cases[0].Val->getBitWidth();
    IntegerType *ValType =
    IntegerType::get(OrigBlock->getContext(), ValTypeBitWidth);
    IntegerType * ByteType = IntegerType::get(OrigBlock->getContext(), 8);
    unsigned BytesInValue = bytesChecked.size();
    std::vector<uint8_t> setSizes;
    std::vector<std::set<uint8_t> > byteSets(BytesInValue, std::set<uint8_t>());

    /* for each of the possible cases we iterate over all bytes of the values
    * build a set of possible values at each byte position in byteSets */
    for (CaseExpr &Case : Cases) {

    for (unsigned i = 0; i < BytesInValue; i++) {

    uint8_t byte = (Case.Val->getZExtValue() >> (i * 8)) & 0xFF; // 获取Case值的每一个字节,存放于byteSets中
    byteSets[i].insert(byte);

    }

    }

    /* find the index of the first byte position that was not yet checked. then
    * save the number of possible values at that byte position */
    unsigned smallestIndex = 0;
    unsigned smallestSize = 257;
    for (unsigned i = 0; i < byteSets.size(); i++) { // 这里的代码主要用于递归!!!

    if (bytesChecked[i]) continue;
    if (byteSets[i].size() < smallestSize) {

    smallestIndex = i;
    smallestSize = byteSets[i].size();

    }

    }

    assert(bytesChecked[smallestIndex] == false);

    /* there are only smallestSize different bytes at index smallestIndex */

    Instruction *Shift, *Trunc;
    Function * F = OrigBlock->getParent();
    BasicBlock * NewNode = BasicBlock::Create(Val->getContext(), "NodeBlock", F);
    Shift = BinaryOperator::Create(Instruction::LShr, Val,
    ConstantInt::get(ValType, smallestIndex * 8)); // condition逻辑右移
    NewNode->getInstList().push_back(Shift);

    if (ValTypeBitWidth > 8) {

    Trunc = new TruncInst(Shift, ByteType); // 截断
    NewNode->getInstList().push_back(Trunc);

    } else {

    /* not necessary to trunc */
    Trunc = Shift;

    }

    /* this is a trivial case, we can directly check for the byte,
    * if the byte is not found go to default. if the byte was found
    * mark the byte as checked. if this was the last byte to check
    * we can finally execute the block belonging to this case */

    if (smallestSize == 1) {

    uint8_t byte = *(byteSets[smallestIndex].begin());

    /* insert instructions to check whether the value we are switching on is
    * equal to byte */
    ICmpInst *Comp =
    new ICmpInst(ICmpInst::ICMP_EQ, Trunc, ConstantInt::get(ByteType, byte),
    "byteMatch");
    NewNode->getInstList().push_back(Comp);

    bytesChecked[smallestIndex] = true;
    bool allBytesAreChecked = true;

    for (std::vector<bool>::iterator BCI = bytesChecked.begin(),
    E = bytesChecked.end();
    BCI != E; ++BCI) {

    if (!*BCI) {

    allBytesAreChecked = false;
    break;

    }

    }

    // if (std::all_of(bytesChecked.begin(), bytesChecked.end(),
    // [](bool b) { return b; })) {

    if (allBytesAreChecked) {

    assert(Cases.size() == 1);
    BranchInst::Create(Cases[0].BB, NewDefault, Comp, NewNode);

    /* we have to update the phi nodes! */
    for (BasicBlock::iterator I = Cases[0].BB->begin();
    I != Cases[0].BB->end(); ++I) {

    if (!isa<PHINode>(&*I)) { continue; }
    PHINode *PN = cast<PHINode>(I);

    /* Only update the first occurrence. */
    unsigned Idx = 0, E = PN->getNumIncomingValues();
    for (; Idx != E; ++Idx) {

    if (PN->getIncomingBlock(Idx) == OrigBlock) {

    PN->setIncomingBlock(Idx, NewNode);
    break;

    }

    }

    }

    } else {

    BasicBlock *BB = switchConvert(Cases, bytesChecked, OrigBlock, NewDefault,
    Val, level + 1);
    BranchInst::Create(BB, NewDefault, Comp, NewNode);

    }

    }

    /* there is no byte which we can directly check on, split the tree */
    else {
    // 二分法 分而治之!
    std::vector<uint8_t> byteVector;
    std::copy(byteSets[smallestIndex].begin(), byteSets[smallestIndex].end(),
    std::back_inserter(byteVector));
    std::sort(byteVector.begin(), byteVector.end());
    uint8_t pivot = byteVector[byteVector.size() / 2];

    /* we already chose to divide the cases based on the value of byte at index
    * smallestIndex the pivot value determines the threshold for the decicion;
    * if a case value
    * is smaller at this byte index move it to the LHS vector, otherwise to the
    * RHS vector */

    CaseVector LHSCases, RHSCases;

    for (CaseExpr &Case : Cases) {

    uint8_t byte = (Case.Val->getZExtValue() >> (smallestIndex * 8)) & 0xFF;

    if (byte < pivot) {

    LHSCases.push_back(Case);

    } else {

    RHSCases.push_back(Case);

    }

    }

    BasicBlock *LBB, *RBB;
    LBB = switchConvert(LHSCases, bytesChecked, OrigBlock, NewDefault, Val,
    level + 1);
    RBB = switchConvert(RHSCases, bytesChecked, OrigBlock, NewDefault, Val,
    level + 1);

    /* insert instructions to check whether the value we are switching on is
    * equal to byte */
    ICmpInst *Comp =
    new ICmpInst(ICmpInst::ICMP_ULT, Trunc,
    ConstantInt::get(ByteType, pivot), "byteMatch");
    NewNode->getInstList().push_back(Comp);
    BranchInst::Create(LBB, RBB, Comp, NewNode);

    }

    return NewNode;

    }

    • 简单来说,其功能就是将switch分支细粒度化,比如:
    1
    2
    3
    4
    5
    6
    7
    8
    9
    10
    11
    12
    13
    14
    15
    16
    17
    18
    19
    20
    21
    22
    23
    24
    // before spliting switches...
    {
    long long x;
    switch(x){
    case 0x12345678:
    do_something1();
    case 0x78563412:
    do_something2();
    }

    }
    // after spliting switches...
    {
    if((u8*)x[0] == 0x78)
    if((u8*)x[1] == 0x56))
    if((u8*)x[2] == 0x34))
    if((u8*)x[3] == 0x12))
    do_something1();
    if((u8*)x[0] == 0x12)
    if((u8*)x[1] == 0x34))
    if((u8*)x[2] == 0x56))
    if((u8*)x[3] == 0x78))
    do_something2();
    }

afl-complier-rt.o.c

  • __cmplog_ins_hook1 为例:
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
void __cmplog_ins_hook1(uint8_t arg1, uint8_t arg2, uint8_t attr) {

// fprintf(stderr, "hook1 arg0=%02x arg1=%02x attr=%u\n",
// (u8) arg1, (u8) arg2, attr);

if (unlikely(!__afl_cmp_map || arg1 == arg2)) return; // Oops? got it!

uintptr_t k = (uintptr_t)__builtin_return_address(0); // __builtin_return_address()返回调用函数的返回地址,0表示当前函数的返回地址,1表示当前函数的调用者的返回地址,以此类推,k为当前函数的返回地址。目的是啥?
k = (uintptr_t)(default_hash((u8 *)&k, sizeof(uintptr_t)) & (CMP_MAP_W - 1)); //指针(返回地址)进行哈希映射,然后与上CMP_MAP大小:65536

u32 hits;

if (__afl_cmp_map->headers[k].type != CMP_TYPE_INS) { // 首先命中?

__afl_cmp_map->headers[k].type = CMP_TYPE_INS; // type赋值
hits = 0; // 初始化hits值
__afl_cmp_map->headers[k].hits = 1; // 命中值hits初始化为1
__afl_cmp_map->headers[k].shape = 0; // ?

} else {

hits = __afl_cmp_map->headers[k].hits++; // 命中值hits++,将原来命中值赋值给局部变量hits

}

__afl_cmp_map->headers[k].attribute = attr; // attribute?貌似是1?

hits &= CMP_MAP_H - 1; // 记录前 CMP_MAP_H(32)次的值,超过 CMP_MAP_H 会替换前面的值
__afl_cmp_map->log[k][hits].v0 = arg1;
__afl_cmp_map->log[k][hits].v1 = arg2;

}

2. Redqueen

afl-fuzz-redqueen.c

  • colorization() 填色函数 ==> 生成污点信息

    • 通过type_replace()函数对输入的每一个字节进行同类型的替换,e.g. 该字节为大写字母,那么就将其替换为其他大写字母
    • pop_biggest_range()得到ranges [双向链表,用来存储当前输入的切割范围] 最大的范围,判断对当前范围的字节改变是否会影响到路径的变化
    • 最终得到的ranges链表中的每一个范围表示对该范围进行变异大概率能够引起覆盖率变化,由此生成taint信息
  • colorization()后得到一个污点后的输入buf和原始输入orig_buf,然后分别运行这两个测试用例,将cmp覆盖率信息保存以进行后续分析:

    • input-to-state阶段:调用cmp_fuzz() TODO!

AFLplusplus cmplog插桩解析
http://bladchan.github.io/2023/04/20/AFLplusplus-cmplog插桩解析/
作者
bladchan
发布于
2023年4月20日
许可协议