001 // Copyright 2012, 2013 Brad Block, Pawjaw, LLC. (an Ohio Limited Liability Company) 002 // 003 // This file is part of JBTCRF. 004 // 005 // JBTCRF is free software: you can redistribute it and/or modify 006 // it under the terms of the GNU General Public License as published by 007 // the Free Software Foundation, either version 3 of the License, or 008 // (at your option) any later version. 009 // 010 // JBTCRF is distributed in the hope that it will be useful, 011 // but WITHOUT ANY WARRANTY; without even the implied warranty of 012 // MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 013 // GNU General Public License for more details. 014 // 015 // You should have received a copy of the GNU General Public License 016 // along with JBTCRF. If not, see <http://www.gnu.org/licenses/>. 017 018 package com.pawjaw.classification.crf.lmcbt.trees; 019 020 import com.pawjaw.classification.crf.lmcbt.configurations.Configuration; 021 import com.pawjaw.classification.crf.lmcbt.sequences.ExpandedPointSequence; 022 import com.pawjaw.classification.crf.lmcbt.sequences.TrainingSequences; 023 import java.util.ArrayList; 024 import java.util.LinkedList; 025 import java.util.List; 026 import java.util.Queue; 027 import java.util.concurrent.Semaphore; 028 029 public class BoostedTrees extends Thread { 030 private int label; 031 private int examples; 032 private Tree tree; 033 private Configuration c; 034 private Splitter splitter; 035 private TrainingSequences ts; 036 private Semaphore updateable, proceed, update_complete; 037 private List<double[]> serialized_trees = new ArrayList(); 038 private int[] training_examples; 039 040 public BoostedTrees(Configuration c, TrainingSequences ts, int label, Semaphore proceed, Semaphore update_complete) { 041 int i; 042 this.c = c; 043 this.ts = ts; 044 this.label = label; 045 this.proceed = proceed; 046 this.update_complete = update_complete; 047 tree = new Tree(); 048 examples = ts.examples(); 049 updateable = new Semaphore(0); 050 splitter = new Splitter(c, ts); 051 training_examples = new int[examples]; 052 for(i = 0;i < examples;i++) 053 training_examples[i] = i; 054 } 055 056 private void buildTree() { 057 tree.build(c, label, splitter, training_examples); 058 serialized_trees.add(tree.serialize()); 059 } 060 061 private void updatePotentials() { 062 Node node; 063 Queue<Node> nodeq = new LinkedList(); 064 nodeq.add(tree.root); 065 while((node = nodeq.poll()) != null) { 066 if(node.isLeaf()) 067 for(int example_index : node.example_indexes) 068 ts.incrementPotential(label, example_index, node.output); 069 else { 070 nodeq.add(node.true_child); 071 nodeq.add(node.false_child); 072 } 073 } 074 } 075 076 public void setUpdateable() { 077 updateable.release(); 078 } 079 080 public void finished() { 081 splitter.finished(); 082 this.interrupt(); 083 } 084 085 public void run() { 086 try { 087 while(true) { 088 updateable.acquire(); 089 proceed.acquire(); 090 buildTree(); 091 updatePotentials(); 092 proceed.release(); 093 update_complete.release(); 094 } 095 } catch(InterruptedException e) {} 096 } 097 098 public double[][] serialize() { 099 int i = 0; 100 double[][] serialized_trees = new double[this.serialized_trees.size()][]; 101 for(double[] serialized_tree : this.serialized_trees) 102 serialized_trees[i++] = serialized_tree; 103 return serialized_trees; 104 } 105 106 public static double boostedOutput(Configuration c, ExpandedPointSequence eps, int expanded_sequence_position, double[][] serialized_trees) { 107 double potential = 0; 108 for(double[] serialized_tree : serialized_trees) 109 potential += Tree.output(c, eps, expanded_sequence_position, serialized_tree); 110 return potential; 111 } 112 113 public static double boostedOutput(Configuration c, ExpandedPointSequence eps, int sequence_position, int previous_label, double[][] serialized_trees) { 114 double potential = 0; 115 for(double[] serialized_tree : serialized_trees) 116 potential += Tree.output(c, eps, sequence_position, previous_label, serialized_tree); 117 return potential; 118 } 119 }