Edmond Yoo
2018-10-13 23d94e4846bf4ec13069703a28b1d776f4bbe44f
src/region_layer.c
@@ -127,17 +127,17 @@
            class_id = hier->parent[class_id];
        }
        *avg_cat += pred;
    } else {
    } else {
        // Focal loss
        if (focal_loss) {
            // Focal Loss
            float alpha = 0.5;    // 0.25 or 0.5
            //float gamma = 2;    // hardcoded in many places of the grad-formula
            //float gamma = 2;    // hardcoded in many places of the grad-formula
            int ti = index + class_id;
            float pt = output[ti] + 0.000000000000001F;
            // http://fooplot.com/#W3sidHlwZSI6MCwiZXEiOiItKDEteCkqKDIqeCpsb2coeCkreC0xKSIsImNvbG9yIjoiIzAwMDAwMCJ9LHsidHlwZSI6MTAwMH1d
            float grad = -(1 - pt) * (2 * pt*logf(pt) + pt - 1);    // http://blog.csdn.net/linmingan/article/details/77885832
            float grad = -(1 - pt) * (2 * pt*logf(pt) + pt - 1);    // http://blog.csdn.net/linmingan/article/details/77885832
            //float grad = (1 - pt) * (2 * pt*logf(pt) + pt - 1);    // https://github.com/unsky/focal-loss
            for (n = 0; n < classes; ++n) {
@@ -224,7 +224,7 @@
            int onlyclass_id = 0;
            for(t = 0; t < l.max_boxes; ++t){
                box truth = float_to_box(state.truth + t*5 + b*l.truths);
                if(!truth.x) break;
                if(!truth.x) break; // continue;
                int class_id = state.truth[t*5 + b*l.truths + 4];
                float maxp = 0;
                int maxi = 0;
@@ -258,7 +258,7 @@
                        box truth = float_to_box(state.truth + t*5 + b*l.truths);
                        int class_id = state.truth[t * 5 + b*l.truths + 4];
                        if (class_id >= l.classes) continue; // if label contains class_id more than number of classes in the cfg-file
                        if(!truth.x) break;
                        if(!truth.x) break; // continue;
                        float iou = box_iou(pred, truth);
                        if (iou > best_iou) {
                            best_class_id = state.truth[t*5 + b*l.truths + 4];
@@ -302,7 +302,7 @@
                continue; // if label contains class_id more than number of classes in the cfg-file
            }
            if(!truth.x) break;
            if(!truth.x) break; // continue;
            float best_iou = 0;
            int best_index = 0;
            int best_n = 0;