AlexeyAB
2018-04-23 3df335bb50f890b12fa1a9965e91b0cf46d7902c
src/tree.c
@@ -24,6 +24,16 @@
    fprintf(stderr, "Found %d leaves.\n", found);
}
float get_hierarchy_probability(float *x, tree *hier, int c)
{
    float p = 1;
    while(c >= 0){
        p = p * x[c];
        c = hier->parent[c];
    }
    return p;
}
void hierarchy_predictions(float *predictions, int n, tree *hier, int only_leaves)
{
    int j;
@@ -40,6 +50,38 @@
    }
}
int hierarchy_top_prediction(float *predictions, tree *hier, float thresh, int stride)
{
   float p = 1;
   int group = 0;
   int i;
   while (1) {
      float max = 0;
      int max_i = 0;
      for (i = 0; i < hier->group_size[group]; ++i) {
         int index = i + hier->group_offset[group];
         float val = predictions[(i + hier->group_offset[group])*stride];
         if (val > max) {
            max_i = index;
            max = val;
         }
      }
      if (p*max > thresh) {
         p = p*max;
         group = hier->child[max_i];
         if (hier->child[max_i] < 0) return max_i;
      }
      else if (group == 0) {
         return max_i;
      }
      else {
         return hier->parent[hier->group_offset[group]];
      }
   }
   return 0;
}
tree *read_tree(char *filename)
{
    tree t = {0};