src/data.c
@@ -148,6 +148,16 @@ return d; } void get_batch(data d, int n, float *X, float *y) { int j; for(j = 0; j < n; ++j){ int index = rand()%d.X.rows; memcpy(X+j*d.X.cols, d.X.vals[index], d.X.cols*sizeof(float)); memcpy(y+j*d.y.cols, d.y.vals[index], d.y.cols*sizeof(float)); } } data load_all_cifar10() { data d;