From 73f7aacf35ec9b1d0f9de9ddf38af0889f213e99 Mon Sep 17 00:00:00 2001
From: Joseph Redmon <pjreddie@gmail.com>
Date: Tue, 20 Sep 2016 18:34:49 +0000
Subject: [PATCH] better multigpu
---
src/yolo.c | 1
src/swag.c | 1
src/voxel.c | 1
src/cifar.c | 2
src/rnn.c | 1
src/classifier.c | 204 +++++++++++------
src/dice.c | 5
src/go.c | 2
src/coco.c | 1
src/rnn_vid.c | 1
src/convolutional_kernels.cu | 6
src/compare.c | 1
src/network.c | 2
src/network.h | 3
src/network_kernels.cu | 243 ++++++++++++++++-----
src/connected_layer.c | 2
src/data.c | 96 ++++++--
src/super.c | 1
src/data.h | 10
src/writing.c | 1
src/tag.c | 1
src/convolutional_layer.c | 11
src/parser.c | 4
src/detector.c | 1
src/captcha.c | 1
src/blas_kernels.cu | 2
26 files changed, 404 insertions(+), 200 deletions(-)
diff --git a/src/blas_kernels.cu b/src/blas_kernels.cu
index 271f017..0391e2e 100644
--- a/src/blas_kernels.cu
+++ b/src/blas_kernels.cu
@@ -365,7 +365,7 @@
__global__ void constrain_kernel(int N, float ALPHA, float *X, int INCX)
{
int i = (blockIdx.x + blockIdx.y*gridDim.x) * blockDim.x + threadIdx.x;
- if(i < N) X[i*INCX] = min(ALPHA, max(-ALPHA, X[i*INCX]));
+ if(i < N) X[i*INCX] = fminf(ALPHA, fmaxf(-ALPHA, X[i*INCX]));
}
__global__ void supp_kernel(int N, float ALPHA, float *X, int INCX)
diff --git a/src/captcha.c b/src/captcha.c
index 79b4e4e..3d449b2 100644
--- a/src/captcha.c
+++ b/src/captcha.c
@@ -28,7 +28,6 @@
void train_captcha(char *cfgfile, char *weightfile)
{
- data_seed = time(0);
srand(time(0));
float avg_loss = -1;
char *base = basecfg(cfgfile);
diff --git a/src/cifar.c b/src/cifar.c
index de52bb8..af1b4d6 100644
--- a/src/cifar.c
+++ b/src/cifar.c
@@ -10,7 +10,6 @@
void train_cifar(char *cfgfile, char *weightfile)
{
- data_seed = time(0);
srand(time(0));
float avg_loss = -1;
char *base = basecfg(cfgfile);
@@ -59,7 +58,6 @@
void train_cifar_distill(char *cfgfile, char *weightfile)
{
- data_seed = time(0);
srand(time(0));
float avg_loss = -1;
char *base = basecfg(cfgfile);
diff --git a/src/classifier.c b/src/classifier.c
index 3424216..b42d010 100644
--- a/src/classifier.c
+++ b/src/classifier.c
@@ -55,10 +55,8 @@
void train_classifier_multi(char *datacfg, char *cfgfile, char *weightfile, int *gpus, int ngpus, int clear)
{
#ifdef GPU
- int nthreads = 8;
int i;
- data_seed = time(0);
srand(time(0));
float avg_loss = -1;
char *base = basecfg(cfgfile);
@@ -68,17 +66,20 @@
for(i = 0; i < ngpus; ++i){
cuda_set_device(gpus[i]);
nets[i] = parse_network_cfg(cfgfile);
- if(weightfile){
- load_weights(&(nets[i]), weightfile);
- }
if(clear) *nets[i].seen = 0;
+ if(weightfile){
+ load_weights(&nets[i], weightfile);
+ }
}
network net = nets[0];
+ for(i = 0; i < ngpus; ++i){
+ *nets[i].seen = *net.seen;
+ nets[i].learning_rate *= ngpus;
+ }
+
+ int imgs = net.batch * net.subdivisions * ngpus;
printf("Learning Rate: %g, Momentum: %g, Decay: %g\n", net.learning_rate, net.momentum, net.decay);
- int imgs = net.batch*ngpus/nthreads;
- assert(net.batch*ngpus % nthreads == 0);
-
list *options = read_data_cfg(datacfg);
char *backup_directory = option_find_str(options, "backup", "/backup/");
@@ -93,13 +94,10 @@
int N = plist->size;
clock_t time;
- pthread_t *load_threads = calloc(nthreads, sizeof(pthread_t));
- data *trains = calloc(nthreads, sizeof(data));
- data *buffers = calloc(nthreads, sizeof(data));
-
load_args args = {0};
args.w = net.w;
args.h = net.h;
+ args.threads = 16;
args.min = net.min_crop;
args.max = net.max_crop;
@@ -117,36 +115,28 @@
args.labels = labels;
args.type = CLASSIFICATION_DATA;
- for(i = 0; i < nthreads; ++i){
- args.d = buffers + i;
- load_threads[i] = load_data_in_thread(args);
- }
+ data train;
+ data buffer;
+ pthread_t load_thread;
+ args.d = &buffer;
+ load_thread = load_data(args);
int epoch = (*net.seen)/N;
while(get_current_batch(net) < net.max_batches || net.max_batches == 0){
time=clock();
- for(i = 0; i < nthreads; ++i){
- pthread_join(load_threads[i], 0);
- trains[i] = buffers[i];
- }
- data train = concat_datas(trains, nthreads);
- for(i = 0; i < nthreads; ++i){
- args.d = buffers + i;
- load_threads[i] = load_data_in_thread(args);
- }
+ pthread_join(load_thread, 0);
+ train = buffer;
+ load_thread = load_data(args);
printf("Loaded: %lf seconds\n", sec(clock()-time));
time=clock();
- float loss = train_networks(nets, ngpus, train);
+ float loss = train_networks(nets, ngpus, train, 4);
if(avg_loss == -1) avg_loss = loss;
avg_loss = avg_loss*.9 + loss*.1;
printf("%d, %.3f: %f, %f avg, %f rate, %lf seconds, %d images\n", get_current_batch(net), (float)(*net.seen)/N, loss, avg_loss, get_current_rate(net), sec(clock()-time), *net.seen);
free_data(train);
- for(i = 0; i < nthreads; ++i){
- free_data(trains[i]);
- }
if(*net.seen/N > epoch){
epoch = *net.seen/N;
char buff[256];
@@ -163,14 +153,6 @@
sprintf(buff, "%s/%s.weights", backup_directory, base);
save_weights(net, buff);
- for(i = 0; i < nthreads; ++i){
- pthread_join(load_threads[i], 0);
- free_data(buffers[i]);
- }
- free(buffers);
- free(trains);
- free(load_threads);
-
free_network(net);
free_ptrs((void**)labels, classes);
free_ptrs((void**)paths, plist->size);
@@ -182,10 +164,6 @@
void train_classifier(char *datacfg, char *cfgfile, char *weightfile, int clear)
{
- int nthreads = 8;
- int i;
-
- data_seed = time(0);
srand(time(0));
float avg_loss = -1;
char *base = basecfg(cfgfile);
@@ -195,10 +173,10 @@
load_weights(&net, weightfile);
}
if(clear) *net.seen = 0;
- printf("Learning Rate: %g, Momentum: %g, Decay: %g\n", net.learning_rate, net.momentum, net.decay);
- int imgs = net.batch*net.subdivisions/nthreads;
- assert(net.batch*net.subdivisions % nthreads == 0);
+ int imgs = net.batch * net.subdivisions;
+
+ printf("Learning Rate: %g, Momentum: %g, Decay: %g\n", net.learning_rate, net.momentum, net.decay);
list *options = read_data_cfg(datacfg);
char *backup_directory = option_find_str(options, "backup", "/backup/");
@@ -213,13 +191,10 @@
int N = plist->size;
clock_t time;
- pthread_t *load_threads = calloc(nthreads, sizeof(pthread_t));
- data *trains = calloc(nthreads, sizeof(data));
- data *buffers = calloc(nthreads, sizeof(data));
-
load_args args = {0};
args.w = net.w;
args.h = net.h;
+ args.threads = 8;
args.min = net.min_crop;
args.max = net.max_crop;
@@ -237,24 +212,19 @@
args.labels = labels;
args.type = CLASSIFICATION_DATA;
- for(i = 0; i < nthreads; ++i){
- args.d = buffers + i;
- load_threads[i] = load_data_in_thread(args);
- }
+ data train;
+ data buffer;
+ pthread_t load_thread;
+ args.d = &buffer;
+ load_thread = load_data(args);
int epoch = (*net.seen)/N;
while(get_current_batch(net) < net.max_batches || net.max_batches == 0){
time=clock();
- for(i = 0; i < nthreads; ++i){
- pthread_join(load_threads[i], 0);
- trains[i] = buffers[i];
- }
- data train = concat_datas(trains, nthreads);
- for(i = 0; i < nthreads; ++i){
- args.d = buffers + i;
- load_threads[i] = load_data_in_thread(args);
- }
+ pthread_join(load_thread, 0);
+ train = buffer;
+ load_thread = load_data(args);
printf("Loaded: %lf seconds\n", sec(clock()-time));
time=clock();
@@ -271,13 +241,11 @@
#endif
float loss = train_network(net, train);
+ free_data(train);
+
if(avg_loss == -1) avg_loss = loss;
avg_loss = avg_loss*.9 + loss*.1;
printf("%d, %.3f: %f, %f avg, %f rate, %lf seconds, %d images\n", get_current_batch(net), (float)(*net.seen)/N, loss, avg_loss, get_current_rate(net), sec(clock()-time), *net.seen);
- free_data(train);
- for(i = 0; i < nthreads; ++i){
- free_data(trains[i]);
- }
if(*net.seen/N > epoch){
epoch = *net.seen/N;
char buff[256];
@@ -294,14 +262,6 @@
sprintf(buff, "%s/%s.weights", backup_directory, base);
save_weights(net, buff);
- for(i = 0; i < nthreads; ++i){
- pthread_join(load_threads[i], 0);
- free_data(buffers[i]);
- }
- free(buffers);
- free(trains);
- free(load_threads);
-
free_network(net);
free_ptrs((void**)labels, classes);
free_ptrs((void**)paths, plist->size);
@@ -934,7 +894,19 @@
int w = x2 - x1 - 2*border;
float *predictions = network_predict(net, in_s.data);
- float curr_threat = predictions[0] * 0 + predictions[1] * .6 + predictions[2];
+ float curr_threat = 0;
+ if(1){
+ curr_threat = predictions[0] * 0 +
+ predictions[1] * .6 +
+ predictions[2];
+ } else {
+ curr_threat = predictions[218] +
+ predictions[539] +
+ predictions[540] +
+ predictions[368] +
+ predictions[369] +
+ predictions[370];
+ }
threat = roll * curr_threat + (1-roll) * threat;
draw_box_width(out, x2 + border, y1 + .02*h, x2 + .5 * w, y1 + .02*h + border, border, 0,0,0);
@@ -970,7 +942,7 @@
top_predictions(net, top, indexes);
char buff[256];
sprintf(buff, "/home/pjreddie/tmp/threat_%06d", count);
- save_image(out, buff);
+ //save_image(out, buff);
printf("\033[2J");
printf("\033[1;1H");
@@ -981,7 +953,7 @@
printf("%.1f%%: %s\n", predictions[index]*100, names[index]);
}
- if(0){
+ if(1){
show_image(out, "Threat");
cvWaitKey(10);
}
@@ -997,6 +969,85 @@
}
+void gun_classifier(char *datacfg, char *cfgfile, char *weightfile, int cam_index, const char *filename)
+{
+#ifdef OPENCV
+ int bad_cats[] = {218, 539, 540, 1213, 1501, 1742, 1911, 2415, 4348, 19223, 368, 369, 370, 1133, 1200, 1306, 2122, 2301, 2537, 2823, 3179, 3596, 3639, 4489, 5107, 5140, 5289, 6240, 6631, 6762, 7048, 7171, 7969, 7984, 7989, 8824, 8927, 9915, 10270, 10448, 13401, 15205, 18358, 18894, 18895, 19249, 19697};
+
+ printf("Classifier Demo\n");
+ network net = parse_network_cfg(cfgfile);
+ if(weightfile){
+ load_weights(&net, weightfile);
+ }
+ set_batch_network(&net, 1);
+ list *options = read_data_cfg(datacfg);
+
+ srand(2222222);
+ CvCapture * cap;
+
+ if(filename){
+ cap = cvCaptureFromFile(filename);
+ }else{
+ cap = cvCaptureFromCAM(cam_index);
+ }
+
+ int top = option_find_int(options, "top", 1);
+
+ char *name_list = option_find_str(options, "names", 0);
+ char **names = get_labels(name_list);
+
+ int *indexes = calloc(top, sizeof(int));
+
+ if(!cap) error("Couldn't connect to webcam.\n");
+ cvNamedWindow("Threat Detection", CV_WINDOW_NORMAL);
+ cvResizeWindow("Threat Detection", 512, 512);
+ float fps = 0;
+ int i;
+
+ while(1){
+ struct timeval tval_before, tval_after, tval_result;
+ gettimeofday(&tval_before, NULL);
+
+ image in = get_image_from_stream(cap);
+ image in_s = resize_image(in, net.w, net.h);
+ show_image(in, "Threat Detection");
+
+ float *predictions = network_predict(net, in_s.data);
+ top_predictions(net, top, indexes);
+
+ printf("\033[2J");
+ printf("\033[1;1H");
+
+ int threat = 0;
+ for(i = 0; i < sizeof(bad_cats)/sizeof(bad_cats[0]); ++i){
+ int index = bad_cats[i];
+ if(predictions[index] > .01){
+ printf("Threat Detected!\n");
+ threat = 1;
+ break;
+ }
+ }
+ if(!threat) printf("Scanning...\n");
+ for(i = 0; i < sizeof(bad_cats)/sizeof(bad_cats[0]); ++i){
+ int index = bad_cats[i];
+ if(predictions[index] > .01){
+ printf("%s\n", names[index]);
+ }
+ }
+
+ free_image(in_s);
+ free_image(in);
+
+ cvWaitKey(10);
+
+ gettimeofday(&tval_after, NULL);
+ timersub(&tval_after, &tval_before, &tval_result);
+ float curr = 1000000.f/((long int)tval_result.tv_usec);
+ fps = .9*fps + .1*curr;
+ }
+#endif
+}
+
void demo_classifier(char *datacfg, char *cfgfile, char *weightfile, int cam_index, const char *filename)
{
#ifdef OPENCV
@@ -1102,6 +1153,7 @@
else if(0==strcmp(argv[2], "train")) train_classifier(data, cfg, weights, clear);
else if(0==strcmp(argv[2], "trainm")) train_classifier_multi(data, cfg, weights, gpus, ngpus, clear);
else if(0==strcmp(argv[2], "demo")) demo_classifier(data, cfg, weights, cam_index, filename);
+ else if(0==strcmp(argv[2], "gun")) gun_classifier(data, cfg, weights, cam_index, filename);
else if(0==strcmp(argv[2], "threat")) threat_classifier(data, cfg, weights, cam_index, filename);
else if(0==strcmp(argv[2], "test")) test_classifier(data, cfg, weights, layer);
else if(0==strcmp(argv[2], "label")) label_classifier(data, cfg, weights);
diff --git a/src/coco.c b/src/coco.c
index 1371870..b78d640 100644
--- a/src/coco.c
+++ b/src/coco.c
@@ -28,7 +28,6 @@
//char *train_images = "data/bags.train.list";
char *backup_directory = "/home/pjreddie/backup/";
srand(time(0));
- data_seed = time(0);
char *base = basecfg(cfgfile);
printf("%s\n", base);
float avg_loss = -1;
diff --git a/src/compare.c b/src/compare.c
index a1f494e..4fd266c 100644
--- a/src/compare.c
+++ b/src/compare.c
@@ -9,7 +9,6 @@
void train_compare(char *cfgfile, char *weightfile)
{
- data_seed = time(0);
srand(time(0));
float avg_loss = -1;
char *base = basecfg(cfgfile);
diff --git a/src/connected_layer.c b/src/connected_layer.c
index b4ced2d..f46c3e1 100644
--- a/src/connected_layer.c
+++ b/src/connected_layer.c
@@ -204,10 +204,12 @@
if(l.batch_normalize){
printf("Scales ");
print_statistics(l.scales, l.outputs);
+ /*
printf("Rolling Mean ");
print_statistics(l.rolling_mean, l.outputs);
printf("Rolling Variance ");
print_statistics(l.rolling_variance, l.outputs);
+ */
}
printf("Biases ");
print_statistics(l.biases, l.outputs);
diff --git a/src/convolutional_kernels.cu b/src/convolutional_kernels.cu
index 8244792..b8d6478 100644
--- a/src/convolutional_kernels.cu
+++ b/src/convolutional_kernels.cu
@@ -237,8 +237,10 @@
axpy_ongpu(layer.n, learning_rate/batch, layer.bias_updates_gpu, 1, layer.biases_gpu, 1);
scal_ongpu(layer.n, momentum, layer.bias_updates_gpu, 1);
- axpy_ongpu(layer.n, learning_rate/batch, layer.scale_updates_gpu, 1, layer.scales_gpu, 1);
- scal_ongpu(layer.n, momentum, layer.scale_updates_gpu, 1);
+ if(layer.scales_gpu){
+ axpy_ongpu(layer.n, learning_rate/batch, layer.scale_updates_gpu, 1, layer.scales_gpu, 1);
+ scal_ongpu(layer.n, momentum, layer.scale_updates_gpu, 1);
+ }
axpy_ongpu(size, -decay*batch, layer.weights_gpu, 1, layer.weight_updates_gpu, 1);
axpy_ongpu(size, learning_rate/batch, layer.weight_updates_gpu, 1, layer.weights_gpu, 1);
diff --git a/src/convolutional_layer.c b/src/convolutional_layer.c
index 299af75..01bb700 100644
--- a/src/convolutional_layer.c
+++ b/src/convolutional_layer.c
@@ -241,9 +241,6 @@
l.biases_gpu = cuda_make_array(l.biases, n);
l.bias_updates_gpu = cuda_make_array(l.bias_updates, n);
- l.scales_gpu = cuda_make_array(l.scales, n);
- l.scale_updates_gpu = cuda_make_array(l.scale_updates, n);
-
l.delta_gpu = cuda_make_array(l.delta, l.batch*out_h*out_w*n);
l.output_gpu = cuda_make_array(l.output, l.batch*out_h*out_w*n);
@@ -265,6 +262,9 @@
l.mean_delta_gpu = cuda_make_array(l.mean, n);
l.variance_delta_gpu = cuda_make_array(l.variance, n);
+ l.scales_gpu = cuda_make_array(l.scales, n);
+ l.scale_updates_gpu = cuda_make_array(l.scale_updates, n);
+
l.x_gpu = cuda_make_array(l.output, l.batch*out_h*out_w*n);
l.x_norm_gpu = cuda_make_array(l.output, l.batch*out_h*out_w*n);
}
@@ -511,6 +511,11 @@
axpy_cpu(l.n, learning_rate/batch, l.bias_updates, 1, l.biases, 1);
scal_cpu(l.n, momentum, l.bias_updates, 1);
+ if(l.scales){
+ axpy_cpu(l.n, learning_rate/batch, l.scale_updates, 1, l.scales, 1);
+ scal_cpu(l.n, momentum, l.scale_updates, 1);
+ }
+
axpy_cpu(size, -decay*batch, l.weights, 1, l.weight_updates, 1);
axpy_cpu(size, learning_rate/batch, l.weight_updates, 1, l.weights, 1);
scal_cpu(size, momentum, l.weight_updates, 1);
diff --git a/src/data.c b/src/data.c
index 02dbac4..5977a3f 100644
--- a/src/data.c
+++ b/src/data.c
@@ -7,7 +7,6 @@
#include <stdlib.h>
#include <string.h>
-unsigned int data_seed;
pthread_mutex_t mutex = PTHREAD_MUTEX_INITIALIZER;
list *get_paths(char *filename)
@@ -23,13 +22,14 @@
return lines;
}
+/*
char **get_random_paths_indexes(char **paths, int n, int m, int *indexes)
{
char **random_paths = calloc(n, sizeof(char*));
int i;
pthread_mutex_lock(&mutex);
for(i = 0; i < n; ++i){
- int index = rand_r(&data_seed)%m;
+ int index = rand()%m;
indexes[i] = index;
random_paths[i] = paths[index];
if(i == 0) printf("%s\n", paths[index]);
@@ -37,6 +37,7 @@
pthread_mutex_unlock(&mutex);
return random_paths;
}
+*/
char **get_random_paths(char **paths, int n, int m)
{
@@ -44,7 +45,7 @@
int i;
pthread_mutex_lock(&mutex);
for(i = 0; i < n; ++i){
- int index = rand_r(&data_seed)%m;
+ int index = rand()%m;
random_paths[i] = paths[index];
if(i == 0) printf("%s\n", paths[index]);
}
@@ -111,7 +112,7 @@
for(i = 0; i < n; ++i){
image im = load_image_color(paths[i], 0, 0);
image crop = random_augment_image(im, angle, aspect, min, max, size);
- int flip = rand_r(&data_seed)%2;
+ int flip = rand()%2;
if (flip) flip_image(crop);
random_distort_image(crop, hue, saturation, exposure);
@@ -159,7 +160,7 @@
int i;
for(i = 0; i < n; ++i){
box_label swap = b[i];
- int index = rand_r(&data_seed)%n;
+ int index = rand()%n;
b[i] = b[index];
b[index] = swap;
}
@@ -430,9 +431,6 @@
void free_data(data d)
{
- if(d.indexes){
- free(d.indexes);
- }
if(!d.shallow){
free_matrix(d.X);
free_matrix(d.y);
@@ -476,7 +474,7 @@
float sx = (float)swidth / ow;
float sy = (float)sheight / oh;
- int flip = rand_r(&data_seed)%2;
+ int flip = rand()%2;
image cropped = crop_image(orig, pleft, ptop, swidth, sheight);
float dx = ((float)pleft/ow)/sx;
@@ -560,7 +558,7 @@
data load_data_swag(char **paths, int n, int classes, float jitter)
{
- int index = rand_r(&data_seed)%n;
+ int index = rand()%n;
char *random_path = paths[index];
image orig = load_image_color(random_path, 0, 0);
@@ -593,7 +591,7 @@
float sx = (float)swidth / w;
float sy = (float)sheight / h;
- int flip = rand_r(&data_seed)%2;
+ int flip = rand()%2;
image cropped = crop_image(orig, pleft, ptop, swidth, sheight);
float dx = ((float)pleft/w)/sx;
@@ -643,7 +641,7 @@
float sx = (float)swidth / ow;
float sy = (float)sheight / oh;
- int flip = rand_r(&data_seed)%2;
+ int flip = rand()%2;
image cropped = crop_image(orig, pleft, ptop, swidth, sheight);
float dx = ((float)pleft/ow)/sx;
@@ -666,26 +664,18 @@
void *load_thread(void *ptr)
{
-
-#ifdef GPU
- cudaError_t status = cudaSetDevice(gpu_index);
- check_error(status);
-#endif
-
- //printf("Loading data: %d\n", rand_r(&data_seed));
+ //printf("Loading data: %d\n", rand());
load_args a = *(struct load_args*)ptr;
if(a.exposure == 0) a.exposure = 1;
if(a.saturation == 0) a.saturation = 1;
if(a.aspect == 0) a.aspect = 1;
if (a.type == OLD_CLASSIFICATION_DATA){
- *a.d = load_data(a.paths, a.n, a.m, a.labels, a.classes, a.w, a.h);
+ *a.d = load_data_old(a.paths, a.n, a.m, a.labels, a.classes, a.w, a.h);
} else if (a.type == CLASSIFICATION_DATA){
*a.d = load_data_augment(a.paths, a.n, a.m, a.labels, a.classes, a.min, a.max, a.size, a.angle, a.aspect, a.hue, a.saturation, a.exposure);
} else if (a.type == SUPER_DATA){
*a.d = load_data_super(a.paths, a.n, a.m, a.w, a.h, a.scale);
- } else if (a.type == STUDY_DATA){
- *a.d = load_data_study(a.paths, a.n, a.m, a.labels, a.classes, a.min, a.max, a.size, a.angle, a.aspect, a.hue, a.saturation, a.exposure);
} else if (a.type == WRITING_DATA){
*a.d = load_data_writing(a.paths, a.n, a.m, a.w, a.h, a.out_w, a.out_h);
} else if (a.type == REGION_DATA){
@@ -701,7 +691,6 @@
*(a.resized) = resize_image(*(a.im), a.w, a.h);
} else if (a.type == TAG_DATA){
*a.d = load_data_tag(a.paths, a.n, a.m, a.classes, a.min, a.max, a.size, a.angle, a.aspect, a.hue, a.saturation, a.exposure);
- //*a.d = load_data(a.paths, a.n, a.m, a.labels, a.classes, a.w, a.h);
}
free(ptr);
return 0;
@@ -716,6 +705,43 @@
return thread;
}
+void *load_threads(void *ptr)
+{
+ int i;
+ load_args args = *(load_args *)ptr;
+ data *out = args.d;
+ int total = args.n;
+ free(ptr);
+ data *buffers = calloc(args.threads, sizeof(data));
+ pthread_t *threads = calloc(args.threads, sizeof(pthread_t));
+ for(i = 0; i < args.threads; ++i){
+ args.d = buffers + i;
+ args.n = (i+1) * total/args.threads - i * total/args.threads;
+ threads[i] = load_data_in_thread(args);
+ }
+ for(i = 0; i < args.threads; ++i){
+ pthread_join(threads[i], 0);
+ }
+ *out = concat_datas(buffers, args.threads);
+ out->shallow = 0;
+ for(i = 0; i < args.threads; ++i){
+ buffers[i].shallow = 1;
+ free_data(buffers[i]);
+ }
+ free(buffers);
+ free(threads);
+ return 0;
+}
+
+pthread_t load_data(load_args args)
+{
+ pthread_t thread;
+ struct load_args *ptr = calloc(1, sizeof(struct load_args));
+ *ptr = args;
+ if(pthread_create(&thread, 0, load_threads, ptr)) error("Thread creation failed");
+ return thread;
+}
+
data load_data_writing(char **paths, int n, int m, int w, int h, int out_w, int out_h)
{
if(m) paths = get_random_paths(paths, n, m);
@@ -731,7 +757,7 @@
return d;
}
-data load_data(char **paths, int n, int m, char **labels, int k, int w, int h)
+data load_data_old(char **paths, int n, int m, char **labels, int k, int w, int h)
{
if(m) paths = get_random_paths(paths, n, m);
data d = {0};
@@ -742,6 +768,7 @@
return d;
}
+/*
data load_data_study(char **paths, int n, int m, char **labels, int k, int min, int max, int size, float angle, float aspect, float hue, float saturation, float exposure)
{
data d = {0};
@@ -753,6 +780,7 @@
if(m) free(paths);
return d;
}
+*/
data load_data_super(char **paths, int n, int m, int w, int h, int scale)
{
@@ -772,7 +800,7 @@
for(i = 0; i < n; ++i){
image im = load_image_color(paths[i], 0, 0);
image crop = random_crop_image(im, w*scale, h*scale);
- int flip = rand_r(&data_seed)%2;
+ int flip = rand()%2;
if (flip) flip_image(crop);
image resize = resize_image(crop, w, h);
d.X.vals[i] = resize.data;
@@ -837,7 +865,6 @@
{
int i;
data out = {0};
- out.shallow = 1;
for(i = 0; i < n; ++i){
data new = concat_data(d[i], out);
free_data(out);
@@ -895,7 +922,7 @@
{
int j;
for(j = 0; j < n; ++j){
- int index = rand_r(&data_seed)%d.X.rows;
+ int index = rand()%d.X.rows;
memcpy(X+j*d.X.cols, d.X.vals[index], d.X.cols*sizeof(float));
memcpy(y+j*d.y.cols, d.y.vals[index], d.y.cols*sizeof(float));
}
@@ -1008,7 +1035,7 @@
{
int i;
for(i = d.X.rows-1; i > 0; --i){
- int index = rand_r(&data_seed)%i;
+ int index = rand()%i;
float *swap = d.X.vals[index];
d.X.vals[index] = d.X.vals[i];
d.X.vals[i] = swap;
@@ -1043,6 +1070,19 @@
}
}
+data get_data_part(data d, int part, int total)
+{
+ data p = {0};
+ p.shallow = 1;
+ p.X.rows = d.X.rows * (part + 1) / total - d.X.rows * part / total;
+ p.y.rows = d.y.rows * (part + 1) / total - d.y.rows * part / total;
+ p.X.cols = d.X.cols;
+ p.y.cols = d.y.cols;
+ p.X.vals = d.X.vals + d.X.rows * part / total;
+ p.y.vals = d.y.vals + d.y.rows * part / total;
+ return p;
+}
+
data get_random_data(data d, int num)
{
data r = {0};
diff --git a/src/data.h b/src/data.h
index 07c994b..c24201d 100644
--- a/src/data.h
+++ b/src/data.h
@@ -6,8 +6,6 @@
#include "list.h"
#include "image.h"
-extern unsigned int data_seed;
-
static inline float distance_from_edge(int x, int max)
{
int dx = (max/2) - x;
@@ -23,7 +21,6 @@
int w, h;
matrix X;
matrix y;
- int *indexes;
int shallow;
int *num_boxes;
box **boxes;
@@ -34,6 +31,7 @@
} data_type;
typedef struct load_args{
+ int threads;
char **paths;
char *path;
int n;
@@ -70,17 +68,18 @@
void free_data(data d);
+pthread_t load_data(load_args args);
+
pthread_t load_data_in_thread(load_args args);
void print_letters(float *pred, int n);
data load_data_captcha(char **paths, int n, int m, int k, int w, int h);
data load_data_captcha_encode(char **paths, int n, int m, int w, int h);
-data load_data(char **paths, int n, int m, char **labels, int k, int w, int h);
+data load_data_old(char **paths, int n, int m, char **labels, int k, int w, int h);
data load_data_detection(int n, char **paths, int m, int w, int h, int boxes, int classes, float jitter, float hue, float saturation, float exposure);
data load_data_tag(char **paths, int n, int m, int k, int min, int max, int size, float angle, float aspect, float hue, float saturation, float exposure);
matrix load_image_augment_paths(char **paths, int n, int min, int max, int size, float angle, float aspect, float hue, float saturation, float exposure);
data load_data_super(char **paths, int n, int m, int w, int h, int scale);
-data load_data_study(char **paths, int n, int m, char **labels, int k, int min, int max, int size, float angle, float aspect, float hue, float saturation, float exposure);
data load_data_augment(char **paths, int n, int m, char **labels, int k, int min, int max, int size, float angle, float aspect, float hue, float saturation, float exposure);
data load_go(char *filename);
@@ -93,6 +92,7 @@
list *get_paths(char *filename);
char **get_labels(char *filename);
void get_random_batch(data d, int n, float *X, float *y);
+data get_data_part(data d, int part, int total);
data get_random_data(data d, int num);
void get_next_batch(data d, int n, int offset, float *X, float *y);
data load_categorical_data_csv(char *filename, int target, int k);
diff --git a/src/detector.c b/src/detector.c
index f4991ac..9498750 100644
--- a/src/detector.c
+++ b/src/detector.c
@@ -17,7 +17,6 @@
char *train_images = "/data/voc/train.txt";
char *backup_directory = "/home/pjreddie/backup/";
srand(time(0));
- data_seed = time(0);
char *base = basecfg(cfgfile);
printf("%s\n", base);
float avg_loss = -1;
diff --git a/src/dice.c b/src/dice.c
index 6f148b0..2286459 100644
--- a/src/dice.c
+++ b/src/dice.c
@@ -6,7 +6,6 @@
void train_dice(char *cfgfile, char *weightfile)
{
- data_seed = time(0);
srand(time(0));
float avg_loss = -1;
char *base = basecfg(cfgfile);
@@ -27,7 +26,7 @@
while(1){
++i;
time=clock();
- data train = load_data(paths, imgs, plist->size, labels, 6, net.w, net.h);
+ data train = load_data_old(paths, imgs, plist->size, labels, 6, net.w, net.h);
printf("Loaded: %lf seconds\n", sec(clock()-time));
time=clock();
@@ -60,7 +59,7 @@
int m = plist->size;
free_list(plist);
- data val = load_data(paths, m, 0, labels, 6, net.w, net.h);
+ data val = load_data_old(paths, m, 0, labels, 6, net.w, net.h);
float *acc = network_accuracies(net, val, 2);
printf("Validation Accuracy: %f, %d images\n", acc[0], m);
free_data(val);
diff --git a/src/go.c b/src/go.c
index bb5e60e..89297b5 100644
--- a/src/go.c
+++ b/src/go.c
@@ -116,7 +116,6 @@
void train_go(char *cfgfile, char *weightfile)
{
- data_seed = time(0);
srand(time(0));
float avg_loss = -1;
char *base = basecfg(cfgfile);
@@ -401,7 +400,6 @@
void valid_go(char *cfgfile, char *weightfile, int multi)
{
- data_seed = time(0);
srand(time(0));
char *base = basecfg(cfgfile);
printf("%s\n", base);
diff --git a/src/network.c b/src/network.c
index f5e1f82..72c8943 100644
--- a/src/network.c
+++ b/src/network.c
@@ -1,5 +1,6 @@
#include <stdio.h>
#include <time.h>
+#include <assert.h>
#include "network.h"
#include "image.h"
#include "data.h"
@@ -356,6 +357,7 @@
float train_network(network net, data d)
{
+ assert(d.X.rows % net.batch == 0);
int batch = net.batch;
int n = d.X.rows / batch;
float *X = calloc(batch*d.X.cols, sizeof(float));
diff --git a/src/network.h b/src/network.h
index af0ad8c..4f9ba75 100644
--- a/src/network.h
+++ b/src/network.h
@@ -65,7 +65,8 @@
} network_state;
#ifdef GPU
-float train_networks(network *nets, int n, data d);
+float train_networks(network *nets, int n, data d, int interval);
+void sync_nets(network *nets, int n, int interval);
float train_network_datum_gpu(network net, float *x, float *y);
float *network_predict_gpu(network net, float *input);
float * get_network_output_gpu_layer(network net, int i);
diff --git a/src/network_kernels.cu b/src/network_kernels.cu
index 3e0c2b6..b7d1d2b 100644
--- a/src/network_kernels.cu
+++ b/src/network_kernels.cu
@@ -219,34 +219,32 @@
typedef struct {
network net;
- float *X;
- float *y;
+ data d;
+ float *err;
} train_args;
void *train_thread(void *ptr)
{
train_args args = *(train_args*)ptr;
-
- cuda_set_device(args.net.gpu_index);
- forward_backward_network_gpu(args.net, args.X, args.y);
free(ptr);
+ cuda_set_device(args.net.gpu_index);
+ *args.err = train_network(args.net, args.d);
return 0;
}
-pthread_t train_network_in_thread(network net, float *X, float *y)
+pthread_t train_network_in_thread(network net, data d, float *err)
{
pthread_t thread;
train_args *ptr = (train_args *)calloc(1, sizeof(train_args));
ptr->net = net;
- ptr->X = X;
- ptr->y = y;
+ ptr->d = d;
+ ptr->err = err;
if(pthread_create(&thread, 0, train_thread, ptr)) error("Thread creation failed");
return thread;
}
void pull_updates(layer l)
{
-#ifdef GPU
if(l.type == CONVOLUTIONAL){
cuda_pull_array(l.bias_updates_gpu, l.bias_updates, l.n);
cuda_pull_array(l.weight_updates_gpu, l.weight_updates, l.n*l.size*l.size*l.c);
@@ -255,12 +253,10 @@
cuda_pull_array(l.bias_updates_gpu, l.bias_updates, l.outputs);
cuda_pull_array(l.weight_updates_gpu, l.weight_updates, l.outputs*l.inputs);
}
-#endif
}
void push_updates(layer l)
{
-#ifdef GPU
if(l.type == CONVOLUTIONAL){
cuda_push_array(l.bias_updates_gpu, l.bias_updates, l.n);
cuda_push_array(l.weight_updates_gpu, l.weight_updates, l.n*l.size*l.size*l.c);
@@ -269,9 +265,95 @@
cuda_push_array(l.bias_updates_gpu, l.bias_updates, l.outputs);
cuda_push_array(l.weight_updates_gpu, l.weight_updates, l.outputs*l.inputs);
}
-#endif
}
+void update_layer(layer l, network net)
+{
+ int update_batch = net.batch*net.subdivisions;
+ float rate = get_current_rate(net);
+ if(l.type == CONVOLUTIONAL){
+ update_convolutional_layer_gpu(l, update_batch, rate, net.momentum, net.decay);
+ } else if(l.type == DECONVOLUTIONAL){
+ update_deconvolutional_layer_gpu(l, rate, net.momentum, net.decay);
+ } else if(l.type == CONNECTED){
+ update_connected_layer_gpu(l, update_batch, rate, net.momentum, net.decay);
+ } else if(l.type == RNN){
+ update_rnn_layer_gpu(l, update_batch, rate, net.momentum, net.decay);
+ } else if(l.type == GRU){
+ update_gru_layer_gpu(l, update_batch, rate, net.momentum, net.decay);
+ } else if(l.type == CRNN){
+ update_crnn_layer_gpu(l, update_batch, rate, net.momentum, net.decay);
+ } else if(l.type == LOCAL){
+ update_local_layer_gpu(l, update_batch, rate, net.momentum, net.decay);
+ }
+}
+
+void merge_weights(layer l, layer base)
+{
+ if (l.type == CONVOLUTIONAL) {
+ axpy_cpu(l.n, 1, l.biases, 1, base.biases, 1);
+ axpy_cpu(l.n*l.size*l.size*l.c, 1, l.weights, 1, base.weights, 1);
+ if (l.scales) {
+ axpy_cpu(l.n, 1, l.scales, 1, base.scales, 1);
+ }
+ } else if(l.type == CONNECTED) {
+ axpy_cpu(l.outputs, 1, l.biases, 1, base.biases, 1);
+ axpy_cpu(l.outputs*l.inputs, 1, l.weights, 1, base.weights, 1);
+ }
+}
+
+void scale_weights(layer l, float s)
+{
+ if (l.type == CONVOLUTIONAL) {
+ scal_cpu(l.n, s, l.biases, 1);
+ scal_cpu(l.n*l.size*l.size*l.c, s, l.weights, 1);
+ if (l.scales) {
+ scal_cpu(l.n, s, l.scales, 1);
+ }
+ } else if(l.type == CONNECTED) {
+ scal_cpu(l.outputs, s, l.biases, 1);
+ scal_cpu(l.outputs*l.inputs, s, l.weights, 1);
+ }
+}
+
+
+void pull_weights(layer l)
+{
+ if(l.type == CONVOLUTIONAL){
+ cuda_pull_array(l.biases_gpu, l.biases, l.n);
+ cuda_pull_array(l.weights_gpu, l.weights, l.n*l.size*l.size*l.c);
+ if(l.scales) cuda_pull_array(l.scales_gpu, l.scales, l.n);
+ } else if(l.type == CONNECTED){
+ cuda_pull_array(l.biases_gpu, l.biases, l.outputs);
+ cuda_pull_array(l.weights_gpu, l.weights, l.outputs*l.inputs);
+ }
+}
+
+void push_weights(layer l)
+{
+ if(l.type == CONVOLUTIONAL){
+ cuda_push_array(l.biases_gpu, l.biases, l.n);
+ cuda_push_array(l.weights_gpu, l.weights, l.n*l.size*l.size*l.c);
+ if(l.scales) cuda_push_array(l.scales_gpu, l.scales, l.n);
+ } else if(l.type == CONNECTED){
+ cuda_push_array(l.biases_gpu, l.biases, l.outputs);
+ cuda_push_array(l.weights_gpu, l.weights, l.outputs*l.inputs);
+ }
+}
+
+void distribute_weights(layer l, layer base)
+{
+ if(l.type == CONVOLUTIONAL){
+ cuda_push_array(l.biases_gpu, base.biases, l.n);
+ cuda_push_array(l.weights_gpu, base.weights, l.n*l.size*l.size*l.c);
+ if(base.scales) cuda_push_array(l.scales_gpu, base.scales, l.n);
+ } else if(l.type == CONNECTED){
+ cuda_push_array(l.biases_gpu, base.biases, l.outputs);
+ cuda_push_array(l.weights_gpu, base.weights, l.outputs*l.inputs);
+ }
+}
+
+
void merge_updates(layer l, layer base)
{
if (l.type == CONVOLUTIONAL) {
@@ -288,79 +370,110 @@
void distribute_updates(layer l, layer base)
{
- if (l.type == CONVOLUTIONAL) {
- copy_cpu(l.n, base.bias_updates, 1, l.bias_updates, 1);
- copy_cpu(l.n*l.size*l.size*l.c, base.weight_updates, 1, l.weight_updates, 1);
- if (l.scale_updates) {
- copy_cpu(l.n, base.scale_updates, 1, l.scale_updates, 1);
- }
- } else if(l.type == CONNECTED) {
- copy_cpu(l.outputs, base.bias_updates, 1, l.bias_updates, 1);
- copy_cpu(l.outputs*l.inputs, base.weight_updates, 1, l.weight_updates, 1);
+ if(l.type == CONVOLUTIONAL){
+ cuda_push_array(l.bias_updates_gpu, base.bias_updates, l.n);
+ cuda_push_array(l.weight_updates_gpu, base.weight_updates, l.n*l.size*l.size*l.c);
+ if(base.scale_updates) cuda_push_array(l.scale_updates_gpu, base.scale_updates, l.n);
+ } else if(l.type == CONNECTED){
+ cuda_push_array(l.bias_updates_gpu, base.bias_updates, l.outputs);
+ cuda_push_array(l.weight_updates_gpu, base.weight_updates, l.outputs*l.inputs);
}
}
-void sync_updates(network *nets, int n)
+void sync_layer(network *nets, int n, int j)
{
- int i,j;
- int layers = nets[0].n;
+ //printf("Syncing layer %d\n", j);
+ int i;
network net = nets[0];
- for (j = 0; j < layers; ++j) {
- layer base = net.layers[j];
- cuda_set_device(net.gpu_index);
- pull_updates(base);
- for (i = 1; i < n; ++i) {
- cuda_set_device(nets[i].gpu_index);
- layer l = nets[i].layers[j];
- pull_updates(l);
- merge_updates(l, base);
- }
- for (i = 1; i < n; ++i) {
- cuda_set_device(nets[i].gpu_index);
- layer l = nets[i].layers[j];
- distribute_updates(l, base);
- push_updates(l);
- }
- cuda_set_device(net.gpu_index);
- push_updates(base);
+ layer base = net.layers[j];
+ cuda_set_device(net.gpu_index);
+ pull_weights(base);
+ for (i = 1; i < n; ++i) {
+ cuda_set_device(nets[i].gpu_index);
+ layer l = nets[i].layers[j];
+ pull_weights(l);
+ merge_weights(l, base);
}
+ scale_weights(base, 1./n);
for (i = 0; i < n; ++i) {
cuda_set_device(nets[i].gpu_index);
- if(i > 0) nets[i].momentum = 0;
- update_network_gpu(nets[i]);
+ layer l = nets[i].layers[j];
+ distribute_weights(l, base);
}
+ //printf("Done syncing layer %d\n", j);
}
-float train_networks(network *nets, int n, data d)
-{
- int batch = nets[0].batch;
- assert(batch * n == d.X.rows);
- assert(nets[0].subdivisions % n == 0);
- float **X = (float **) calloc(n, sizeof(float *));
- float **y = (float **) calloc(n, sizeof(float *));
- pthread_t *threads = (pthread_t *) calloc(n, sizeof(pthread_t));
+typedef struct{
+ network *nets;
+ int n;
+ int j;
+} sync_args;
+void *sync_layer_thread(void *ptr)
+{
+ sync_args args = *(sync_args*)ptr;
+ sync_layer(args.nets, args.n, args.j);
+ free(ptr);
+ return 0;
+}
+
+pthread_t sync_layer_in_thread(network *nets, int n, int j)
+{
+ pthread_t thread;
+ sync_args *ptr = (sync_args *)calloc(1, sizeof(sync_args));
+ ptr->nets = nets;
+ ptr->n = n;
+ ptr->j = j;
+ if(pthread_create(&thread, 0, sync_layer_thread, ptr)) error("Thread creation failed");
+ return thread;
+}
+
+void sync_nets(network *nets, int n, int interval)
+{
+ int j;
+ int layers = nets[0].n;
+ pthread_t *threads = (pthread_t *) calloc(layers, sizeof(pthread_t));
+
+ *nets[0].seen += interval * (n-1) * nets[0].batch * nets[0].subdivisions;
+ for (j = 0; j < n; ++j){
+ *nets[j].seen = *nets[0].seen;
+ }
+ for (j = 0; j < layers; ++j) {
+ threads[j] = sync_layer_in_thread(nets, n, j);
+ }
+ for (j = 0; j < layers; ++j) {
+ pthread_join(threads[j], 0);
+ }
+ free(threads);
+}
+
+float train_networks(network *nets, int n, data d, int interval)
+{
int i;
+ int batch = nets[0].batch;
+ int subdivisions = nets[0].subdivisions;
+ assert(batch * subdivisions * n == d.X.rows);
+ pthread_t *threads = (pthread_t *) calloc(n, sizeof(pthread_t));
+ float *errors = (float *) calloc(n, sizeof(float));
+
float sum = 0;
for(i = 0; i < n; ++i){
- X[i] = (float *) calloc(batch*d.X.cols, sizeof(float));
- y[i] = (float *) calloc(batch*d.y.cols, sizeof(float));
- get_next_batch(d, batch, i*batch, X[i], y[i]);
- threads[i] = train_network_in_thread(nets[i], X[i], y[i]);
+ data p = get_data_part(d, i, n);
+ threads[i] = train_network_in_thread(nets[i], p, errors + i);
}
for(i = 0; i < n; ++i){
pthread_join(threads[i], 0);
- *nets[i].seen += n*nets[i].batch;
- printf("%f\n", get_network_cost(nets[i]) / batch);
- sum += get_network_cost(nets[i]);
- free(X[i]);
- free(y[i]);
+ printf("%f\n", errors[i]);
+ sum += errors[i];
}
- if (((*nets[0].seen) / nets[0].batch) % nets[0].subdivisions == 0) sync_updates(nets, n);
- free(X);
- free(y);
+ if (get_current_batch(nets[0]) % interval == 0) {
+ printf("Syncing... ");
+ sync_nets(nets, n, interval);
+ printf("Done!\n");
+ }
free(threads);
- return (float)sum/(n*batch);
+ free(errors);
+ return (float)sum/(n);
}
float *get_network_output_layer_gpu(network net, int i)
diff --git a/src/parser.c b/src/parser.c
index 3551983..2b285b5 100644
--- a/src/parser.c
+++ b/src/parser.c
@@ -954,7 +954,9 @@
void save_weights_upto(network net, char *filename, int cutoff)
{
#ifdef GPU
+ if(net.gpu_index >= 0){
cuda_set_device(net.gpu_index);
+ }
#endif
fprintf(stderr, "Saving weights to %s\n", filename);
FILE *fp = fopen(filename, "w");
@@ -1120,7 +1122,9 @@
void load_weights_upto(network *net, char *filename, int cutoff)
{
#ifdef GPU
+ if(net->gpu_index >= 0){
cuda_set_device(net->gpu_index);
+ }
#endif
fprintf(stderr, "Loading weights from %s...", filename);
fflush(stdout);
diff --git a/src/rnn.c b/src/rnn.c
index 4f0e011..eca6f55 100644
--- a/src/rnn.c
+++ b/src/rnn.c
@@ -129,7 +129,6 @@
void train_char_rnn(char *cfgfile, char *weightfile, char *filename, int clear, int tokenized)
{
srand(time(0));
- data_seed = time(0);
unsigned char *text = 0;
int *tokens = 0;
size_t size;
diff --git a/src/rnn_vid.c b/src/rnn_vid.c
index 183ae77..bf024f9 100644
--- a/src/rnn_vid.c
+++ b/src/rnn_vid.c
@@ -76,7 +76,6 @@
char *train_videos = "data/vid/train.txt";
char *backup_directory = "/home/pjreddie/backup/";
srand(time(0));
- data_seed = time(0);
char *base = basecfg(cfgfile);
printf("%s\n", base);
float avg_loss = -1;
diff --git a/src/super.c b/src/super.c
index 67b941f..63e9860 100644
--- a/src/super.c
+++ b/src/super.c
@@ -12,7 +12,6 @@
char *train_images = "/data/imagenet/imagenet1k.train.list";
char *backup_directory = "/home/pjreddie/backup/";
srand(time(0));
- data_seed = time(0);
char *base = basecfg(cfgfile);
printf("%s\n", base);
float avg_loss = -1;
diff --git a/src/swag.c b/src/swag.c
index f06db4c..2cb3093 100644
--- a/src/swag.c
+++ b/src/swag.c
@@ -14,7 +14,6 @@
char *train_images = "data/voc.0712.trainval";
char *backup_directory = "/home/pjreddie/backup/";
srand(time(0));
- data_seed = time(0);
char *base = basecfg(cfgfile);
printf("%s\n", base);
float avg_loss = -1;
diff --git a/src/tag.c b/src/tag.c
index dd10591..1e43e7d 100644
--- a/src/tag.c
+++ b/src/tag.c
@@ -8,7 +8,6 @@
void train_tag(char *cfgfile, char *weightfile, int clear)
{
- data_seed = time(0);
srand(time(0));
float avg_loss = -1;
char *base = basecfg(cfgfile);
diff --git a/src/voxel.c b/src/voxel.c
index a047c6a..c277bcf 100644
--- a/src/voxel.c
+++ b/src/voxel.c
@@ -48,7 +48,6 @@
char *train_images = "/data/imagenet/imagenet1k.train.list";
char *backup_directory = "/home/pjreddie/backup/";
srand(time(0));
- data_seed = time(0);
char *base = basecfg(cfgfile);
printf("%s\n", base);
float avg_loss = -1;
diff --git a/src/writing.c b/src/writing.c
index 32e4f6d..0a76d48 100644
--- a/src/writing.c
+++ b/src/writing.c
@@ -9,7 +9,6 @@
void train_writing(char *cfgfile, char *weightfile)
{
char *backup_directory = "/home/pjreddie/backup/";
- data_seed = time(0);
srand(time(0));
float avg_loss = -1;
char *base = basecfg(cfgfile);
diff --git a/src/yolo.c b/src/yolo.c
index 43d5355..2465a2c 100644
--- a/src/yolo.c
+++ b/src/yolo.c
@@ -18,7 +18,6 @@
char *train_images = "/data/voc/train.txt";
char *backup_directory = "/home/pjreddie/backup/";
srand(time(0));
- data_seed = time(0);
char *base = basecfg(cfgfile);
printf("%s\n", base);
float avg_loss = -1;
--
Gitblit v1.10.0