Edmond Yoo
2018-09-16 0dab894a5be9f7d10d85e89dea91d02c71bae84d
src/tree.c
@@ -24,6 +24,16 @@
    fprintf(stderr, "Found %d leaves.\n", found);
}
float get_hierarchy_probability(float *x, tree *hier, int c)
{
    float p = 1;
    while(c >= 0){
        p = p * x[c];
        c = hier->parent[c];
    }
    return p;
}
void hierarchy_predictions(float *predictions, int n, tree *hier, int only_leaves)
{
    int j;
@@ -40,6 +50,38 @@
    }
}
int hierarchy_top_prediction(float *predictions, tree *hier, float thresh, int stride)
{
    float p = 1;
    int group = 0;
    int i;
    while (1) {
        float max = 0;
        int max_i = 0;
        for (i = 0; i < hier->group_size[group]; ++i) {
            int index = i + hier->group_offset[group];
            float val = predictions[(i + hier->group_offset[group])*stride];
            if (val > max) {
                max_i = index;
                max = val;
            }
        }
        if (p*max > thresh) {
            p = p*max;
            group = hier->child[max_i];
            if (hier->child[max_i] < 0) return max_i;
        }
        else if (group == 0) {
            return max_i;
        }
        else {
            return hier->parent[hier->group_offset[group]];
        }
    }
    return 0;
}
tree *read_tree(char *filename)
{
    tree t = {0};