YOLO系列的IOU、GIOU源码阅读

自古美人都是妖i 提交于 2019-11-30 12:05:26
#include "box.h"
#include <stdio.h>
#include <math.h>
#include <stdlib.h>

//#define DEBUG_NAN
//#define DEBUG_PRINTS

int nms_comparator(const void *pa, const void *pb)
{
    detection a = *(detection *)pa;
    detection b = *(detection *)pb;
    float diff = 0;
    if(b.sort_class >= 0){
        diff = a.prob[b.sort_class] - b.prob[b.sort_class];
    } else {
        diff = a.objectness - b.objectness;
    }
    if(diff < 0) return 1;
    else if(diff > 0) return -1;
    return 0;
}

void do_nms_obj(detection *dets, int total, int classes, float thresh)
{
    int i, j, k;
    k = total-1;
    for(i = 0; i <= k; ++i){
        if(dets[i].objectness == 0){
            detection swap = dets[i];
            dets[i] = dets[k];
            dets[k] = swap;
            --k;
            --i;
        }
    }
    total = k+1;

    for(i = 0; i < total; ++i){
        dets[i].sort_class = -1;
    }

    qsort(dets, total, sizeof(detection), nms_comparator);
    for(i = 0; i < total; ++i){
        if(dets[i].objectness == 0) continue;
        box a = dets[i].bbox;
        for(j = i+1; j < total; ++j){
            if(dets[j].objectness == 0) continue;
            box b = dets[j].bbox;
            if (box_iou(a, b) > thresh){
                dets[j].objectness = 0;
                for(k = 0; k < classes; ++k){
                    dets[j].prob[k] = 0;
                }
            }
        }
    }
}


void do_nms_sort(detection *dets, int total, int classes, float thresh)
{
    int i, j, k;
    k = total-1;
    for(i = 0; i <= k; ++i){
        if(dets[i].objectness == 0){
            detection swap = dets[i];
            dets[i] = dets[k];
            dets[k] = swap;
            --k;
            --i;
        }
    }
    total = k+1;

    for(k = 0; k < classes; ++k){
        for(i = 0; i < total; ++i){
            dets[i].sort_class = k;
        }
        qsort(dets, total, sizeof(detection), nms_comparator);
        for(i = 0; i < total; ++i){
            if(dets[i].prob[k] == 0) continue;
            box a = dets[i].bbox;
            for(j = i+1; j < total; ++j){
                box b = dets[j].bbox;
                if (box_iou(a, b) > thresh){
                    dets[j].prob[k] = 0;
                }
            }
        }
    }
}

/** 从存储矩形框信息的大数组中,提取某一个矩形框定位信息并返回(f是这个大数组中某一起始地址).
 * @param f 存储了矩形框信息(每个矩形框有5个参数值,此函数仅提取其中4个用于定位的参数x,y,w,h,不包含物体类别编号)
 * @param stride 跨度,按倍数跳越取值
 * @return b 矩形框信息
*/
box float_to_box(float *f, int stride)
{
	
    /// f中存储每一个矩形框信息的顺序为: x, y, w, h, class_index,这里仅提取前四个,
    /// 也即矩形框的定位信息,最后一个物体类别编号信息不在此处提取	
    box b = {0};
    b.x = f[0];
    b.y = f[1*stride];
    b.w = f[2*stride];
    b.h = f[3*stride];
    return b;
}

dbox derivative(box a, box b)   //这里只是对di (W*H)内部区域的求导
{
    dbox d;                      // W的求导 
    d.dx = 0;
    d.dw = 0;      
    float l1 = a.x - a.w/2;   
    float l2 = b.x - b.w/2;   
    if (l1 > l2){    // Y left=L1    N  left=L2    与 w,x 相关  Y  选L1 =a.x-a.w/2
        d.dx -= 1;    // w=a.x - a.w/2-(b.x - b.w/2)
        d.dw += .5;
    }
    float r1 = a.x + a.w/2;       
    float r2 = b.x + b.w/2;
    if(r1 < r2){           // Y right=r1   N right=r2   与w,x相关   W= r -l   选 R1=a.x+a.w/2      对F(x,w)=R-L求偏导x,w
        d.dx += 1
        d.dw += .5;
    }
    if (l1 > r2) {   // r1> 11>r2  >l2   ------------->  r1>r2          选L1= = a.x - a.w/2;        R2= b.x + b.w/2;   R-L   w为负越界
        d.dx = -1;
        d.dw = 0;                         
    }
    if (r1 < l2){    //  l1<  r1<l2  <r2        --------------->  l1<l2  L选 l2=b.x-b.w/2    r1 选a.x+a.w/2     w=  R1-L2=(a.x+a.w/2)-(b.x-b.w/2)  w为负越界
        d.dx = 1;
        d.dw = 0;
    }

    d.dy = 0;
    d.dh = 0;                    // H的求导
    float t1 = a.y - a.h/2;
    float t2 = b.y - b.h/2;
    if (t1 > t2){
        d.dy -= 1;
        d.dh += .5;
    }
    float b1 = a.y + a.h/2;
    float b2 = b.y + b.h/2;
    if(b1 < b2){
        d.dy += 1;
        d.dh += .5;
    }
    if (t1 > b2) {
        d.dy = -1;
        d.dh = 0;
    }
    if (b1 < t2){
        d.dy = 1;
        d.dh = 0;
    }
    return d;
}

/** 计算两个矩形框相交部分矩形的某一边的边长(视调用情况,可能是相交部分矩形的高,也可能是宽).
 * @param x1 第一个矩形框的x坐标(或者y坐标,视调用情况,如果计算的是相交部分矩形的宽,则输入的是x坐标)
 * @param w1 第一个矩形框的宽(而如果要计算相交部分矩形的高,则为y坐标,下面凡是说x坐标的,都可能为y坐标,当然,对应宽变为高)
 * @param x2 第二个矩形框的x坐标
 * @param w2 第二个矩形框的宽
 * @details 在纸上画一下两个矩形,自己想一下如何计算交集的面积就很清楚下面的代码了:首先计算两个框左边的x坐标,比较大小,
 *          取其大者,记为left;而后计算两个框右边的x坐标,取其小者,记为right,right-left即得相交部分矩形的宽。
 * @return 两个矩形框相交部分矩形的宽或者高
 */



float overlap(float x1, float w1, float x2, float w2)
{
    float l1 = x1 - w1/2;
    float l2 = x2 - w2/2;
    float left = l1 > l2 ? l1 : l2;
    float r1 = x1 + w1/2;
    float r2 = x2 + w2/2;
    float right = r1 < r2 ? r1 : r2;
    return right - left;
}


/** 两个矩形框求交:计算两个矩形框a,b相交部分的面积.
 * @return 两个矩形a,b相交部分的面积
 * @note 当两个矩形不相交的时候,返回的值为0(此时计算得到的w,h将小于0,w,h是按照上面overlap()函数的方式计算得到的,
 *       在纸上比划一下就知道为什么会小于0了)
 */


float box_intersection(box a, box b)
{
    float w = overlap(a.x, a.w, b.x, b.w);
    float h = overlap(a.y, a.h, b.y, b.h);
    if(w < 0 || h < 0) return 0;
    float area = w*h;
    return area;
}

/** 两个矩形框求并:计算两个矩形框a,b求并的面积.
 * @return 两个矩形a,b求并之后的总面积(就是a的面积加上b的面积减去相交部分的面积)
 */


float box_union(box a, box b)  //面积和
{
    float i = box_intersection(a, b);
    float u = a.w*a.h + b.w*b.h - i;
#ifdef DEBUG_NAN
    if (isnan(i) || isnan(u)) {
      printf("box_union a, b: (%f,%f,%f,%f), (%f,%f,%f,%f): %f/%f\n", a.x, a.y, a.w, a.h, b.x, b.y, b.w, b.h, i, u);
    }
#endif
    return u;
}

/**
 * where c is the smallest box that fully encompases a and b
 */
boxabs box_c(box a, box b) {    //外接最小框
  boxabs ba = {0};
  ba.top = fmin(a.y - a.h/2, b.y - b.h/2);
  ba.bot = fmax(a.y + a.h/2, b.y + b.h/2);
  ba.left = fmin(a.x - a.w/2, b.x - b.w/2);
  ba.right = fmax(a.x + a.w/2, b.x + b.w/2);
  return ba;
}

/**
 * representation from x,y,w,h to top,left,bottom,right
 */
boxabs to_tblr(box a) {        /**四个坐标*/
  boxabs tblr = {0};
  float t = a.y-(a.h/2);
  float b = a.y+(a.h/2);
  float l = a.x-(a.w/2);
  float r = a.x+(a.w/2);
  tblr.top = t;
  tblr.bot = b;
  tblr.left = l;
  tblr.right = r;
  return tblr;
}

/** 计算IoU值.
 * @details IoU值,是目标检测精确度的一个评判指标,全称是intersection over union,翻译成中文就是交比并值,
 *          字面上的意思很直接,就是两个矩形相交部分的面积比两个矩形求并之后的总面积,用来做检测评判指标时,
 *          含义为模型检测到的矩形框与GroundTruth标记的矩形框之间的交比并值(即可反映检测到的矩形框与GroundTruth之间的重叠度),
 *          当两个矩形框完全重叠时,值为1;完全不相交时,值为0。
 */

float box_iou(box a, box b)     //IOU定义
{
    float I = box_intersection(a, b);
    float U = box_union(a, b);
    if (I == 0 || U == 0) {
      return 0;
    }
    return I / U;
}

float box_giou(box a, box b)  //GIOU 定义
{
    boxabs ba = box_c(a, b);
    float w = ba.right - ba.left;
    float h = ba.bot - ba.top;
    float c = w*h;
    float iou = box_iou(a, b);
    if (c == 0) {
      return iou;
    }
    float u = box_union(a,b);
    float giou_term = (c - u)/c;
#ifdef DEBUG_PRINTS
    printf("  c: %f, u: %f, giou_term: %f\n", c, u, giou_term);
#endif
    return iou - giou_term;
}

dxrep dx_box_iou(box pred, box truth, IOU_LOSS iou_loss) {
    boxabs pred_tblr = to_tblr(pred);  //校正四点坐标?
    float pred_t = fmin(pred_tblr.top, pred_tblr.bot);
    float pred_b = fmax(pred_tblr.top, pred_tblr.bot);
    float pred_l = fmin(pred_tblr.left, pred_tblr.right);
    float pred_r = fmax(pred_tblr.left, pred_tblr.right);

    boxabs truth_tblr = to_tblr(truth);  //Ground truth 四点坐标
#ifdef DEBUG_PRINTS
    printf("\niou: %f, giou: %f\n", box_iou(pred, truth), box_giou(pred, truth));
    printf("pred: x,y,w,h: (%f, %f, %f, %f) -> t,b,l,r: (%f, %f, %f, %f)\n", pred.x, pred.y, pred.w, pred.h, pred_tblr.top, pred_tblr.bot, pred_tblr.left, pred_tblr.right);
    printf("truth: x,y,w,h: (%f, %f, %f, %f) -> t,b,l,r: (%f, %f, %f, %f)\n", truth.x, truth.y, truth.w, truth.h, truth_tblr.top, truth_tblr.bot, truth_tblr.left, truth_tblr.right);
#endif
    //printf("pred (t,b,l,r): (%f, %f, %f, %f)\n", pred_t, pred_b, pred_l, pred_r);
    //printf("trut (t,b,l,r): (%f, %f, %f, %f)\n", truth_tblr.top, truth_tblr.bot, truth_tblr.left, truth_tblr.right);
    dxrep dx = {0};
    float X = (pred_b - pred_t) * (pred_r - pred_l);
    float Xhat = (truth_tblr.bot - truth_tblr.top) * (truth_tblr.right - truth_tblr.left);
    float Ih = fmin(pred_b, truth_tblr.bot) - fmax(pred_t, truth_tblr.top);  //内接 h    b - t
    float Iw = fmin(pred_r, truth_tblr.right) - fmax(pred_l, truth_tblr.left);  // 内接 w  r- l
    float I = Iw * Ih;    // 内接矩形框
    float U = X + Xhat - I;  // 并集

    float Cw = fmax(pred_r, truth_tblr.right) - fmin(pred_l, truth_tblr.left);
    float Ch = fmax(pred_b, truth_tblr.bot) - fmin(pred_t, truth_tblr.top);
    float C = Cw * Ch;  // 外接最小框?
#ifdef DEBUG_PRINTS
    printf("X: %f", X);
    printf(", Xhat: %f", Xhat);
    printf(", Ih: %f", Ih);
    printf(", Iw: %f", Iw);
    printf(", I: %f", I);
    printf(", U: %f", U);
    printf(", IoU: %f\n", I / U);
#endif
    // float IoU = I / U;
    // Partial Derivatives, derivatives
    float dX_wrt_t = -1 * (pred_r - pred_l);   //  预测矩形面积X  分别对每个求导
    float dX_wrt_b = pred_r - pred_l;
    float dX_wrt_l = -1 * (pred_b - pred_t);
    float dX_wrt_r = pred_b - pred_t;
    // UNUSED
    //// Ground truth
    //float dXhat_wrt_t = -1 * (truth_tblr.right - truth_tblr.left);
    //float dXhat_wrt_b = truth_tblr.right - truth_tblr.left;
    //float dXhat_wrt_l = -1 * (truth_tblr.bot - truth_tblr.top);
    //float dXhat_wrt_r = truth_tblr.bot - truth_tblr.top;

    // gradient of I min/max in IoU calc (prediction)
    float dI_wrt_t = pred_t > truth_tblr.top ? (-1 * Iw) : 0;  //对 I 求偏导, 仔细看I的公式,这里的判断确保fMax,fmin成立
    float dI_wrt_b = pred_b < truth_tblr.bot ? Iw : 0;
    float dI_wrt_l = pred_l > truth_tblr.left ? (-1 * Ih) : 0;
    float dI_wrt_r = pred_r < truth_tblr.right ? Ih : 0;
    // derivative of U with regard to x
    float dU_wrt_t = dX_wrt_t - dI_wrt_t;
    float dU_wrt_b = dX_wrt_b - dI_wrt_b;
    float dU_wrt_l = dX_wrt_l - dI_wrt_l;
    float dU_wrt_r = dX_wrt_r - dI_wrt_r;
    // gradient of C min/max in IoU calc (prediction)
    float dC_wrt_t = pred_t < truth_tblr.top ? (-1 * Cw) : 0;  // 与dI  推到类似
    float dC_wrt_b = pred_b > truth_tblr.bot ? Cw : 0;
    float dC_wrt_l = pred_l < truth_tblr.left ? (-1 * Ch) : 0;
    float dC_wrt_r = pred_r > truth_tblr.right ? Ch : 0;

    // UNUSED
    //// ground truth
    //float dI_wrt_xhat_t = pred_t < truth_tblr.top ? (-1 * Iw) : 0;
    //float dI_wrt_xhat_b = pred_b > truth_tblr.bot ? Iw : 0;
    //float dI_wrt_xhat_l = pred_l < truth_tblr.left ? (-1 * Ih) : 0;
    //float dI_wrt_xhat_r = pred_r > truth_tblr.right ? Ih : 0;

    // Final IOU loss (prediction) (negative of IOU gradient, we want the negative loss)
    float p_dt = 0;
    float p_db = 0;
    float p_dl = 0;
    float p_dr = 0;
    if (U > 0) {
      p_dt = ((U * dI_wrt_t) - (I * dU_wrt_t)) / (U * U);
      p_db = ((U * dI_wrt_b) - (I * dU_wrt_b)) / (U * U);
      p_dl = ((U * dI_wrt_l) - (I * dU_wrt_l)) / (U * U);     //IOU Loss f(t,b,l,r)=I/U 对四个参数求导
      p_dr = ((U * dI_wrt_r) - (I * dU_wrt_r)) / (U * U);
    }

    if (iou_loss == GIOU) {                                //GIOU 的 (C-U)/C =1-U/C 求导
      if (C > 0) {
        // apply "C" term from gIOU
        p_dt += ((C * dU_wrt_t) - (U * dC_wrt_t)) / (C * C);
        p_db += ((C * dU_wrt_b) - (U * dC_wrt_b)) / (C * C);
        p_dl += ((C * dU_wrt_l) - (U * dC_wrt_l)) / (C * C);
        p_dr += ((C * dU_wrt_r) - (U * dC_wrt_r)) / (C * C);
      }
    }

    // UNUSED
    //// ground truth
    //float gt_dt = ((U * dI_wrt_xhat_t) - (I * (dXhat_wrt_t - dI_wrt_xhat_t))) / (U * U);
    //float gt_db = ((U * dI_wrt_xhat_b) - (I * (dXhat_wrt_b - dI_wrt_xhat_b))) / (U * U);
    //float gt_dl = ((U * dI_wrt_xhat_l) - (I * (dXhat_wrt_l - dI_wrt_xhat_l))) / (U * U);
    //float gt_dr = ((U * dI_wrt_xhat_r) - (I * (dXhat_wrt_r - dI_wrt_xhat_r))) / (U * U);

    // no min/max grad applied
    //dx.dt = dt;
    //dx.db = db;
    //dx.dl = dl;
    //dx.dr = dr;

    // apply grad from prediction min/max for correct corner selection  #校正四点坐标?
    dx.dt = pred_tblr.top < pred_tblr.bot ? p_dt : p_db;
    dx.db = pred_tblr.top < pred_tblr.bot ? p_db : p_dt;
    dx.dl = pred_tblr.left < pred_tblr.right ? p_dl : p_dr;
    dx.dr = pred_tblr.left < pred_tblr.right ? p_dr : p_dl;

    //// sum in gt -- THIS DOESNT WORK
    //dx.dt += gt_dt;
    //dx.db += gt_db;
    //dx.dl += gt_dl;
    //dx.dr += gt_dr;

    //// instead, look at the change between pred and gt, and weight t/b/l/r appropriately...
    //// need the real derivative here (I think?)
    //float delta_t = fmax(truth_tblr.top, pred_t) - fmin(truth_tblr.top, pred_t);
    //float delta_b = fmax(truth_tblr.bot, pred_b) - fmin(truth_tblr.bot, pred_b);
    //float delta_l = fmax(truth_tblr.left, pred_l) - fmin(truth_tblr.left, pred_l);
    //float delta_r = fmax(truth_tblr.right, pred_r) - fmin(truth_tblr.right, pred_r);

    //dx.dt *= delta_t / (delta_t + delta_b);
    //dx.db *= delta_b / (delta_t + delta_b);
    //dx.dl *= delta_l / (delta_l + delta_r);
    //dx.dr *= delta_r / (delta_l + delta_r);

#ifdef DEBUG_PRINTS
    printf("  directions dt: ");
    if ((pred_tblr.top < truth_tblr.top && dx.dt > 0) || (pred_tblr.top > truth_tblr.top && dx.dt < 0)) {
      printf("✓");
    } else {
      printf("𝒙");
    }
    printf(", ");
    if ((pred_tblr.bot < truth_tblr.bot && dx.db > 0) || (pred_tblr.bot > truth_tblr.bot && dx.db < 0)) {
      printf("✓");
    } else {
      printf("𝒙");
    }
    printf(", ");
    if ((pred_tblr.left < truth_tblr.left && dx.dl > 0) || (pred_tblr.left > truth_tblr.left && dx.dl < 0)) {
      printf("✓");
    } else {
      printf("𝒙");
    }
    printf(", ");
    if ((pred_tblr.right < truth_tblr.right && dx.dr > 0) || (pred_tblr.right > truth_tblr.right && dx.dr < 0)) {
      printf("✓");
    } else {
      printf("𝒙");
    }
    printf("\n");

    printf("dx dt:%f", dx.dt);
    printf(", db: %f", dx.db);
    printf(", dl: %f", dx.dl);
    printf(", dr: %f | ", dx.dr);
#endif

#ifdef DEBUG_NAN
    if (isnan(dx.dt)) { printf("dt isnan\n"); }
    if (isnan(dx.db)) { printf("db isnan\n"); }
    if (isnan(dx.dl)) { printf("dl isnan\n"); }
    if (isnan(dx.dr)) { printf("dr isnan\n"); }
#endif

//    // No update if 0 or nan
//    if (dx.dt == 0 || isnan(dx.dt)) { dx.dt = 1; }
//    if (dx.db == 0 || isnan(dx.db)) { dx.db = 1; }
//    if (dx.dl == 0 || isnan(dx.dl)) { dx.dl = 1; }
//    if (dx.dr == 0 || isnan(dx.dr)) { dx.dr = 1; }
//
//#ifdef DEBUG_PRINTS
//    printf("dx dt:%f (t: %f, p: %f)", dx.dt, gt_dt, p_dt);
//    printf(", db: %f (t: %f, p: %f)", dx.db, gt_db, p_db);
//    printf(", dl: %f (t: %f, p: %f)", dx.dl, gt_dl, p_dl);
//    printf(", dr: %f (t: %f, p: %f) | ", dx.dr, gt_dr, p_dr);
//#endif
    return dx;
}

float box_rmse(box a, box b)
{
    return sqrt(pow(a.x-b.x, 2) + 
                pow(a.y-b.y, 2) + 
                pow(a.w-b.w, 2) + 
                pow(a.h-b.h, 2));
}

dbox dintersect(box a, box b)
{
    float w = overlap(a.x, a.w, b.x, b.w);
    float h = overlap(a.y, a.h, b.y, b.h);
    dbox dover = derivative(a, b);     // 对 交集I=w*H  中w和H项分别独立求偏导,有点难理解,慢慢看 
    dbox di;

    di.dw = dover.dw*h;              
    di.dx = dover.dx*h;
    di.dh = dover.dh*w;
    di.dy = dover.dy*w;

    return di;
}

dbox dunion(box a, box b)            // 并集求导
{
    dbox du;

    dbox di = dintersect(a, b);
    du.dw = a.h - di.dw;                  // u=a.w*a.h + b.w*b.h - i;
    du.dh = a.w - di.dh;
    du.dx = -di.dx;      
    du.dy = -di.dy;

    return du;
}


void test_dunion()
{
    box a = {0, 0, 1, 1};
    box dxa= {0+.0001, 0, 1, 1};
    box dya= {0, 0+.0001, 1, 1};
    box dwa= {0, 0, 1+.0001, 1};
    box dha= {0, 0, 1, 1+.0001};

    box b = {.5, .5, .2, .2};
    dbox di = dunion(a,b);
    printf("Union: %f %f %f %f\n", di.dx, di.dy, di.dw, di.dh);
    float inter =  box_union(a, b);
    float xinter = box_union(dxa, b);
    float yinter = box_union(dya, b);
    float winter = box_union(dwa, b);
    float hinter = box_union(dha, b);
    xinter = (xinter - inter)/(.0001);
    yinter = (yinter - inter)/(.0001);
    winter = (winter - inter)/(.0001);
    hinter = (hinter - inter)/(.0001);
    printf("Union Manual %f %f %f %f\n", xinter, yinter, winter, hinter);
}
void test_dintersect()
{
    box a = {0, 0, 1, 1};
    box dxa= {0+.0001, 0, 1, 1};
    box dya= {0, 0+.0001, 1, 1};
    box dwa= {0, 0, 1+.0001, 1};
    box dha= {0, 0, 1, 1+.0001};

    box b = {.5, .5, .2, .2};
    dbox di = dintersect(a,b);
    printf("Inter: %f %f %f %f\n", di.dx, di.dy, di.dw, di.dh);
    float inter =  box_intersection(a, b);
    float xinter = box_intersection(dxa, b);
    float yinter = box_intersection(dya, b);
    float winter = box_intersection(dwa, b);
    float hinter = box_intersection(dha, b);
    xinter = (xinter - inter)/(.0001);
    yinter = (yinter - inter)/(.0001);
    winter = (winter - inter)/(.0001);
    hinter = (hinter - inter)/(.0001);
    printf("Inter Manual %f %f %f %f\n", xinter, yinter, winter, hinter);
}

void test_box()
{
    test_dintersect();
    test_dunion();
    box a = {0, 0, 1, 1};
    box dxa= {0+.00001, 0, 1, 1};
    box dya= {0, 0+.00001, 1, 1};
    box dwa= {0, 0, 1+.00001, 1};
    box dha= {0, 0, 1, 1+.00001};

    box b = {.5, 0, .2, .2};

    float iou = box_iou(a,b);
    iou = (1-iou)*(1-iou);
    printf("%f\n", iou);
    dbox d = diou(a, b);
    printf("%f %f %f %f\n", d.dx, d.dy, d.dw, d.dh);

    float xiou = box_iou(dxa, b);
    float yiou = box_iou(dya, b);
    float wiou = box_iou(dwa, b);
    float hiou = box_iou(dha, b);
    xiou = ((1-xiou)*(1-xiou) - iou)/(.00001);
    yiou = ((1-yiou)*(1-yiou) - iou)/(.00001);
    wiou = ((1-wiou)*(1-wiou) - iou)/(.00001);
    hiou = ((1-hiou)*(1-hiou) - iou)/(.00001);
    printf("manual %f %f %f %f\n", xiou, yiou, wiou, hiou);
}

dbox diou(box a, box b)
{
    float u = box_union(a,b);
    float i = box_intersection(a,b);
    dbox di = dintersect(a,b);
    dbox du = dunion(a,b);
    dbox dd = {0,0,0,0};

    if(i <= 0 || 1) {
        dd.dx = b.x - a.x;
        dd.dy = b.y - a.y;
        dd.dw = b.w - a.w;
        dd.dh = b.h - a.h;
        return dd;
    }

    dd.dx = 2*pow((1-(i/u)),1)*(di.dx*u - du.dx*i)/(u*u);   //复合函数求导 F(x, y, w, h)=1-I/U   I,U分别是(x, y, w, h)的函数
    dd.dy = 2*pow((1-(i/u)),1)*(di.dy*u - du.dy*i)/(u*u);
    dd.dw = 2*pow((1-(i/u)),1)*(di.dw*u - du.dw*i)/(u*u);
    dd.dh = 2*pow((1-(i/u)),1)*(di.dh*u - du.dh*i)/(u*u);
    return dd;
}


void do_nms(box *boxes, float **probs, int total, int classes, float thresh)
{
    int i, j, k;
    for(i = 0; i < total; ++i){
        int any = 0;
        for(k = 0; k < classes; ++k) any = any || (probs[i][k] > 0);
        if(!any) {
            continue;
        }
        for(j = i+1; j < total; ++j){
            if (box_iou(boxes[i], boxes[j]) > thresh){
                for(k = 0; k < classes; ++k){
                    if (probs[i][k] < probs[j][k]) probs[i][k] = 0;
                    else probs[j][k] = 0;
                }
            }
        }
    }
}

box encode_box(box b, box anchor)
{
    box encode;
    encode.x = (b.x - anchor.x) / anchor.w;
    encode.y = (b.y - anchor.y) / anchor.h;
    encode.w = log2(b.w / anchor.w);
    encode.h = log2(b.h / anchor.h);
    return encode;
}

box decode_box(box b, box anchor)
{
    box decode;
    decode.x = b.x * anchor.w + anchor.x;
    decode.y = b.y * anchor.h + anchor.y;
    decode.w = pow(2., b.w) * anchor.w;
    decode.h = pow(2., b.h) * anchor.h;
    return decode;
}

标签
易学教程内所有资源均来自网络或用户发布的内容,如有违反法律规定的内容欢迎反馈
该文章没有解决你所遇到的问题?点击提问,说说你的问题,让更多的人一起探讨吧!