From 5a47c46b39475fc3581b9819f488b977ea1beca3 Mon Sep 17 00:00:00 2001
From: Edmond Yoo <hj3yoo@uwaterloo.ca>
Date: Sun, 16 Sep 2018 03:11:04 +0000
Subject: [PATCH] Moving files from MTGCardDetector
---
src/compare.c | 115 +++++++++++++++++++++++++++++++++++++++++----------------
1 files changed, 82 insertions(+), 33 deletions(-)
diff --git a/src/compare.c b/src/compare.c
index 0408f80..803d812 100644
--- a/src/compare.c
+++ b/src/compare.c
@@ -9,7 +9,6 @@
void train_compare(char *cfgfile, char *weightfile)
{
- data_seed = time(0);
srand(time(0));
float avg_loss = -1;
char *base = basecfg(cfgfile);
@@ -149,18 +148,21 @@
typedef struct {
network net;
char *filename;
- int class;
+ int class_id;
+ int classes;
float elo;
+ float *elos;
} sortable_bbox;
int total_compares = 0;
+int current_class_id = 0;
int elo_comparator(const void*a, const void *b)
{
sortable_bbox box1 = *(sortable_bbox*)a;
sortable_bbox box2 = *(sortable_bbox*)b;
- if(box1.elo == box2.elo) return 0;
- if(box1.elo > box2.elo) return -1;
+ if(box1.elos[current_class_id] == box2.elos[current_class_id]) return 0;
+ if(box1.elos[current_class_id] > box2.elos[current_class_id]) return -1;
return 1;
}
@@ -170,7 +172,7 @@
sortable_bbox box1 = *(sortable_bbox*)a;
sortable_bbox box2 = *(sortable_bbox*)b;
network net = box1.net;
- int class = box1.class;
+ int class_id = box1.class_id;
image im1 = load_image_color(box1.filename, net.w, net.h);
image im2 = load_image_color(box2.filename, net.w, net.h);
@@ -182,22 +184,44 @@
free_image(im1);
free_image(im2);
free(X);
- if (predictions[class*2] > predictions[class*2+1]){
+ if (predictions[class_id*2] > predictions[class_id*2+1]){
return 1;
}
return -1;
}
-void bbox_fight(sortable_bbox *a, sortable_bbox *b)
+void bbox_update(sortable_bbox *a, sortable_bbox *b, int class_id, int result)
{
int k = 32;
- int result = bbox_comparator(a,b);
- float EA = 1./(1+pow(10, (b->elo - a->elo)/400.));
- float EB = 1./(1+pow(10, (a->elo - b->elo)/400.));
- float SA = 1.*(result > 0);
- float SB = 1.*(result < 0);
- a->elo = a->elo + k*(SA - EA);
- b->elo = b->elo + k*(SB - EB);
+ float EA = 1./(1+pow(10, (b->elos[class_id] - a->elos[class_id])/400.));
+ float EB = 1./(1+pow(10, (a->elos[class_id] - b->elos[class_id])/400.));
+ float SA = result ? 1 : 0;
+ float SB = result ? 0 : 1;
+ a->elos[class_id] += k*(SA - EA);
+ b->elos[class_id] += k*(SB - EB);
+}
+
+void bbox_fight(network net, sortable_bbox *a, sortable_bbox *b, int classes, int class_id)
+{
+ image im1 = load_image_color(a->filename, net.w, net.h);
+ image im2 = load_image_color(b->filename, net.w, net.h);
+ float *X = calloc(net.w*net.h*net.c, sizeof(float));
+ memcpy(X, im1.data, im1.w*im1.h*im1.c*sizeof(float));
+ memcpy(X+im1.w*im1.h*im1.c, im2.data, im2.w*im2.h*im2.c*sizeof(float));
+ float *predictions = network_predict(net, X);
+ ++total_compares;
+
+ int i;
+ for(i = 0; i < classes; ++i){
+ if(class_id < 0 || class_id == i){
+ int result = predictions[i*2] > predictions[i*2+1];
+ bbox_update(a, b, i, result);
+ }
+ }
+
+ free_image(im1);
+ free_image(im2);
+ free(X);
}
void SortMaster3000(char *filename, char *weightfile)
@@ -220,7 +244,7 @@
for(i = 0; i < N; ++i){
boxes[i].filename = paths[i];
boxes[i].net = net;
- boxes[i].class = 7;
+ boxes[i].class_id = 7;
boxes[i].elo = 1500;
}
clock_t time=clock();
@@ -233,7 +257,8 @@
void BattleRoyaleWithCheese(char *filename, char *weightfile)
{
- int i = 0;
+ int classes = 20;
+ int i,j;
network net = parse_network_cfg(filename);
if(weightfile){
load_weights(&net, weightfile);
@@ -242,41 +267,65 @@
set_batch_network(&net, 1);
list *plist = get_paths("data/compare.sort.list");
+ //list *plist = get_paths("data/compare.small.list");
+ //list *plist = get_paths("data/compare.cat.list");
//list *plist = get_paths("data/compare.val.old");
char **paths = (char **)list_to_array(plist);
int N = plist->size;
+ int total = N;
free_list(plist);
sortable_bbox *boxes = calloc(N, sizeof(sortable_bbox));
printf("Battling %d boxes...\n", N);
for(i = 0; i < N; ++i){
boxes[i].filename = paths[i];
boxes[i].net = net;
- boxes[i].class = 7;
- boxes[i].elo = 1500;
+ boxes[i].classes = classes;
+ boxes[i].elos = calloc(classes, sizeof(float));;
+ for(j = 0; j < classes; ++j){
+ boxes[i].elos[j] = 1500;
+ }
}
int round;
clock_t time=clock();
- for(round = 1; round <= 40; ++round){
+ for(round = 1; round <= 4; ++round){
clock_t round_time=clock();
printf("Round: %d\n", round);
- qsort(boxes, N, sizeof(sortable_bbox), elo_comparator);
- sorta_shuffle(boxes, N, sizeof(sortable_bbox), 10);
+ shuffle(boxes, N, sizeof(sortable_bbox));
for(i = 0; i < N/2; ++i){
- bbox_fight(boxes+i*2, boxes+i*2+1);
- }
- if(round >= 4){
- qsort(boxes, N, sizeof(sortable_bbox), elo_comparator);
- if(round == 4){
- N = N/2;
- }else{
- N = (N*9/10)/2*2;
- }
+ bbox_fight(net, boxes+i*2, boxes+i*2+1, classes, -1);
}
printf("Round: %f secs, %d remaining\n", sec(clock()-round_time), N);
}
- qsort(boxes, N, sizeof(sortable_bbox), elo_comparator);
- for(i = 0; i < N; ++i){
- printf("%s %f\n", boxes[i].filename, boxes[i].elo);
+
+ int class_id;
+
+ for (class_id = 0; class_id < classes; ++class_id){
+
+ N = total;
+ current_class_id = class_id;
+ qsort(boxes, N, sizeof(sortable_bbox), elo_comparator);
+ N /= 2;
+
+ for(round = 1; round <= 100; ++round){
+ clock_t round_time=clock();
+ printf("Round: %d\n", round);
+
+ sorta_shuffle(boxes, N, sizeof(sortable_bbox), 10);
+ for(i = 0; i < N/2; ++i){
+ bbox_fight(net, boxes+i*2, boxes+i*2+1, classes, class_id);
+ }
+ qsort(boxes, N, sizeof(sortable_bbox), elo_comparator);
+ if(round <= 20) N = (N*9/10)/2*2;
+
+ printf("Round: %f secs, %d remaining\n", sec(clock()-round_time), N);
+ }
+ char buff[256];
+ sprintf(buff, "results/battle_%d.log", class_id);
+ FILE *outfp = fopen(buff, "w");
+ for(i = 0; i < N; ++i){
+ fprintf(outfp, "%s %f\n", boxes[i].filename, boxes[i].elos[class_id]);
+ }
+ fclose(outfp);
}
printf("Tournament in %d compares, %f secs\n", total_compares, sec(clock()-time));
}
--
Gitblit v1.10.0