001    // Copyright 2012, 2013 Brad Block, Pawjaw, LLC. (an Ohio Limited Liability Company)
002    // 
003    // This file is part of JBTCRF.
004    // 
005    // JBTCRF is free software: you can redistribute it and/or modify
006    // it under the terms of the GNU General Public License as published by
007    // the Free Software Foundation, either version 3 of the License, or
008    // (at your option) any later version.
009    // 
010    // JBTCRF is distributed in the hope that it will be useful,
011    // but WITHOUT ANY WARRANTY; without even the implied warranty of
012    // MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
013    // GNU General Public License for more details.
014    // 
015    // You should have received a copy of the GNU General Public License
016    // along with JBTCRF.  If not, see <http://www.gnu.org/licenses/>.
017    
018    package com.pawjaw.classification.crf.lmcbt;
019    
020    import com.pawjaw.classification.crf.lmcbt.configurations.Configuration;
021    import com.pawjaw.classification.crf.lmcbt.points.Point;
022    import com.pawjaw.classification.crf.lmcbt.sequences.TrainingSequences;
023    import com.pawjaw.classification.crf.lmcbt.trees.BoostedTrees;
024    import java.util.ArrayList;
025    import java.util.Collections;
026    import java.util.List;
027    import java.util.concurrent.Semaphore;
028    
029    public class CRFValidator {
030        private double training_percent;
031        private Configuration c;
032        private Semaphore proceed;
033        private TrainingSequences ts;
034        private List<Point[]> all_pointss;
035        private List<int[]> all_true_labelss;
036        private Semaphore update_complete = new Semaphore(0);
037        private List<Integer> example_indexes = new ArrayList();
038        private List<Point[]> validation_points = new ArrayList();
039        private List<int[]> validation_true_labels = new ArrayList();
040        private BoostedTrees[] bt_k;
041    
042        public List<Double> accuracies = new ArrayList();
043        public List<Integer> boosting_iterations = new ArrayList();
044    
045        public CRFValidator(Configuration c, List<Point[]> pointss, List<int[]> true_labelss) {
046            int i, I = pointss.size();
047            this.c = c;
048            this.all_pointss = pointss;
049            this.all_true_labelss = true_labelss;
050            proceed = new Semaphore(c.boosted_tree_threads);
051            bt_k = new BoostedTrees[c.label_count_including_start_label];
052            for(i = 0;i < I;i++)
053                example_indexes.add(i);
054        }
055        
056        public void run(int cross_validation_passes, double training_percent) {
057            int i;
058            this.training_percent = training_percent;
059            for(i = 0;i < cross_validation_passes;i++) {
060                setupPass();
061                runPass(i);
062            }
063        }
064    
065        private void setupPass() {
066            int i = 0, I, J = (int)Math.ceil((double)(I = example_indexes.size()) * training_percent);
067            int k, K = bt_k.length, example_index;
068            Collections.sort(example_indexes);
069            ts = new TrainingSequences();
070            ts.allocate(c, J);
071            validation_points.clear();
072            validation_true_labels.clear();
073            while(i < J) {
074                example_index = example_indexes.get(i);
075                ts.set(i, all_pointss.get(example_index), all_true_labelss.get(example_index));
076                i++;
077            }
078            while(i < I) {
079                example_index = example_indexes.get(i);
080                validation_points.add(all_pointss.get(example_index));
081                validation_true_labels.add(all_true_labelss.get(example_index));
082                i++;
083            }
084            ts.indexExamples();
085            if(c.cache_expanded_true_features)
086                ts.cacheTrueFeatures();
087            for(k = 0;k < K;k++)
088                (bt_k[k] = new BoostedTrees(c, ts, k, proceed, update_complete)).start();
089        }
090    
091        private double[][][] serialize() {
092            int k, K = bt_k.length;
093            double[][][] sbt_k = new double[K][][];
094            for(k = 0;k < K;k++)
095                sbt_k[k] = bt_k[k].serialize();
096            return sbt_k;
097        }
098    
099        private void runPass(int cross_validation_pass) {
100            int k, K = bt_k.length, m;
101            double accuracy;
102            CRFInference crfi = new CRFInference();
103            for(m = 1;m <= c.max_boosting_iterations;m++) {
104                crfi.set(c, serialize());
105                boosting_iterations.add(m);
106                accuracies.add(accuracy = crfi.accuracyViterbi(validation_points, validation_true_labels));
107                System.out.println("cross validation pass: " + cross_validation_pass + ", iterations: " + m + ", accuracy: " + accuracy);
108                ts.update();
109                for(k = 0;k < K;k++)
110                    bt_k[k].setUpdateable();
111                try {
112                    update_complete.acquire(K);
113                } catch(InterruptedException e) {
114                    break;
115                }
116            }
117            for(k = 0;k < K;k++)
118                bt_k[k].finished();
119        }
120    }