001    // Copyright 2012, 2013 Brad Block, Pawjaw, LLC. (an Ohio Limited Liability Company)
002    // 
003    // This file is part of JBTCRF.
004    // 
005    // JBTCRF is free software: you can redistribute it and/or modify
006    // it under the terms of the GNU General Public License as published by
007    // the Free Software Foundation, either version 3 of the License, or
008    // (at your option) any later version.
009    // 
010    // JBTCRF is distributed in the hope that it will be useful,
011    // but WITHOUT ANY WARRANTY; without even the implied warranty of
012    // MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
013    // GNU General Public License for more details.
014    // 
015    // You should have received a copy of the GNU General Public License
016    // along with JBTCRF.  If not, see <http://www.gnu.org/licenses/>.
017    
018    package com.pawjaw.classification.crf.lmcbt.trees;
019    
020    import com.pawjaw.classification.crf.lmcbt.configurations.Configuration;
021    import com.pawjaw.classification.crf.lmcbt.sequences.ExpandedPointSequence;
022    import com.pawjaw.classification.crf.lmcbt.sequences.TrainingSequences;
023    import java.util.ArrayList;
024    import java.util.LinkedList;
025    import java.util.List;
026    import java.util.Queue;
027    import java.util.concurrent.Semaphore;
028    
029    public class BoostedTrees extends Thread {
030        private int label;
031        private int examples;
032        private Tree tree;
033        private Configuration c;
034        private Splitter splitter;
035        private TrainingSequences ts;
036        private Semaphore updateable, proceed, update_complete;
037        private List<double[]> serialized_trees = new ArrayList();
038        private int[] training_examples;
039    
040        public BoostedTrees(Configuration c, TrainingSequences ts, int label, Semaphore proceed, Semaphore update_complete) {
041            int i;
042            this.c = c;
043            this.ts = ts;
044            this.label = label;
045            this.proceed = proceed;
046            this.update_complete = update_complete;
047            tree = new Tree();
048            examples = ts.examples();
049            updateable = new Semaphore(0);
050            splitter = new Splitter(c, ts);
051            training_examples = new int[examples];
052            for(i = 0;i < examples;i++)
053                training_examples[i] = i;
054        }
055    
056        private void buildTree() {
057            tree.build(c, label, splitter, training_examples);
058            serialized_trees.add(tree.serialize());
059        }
060    
061        private void updatePotentials() {
062            Node node;
063            Queue<Node> nodeq = new LinkedList();
064            nodeq.add(tree.root);
065            while((node = nodeq.poll()) != null) {
066                if(node.isLeaf())
067                    for(int example_index : node.example_indexes)
068                        ts.incrementPotential(label, example_index, node.output);
069                else {
070                    nodeq.add(node.true_child);
071                    nodeq.add(node.false_child);
072                }
073            }
074        }
075    
076        public void setUpdateable() {
077            updateable.release();
078        }
079    
080        public void finished() {
081            splitter.finished();
082            this.interrupt();
083        }
084    
085        public void run() {
086            try {
087                while(true) {
088                    updateable.acquire();
089                    proceed.acquire();
090                    buildTree();
091                    updatePotentials();
092                    proceed.release();
093                    update_complete.release();
094                }
095            } catch(InterruptedException e) {}
096        }
097    
098        public double[][] serialize() {
099            int i = 0;
100            double[][] serialized_trees = new double[this.serialized_trees.size()][];
101            for(double[] serialized_tree : this.serialized_trees)
102                serialized_trees[i++] = serialized_tree;
103            return serialized_trees;
104        }
105    
106        public static double boostedOutput(Configuration c, ExpandedPointSequence eps, int expanded_sequence_position, double[][] serialized_trees) {
107            double potential = 0;
108            for(double[] serialized_tree : serialized_trees)
109                potential += Tree.output(c, eps, expanded_sequence_position, serialized_tree);
110            return potential;
111        }
112    
113        public static double boostedOutput(Configuration c, ExpandedPointSequence eps, int sequence_position, int previous_label, double[][] serialized_trees) {
114            double potential = 0;
115            for(double[] serialized_tree : serialized_trees)
116                potential += Tree.output(c, eps, sequence_position, previous_label, serialized_tree);
117            return potential;
118        }
119    }