src/data.c
@@ -172,7 +172,7 @@ return d; } void get_batch(data d, int n, float *X, float *y) void get_random_batch(data d, int n, float *X, float *y) { int j; for(j = 0; j < n; ++j){ @@ -182,6 +182,17 @@ } } void get_next_batch(data d, int n, int offset, float *X, float *y) { int j; for(j = 0; j < n; ++j){ int index = offset + j; memcpy(X+j*d.X.cols, d.X.vals[index], d.X.cols*sizeof(float)); memcpy(y+j*d.y.cols, d.y.vals[index], d.y.cols*sizeof(float)); } } data load_all_cifar10() { data d;