001    // Copyright 2012, 2013 Brad Block, Pawjaw, LLC. (an Ohio Limited Liability Company)
002    // 
003    // This file is part of JBTCRF.
004    // 
005    // JBTCRF is free software: you can redistribute it and/or modify
006    // it under the terms of the GNU General Public License as published by
007    // the Free Software Foundation, either version 3 of the License, or
008    // (at your option) any later version.
009    // 
010    // JBTCRF is distributed in the hope that it will be useful,
011    // but WITHOUT ANY WARRANTY; without even the implied warranty of
012    // MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
013    // GNU General Public License for more details.
014    // 
015    // You should have received a copy of the GNU General Public License
016    // along with JBTCRF.  If not, see <http://www.gnu.org/licenses/>.
017    
018    package com.pawjaw.classification.crf.lmcbt.trees;
019    
020    import com.pawjaw.classification.crf.lmcbt.configurations.Configuration;
021    import com.pawjaw.classification.crf.lmcbt.sequences.TrainingSequences;
022    import java.util.ArrayList;
023    import java.util.Arrays;
024    import java.util.List;
025    import java.util.concurrent.Semaphore;
026    
027    public class Splitter {
028        private int features, count, label;
029        private double sum;
030        private Configuration c;
031        private TrainingSequences ts;
032        private Semaphore proceed, update_complete;
033        private int[] count_f, example_indexes;
034        private double[] sum_f;
035        private StatsThread[] sts;
036    
037        public Splitter(Configuration c, TrainingSequences ts) {
038            int i;
039            this.c = c;
040            this.ts = ts;
041            features = ts.features();
042            count_f = new int[features];
043            sum_f = new double[features];
044            update_complete = new Semaphore(0);
045            sts = new StatsThread[c.splitter_threads];
046            proceed = new Semaphore(c.splitter_threads);
047            for(i = 0;i < c.splitter_threads;i++)
048                (sts[i] = new StatsThread(i)).start();
049        }
050    
051        public void initialize(int label, Node root) {
052            double sum = 0;
053            this.label = label;
054            if(root.example_indexes.length == 0)
055                root.output = 0;
056            else {
057                for(int example_index : root.example_indexes)
058                    sum += ts.getTarget(label, example_index);
059                root.output = sum / (c.regression_tree_shrinkage + root.example_indexes.length);
060            }
061        }
062    
063        private class StatsThread extends Thread {
064            private int id;
065            private Semaphore updateable = new Semaphore(0);
066            private List<Integer> feature_indexes = new ArrayList();
067    
068            public double _sum;
069            public int[] _count_f = new int[features];
070            public double[] _sum_f = new double[features];
071    
072            public StatsThread(int id) {
073                this.id = id;
074            }
075    
076            public void run() {
077                int i, I, length;
078                double target;
079                try {
080                    while(true) {
081                        updateable.acquire();
082                        proceed.acquire();
083                        _sum = 0;
084                        Arrays.fill(_sum_f, 0);
085                        Arrays.fill(_count_f, 0);
086                        length = (int)Math.ceil((double)example_indexes.length / (double)sts.length);
087                        I = Math.min(example_indexes.length, length * (id + 1));
088                        for(i = length * id;i < I;i++) {
089                            _sum += target = ts.getTarget(label, example_indexes[i]);
090                            if(c.cache_expanded_true_features)
091                                for(int feature_index : ts.getCachedTrueFeatures(example_indexes[i])) {
092                                    _count_f[feature_index]++;
093                                    _sum_f[feature_index] += target;
094                                }
095                            else
096                                for(int feature_index : ts.getTrueFeatures(example_indexes[i], feature_indexes)) {
097                                    _count_f[feature_index]++;
098                                    _sum_f[feature_index] += target;
099                                }
100                        }
101                        proceed.release();
102                        update_complete.release();
103                    }
104                } catch(InterruptedException e) {}
105            }
106    
107            public void setUpdateable() {
108                updateable.release();
109            }
110    
111            public void finished() {
112                this.interrupt();
113            }
114        }
115    
116        public void finished() {
117            int i;
118            for(i = 0;i < sts.length;i++)
119                sts[i].finished();
120        }
121    
122        private void collectStatistics(int[] example_indexes) {
123            int i, j;
124            this.example_indexes = example_indexes;
125            for(i = 0;i < sts.length;i++)
126                sts[i].setUpdateable();
127            Arrays.fill(sum_f, 0);
128            Arrays.fill(count_f, 0);
129            try {
130                update_complete.acquire(sts.length);
131            } catch(InterruptedException e) {
132                throw new RuntimeException(e);
133            }
134            sum = 0;
135            count = example_indexes.length;
136            for(i = 0;i < features;i++) {
137                for(j = 0;j < sts.length;j++) {
138                    sum_f[i] += sts[j]._sum_f[i];
139                    count_f[i] += sts[j]._count_f[i];
140                }
141            }
142            for(j = 0;j < sts.length;j++)
143                sum += sts[j]._sum;
144        }
145    
146        public Split findBestSplit(Node parent) {
147            int f;
148            double true_count, false_count, gain;
149            Split s = new Split(parent);
150            collectStatistics(parent.example_indexes);
151            for(f = 0;f < features;f++)
152                if((true_count = count_f[f]) >= 2 && (false_count = count - count_f[2]) >= 2) {
153                    s.true_child_output = sum_f[f] / (c.regression_tree_shrinkage + true_count);
154                    s.false_child_output = (sum - sum_f[f]) / (c.regression_tree_shrinkage + false_count);
155                    gain = (2.0 * c.regression_tree_shrinkage + true_count) * s.true_child_output * s.true_child_output +
156                            (2.0 * c.regression_tree_shrinkage + false_count) * s.false_child_output * s.false_child_output -
157                            (2.0 * c.regression_tree_shrinkage + count) * parent.output * parent.output;
158                    if(gain > s.gain) {
159                        s.split_feature = f;
160                        s.gain = gain;
161                    }
162                }
163            return s.isValid() ? s : null;
164        }
165    
166        public Node[] acceptSplit(Split s) {
167            List<Integer> true_child_example_indexes = new ArrayList();
168            List<Integer> false_child_example_indexes = new ArrayList();
169            for(int example_index : s.parent.example_indexes)
170                if(ts.hasTrueFeature(example_index, s.split_feature))
171                    true_child_example_indexes.add(example_index);
172                else
173                    false_child_example_indexes.add(example_index);
174            s.parent.example_indexes = null;
175            s.parent.split_feature = s.split_feature;
176            return new Node[] {
177                s.parent.true_child = new Node(s.true_child_output, true_child_example_indexes),
178                s.parent.false_child = new Node(s.false_child_output, false_child_example_indexes)
179            };
180        }
181    }