001 // Copyright 2012, 2013 Brad Block, Pawjaw, LLC. (an Ohio Limited Liability Company) 002 // 003 // This file is part of JBTCRF. 004 // 005 // JBTCRF is free software: you can redistribute it and/or modify 006 // it under the terms of the GNU General Public License as published by 007 // the Free Software Foundation, either version 3 of the License, or 008 // (at your option) any later version. 009 // 010 // JBTCRF is distributed in the hope that it will be useful, 011 // but WITHOUT ANY WARRANTY; without even the implied warranty of 012 // MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 013 // GNU General Public License for more details. 014 // 015 // You should have received a copy of the GNU General Public License 016 // along with JBTCRF. If not, see <http://www.gnu.org/licenses/>. 017 018 package com.pawjaw.classification.crf.lmcbt; 019 020 import com.pawjaw.classification.crf.lmcbt.configurations.Configuration; 021 import com.pawjaw.classification.crf.lmcbt.points.Point; 022 import com.pawjaw.classification.crf.lmcbt.sequences.TrainingSequences; 023 import com.pawjaw.classification.crf.lmcbt.trees.BoostedTrees; 024 import java.util.ArrayList; 025 import java.util.Collections; 026 import java.util.List; 027 import java.util.concurrent.Semaphore; 028 029 public class CRFValidator { 030 private double training_percent; 031 private Configuration c; 032 private Semaphore proceed; 033 private TrainingSequences ts; 034 private List<Point[]> all_pointss; 035 private List<int[]> all_true_labelss; 036 private Semaphore update_complete = new Semaphore(0); 037 private List<Integer> example_indexes = new ArrayList(); 038 private List<Point[]> validation_points = new ArrayList(); 039 private List<int[]> validation_true_labels = new ArrayList(); 040 private BoostedTrees[] bt_k; 041 042 public List<Double> accuracies = new ArrayList(); 043 public List<Integer> boosting_iterations = new ArrayList(); 044 045 public CRFValidator(Configuration c, List<Point[]> pointss, List<int[]> true_labelss) { 046 int i, I = pointss.size(); 047 this.c = c; 048 this.all_pointss = pointss; 049 this.all_true_labelss = true_labelss; 050 proceed = new Semaphore(c.boosted_tree_threads); 051 bt_k = new BoostedTrees[c.label_count_including_start_label]; 052 for(i = 0;i < I;i++) 053 example_indexes.add(i); 054 } 055 056 public void run(int cross_validation_passes, double training_percent) { 057 int i; 058 this.training_percent = training_percent; 059 for(i = 0;i < cross_validation_passes;i++) { 060 setupPass(); 061 runPass(i); 062 } 063 } 064 065 private void setupPass() { 066 int i = 0, I, J = (int)Math.ceil((double)(I = example_indexes.size()) * training_percent); 067 int k, K = bt_k.length, example_index; 068 Collections.sort(example_indexes); 069 ts = new TrainingSequences(); 070 ts.allocate(c, J); 071 validation_points.clear(); 072 validation_true_labels.clear(); 073 while(i < J) { 074 example_index = example_indexes.get(i); 075 ts.set(i, all_pointss.get(example_index), all_true_labelss.get(example_index)); 076 i++; 077 } 078 while(i < I) { 079 example_index = example_indexes.get(i); 080 validation_points.add(all_pointss.get(example_index)); 081 validation_true_labels.add(all_true_labelss.get(example_index)); 082 i++; 083 } 084 ts.indexExamples(); 085 if(c.cache_expanded_true_features) 086 ts.cacheTrueFeatures(); 087 for(k = 0;k < K;k++) 088 (bt_k[k] = new BoostedTrees(c, ts, k, proceed, update_complete)).start(); 089 } 090 091 private double[][][] serialize() { 092 int k, K = bt_k.length; 093 double[][][] sbt_k = new double[K][][]; 094 for(k = 0;k < K;k++) 095 sbt_k[k] = bt_k[k].serialize(); 096 return sbt_k; 097 } 098 099 private void runPass(int cross_validation_pass) { 100 int k, K = bt_k.length, m; 101 double accuracy; 102 CRFInference crfi = new CRFInference(); 103 for(m = 1;m <= c.max_boosting_iterations;m++) { 104 crfi.set(c, serialize()); 105 boosting_iterations.add(m); 106 accuracies.add(accuracy = crfi.accuracyViterbi(validation_points, validation_true_labels)); 107 System.out.println("cross validation pass: " + cross_validation_pass + ", iterations: " + m + ", accuracy: " + accuracy); 108 ts.update(); 109 for(k = 0;k < K;k++) 110 bt_k[k].setUpdateable(); 111 try { 112 update_complete.acquire(K); 113 } catch(InterruptedException e) { 114 break; 115 } 116 } 117 for(k = 0;k < K;k++) 118 bt_k[k].finished(); 119 } 120 }