001    // Copyright 2012, 2013 Brad Block, Pawjaw, LLC. (an Ohio Limited Liability Company)
002    // 
003    // This file is part of JBTCRF.
004    // 
005    // JBTCRF is free software: you can redistribute it and/or modify
006    // it under the terms of the GNU General Public License as published by
007    // the Free Software Foundation, either version 3 of the License, or
008    // (at your option) any later version.
009    // 
010    // JBTCRF is distributed in the hope that it will be useful,
011    // but WITHOUT ANY WARRANTY; without even the implied warranty of
012    // MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
013    // GNU General Public License for more details.
014    // 
015    // You should have received a copy of the GNU General Public License
016    // along with JBTCRF.  If not, see <http://www.gnu.org/licenses/>.
017    
018    package com.pawjaw.classification.crf.lmcbt.trees;
019    
020    import com.pawjaw.classification.crf.lmcbt.configurations.Configuration;
021    import com.pawjaw.classification.crf.lmcbt.sequences.ExpandedPointSequence;
022    import com.pawjaw.classification.crf.lmcbt.sequences.TrainingSequences;
023    import java.util.ArrayList;
024    import java.util.LinkedList;
025    import java.util.List;
026    import java.util.PriorityQueue;
027    import java.util.Queue;
028    
029    public class Tree {
030        public Node root;
031    
032        private Queue<Split> splitq = new PriorityQueue();
033        private List<Double> serializer_temp = new ArrayList();
034    
035        public void build(Configuration c, int label, Splitter splitter, List<Integer> example_indexes) {
036            root = new Node(example_indexes);
037            findAllSplits(c.max_leaves_per_tree, label, splitter);
038        }
039    
040        public void build(Configuration c, int label, Splitter splitter, int[] example_indexes) {
041            root = new Node(example_indexes);
042            findAllSplits(c.max_leaves_per_tree, label, splitter);
043        }
044    
045        public void printLeaves() {
046            Node node;
047            Queue<Node> nodeq = new LinkedList();
048            nodeq.add(root);
049            while(!nodeq.isEmpty())
050                if((node = nodeq.poll()).isLeaf())
051                    System.out.println("    leaf output: " + node.output);
052                else {
053                    nodeq.add(node.true_child);
054                    nodeq.add(node.false_child);
055                }
056        }
057    
058        public double mse(Configuration c, int label, TrainingSequences ts, List<Integer> example_indexes) {
059            double mse = 0;
060            Node node;
061            for(int example_index : example_indexes) {
062                node = root;
063                while(!node.isLeaf())
064                    if(ts.hasTrueFeature(example_index, node.split_feature))
065                        node = node.true_child;
066                    else
067                        node = node.false_child;
068                mse += Math.pow(node.output - ts.getTarget(label, example_index), 2.0);
069            }
070            return mse / (double)example_indexes.size();
071        }
072    
073        public static double mse(Configuration c, int label, TrainingSequences ts, double[] st) {
074            int example_index, examples = ts.examples();
075            double mse = 0;
076            for(example_index = 0;example_index < examples;example_index++)
077                mse += Math.pow(Tree.output(ts, example_index, st) - ts.getTarget(label, example_index), 2.0);
078            return mse / (double)examples;
079        }
080    
081        private void findAllSplits(int max_leaves_per_tree, int label, Splitter splitter) {
082            Split split;
083            int leaves = 1;
084            splitter.initialize(label, root);
085            if((split = splitter.findBestSplit(root)) != null) {
086                splitq.add(split);
087                while(!splitq.isEmpty() && leaves < max_leaves_per_tree) {
088                    for(Node child : splitter.acceptSplit(splitq.poll()))
089                        if((split = splitter.findBestSplit(child)) != null)
090                            splitq.add(split);
091                    leaves++;
092                }
093            }
094            splitq.clear();
095        }
096    
097        public double[] serialize() {
098            Node node;
099            Integer child_pointer;
100            Queue<Node> nodeq = new LinkedList();
101            Queue<Integer> child_pointerq = new LinkedList();
102            nodeq.add(root);
103            child_pointerq.add(-1);
104            while((node = nodeq.poll()) != null) {
105                if((child_pointer = child_pointerq.poll()) >= 0)
106                    serializer_temp.set(child_pointer, (double)serializer_temp.size());
107                if(node.isLeaf()) {
108                    serializer_temp.add(-1.0);
109                    serializer_temp.add(node.output);
110                } else {
111                    serializer_temp.add((double)node.split_feature);
112                    nodeq.add(node.true_child);
113                    child_pointerq.add(serializer_temp.size());
114                    serializer_temp.add(-1.0);
115                    nodeq.add(node.false_child);
116                    child_pointerq.add(serializer_temp.size());
117                    serializer_temp.add(-1.0);
118                }
119            }
120            return serializerTempToArray();
121        }
122    
123        private double[] serializerTempToArray() {
124            int i = 0;
125            double[] serialized_tree = new double[serializer_temp.size()];
126            for(double d : serializer_temp)
127                serialized_tree[i++] = d;
128            serializer_temp.clear();
129            return serialized_tree;
130        }
131    
132        public static double output(Configuration c, ExpandedPointSequence eps, int expanded_sequence_position, double[] serialized_tree) {
133            return output(c, eps, expanded_sequence_position / c.label_count_including_start_label,
134                    expanded_sequence_position % c.label_count_including_start_label,
135                    serialized_tree);
136        }
137    
138        public static double output(Configuration c, ExpandedPointSequence eps, int sequence_position, int previous_label, double[] serialized_tree) {
139            int i = 0;
140            while(i >= 0)
141                if(serialized_tree[i] < 0)
142                    return serialized_tree[i + 1];
143                else if(eps.hasTrueFeature(c, sequence_position, previous_label, (int)serialized_tree[i]))
144                    i = (int)serialized_tree[i + 1];
145                else
146                    i = (int)serialized_tree[i + 2];
147            return 0;
148        }
149    
150        public static double output(TrainingSequences ts, int example_index, double[] serialized_tree) {
151            int i = 0;
152            while(i >= 0)
153                if(serialized_tree[i] < 0)
154                    return serialized_tree[i + 1];
155                else if(ts.hasTrueFeature(example_index, (int)serialized_tree[i]))
156                    i = (int)serialized_tree[i + 1];
157                else
158                    i = (int)serialized_tree[i + 2];
159            return 0;
160        }
161    
162        public double output(Configuration c, ExpandedPointSequence eps, int sequence_position, int previous_label) {
163            Node node = root;
164            while(!node.isLeaf())
165                if(eps.hasTrueFeature(c, sequence_position, previous_label, node.split_feature))
166                    node = node.true_child;
167                else
168                    node = node.false_child;
169            return node.output;
170        }
171    }