From c7b10ceadb1a78e7480d281444a31ae2a7dc1b05 Mon Sep 17 00:00:00 2001
From: Joseph Redmon <pjreddie@gmail.com>
Date: Fri, 06 May 2016 23:25:16 +0000
Subject: [PATCH] so much need to commit

---
 src/yolo.c                   |   43 -
 src/batchnorm_layer.c        |  175 +++++
 Makefile                     |    9 
 src/batchnorm_layer.h        |   17 
 src/rnn.c                    |  103 ++
 src/classifier.c             |    5 
 src/classifier.h             |    2 
 src/go.c                     |  268 +++++++
 src/convolutional_layer.h    |    2 
 src/image.c                  |   10 
 src/convolutional_kernels.cu |  209 +----
 src/crnn_layer.c             |    6 
 src/layer.h                  |   49 +
 src/utils.c                  |   61 +
 src/gru_layer.h              |   25 
 src/network.c                |   19 
 src/utils.h                  |    6 
 src/network.h                |    1 
 src/network_kernels.cu       |   12 
 src/connected_layer.c        |   61 -
 src/connected_layer.h        |    1 
 src/data.c                   |   60 +
 src/rnn_layer.h              |   22 
 src/blas.h                   |    8 
 src/data.h                   |    5 
 src/gru_layer.c              |  307 +++++++++
 src/yolo_kernels.cu          |    1 
 src/rnn_layer.c              |    6 
 src/yolo_demo.c              |    1 
 src/convolutional_layer.c    |   77 --
 src/activations.h            |   13 
 src/parser.c                 |   91 ++
 src/activation_kernels.cu    |   17 
 cfg/darknet.cfg              |   32 
 src/activations.c            |    7 
 src/blas_kernels.cu          |  189 +++++
 src/darknet.c                |   20 
 37 files changed, 1,502 insertions(+), 438 deletions(-)

diff --git a/Makefile b/Makefile
index cd3edcd..1ef1b3b 100644
--- a/Makefile
+++ b/Makefile
@@ -1,5 +1,5 @@
-GPU=1
-OPENCV=1
+GPU=0
+OPENCV=0
 DEBUG=0
 
 ARCH= --gpu-architecture=compute_20 --gpu-code=compute_20 
@@ -11,7 +11,7 @@
 CC=gcc
 NVCC=nvcc
 OPTS=-Ofast
-LDFLAGS= -lm -pthread -lstdc++ 
+LDFLAGS= -lm -pthread 
 COMMON= 
 CFLAGS=-Wall -Wfatal-errors 
 
@@ -34,8 +34,9 @@
 LDFLAGS+= -L/usr/local/cuda/lib64 -lcuda -lcudart -lcublas -lcurand
 endif
 
-OBJ=gemm.o utils.o cuda.o deconvolutional_layer.o convolutional_layer.o list.o image.o activations.o im2col.o col2im.o blas.o crop_layer.o dropout_layer.o maxpool_layer.o softmax_layer.o data.o matrix.o network.o connected_layer.o cost_layer.o parser.o option_list.o darknet.o detection_layer.o imagenet.o captcha.o route_layer.o writing.o box.o nightmare.o normalization_layer.o avgpool_layer.o coco.o dice.o yolo.o layer.o compare.o classifier.o local_layer.o swag.o shortcut_layer.o activation_layer.o rnn_layer.o rnn.o rnn_vid.o crnn_layer.o coco_demo.o tag.o cifar.o yolo_demo.o go.o
+OBJ=gemm.o utils.o cuda.o deconvolutional_layer.o convolutional_layer.o list.o image.o activations.o im2col.o col2im.o blas.o crop_layer.o dropout_layer.o maxpool_layer.o softmax_layer.o data.o matrix.o network.o connected_layer.o cost_layer.o parser.o option_list.o darknet.o detection_layer.o imagenet.o captcha.o route_layer.o writing.o box.o nightmare.o normalization_layer.o avgpool_layer.o coco.o dice.o yolo.o layer.o compare.o classifier.o local_layer.o swag.o shortcut_layer.o activation_layer.o rnn_layer.o gru_layer.o rnn.o rnn_vid.o crnn_layer.o coco_demo.o tag.o cifar.o yolo_demo.o go.o batchnorm_layer.o
 ifeq ($(GPU), 1) 
+LDFLAGS+= -lstdc++ 
 OBJ+=convolutional_kernels.o deconvolutional_kernels.o activation_kernels.o im2col_kernels.o col2im_kernels.o blas_kernels.o crop_layer_kernels.o dropout_layer_kernels.o maxpool_layer_kernels.o softmax_layer_kernels.o network_kernels.o avgpool_layer_kernels.o
 endif
 
diff --git a/cfg/darknet.cfg b/cfg/darknet.cfg
index 00e9c36..ff0d33e 100644
--- a/cfg/darknet.cfg
+++ b/cfg/darknet.cfg
@@ -1,27 +1,20 @@
 [net]
 batch=128
 subdivisions=1
-height=256
-width=256
+height=224
+width=224
 channels=3
 momentum=0.9
 decay=0.0005
+max_crop=320
 
-learning_rate=0.01
-policy=sigmoid
-gamma=.00002
-step=400000
-max_batches=800000
-
-[crop]
-crop_height=224
-crop_width=224
-flip=1
-angle=0
-saturation=1
-exposure=1
+learning_rate=0.1
+policy=poly
+power=4
+max_batches=500000
 
 [convolutional]
+batch_normalize=1
 filters=16
 size=3
 stride=1
@@ -33,6 +26,7 @@
 stride=2
 
 [convolutional]
+batch_normalize=1
 filters=32
 size=3
 stride=1
@@ -44,6 +38,7 @@
 stride=2
 
 [convolutional]
+batch_normalize=1
 filters=64
 size=3
 stride=1
@@ -55,6 +50,7 @@
 stride=2
 
 [convolutional]
+batch_normalize=1
 filters=128
 size=3
 stride=1
@@ -66,6 +62,7 @@
 stride=2
 
 [convolutional]
+batch_normalize=1
 filters=256
 size=3
 stride=1
@@ -77,6 +74,7 @@
 stride=2
 
 [convolutional]
+batch_normalize=1
 filters=512
 size=3
 stride=1
@@ -88,6 +86,7 @@
 stride=2
 
 [convolutional]
+batch_normalize=1
 filters=1024
 size=3
 stride=1
@@ -96,9 +95,6 @@
 
 [avgpool]
 
-[dropout]
-probability=.5
-
 [connected]
 output=1000
 activation=leaky
diff --git a/src/activation_kernels.cu b/src/activation_kernels.cu
index 99933c8..3dc3af0 100644
--- a/src/activation_kernels.cu
+++ b/src/activation_kernels.cu
@@ -15,13 +15,19 @@
 __device__ float relie_activate_kernel(float x){return x*(x>0);}
 __device__ float ramp_activate_kernel(float x){return x*(x>0)+.1*x;}
 __device__ float leaky_activate_kernel(float x){return (x>0) ? x : .1*x;}
-__device__ float tanh_activate_kernel(float x){return (exp(2*x)-1)/(exp(2*x)+1);}
+__device__ float tanh_activate_kernel(float x){return (2/(1 + exp(-2*x)) - 1);}
 __device__ float plse_activate_kernel(float x)
 {
     if(x < -4) return .01 * (x + 4);
     if(x > 4)  return .01 * (x - 4) + 1;
     return .125*x + .5;
 }
+__device__ float stair_activate_kernel(float x)
+{
+    int n = floor(x);
+    if (n%2 == 0) return floor(x/2.);
+    else return (x - n) + floor(x/2.);
+}
  
 __device__ float linear_gradient_kernel(float x){return 1;}
 __device__ float logistic_gradient_kernel(float x){return (1-x)*x;}
@@ -37,6 +43,11 @@
 __device__ float leaky_gradient_kernel(float x){return (x>0) ? 1 : .1;}
 __device__ float tanh_gradient_kernel(float x){return 1-x*x;}
 __device__ float plse_gradient_kernel(float x){return (x < 0 || x > 1) ? .01 : .125;}
+__device__ float stair_gradient_kernel(float x)
+{
+    if (floor(x) == x) return 0;
+    return 1;
+}
 
 __device__ float activate_kernel(float x, ACTIVATION a)
 {
@@ -61,6 +72,8 @@
             return tanh_activate_kernel(x);
         case PLSE:
             return plse_activate_kernel(x);
+        case STAIR:
+            return stair_activate_kernel(x);
     }
     return 0;
 }
@@ -88,6 +101,8 @@
             return tanh_gradient_kernel(x);
         case PLSE:
             return plse_gradient_kernel(x);
+        case STAIR:
+            return stair_gradient_kernel(x);
     }
     return 0;
 }
diff --git a/src/activations.c b/src/activations.c
index 07e3a45..6b98e1c 100644
--- a/src/activations.c
+++ b/src/activations.c
@@ -28,6 +28,8 @@
             return "plse";
         case LEAKY:
             return "leaky";
+        case STAIR:
+            return "stair";
         default:
             break;
     }
@@ -46,6 +48,7 @@
     if (strcmp(s, "ramp")==0) return RAMP;
     if (strcmp(s, "leaky")==0) return LEAKY;
     if (strcmp(s, "tanh")==0) return TANH;
+    if (strcmp(s, "stair")==0) return STAIR;
     fprintf(stderr, "Couldn't find activation function %s, going with ReLU\n", s);
     return RELU;
 }
@@ -73,6 +76,8 @@
             return tanh_activate(x);
         case PLSE:
             return plse_activate(x);
+        case STAIR:
+            return stair_activate(x);
     }
     return 0;
 }
@@ -108,6 +113,8 @@
             return tanh_gradient(x);
         case PLSE:
             return plse_gradient(x);
+        case STAIR:
+            return stair_gradient(x);
     }
     return 0;
 }
diff --git a/src/activations.h b/src/activations.h
index 7806025..05f7bca 100644
--- a/src/activations.h
+++ b/src/activations.h
@@ -4,7 +4,7 @@
 #include "math.h"
 
 typedef enum{
-    LOGISTIC, RELU, RELIE, LINEAR, RAMP, TANH, PLSE, LEAKY, ELU, LOGGY
+    LOGISTIC, RELU, RELIE, LINEAR, RAMP, TANH, PLSE, LEAKY, ELU, LOGGY, STAIR
 }ACTIVATION;
 
 ACTIVATION get_activation(char *s);
@@ -19,6 +19,12 @@
 void gradient_array_ongpu(float *x, int n, ACTIVATION a, float *delta);
 #endif
 
+static inline float stair_activate(float x)
+{
+    int n = floor(x);
+    if (n%2 == 0) return floor(x/2.);
+    else return (x - n) + floor(x/2.);
+}
 static inline float linear_activate(float x){return x;}
 static inline float logistic_activate(float x){return 1./(1. + exp(-x));}
 static inline float loggy_activate(float x){return 2./(1. + exp(-x)) - 1;}
@@ -42,6 +48,11 @@
     float y = (x+1.)/2.;
     return 2*(1-y)*y;
 }
+static inline float stair_gradient(float x)
+{
+    if (floor(x) == x) return 0;
+    return 1;
+}
 static inline float relu_gradient(float x){return (x>0);}
 static inline float elu_gradient(float x){return (x >= 0) + (x < 0)*(x + 1);}
 static inline float relie_gradient(float x){return (x>0) ? 1 : .01;}
diff --git a/src/batchnorm_layer.c b/src/batchnorm_layer.c
new file mode 100644
index 0000000..6ea4040
--- /dev/null
+++ b/src/batchnorm_layer.c
@@ -0,0 +1,175 @@
+#include "batchnorm_layer.h"
+#include "blas.h"
+#include <stdio.h>
+
+layer make_batchnorm_layer(int batch, int w, int h, int c)
+{
+    fprintf(stderr, "Batch Normalization Layer: %d x %d x %d image\n", w,h,c);
+    layer layer = {0};
+    layer.type = BATCHNORM;
+    layer.batch = batch;
+    layer.h = layer.out_h = h;
+    layer.w = layer.out_w = w;
+    layer.c = layer.out_c = c;
+    layer.output = calloc(h * w * c * batch, sizeof(float));
+    layer.delta  = calloc(h * w * c * batch, sizeof(float));
+    layer.inputs = w*h*c;
+    layer.outputs = layer.inputs;
+
+    layer.scales = calloc(c, sizeof(float));
+    layer.scale_updates = calloc(c, sizeof(float));
+    int i;
+    for(i = 0; i < c; ++i){
+        layer.scales[i] = 1;
+    }
+
+    layer.mean = calloc(c, sizeof(float));
+    layer.variance = calloc(c, sizeof(float));
+
+    layer.rolling_mean = calloc(c, sizeof(float));
+    layer.rolling_variance = calloc(c, sizeof(float));
+#ifdef GPU
+    layer.output_gpu =  cuda_make_array(layer.output, h * w * c * batch);
+    layer.delta_gpu =   cuda_make_array(layer.delta, h * w * c * batch);
+
+    layer.scales_gpu = cuda_make_array(layer.scales, c);
+    layer.scale_updates_gpu = cuda_make_array(layer.scale_updates, c);
+
+    layer.mean_gpu = cuda_make_array(layer.mean, c);
+    layer.variance_gpu = cuda_make_array(layer.variance, c);
+
+    layer.rolling_mean_gpu = cuda_make_array(layer.mean, c);
+    layer.rolling_variance_gpu = cuda_make_array(layer.variance, c);
+
+    layer.mean_delta_gpu = cuda_make_array(layer.mean, c);
+    layer.variance_delta_gpu = cuda_make_array(layer.variance, c);
+
+    layer.x_gpu = cuda_make_array(layer.output, layer.batch*layer.outputs);
+    layer.x_norm_gpu = cuda_make_array(layer.output, layer.batch*layer.outputs);
+#endif
+    return layer;
+}
+
+void backward_scale_cpu(float *x_norm, float *delta, int batch, int n, int size, float *scale_updates)
+{
+    int i,b,f;
+    for(f = 0; f < n; ++f){
+        float sum = 0;
+        for(b = 0; b < batch; ++b){
+            for(i = 0; i < size; ++i){
+                int index = i + size*(f + n*b);
+                sum += delta[index] * x_norm[index];
+            }
+        }
+        scale_updates[f] += sum;
+    }
+}
+
+void mean_delta_cpu(float *delta, float *variance, int batch, int filters, int spatial, float *mean_delta)
+{
+
+    int i,j,k;
+    for(i = 0; i < filters; ++i){
+        mean_delta[i] = 0;
+        for (j = 0; j < batch; ++j) {
+            for (k = 0; k < spatial; ++k) {
+                int index = j*filters*spatial + i*spatial + k;
+                mean_delta[i] += delta[index];
+            }
+        }
+        mean_delta[i] *= (-1./sqrt(variance[i] + .00001f));
+    }
+}
+void  variance_delta_cpu(float *x, float *delta, float *mean, float *variance, int batch, int filters, int spatial, float *variance_delta)
+{
+
+    int i,j,k;
+    for(i = 0; i < filters; ++i){
+        variance_delta[i] = 0;
+        for(j = 0; j < batch; ++j){
+            for(k = 0; k < spatial; ++k){
+                int index = j*filters*spatial + i*spatial + k;
+                variance_delta[i] += delta[index]*(x[index] - mean[i]);
+            }
+        }
+        variance_delta[i] *= -.5 * pow(variance[i] + .00001f, (float)(-3./2.));
+    }
+}
+void normalize_delta_cpu(float *x, float *mean, float *variance, float *mean_delta, float *variance_delta, int batch, int filters, int spatial, float *delta)
+{
+    int f, j, k;
+    for(j = 0; j < batch; ++j){
+        for(f = 0; f < filters; ++f){
+            for(k = 0; k < spatial; ++k){
+                int index = j*filters*spatial + f*spatial + k;
+                delta[index] = delta[index] * 1./(sqrt(variance[f]) + .00001f) + variance_delta[f] * 2. * (x[index] - mean[f]) / (spatial * batch) + mean_delta[f]/(spatial*batch);
+            }
+        }
+    }
+}
+
+void resize_batchnorm_layer(layer *layer, int w, int h)
+{
+    fprintf(stderr, "Not implemented\n");
+}
+
+void forward_batchnorm_layer(layer l, network_state state)
+{
+    if(l.type == BATCHNORM) copy_cpu(l.outputs*l.batch, state.input, 1, l.output, 1);
+    if(l.type == CONNECTED){
+        l.out_c = l.outputs;
+        l.out_h = l.out_w = 1;
+    }
+    if(state.train){
+        mean_cpu(l.output, l.batch, l.out_c, l.out_h*l.out_w, l.mean);   
+        variance_cpu(l.output, l.mean, l.batch, l.out_c, l.out_h*l.out_w, l.variance);   
+        normalize_cpu(l.output, l.mean, l.variance, l.batch, l.out_c, l.out_h*l.out_w);   
+    } else {
+        normalize_cpu(l.output, l.rolling_mean, l.rolling_variance, l.batch, l.out_c, l.out_h*l.out_w);
+    }
+    scale_bias(l.output, l.scales, l.batch, l.out_c, l.out_h*l.out_w);
+}
+
+void backward_batchnorm_layer(const layer layer, network_state state)
+{
+}
+
+#ifdef GPU
+void forward_batchnorm_layer_gpu(layer l, network_state state)
+{
+    if(l.type == BATCHNORM) copy_ongpu(l.outputs*l.batch, state.input, 1, l.output_gpu, 1);
+    if(l.type == CONNECTED){
+        l.out_c = l.outputs;
+        l.out_h = l.out_w = 1;
+    }
+    if (state.train) {
+        fast_mean_gpu(l.output_gpu, l.batch, l.out_c, l.out_h*l.out_w, l.mean_gpu);
+        fast_variance_gpu(l.output_gpu, l.mean_gpu, l.batch, l.out_c, l.out_h*l.out_w, l.variance_gpu);
+
+        scal_ongpu(l.out_c, .95, l.rolling_mean_gpu, 1);
+        axpy_ongpu(l.out_c, .05, l.mean_gpu, 1, l.rolling_mean_gpu, 1);
+        scal_ongpu(l.out_c, .95, l.rolling_variance_gpu, 1);
+        axpy_ongpu(l.out_c, .05, l.variance_gpu, 1, l.rolling_variance_gpu, 1);
+
+        copy_ongpu(l.outputs*l.batch, l.output_gpu, 1, l.x_gpu, 1);
+        normalize_gpu(l.output_gpu, l.mean_gpu, l.variance_gpu, l.batch, l.out_c, l.out_h*l.out_w);
+        copy_ongpu(l.outputs*l.batch, l.output_gpu, 1, l.x_norm_gpu, 1);
+    } else {
+        normalize_gpu(l.output_gpu, l.rolling_mean_gpu, l.rolling_variance_gpu, l.batch, l.out_c, l.out_h*l.out_w);
+    }
+
+    scale_bias_gpu(l.output_gpu, l.scales_gpu, l.batch, l.out_c, l.out_h*l.out_w);
+}
+
+void backward_batchnorm_layer_gpu(const layer l, network_state state)
+{
+    backward_scale_gpu(l.x_norm_gpu, l.delta_gpu, l.batch, l.out_c, l.out_w*l.out_h, l.scale_updates_gpu);
+
+    scale_bias_gpu(l.delta_gpu, l.scales_gpu, l.batch, l.out_c, l.out_h*l.out_w);
+
+    fast_mean_delta_gpu(l.delta_gpu, l.variance_gpu, l.batch, l.out_c, l.out_w*l.out_h, l.mean_delta_gpu);
+    fast_variance_delta_gpu(l.x_gpu, l.delta_gpu, l.mean_gpu, l.variance_gpu, l.batch, l.out_c, l.out_w*l.out_h, l.variance_delta_gpu);
+    normalize_delta_gpu(l.x_gpu, l.mean_gpu, l.variance_gpu, l.mean_delta_gpu, l.variance_delta_gpu, l.batch, l.out_c, l.out_w*l.out_h, l.delta_gpu);
+    if(l.type == BATCHNORM) copy_ongpu(l.outputs*l.batch, l.delta_gpu, 1, state.delta, 1);
+}
+#endif
diff --git a/src/batchnorm_layer.h b/src/batchnorm_layer.h
new file mode 100644
index 0000000..61810b6
--- /dev/null
+++ b/src/batchnorm_layer.h
@@ -0,0 +1,17 @@
+#ifndef BATCHNORM_LAYER_H
+#define BATCHNORM_LAYER_H
+
+#include "image.h"
+#include "layer.h"
+#include "network.h"
+
+layer make_batchnorm_layer(int batch, int w, int h, int c);
+void forward_batchnorm_layer(layer l, network_state state);
+void backward_batchnorm_layer(layer l, network_state state);
+
+#ifdef GPU
+void forward_batchnorm_layer_gpu(layer l, network_state state);
+void backward_batchnorm_layer_gpu(layer l, network_state state);
+#endif
+
+#endif
diff --git a/src/blas.h b/src/blas.h
index 030ef66..47d930c 100644
--- a/src/blas.h
+++ b/src/blas.h
@@ -7,6 +7,7 @@
 void test_blas();
 
 void const_cpu(int N, float ALPHA, float *X, int INCX);
+void constrain_ongpu(int N, float ALPHA, float * X, int INCX);
 void pow_cpu(int N, float ALPHA, float *X, int INCX, float *Y, int INCY);
 void mul_cpu(int N, float *X, int INCX, float *Y, int INCY);
 
@@ -58,8 +59,15 @@
 void scale_bias_gpu(float *output, float *biases, int batch, int n, int size);
 void backward_scale_gpu(float *x_norm, float *delta, int batch, int n, int size, float *scale_updates);
 void scale_bias_gpu(float *output, float *biases, int batch, int n, int size);
+void add_bias_gpu(float *output, float *biases, int batch, int n, int size);
+void backward_bias_gpu(float *bias_updates, float *delta, int batch, int n, int size);
 
 void smooth_l1_gpu(int n, float *pred, float *truth, float *delta, float *error);
 void l2_gpu(int n, float *pred, float *truth, float *delta, float *error);
+void weighted_delta_gpu(float *a, float *b, float *s, float *da, float *db, float *ds, int num, float *dc);
+void weighted_sum_gpu(float *a, float *b, float *s, int num, float *c);
+void mult_add_into_gpu(int num, float *a, float *b, float *c);
+
+
 #endif
 #endif
diff --git a/src/blas_kernels.cu b/src/blas_kernels.cu
index 98366f8..ac537d8 100644
--- a/src/blas_kernels.cu
+++ b/src/blas_kernels.cu
@@ -9,6 +9,137 @@
 #include "utils.h"
 }
 
+__global__ void scale_bias_kernel(float *output, float *biases, int n, int size)
+{
+    int offset = blockIdx.x * blockDim.x + threadIdx.x;
+    int filter = blockIdx.y;
+    int batch = blockIdx.z;
+
+    if(offset < size) output[(batch*n+filter)*size + offset] *= biases[filter];
+}
+
+void scale_bias_gpu(float *output, float *biases, int batch, int n, int size)
+{
+    dim3 dimGrid((size-1)/BLOCK + 1, n, batch);
+    dim3 dimBlock(BLOCK, 1, 1);
+
+    scale_bias_kernel<<<dimGrid, dimBlock>>>(output, biases, n, size);
+    check_error(cudaPeekAtLastError());
+}
+
+__global__ void backward_scale_kernel(float *x_norm, float *delta, int batch, int n, int size, float *scale_updates)
+{
+    __shared__ float part[BLOCK];
+    int i,b;
+    int filter = blockIdx.x;
+    int p = threadIdx.x;
+    float sum = 0;
+    for(b = 0; b < batch; ++b){
+        for(i = 0; i < size; i += BLOCK){
+            int index = p + i + size*(filter + n*b);
+            sum += (p+i < size) ? delta[index]*x_norm[index] : 0;
+        }
+    }
+    part[p] = sum;
+    __syncthreads();
+    if (p == 0) {
+        for(i = 0; i < BLOCK; ++i) scale_updates[filter] += part[i];
+    }
+}
+
+void backward_scale_gpu(float *x_norm, float *delta, int batch, int n, int size, float *scale_updates)
+{
+    backward_scale_kernel<<<n, BLOCK>>>(x_norm, delta, batch, n, size, scale_updates);
+    check_error(cudaPeekAtLastError());
+}
+
+__global__ void add_bias_kernel(float *output, float *biases, int n, int size)
+{
+    int offset = blockIdx.x * blockDim.x + threadIdx.x;
+    int filter = blockIdx.y;
+    int batch = blockIdx.z;
+
+    if(offset < size) output[(batch*n+filter)*size + offset] += biases[filter];
+}
+
+void add_bias_gpu(float *output, float *biases, int batch, int n, int size)
+{
+    dim3 dimGrid((size-1)/BLOCK + 1, n, batch);
+    dim3 dimBlock(BLOCK, 1, 1);
+
+    add_bias_kernel<<<dimGrid, dimBlock>>>(output, biases, n, size);
+    check_error(cudaPeekAtLastError());
+}
+
+__global__ void backward_bias_kernel(float *bias_updates, float *delta, int batch, int n, int size)
+{
+    __shared__ float part[BLOCK];
+    int i,b;
+    int filter = blockIdx.x;
+    int p = threadIdx.x;
+    float sum = 0;
+    for(b = 0; b < batch; ++b){
+        for(i = 0; i < size; i += BLOCK){
+            int index = p + i + size*(filter + n*b);
+            sum += (p+i < size) ? delta[index] : 0;
+        }
+    }
+    part[p] = sum;
+    __syncthreads();
+    if (p == 0) {
+        for(i = 0; i < BLOCK; ++i) bias_updates[filter] += part[i];
+    }
+}
+
+/*
+__global__ void dot_kernel(float *output, float scale, int batch, int n, int size, float *delta)
+{
+    int index = (blockIdx.x + blockIdx.y*gridDim.x) * blockDim.x + threadIdx.x;
+    int f1 = index / n;
+    int f2 = index % n;
+    if (f2 <= f1) return;
+    
+    float sum = 0;
+    float norm1 = 0;
+    float norm2 = 0;
+    int b, i;
+    for(b = 0; b <  batch; ++b){
+        for(i = 0; i < size; ++i){
+            int i1 = b * size * n + f1 * size + i;
+            int i2 = b * size * n + f2 * size + i;
+            sum += output[i1] * output[i2];
+            norm1 += output[i1] * output[i1];
+            norm2 += output[i2] * output[i2];
+        }
+    }
+    norm1 = sqrt(norm1);
+    norm2 = sqrt(norm2);
+    float norm = norm1 * norm2;
+    sum = sum / norm;
+    for(b = 0; b <  batch; ++b){
+        for(i = 0; i < size; ++i){
+            int i1 = b * size * n + f1 * size + i;
+            int i2 = b * size * n + f2 * size + i;
+            delta[i1] += - scale * sum * output[i2] / norm;
+            delta[i2] += - scale * sum * output[i1] / norm;
+        }
+    }
+}
+
+void dot_error_gpu(layer l)
+{
+    dot_kernel<<<cuda_gridsize(l.n*l.n), BLOCK>>>(l.output_gpu, l.dot, l.batch, l.n, l.out_w * l.out_h, l.delta_gpu);
+    check_error(cudaPeekAtLastError());
+}
+*/
+
+void backward_bias_gpu(float *bias_updates, float *delta, int batch, int n, int size)
+{
+    backward_bias_kernel<<<n, BLOCK>>>(bias_updates, delta, batch, n, size);
+    check_error(cudaPeekAtLastError());
+}
+
+
 __global__ void normalize_kernel(int N, float *x, float *mean, float *variance, int batch, int filters, int spatial)
 {
     int index = (blockIdx.x + blockIdx.y*gridDim.x) * blockDim.x + threadIdx.x;
@@ -199,6 +330,12 @@
     if(i < N) X[i*INCX] = ALPHA;
 }
 
+__global__ void constrain_kernel(int N, float ALPHA, float *X, int INCX)
+{
+    int i = (blockIdx.x + blockIdx.y*gridDim.x) * blockDim.x + threadIdx.x;
+    if(i < N) X[i*INCX] = min(ALPHA, max(-ALPHA, X[i*INCX]));
+}
+
 __global__ void scal_kernel(int N, float ALPHA, float *X, int INCX)
 {
     int i = (blockIdx.x + blockIdx.y*gridDim.x) * blockDim.x + threadIdx.x;
@@ -363,6 +500,13 @@
     check_error(cudaPeekAtLastError());
 }
 
+extern "C" void constrain_ongpu(int N, float ALPHA, float * X, int INCX)
+{
+    constrain_kernel<<<cuda_gridsize(N), BLOCK>>>(N, ALPHA, X, INCX);
+    check_error(cudaPeekAtLastError());
+}
+
+
 extern "C" void scal_ongpu(int N, float ALPHA, float * X, int INCX)
 {
     scal_kernel<<<cuda_gridsize(N), BLOCK>>>(N, ALPHA, X, INCX);
@@ -448,3 +592,48 @@
     l2_kernel<<<cuda_gridsize(n), BLOCK>>>(n, pred, truth, delta, error);
     check_error(cudaPeekAtLastError());
 }
+
+
+__global__ void weighted_sum_kernel(int n, float *a, float *b, float *s, float *c)
+{
+    int i = (blockIdx.x + blockIdx.y*gridDim.x) * blockDim.x + threadIdx.x;
+    if(i < n){
+        c[i] = s[i]*a[i] + (1-s[i])*(b ? b[i] : 0);
+    }
+}
+
+extern "C" void weighted_sum_gpu(float *a, float *b, float *s, int num, float *c)
+{
+    weighted_sum_kernel<<<cuda_gridsize(num), BLOCK>>>(num, a, b, s, c);
+    check_error(cudaPeekAtLastError());
+}
+
+__global__ void weighted_delta_kernel(int n, float *a, float *b, float *s, float *da, float *db, float *ds, float *dc)
+{
+    int i = (blockIdx.x + blockIdx.y*gridDim.x) * blockDim.x + threadIdx.x;
+    if(i < n){
+        if(da) da[i] += dc[i] * s[i];
+        db[i] += dc[i] * (1-s[i]);
+        ds[i] += dc[i] * a[i] + dc[i] * -b[i];
+    }
+}
+
+extern "C" void weighted_delta_gpu(float *a, float *b, float *s, float *da, float *db, float *ds, int num, float *dc)
+{
+    weighted_delta_kernel<<<cuda_gridsize(num), BLOCK>>>(num, a, b, s, da, db, ds, dc);
+    check_error(cudaPeekAtLastError());
+}
+
+__global__ void mult_add_into_kernel(int n, float *a, float *b, float *c)
+{
+    int i = (blockIdx.x + blockIdx.y*gridDim.x) * blockDim.x + threadIdx.x;
+    if(i < n){
+        c[i] += a[i]*b[i];
+    }
+}
+
+extern "C" void mult_add_into_gpu(int num, float *a, float *b, float *c)
+{
+    mult_add_into_kernel<<<cuda_gridsize(num), BLOCK>>>(num, a, b, c);
+    check_error(cudaPeekAtLastError());
+}
diff --git a/src/classifier.c b/src/classifier.c
index 2e974a5..7060c5e 100644
--- a/src/classifier.c
+++ b/src/classifier.c
@@ -3,6 +3,7 @@
 #include "parser.h"
 #include "option_list.h"
 #include "blas.h"
+#include "classifier.h"
 #include <sys/time.h>
 
 #ifdef OPENCV
@@ -49,7 +50,7 @@
         load_weights(&net, weightfile);
     }
     printf("Learning Rate: %g, Momentum: %g, Decay: %g\n", net.learning_rate, net.momentum, net.decay);
-    int imgs = 1024;
+    int imgs = net.batch;
 
     list *options = read_data_cfg(datacfg);
 
@@ -72,7 +73,7 @@
     args.w = net.w;
     args.h = net.h;
 
-    args.min = net.w;
+    args.min = net.min_crop;
     args.max = net.max_crop;
     args.size = net.w;
 
diff --git a/src/classifier.h b/src/classifier.h
new file mode 100644
index 0000000..3c89f49
--- /dev/null
+++ b/src/classifier.h
@@ -0,0 +1,2 @@
+
+list *read_data_cfg(char *filename);
diff --git a/src/connected_layer.c b/src/connected_layer.c
index df78e67..f20aa93 100644
--- a/src/connected_layer.c
+++ b/src/connected_layer.c
@@ -1,4 +1,5 @@
 #include "connected_layer.h"
+#include "batchnorm_layer.h"
 #include "utils.h"
 #include "cuda.h"
 #include "blas.h"
@@ -19,6 +20,12 @@
     l.outputs = outputs;
     l.batch=batch;
     l.batch_normalize = batch_normalize;
+    l.h = 1;
+    l.w = 1;
+    l.c = inputs;
+    l.out_h = 1;
+    l.out_w = 1;
+    l.out_c = outputs;
 
     l.output = calloc(batch*outputs, sizeof(float));
     l.delta = calloc(batch*outputs, sizeof(float));
@@ -29,7 +36,6 @@
     l.weights = calloc(outputs*inputs, sizeof(float));
     l.biases = calloc(outputs, sizeof(float));
 
-
     //float scale = 1./sqrt(inputs);
     float scale = sqrt(2./inputs);
     for(i = 0; i < outputs*inputs; ++i){
@@ -37,7 +43,7 @@
     }
 
     for(i = 0; i < outputs; ++i){
-        l.biases[i] = scale;
+        l.biases[i] = 0;
     }
 
     if(batch_normalize){
@@ -176,6 +182,19 @@
     if(c) gemm(0,0,m,n,k,1,a,k,b,n,1,c,n);
 }
 
+
+void denormalize_connected_layer(layer l)
+{
+    int i, j;
+    for(i = 0; i < l.outputs; ++i){
+        float scale = l.scales[i]/sqrt(l.rolling_variance[i] + .00001);
+        for(j = 0; j < l.inputs; ++j){
+            l.weights[i*l.inputs + j] *= scale;
+        }
+        l.biases[i] -= l.rolling_mean[i] * scale;
+    }
+}
+
 #ifdef GPU
 
 void pull_connected_layer(connected_layer l)
@@ -223,11 +242,7 @@
 {
     int i;
     fill_ongpu(l.outputs*l.batch, 0, l.output_gpu, 1);
-    /*
-       for(i = 0; i < l.batch; ++i){
-       copy_ongpu_offset(l.outputs, l.biases_gpu, 0, 1, l.output_gpu, i*l.outputs, 1);
-       }
-     */
+
     int m = l.batch;
     int k = l.inputs;
     int n = l.outputs;
@@ -236,52 +251,26 @@
     float * c = l.output_gpu;
     gemm_ongpu(0,1,m,n,k,1,a,k,b,k,1,c,n);
     if(l.batch_normalize){
-        if(state.train){
-            fast_mean_gpu(l.output_gpu, l.batch, l.outputs, 1, l.mean_gpu);
-            fast_variance_gpu(l.output_gpu, l.mean_gpu, l.batch, l.outputs, 1, l.variance_gpu);
-
-            scal_ongpu(l.outputs, .95, l.rolling_mean_gpu, 1);
-            axpy_ongpu(l.outputs, .05, l.mean_gpu, 1, l.rolling_mean_gpu, 1);
-            scal_ongpu(l.outputs, .95, l.rolling_variance_gpu, 1);
-            axpy_ongpu(l.outputs, .05, l.variance_gpu, 1, l.rolling_variance_gpu, 1);
-
-            copy_ongpu(l.outputs*l.batch, l.output_gpu, 1, l.x_gpu, 1);
-            normalize_gpu(l.output_gpu, l.mean_gpu, l.variance_gpu, l.batch, l.outputs, 1);
-            copy_ongpu(l.outputs*l.batch, l.output_gpu, 1, l.x_norm_gpu, 1);
-        } else {
-            normalize_gpu(l.output_gpu, l.rolling_mean_gpu, l.rolling_variance_gpu, l.batch, l.outputs, 1);
-        }
-
-        scale_bias_gpu(l.output_gpu, l.scales_gpu, l.batch, l.outputs, 1);
+        forward_batchnorm_layer_gpu(l, state);
     }
     for(i = 0; i < l.batch; ++i){
         axpy_ongpu(l.outputs, 1, l.biases_gpu, 1, l.output_gpu + i*l.outputs, 1);
     }
     activate_array_ongpu(l.output_gpu, l.outputs*l.batch, l.activation);
 
-    /*
-       cuda_pull_array(l.output_gpu, l.output, l.outputs*l.batch);
-       float avg = mean_array(l.output, l.outputs*l.batch);
-       printf("%f\n", avg);
-     */
 }
 
 void backward_connected_layer_gpu(connected_layer l, network_state state)
 {
     int i;
+    constrain_ongpu(l.outputs*l.batch, 5, l.delta_gpu, 1);
     gradient_array_ongpu(l.output_gpu, l.outputs*l.batch, l.activation, l.delta_gpu);
     for(i = 0; i < l.batch; ++i){
         axpy_ongpu(l.outputs, 1, l.delta_gpu + i*l.outputs, 1, l.bias_updates_gpu, 1);
     }
 
     if(l.batch_normalize){
-        backward_scale_gpu(l.x_norm_gpu, l.delta_gpu, l.batch, l.outputs, 1, l.scale_updates_gpu);
-
-        scale_bias_gpu(l.delta_gpu, l.scales_gpu, l.batch, l.outputs, 1);
-
-        fast_mean_delta_gpu(l.delta_gpu, l.variance_gpu, l.batch, l.outputs, 1, l.mean_delta_gpu);
-        fast_variance_delta_gpu(l.x_gpu, l.delta_gpu, l.mean_gpu, l.variance_gpu, l.batch, l.outputs, 1, l.variance_delta_gpu);
-        normalize_delta_gpu(l.x_gpu, l.mean_gpu, l.variance_gpu, l.mean_delta_gpu, l.variance_delta_gpu, l.batch, l.outputs, 1, l.delta_gpu);
+        backward_batchnorm_layer_gpu(l, state);
     }
 
     int m = l.outputs;
diff --git a/src/connected_layer.h b/src/connected_layer.h
index 56bd1c3..affcaaf 100644
--- a/src/connected_layer.h
+++ b/src/connected_layer.h
@@ -12,6 +12,7 @@
 void forward_connected_layer(connected_layer layer, network_state state);
 void backward_connected_layer(connected_layer layer, network_state state);
 void update_connected_layer(connected_layer layer, int batch, float learning_rate, float momentum, float decay);
+void denormalize_connected_layer(layer l);
 
 #ifdef GPU
 void forward_connected_layer_gpu(connected_layer layer, network_state state);
diff --git a/src/convolutional_kernels.cu b/src/convolutional_kernels.cu
index 3dc125f..62d6079 100644
--- a/src/convolutional_kernels.cu
+++ b/src/convolutional_kernels.cu
@@ -4,6 +4,7 @@
 
 extern "C" {
 #include "convolutional_layer.h"
+#include "batchnorm_layer.h"
 #include "gemm.h"
 #include "blas.h"
 #include "im2col.h"
@@ -12,6 +13,41 @@
 #include "cuda.h"
 }
 
+__global__ void binarize_kernel(float *x, int n, float *binary)
+{
+    int i = (blockIdx.x + blockIdx.y*gridDim.x) * blockDim.x + threadIdx.x;
+    if (i >= n) return;
+    binary[i] = (x[i] > 0) ? 1 : -1;
+}
+
+void binarize_gpu(float *x, int n, float *binary)
+{
+    binarize_kernel<<<cuda_gridsize(n), BLOCK>>>(x, n, binary);
+    check_error(cudaPeekAtLastError());
+}
+
+__global__ void binarize_input_kernel(float *input, int n, int size, float *binary)
+{
+    int s = (blockIdx.x + blockIdx.y*gridDim.x) * blockDim.x + threadIdx.x;
+    if (s >= size) return;
+    int i = 0;
+    float mean = 0;
+    for(i = 0; i < n; ++i){
+        mean += abs(input[i*size + s]);
+    }
+    mean = mean / n;
+    for(i = 0; i < n; ++i){
+        binary[i*size + s] = (input[i*size + s] > 0) ? mean : -mean;
+    }
+}
+
+void binarize_input_gpu(float *input, int n, int size, float *binary)
+{
+    binarize_input_kernel<<<cuda_gridsize(size), BLOCK>>>(input, n, size, binary);
+    check_error(cudaPeekAtLastError());
+}
+
+
 __global__ void binarize_filters_kernel(float *filters, int n, int size, float *binary)
 {
     int f = (blockIdx.x + blockIdx.y*gridDim.x) * blockDim.x + threadIdx.x;
@@ -27,140 +63,12 @@
     }
 }
 
-__global__ void scale_bias_kernel(float *output, float *biases, int n, int size)
-{
-    int offset = blockIdx.x * blockDim.x + threadIdx.x;
-    int filter = blockIdx.y;
-    int batch = blockIdx.z;
-
-    if(offset < size) output[(batch*n+filter)*size + offset] *= biases[filter];
-}
-
-void scale_bias_gpu(float *output, float *biases, int batch, int n, int size)
-{
-    dim3 dimGrid((size-1)/BLOCK + 1, n, batch);
-    dim3 dimBlock(BLOCK, 1, 1);
-
-    scale_bias_kernel<<<dimGrid, dimBlock>>>(output, biases, n, size);
-    check_error(cudaPeekAtLastError());
-}
-
-__global__ void backward_scale_kernel(float *x_norm, float *delta, int batch, int n, int size, float *scale_updates)
-{
-    __shared__ float part[BLOCK];
-    int i,b;
-    int filter = blockIdx.x;
-    int p = threadIdx.x;
-    float sum = 0;
-    for(b = 0; b < batch; ++b){
-        for(i = 0; i < size; i += BLOCK){
-            int index = p + i + size*(filter + n*b);
-            sum += (p+i < size) ? delta[index]*x_norm[index] : 0;
-        }
-    }
-    part[p] = sum;
-    __syncthreads();
-    if (p == 0) {
-        for(i = 0; i < BLOCK; ++i) scale_updates[filter] += part[i];
-    }
-}
-
 void binarize_filters_gpu(float *filters, int n, int size, float *binary)
 {
     binarize_filters_kernel<<<cuda_gridsize(n), BLOCK>>>(filters, n, size, binary);
     check_error(cudaPeekAtLastError());
 }
 
-void backward_scale_gpu(float *x_norm, float *delta, int batch, int n, int size, float *scale_updates)
-{
-    backward_scale_kernel<<<n, BLOCK>>>(x_norm, delta, batch, n, size, scale_updates);
-    check_error(cudaPeekAtLastError());
-}
-
-__global__ void add_bias_kernel(float *output, float *biases, int n, int size)
-{
-    int offset = blockIdx.x * blockDim.x + threadIdx.x;
-    int filter = blockIdx.y;
-    int batch = blockIdx.z;
-
-    if(offset < size) output[(batch*n+filter)*size + offset] += biases[filter];
-}
-
-void add_bias_gpu(float *output, float *biases, int batch, int n, int size)
-{
-    dim3 dimGrid((size-1)/BLOCK + 1, n, batch);
-    dim3 dimBlock(BLOCK, 1, 1);
-
-    add_bias_kernel<<<dimGrid, dimBlock>>>(output, biases, n, size);
-    check_error(cudaPeekAtLastError());
-}
-
-__global__ void backward_bias_kernel(float *bias_updates, float *delta, int batch, int n, int size)
-{
-    __shared__ float part[BLOCK];
-    int i,b;
-    int filter = blockIdx.x;
-    int p = threadIdx.x;
-    float sum = 0;
-    for(b = 0; b < batch; ++b){
-        for(i = 0; i < size; i += BLOCK){
-            int index = p + i + size*(filter + n*b);
-            sum += (p+i < size) ? delta[index] : 0;
-        }
-    }
-    part[p] = sum;
-    __syncthreads();
-    if (p == 0) {
-        for(i = 0; i < BLOCK; ++i) bias_updates[filter] += part[i];
-    }
-}
-
-__global__ void dot_kernel(float *output, float scale, int batch, int n, int size, float *delta)
-{
-    int index = (blockIdx.x + blockIdx.y*gridDim.x) * blockDim.x + threadIdx.x;
-    int f1 = index / n;
-    int f2 = index % n;
-    if (f2 <= f1) return;
-    
-    float sum = 0;
-    float norm1 = 0;
-    float norm2 = 0;
-    int b, i;
-    for(b = 0; b <  batch; ++b){
-        for(i = 0; i < size; ++i){
-            int i1 = b * size * n + f1 * size + i;
-            int i2 = b * size * n + f2 * size + i;
-            sum += output[i1] * output[i2];
-            norm1 += output[i1] * output[i1];
-            norm2 += output[i2] * output[i2];
-        }
-    }
-    norm1 = sqrt(norm1);
-    norm2 = sqrt(norm2);
-    float norm = norm1 * norm2;
-    sum = sum / norm;
-    for(b = 0; b <  batch; ++b){
-        for(i = 0; i < size; ++i){
-            int i1 = b * size * n + f1 * size + i;
-            int i2 = b * size * n + f2 * size + i;
-            delta[i1] += - scale * sum * output[i2] / norm;
-            delta[i2] += - scale * sum * output[i1] / norm;
-        }
-    }
-}
-
-void dot_error_gpu(layer l)
-{
-    dot_kernel<<<cuda_gridsize(l.n*l.n), BLOCK>>>(l.output_gpu, l.dot, l.batch, l.n, l.out_w * l.out_h, l.delta_gpu);
-    check_error(cudaPeekAtLastError());
-}
-
-void backward_bias_gpu(float *bias_updates, float *delta, int batch, int n, int size)
-{
-    backward_bias_kernel<<<n, BLOCK>>>(bias_updates, delta, batch, n, size);
-    check_error(cudaPeekAtLastError());
-}
-
 void forward_convolutional_layer_gpu(convolutional_layer l, network_state state)
 {
     int i;
@@ -175,6 +83,16 @@
         swap_binary(&l);
     }
 
+    if(l.xnor){
+        binarize_filters_gpu(l.filters_gpu, l.n, l.c*l.size*l.size, l.binary_filters_gpu);
+        //binarize_gpu(l.filters_gpu, l.n*l.c*l.size*l.size, l.binary_filters_gpu);
+        swap_binary(&l);
+        for(i = 0; i < l.batch; ++i){
+            binarize_input_gpu(state.input + i*l.inputs, l.c, l.h*l.w, l.binary_input_gpu + i*l.inputs);
+        }
+        state.input = l.binary_input_gpu;
+    }
+
     for(i = 0; i < l.batch; ++i){
         im2col_ongpu(state.input + i*l.c*l.h*l.w, l.c,  l.h,  l.w,  l.size,  l.stride, l.pad, l.col_image_gpu);
         float * a = l.filters_gpu;
@@ -184,29 +102,13 @@
     }
 
     if (l.batch_normalize) {
-        if (state.train) {
-            fast_mean_gpu(l.output_gpu, l.batch, l.n, l.out_h*l.out_w, l.mean_gpu);
-            fast_variance_gpu(l.output_gpu, l.mean_gpu, l.batch, l.n, l.out_h*l.out_w, l.variance_gpu);
-
-            scal_ongpu(l.n, .95, l.rolling_mean_gpu, 1);
-            axpy_ongpu(l.n, .05, l.mean_gpu, 1, l.rolling_mean_gpu, 1);
-            scal_ongpu(l.n, .95, l.rolling_variance_gpu, 1);
-            axpy_ongpu(l.n, .05, l.variance_gpu, 1, l.rolling_variance_gpu, 1);
-
-            copy_ongpu(l.outputs*l.batch, l.output_gpu, 1, l.x_gpu, 1);
-            normalize_gpu(l.output_gpu, l.mean_gpu, l.variance_gpu, l.batch, l.n, l.out_h*l.out_w);
-            copy_ongpu(l.outputs*l.batch, l.output_gpu, 1, l.x_norm_gpu, 1);
-        } else {
-            normalize_gpu(l.output_gpu, l.rolling_mean_gpu, l.rolling_variance_gpu, l.batch, l.n, l.out_h*l.out_w);
-        }
-
-        scale_bias_gpu(l.output_gpu, l.scales_gpu, l.batch, l.n, l.out_h*l.out_w);
+        forward_batchnorm_layer_gpu(l, state);
     }
     add_bias_gpu(l.output_gpu, l.biases_gpu, l.batch, l.n, n);
 
     activate_array_ongpu(l.output_gpu, m*n*l.batch, l.activation);
-    if(l.dot > 0) dot_error_gpu(l);
-    if(l.binary) swap_binary(&l);
+    //if(l.dot > 0) dot_error_gpu(l);
+    if(l.binary || l.xnor) swap_binary(&l);
 }
 
 void backward_convolutional_layer_gpu(convolutional_layer l, network_state state)
@@ -222,15 +124,10 @@
     backward_bias_gpu(l.bias_updates_gpu, l.delta_gpu, l.batch, l.n, k);
 
     if(l.batch_normalize){
-        backward_scale_gpu(l.x_norm_gpu, l.delta_gpu, l.batch, l.n, l.out_w*l.out_h, l.scale_updates_gpu);
-
-        scale_bias_gpu(l.delta_gpu, l.scales_gpu, l.batch, l.n, l.out_h*l.out_w);
-
-        fast_mean_delta_gpu(l.delta_gpu, l.variance_gpu, l.batch, l.n, l.out_w*l.out_h, l.mean_delta_gpu);
-        fast_variance_delta_gpu(l.x_gpu, l.delta_gpu, l.mean_gpu, l.variance_gpu, l.batch, l.n, l.out_w*l.out_h, l.variance_delta_gpu);
-        normalize_delta_gpu(l.x_gpu, l.mean_gpu, l.variance_gpu, l.mean_delta_gpu, l.variance_delta_gpu, l.batch, l.n, l.out_w*l.out_h, l.delta_gpu);
+        backward_batchnorm_layer_gpu(l, state);
     }
 
+    if(l.xnor) state.input = l.binary_input_gpu;
     for(i = 0; i < l.batch; ++i){
         float * a = l.delta_gpu;
         float * b = l.col_image_gpu;
@@ -240,7 +137,7 @@
         gemm_ongpu(0,1,m,n,k,1,a + i*m*k,k,b,k,1,c,n);
 
         if(state.delta){
-            if(l.binary) swap_binary(&l);
+            if(l.binary || l.xnor) swap_binary(&l);
             float * a = l.filters_gpu;
             float * b = l.delta_gpu;
             float * c = l.col_image_gpu;
@@ -248,7 +145,7 @@
             gemm_ongpu(1,0,n,k,m,1,a,n,b + i*k*m,k,0,c,k);
 
             col2im_ongpu(l.col_image_gpu, l.c,  l.h,  l.w,  l.size,  l.stride, l.pad, state.delta + i*l.c*l.h*l.w);
-            if(l.binary) swap_binary(&l);
+            if(l.binary || l.xnor) swap_binary(&l);
         }
     }
 }
diff --git a/src/convolutional_layer.c b/src/convolutional_layer.c
index cdc8bd3..d76dfcd 100644
--- a/src/convolutional_layer.c
+++ b/src/convolutional_layer.c
@@ -1,5 +1,6 @@
 #include "convolutional_layer.h"
 #include "utils.h"
+#include "batchnorm_layer.h"
 #include "im2col.h"
 #include "col2im.h"
 #include "blas.h"
@@ -87,65 +88,7 @@
     return float_to_image(w,h,c,l.delta);
 }
 
-void backward_scale_cpu(float *x_norm, float *delta, int batch, int n, int size, float *scale_updates)
-{
-    int i,b,f;
-    for(f = 0; f < n; ++f){
-        float sum = 0;
-        for(b = 0; b < batch; ++b){
-            for(i = 0; i < size; ++i){
-                int index = i + size*(f + n*b);
-                sum += delta[index] * x_norm[index];
-            }
-        }
-        scale_updates[f] += sum;
-    }
-}
-
-void mean_delta_cpu(float *delta, float *variance, int batch, int filters, int spatial, float *mean_delta)
-{
-
-    int i,j,k;
-    for(i = 0; i < filters; ++i){
-        mean_delta[i] = 0;
-        for (j = 0; j < batch; ++j) {
-            for (k = 0; k < spatial; ++k) {
-                int index = j*filters*spatial + i*spatial + k;
-                mean_delta[i] += delta[index];
-            }
-        }
-        mean_delta[i] *= (-1./sqrt(variance[i] + .00001f));
-    }
-}
-void  variance_delta_cpu(float *x, float *delta, float *mean, float *variance, int batch, int filters, int spatial, float *variance_delta)
-{
-
-    int i,j,k;
-    for(i = 0; i < filters; ++i){
-        variance_delta[i] = 0;
-        for(j = 0; j < batch; ++j){
-            for(k = 0; k < spatial; ++k){
-                int index = j*filters*spatial + i*spatial + k;
-                variance_delta[i] += delta[index]*(x[index] - mean[i]);
-            }
-        }
-        variance_delta[i] *= -.5 * pow(variance[i] + .00001f, (float)(-3./2.));
-    }
-}
-void normalize_delta_cpu(float *x, float *mean, float *variance, float *mean_delta, float *variance_delta, int batch, int filters, int spatial, float *delta)
-{
-    int f, j, k;
-    for(j = 0; j < batch; ++j){
-        for(f = 0; f < filters; ++f){
-            for(k = 0; k < spatial; ++k){
-                int index = j*filters*spatial + f*spatial + k;
-                delta[index] = delta[index] * 1./(sqrt(variance[f]) + .00001f) + variance_delta[f] * 2. * (x[index] - mean[f]) / (spatial * batch) + mean_delta[f]/(spatial*batch);
-            }
-        }
-    }
-}
-
-convolutional_layer make_convolutional_layer(int batch, int h, int w, int c, int n, int size, int stride, int pad, ACTIVATION activation, int batch_normalize, int binary)
+convolutional_layer make_convolutional_layer(int batch, int h, int w, int c, int n, int size, int stride, int pad, ACTIVATION activation, int batch_normalize, int binary, int xnor)
 {
     int i;
     convolutional_layer l = {0};
@@ -220,6 +163,11 @@
     if(binary){
         l.binary_filters_gpu = cuda_make_array(l.filters, c*n*size*size);
     }
+    if(xnor){
+        l.binary_filters_gpu = cuda_make_array(l.filters, c*n*size*size);
+        l.binary_input_gpu = cuda_make_array(0, l.inputs*l.batch);
+    }
+    l.xnor = xnor;
 
     if(batch_normalize){
         l.mean_gpu = cuda_make_array(l.mean, n);
@@ -256,7 +204,7 @@
 
 void test_convolutional_layer()
 {
-    convolutional_layer l = make_convolutional_layer(1, 5, 5, 3, 2, 5, 2, 1, LEAKY, 1, 0);
+    convolutional_layer l = make_convolutional_layer(1, 5, 5, 3, 2, 5, 2, 1, LEAKY, 1, 0, 0);
     l.batch_normalize = 1;
     float data[] = {1,1,1,1,1,
         1,1,1,1,1,
@@ -397,14 +345,7 @@
     }
 
     if(l.batch_normalize){
-        if(state.train){
-            mean_cpu(l.output, l.batch, l.n, l.out_h*l.out_w, l.mean);   
-            variance_cpu(l.output, l.mean, l.batch, l.n, l.out_h*l.out_w, l.variance);   
-            normalize_cpu(l.output, l.mean, l.variance, l.batch, l.n, l.out_h*l.out_w);   
-        } else {
-            normalize_cpu(l.output, l.rolling_mean, l.rolling_variance, l.batch, l.n, l.out_h*l.out_w);
-        }
-        scale_bias(l.output, l.scales, l.batch, l.n, out_h*out_w);
+        forward_batchnorm_layer(l, state);
     }
     add_bias(l.output, l.biases, l.batch, l.n, out_h*out_w);
 
diff --git a/src/convolutional_layer.h b/src/convolutional_layer.h
index d0c3d46..3d52b22 100644
--- a/src/convolutional_layer.h
+++ b/src/convolutional_layer.h
@@ -21,7 +21,7 @@
 void backward_bias_gpu(float *bias_updates, float *delta, int batch, int n, int size);
 #endif
 
-convolutional_layer make_convolutional_layer(int batch, int h, int w, int c, int n, int size, int stride, int pad, ACTIVATION activation, int batch_normalization, int binary);
+convolutional_layer make_convolutional_layer(int batch, int h, int w, int c, int n, int size, int stride, int pad, ACTIVATION activation, int batch_normalization, int binary, int xnor);
 void denormalize_convolutional_layer(convolutional_layer l);
 void resize_convolutional_layer(convolutional_layer *layer, int w, int h);
 void forward_convolutional_layer(const convolutional_layer layer, network_state state);
diff --git a/src/crnn_layer.c b/src/crnn_layer.c
index ed65665..5d5fa63 100644
--- a/src/crnn_layer.c
+++ b/src/crnn_layer.c
@@ -48,17 +48,17 @@
 
     l.input_layer = malloc(sizeof(layer));
     fprintf(stderr, "\t\t");
-    *(l.input_layer) = make_convolutional_layer(batch*steps, h, w, c, hidden_filters, 3, 1, 1,  activation, batch_normalize, 0);
+    *(l.input_layer) = make_convolutional_layer(batch*steps, h, w, c, hidden_filters, 3, 1, 1,  activation, batch_normalize, 0, 0);
     l.input_layer->batch = batch;
 
     l.self_layer = malloc(sizeof(layer));
     fprintf(stderr, "\t\t");
-    *(l.self_layer) = make_convolutional_layer(batch*steps, h, w, hidden_filters, hidden_filters, 3, 1, 1,  activation, batch_normalize, 0);
+    *(l.self_layer) = make_convolutional_layer(batch*steps, h, w, hidden_filters, hidden_filters, 3, 1, 1,  activation, batch_normalize, 0, 0);
     l.self_layer->batch = batch;
 
     l.output_layer = malloc(sizeof(layer));
     fprintf(stderr, "\t\t");
-    *(l.output_layer) = make_convolutional_layer(batch*steps, h, w, hidden_filters, output_filters, 3, 1, 1,  activation, batch_normalize, 0);
+    *(l.output_layer) = make_convolutional_layer(batch*steps, h, w, hidden_filters, output_filters, 3, 1, 1,  activation, batch_normalize, 0, 0);
     l.output_layer->batch = batch;
 
     l.output = l.output_layer->output;
diff --git a/src/darknet.c b/src/darknet.c
index 0865c61..f2982ac 100644
--- a/src/darknet.c
+++ b/src/darknet.c
@@ -6,6 +6,7 @@
 #include "utils.h"
 #include "cuda.h"
 #include "blas.h"
+#include "connected_layer.h"
 
 #ifdef OPENCV
 #include "opencv2/highgui/highgui_c.h"
@@ -182,6 +183,25 @@
             denormalize_convolutional_layer(l);
             net.layers[i].batch_normalize=0;
         }
+        if (l.type == CONNECTED && l.batch_normalize) {
+            denormalize_connected_layer(l);
+            net.layers[i].batch_normalize=0;
+        }
+        if (l.type == GRU && l.batch_normalize) {
+            denormalize_connected_layer(*l.input_z_layer);
+            denormalize_connected_layer(*l.input_r_layer);
+            denormalize_connected_layer(*l.input_h_layer);
+            denormalize_connected_layer(*l.state_z_layer);
+            denormalize_connected_layer(*l.state_r_layer);
+            denormalize_connected_layer(*l.state_h_layer);
+            l.input_z_layer->batch_normalize = 0;
+            l.input_r_layer->batch_normalize = 0;
+            l.input_h_layer->batch_normalize = 0;
+            l.state_z_layer->batch_normalize = 0;
+            l.state_r_layer->batch_normalize = 0;
+            l.state_h_layer->batch_normalize = 0;
+            net.layers[i].batch_normalize=0;
+        }
     }
     save_weights(net, outfile);
 }
diff --git a/src/data.c b/src/data.c
index 4d52d11..b0368ee 100644
--- a/src/data.c
+++ b/src/data.c
@@ -22,6 +22,19 @@
     return lines;
 }
 
+char **get_random_paths_indexes(char **paths, int n, int m, int *indexes)
+{
+    char **random_paths = calloc(n, sizeof(char*));
+    int i;
+    for(i = 0; i < n; ++i){
+        int index = rand_r(&data_seed)%m;
+        indexes[i] = index;
+        random_paths[i] = paths[index];
+        if(i == 0) printf("%s\n", paths[index]);
+    }
+    return random_paths;
+}
+
 char **get_random_paths(char **paths, int n, int m)
 {
     char **random_paths = calloc(n, sizeof(char*));
@@ -364,7 +377,7 @@
 data load_data_captcha(char **paths, int n, int m, int k, int w, int h)
 {
     if(m) paths = get_random_paths(paths, n, m);
-    data d;
+    data d = {0};
     d.shallow = 0;
     d.X = load_image_paths(paths, n, w, h);
     d.y = make_matrix(n, k*NUMCHARS);
@@ -379,7 +392,7 @@
 data load_data_captcha_encode(char **paths, int n, int m, int w, int h)
 {
     if(m) paths = get_random_paths(paths, n, m);
-    data d;
+    data d = {0};
     d.shallow = 0;
     d.X = load_image_paths(paths, n, w, h);
     d.X.cols = 17100;
@@ -449,6 +462,9 @@
 
 void free_data(data d)
 {
+    if(d.indexes){
+        free(d.indexes);
+    }
     if(!d.shallow){
         free_matrix(d.X);
         free_matrix(d.y);
@@ -462,7 +478,7 @@
 {
     char **random_paths = get_random_paths(paths, n, m);
     int i;
-    data d;
+    data d = {0};
     d.shallow = 0;
 
     d.X.rows = n;
@@ -514,7 +530,7 @@
 {
     if(m) paths = get_random_paths(paths, 2*n, m);
     int i,j;
-    data d;
+    data d = {0};
     d.shallow = 0;
 
     d.X.rows = n;
@@ -581,7 +597,7 @@
     int h = orig.h;
     int w = orig.w;
 
-    data d;
+    data d = {0};
     d.shallow = 0;
     d.w = w;
     d.h = h;
@@ -629,7 +645,7 @@
 {
     char **random_paths = get_random_paths(paths, n, m);
     int i;
-    data d;
+    data d = {0};
     d.shallow = 0;
 
     d.X.rows = n;
@@ -698,6 +714,8 @@
         *a.d = load_data(a.paths, a.n, a.m, a.labels, a.classes, a.w, a.h);
     } else if (a.type == CLASSIFICATION_DATA){
         *a.d = load_data_augment(a.paths, a.n, a.m, a.labels, a.classes, a.min, a.max, a.size);
+    } else if (a.type == STUDY_DATA){
+        *a.d = load_data_study(a.paths, a.n, a.m, a.labels, a.classes, a.min, a.max, a.size);
     } else if (a.type == DETECTION_DATA){
         *a.d = load_data_detection(a.n, a.paths, a.m, a.classes, a.w, a.h, a.num_boxes, a.background);
     } else if (a.type == WRITING_DATA){
@@ -732,7 +750,7 @@
 {
     if(m) paths = get_random_paths(paths, n, m);
     char **replace_paths = find_replace_paths(paths, n, ".png", "-label.png");
-    data d;
+    data d = {0};
     d.shallow = 0;
     d.X = load_image_paths(paths, n, w, h);
     d.y = load_image_paths_gray(replace_paths, n, out_w, out_h);
@@ -746,7 +764,7 @@
 data load_data(char **paths, int n, int m, char **labels, int k, int w, int h)
 {
     if(m) paths = get_random_paths(paths, n, m);
-    data d;
+    data d = {0};
     d.shallow = 0;
     d.X = load_image_paths(paths, n, w, h);
     d.y = load_labels_paths(paths, n, labels, k);
@@ -754,10 +772,22 @@
     return d;
 }
 
+data load_data_study(char **paths, int n, int m, char **labels, int k, int min, int max, int size)
+{
+    data d = {0};
+    d.indexes = calloc(n, sizeof(int));
+    if(m) paths = get_random_paths_indexes(paths, n, m, d.indexes);
+    d.shallow = 0;
+    d.X = load_image_cropped_paths(paths, n, min, max, size);
+    d.y = load_labels_paths(paths, n, labels, k);
+    if(m) free(paths);
+    return d;
+}
+
 data load_data_augment(char **paths, int n, int m, char **labels, int k, int min, int max, int size)
 {
     if(m) paths = get_random_paths(paths, n, m);
-    data d;
+    data d = {0};
     d.shallow = 0;
     d.X = load_image_cropped_paths(paths, n, min, max, size);
     d.y = load_labels_paths(paths, n, labels, k);
@@ -796,7 +826,7 @@
 
 data concat_data(data d1, data d2)
 {
-    data d;
+    data d = {0};
     d.shallow = 1;
     d.X = concat_matrix(d1.X, d2.X);
     d.y = concat_matrix(d1.y, d2.y);
@@ -805,7 +835,7 @@
 
 data load_categorical_data_csv(char *filename, int target, int k)
 {
-    data d;
+    data d = {0};
     d.shallow = 0;
     matrix X = csv_to_matrix(filename);
     float *truth_1d = pop_column(&X, target);
@@ -822,7 +852,7 @@
 
 data load_cifar10_data(char *filename)
 {
-    data d;
+    data d = {0};
     d.shallow = 0;
     long i,j;
     matrix X = make_matrix(10000, 3072);
@@ -882,7 +912,7 @@
 
 data load_all_cifar10()
 {
-    data d;
+    data d = {0};
     d.shallow = 0;
     int i,j,b;
     matrix X = make_matrix(50000, 3072);
@@ -910,7 +940,7 @@
     //normalize_data_rows(d);
     //translate_data_rows(d, -128);
     scale_data_rows(d, 1./255);
-   // smooth_data(d);
+    smooth_data(d);
     return d;
 }
 
@@ -949,7 +979,7 @@
     X = resize_matrix(X, count);
     y = resize_matrix(y, count);
 
-    data d;
+    data d = {0};
     d.shallow = 0;
     d.X = X;
     d.y = y;
diff --git a/src/data.h b/src/data.h
index f928ade..6befeea 100644
--- a/src/data.h
+++ b/src/data.h
@@ -23,11 +23,12 @@
     int w, h;
     matrix X;
     matrix y;
+    int *indexes;
     int shallow;
 } data;
 
 typedef enum {
-    CLASSIFICATION_DATA, DETECTION_DATA, CAPTCHA_DATA, REGION_DATA, IMAGE_DATA, COMPARE_DATA, WRITING_DATA, SWAG_DATA, TAG_DATA, OLD_CLASSIFICATION_DATA
+    CLASSIFICATION_DATA, DETECTION_DATA, CAPTCHA_DATA, REGION_DATA, IMAGE_DATA, COMPARE_DATA, WRITING_DATA, SWAG_DATA, TAG_DATA, OLD_CLASSIFICATION_DATA, STUDY_DATA
 } data_type;
 
 typedef struct load_args{
@@ -70,6 +71,7 @@
 data load_data_detection(int n, char **paths, int m, int classes, int w, int h, int num_boxes, int background);
 data load_data_tag(char **paths, int n, int m, int k, int min, int max, int size);
 data load_data_augment(char **paths, int n, int m, char **labels, int k, int min, int max, int size);
+data load_data_study(char **paths, int n, int m, char **labels, int k, int min, int max, int size);
 data load_go(char *filename);
 
 box_label *read_boxes(char *filename, int *n);
@@ -90,5 +92,6 @@
 void randomize_data(data d);
 data *split_data(data d, int part, int total);
 data concat_data(data d1, data d2);
+void fill_truth(char *path, char **labels, int k, float *truth);
 
 #endif
diff --git a/src/go.c b/src/go.c
index 8d0cf52..7883ed5 100644
--- a/src/go.c
+++ b/src/go.c
@@ -98,6 +98,7 @@
         int col = b[1];
         labels[col + 19*(row + i*19)] = 1;
         string_to_board(b+2, boards+i*19*19);
+        boards[col + 19*(row + i*19)] = 0;
 
         int flip = rand()%2;
         int rotate = rand()%4;
@@ -132,6 +133,7 @@
     float *board = calloc(19*19*net.batch, sizeof(float));
     float *move = calloc(19*19*net.batch, sizeof(float));
     moves m = load_go_moves("/home/pjreddie/go.train");
+    //moves m = load_go_moves("games.txt");
 
     int N = m.n;
     int epoch = (*net.seen)/N;
@@ -337,6 +339,90 @@
     return 1;
 }
 
+int generate_move(network net, int player, float *board, int multi, float thresh, float temp, char *ko, int print)
+{
+    int i, j;
+    for(i = 0; i < net.n; ++i) net.layers[i].temperature = temp;
+
+    float move[361];
+    if (player < 0) flip_board(board);
+    predict_move(net, board, move, multi);
+    if (player < 0) flip_board(board);
+
+    
+    for(i = 0; i < 19; ++i){
+        for(j = 0; j < 19; ++j){
+            if (!legal_go(board, ko, player, i, j)) move[i*19 + j] = 0;
+        }
+    }
+
+    int indexes[nind];
+    top_k(move, 19*19, nind, indexes);
+    if(thresh > move[indexes[0]]) thresh = move[indexes[nind-1]];
+
+    for(i = 0; i < 19; ++i){
+        for(j = 0; j < 19; ++j){
+            if (move[i*19 + j] < thresh) move[i*19 + j] = 0;
+        }
+    }
+
+
+    int max = max_index(move, 19*19);
+    int row = max / 19;
+    int col = max % 19;
+    int index = sample_array(move, 19*19);
+
+    if(print){
+        top_k(move, 19*19, nind, indexes);
+        for(i = 0; i < nind; ++i){
+            if (!move[indexes[i]]) indexes[i] = -1;
+        }
+        print_board(board, player, indexes);
+        for(i = 0; i < nind; ++i){
+            fprintf(stderr, "%d: %f\n", i+1, move[indexes[i]]);
+        }
+    }
+
+    if(suicide_go(board, player, row, col)){
+        return -1; 
+    }
+    if(suicide_go(board, player, index/19, index%19)) index = max;
+    return index;
+}
+
+void valid_go(char *cfgfile, char *weightfile, int multi)
+{
+    data_seed = time(0);
+    srand(time(0));
+    char *base = basecfg(cfgfile);
+    printf("%s\n", base);
+    network net = parse_network_cfg(cfgfile);
+    if(weightfile){
+        load_weights(&net, weightfile);
+    }
+    set_batch_network(&net, 1);
+    printf("Learning Rate: %g, Momentum: %g, Decay: %g\n", net.learning_rate, net.momentum, net.decay);
+
+    float *board = calloc(19*19, sizeof(float));
+    float *move = calloc(19*19, sizeof(float));
+    moves m = load_go_moves("/home/pjreddie/backup/go.test");
+
+    int N = m.n;
+    int i;
+    int correct = 0;
+    for(i = 0; i <N; ++i){
+        char *b = m.data[i];
+        int row = b[0];
+        int col = b[1];
+        int truth = col + 19*row;
+        string_to_board(b+2, board);
+        predict_move(net, board, move, multi);
+        int index = max_index(move, 19*19);
+        if(index == truth) ++correct;
+        printf("%d Accuracy %f\n", i, (float) correct/(i+1));
+    }
+}
+
 void engine_go(char *filename, char *weightfile, int multi)
 {
     network net = parse_network_cfg(filename);
@@ -346,12 +432,10 @@
     srand(time(0));
     set_batch_network(&net, 1);
     float *board = calloc(19*19, sizeof(float));
-    float *move = calloc(19*19, sizeof(float));
     char *one = calloc(91, sizeof(char));
     char *two = calloc(91, sizeof(char));
     int passed = 0;
     while(1){
-        print_board(board, 1, 0);
         char buff[256];
         int id = 0;
         int has_id = (scanf("%d", &id) == 1);
@@ -436,42 +520,34 @@
             board_to_string(one, board);
 
             printf("=%s \n\n", ids);
+            print_board(board, 1, 0);
         } else if (!strcmp(buff, "genmove")){
             char color[256];
             scanf("%s", color);
             int player = (color[0] == 'b' || color[0] == 'B') ? 1 : -1;
 
-            if(player < 0) flip_board(board);
-            predict_move(net, board, move, multi);
-            if(player < 0) flip_board(board);
-
-            int i, j;
-            for(i = 0; i < 19; ++i){
-                for(j = 0; j < 19; ++j){
-                    if (!legal_go(board, two, player, i, j)) move[i*19 + j] = 0;
-                }
-            }
-            int index = max_index(move, 19*19);
-            int row = index / 19;
-            char col = index % 19;
-
-            char *swap = two;
-            two = one;
-            one = swap;
-
-            if(passed || suicide_go(board, player, row, col)){
+            int index = generate_move(net, player, board, multi, .1, .7, two, 1);
+            if(passed || index < 0){
                 printf("=%s pass\n\n", ids);
                 passed = 0;
             } else {
+                int row = index / 19;
+                int col = index % 19;
+
+                char *swap = two;
+                two = one;
+                one = swap;
+
                 move_go(board, player, row, col);
                 board_to_string(one, board);
-
                 row = 19 - row;
                 if (col >= 8) ++col;
                 printf("=%s %c%d\n\n", ids, 'A' + col, row);
+                print_board(board, 1, 0);
             }
+
         } else if (!strcmp(buff, "p")){
-            print_board(board, 1, 0);
+            //print_board(board, 1, 0);
         } else if (!strcmp(buff, "final_status_list")){
             char type[256];
             scanf("%s", type);
@@ -479,7 +555,30 @@
             char *line = fgetl(stdin);
             free(line);
             if(type[0] == 'd' || type[0] == 'D'){
-                printf("=%s \n\n", ids);
+                FILE *f = fopen("game.txt", "w");
+                int i, j;
+                int count = 2;
+                fprintf(f, "boardsize 19\n");
+                fprintf(f, "clear_board\n");
+                for(j = 0; j < 19; ++j){
+                    for(i = 0; i < 19; ++i){
+                        if(board[j*19 + i] == 1) fprintf(f, "play black %c%d\n", 'A'+i+(i>=8), 19-j);
+                        if(board[j*19 + i] == -1) fprintf(f, "play white %c%d\n", 'A'+i+(i>=8), 19-j);
+                        if(board[j*19 + i]) ++count;
+                    }
+                }
+                fprintf(f, "final_status_list dead\n");
+                fclose(f);
+                FILE *p = popen("./gnugo --mode gtp < game.txt", "r");
+                for(i = 0; i < count; ++i){
+                    free(fgetl(p));
+                    free(fgetl(p));
+                }
+                char *l = 0;
+                while((l = fgetl(p))){
+                    printf("%s\n", l);
+                    free(l);
+                }
             } else {
                 printf("?%s unknown command\n\n", ids);
             }
@@ -588,17 +687,118 @@
     }
 }
 
-void boards_go()
+float score_game(float *board)
 {
-    moves m = load_go_moves("/home/pjreddie/go.train");
-    int i;
-    float board[361];
-    for(i = 0; i < 10; ++i){
-        printf("%d %d\n", m.data[i][0], m.data[i][1]);
-        string_to_board(m.data[i]+2, board);
-        print_board(board, 1, 0);
+    FILE *f = fopen("game.txt", "w");
+    int i, j;
+    int count = 3;
+    fprintf(f, "komi 6.5\n");
+    fprintf(f, "boardsize 19\n");
+    fprintf(f, "clear_board\n");
+    for(j = 0; j < 19; ++j){
+        for(i = 0; i < 19; ++i){
+            if(board[j*19 + i] == 1) fprintf(f, "play black %c%d\n", 'A'+i+(i>=8), 19-j);
+            if(board[j*19 + i] == -1) fprintf(f, "play white %c%d\n", 'A'+i+(i>=8), 19-j);
+            if(board[j*19 + i]) ++count;
+        }
+    }
+    fprintf(f, "final_score\n");
+    fclose(f);
+    FILE *p = popen("./gnugo --mode gtp < game.txt", "r");
+    for(i = 0; i < count; ++i){
+        free(fgetl(p));
+        free(fgetl(p));
+    }
+    char *l = 0;
+    float score = 0;
+    char player = 0;
+    while((l = fgetl(p))){
+        fprintf(stderr, "%s  \t", l);
+        int n = sscanf(l, "= %c+%f", &player, &score);
+        free(l);
+        if (n == 2) break;
+    }
+    if(player == 'W') score = -score;
+    pclose(p);
+    return score;
+}
+
+void self_go(char *filename, char *weightfile, char *f2, char *w2, int multi)
+{
+    network net = parse_network_cfg(filename);
+    if(weightfile){
+        load_weights(&net, weightfile);
     }
 
+    network net2 = net;
+    if(f2){
+        net2 = parse_network_cfg(f2);
+        if(w2){
+            load_weights(&net2, w2);
+        }
+    }
+    srand(time(0));
+    char boards[300][93];
+    int count = 0;
+    set_batch_network(&net, 1);
+    set_batch_network(&net2, 1);
+    float *board = calloc(19*19, sizeof(float));
+    char *one = calloc(91, sizeof(char));
+    char *two = calloc(91, sizeof(char));
+    int done = 0;
+    int player = 1;
+    int p1 = 0;
+    int p2 = 0;
+    int total = 0;
+    while(1){
+        if (done || count >= 300){
+            float score = score_game(board);
+            int i = (score > 0)? 0 : 1;
+            if((score > 0) == (total%2==0)) ++p1;
+            else ++p2;
+            ++total;
+            fprintf(stderr, "Total: %d, Player 1: %f, Player 2: %f\n", total, (float)p1/total, (float)p2/total);
+            int j;
+            for(; i < count; i += 2){
+                for(j = 0; j < 93; ++j){
+                    printf("%c", boards[i][j]);
+                }
+                printf("\n");
+            }
+            memset(board, 0, 19*19*sizeof(float));
+            player = 1;
+            done = 0;
+            count = 0;
+            fflush(stdout);
+            fflush(stderr);
+        }
+        //print_board(board, 1, 0);
+        //sleep(1);
+        network use = ((total%2==0) == (player==1)) ? net : net2;
+        int index = generate_move(use, player, board, multi, .1, .7, two, 0);
+        if(index < 0){
+            done = 1;
+            continue;
+        }
+        int row = index / 19;
+        int col = index % 19;
+
+        char *swap = two;
+        two = one;
+        one = swap;
+
+        if(player < 0) flip_board(board);
+        boards[count][0] = row;
+        boards[count][1] = col;
+        board_to_string(boards[count] + 2, board);
+        if(player < 0) flip_board(board);
+        ++count;
+
+        move_go(board, player, row, col);
+        board_to_string(one, board);
+
+        player = -player;
+    }
 }
 
 void run_go(int argc, char **argv)
@@ -611,8 +811,12 @@
 
     char *cfg = argv[3];
     char *weights = (argc > 4) ? argv[4] : 0;
+    char *c2 = (argc > 5) ? argv[5] : 0;
+    char *w2 = (argc > 6) ? argv[6] : 0;
     int multi = find_arg(argc, argv, "-multi");
     if(0==strcmp(argv[2], "train")) train_go(cfg, weights);
+    else if(0==strcmp(argv[2], "valid")) valid_go(cfg, weights, multi);
+    else if(0==strcmp(argv[2], "self")) self_go(cfg, weights, c2, w2, multi);
     else if(0==strcmp(argv[2], "test")) test_go(cfg, weights, multi);
     else if(0==strcmp(argv[2], "engine")) engine_go(cfg, weights, multi);
 }
diff --git a/src/gru_layer.c b/src/gru_layer.c
new file mode 100644
index 0000000..1c41cbf
--- /dev/null
+++ b/src/gru_layer.c
@@ -0,0 +1,307 @@
+#include "gru_layer.h"
+#include "connected_layer.h"
+#include "utils.h"
+#include "cuda.h"
+#include "blas.h"
+#include "gemm.h"
+
+#include <math.h>
+#include <stdio.h>
+#include <stdlib.h>
+#include <string.h>
+
+static void increment_layer(layer *l, int steps)
+{
+    int num = l->outputs*l->batch*steps;
+    l->output += num;
+    l->delta += num;
+    l->x += num;
+    l->x_norm += num;
+
+#ifdef GPU
+    l->output_gpu += num;
+    l->delta_gpu += num;
+    l->x_gpu += num;
+    l->x_norm_gpu += num;
+#endif
+}
+
+layer make_gru_layer(int batch, int inputs, int outputs, int steps, int batch_normalize)
+{
+    fprintf(stderr, "GRU Layer: %d inputs, %d outputs\n", inputs, outputs);
+    batch = batch / steps;
+    layer l = {0};
+    l.batch = batch;
+    l.type = GRU;
+    l.steps = steps;
+    l.inputs = inputs;
+
+    l.input_z_layer = malloc(sizeof(layer));
+    fprintf(stderr, "\t\t");
+    *(l.input_z_layer) = make_connected_layer(batch*steps, inputs, outputs, LINEAR, batch_normalize);
+    l.input_z_layer->batch = batch;
+
+    l.state_z_layer = malloc(sizeof(layer));
+    fprintf(stderr, "\t\t");
+    *(l.state_z_layer) = make_connected_layer(batch*steps, outputs, outputs, LINEAR, batch_normalize);
+    l.state_z_layer->batch = batch;
+
+
+
+    l.input_r_layer = malloc(sizeof(layer));
+    fprintf(stderr, "\t\t");
+    *(l.input_r_layer) = make_connected_layer(batch*steps, inputs, outputs, LINEAR, batch_normalize);
+    l.input_r_layer->batch = batch;
+
+    l.state_r_layer = malloc(sizeof(layer));
+    fprintf(stderr, "\t\t");
+    *(l.state_r_layer) = make_connected_layer(batch*steps, outputs, outputs, LINEAR, batch_normalize);
+    l.state_r_layer->batch = batch;
+
+
+
+    l.input_h_layer = malloc(sizeof(layer));
+    fprintf(stderr, "\t\t");
+    *(l.input_h_layer) = make_connected_layer(batch*steps, inputs, outputs, LINEAR, batch_normalize);
+    l.input_h_layer->batch = batch;
+
+    l.state_h_layer = malloc(sizeof(layer));
+    fprintf(stderr, "\t\t");
+    *(l.state_h_layer) = make_connected_layer(batch*steps, outputs, outputs, LINEAR, batch_normalize);
+    l.state_h_layer->batch = batch;
+
+    l.batch_normalize = batch_normalize;
+
+
+    l.outputs = outputs;
+    l.output = calloc(outputs*batch*steps, sizeof(float));
+    l.delta = calloc(outputs*batch*steps, sizeof(float));
+
+#ifdef GPU
+    l.forgot_state_gpu = cuda_make_array(l.output, batch*outputs);
+    l.forgot_delta_gpu = cuda_make_array(l.output, batch*outputs);
+    l.prev_state_gpu = cuda_make_array(l.output, batch*outputs);
+    l.state_gpu = cuda_make_array(l.output, batch*outputs);
+    l.output_gpu = cuda_make_array(l.output, batch*outputs*steps);
+    l.delta_gpu = cuda_make_array(l.delta, batch*outputs*steps);
+    l.r_gpu = cuda_make_array(l.output_gpu, batch*outputs);
+    l.z_gpu = cuda_make_array(l.output_gpu, batch*outputs);
+    l.h_gpu = cuda_make_array(l.output_gpu, batch*outputs);
+#endif
+
+    return l;
+}
+
+void update_gru_layer(layer l, int batch, float learning_rate, float momentum, float decay)
+{
+    update_connected_layer(*(l.input_layer), batch, learning_rate, momentum, decay);
+    update_connected_layer(*(l.self_layer), batch, learning_rate, momentum, decay);
+    update_connected_layer(*(l.output_layer), batch, learning_rate, momentum, decay);
+}
+
+void forward_gru_layer(layer l, network_state state)
+{
+}
+
+void backward_gru_layer(layer l, network_state state)
+{
+}
+
+#ifdef GPU
+
+void pull_gru_layer(layer l)
+{
+}
+
+void push_gru_layer(layer l)
+{
+}
+
+void update_gru_layer_gpu(layer l, int batch, float learning_rate, float momentum, float decay)
+{
+    update_connected_layer_gpu(*(l.input_r_layer), batch, learning_rate, momentum, decay);
+    update_connected_layer_gpu(*(l.input_z_layer), batch, learning_rate, momentum, decay);
+    update_connected_layer_gpu(*(l.input_h_layer), batch, learning_rate, momentum, decay);
+    update_connected_layer_gpu(*(l.state_r_layer), batch, learning_rate, momentum, decay);
+    update_connected_layer_gpu(*(l.state_z_layer), batch, learning_rate, momentum, decay);
+    update_connected_layer_gpu(*(l.state_h_layer), batch, learning_rate, momentum, decay);
+}
+
+void forward_gru_layer_gpu(layer l, network_state state)
+{
+    network_state s = {0};
+    s.train = state.train;
+    int i;
+    layer input_z_layer = *(l.input_z_layer);
+    layer input_r_layer = *(l.input_r_layer);
+    layer input_h_layer = *(l.input_h_layer);
+
+    layer state_z_layer = *(l.state_z_layer);
+    layer state_r_layer = *(l.state_r_layer);
+    layer state_h_layer = *(l.state_h_layer);
+
+    fill_ongpu(l.outputs * l.batch * l.steps, 0, input_z_layer.delta_gpu, 1);
+    fill_ongpu(l.outputs * l.batch * l.steps, 0, input_r_layer.delta_gpu, 1);
+    fill_ongpu(l.outputs * l.batch * l.steps, 0, input_h_layer.delta_gpu, 1);
+
+    fill_ongpu(l.outputs * l.batch * l.steps, 0, state_z_layer.delta_gpu, 1);
+    fill_ongpu(l.outputs * l.batch * l.steps, 0, state_r_layer.delta_gpu, 1);
+    fill_ongpu(l.outputs * l.batch * l.steps, 0, state_h_layer.delta_gpu, 1);
+    if(state.train) {
+        fill_ongpu(l.outputs * l.batch * l.steps, 0, l.delta_gpu, 1);
+        copy_ongpu(l.outputs*l.batch, l.state_gpu, 1, l.prev_state_gpu, 1);
+    }
+
+    for (i = 0; i < l.steps; ++i) {
+        s.input = l.state_gpu;
+        forward_connected_layer_gpu(state_z_layer, s);
+        forward_connected_layer_gpu(state_r_layer, s);
+
+        s.input = state.input;
+        forward_connected_layer_gpu(input_z_layer, s);
+        forward_connected_layer_gpu(input_r_layer, s);
+        forward_connected_layer_gpu(input_h_layer, s);
+
+
+        copy_ongpu(l.outputs*l.batch, input_z_layer.output_gpu, 1, l.z_gpu, 1);
+        axpy_ongpu(l.outputs*l.batch, 1, state_z_layer.output_gpu, 1, l.z_gpu, 1);
+
+        copy_ongpu(l.outputs*l.batch, input_r_layer.output_gpu, 1, l.r_gpu, 1);
+        axpy_ongpu(l.outputs*l.batch, 1, state_r_layer.output_gpu, 1, l.r_gpu, 1);
+
+        activate_array_ongpu(l.z_gpu, l.outputs*l.batch, LOGISTIC);
+        activate_array_ongpu(l.r_gpu, l.outputs*l.batch, LOGISTIC);
+
+        copy_ongpu(l.outputs*l.batch, l.state_gpu, 1, l.forgot_state_gpu, 1);
+        mul_ongpu(l.outputs*l.batch, l.r_gpu, 1, l.forgot_state_gpu, 1);
+
+        s.input = l.forgot_state_gpu;
+        forward_connected_layer_gpu(state_h_layer, s);
+
+        copy_ongpu(l.outputs*l.batch, input_h_layer.output_gpu, 1, l.h_gpu, 1);
+        axpy_ongpu(l.outputs*l.batch, 1, state_h_layer.output_gpu, 1, l.h_gpu, 1);
+
+        #ifdef USET
+        activate_array_ongpu(l.h_gpu, l.outputs*l.batch, TANH);
+        #else
+        activate_array_ongpu(l.h_gpu, l.outputs*l.batch, LOGISTIC);
+        #endif
+
+        weighted_sum_gpu(l.state_gpu, l.h_gpu, l.z_gpu, l.outputs*l.batch, l.output_gpu);
+
+        copy_ongpu(l.outputs*l.batch, l.output_gpu, 1, l.state_gpu, 1);
+
+        state.input += l.inputs*l.batch;
+        l.output_gpu += l.outputs*l.batch;
+        increment_layer(&input_z_layer, 1);
+        increment_layer(&input_r_layer, 1);
+        increment_layer(&input_h_layer, 1);
+
+        increment_layer(&state_z_layer, 1);
+        increment_layer(&state_r_layer, 1);
+        increment_layer(&state_h_layer, 1);
+    }
+}
+
+void backward_gru_layer_gpu(layer l, network_state state)
+{
+    network_state s = {0};
+    s.train = state.train;
+    int i;
+    layer input_z_layer = *(l.input_z_layer);
+    layer input_r_layer = *(l.input_r_layer);
+    layer input_h_layer = *(l.input_h_layer);
+
+    layer state_z_layer = *(l.state_z_layer);
+    layer state_r_layer = *(l.state_r_layer);
+    layer state_h_layer = *(l.state_h_layer);
+
+    increment_layer(&input_z_layer, l.steps - 1);
+    increment_layer(&input_r_layer, l.steps - 1);
+    increment_layer(&input_h_layer, l.steps - 1);
+
+    increment_layer(&state_z_layer, l.steps - 1);
+    increment_layer(&state_r_layer, l.steps - 1);
+    increment_layer(&state_h_layer, l.steps - 1);
+
+    state.input += l.inputs*l.batch*(l.steps-1);
+    if(state.delta) state.delta += l.inputs*l.batch*(l.steps-1);
+    l.output_gpu += l.outputs*l.batch*(l.steps-1);
+    l.delta_gpu += l.outputs*l.batch*(l.steps-1);
+    for (i = l.steps-1; i >= 0; --i) {
+        if(i != 0) copy_ongpu(l.outputs*l.batch, l.output_gpu - l.outputs*l.batch, 1, l.prev_state_gpu, 1);
+        float *prev_delta_gpu = (i == 0) ? 0 : l.delta_gpu - l.outputs*l.batch;
+
+        copy_ongpu(l.outputs*l.batch, input_z_layer.output_gpu, 1, l.z_gpu, 1);
+        axpy_ongpu(l.outputs*l.batch, 1, state_z_layer.output_gpu, 1, l.z_gpu, 1);
+
+        copy_ongpu(l.outputs*l.batch, input_r_layer.output_gpu, 1, l.r_gpu, 1);
+        axpy_ongpu(l.outputs*l.batch, 1, state_r_layer.output_gpu, 1, l.r_gpu, 1);
+
+        activate_array_ongpu(l.z_gpu, l.outputs*l.batch, LOGISTIC);
+        activate_array_ongpu(l.r_gpu, l.outputs*l.batch, LOGISTIC);
+
+        copy_ongpu(l.outputs*l.batch, input_h_layer.output_gpu, 1, l.h_gpu, 1);
+        axpy_ongpu(l.outputs*l.batch, 1, state_h_layer.output_gpu, 1, l.h_gpu, 1);
+
+        #ifdef USET
+        activate_array_ongpu(l.h_gpu, l.outputs*l.batch, TANH);
+        #else
+        activate_array_ongpu(l.h_gpu, l.outputs*l.batch, LOGISTIC);
+        #endif
+        
+        weighted_delta_gpu(l.prev_state_gpu, l.h_gpu, l.z_gpu, prev_delta_gpu, input_h_layer.delta_gpu, input_z_layer.delta_gpu, l.outputs*l.batch, l.delta_gpu);
+
+        #ifdef USET
+        gradient_array_ongpu(l.h_gpu, l.outputs*l.batch, TANH, input_h_layer.delta_gpu);
+        #else
+        gradient_array_ongpu(l.h_gpu, l.outputs*l.batch, LOGISTIC, input_h_layer.delta_gpu);
+        #endif
+
+        copy_ongpu(l.outputs*l.batch, input_h_layer.delta_gpu, 1, state_h_layer.delta_gpu, 1);
+        
+        copy_ongpu(l.outputs*l.batch, l.prev_state_gpu, 1, l.forgot_state_gpu, 1);
+        mul_ongpu(l.outputs*l.batch, l.r_gpu, 1, l.forgot_state_gpu, 1);
+        fill_ongpu(l.outputs*l.batch, 0, l.forgot_delta_gpu, 1);
+
+        s.input = l.forgot_state_gpu;
+        s.delta = l.forgot_delta_gpu;
+        
+        backward_connected_layer_gpu(state_h_layer, s);
+        if(prev_delta_gpu) mult_add_into_gpu(l.outputs*l.batch, l.forgot_delta_gpu, l.r_gpu, prev_delta_gpu);
+        mult_add_into_gpu(l.outputs*l.batch, l.forgot_delta_gpu, l.prev_state_gpu, input_r_layer.delta_gpu);
+
+        gradient_array_ongpu(l.r_gpu, l.outputs*l.batch, LOGISTIC, input_r_layer.delta_gpu);
+        copy_ongpu(l.outputs*l.batch, input_r_layer.delta_gpu, 1, state_r_layer.delta_gpu, 1);
+
+        gradient_array_ongpu(l.z_gpu, l.outputs*l.batch, LOGISTIC, input_z_layer.delta_gpu);
+        copy_ongpu(l.outputs*l.batch, input_z_layer.delta_gpu, 1, state_z_layer.delta_gpu, 1);
+        
+        s.input = l.prev_state_gpu;
+        s.delta = prev_delta_gpu;
+        
+        backward_connected_layer_gpu(state_r_layer, s);
+        backward_connected_layer_gpu(state_z_layer, s);
+
+        s.input = state.input;
+        s.delta = state.delta;
+        
+        backward_connected_layer_gpu(input_h_layer, s);
+        backward_connected_layer_gpu(input_r_layer, s);
+        backward_connected_layer_gpu(input_z_layer, s);
+
+
+        state.input -= l.inputs*l.batch;
+        if(state.delta) state.delta -= l.inputs*l.batch;
+        l.output_gpu -= l.outputs*l.batch;
+        l.delta_gpu -= l.outputs*l.batch;
+        increment_layer(&input_z_layer, -1);
+        increment_layer(&input_r_layer, -1);
+        increment_layer(&input_h_layer, -1);
+
+        increment_layer(&state_z_layer, -1);
+        increment_layer(&state_r_layer, -1);
+        increment_layer(&state_h_layer, -1);
+    }
+}
+#endif
diff --git a/src/gru_layer.h b/src/gru_layer.h
new file mode 100644
index 0000000..bb9478b
--- /dev/null
+++ b/src/gru_layer.h
@@ -0,0 +1,25 @@
+
+#ifndef RNN_LAYER_H
+#define RNN_LAYER_H
+
+#include "activations.h"
+#include "layer.h"
+#include "network.h"
+#define USET
+
+layer make_rnn_layer(int batch, int inputs, int hidden, int outputs, int steps, ACTIVATION activation, int batch_normalize, int log);
+
+void forward_rnn_layer(layer l, network_state state);
+void backward_rnn_layer(layer l, network_state state);
+void update_rnn_layer(layer l, int batch, float learning_rate, float momentum, float decay);
+
+#ifdef GPU
+void forward_rnn_layer_gpu(layer l, network_state state);
+void backward_rnn_layer_gpu(layer l, network_state state);
+void update_rnn_layer_gpu(layer l, int batch, float learning_rate, float momentum, float decay);
+void push_rnn_layer(layer l);
+void pull_rnn_layer(layer l);
+#endif
+
+#endif
+
diff --git a/src/image.c b/src/image.c
index ef76d45..aff5f64 100644
--- a/src/image.c
+++ b/src/image.c
@@ -110,6 +110,7 @@
         float prob = probs[i][class];
         if(prob > thresh){
             int width = pow(prob, 1./2.)*10+1;
+            width = 8;
             printf("%s: %.2f\n", names[class], prob);
             int offset = class*17 % classes;
             float red = get_color(0,offset,classes);
@@ -511,6 +512,7 @@
             w = (w * min) / h;
             h = min;
         }
+        if(w == im.w && h == im.h) return im;
         image resized = resize_image(im, w, h);
         return resized;
     }
@@ -523,13 +525,7 @@
         int dy = rand_int(0, resized.h - size);
         image crop = crop_image(resized, dx, dy, size, size);
 
-        /*
-           show_image(im, "orig");
-           show_image(crop, "cropped");
-           cvWaitKey(0);
-         */
-
-        free_image(resized);
+        if(resized.data != im.data) free_image(resized);
         return crop;
     }
 
diff --git a/src/layer.h b/src/layer.h
index c2cf307..2376929 100644
--- a/src/layer.h
+++ b/src/layer.h
@@ -23,7 +23,11 @@
     SHORTCUT,
     ACTIVE,
     RNN,
-    CRNN
+    GRU,
+    CRNN,
+    BATCHNORM,
+    NETWORK,
+    BLANK
 } LAYER_TYPE;
 
 typedef enum{
@@ -54,6 +58,7 @@
     int flip;
     int index;
     int binary;
+    int xnor;
     int steps;
     int hidden;
     float dot;
@@ -95,6 +100,10 @@
     char  *cfilters;
     float *filter_updates;
     float *state;
+    float *state_delta;
+
+    float *concat;
+    float *concat_delta;
 
     float *binary_filters;
 
@@ -132,17 +141,44 @@
     struct layer *self_layer;
     struct layer *output_layer;
 
+    struct layer *input_gate_layer;
+    struct layer *state_gate_layer;
+    struct layer *input_save_layer;
+    struct layer *state_save_layer;
+    struct layer *input_state_layer;
+    struct layer *state_state_layer;
+
+    struct layer *input_z_layer;
+    struct layer *state_z_layer;
+
+    struct layer *input_r_layer;
+    struct layer *state_r_layer;
+
+    struct layer *input_h_layer;
+    struct layer *state_h_layer;
+
     #ifdef GPU
+    float *z_gpu;
+    float *r_gpu;
+    float *h_gpu;
+
     int *indexes_gpu;
+    float * prev_state_gpu;
+    float * forgot_state_gpu;
+    float * forgot_delta_gpu;
     float * state_gpu;
+    float * state_delta_gpu;
+    float * gate_gpu;
+    float * gate_delta_gpu;
+    float * save_gpu;
+    float * save_delta_gpu;
+    float * concat_gpu;
+    float * concat_delta_gpu;
     float * filters_gpu;
     float * filter_updates_gpu;
 
+    float *binary_input_gpu;
     float *binary_filters_gpu;
-    float *mean_filters_gpu;
-
-    float * spatial_mean_gpu;
-    float * spatial_variance_gpu;
 
     float * mean_gpu;
     float * variance_gpu;
@@ -150,9 +186,6 @@
     float * rolling_mean_gpu;
     float * rolling_variance_gpu;
 
-    float * spatial_mean_delta_gpu;
-    float * spatial_variance_delta_gpu;
-
     float * variance_delta_gpu;
     float * mean_delta_gpu;
 
diff --git a/src/network.c b/src/network.c
index e6fb51e..ca485d6 100644
--- a/src/network.c
+++ b/src/network.c
@@ -8,6 +8,7 @@
 
 #include "crop_layer.h"
 #include "connected_layer.h"
+#include "gru_layer.h"
 #include "rnn_layer.h"
 #include "crnn_layer.h"
 #include "local_layer.h"
@@ -16,6 +17,7 @@
 #include "deconvolutional_layer.h"
 #include "detection_layer.h"
 #include "normalization_layer.h"
+#include "batchnorm_layer.h"
 #include "maxpool_layer.h"
 #include "avgpool_layer.h"
 #include "cost_layer.h"
@@ -86,6 +88,8 @@
             return "connected";
         case RNN:
             return "rnn";
+        case GRU:
+            return "gru";
         case CRNN:
             return "crnn";
         case MAXPOOL:
@@ -108,6 +112,8 @@
             return "shortcut";
         case NORMALIZATION:
             return "normalization";
+        case BATCHNORM:
+            return "batchnorm";
         default:
             break;
     }
@@ -146,12 +152,16 @@
             forward_local_layer(l, state);
         } else if(l.type == NORMALIZATION){
             forward_normalization_layer(l, state);
+        } else if(l.type == BATCHNORM){
+            forward_batchnorm_layer(l, state);
         } else if(l.type == DETECTION){
             forward_detection_layer(l, state);
         } else if(l.type == CONNECTED){
             forward_connected_layer(l, state);
         } else if(l.type == RNN){
             forward_rnn_layer(l, state);
+        } else if(l.type == GRU){
+            forward_gru_layer(l, state);
         } else if(l.type == CRNN){
             forward_crnn_layer(l, state);
         } else if(l.type == CROP){
@@ -190,6 +200,8 @@
             update_connected_layer(l, update_batch, rate, net.momentum, net.decay);
         } else if(l.type == RNN){
             update_rnn_layer(l, update_batch, rate, net.momentum, net.decay);
+        } else if(l.type == GRU){
+            update_gru_layer(l, update_batch, rate, net.momentum, net.decay);
         } else if(l.type == CRNN){
             update_crnn_layer(l, update_batch, rate, net.momentum, net.decay);
         } else if(l.type == LOCAL){
@@ -200,6 +212,9 @@
 
 float *get_network_output(network net)
 {
+    #ifdef GPU
+        return get_network_output_gpu(net);
+    #endif 
     int i;
     for(i = net.n-1; i > 0; --i) if(net.layers[i].type != COST) break;
     return net.layers[i].output;
@@ -254,6 +269,8 @@
             backward_activation_layer(l, state);
         } else if(l.type == NORMALIZATION){
             backward_normalization_layer(l, state);
+        } else if(l.type == BATCHNORM){
+            backward_batchnorm_layer(l, state);
         } else if(l.type == MAXPOOL){
             if(i != 0) backward_maxpool_layer(l, state);
         } else if(l.type == AVGPOOL){
@@ -268,6 +285,8 @@
             backward_connected_layer(l, state);
         } else if(l.type == RNN){
             backward_rnn_layer(l, state);
+        } else if(l.type == GRU){
+            backward_gru_layer(l, state);
         } else if(l.type == CRNN){
             backward_crnn_layer(l, state);
         } else if(l.type == LOCAL){
diff --git a/src/network.h b/src/network.h
index f4f8b5c..66ceb30 100644
--- a/src/network.h
+++ b/src/network.h
@@ -37,6 +37,7 @@
     int inputs;
     int h, w, c;
     int max_crop;
+    int min_crop;
 
     #ifdef GPU
     float **input_gpu;
diff --git a/src/network_kernels.cu b/src/network_kernels.cu
index 730634e..986a808 100644
--- a/src/network_kernels.cu
+++ b/src/network_kernels.cu
@@ -16,6 +16,7 @@
 #include "crop_layer.h"
 #include "connected_layer.h"
 #include "rnn_layer.h"
+#include "gru_layer.h"
 #include "crnn_layer.h"
 #include "detection_layer.h"
 #include "convolutional_layer.h"
@@ -24,6 +25,7 @@
 #include "maxpool_layer.h"
 #include "avgpool_layer.h"
 #include "normalization_layer.h"
+#include "batchnorm_layer.h"
 #include "cost_layer.h"
 #include "local_layer.h"
 #include "softmax_layer.h"
@@ -60,6 +62,8 @@
             forward_connected_layer_gpu(l, state);
         } else if(l.type == RNN){
             forward_rnn_layer_gpu(l, state);
+        } else if(l.type == GRU){
+            forward_gru_layer_gpu(l, state);
         } else if(l.type == CRNN){
             forward_crnn_layer_gpu(l, state);
         } else if(l.type == CROP){
@@ -70,6 +74,8 @@
             forward_softmax_layer_gpu(l, state);
         } else if(l.type == NORMALIZATION){
             forward_normalization_layer_gpu(l, state);
+        } else if(l.type == BATCHNORM){
+            forward_batchnorm_layer_gpu(l, state);
         } else if(l.type == MAXPOOL){
             forward_maxpool_layer_gpu(l, state);
         } else if(l.type == AVGPOOL){
@@ -119,12 +125,16 @@
             backward_detection_layer_gpu(l, state);
         } else if(l.type == NORMALIZATION){
             backward_normalization_layer_gpu(l, state);
+        } else if(l.type == BATCHNORM){
+            backward_batchnorm_layer_gpu(l, state);
         } else if(l.type == SOFTMAX){
             if(i != 0) backward_softmax_layer_gpu(l, state);
         } else if(l.type == CONNECTED){
             backward_connected_layer_gpu(l, state);
         } else if(l.type == RNN){
             backward_rnn_layer_gpu(l, state);
+        } else if(l.type == GRU){
+            backward_gru_layer_gpu(l, state);
         } else if(l.type == CRNN){
             backward_crnn_layer_gpu(l, state);
         } else if(l.type == COST){
@@ -150,6 +160,8 @@
             update_deconvolutional_layer_gpu(l, rate, net.momentum, net.decay);
         } else if(l.type == CONNECTED){
             update_connected_layer_gpu(l, update_batch, rate, net.momentum, net.decay);
+        } else if(l.type == GRU){
+            update_gru_layer_gpu(l, update_batch, rate, net.momentum, net.decay);
         } else if(l.type == RNN){
             update_rnn_layer_gpu(l, update_batch, rate, net.momentum, net.decay);
         } else if(l.type == CRNN){
diff --git a/src/parser.c b/src/parser.c
index c109a14..6c88fd5 100644
--- a/src/parser.c
+++ b/src/parser.c
@@ -9,9 +9,11 @@
 #include "convolutional_layer.h"
 #include "activation_layer.h"
 #include "normalization_layer.h"
+#include "batchnorm_layer.h"
 #include "deconvolutional_layer.h"
 #include "connected_layer.h"
 #include "rnn_layer.h"
+#include "gru_layer.h"
 #include "crnn_layer.h"
 #include "maxpool_layer.h"
 #include "softmax_layer.h"
@@ -37,12 +39,14 @@
 int is_deconvolutional(section *s);
 int is_connected(section *s);
 int is_rnn(section *s);
+int is_gru(section *s);
 int is_crnn(section *s);
 int is_maxpool(section *s);
 int is_avgpool(section *s);
 int is_dropout(section *s);
 int is_softmax(section *s);
 int is_normalization(section *s);
+int is_batchnorm(section *s);
 int is_crop(section *s);
 int is_shortcut(section *s);
 int is_cost(section *s);
@@ -157,8 +161,9 @@
     if(!(h && w && c)) error("Layer before convolutional layer must output image.");
     int batch_normalize = option_find_int_quiet(options, "batch_normalize", 0);
     int binary = option_find_int_quiet(options, "binary", 0);
+    int xnor = option_find_int_quiet(options, "xnor", 0);
 
-    convolutional_layer layer = make_convolutional_layer(batch,h,w,c,n,size,stride,pad,activation, batch_normalize, binary);
+    convolutional_layer layer = make_convolutional_layer(batch,h,w,c,n,size,stride,pad,activation, batch_normalize, binary, xnor);
     layer.flipped = option_find_int_quiet(options, "flipped", 0);
     layer.dot = option_find_float_quiet(options, "dot", 0);
 
@@ -203,6 +208,16 @@
     return l;
 }
 
+layer parse_gru(list *options, size_params params)
+{
+    int output = option_find_int(options, "output",1);
+    int batch_normalize = option_find_int_quiet(options, "batch_normalize", 0);
+
+    layer l = make_gru_layer(params.batch, params.inputs, output, params.time_steps, batch_normalize);
+
+    return l;
+}
+
 connected_layer parse_connected(list *options, size_params params)
 {
     int output = option_find_int(options, "output",1);
@@ -333,6 +348,12 @@
     return l;
 }
 
+layer parse_batchnorm(list *options, size_params params)
+{
+    layer l = make_batchnorm_layer(params.batch, params.w, params.h, params.c);
+    return l;
+}
+
 layer parse_shortcut(list *options, size_params params, network net)
 {
     char *l = option_find(options, "from");   
@@ -438,6 +459,7 @@
     net->c = option_find_int_quiet(options, "channels",0);
     net->inputs = option_find_int_quiet(options, "inputs", net->h * net->w * net->c);
     net->max_crop = option_find_int_quiet(options, "max_crop",net->w*2);
+    net->min_crop = option_find_int_quiet(options, "min_crop",net->w);
 
     if(!net->inputs && !(net->h && net->w && net->c)) error("No input parameters supplied");
 
@@ -520,6 +542,8 @@
             l = parse_deconvolutional(options, params);
         }else if(is_rnn(s)){
             l = parse_rnn(options, params);
+        }else if(is_gru(s)){
+            l = parse_gru(options, params);
         }else if(is_crnn(s)){
             l = parse_crnn(options, params);
         }else if(is_connected(s)){
@@ -534,6 +558,8 @@
             l = parse_softmax(options, params);
         }else if(is_normalization(s)){
             l = parse_normalization(options, params);
+        }else if(is_batchnorm(s)){
+            l = parse_batchnorm(options, params);
         }else if(is_maxpool(s)){
             l = parse_maxpool(options, params);
         }else if(is_avgpool(s)){
@@ -573,6 +599,40 @@
     return net;
 }
 
+LAYER_TYPE string_to_layer_type(char * type)
+{
+
+    if (strcmp(type, "[shortcut]")==0) return SHORTCUT;
+    if (strcmp(type, "[crop]")==0) return CROP;
+    if (strcmp(type, "[cost]")==0) return COST;
+    if (strcmp(type, "[detection]")==0) return DETECTION;
+    if (strcmp(type, "[local]")==0) return LOCAL;
+    if (strcmp(type, "[deconv]")==0
+            || strcmp(type, "[deconvolutional]")==0) return DECONVOLUTIONAL;
+    if (strcmp(type, "[conv]")==0
+            || strcmp(type, "[convolutional]")==0) return CONVOLUTIONAL;
+    if (strcmp(type, "[activation]")==0) return ACTIVE;
+    if (strcmp(type, "[net]")==0
+            || strcmp(type, "[network]")==0) return NETWORK;
+    if (strcmp(type, "[crnn]")==0) return CRNN;
+    if (strcmp(type, "[gru]")==0) return GRU;
+    if (strcmp(type, "[rnn]")==0) return RNN;
+    if (strcmp(type, "[conn]")==0
+            || strcmp(type, "[connected]")==0) return CONNECTED;
+    if (strcmp(type, "[max]")==0
+            || strcmp(type, "[maxpool]")==0) return MAXPOOL;
+    if (strcmp(type, "[avg]")==0
+            || strcmp(type, "[avgpool]")==0) return AVGPOOL;
+    if (strcmp(type, "[dropout]")==0) return DROPOUT;
+    if (strcmp(type, "[lrn]")==0
+            || strcmp(type, "[normalization]")==0) return NORMALIZATION;
+    if (strcmp(type, "[batchnorm]")==0) return BATCHNORM;
+    if (strcmp(type, "[soft]")==0
+            || strcmp(type, "[softmax]")==0) return SOFTMAX;
+    if (strcmp(type, "[route]")==0) return ROUTE;
+    return BLANK;
+}
+
 int is_shortcut(section *s)
 {
     return (strcmp(s->type, "[shortcut]")==0);
@@ -616,6 +676,10 @@
 {
     return (strcmp(s->type, "[crnn]")==0);
 }
+int is_gru(section *s)
+{
+    return (strcmp(s->type, "[gru]")==0);
+}
 int is_rnn(section *s)
 {
     return (strcmp(s->type, "[rnn]")==0);
@@ -646,6 +710,11 @@
             || strcmp(s->type, "[normalization]")==0);
 }
 
+int is_batchnorm(section *s)
+{
+    return (strcmp(s->type, "[batchnorm]")==0);
+}
+
 int is_softmax(section *s)
 {
     return (strcmp(s->type, "[soft]")==0
@@ -824,6 +893,13 @@
             save_connected_weights(*(l.input_layer), fp);
             save_connected_weights(*(l.self_layer), fp);
             save_connected_weights(*(l.output_layer), fp);
+        } if(l.type == GRU){
+            save_connected_weights(*(l.input_z_layer), fp);
+            save_connected_weights(*(l.input_r_layer), fp);
+            save_connected_weights(*(l.input_h_layer), fp);
+            save_connected_weights(*(l.state_z_layer), fp);
+            save_connected_weights(*(l.state_r_layer), fp);
+            save_connected_weights(*(l.state_h_layer), fp);
         } if(l.type == CRNN){
             save_convolutional_weights(*(l.input_layer), fp);
             save_convolutional_weights(*(l.self_layer), fp);
@@ -867,10 +943,15 @@
     if(transpose){
         transpose_matrix(l.weights, l.inputs, l.outputs);
     }
+        //printf("Biases: %f mean %f variance\n", mean_array(l.biases, l.outputs), variance_array(l.biases, l.outputs));
+        //printf("Weights: %f mean %f variance\n", mean_array(l.weights, l.outputs*l.inputs), variance_array(l.weights, l.outputs*l.inputs));
     if (l.batch_normalize && (!l.dontloadscales)){
         fread(l.scales, sizeof(float), l.outputs, fp);
         fread(l.rolling_mean, sizeof(float), l.outputs, fp);
         fread(l.rolling_variance, sizeof(float), l.outputs, fp);
+        //printf("Scales: %f mean %f variance\n", mean_array(l.scales, l.outputs), variance_array(l.scales, l.outputs));
+        //printf("rolling_mean: %f mean %f variance\n", mean_array(l.rolling_mean, l.outputs), variance_array(l.rolling_mean, l.outputs));
+        //printf("rolling_variance: %f mean %f variance\n", mean_array(l.rolling_variance, l.outputs), variance_array(l.rolling_variance, l.outputs));
     }
 #ifdef GPU
     if(gpu_index >= 0){
@@ -982,6 +1063,14 @@
             load_connected_weights(*(l.self_layer), fp, transpose);
             load_connected_weights(*(l.output_layer), fp, transpose);
         }
+        if(l.type == GRU){
+            load_connected_weights(*(l.input_z_layer), fp, transpose);
+            load_connected_weights(*(l.input_r_layer), fp, transpose);
+            load_connected_weights(*(l.input_h_layer), fp, transpose);
+            load_connected_weights(*(l.state_z_layer), fp, transpose);
+            load_connected_weights(*(l.state_r_layer), fp, transpose);
+            load_connected_weights(*(l.state_h_layer), fp, transpose);
+        }
         if(l.type == LOCAL){
             int locations = l.out_w*l.out_h;
             int size = l.size*l.size*l.c*l.n*locations;
diff --git a/src/rnn.c b/src/rnn.c
index 30fa4bd..b72fafc 100644
--- a/src/rnn.c
+++ b/src/rnn.c
@@ -1,6 +1,7 @@
 #include "network.h"
 #include "cost_layer.h"
 #include "utils.h"
+#include "blas.h"
 #include "parser.h"
 
 #ifdef OPENCV
@@ -12,29 +13,26 @@
     float *y;
 } float_pair;
 
-float_pair get_rnn_data(unsigned char *text, int characters, int len, int batch, int steps)
+float_pair get_rnn_data(unsigned char *text, size_t *offsets, int characters, size_t len, int batch, int steps)
 {
     float *x = calloc(batch * steps * characters, sizeof(float));
     float *y = calloc(batch * steps * characters, sizeof(float));
     int i,j;
     for(i = 0; i < batch; ++i){
-        int index = rand() %(len - steps - 1);
-        /*
-        int done = 1;
-        while(!done){
-            index = rand() %(len - steps - 1);
-            while(index < len-steps-1 && text[index++] != '\n');
-            if (index < len-steps-1) done = 1;
-            }
-         */
         for(j = 0; j < steps; ++j){
-            x[(j*batch + i)*characters + text[index + j]] = 1;
-            y[(j*batch + i)*characters + text[index + j + 1]] = 1;
+            unsigned char curr = text[(offsets[i])%len];
+            unsigned char next = text[(offsets[i] + 1)%len];
 
-            if(text[index+j] > 255 || text[index+j] <= 0 || text[index+j+1] > 255 || text[index+j+1] <= 0){
-                text[index+j+2] = 0;
-                printf("%d %d %d %d %d\n", index, j, len, (int)text[index+j], (int)text[index+j+1]);
+            x[(j*batch + i)*characters + curr] = 1;
+            y[(j*batch + i)*characters + next] = 1;
+
+            offsets[i] = (offsets[i] + 1) % len;
+
+            if(curr > 255 || curr <= 0 || next > 255 || next <= 0){
+                /*text[(index+j+2)%len] = 0;
+                printf("%ld %d %d %d %d\n", index, j, len, (int)text[index+j], (int)text[index+j+1]);
                 printf("%s", text+index);
+                */
                 error("Bad char");
             }
         }
@@ -45,8 +43,23 @@
     return p;
 }
 
-void train_char_rnn(char *cfgfile, char *weightfile, char *filename)
+void reset_rnn_state(network net, int b)
 {
+    int i;
+    for (i = 0; i < net.n; ++i) {
+        layer l = net.layers[i];
+        #ifdef GPU
+        if(l.state_gpu){
+            fill_ongpu(l.outputs, 0, l.state_gpu + l.outputs*b, 1);
+        }
+        #endif
+    }
+}
+
+void train_char_rnn(char *cfgfile, char *weightfile, char *filename, int clear)
+{
+    srand(time(0));
+    data_seed = time(0);
     FILE *fp = fopen(filename, "rb");
 
     fseek(fp, 0, SEEK_END); 
@@ -58,8 +71,6 @@
     fclose(fp);
 
     char *backup_directory = "/home/pjreddie/backup/";
-    srand(time(0));
-    data_seed = time(0);
     char *base = basecfg(cfgfile);
     fprintf(stderr, "%s\n", base);
     float avg_loss = -1;
@@ -67,18 +78,26 @@
     if(weightfile){
         load_weights(&net, weightfile);
     }
+
     int inputs = get_network_input_size(net);
     fprintf(stderr, "Learning Rate: %g, Momentum: %g, Decay: %g\n", net.learning_rate, net.momentum, net.decay);
     int batch = net.batch;
     int steps = net.time_steps;
-    //*net.seen = 0;
+    if(clear) *net.seen = 0;
     int i = (*net.seen)/net.batch;
 
+    int streams = batch/steps;
+    size_t *offsets = calloc(streams, sizeof(size_t));
+    int j;
+    for(j = 0; j < streams; ++j){
+        offsets[j] = rand_size_t()%size;
+    }
+
     clock_t time;
     while(get_current_batch(net) < net.max_batches){
         i += 1;
         time=clock();
-        float_pair p = get_rnn_data(text, inputs, size, batch/steps, steps);
+        float_pair p = get_rnn_data(text, offsets, inputs, size, streams, steps);
 
         float loss = train_network_datum(net, p.x, p.y) / (batch);
         free(p.x);
@@ -86,7 +105,18 @@
         if (avg_loss < 0) avg_loss = loss;
         avg_loss = avg_loss*.9 + loss*.1;
 
-        fprintf(stderr, "%d: %f, %f avg, %f rate, %lf seconds\n", i, loss, avg_loss, get_current_rate(net), sec(clock()-time));
+        int chars = get_current_batch(net)*batch;
+        fprintf(stderr, "%d: %f, %f avg, %f rate, %lf seconds, %f epochs\n", i, loss, avg_loss, get_current_rate(net), sec(clock()-time), (float) chars/size);
+
+        for(j = 0; j < streams; ++j){
+            //printf("%d\n", j);
+            if(rand()%10 == 0){
+                //fprintf(stderr, "Reset\n");
+                offsets[j] = rand_size_t()%size;
+                reset_rnn_state(net, j);
+            }
+        }
+
         if(i%100==0){
             char buff[256];
             sprintf(buff, "%s/%s_%d.weights", backup_directory, base, i);
@@ -120,6 +150,15 @@
     unsigned char c;
     int len = strlen(seed);
     float *input = calloc(inputs, sizeof(float));
+
+/*
+    fill_cpu(inputs, 0, input, 1);
+    for(i = 0; i < 10; ++i){
+        network_predict(net, input);
+    }
+    fill_cpu(inputs, 0, input, 1);
+    */
+
     for(i = 0; i < len-1; ++i){
         c = seed[i];
         input[(int)c] = 1;
@@ -130,16 +169,16 @@
     c = seed[len-1];
     for(i = 0; i < num; ++i){
         printf("%c", c);
-        float r = rand_uniform(0,1);
-        float sum = 0;
         input[(int)c] = 1;
         float *out = network_predict(net, input);
         input[(int)c] = 0;
-        for(j = 0; j < inputs; ++j){
-            sum += out[j];
-            if(sum > r) break;
+        for(j = 32; j < 127; ++j){
+            //printf("%d %c %f\n",j, j, out[j]);
         }
-        c = j;
+        for(j = 0; j < inputs; ++j){
+            //if (out[j] < .0001) out[j] = 0;
+        }
+        c = sample_array(out, inputs);
     }
     printf("\n");
 }
@@ -158,11 +197,16 @@
     int count = 0;
     int c;
     float *input = calloc(inputs, sizeof(float));
+    int i;
+    for(i = 0; i < 100; ++i){
+        network_predict(net, input);
+    }
     float sum = 0;
     c = getc(stdin);
     float log2 = log(2);
     while(c != EOF){
         int next = getc(stdin);
+        if(next < 0 || next >= 255) error("Out of range character");
         if(next == EOF) break;
         ++count;
         input[c] = 1;
@@ -170,8 +214,8 @@
         input[c] = 0;
         sum += log(out[next])/log2;
         c = next;
+        printf("%d Perplexity: %f\n", count, pow(2, -sum/count));
     }
-    printf("Perplexity: %f\n", pow(2, -sum/count));
 }
 
 
@@ -186,10 +230,11 @@
     int len = find_int_arg(argc, argv, "-len", 1000);
     float temp = find_float_arg(argc, argv, "-temp", .7);
     int rseed = find_int_arg(argc, argv, "-srand", time(0));
+    int clear = find_arg(argc, argv, "-clear");
 
     char *cfg = argv[3];
     char *weights = (argc > 4) ? argv[4] : 0;
-    if(0==strcmp(argv[2], "train")) train_char_rnn(cfg, weights, filename);
+    if(0==strcmp(argv[2], "train")) train_char_rnn(cfg, weights, filename, clear);
     else if(0==strcmp(argv[2], "valid")) valid_char_rnn(cfg, weights);
     else if(0==strcmp(argv[2], "generate")) test_char_rnn(cfg, weights, len, seed, temp, rseed);
 }
diff --git a/src/rnn_layer.c b/src/rnn_layer.c
index 35cf992..b713899 100644
--- a/src/rnn_layer.c
+++ b/src/rnn_layer.c
@@ -242,8 +242,6 @@
     increment_layer(&output_layer, l.steps - 1);
     l.state_gpu += l.hidden*l.batch*l.steps;
     for (i = l.steps-1; i >= 0; --i) {
-        copy_ongpu(l.hidden * l.batch, input_layer.output_gpu, 1, l.state_gpu, 1);
-        axpy_ongpu(l.hidden * l.batch, 1, self_layer.output_gpu, 1, l.state_gpu, 1);
 
         s.input = l.state_gpu;
         s.delta = self_layer.delta_gpu;
@@ -251,12 +249,14 @@
 
         l.state_gpu -= l.hidden*l.batch;
 
+        copy_ongpu(l.hidden*l.batch, self_layer.delta_gpu, 1, input_layer.delta_gpu, 1);
+
         s.input = l.state_gpu;
         s.delta = self_layer.delta_gpu - l.hidden*l.batch;
         if (i == 0) s.delta = 0;
         backward_connected_layer_gpu(self_layer, s);
 
-        copy_ongpu(l.hidden*l.batch, self_layer.delta_gpu, 1, input_layer.delta_gpu, 1);
+        //copy_ongpu(l.hidden*l.batch, self_layer.delta_gpu, 1, input_layer.delta_gpu, 1);
         if (i > 0 && l.shortcut) axpy_ongpu(l.hidden*l.batch, 1, self_layer.delta_gpu, 1, self_layer.delta_gpu - l.hidden*l.batch, 1);
         s.input = state.input + i*l.inputs*l.batch;
         if(state.delta) s.delta = state.delta + i*l.inputs*l.batch;
diff --git a/src/rnn_layer.h b/src/rnn_layer.h
index 00dc1be..9e19cee 100644
--- a/src/rnn_layer.h
+++ b/src/rnn_layer.h
@@ -1,23 +1,23 @@
 
-#ifndef RNN_LAYER_H
-#define RNN_LAYER_H
+#ifndef GRU_LAYER_H
+#define GRU_LAYER_H
 
 #include "activations.h"
 #include "layer.h"
 #include "network.h"
 
-layer make_rnn_layer(int batch, int inputs, int hidden, int outputs, int steps, ACTIVATION activation, int batch_normalize, int log);
+layer make_gru_layer(int batch, int inputs, int outputs, int steps, int batch_normalize);
 
-void forward_rnn_layer(layer l, network_state state);
-void backward_rnn_layer(layer l, network_state state);
-void update_rnn_layer(layer l, int batch, float learning_rate, float momentum, float decay);
+void forward_gru_layer(layer l, network_state state);
+void backward_gru_layer(layer l, network_state state);
+void update_gru_layer(layer l, int batch, float learning_rate, float momentum, float decay);
 
 #ifdef GPU
-void forward_rnn_layer_gpu(layer l, network_state state);
-void backward_rnn_layer_gpu(layer l, network_state state);
-void update_rnn_layer_gpu(layer l, int batch, float learning_rate, float momentum, float decay);
-void push_rnn_layer(layer l);
-void pull_rnn_layer(layer l);
+void forward_gru_layer_gpu(layer l, network_state state);
+void backward_gru_layer_gpu(layer l, network_state state);
+void update_gru_layer_gpu(layer l, int batch, float learning_rate, float momentum, float decay);
+void push_gru_layer(layer l);
+void pull_gru_layer(layer l);
 #endif
 
 #endif
diff --git a/src/utils.c b/src/utils.c
index 398d18a..1541e05 100644
--- a/src/utils.c
+++ b/src/utils.c
@@ -273,6 +273,42 @@
     return line;
 }
 
+int read_int(int fd)
+{
+    int n = 0;
+    int next = read(fd, &n, sizeof(int));
+    if(next <= 0) return -1;
+    return n;
+}
+
+void write_int(int fd, int n)
+{
+    int next = write(fd, &n, sizeof(int));
+    if(next <= 0) error("read failed");
+}
+
+int read_all_fail(int fd, char *buffer, size_t bytes)
+{
+    size_t n = 0;
+    while(n < bytes){
+        int next = read(fd, buffer + n, bytes-n);
+        if(next <= 0) return 1;
+        n += next;
+    }
+    return 0;
+}
+
+int write_all_fail(int fd, char *buffer, size_t bytes)
+{
+    size_t n = 0;
+    while(n < bytes){
+        size_t next = write(fd, buffer + n, bytes-n);
+        if(next <= 0) return 1;
+        n += next;
+    }
+    return 0;
+}
+
 void read_all(int fd, char *buffer, size_t bytes)
 {
     size_t n = 0;
@@ -441,6 +477,19 @@
     }
 }
 
+int sample_array(float *a, int n)
+{
+    float sum = sum_array(a, n);
+    scale_array(a, n, 1./sum);
+    float r = rand_uniform(0, 1);
+    int i;
+    for(i = 0; i < n; ++i){
+        r = r - a[i];
+        if (r <= 0) return i;
+    }
+    return n-1;
+}
+
 int max_index(float *a, int n)
 {
     if(n <= 0) return -1;
@@ -495,6 +544,18 @@
    }
  */
 
+size_t rand_size_t()
+{
+    return  ((size_t)(rand()&0xff) << 56) | 
+            ((size_t)(rand()&0xff) << 48) |
+            ((size_t)(rand()&0xff) << 40) |
+            ((size_t)(rand()&0xff) << 32) |
+            ((size_t)(rand()&0xff) << 24) |
+            ((size_t)(rand()&0xff) << 16) |
+            ((size_t)(rand()&0xff) << 8) |
+            ((size_t)(rand()&0xff) << 0);
+}
+
 float rand_uniform(float min, float max)
 {
     return ((float)rand()/RAND_MAX * (max - min)) + min;
diff --git a/src/utils.h b/src/utils.h
index 3af85d3..7e49818 100644
--- a/src/utils.h
+++ b/src/utils.h
@@ -12,8 +12,12 @@
 char *basecfg(char *cfgfile);
 int alphanum_to_int(char c);
 char int_to_alphanum(int i);
+int read_int(int fd);
+void write_int(int fd, int n);
 void read_all(int fd, char *buffer, size_t bytes);
 void write_all(int fd, char *buffer, size_t bytes);
+int read_all_fail(int fd, char *buffer, size_t bytes);
+int write_all_fail(int fd, char *buffer, size_t bytes);
 char *find_replace(char *str, char *orig, char *rep);
 void error(const char *s);
 void malloc_error();
@@ -34,6 +38,7 @@
 float constrain(float min, float max, float a);
 float mse_array(float *a, int n);
 float rand_normal();
+size_t rand_size_t();
 float rand_uniform(float min, float max);
 int rand_int(int min, int max);
 float sum_array(float *a, int n);
@@ -47,6 +52,7 @@
 float find_float_arg(int argc, char **argv, char *arg, float def);
 int find_arg(int argc, char* argv[], char *arg);
 char *find_char_arg(int argc, char **argv, char *arg, char *def);
+int sample_array(float *a, int n);
 
 #endif
 
diff --git a/src/yolo.c b/src/yolo.c
index 02c4fba..9c3999e 100644
--- a/src/yolo.c
+++ b/src/yolo.c
@@ -71,7 +71,7 @@
         avg_loss = avg_loss*.9 + loss*.1;
 
         printf("%d: %f, %f avg, %f rate, %lf seconds, %d images\n", i, loss, avg_loss, get_current_rate(net), sec(clock()-time), i*imgs);
-        if(i%1000==0 || i == 600){
+        if(i%1000==0 || (i < 1000 && i%100 == 0)){
             char buff[256];
             sprintf(buff, "%s/%s_%d.weights", backup_directory, base, i);
             save_weights(net, buff);
@@ -143,7 +143,8 @@
     srand(time(0));
 
     char *base = "results/comp4_det_test_";
-    list *plist = get_paths("data/voc.2007.test");
+    //list *plist = get_paths("data/voc.2007.test");
+    list *plist = get_paths("/home/pjreddie/data/voc/2007_test.txt");
     //list *plist = get_paths("data/voc.2012.test");
     char **paths = (char **)list_to_array(plist);
 
@@ -344,7 +345,7 @@
         convert_yolo_detections(predictions, l.classes, l.n, l.sqrt, l.side, 1, 1, thresh, probs, boxes, 0);
         if (nms) do_nms_sort(boxes, probs, l.side*l.side*l.n, l.classes, nms);
         //draw_detections(im, l.side*l.side*l.n, thresh, boxes, probs, voc_names, voc_labels, 20);
-        draw_detections(im, l.side*l.side*l.n, thresh, boxes, probs, voc_names, 0, 20);
+        draw_detections(im, l.side*l.side*l.n, thresh, boxes, probs, voc_names, voc_labels, 20);
         show_image(im, "predictions");
         save_image(im, "predictions");
 
@@ -359,42 +360,6 @@
     }
 }
 
-/*
-#ifdef OPENCV
-image ipl_to_image(IplImage* src);
-#include "opencv2/highgui/highgui_c.h"
-#include "opencv2/imgproc/imgproc_c.h"
-
-void demo_swag(char *cfgfile, char *weightfile, float thresh)
-{
-network net = parse_network_cfg(cfgfile);
-if(weightfile){
-load_weights(&net, weightfile);
-}
-detection_layer layer = net.layers[net.n-1];
-CvCapture *capture = cvCaptureFromCAM(-1);
-set_batch_network(&net, 1);
-srand(2222222);
-while(1){
-IplImage* frame = cvQueryFrame(capture);
-image im = ipl_to_image(frame);
-cvReleaseImage(&frame);
-rgbgr_image(im);
-
-image sized = resize_image(im, net.w, net.h);
-float *X = sized.data;
-float *predictions = network_predict(net, X);
-draw_swag(im, predictions, layer.side, layer.n, "predictions", thresh);
-free_image(im);
-free_image(sized);
-cvWaitKey(10);
-}
-}
-#else
-void demo_swag(char *cfgfile, char *weightfile, float thresh){}
-#endif
- */
-
 void demo_yolo(char *cfgfile, char *weightfile, float thresh, int cam_index, char *filename);
 
 void run_yolo(int argc, char **argv)
diff --git a/src/yolo_demo.c b/src/yolo_demo.c
index 4e3f839..194a236 100644
--- a/src/yolo_demo.c
+++ b/src/yolo_demo.c
@@ -12,7 +12,6 @@
 #include "opencv2/imgproc/imgproc.hpp"
 image ipl_to_image(IplImage* src);
 void convert_yolo_detections(float *predictions, int classes, int num, int square, int side, int w, int h, float thresh, float **probs, box *boxes, int only_objectness);
-void draw_yolo(image im, int num, float thresh, box *boxes, float **probs);
 
 extern char *voc_names[];
 extern image voc_labels[];
diff --git a/src/yolo_kernels.cu b/src/yolo_kernels.cu
index d7f1b26..b320026 100644
--- a/src/yolo_kernels.cu
+++ b/src/yolo_kernels.cu
@@ -18,7 +18,6 @@
 #include "opencv2/imgproc/imgproc.hpp"
 extern "C" image ipl_to_image(IplImage* src);
 extern "C" void convert_yolo_detections(float *predictions, int classes, int num, int square, int side, int w, int h, float thresh, float **probs, box *boxes, int only_objectness);
-extern "C" void draw_yolo(image im, int num, float thresh, box *boxes, float **probs);
 
 extern "C" char *voc_names[];
 extern "C" image voc_labels[];

--
Gitblit v1.10.0