From 0f7f2899b65343e56b0a1188f703d459d824d398 Mon Sep 17 00:00:00 2001
From: Joseph Redmon <pjreddie@gmail.com>
Date: Mon, 16 Nov 2015 03:51:26 +0000
Subject: [PATCH] Fix for cuda 7.5
---
src/network.c | 26 ++++++++++++--------------
1 files changed, 12 insertions(+), 14 deletions(-)
diff --git a/src/network.c b/src/network.c
index 7f19318..6c7461d 100644
--- a/src/network.c
+++ b/src/network.c
@@ -8,10 +8,10 @@
#include "crop_layer.h"
#include "connected_layer.h"
+#include "local_layer.h"
#include "convolutional_layer.h"
#include "deconvolutional_layer.h"
#include "detection_layer.h"
-#include "region_layer.h"
#include "normalization_layer.h"
#include "maxpool_layer.h"
#include "avgpool_layer.h"
@@ -60,6 +60,8 @@
switch(a){
case CONVOLUTIONAL:
return "convolutional";
+ case LOCAL:
+ return "local";
case DECONVOLUTIONAL:
return "deconvolutional";
case CONNECTED:
@@ -72,8 +74,6 @@
return "softmax";
case DETECTION:
return "detection";
- case REGION:
- return "region";
case DROPOUT:
return "dropout";
case CROP:
@@ -115,12 +115,12 @@
forward_convolutional_layer(l, state);
} else if(l.type == DECONVOLUTIONAL){
forward_deconvolutional_layer(l, state);
+ } else if(l.type == LOCAL){
+ forward_local_layer(l, state);
} else if(l.type == NORMALIZATION){
forward_normalization_layer(l, state);
} else if(l.type == DETECTION){
forward_detection_layer(l, state);
- } else if(l.type == REGION){
- forward_region_layer(l, state);
} else if(l.type == CONNECTED){
forward_connected_layer(l, state);
} else if(l.type == CROP){
@@ -155,6 +155,8 @@
update_deconvolutional_layer(l, rate, net.momentum, net.decay);
} else if(l.type == CONNECTED){
update_connected_layer(l, update_batch, rate, net.momentum, net.decay);
+ } else if(l.type == LOCAL){
+ update_local_layer(l, update_batch, rate, net.momentum, net.decay);
}
}
}
@@ -180,10 +182,6 @@
sum += net.layers[i].cost[0];
++count;
}
- if(net.layers[i].type == REGION){
- sum += net.layers[i].cost[0];
- ++count;
- }
}
return sum/count;
}
@@ -224,12 +222,12 @@
backward_dropout_layer(l, state);
} else if(l.type == DETECTION){
backward_detection_layer(l, state);
- } else if(l.type == REGION){
- backward_region_layer(l, state);
} else if(l.type == SOFTMAX){
if(i != 0) backward_softmax_layer(l, state);
} else if(l.type == CONNECTED){
backward_connected_layer(l, state);
+ } else if(l.type == LOCAL){
+ backward_local_layer(l, state);
} else if(l.type == COST){
backward_cost_layer(l, state);
} else if(l.type == ROUTE){
@@ -540,12 +538,12 @@
return acc;
}
-float *network_accuracies(network net, data d)
+float *network_accuracies(network net, data d, int n)
{
static float acc[2];
matrix guess = network_predict_data(net, d);
- acc[0] = matrix_topk_accuracy(d.y, guess,1);
- acc[1] = matrix_topk_accuracy(d.y, guess,5);
+ acc[0] = matrix_topk_accuracy(d.y, guess, 1);
+ acc[1] = matrix_topk_accuracy(d.y, guess, n);
free_matrix(guess);
return acc;
}
--
Gitblit v1.10.0