001 // Copyright 2012, 2013 Brad Block, Pawjaw, LLC. (an Ohio Limited Liability Company) 002 // 003 // This file is part of JBTCRF. 004 // 005 // JBTCRF is free software: you can redistribute it and/or modify 006 // it under the terms of the GNU General Public License as published by 007 // the Free Software Foundation, either version 3 of the License, or 008 // (at your option) any later version. 009 // 010 // JBTCRF is distributed in the hope that it will be useful, 011 // but WITHOUT ANY WARRANTY; without even the implied warranty of 012 // MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 013 // GNU General Public License for more details. 014 // 015 // You should have received a copy of the GNU General Public License 016 // along with JBTCRF. If not, see <http://www.gnu.org/licenses/>. 017 018 package com.pawjaw.classification.crf.lmcbt.trees; 019 020 import com.pawjaw.classification.crf.lmcbt.configurations.Configuration; 021 import com.pawjaw.classification.crf.lmcbt.sequences.ExpandedPointSequence; 022 import com.pawjaw.classification.crf.lmcbt.sequences.TrainingSequences; 023 import java.util.ArrayList; 024 import java.util.LinkedList; 025 import java.util.List; 026 import java.util.PriorityQueue; 027 import java.util.Queue; 028 029 public class Tree { 030 public Node root; 031 032 private Queue<Split> splitq = new PriorityQueue(); 033 private List<Double> serializer_temp = new ArrayList(); 034 035 public void build(Configuration c, int label, Splitter splitter, List<Integer> example_indexes) { 036 root = new Node(example_indexes); 037 findAllSplits(c.max_leaves_per_tree, label, splitter); 038 } 039 040 public void build(Configuration c, int label, Splitter splitter, int[] example_indexes) { 041 root = new Node(example_indexes); 042 findAllSplits(c.max_leaves_per_tree, label, splitter); 043 } 044 045 public void printLeaves() { 046 Node node; 047 Queue<Node> nodeq = new LinkedList(); 048 nodeq.add(root); 049 while(!nodeq.isEmpty()) 050 if((node = nodeq.poll()).isLeaf()) 051 System.out.println(" leaf output: " + node.output); 052 else { 053 nodeq.add(node.true_child); 054 nodeq.add(node.false_child); 055 } 056 } 057 058 public double mse(Configuration c, int label, TrainingSequences ts, List<Integer> example_indexes) { 059 double mse = 0; 060 Node node; 061 for(int example_index : example_indexes) { 062 node = root; 063 while(!node.isLeaf()) 064 if(ts.hasTrueFeature(example_index, node.split_feature)) 065 node = node.true_child; 066 else 067 node = node.false_child; 068 mse += Math.pow(node.output - ts.getTarget(label, example_index), 2.0); 069 } 070 return mse / (double)example_indexes.size(); 071 } 072 073 public static double mse(Configuration c, int label, TrainingSequences ts, double[] st) { 074 int example_index, examples = ts.examples(); 075 double mse = 0; 076 for(example_index = 0;example_index < examples;example_index++) 077 mse += Math.pow(Tree.output(ts, example_index, st) - ts.getTarget(label, example_index), 2.0); 078 return mse / (double)examples; 079 } 080 081 private void findAllSplits(int max_leaves_per_tree, int label, Splitter splitter) { 082 Split split; 083 int leaves = 1; 084 splitter.initialize(label, root); 085 if((split = splitter.findBestSplit(root)) != null) { 086 splitq.add(split); 087 while(!splitq.isEmpty() && leaves < max_leaves_per_tree) { 088 for(Node child : splitter.acceptSplit(splitq.poll())) 089 if((split = splitter.findBestSplit(child)) != null) 090 splitq.add(split); 091 leaves++; 092 } 093 } 094 splitq.clear(); 095 } 096 097 public double[] serialize() { 098 Node node; 099 Integer child_pointer; 100 Queue<Node> nodeq = new LinkedList(); 101 Queue<Integer> child_pointerq = new LinkedList(); 102 nodeq.add(root); 103 child_pointerq.add(-1); 104 while((node = nodeq.poll()) != null) { 105 if((child_pointer = child_pointerq.poll()) >= 0) 106 serializer_temp.set(child_pointer, (double)serializer_temp.size()); 107 if(node.isLeaf()) { 108 serializer_temp.add(-1.0); 109 serializer_temp.add(node.output); 110 } else { 111 serializer_temp.add((double)node.split_feature); 112 nodeq.add(node.true_child); 113 child_pointerq.add(serializer_temp.size()); 114 serializer_temp.add(-1.0); 115 nodeq.add(node.false_child); 116 child_pointerq.add(serializer_temp.size()); 117 serializer_temp.add(-1.0); 118 } 119 } 120 return serializerTempToArray(); 121 } 122 123 private double[] serializerTempToArray() { 124 int i = 0; 125 double[] serialized_tree = new double[serializer_temp.size()]; 126 for(double d : serializer_temp) 127 serialized_tree[i++] = d; 128 serializer_temp.clear(); 129 return serialized_tree; 130 } 131 132 public static double output(Configuration c, ExpandedPointSequence eps, int expanded_sequence_position, double[] serialized_tree) { 133 return output(c, eps, expanded_sequence_position / c.label_count_including_start_label, 134 expanded_sequence_position % c.label_count_including_start_label, 135 serialized_tree); 136 } 137 138 public static double output(Configuration c, ExpandedPointSequence eps, int sequence_position, int previous_label, double[] serialized_tree) { 139 int i = 0; 140 while(i >= 0) 141 if(serialized_tree[i] < 0) 142 return serialized_tree[i + 1]; 143 else if(eps.hasTrueFeature(c, sequence_position, previous_label, (int)serialized_tree[i])) 144 i = (int)serialized_tree[i + 1]; 145 else 146 i = (int)serialized_tree[i + 2]; 147 return 0; 148 } 149 150 public static double output(TrainingSequences ts, int example_index, double[] serialized_tree) { 151 int i = 0; 152 while(i >= 0) 153 if(serialized_tree[i] < 0) 154 return serialized_tree[i + 1]; 155 else if(ts.hasTrueFeature(example_index, (int)serialized_tree[i])) 156 i = (int)serialized_tree[i + 1]; 157 else 158 i = (int)serialized_tree[i + 2]; 159 return 0; 160 } 161 162 public double output(Configuration c, ExpandedPointSequence eps, int sequence_position, int previous_label) { 163 Node node = root; 164 while(!node.isLeaf()) 165 if(eps.hasTrueFeature(c, sequence_position, previous_label, node.split_feature)) 166 node = node.true_child; 167 else 168 node = node.false_child; 169 return node.output; 170 } 171 }