From 13209df7bb53de19aa3f82e870db11eb5b7587f1 Mon Sep 17 00:00:00 2001
From: Joseph Redmon <pjreddie@gmail.com>
Date: Fri, 13 May 2016 18:59:43 +0000
Subject: [PATCH] art, cudnn

---
 src/cuda.c                   |   25 +++-
 src/art.c                    |   76 ++++++++++++
 src/network.h                |    2 
 Makefile                     |    9 +
 src/convolutional_layer.c    |   92 +++++++++++++-
 src/network_kernels.cu       |    2 
 src/parser.c                 |    7 +
 src/cuda.h                   |   20 ++-
 src/convolutional_kernels.cu |   67 +++++++++-
 src/darknet.c                |    3 
 src/layer.h                  |   13 ++
 11 files changed, 286 insertions(+), 30 deletions(-)

diff --git a/Makefile b/Makefile
index 1ef1b3b..03be143 100644
--- a/Makefile
+++ b/Makefile
@@ -1,4 +1,5 @@
 GPU=0
+CUDNN=0
 OPENCV=0
 DEBUG=0
 
@@ -34,7 +35,13 @@
 LDFLAGS+= -L/usr/local/cuda/lib64 -lcuda -lcudart -lcublas -lcurand
 endif
 
-OBJ=gemm.o utils.o cuda.o deconvolutional_layer.o convolutional_layer.o list.o image.o activations.o im2col.o col2im.o blas.o crop_layer.o dropout_layer.o maxpool_layer.o softmax_layer.o data.o matrix.o network.o connected_layer.o cost_layer.o parser.o option_list.o darknet.o detection_layer.o imagenet.o captcha.o route_layer.o writing.o box.o nightmare.o normalization_layer.o avgpool_layer.o coco.o dice.o yolo.o layer.o compare.o classifier.o local_layer.o swag.o shortcut_layer.o activation_layer.o rnn_layer.o gru_layer.o rnn.o rnn_vid.o crnn_layer.o coco_demo.o tag.o cifar.o yolo_demo.o go.o batchnorm_layer.o
+ifeq ($(CUDNN), 1) 
+COMMON+= -DCUDNN 
+CFLAGS+= -DCUDNN
+LDFLAGS+= -lcudnn
+endif
+
+OBJ=gemm.o utils.o cuda.o deconvolutional_layer.o convolutional_layer.o list.o image.o activations.o im2col.o col2im.o blas.o crop_layer.o dropout_layer.o maxpool_layer.o softmax_layer.o data.o matrix.o network.o connected_layer.o cost_layer.o parser.o option_list.o darknet.o detection_layer.o imagenet.o captcha.o route_layer.o writing.o box.o nightmare.o normalization_layer.o avgpool_layer.o coco.o dice.o yolo.o layer.o compare.o classifier.o local_layer.o swag.o shortcut_layer.o activation_layer.o rnn_layer.o gru_layer.o rnn.o rnn_vid.o crnn_layer.o coco_demo.o tag.o cifar.o yolo_demo.o go.o batchnorm_layer.o art.o
 ifeq ($(GPU), 1) 
 LDFLAGS+= -lstdc++ 
 OBJ+=convolutional_kernels.o deconvolutional_kernels.o activation_kernels.o im2col_kernels.o col2im_kernels.o blas_kernels.o crop_layer_kernels.o dropout_layer_kernels.o maxpool_layer_kernels.o softmax_layer_kernels.o network_kernels.o avgpool_layer_kernels.o
diff --git a/src/art.c b/src/art.c
new file mode 100644
index 0000000..785ab52
--- /dev/null
+++ b/src/art.c
@@ -0,0 +1,76 @@
+#include "network.h"
+#include "utils.h"
+#include "parser.h"
+#include "option_list.h"
+#include "blas.h"
+#include "classifier.h"
+#include <sys/time.h>
+
+#ifdef OPENCV
+#include "opencv2/highgui/highgui_c.h"
+#endif
+
+
+void demo_art(char *cfgfile, char *weightfile, int cam_index)
+{
+#ifdef OPENCV
+    network net = parse_network_cfg(cfgfile);
+    if(weightfile){
+        load_weights(&net, weightfile);
+    }
+    set_batch_network(&net, 1);
+
+    srand(2222222);
+    CvCapture * cap;
+
+    cap = cvCaptureFromCAM(cam_index);
+
+    char *window = "ArtJudgementBot9000!!!";
+    if(!cap) error("Couldn't connect to webcam.\n");
+    cvNamedWindow(window, CV_WINDOW_NORMAL); 
+    cvResizeWindow(window, 512, 512);
+    int i;
+    int idx[] = {37, 401, 434};
+    int n = sizeof(idx)/sizeof(idx[0]);
+
+    while(1){
+        image in = get_image_from_stream(cap);
+        image in_s = resize_image(in, net.w, net.h);
+        show_image(in, window);
+
+        float *p = network_predict(net, in_s.data);
+
+        printf("\033[2J");
+        printf("\033[1;1H");
+
+        float score = 0;
+        for(i = 0; i < n; ++i){
+            float s = p[idx[i]];
+            if (s > score) score = s;
+        }
+        score = score;
+        printf("I APPRECIATE THIS ARTWORK: %10.7f%%\n", score*100);
+        printf("[");
+	int upper = 30;
+        for(i = 0; i < upper; ++i){
+            printf("%s", ((i+.5) < score*upper) ? "\u2588" : " ");
+        }
+        printf("]\n");
+
+        free_image(in_s);
+        free_image(in);
+
+        cvWaitKey(1);
+    }
+#endif
+}
+
+
+void run_art(int argc, char **argv)
+{
+    int cam_index = find_int_arg(argc, argv, "-c", 0);
+    char *cfg = argv[2];
+    char *weights = argv[3];
+    demo_art(cfg, weights, cam_index);
+}
+
diff --git a/src/convolutional_kernels.cu b/src/convolutional_kernels.cu
index 62d6079..0cd5124 100644
--- a/src/convolutional_kernels.cu
+++ b/src/convolutional_kernels.cu
@@ -85,7 +85,6 @@
 
     if(l.xnor){
         binarize_filters_gpu(l.filters_gpu, l.n, l.c*l.size*l.size, l.binary_filters_gpu);
-        //binarize_gpu(l.filters_gpu, l.n*l.c*l.size*l.size, l.binary_filters_gpu);
         swap_binary(&l);
         for(i = 0; i < l.batch; ++i){
             binarize_input_gpu(state.input + i*l.inputs, l.c, l.h*l.w, l.binary_input_gpu + i*l.inputs);
@@ -93,13 +92,31 @@
         state.input = l.binary_input_gpu;
     }
 
+#ifdef CUDNN
+    float one = 1;
+    cudnnConvolutionForward(cudnn_handle(),
+                &one,
+                l.srcTensorDesc,
+                state.input,
+                l.filterDesc,
+                l.filters_gpu,
+                l.convDesc,
+                l.fw_algo,
+                state.workspace,
+                l.workspace_size,
+                &one,
+                l.dstTensorDesc,
+                l.output_gpu);
+
+#else
     for(i = 0; i < l.batch; ++i){
-        im2col_ongpu(state.input + i*l.c*l.h*l.w, l.c,  l.h,  l.w,  l.size,  l.stride, l.pad, l.col_image_gpu);
+        im2col_ongpu(state.input + i*l.c*l.h*l.w, l.c,  l.h,  l.w,  l.size,  l.stride, l.pad, state.workspace);
         float * a = l.filters_gpu;
-        float * b = l.col_image_gpu;
+        float * b = state.workspace;
         float * c = l.output_gpu;
         gemm_ongpu(0,0,m,n,k,1.,a,k,b,n,1.,c+i*m*n,n);
     }
+#endif
 
     if (l.batch_normalize) {
         forward_batchnorm_layer_gpu(l, state);
@@ -113,7 +130,6 @@
 
 void backward_convolutional_layer_gpu(convolutional_layer l, network_state state)
 {
-    int i;
     int m = l.n;
     int n = l.size*l.size*l.c;
     int k = convolutional_out_height(l)*
@@ -128,26 +144,61 @@
     }
 
     if(l.xnor) state.input = l.binary_input_gpu;
+#ifdef CUDNN
+    float one = 1;
+    cudnnConvolutionBackwardFilter(cudnn_handle(),
+            &one,
+            l.srcTensorDesc,
+            state.input,
+            l.ddstTensorDesc,
+            l.delta_gpu,
+            l.convDesc,
+            l.bf_algo,
+            state.workspace,
+            l.workspace_size,
+            &one,
+            l.dfilterDesc,
+            l.filter_updates_gpu);
+
+    if(state.delta){
+        cudnnConvolutionBackwardData(cudnn_handle(),
+                &one,
+                l.filterDesc,
+                l.filters_gpu,
+                l.ddstTensorDesc,
+                l.delta_gpu,
+                l.convDesc,
+                l.bd_algo,
+                state.workspace,
+                l.workspace_size,
+                &one,
+                l.dsrcTensorDesc,
+                state.delta);
+    }
+
+#else
+    int i;
     for(i = 0; i < l.batch; ++i){
         float * a = l.delta_gpu;
-        float * b = l.col_image_gpu;
+        float * b = state.workspace;
         float * c = l.filter_updates_gpu;
 
-        im2col_ongpu(state.input + i*l.c*l.h*l.w, l.c,  l.h,  l.w,  l.size,  l.stride, l.pad, l.col_image_gpu);
+        im2col_ongpu(state.input + i*l.c*l.h*l.w, l.c,  l.h,  l.w,  l.size,  l.stride, l.pad, state.workspace);
         gemm_ongpu(0,1,m,n,k,1,a + i*m*k,k,b,k,1,c,n);
 
         if(state.delta){
             if(l.binary || l.xnor) swap_binary(&l);
             float * a = l.filters_gpu;
             float * b = l.delta_gpu;
-            float * c = l.col_image_gpu;
+            float * c = state.workspace;
 
             gemm_ongpu(1,0,n,k,m,1,a,n,b + i*k*m,k,0,c,k);
 
-            col2im_ongpu(l.col_image_gpu, l.c,  l.h,  l.w,  l.size,  l.stride, l.pad, state.delta + i*l.c*l.h*l.w);
+            col2im_ongpu(state.workspace, l.c,  l.h,  l.w,  l.size,  l.stride, l.pad, state.delta + i*l.c*l.h*l.w);
             if(l.binary || l.xnor) swap_binary(&l);
         }
     }
+#endif
 }
 
 void pull_convolutional_layer(convolutional_layer layer)
diff --git a/src/convolutional_layer.c b/src/convolutional_layer.c
index d76dfcd..a93087f 100644
--- a/src/convolutional_layer.c
+++ b/src/convolutional_layer.c
@@ -88,6 +88,38 @@
     return float_to_image(w,h,c,l.delta);
 }
 
+#ifdef CUDNN
+size_t get_workspace_size(layer l){
+    size_t most = 0;
+    size_t s = 0;
+    cudnnGetConvolutionForwardWorkspaceSize(cudnn_handle(),
+            l.srcTensorDesc,
+            l.filterDesc,
+            l.convDesc,
+            l.dstTensorDesc,
+            l.fw_algo,
+            &s);
+    if (s > most) most = s;
+    cudnnGetConvolutionBackwardFilterWorkspaceSize(cudnn_handle(),
+            l.srcTensorDesc,
+            l.ddstTensorDesc,
+            l.convDesc,
+            l.dfilterDesc,
+            l.bf_algo,
+            &s);
+    if (s > most) most = s;
+    cudnnGetConvolutionBackwardDataWorkspaceSize(cudnn_handle(),
+            l.filterDesc,
+            l.ddstTensorDesc,
+            l.convDesc,
+            l.dsrcTensorDesc,
+            l.bd_algo,
+            &s);
+    if (s > most) most = s;
+    return most;
+}
+#endif
+
 convolutional_layer make_convolutional_layer(int batch, int h, int w, int c, int n, int size, int stride, int pad, ACTIVATION activation, int batch_normalize, int binary, int xnor)
 {
     int i;
@@ -156,7 +188,7 @@
     l.scales_gpu = cuda_make_array(l.scales, n);
     l.scale_updates_gpu = cuda_make_array(l.scale_updates, n);
 
-    l.col_image_gpu = cuda_make_array(l.col_image, out_h*out_w*size*size*c);
+    l.workspace_size = out_h*out_w*size*size*c;
     l.delta_gpu = cuda_make_array(l.delta, l.batch*out_h*out_w*n);
     l.output_gpu = cuda_make_array(l.output, l.batch*out_h*out_w*n);
 
@@ -182,6 +214,50 @@
         l.x_gpu = cuda_make_array(l.output, l.batch*out_h*out_w*n);
         l.x_norm_gpu = cuda_make_array(l.output, l.batch*out_h*out_w*n);
     }
+#ifdef CUDNN
+    cudnnCreateTensorDescriptor(&l.srcTensorDesc);
+    cudnnCreateTensorDescriptor(&l.dstTensorDesc);
+    cudnnCreateFilterDescriptor(&l.filterDesc);
+    cudnnCreateTensorDescriptor(&l.dsrcTensorDesc);
+    cudnnCreateTensorDescriptor(&l.ddstTensorDesc);
+    cudnnCreateFilterDescriptor(&l.dfilterDesc);
+    cudnnCreateConvolutionDescriptor(&l.convDesc);
+    cudnnSetTensor4dDescriptor(l.dsrcTensorDesc, CUDNN_TENSOR_NCHW, CUDNN_DATA_FLOAT, l.batch, l.c, l.h, l.w); 
+    cudnnSetTensor4dDescriptor(l.ddstTensorDesc, CUDNN_TENSOR_NCHW, CUDNN_DATA_FLOAT, l.batch, l.out_c, l.out_h, l.out_w); 
+    cudnnSetFilter4dDescriptor(l.dfilterDesc, CUDNN_DATA_FLOAT, CUDNN_TENSOR_NCHW, l.n, l.c, l.size, l.size); 
+
+    cudnnSetTensor4dDescriptor(l.srcTensorDesc, CUDNN_TENSOR_NCHW, CUDNN_DATA_FLOAT, l.batch, l.c, l.h, l.w); 
+    cudnnSetTensor4dDescriptor(l.dstTensorDesc, CUDNN_TENSOR_NCHW, CUDNN_DATA_FLOAT, l.batch, l.out_c, l.out_h, l.out_w); 
+    cudnnSetFilter4dDescriptor(l.filterDesc, CUDNN_DATA_FLOAT, CUDNN_TENSOR_NCHW, l.n, l.c, l.size, l.size); 
+    int padding = l.pad ? l.size/2 : 0;
+    cudnnSetConvolution2dDescriptor(l.convDesc, padding, padding, l.stride, l.stride, 1, 1, CUDNN_CROSS_CORRELATION);
+    cudnnGetConvolutionForwardAlgorithm(cudnn_handle(),
+            l.srcTensorDesc,
+            l.filterDesc,
+            l.convDesc,
+            l.dstTensorDesc,
+            CUDNN_CONVOLUTION_FWD_PREFER_FASTEST,
+            0,
+            &l.fw_algo);
+    cudnnGetConvolutionBackwardDataAlgorithm(cudnn_handle(),
+            l.filterDesc,
+            l.ddstTensorDesc,
+            l.convDesc,
+            l.dsrcTensorDesc,
+            CUDNN_CONVOLUTION_BWD_DATA_PREFER_FASTEST,
+            0,
+            &l.bd_algo);
+    cudnnGetConvolutionBackwardFilterAlgorithm(cudnn_handle(),
+            l.srcTensorDesc,
+            l.ddstTensorDesc,
+            l.convDesc,
+            l.dfilterDesc,
+            CUDNN_CONVOLUTION_BWD_FILTER_PREFER_FASTEST,
+            0,
+            &l.bf_algo);
+    l.workspace_size = get_workspace_size(l);
+
+#endif
 #endif
     l.activation = activation;
 
@@ -247,11 +323,9 @@
             l->batch*out_h * out_w * l->n*sizeof(float));
 
 #ifdef GPU
-    cuda_free(l->col_image_gpu);
     cuda_free(l->delta_gpu);
     cuda_free(l->output_gpu);
 
-    l->col_image_gpu = cuda_make_array(l->col_image, out_h*out_w*l->size*l->size*l->c);
     l->delta_gpu =     cuda_make_array(l->delta, l->batch*out_h*out_w*l->n);
     l->output_gpu =    cuda_make_array(l->output, l->batch*out_h*out_w*l->n);
 #endif
@@ -299,12 +373,12 @@
 
     fill_cpu(l.outputs*l.batch, 0, l.output, 1);
     /*
-    if(l.binary){
-        binarize_filters(l.filters, l.n, l.c*l.size*l.size, l.binary_filters);
-        binarize_filters2(l.filters, l.n, l.c*l.size*l.size, l.cfilters, l.scales);
-        swap_binary(&l);
-    }
-    */
+       if(l.binary){
+       binarize_filters(l.filters, l.n, l.c*l.size*l.size, l.binary_filters);
+       binarize_filters2(l.filters, l.n, l.c*l.size*l.size, l.cfilters, l.scales);
+       swap_binary(&l);
+       }
+     */
 
     if(l.binary){
         int m = l.n;
diff --git a/src/cuda.c b/src/cuda.c
index d773d0b..327813d 100644
--- a/src/cuda.c
+++ b/src/cuda.c
@@ -46,6 +46,19 @@
     return d;
 }
 
+#ifdef CUDNN
+cudnnHandle_t cudnn_handle()
+{
+    static int init = 0;
+    static cudnnHandle_t handle;
+    if(!init) {
+        cudnnCreate(&handle);
+        init = 1;
+    }
+    return handle;
+}
+#endif
+
 cublasHandle_t blas_handle()
 {
     static int init = 0;
@@ -57,7 +70,7 @@
     return handle;
 }
 
-float *cuda_make_array(float *x, int n)
+float *cuda_make_array(float *x, size_t n)
 {
     float *x_gpu;
     size_t size = sizeof(float)*n;
@@ -71,7 +84,7 @@
     return x_gpu;
 }
 
-void cuda_random(float *x_gpu, int n)
+void cuda_random(float *x_gpu, size_t n)
 {
     static curandGenerator_t gen;
     static int init = 0;
@@ -84,7 +97,7 @@
     check_error(cudaPeekAtLastError());
 }
 
-float cuda_compare(float *x_gpu, float *x, int n, char *s)
+float cuda_compare(float *x_gpu, float *x, size_t n, char *s)
 {
     float *tmp = calloc(n, sizeof(float));
     cuda_pull_array(x_gpu, tmp, n);
@@ -97,7 +110,7 @@
     return err;
 }
 
-int *cuda_make_int_array(int n)
+int *cuda_make_int_array(size_t n)
 {
     int *x_gpu;
     size_t size = sizeof(int)*n;
@@ -112,14 +125,14 @@
     check_error(status);
 }
 
-void cuda_push_array(float *x_gpu, float *x, int n)
+void cuda_push_array(float *x_gpu, float *x, size_t n)
 {
     size_t size = sizeof(float)*n;
     cudaError_t status = cudaMemcpy(x_gpu, x, size, cudaMemcpyHostToDevice);
     check_error(status);
 }
 
-void cuda_pull_array(float *x_gpu, float *x, int n)
+void cuda_pull_array(float *x_gpu, float *x, size_t n)
 {
     size_t size = sizeof(float)*n;
     cudaError_t status = cudaMemcpy(x, x_gpu, size, cudaMemcpyDeviceToHost);
diff --git a/src/cuda.h b/src/cuda.h
index 9d949b0..cdd6db9 100644
--- a/src/cuda.h
+++ b/src/cuda.h
@@ -11,16 +11,24 @@
 #include "curand.h"
 #include "cublas_v2.h"
 
+#ifdef CUDNN
+#include "cudnn.h"
+#endif
+
 void check_error(cudaError_t status);
 cublasHandle_t blas_handle();
-float *cuda_make_array(float *x, int n);
-int *cuda_make_int_array(int n);
-void cuda_push_array(float *x_gpu, float *x, int n);
-void cuda_pull_array(float *x_gpu, float *x, int n);
+float *cuda_make_array(float *x, size_t n);
+int *cuda_make_int_array(size_t n);
+void cuda_push_array(float *x_gpu, float *x, size_t n);
+void cuda_pull_array(float *x_gpu, float *x, size_t n);
 void cuda_free(float *x_gpu);
-void cuda_random(float *x_gpu, int n);
-float cuda_compare(float *x_gpu, float *x, int n, char *s);
+void cuda_random(float *x_gpu, size_t n);
+float cuda_compare(float *x_gpu, float *x, size_t n, char *s);
 dim3 cuda_gridsize(size_t n);
 
+#ifdef CUDNN
+cudnnHandle_t cudnn_handle();
+#endif
+
 #endif
 #endif
diff --git a/src/darknet.c b/src/darknet.c
index f2982ac..bf662d9 100644
--- a/src/darknet.c
+++ b/src/darknet.c
@@ -26,6 +26,7 @@
 extern void run_tag(int argc, char **argv);
 extern void run_cifar(int argc, char **argv);
 extern void run_go(int argc, char **argv);
+extern void run_art(int argc, char **argv);
 
 void change_rate(char *filename, float scale, float add)
 {
@@ -259,6 +260,8 @@
         run_coco(argc, argv);
     } else if (0 == strcmp(argv[1], "classifier")){
         run_classifier(argc, argv);
+    } else if (0 == strcmp(argv[1], "art")){
+        run_art(argc, argv);
     } else if (0 == strcmp(argv[1], "tag")){
         run_tag(argc, argv);
     } else if (0 == strcmp(argv[1], "compare")){
diff --git a/src/layer.h b/src/layer.h
index 2376929..c3697ce 100644
--- a/src/layer.h
+++ b/src/layer.h
@@ -2,6 +2,7 @@
 #define BASE_LAYER_H
 
 #include "activations.h"
+#include "stddef.h"
 
 struct layer;
 typedef struct layer layer;
@@ -157,6 +158,8 @@
     struct layer *input_h_layer;
     struct layer *state_h_layer;
 
+    size_t workspace_size;
+
     #ifdef GPU
     float *z_gpu;
     float *r_gpu;
@@ -207,6 +210,16 @@
     float * rand_gpu;
     float * squared_gpu;
     float * norms_gpu;
+    #ifdef CUDNN
+    cudnnTensorDescriptor_t srcTensorDesc, dstTensorDesc;
+    cudnnTensorDescriptor_t dsrcTensorDesc, ddstTensorDesc;
+    cudnnFilterDescriptor_t filterDesc;
+    cudnnFilterDescriptor_t dfilterDesc;
+    cudnnConvolutionDescriptor_t convDesc;
+    cudnnConvolutionFwdAlgo_t fw_algo;
+    cudnnConvolutionBwdDataAlgo_t bd_algo;
+    cudnnConvolutionBwdFilterAlgo_t bf_algo;
+    #endif
     #endif
 };
 
diff --git a/src/network.h b/src/network.h
index c3a13db..15b58b8 100644
--- a/src/network.h
+++ b/src/network.h
@@ -11,6 +11,7 @@
 } learning_rate_policy;
 
 typedef struct network{
+    float *workspace;
     int n;
     int batch;
     int *seen;
@@ -49,6 +50,7 @@
     float *truth;
     float *input;
     float *delta;
+    float *workspace;
     int train;
     int index;
     network net;
diff --git a/src/network_kernels.cu b/src/network_kernels.cu
index 986a808..285f72c 100644
--- a/src/network_kernels.cu
+++ b/src/network_kernels.cu
@@ -41,6 +41,7 @@
 
 void forward_network_gpu(network net, network_state state)
 {
+    state.workspace = net.workspace;
     int i;
     for(i = 0; i < net.n; ++i){
         state.index = i;
@@ -93,6 +94,7 @@
 
 void backward_network_gpu(network net, network_state state)
 {
+    state.workspace = net.workspace;
     int i;
     float * original_input = state.input;
     float * original_delta = state.delta;
diff --git a/src/parser.c b/src/parser.c
index 5a9d0a3..d5288aa 100644
--- a/src/parser.c
+++ b/src/parser.c
@@ -524,6 +524,7 @@
     params.batch = net.batch;
     params.time_steps = net.time_steps;
 
+    size_t workspace_size = 0;
     n = n->next;
     int count = 0;
     free_section(s);
@@ -584,6 +585,7 @@
         l.dontloadscales = option_find_int_quiet(options, "dontloadscales", 0);
         option_unused(options);
         net.layers[count] = l;
+        if (l.workspace_size > workspace_size) workspace_size = l.workspace_size;
         free_section(s);
         n = n->next;
         ++count;
@@ -597,6 +599,11 @@
     free_list(sections);
     net.outputs = get_network_output_size(net);
     net.output = get_network_output(net);
+    if(workspace_size){
+#ifdef GPU
+        net.workspace = cuda_make_array(0, (workspace_size-1)/sizeof(float)+1);
+#endif
+    }
     return net;
 }
 

--
Gitblit v1.10.0