From 5a47c46b39475fc3581b9819f488b977ea1beca3 Mon Sep 17 00:00:00 2001
From: Edmond Yoo <hj3yoo@uwaterloo.ca>
Date: Sun, 16 Sep 2018 03:11:04 +0000
Subject: [PATCH] Moving files from MTGCardDetector

---
 src/tree.c |   86 +++++++++++++++++++++++++++++++++++++++++++
 1 files changed, 86 insertions(+), 0 deletions(-)

diff --git a/src/tree.c b/src/tree.c
index 5a758f7..d66da9f 100644
--- a/src/tree.c
+++ b/src/tree.c
@@ -2,6 +2,85 @@
 #include <stdlib.h>
 #include "tree.h"
 #include "utils.h"
+#include "data.h"
+
+void change_leaves(tree *t, char *leaf_list)
+{
+    list *llist = get_paths(leaf_list);
+    char **leaves = (char **)list_to_array(llist);
+    int n = llist->size;
+    int i,j;
+    int found = 0;
+    for(i = 0; i < t->n; ++i){
+        t->leaf[i] = 0;
+        for(j = 0; j < n; ++j){
+            if (0==strcmp(t->name[i], leaves[j])){
+                t->leaf[i] = 1;
+                ++found;
+                break;
+            }
+        }
+    }
+    fprintf(stderr, "Found %d leaves.\n", found);
+}
+
+float get_hierarchy_probability(float *x, tree *hier, int c)
+{
+    float p = 1;
+    while(c >= 0){
+        p = p * x[c];
+        c = hier->parent[c];
+    }
+    return p;
+}
+
+void hierarchy_predictions(float *predictions, int n, tree *hier, int only_leaves)
+{
+    int j;
+    for(j = 0; j < n; ++j){
+        int parent = hier->parent[j];
+        if(parent >= 0){
+            predictions[j] *= predictions[parent]; 
+        }
+    }
+    if(only_leaves){
+        for(j = 0; j < n; ++j){
+            if(!hier->leaf[j]) predictions[j] = 0;
+        }
+    }
+}
+
+int hierarchy_top_prediction(float *predictions, tree *hier, float thresh, int stride)
+{
+    float p = 1;
+    int group = 0;
+    int i;
+    while (1) {
+        float max = 0;
+        int max_i = 0;
+
+        for (i = 0; i < hier->group_size[group]; ++i) {
+            int index = i + hier->group_offset[group];
+            float val = predictions[(i + hier->group_offset[group])*stride];
+            if (val > max) {
+                max_i = index;
+                max = val;
+            }
+        }
+        if (p*max > thresh) {
+            p = p*max;
+            group = hier->child[max_i];
+            if (hier->child[max_i] < 0) return max_i;
+        }
+        else if (group == 0) {
+            return max_i;
+        }
+        else {
+            return hier->parent[hier->group_offset[group]];
+        }
+    }
+    return 0;
+}
 
 tree *read_tree(char *filename)
 {
@@ -19,19 +98,26 @@
         sscanf(line, "%s %d", id, &parent);
         t.parent = realloc(t.parent, (n+1)*sizeof(int));
         t.parent[n] = parent;
+
         t.name = realloc(t.name, (n+1)*sizeof(char *));
         t.name[n] = id;
         if(parent != last_parent){
             ++groups;
+            t.group_offset = realloc(t.group_offset, groups * sizeof(int));
+            t.group_offset[groups - 1] = n - group_size;
             t.group_size = realloc(t.group_size, groups * sizeof(int));
             t.group_size[groups - 1] = group_size;
             group_size = 0;
             last_parent = parent;
         }
+        t.group = realloc(t.group, (n+1)*sizeof(int));
+        t.group[n] = groups;
         ++n;
         ++group_size;
     }
     ++groups;
+    t.group_offset = realloc(t.group_offset, groups * sizeof(int));
+    t.group_offset[groups - 1] = n - group_size;
     t.group_size = realloc(t.group_size, groups * sizeof(int));
     t.group_size[groups - 1] = group_size;
     t.n = n;

--
Gitblit v1.10.0