001    // Copyright 2012, 2013 Brad Block, Pawjaw, LLC. (an Ohio Limited Liability Company)
002    // 
003    // This file is part of JBTCRF.
004    // 
005    // JBTCRF is free software: you can redistribute it and/or modify
006    // it under the terms of the GNU General Public License as published by
007    // the Free Software Foundation, either version 3 of the License, or
008    // (at your option) any later version.
009    // 
010    // JBTCRF is distributed in the hope that it will be useful,
011    // but WITHOUT ANY WARRANTY; without even the implied warranty of
012    // MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
013    // GNU General Public License for more details.
014    // 
015    // You should have received a copy of the GNU General Public License
016    // along with JBTCRF.  If not, see <http://www.gnu.org/licenses/>.
017    
018    package com.pawjaw.classification.crf.lmcbt;
019    
020    import com.pawjaw.classification.crf.lmcbt.configurations.Configuration;
021    import com.pawjaw.classification.crf.lmcbt.trees.BoostedTrees;
022    import com.pawjaw.classification.crf.lmcbt.sequences.TrainingSequences;
023    import com.pawjaw.classification.crf.lmcbt.points.Point;
024    import java.util.List;
025    import java.util.concurrent.Semaphore;
026    
027    /**
028     * Main entry point for training new CRF model.  Following instantiation,
029     * set all training examples using set() method, then run one of the
030     * train() methods.
031     * 
032     */
033    public class CRFTrainer {
034        private Configuration c;
035        private TrainingSequences ts;
036        private BoostedTrees[] bt_k;
037    
038        /**
039         * Instantiate this, then use set() repeatedly finally followed by train().
040         * Same configuration should be used for inference.
041         *
042         * @param c options for training specified here
043         * @param example_sequences number of example sequences to be used in training
044         */
045        public CRFTrainer(Configuration c, int example_sequences) {
046            this.c = c;
047            ts = new TrainingSequences();
048            ts.allocate(c, example_sequences);
049            bt_k = new BoostedTrees[c.label_count_including_start_label];
050        }
051    
052        /**
053         * Set a new training example.  sequence_index ranges up to example_sequences
054         * specified in constructor.  True labels range to the value specified in
055         * the configuraton given to the constructor.
056         *
057         * @param sequence_index ranges from 0 to example_sequences - 1 (inclusive)
058         * @param points data points in the given sequence
059         * @param labels true labels for the given sequence
060         */
061        public void set(int sequence_index, Point[] points, int[] labels) {
062            ts.set(sequence_index, points, labels);
063        }
064    
065        /**
066         * Train a new CRF model with the number of boosting iterations specified
067         * in the configuration given to the constructor.
068         */
069        public void train() {
070            int k, K = bt_k.length, m;
071            Semaphore update_complete = new Semaphore(0);
072            Semaphore proceed = new Semaphore(c.boosted_tree_threads);
073            ts.indexExamples();
074            if(c.cache_expanded_true_features)
075                ts.cacheTrueFeatures();
076            for(k = 0;k < K;k++)
077                (bt_k[k] = new BoostedTrees(c, ts, k, proceed, update_complete)).start();
078            for(m = 1;m <= c.max_boosting_iterations;m++) {
079                if(c.report_training_progress)
080                    System.err.println("boosting iteration: " + m + " / " + c.max_boosting_iterations);
081                ts.update();
082                for(k = 0;k < K;k++)
083                    bt_k[k].setUpdateable();
084                try {
085                    update_complete.acquire(K);
086                } catch(InterruptedException e) {
087                    break;
088                }
089            }
090            for(k = 0;k < K;k++)
091                bt_k[k].finished();
092        }
093    
094        /**
095         * Train a new CRF model with the maximum number of boosting iterations
096         * specified in the configuration given to the constructor but with the
097         * option to stop early if the incremental improvement in Viterbi accuracy
098         * is less than the amount specified in the configuration.  Use the
099         * arguments to this method to validate the model for accuracy.
100         *
101         * @param pointss validation data points
102         * @param true_labelss validation labels
103         */
104        public void train(List<Point[]> pointss, List<int[]> true_labelss) {
105            int i = c.boosting_iterations_between_tests, k, K = bt_k.length, m;
106            double accuracy, previous_accuracy = 0;
107            CRFInference crfi = new CRFInference();
108            Semaphore update_complete = new Semaphore(0);
109            Semaphore proceed = new Semaphore(c.boosted_tree_threads);
110            ts.indexExamples();
111            if(c.cache_expanded_true_features)
112                ts.cacheTrueFeatures();
113            for(k = 0;k < K;k++)
114                (bt_k[k] = new BoostedTrees(c, ts, k, proceed, update_complete)).start();
115            for(m = 1;m <= c.max_boosting_iterations;m++) {
116                if(c.report_training_progress)
117                    System.err.println("boosting iteration: " + m + " / " + c.max_boosting_iterations);
118                if(m >= c.min_boosting_iterations && i++ == c.boosting_iterations_between_tests) {
119                    crfi.set(c, serialize());
120                    accuracy = crfi.accuracyViterbi(pointss, true_labelss);
121                    if(c.report_training_progress)
122                        System.err.println("viterbi validation accuracy at iteration " + m + ": " + accuracy);
123                    if(previous_accuracy > 0 && (accuracy - previous_accuracy) / previous_accuracy < c.min_relative_accuracy_improvement) {
124                        System.err.println("stopping early, relative accuracy gained: " + ((accuracy - previous_accuracy) / previous_accuracy));
125                        break;
126                    } else {
127                        i = 1;
128                        previous_accuracy = accuracy;
129                    }
130                }
131                ts.update();
132                for(k = 0;k < K;k++)
133                    bt_k[k].setUpdateable();
134                try {
135                    update_complete.acquire(K);
136                } catch(InterruptedException e) {
137                    break;
138                }
139            }
140            for(k = 0;k < K;k++)
141                bt_k[k].finished();
142        }
143    
144        /**
145         * Serialize trained CRF model to be written to disk and/or used for
146         * subsequent inference (labeling new data).
147         *
148         * @return CRF model for subsequent inference
149         */
150        public double[][][] serialize() {
151            int k, K = bt_k.length;
152            double[][][] sbt_k = new double[K][][];
153            for(k = 0;k < K;k++)
154                sbt_k[k] = bt_k[k].serialize();
155            return sbt_k;
156        }
157    }