| | |
| | | } |
| | | } |
| | | |
| | | void predict_detections(network net, data d, float threshold, int offset, int classes, int objectness, int background, int num_boxes, int per_box) |
| | | { |
| | | matrix pred = network_predict_data(net, d); |
| | | int j, k, class; |
| | | for(j = 0; j < pred.rows; ++j){ |
| | | for(k = 0; k < pred.cols; k += per_box){ |
| | | float scale = 1.; |
| | | int index = k/per_box; |
| | | int row = index / num_boxes; |
| | | int col = index % num_boxes; |
| | | if (objectness) scale = 1.-pred.vals[j][k]; |
| | | for (class = 0; class < classes; ++class){ |
| | | int ci = k+classes+(background || objectness); |
| | | float x = (pred.vals[j][ci + 0] + col)/num_boxes; |
| | | float y = (pred.vals[j][ci + 1] + row)/num_boxes; |
| | | float w = pred.vals[j][ci + 2]; // distance_from_edge(row, num_boxes); |
| | | float h = pred.vals[j][ci + 3]; // distance_from_edge(col, num_boxes); |
| | | w = pow(w, 2); |
| | | h = pow(h, 2); |
| | | float prob = scale*pred.vals[j][k+class+(background || objectness)]; |
| | | if(prob < threshold) continue; |
| | | printf("%d %d %f %f %f %f %f\n", offset + j, class, prob, x, y, w, h); |
| | | } |
| | | } |
| | | } |
| | | free_matrix(pred); |
| | | } |
| | | |
| | | void validate_detection(char *cfgfile, char *weightfile) |
| | | { |
| | | network net = parse_network_cfg(cfgfile); |
| | | if(weightfile){ |
| | | load_weights(&net, weightfile); |
| | | } |
| | | detection_layer layer = get_network_detection_layer(net); |
| | | fprintf(stderr, "Learning Rate: %g, Momentum: %g, Decay: %g\n", net.learning_rate, net.momentum, net.decay); |
| | | srand(time(0)); |
| | | |
| | | list *plist = get_paths("/home/pjreddie/data/voc/test.txt"); |
| | | char **paths = (char **)list_to_array(plist); |
| | | |
| | | int classes = layer.classes; |
| | | int objectness = layer.objectness; |
| | | int background = layer.background; |
| | | int num_boxes = sqrt(get_detection_layer_locations(layer)); |
| | | |
| | | int per_box = 4+classes+(background || objectness); |
| | | int num_output = num_boxes*num_boxes*per_box; |
| | | |
| | | int m = plist->size; |
| | | int i = 0; |
| | | int splits = 100; |
| | | |
| | | int nthreads = 4; |
| | | int t; |
| | | data *val = calloc(nthreads, sizeof(data)); |
| | | data *buf = calloc(nthreads, sizeof(data)); |
| | | pthread_t *thr = calloc(nthreads, sizeof(data)); |
| | | |
| | | time_t start = time(0); |
| | | |
| | | for(t = 0; t < nthreads; ++t){ |
| | | int num = (i+1+t)*m/splits - (i+t)*m/splits; |
| | | char **part = paths+((i+t)*m/splits); |
| | | thr[t] = load_data_thread(part, num, 0, 0, num_output, net.w, net.h, &(buf[t])); |
| | | } |
| | | |
| | | for(i = nthreads; i <= splits; i += nthreads){ |
| | | for(t = 0; t < nthreads; ++t){ |
| | | pthread_join(thr[t], 0); |
| | | val[t] = buf[t]; |
| | | } |
| | | for(t = 0; t < nthreads && i < splits; ++t){ |
| | | int num = (i+1+t)*m/splits - (i+t)*m/splits; |
| | | char **part = paths+((i+t)*m/splits); |
| | | thr[t] = load_data_thread(part, num, 0, 0, num_output, net.w, net.h, &(buf[t])); |
| | | } |
| | | |
| | | fprintf(stderr, "%d\n", i); |
| | | for(t = 0; t < nthreads; ++t){ |
| | | predict_detections(net, val[t], .001, (i-nthreads+t)*m/splits, classes, objectness, background, num_boxes, per_box); |
| | | free_data(val[t]); |
| | | } |
| | | } |
| | | fprintf(stderr, "Total Detection Time: %f Seconds\n", (double)(time(0) - start)); |
| | | } |
| | | |
| | | |
| | | void convert_detections(float *predictions, int classes, int objectness, int background, int num_boxes, int w, int h, float thresh, float **probs, box *boxes) |
| | | { |
| | | int i,j; |
| | |
| | | } |
| | | } |
| | | |
| | | void valid_detection(char *cfgfile, char *weightfile) |
| | | void validate_detection(char *cfgfile, char *weightfile) |
| | | { |
| | | network net = parse_network_cfg(cfgfile); |
| | | if(weightfile){ |
| | |
| | | fprintf(stderr, "Learning Rate: %g, Momentum: %g, Decay: %g\n", net.learning_rate, net.momentum, net.decay); |
| | | srand(time(0)); |
| | | |
| | | char *base = "/home/pjreddie/data/voc/devkit/results/VOC2012/Main/comp4_det_test_"; |
| | | list *plist = get_paths("/home/pjreddie/data/voc/test.txt"); |
| | | char *base = "results/comp4_det_test_"; |
| | | list *plist = get_paths("data/voc.2012test.list"); |
| | | char **paths = (char **)list_to_array(plist); |
| | | |
| | | int classes = layer.classes; |
| | |
| | | if(0==strcmp(argv[2], "test")) test_detection(cfg, weights, filename); |
| | | else if(0==strcmp(argv[2], "train")) train_detection(cfg, weights); |
| | | else if(0==strcmp(argv[2], "valid")) validate_detection(cfg, weights); |
| | | else if(0==strcmp(argv[2], "run")) valid_detection(cfg, weights); |
| | | } |