| | |
| | | activate_array(l.output + index + offset, locations*l.n*(1+l.coords), LOGISTIC); |
| | | } |
| | | } |
| | | if (l.object_logistic) { |
| | | for(b = 0; b < l.batch; ++b){ |
| | | int index = b*l.inputs; |
| | | int p_index = index + locations*l.classes; |
| | | activate_array(l.output + p_index, locations*l.n, LOGISTIC); |
| | | } |
| | | } |
| | | |
| | | if (l.coord_logistic) { |
| | | for(b = 0; b < l.batch; ++b){ |
| | | int index = b*l.inputs; |
| | | int coord_index = index + locations*(l.classes + l.n); |
| | | activate_array(l.output + coord_index, locations*l.n*l.coords, LOGISTIC); |
| | | } |
| | | } |
| | | |
| | | if (l.class_logistic) { |
| | | for(b = 0; b < l.batch; ++b){ |
| | | int class_index = b*l.inputs; |
| | | activate_array(l.output + class_index, locations*l.classes, LOGISTIC); |
| | | } |
| | | } |
| | | |
| | | if(state.train){ |
| | | float avg_iou = 0; |
| | |
| | | float best_rmse = 20; |
| | | |
| | | if (!is_obj){ |
| | | //printf("."); |
| | | continue; |
| | | } |
| | | |
| | |
| | | } |
| | | |
| | | float iou = box_iou(out, truth); |
| | | //iou = 0; |
| | | float rmse = box_rmse(out, truth); |
| | | if(best_iou > 0 || iou > 0){ |
| | | if(iou > best_iou){ |
| | |
| | | gradient_array(l.output + index + locations*l.classes, locations*l.n*(1+l.coords), |
| | | LOGISTIC, l.delta + index + locations*l.classes); |
| | | } |
| | | if (l.object_logistic) { |
| | | int p_index = index + locations*l.classes; |
| | | gradient_array(l.output + p_index, locations*l.n, LOGISTIC, l.delta + p_index); |
| | | } |
| | | |
| | | if (l.class_logistic) { |
| | | int class_index = index; |
| | | gradient_array(l.output + class_index, locations*l.classes, LOGISTIC, l.delta + class_index); |
| | | } |
| | | |
| | | if (l.coord_logistic) { |
| | | int coord_index = index + locations*(l.classes + l.n); |
| | | gradient_array(l.output + coord_index, locations*l.n*l.coords, LOGISTIC, l.delta + coord_index); |
| | | } |
| | | //printf("\n"); |
| | | } |
| | | printf("Region Avg IOU: %f, Pos Cat: %f, All Cat: %f, Pos Obj: %f, Any Obj: %f, count: %d\n", avg_iou/count, avg_cat/count, avg_allcat/(count*l.classes), avg_obj/count, avg_anyobj/(l.batch*locations*l.n), count); |