From 738cd4c2d7abcf62b85b759a5765b08380ee90e8 Mon Sep 17 00:00:00 2001
From: Joseph Redmon <pjreddie@gmail.com>
Date: Thu, 17 Apr 2014 00:05:29 +0000
Subject: [PATCH] Visualizations?

---
 src/tests.c |   92 +++++++++++++++++++++++++++++++++++++++++++++-
 1 files changed, 90 insertions(+), 2 deletions(-)

diff --git a/src/tests.c b/src/tests.c
index 5d9136d..a6c3cd3 100644
--- a/src/tests.c
+++ b/src/tests.c
@@ -1,4 +1,5 @@
 #include "connected_layer.h"
+
 //#include "old_conv.h"
 #include "convolutional_layer.h"
 #include "maxpool_layer.h"
@@ -223,7 +224,7 @@
 
 void test_visualize()
 {
-    network net = parse_network_cfg("cfg/imagenet.cfg");
+    network net = parse_network_cfg("cfg/voc_imagenet.cfg");
     srand(2222222);
     visualize_network(net);
     cvWaitKey(0);
@@ -445,6 +446,12 @@
     }
 }
 
+void flip_network()
+{
+    network net = parse_network_cfg("cfg/voc_imagenet_orig.cfg");
+    save_network(net, "cfg/voc_imagenet_rev.cfg");
+}
+
 void train_VOC()
 {
     network net = parse_network_cfg("cfg/voc_start.cfg");
@@ -498,6 +505,7 @@
     IplImage *sized = cvCreateImage(cvSize(w,h), src->depth, src->nChannels);
     cvResize(src, sized, CV_INTER_LINEAR);
     image im = ipl_to_image(sized);
+    normalize_array(im.data, im.h*im.w*im.c);
     resize_network(net, im.h, im.w, im.c);
     forward_network(net, im.data);
     image out = get_network_image_layer(net, 6);
@@ -523,6 +531,69 @@
     free_image(out);
     cvReleaseImage(&src);
 }
+void visualize_imagenet_topk(char *filename)
+{
+    int i,j,k,l;
+    int topk = 10;
+    network net = parse_network_cfg("cfg/voc_imagenet.cfg");
+    list *plist = get_paths(filename);
+    node *n = plist->front;
+    int h = voc_size(1), w = voc_size(1);
+    int num = get_network_image(net).c;
+    image **vizs = calloc(num, sizeof(image*));
+    float **score = calloc(num, sizeof(float *));
+    for(i = 0; i < num; ++i){
+        vizs[i] = calloc(topk, sizeof(image));
+        for(j = 0; j < topk; ++j) vizs[i][j] = make_image(h,w,3);
+        score[i] = calloc(topk, sizeof(float));
+    }
+
+    while(n){
+        char *image_path = (char *)n->val;
+        image im = load_image(image_path, 0, 0);
+        n = n->next;
+        if(im.h < 200 || im.w < 200) continue;
+        printf("Processing %dx%d image\n", im.h, im.w);
+        resize_network(net, im.h, im.w, im.c);
+        //scale_image(im, 1./255);
+        translate_image(im, -144);
+        forward_network(net, im.data);
+        image out = get_network_image(net);
+
+        int dh = (im.h - h)/h;
+        int dw = (im.w - w)/w;
+        for(i = 0; i < out.h; ++i){
+            for(j = 0; j < out.w; ++j){
+                image sub = get_sub_image(im, dh*i, dw*j, h, w);
+                for(k = 0; k < out.c; ++k){
+                    float val = get_pixel(out, i, j, k);
+                    //printf("%f, ", val);
+                    image sub_c = copy_image(sub);
+                    for(l = 0; l < topk; ++l){
+                        if(val > score[k][l]){
+                            float swap = score[k][l];
+                            score[k][l] = val;
+                            val = swap;
+
+                            image swapi = vizs[k][l];
+                            vizs[k][l] = sub_c;
+                            sub_c = swapi;
+                        }
+                    }
+                    free_image(sub_c);
+                }
+                free_image(sub);
+            }
+        }
+        free_image(im);
+        //printf("\n");
+        image grid = grid_images(vizs, num, topk);
+        show_image(grid, "IMAGENET Visualization");
+        save_image(grid, "IMAGENET Grid");
+        free_image(grid);
+    }
+    //cvWaitKey(0);
+}
 
 void visualize_imagenet_features(char *filename)
 {
@@ -566,6 +637,20 @@
     cvWaitKey(0);
 }
 
+void visualize_cat()
+{
+    network net = parse_network_cfg("cfg/voc_imagenet.cfg");
+    image im = load_image("data/cat.png", 0, 0);
+    printf("Processing %dx%d image\n", im.h, im.w);
+    resize_network(net, im.h, im.w, im.c);
+    forward_network(net, im.data);
+    
+    image out = get_network_image(net);
+    visualize_network(net);
+    cvWaitKey(1000);
+    cvWaitKey(0);
+}
+
 void features_VOC_image(char *image_file, char *image_dir, char *out_dir)
 {
     int i,j;
@@ -693,7 +778,10 @@
     //features_VOC_image(argv[1], argv[2], argv[3]);
     //features_VOC_image_size(argv[1], atoi(argv[2]), atoi(argv[3]));
     //visualize_imagenet_features("data/assira/train.list");
-    visualize_imagenet_features("data/VOC2011.list");
+    visualize_imagenet_topk("data/VOC2011.list");
+    //visualize_cat();
+    //flip_network();
+    //test_visualize();
     fprintf(stderr, "Success!\n");
     //test_random_preprocess();
     //test_random_classify();

--
Gitblit v1.10.0