001 // Copyright 2012, 2013 Brad Block, Pawjaw, LLC. (an Ohio Limited Liability Company) 002 // 003 // This file is part of JBTCRF. 004 // 005 // JBTCRF is free software: you can redistribute it and/or modify 006 // it under the terms of the GNU General Public License as published by 007 // the Free Software Foundation, either version 3 of the License, or 008 // (at your option) any later version. 009 // 010 // JBTCRF is distributed in the hope that it will be useful, 011 // but WITHOUT ANY WARRANTY; without even the implied warranty of 012 // MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 013 // GNU General Public License for more details. 014 // 015 // You should have received a copy of the GNU General Public License 016 // along with JBTCRF. If not, see <http://www.gnu.org/licenses/>. 017 018 package com.pawjaw.classification.crf.lmcbt.trees; 019 020 import com.pawjaw.classification.crf.lmcbt.configurations.Configuration; 021 import com.pawjaw.classification.crf.lmcbt.sequences.TrainingSequences; 022 import java.util.ArrayList; 023 import java.util.Arrays; 024 import java.util.List; 025 import java.util.concurrent.Semaphore; 026 027 public class Splitter { 028 private int features, count, label; 029 private double sum; 030 private Configuration c; 031 private TrainingSequences ts; 032 private Semaphore proceed, update_complete; 033 private int[] count_f, example_indexes; 034 private double[] sum_f; 035 private StatsThread[] sts; 036 037 public Splitter(Configuration c, TrainingSequences ts) { 038 int i; 039 this.c = c; 040 this.ts = ts; 041 features = ts.features(); 042 count_f = new int[features]; 043 sum_f = new double[features]; 044 update_complete = new Semaphore(0); 045 sts = new StatsThread[c.splitter_threads]; 046 proceed = new Semaphore(c.splitter_threads); 047 for(i = 0;i < c.splitter_threads;i++) 048 (sts[i] = new StatsThread(i)).start(); 049 } 050 051 public void initialize(int label, Node root) { 052 double sum = 0; 053 this.label = label; 054 if(root.example_indexes.length == 0) 055 root.output = 0; 056 else { 057 for(int example_index : root.example_indexes) 058 sum += ts.getTarget(label, example_index); 059 root.output = sum / (c.regression_tree_shrinkage + root.example_indexes.length); 060 } 061 } 062 063 private class StatsThread extends Thread { 064 private int id; 065 private Semaphore updateable = new Semaphore(0); 066 private List<Integer> feature_indexes = new ArrayList(); 067 068 public double _sum; 069 public int[] _count_f = new int[features]; 070 public double[] _sum_f = new double[features]; 071 072 public StatsThread(int id) { 073 this.id = id; 074 } 075 076 public void run() { 077 int i, I, length; 078 double target; 079 try { 080 while(true) { 081 updateable.acquire(); 082 proceed.acquire(); 083 _sum = 0; 084 Arrays.fill(_sum_f, 0); 085 Arrays.fill(_count_f, 0); 086 length = (int)Math.ceil((double)example_indexes.length / (double)sts.length); 087 I = Math.min(example_indexes.length, length * (id + 1)); 088 for(i = length * id;i < I;i++) { 089 _sum += target = ts.getTarget(label, example_indexes[i]); 090 if(c.cache_expanded_true_features) 091 for(int feature_index : ts.getCachedTrueFeatures(example_indexes[i])) { 092 _count_f[feature_index]++; 093 _sum_f[feature_index] += target; 094 } 095 else 096 for(int feature_index : ts.getTrueFeatures(example_indexes[i], feature_indexes)) { 097 _count_f[feature_index]++; 098 _sum_f[feature_index] += target; 099 } 100 } 101 proceed.release(); 102 update_complete.release(); 103 } 104 } catch(InterruptedException e) {} 105 } 106 107 public void setUpdateable() { 108 updateable.release(); 109 } 110 111 public void finished() { 112 this.interrupt(); 113 } 114 } 115 116 public void finished() { 117 int i; 118 for(i = 0;i < sts.length;i++) 119 sts[i].finished(); 120 } 121 122 private void collectStatistics(int[] example_indexes) { 123 int i, j; 124 this.example_indexes = example_indexes; 125 for(i = 0;i < sts.length;i++) 126 sts[i].setUpdateable(); 127 Arrays.fill(sum_f, 0); 128 Arrays.fill(count_f, 0); 129 try { 130 update_complete.acquire(sts.length); 131 } catch(InterruptedException e) { 132 throw new RuntimeException(e); 133 } 134 sum = 0; 135 count = example_indexes.length; 136 for(i = 0;i < features;i++) { 137 for(j = 0;j < sts.length;j++) { 138 sum_f[i] += sts[j]._sum_f[i]; 139 count_f[i] += sts[j]._count_f[i]; 140 } 141 } 142 for(j = 0;j < sts.length;j++) 143 sum += sts[j]._sum; 144 } 145 146 public Split findBestSplit(Node parent) { 147 int f; 148 double true_count, false_count, gain; 149 Split s = new Split(parent); 150 collectStatistics(parent.example_indexes); 151 for(f = 0;f < features;f++) 152 if((true_count = count_f[f]) >= 2 && (false_count = count - count_f[2]) >= 2) { 153 s.true_child_output = sum_f[f] / (c.regression_tree_shrinkage + true_count); 154 s.false_child_output = (sum - sum_f[f]) / (c.regression_tree_shrinkage + false_count); 155 gain = (2.0 * c.regression_tree_shrinkage + true_count) * s.true_child_output * s.true_child_output + 156 (2.0 * c.regression_tree_shrinkage + false_count) * s.false_child_output * s.false_child_output - 157 (2.0 * c.regression_tree_shrinkage + count) * parent.output * parent.output; 158 if(gain > s.gain) { 159 s.split_feature = f; 160 s.gain = gain; 161 } 162 } 163 return s.isValid() ? s : null; 164 } 165 166 public Node[] acceptSplit(Split s) { 167 List<Integer> true_child_example_indexes = new ArrayList(); 168 List<Integer> false_child_example_indexes = new ArrayList(); 169 for(int example_index : s.parent.example_indexes) 170 if(ts.hasTrueFeature(example_index, s.split_feature)) 171 true_child_example_indexes.add(example_index); 172 else 173 false_child_example_indexes.add(example_index); 174 s.parent.example_indexes = null; 175 s.parent.split_feature = s.split_feature; 176 return new Node[] { 177 s.parent.true_child = new Node(s.true_child_output, true_child_example_indexes), 178 s.parent.false_child = new Node(s.false_child_output, false_child_example_indexes) 179 }; 180 } 181 }