Joseph Redmon
2015-11-04 8fd18add6e060a433629fae3fa2a7ef75df4644e
src/region_layer.c
@@ -6,34 +6,32 @@
#include "cuda.h"
#include "utils.h"
#include <stdio.h>
#include <assert.h>
#include <string.h>
#include <stdlib.h>
int get_region_layer_locations(region_layer l)
{
    return l.inputs / (l.classes+l.coords);
}
region_layer make_region_layer(int batch, int inputs, int n, int classes, int coords, int rescore)
region_layer make_region_layer(int batch, int inputs, int n, int side, int classes, int coords, int rescore)
{
    region_layer l = {0};
    l.type = REGION;
    l.n = n;
    l.batch = batch;
    l.inputs = inputs;
    l.classes = classes;
    l.coords = coords;
    l.rescore = rescore;
    l.side = side;
    assert(side*side*((1 + l.coords)*l.n + l.classes) == inputs);
    l.cost = calloc(1, sizeof(float));
    int outputs = inputs;
    l.outputs = outputs;
    l.output = calloc(batch*outputs, sizeof(float));
    l.delta = calloc(batch*outputs, sizeof(float));
    #ifdef GPU
    l.output_gpu = cuda_make_array(0, batch*outputs);
    l.delta_gpu = cuda_make_array(0, batch*outputs);
    #endif
    l.outputs = l.inputs;
    l.truths = l.side*l.side*(1+l.coords+l.classes);
    l.output = calloc(batch*l.outputs, sizeof(float));
    l.delta = calloc(batch*l.outputs, sizeof(float));
#ifdef GPU
    l.output_gpu = cuda_make_array(l.output, batch*l.outputs);
    l.delta_gpu = cuda_make_array(l.delta, batch*l.outputs);
#endif
    fprintf(stderr, "Region Layer\n");
    srand(0);
@@ -43,102 +41,202 @@
void forward_region_layer(const region_layer l, network_state state)
{
    int locations = get_region_layer_locations(l);
    int locations = l.side*l.side;
    int i,j;
    for(i = 0; i < l.batch*locations; ++i){
        int index = i*(l.classes + l.coords);
        int mask = (!state.truth || !state.truth[index]);
        for(j = 0; j < l.classes; ++j){
            l.output[index+j] = state.input[index+j];
        }
        softmax_array(l.output + index, l.classes, l.output + index);
        index += l.classes;
        for(j = 0; j < l.coords; ++j){
            l.output[index+j] = mask*state.input[index+j];
    memcpy(l.output, state.input, l.outputs*l.batch*sizeof(float));
    int b;
    if (l.softmax){
        for(b = 0; b < l.batch; ++b){
            int index = b*l.inputs;
            for (i = 0; i < locations; ++i) {
                int offset = i*l.classes;
                softmax_array(l.output + index + offset, l.classes,
                        l.output + index + offset);
            }
            int offset = locations*l.classes;
            activate_array(l.output + index + offset, locations*l.n*(1+l.coords), LOGISTIC);
        }
    }
    if (l.object_logistic) {
        for(b = 0; b < l.batch; ++b){
            int index = b*l.inputs;
            int p_index = index + locations*l.classes;
            activate_array(l.output + p_index, locations*l.n, LOGISTIC);
        }
    }
    if (l.coord_logistic) {
        for(b = 0; b < l.batch; ++b){
            int index = b*l.inputs;
            int coord_index = index + locations*(l.classes + l.n);
            activate_array(l.output + coord_index, locations*l.n*l.coords, LOGISTIC);
        }
    }
    if (l.class_logistic) {
        for(b = 0; b < l.batch; ++b){
            int class_index = b*l.inputs;
            activate_array(l.output + class_index, locations*l.classes, LOGISTIC);
        }
    }
    if(state.train){
        float avg_iou = 0;
        float avg_cat = 0;
        float avg_allcat = 0;
        float avg_obj = 0;
        float avg_anyobj = 0;
        int count = 0;
        *(l.cost) = 0;
        int size = l.outputs * l.batch;
        int size = l.inputs * l.batch;
        memset(l.delta, 0, size * sizeof(float));
        for (i = 0; i < l.batch*locations; ++i) {
            int offset = i*(l.classes+l.coords);
            int bg = state.truth[offset];
            for (j = offset; j < offset+l.classes; ++j) {
                //*(l.cost) += pow(state.truth[j] - l.output[j], 2);
                //l.delta[j] =  state.truth[j] - l.output[j];
            }
            box anchor = {0,0,.5,.5};
            box truth_code = {state.truth[j+0], state.truth[j+1], state.truth[j+2], state.truth[j+3]};
            box out_code =   {l.output[j+0], l.output[j+1], l.output[j+2], l.output[j+3]};
            box out = decode_box(out_code, anchor);
            box truth = decode_box(truth_code, anchor);
            if(bg) continue;
            //printf("Box:       %f %f %f %f\n", truth.x, truth.y, truth.w, truth.h);
            //printf("Code:      %f %f %f %f\n", truth_code.x, truth_code.y, truth_code.w, truth_code.h);
            //printf("Pred     : %f %f %f %f\n", out.x, out.y, out.w, out.h);
            // printf("Pred Code: %f %f %f %f\n", out_code.x, out_code.y, out_code.w, out_code.h);
            float iou = box_iou(out, truth);
            avg_iou += iou;
            ++count;
            /*
             *(l.cost) += pow((1-iou), 2);
             l.delta[j+0] = (state.truth[j+0] - l.output[j+0]);
             l.delta[j+1] = (state.truth[j+1] - l.output[j+1]);
             l.delta[j+2] = (state.truth[j+2] - l.output[j+2]);
             l.delta[j+3] = (state.truth[j+3] - l.output[j+3]);
             */
            for (j = offset+l.classes; j < offset+l.classes+l.coords; ++j) {
                //*(l.cost) += pow(state.truth[j] - l.output[j], 2);
                //l.delta[j] =  state.truth[j] - l.output[j];
                float diff = state.truth[j] - l.output[j];
                if (fabs(diff) < 1){
                    l.delta[j] = diff;
                    *(l.cost) += .5*pow(state.truth[j] - l.output[j], 2);
                } else {
                    l.delta[j] = (diff > 0) ? 1 : -1;
                    *(l.cost) += fabs(diff) - .5;
        for (b = 0; b < l.batch; ++b){
            int index = b*l.inputs;
            for (i = 0; i < locations; ++i) {
                int truth_index = (b*locations + i)*(1+l.coords+l.classes);
                int is_obj = state.truth[truth_index];
                for (j = 0; j < l.n; ++j) {
                    int p_index = index + locations*l.classes + i*l.n + j;
                    l.delta[p_index] = l.noobject_scale*(0 - l.output[p_index]);
                    *(l.cost) += l.noobject_scale*pow(l.output[p_index], 2);
                    avg_anyobj += l.output[p_index];
                }
                //l.delta[j] = state.truth[j] - l.output[j];
                int best_index = -1;
                float best_iou = 0;
                float best_rmse = 20;
                if (!is_obj){
                    continue;
                }
                int class_index = index + i*l.classes;
                for(j = 0; j < l.classes; ++j) {
                    l.delta[class_index+j] = l.class_scale * (state.truth[truth_index+1+j] - l.output[class_index+j]);
                    *(l.cost) += l.class_scale * pow(state.truth[truth_index+1+j] - l.output[class_index+j], 2);
                    if(state.truth[truth_index + 1 + j]) avg_cat += l.output[class_index+j];
                    avg_allcat += l.output[class_index+j];
                }
                box truth = float_to_box(state.truth + truth_index + 1 + l.classes);
                truth.x /= l.side;
                truth.y /= l.side;
                for(j = 0; j < l.n; ++j){
                    int box_index = index + locations*(l.classes + l.n) + (i*l.n + j) * l.coords;
                    box out = float_to_box(l.output + box_index);
                    out.x /= l.side;
                    out.y /= l.side;
                    if (l.sqrt){
                        out.w = out.w*out.w;
                        out.h = out.h*out.h;
                    }
                    float iou  = box_iou(out, truth);
                    //iou = 0;
                    float rmse = box_rmse(out, truth);
                    if(best_iou > 0 || iou > 0){
                        if(iou > best_iou){
                            best_iou = iou;
                            best_index = j;
                        }
                    }else{
                        if(rmse < best_rmse){
                            best_rmse = rmse;
                            best_index = j;
                        }
                    }
                }
                if(l.forced){
                    if(truth.w*truth.h < .1){
                        best_index = 1;
                    }else{
                        best_index = 0;
                    }
                }
                int box_index = index + locations*(l.classes + l.n) + (i*l.n + best_index) * l.coords;
                int tbox_index = truth_index + 1 + l.classes;
                box out = float_to_box(l.output + box_index);
                out.x /= l.side;
                out.y /= l.side;
                if (l.sqrt) {
                    out.w = out.w*out.w;
                    out.h = out.h*out.h;
                }
                float iou  = box_iou(out, truth);
                //printf("%d", best_index);
                int p_index = index + locations*l.classes + i*l.n + best_index;
                *(l.cost) -= l.noobject_scale * pow(l.output[p_index], 2);
                *(l.cost) += l.object_scale * pow(1-l.output[p_index], 2);
                avg_obj += l.output[p_index];
                l.delta[p_index] = l.object_scale * (1.-l.output[p_index]);
                if(l.rescore){
                    l.delta[p_index] = l.object_scale * (iou - l.output[p_index]);
                }
                l.delta[box_index+0] = l.coord_scale*(state.truth[tbox_index + 0] - l.output[box_index + 0]);
                l.delta[box_index+1] = l.coord_scale*(state.truth[tbox_index + 1] - l.output[box_index + 1]);
                l.delta[box_index+2] = l.coord_scale*(state.truth[tbox_index + 2] - l.output[box_index + 2]);
                l.delta[box_index+3] = l.coord_scale*(state.truth[tbox_index + 3] - l.output[box_index + 3]);
                if(l.sqrt){
                    l.delta[box_index+2] = l.coord_scale*(sqrt(state.truth[tbox_index + 2]) - l.output[box_index + 2]);
                    l.delta[box_index+3] = l.coord_scale*(sqrt(state.truth[tbox_index + 3]) - l.output[box_index + 3]);
                }
                *(l.cost) += pow(1-iou, 2);
                avg_iou += iou;
                ++count;
            }
            if(l.softmax){
                gradient_array(l.output + index + locations*l.classes, locations*l.n*(1+l.coords),
                        LOGISTIC, l.delta + index + locations*l.classes);
            }
            if (l.object_logistic) {
                int p_index = index + locations*l.classes;
                gradient_array(l.output + p_index, locations*l.n, LOGISTIC, l.delta + p_index);
            }
            /*
               if(l.rescore){
               for (j = offset; j < offset+l.classes; ++j) {
               if(state.truth[j]) state.truth[j] = iou;
               l.delta[j] =  state.truth[j] - l.output[j];
               }
               }
             */
            if (l.class_logistic) {
                int class_index = index;
                gradient_array(l.output + class_index, locations*l.classes, LOGISTIC, l.delta + class_index);
            }
            if (l.coord_logistic) {
                    int coord_index = index + locations*(l.classes + l.n);
                gradient_array(l.output + coord_index, locations*l.n*l.coords, LOGISTIC, l.delta + coord_index);
            }
            //printf("\n");
        }
        printf("Avg IOU: %f\n", avg_iou/count);
        printf("Region Avg IOU: %f, Pos Cat: %f, All Cat: %f, Pos Obj: %f, Any Obj: %f, count: %d\n", avg_iou/count, avg_cat/count, avg_allcat/(count*l.classes), avg_obj/count, avg_anyobj/(l.batch*locations*l.n), count);
    }
}
void backward_region_layer(const region_layer l, network_state state)
{
    axpy_cpu(l.batch*l.inputs, 1, l.delta_gpu, 1, state.delta, 1);
    //copy_cpu(l.batch*l.inputs, l.delta_gpu, 1, state.delta, 1);
    axpy_cpu(l.batch*l.inputs, 1, l.delta, 1, state.delta, 1);
}
#ifdef GPU
void forward_region_layer_gpu(const region_layer l, network_state state)
{
    if(!state.train){
        copy_ongpu(l.batch*l.inputs, state.input, 1, l.output_gpu, 1);
        return;
    }
    float *in_cpu = calloc(l.batch*l.inputs, sizeof(float));
    float *truth_cpu = 0;
    if(state.truth){
        truth_cpu = calloc(l.batch*l.outputs, sizeof(float));
        cuda_pull_array(state.truth, truth_cpu, l.batch*l.outputs);
        int num_truth = l.batch*l.side*l.side*(1+l.coords+l.classes);
        truth_cpu = calloc(num_truth, sizeof(float));
        cuda_pull_array(state.truth, truth_cpu, num_truth);
    }
    cuda_pull_array(state.input, in_cpu, l.batch*l.inputs);
    network_state cpu_state;
@@ -147,7 +245,7 @@
    cpu_state.input = in_cpu;
    forward_region_layer(l, cpu_state);
    cuda_push_array(l.output_gpu, l.output, l.batch*l.outputs);
    cuda_push_array(l.delta_gpu, l.delta, l.batch*l.outputs);
    cuda_push_array(l.delta_gpu, l.delta, l.batch*l.inputs);
    free(cpu_state.input);
    if(cpu_state.truth) free(cpu_state.truth);
}