001 // Copyright 2012, 2013 Brad Block, Pawjaw, LLC. (an Ohio Limited Liability Company) 002 // 003 // This file is part of JBTCRF. 004 // 005 // JBTCRF is free software: you can redistribute it and/or modify 006 // it under the terms of the GNU General Public License as published by 007 // the Free Software Foundation, either version 3 of the License, or 008 // (at your option) any later version. 009 // 010 // JBTCRF is distributed in the hope that it will be useful, 011 // but WITHOUT ANY WARRANTY; without even the implied warranty of 012 // MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 013 // GNU General Public License for more details. 014 // 015 // You should have received a copy of the GNU General Public License 016 // along with JBTCRF. If not, see <http://www.gnu.org/licenses/>. 017 018 package com.pawjaw.classification.crf.lmcbt; 019 020 import com.pawjaw.classification.crf.lmcbt.configurations.Configuration; 021 import com.pawjaw.classification.crf.lmcbt.trees.BoostedTrees; 022 import com.pawjaw.classification.crf.lmcbt.sequences.TrainingSequences; 023 import com.pawjaw.classification.crf.lmcbt.points.Point; 024 import java.util.List; 025 import java.util.concurrent.Semaphore; 026 027 /** 028 * Main entry point for training new CRF model. Following instantiation, 029 * set all training examples using set() method, then run one of the 030 * train() methods. 031 * 032 */ 033 public class CRFTrainer { 034 private Configuration c; 035 private TrainingSequences ts; 036 private BoostedTrees[] bt_k; 037 038 /** 039 * Instantiate this, then use set() repeatedly finally followed by train(). 040 * Same configuration should be used for inference. 041 * 042 * @param c options for training specified here 043 * @param example_sequences number of example sequences to be used in training 044 */ 045 public CRFTrainer(Configuration c, int example_sequences) { 046 this.c = c; 047 ts = new TrainingSequences(); 048 ts.allocate(c, example_sequences); 049 bt_k = new BoostedTrees[c.label_count_including_start_label]; 050 } 051 052 /** 053 * Set a new training example. sequence_index ranges up to example_sequences 054 * specified in constructor. True labels range to the value specified in 055 * the configuraton given to the constructor. 056 * 057 * @param sequence_index ranges from 0 to example_sequences - 1 (inclusive) 058 * @param points data points in the given sequence 059 * @param labels true labels for the given sequence 060 */ 061 public void set(int sequence_index, Point[] points, int[] labels) { 062 ts.set(sequence_index, points, labels); 063 } 064 065 /** 066 * Train a new CRF model with the number of boosting iterations specified 067 * in the configuration given to the constructor. 068 */ 069 public void train() { 070 int k, K = bt_k.length, m; 071 Semaphore update_complete = new Semaphore(0); 072 Semaphore proceed = new Semaphore(c.boosted_tree_threads); 073 ts.indexExamples(); 074 if(c.cache_expanded_true_features) 075 ts.cacheTrueFeatures(); 076 for(k = 0;k < K;k++) 077 (bt_k[k] = new BoostedTrees(c, ts, k, proceed, update_complete)).start(); 078 for(m = 1;m <= c.max_boosting_iterations;m++) { 079 if(c.report_training_progress) 080 System.err.println("boosting iteration: " + m + " / " + c.max_boosting_iterations); 081 ts.update(); 082 for(k = 0;k < K;k++) 083 bt_k[k].setUpdateable(); 084 try { 085 update_complete.acquire(K); 086 } catch(InterruptedException e) { 087 break; 088 } 089 } 090 for(k = 0;k < K;k++) 091 bt_k[k].finished(); 092 } 093 094 /** 095 * Train a new CRF model with the maximum number of boosting iterations 096 * specified in the configuration given to the constructor but with the 097 * option to stop early if the incremental improvement in Viterbi accuracy 098 * is less than the amount specified in the configuration. Use the 099 * arguments to this method to validate the model for accuracy. 100 * 101 * @param pointss validation data points 102 * @param true_labelss validation labels 103 */ 104 public void train(List<Point[]> pointss, List<int[]> true_labelss) { 105 int i = c.boosting_iterations_between_tests, k, K = bt_k.length, m; 106 double accuracy, previous_accuracy = 0; 107 CRFInference crfi = new CRFInference(); 108 Semaphore update_complete = new Semaphore(0); 109 Semaphore proceed = new Semaphore(c.boosted_tree_threads); 110 ts.indexExamples(); 111 if(c.cache_expanded_true_features) 112 ts.cacheTrueFeatures(); 113 for(k = 0;k < K;k++) 114 (bt_k[k] = new BoostedTrees(c, ts, k, proceed, update_complete)).start(); 115 for(m = 1;m <= c.max_boosting_iterations;m++) { 116 if(c.report_training_progress) 117 System.err.println("boosting iteration: " + m + " / " + c.max_boosting_iterations); 118 if(m >= c.min_boosting_iterations && i++ == c.boosting_iterations_between_tests) { 119 crfi.set(c, serialize()); 120 accuracy = crfi.accuracyViterbi(pointss, true_labelss); 121 if(c.report_training_progress) 122 System.err.println("viterbi validation accuracy at iteration " + m + ": " + accuracy); 123 if(previous_accuracy > 0 && (accuracy - previous_accuracy) / previous_accuracy < c.min_relative_accuracy_improvement) { 124 System.err.println("stopping early, relative accuracy gained: " + ((accuracy - previous_accuracy) / previous_accuracy)); 125 break; 126 } else { 127 i = 1; 128 previous_accuracy = accuracy; 129 } 130 } 131 ts.update(); 132 for(k = 0;k < K;k++) 133 bt_k[k].setUpdateable(); 134 try { 135 update_complete.acquire(K); 136 } catch(InterruptedException e) { 137 break; 138 } 139 } 140 for(k = 0;k < K;k++) 141 bt_k[k].finished(); 142 } 143 144 /** 145 * Serialize trained CRF model to be written to disk and/or used for 146 * subsequent inference (labeling new data). 147 * 148 * @return CRF model for subsequent inference 149 */ 150 public double[][][] serialize() { 151 int k, K = bt_k.length; 152 double[][][] sbt_k = new double[K][][]; 153 for(k = 0;k < K;k++) 154 sbt_k[k] = bt_k[k].serialize(); 155 return sbt_k; 156 } 157 }