From 01bec657080d123a5b44a10ec94ebe132e20b1f3 Mon Sep 17 00:00:00 2001
From: IlyaOvodov <b@ovdv.ru>
Date: Wed, 30 May 2018 16:51:55 +0000
Subject: [PATCH] "channel" parameter of [net] is used in detector when preparing images for net.

---
 src/image.c         |   42 +++++++++++++--------
 src/demo.c          |    4 +-
 src/data.c          |   20 +++++----
 src/detector.c      |    9 +++-
 src/http_stream.cpp |   25 ++++++++----
 src/data.h          |    5 +-
 6 files changed, 64 insertions(+), 41 deletions(-)

diff --git a/src/data.c b/src/data.c
index a15bc1d..dc803f0 100644
--- a/src/data.c
+++ b/src/data.c
@@ -687,8 +687,9 @@
 
 #include "http_stream.h"
 
-data load_data_detection(int n, char **paths, int m, int w, int h, int boxes, int classes, int use_flip, float jitter, float hue, float saturation, float exposure, int small_object)
+data load_data_detection(int n, char **paths, int m, int w, int h, int c, int boxes, int classes, int use_flip, float jitter, float hue, float saturation, float exposure, int small_object)
 {
+    c = c ? c : 3;
     char **random_paths = get_random_paths(paths, n, m);
     int i;
     data d = {0};
@@ -696,13 +697,13 @@
 
     d.X.rows = n;
     d.X.vals = calloc(d.X.rows, sizeof(float*));
-    d.X.cols = h*w*3;
+    d.X.cols = h*w*c;
 
     d.y = make_matrix(n, 5*boxes);
     for(i = 0; i < n; ++i){
 		const char *filename = random_paths[i];
 
-		int flag = 1;
+		int flag = (c >= 3);
 		IplImage *src;
 		if ((src = cvLoadImage(filename, flag)) == 0)
 		{
@@ -754,8 +755,9 @@
     return d;
 }
 #else	// OPENCV
-data load_data_detection(int n, char **paths, int m, int w, int h, int boxes, int classes, int use_flip, float jitter, float hue, float saturation, float exposure, int small_object)
+data load_data_detection(int n, char **paths, int m, int w, int h, int c, int boxes, int classes, int use_flip, float jitter, float hue, float saturation, float exposure, int small_object)
 {
+    c = c ? c : 3;
 	char **random_paths = get_random_paths(paths, n, m);
 	int i;
 	data d = { 0 };
@@ -763,11 +765,11 @@
 
 	d.X.rows = n;
 	d.X.vals = calloc(d.X.rows, sizeof(float*));
-	d.X.cols = h*w * 3;
+	d.X.cols = h*w*c;
 
 	d.y = make_matrix(n, 5 * boxes);
 	for (i = 0; i < n; ++i) {
-		image orig = load_image_color(random_paths[i], 0, 0);
+		image orig = load_image(random_paths[i], 0, 0, c);
 
 		int oh = orig.h;
 		int ow = orig.w;
@@ -827,16 +829,16 @@
     } else if (a.type == REGION_DATA){
         *a.d = load_data_region(a.n, a.paths, a.m, a.w, a.h, a.num_boxes, a.classes, a.jitter, a.hue, a.saturation, a.exposure);
     } else if (a.type == DETECTION_DATA){
-        *a.d = load_data_detection(a.n, a.paths, a.m, a.w, a.h, a.num_boxes, a.classes, a.flip, a.jitter, a.hue, a.saturation, a.exposure, a.small_object);
+        *a.d = load_data_detection(a.n, a.paths, a.m, a.w, a.h, a.c, a.num_boxes, a.classes, a.flip, a.jitter, a.hue, a.saturation, a.exposure, a.small_object);
     } else if (a.type == SWAG_DATA){
         *a.d = load_data_swag(a.paths, a.n, a.classes, a.jitter);
     } else if (a.type == COMPARE_DATA){
         *a.d = load_data_compare(a.n, a.paths, a.m, a.classes, a.w, a.h);
     } else if (a.type == IMAGE_DATA){
-        *(a.im) = load_image_color(a.path, 0, 0);
+        *(a.im) = load_image(a.path, 0, 0, a.c);
         *(a.resized) = resize_image(*(a.im), a.w, a.h);
 	}else if (a.type == LETTERBOX_DATA) {
-		*(a.im) = load_image_color(a.path, 0, 0);
+		*(a.im) = load_image(a.path, 0, 0, a.c);
 		*(a.resized) = letterbox_image(*(a.im), a.w, a.h);
     } else if (a.type == TAG_DATA){
         *a.d = load_data_tag(a.paths, a.n, a.m, a.classes, a.flip, a.min, a.max, a.size, a.angle, a.aspect, a.hue, a.saturation, a.exposure);
diff --git a/src/data.h b/src/data.h
index 57f4702..b46143f 100644
--- a/src/data.h
+++ b/src/data.h
@@ -44,7 +44,8 @@
     char **labels;
     int h;
     int w;
-    int out_w;
+	int c; // color depth
+	int out_w;
     int out_h;
     int nh;
     int nw;
@@ -84,7 +85,7 @@
 data load_data_captcha(char **paths, int n, int m, int k, int w, int h);
 data load_data_captcha_encode(char **paths, int n, int m, int w, int h);
 data load_data_old(char **paths, int n, int m, char **labels, int k, int w, int h);
-data load_data_detection(int n, char **paths, int m, int w, int h, int boxes, int classes, int use_flip, float jitter, float hue, float saturation, float exposure, int small_object);
+data load_data_detection(int n, char **paths, int m, int w, int h, int c, int boxes, int classes, int use_flip, float jitter, float hue, float saturation, float exposure, int small_object);
 data load_data_tag(char **paths, int n, int m, int k, int use_flip, int min, int max, int size, float angle, float aspect, float hue, float saturation, float exposure);
 matrix load_image_augment_paths(char **paths, int n, int use_flip, int min, int max, int size, float angle, float aspect, float hue, float saturation, float exposure);
 data load_data_super(char **paths, int n, int m, int w, int h, int scale);
diff --git a/src/demo.c b/src/demo.c
index 81eddb2..72f9c03 100644
--- a/src/demo.c
+++ b/src/demo.c
@@ -51,7 +51,7 @@
 void draw_detections_cv(IplImage* show_img, int num, float thresh, box *boxes, float **probs, char **names, image **alphabet, int classes);
 void draw_detections_cv_v3(IplImage* show_img, detection *dets, int num, float thresh, char **names, image **alphabet, int classes, int ext_output);
 void show_image_cv_ipl(IplImage *disp, const char *name);
-image get_image_from_stream_resize(CvCapture *cap, int w, int h, IplImage** in_img, int cpp_video_capture);
+image get_image_from_stream_resize(CvCapture *cap, int w, int h, int c, IplImage** in_img, int cpp_video_capture);
 IplImage* in_img;
 IplImage* det_img;
 IplImage* show_img;
@@ -61,7 +61,7 @@
 void *fetch_in_thread(void *ptr)
 {
     //in = get_image_from_stream(cap);
-	in_s = get_image_from_stream_resize(cap, net.w, net.h, &in_img, cpp_video_capture);
+	in_s = get_image_from_stream_resize(cap, net.w, net.h, net.c, &in_img, cpp_video_capture);
     if(!in_s.data){
         //error("Stream closed.");
 		printf("Stream closed.\n");
diff --git a/src/detector.c b/src/detector.c
index 6fc6b67..202fbf9 100644
--- a/src/detector.c
+++ b/src/detector.c
@@ -87,7 +87,8 @@
     load_args args = {0};
     args.w = net.w;
     args.h = net.h;
-    args.paths = paths;
+	args.c = net.c;
+	args.paths = paths;
     args.n = imgs;
     args.m = plist->size;
     args.classes = classes;
@@ -388,6 +389,7 @@
 	load_args args = { 0 };
 	args.w = net.w;
 	args.h = net.h;
+	args.c = net.c;
 	args.type = IMAGE_DATA;
 	//args.type = LETTERBOX_DATA;
 
@@ -482,7 +484,7 @@
 
 	for (i = 0; i < m; ++i) {
 		char *path = paths[i];
-		image orig = load_image_color(path, 0, 0);
+		image orig = load_image(path, 0, 0, net.c);
 		image sized = resize_image(orig, net.w, net.h);
 		char *id = basecfg(path);
 		network_predict(net, sized.data);
@@ -595,6 +597,7 @@
 	load_args args = { 0 };
 	args.w = net.w;
 	args.h = net.h;
+	args.c = net.c;
 	args.type = IMAGE_DATA;
 	//args.type = LETTERBOX_DATA;
 
@@ -1093,7 +1096,7 @@
             if(!input) return;
             strtok(input, "\n");
         }
-        image im = load_image_color(input,0,0);
+        image im = load_image(input,0,0,net.c);
 		int letterbox = 0;
         //image sized = resize_image(im, net.w, net.h);
 		image sized = letterbox_image(im, net.w, net.h); letterbox = 1;
diff --git a/src/http_stream.cpp b/src/http_stream.cpp
index 9192f75..acb6c8e 100644
--- a/src/http_stream.cpp
+++ b/src/http_stream.cpp
@@ -283,19 +283,26 @@
 
 	// HSV augmentation
 	// CV_BGR2HSV, CV_RGB2HSV, CV_HSV2BGR, CV_HSV2RGB
-	cv::Mat hsv_src;
-	cvtColor(sized, hsv_src, CV_BGR2HSV);	// also BGR -> RGB
+	if (ipl->nChannels >= 3)
+	{
+		cv::Mat hsv_src;
+		cvtColor(sized, hsv_src, CV_BGR2HSV);	// also BGR -> RGB
 	
-	std::vector<cv::Mat> hsv;
-	cv::split(hsv_src, hsv);
+		std::vector<cv::Mat> hsv;
+		cv::split(hsv_src, hsv);
 
-	hsv[1] *= dsat;
-	hsv[2] *= dexp;
-	hsv[0] += 179 * dhue;
+		hsv[1] *= dsat;
+		hsv[2] *= dexp;
+		hsv[0] += 179 * dhue;
 
-	cv::merge(hsv, hsv_src);
+		cv::merge(hsv, hsv_src);
 
-	cvtColor(hsv_src, sized, CV_HSV2RGB);	// now RGB instead of BGR
+		cvtColor(hsv_src, sized, CV_HSV2RGB);	// now RGB instead of BGR
+	}
+	else
+	{
+		sized *= dexp;
+	}
 
 	// Mat -> IplImage -> image
 	IplImage src = sized;
diff --git a/src/image.c b/src/image.c
index 7545e7d..3ffd552 100644
--- a/src/image.c
+++ b/src/image.c
@@ -957,7 +957,7 @@
 {
     IplImage* src = 0;
     int flag = -1;
-    if (channels == 0) flag = -1;
+    if (channels == 0) flag = 1;
     else if (channels == 1) flag = 0;
     else if (channels == 3) flag = 1;
     else {
@@ -975,7 +975,8 @@
     }
     image out = ipl_to_image(src);
     cvReleaseImage(&src);
-    rgbgr_image(out);
+	if (out.c > 1)
+		rgbgr_image(out);
     return out;
 }
 
@@ -1010,8 +1011,9 @@
 	return im;
 }
 
-image get_image_from_stream_resize(CvCapture *cap, int w, int h, IplImage** in_img, int cpp_video_capture)
+image get_image_from_stream_resize(CvCapture *cap, int w, int h, int c, IplImage** in_img, int cpp_video_capture)
 {
+	c = c ? c : 3;
 	IplImage* src;
 	if (cpp_video_capture) {
 		static int once = 1;
@@ -1029,14 +1031,15 @@
 
 	if (!src) return make_empty_image(0, 0, 0);
 	if (src->width < 1 || src->height < 1 || src->nChannels < 1) return make_empty_image(0, 0, 0);
-	IplImage* new_img = cvCreateImage(cvSize(w, h), IPL_DEPTH_8U, 3);
-	*in_img = cvCreateImage(cvSize(src->width, src->height), IPL_DEPTH_8U, 3);
+	IplImage* new_img = cvCreateImage(cvSize(w, h), IPL_DEPTH_8U, c);
+	*in_img = cvCreateImage(cvSize(src->width, src->height), IPL_DEPTH_8U, c);
 	cvResize(src, *in_img, CV_INTER_LINEAR);
 	cvResize(src, new_img, CV_INTER_LINEAR);
 	image im = ipl_to_image(new_img);
 	cvReleaseImage(&new_img);
 	if (cpp_video_capture) cvReleaseImage(&src);
-	rgbgr_image(im);
+	if (c>1)
+		rgbgr_image(im);
 	return im;
 }
 
@@ -1589,16 +1592,23 @@
 
 void distort_image(image im, float hue, float sat, float val)
 {
-    rgb_to_hsv(im);
-    scale_image_channel(im, 1, sat);
-    scale_image_channel(im, 2, val);
-    int i;
-    for(i = 0; i < im.w*im.h; ++i){
-        im.data[i] = im.data[i] + hue;
-        if (im.data[i] > 1) im.data[i] -= 1;
-        if (im.data[i] < 0) im.data[i] += 1;
-    }
-    hsv_to_rgb(im);
+	if (im.c >= 3)
+	{
+		rgb_to_hsv(im);
+		scale_image_channel(im, 1, sat);
+		scale_image_channel(im, 2, val);
+		int i;
+		for(i = 0; i < im.w*im.h; ++i){
+			im.data[i] = im.data[i] + hue;
+			if (im.data[i] > 1) im.data[i] -= 1;
+			if (im.data[i] < 0) im.data[i] += 1;
+		}
+		hsv_to_rgb(im);
+	}
+	else
+	{
+		scale_image_channel(im, 0, val);
+	}
     constrain_image(im);
 }
 

--
Gitblit v1.10.0