001    // Copyright 2012, 2013 Brad Block, Pawjaw, LLC. (an Ohio Limited Liability Company)
002    // 
003    // This file is part of JBTCRF.
004    // 
005    // JBTCRF is free software: you can redistribute it and/or modify
006    // it under the terms of the GNU General Public License as published by
007    // the Free Software Foundation, either version 3 of the License, or
008    // (at your option) any later version.
009    // 
010    // JBTCRF is distributed in the hope that it will be useful,
011    // but WITHOUT ANY WARRANTY; without even the implied warranty of
012    // MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
013    // GNU General Public License for more details.
014    // 
015    // You should have received a copy of the GNU General Public License
016    // along with JBTCRF.  If not, see <http://www.gnu.org/licenses/>.
017    
018    package com.pawjaw.classification.crf.lmcbt.configurations;
019    
020    /**
021     * Single point for configuring all options for training and inference.
022     */
023    public abstract class Configuration {
024        /**
025         * One side of the sliding window (e.g. window_radius = 3 means
026         * use features from previous 3 and next 3 elements in the
027         * sequence).
028         */
029        public final int window_radius;
030        /**
031         * Number of features in each individual element (aka point).
032         */
033        public final int point_features;
034        /**
035         * Number of threads used in multithreaded sufficient
036         * statistics collection portion of node splitting during
037         * regression tree construction.
038         */
039        public final int splitter_threads = getSplitterThreads();
040        /**
041         * Number of features in each individual element (aka point)
042         * including features incorporated through sliding window and
043         * previous point label.
044         */
045        public final int expanded_features;
046        /**
047         * Offset into feature index where valid window features start.
048         */
049        public final int valid_window_offset;
050        /**
051         * Maximum leaves per boosted regression tree.
052         */
053        public final int max_leaves_per_tree = getMaxLeavesPerTree();
054        /**
055         * Number of threads used in multithreaded tree boosting.
056         * Can utilize as many simultaneous threads as the value in
057         * {@link label_count_including_start_label}.
058         */
059        public final int boosted_tree_threads = getBoostedTreeThreads();
060        /**
061         * Maximum number of boosted tree iterations.  May be less if
062         * training includes validation data and
063         * {@link min_relative_accuracy_improvement} is greater than zero.
064         */
065        public final int max_boosting_iterations = getMaxBoostingIterations();
066        /**
067         * Minimum number of boosting iterations regardless of validation
068         * data and {@link min_relative_accuracy_improvement}.
069         */
070        public final int min_boosting_iterations = getMinBoostingIterations();
071        /**
072         * Boosting itereations to run between evaluating stopping criteria when
073         * validation data is provided for training.
074         */
075        public final int boosting_iterations_between_tests = getBoostingIterationsBetweenTests();
076        /**
077         * Number of labels for elements in the data not including the start
078         * pseudo-label.
079         */
080        public final int label_count_excluding_start_label;
081        /**
082         * Number of labels for elements in the data including the start
083         * pseudo-label.
084         */
085        public final int label_count_including_start_label;
086        /**
087         * Shrinkage for regression tree smoothing.  Typical range is 0 to 100.
088         */
089        public final double regression_tree_shrinkage = getRegressionTreeShrinkage();
090        /**
091         * Minimum relative accuracy improvement required to continue boosting
092         * when training includes validation data.  Typical range is 1e-3 to 1e-4.
093         */
094        public final double min_relative_accuracy_improvement = getMinRelativeAccuracyImprovement();
095        /*
096         * Print to standard error some basic reporting of training progress.
097         */
098        public final boolean report_training_progress = getReportTrainingProgress();
099        /*
100         * Cache expanded true features for training examples.
101         */
102        public final boolean cache_expanded_true_features = getCacheExpandedTrueFeatures();
103    
104        protected abstract int getSplitterThreads();
105        protected abstract int getMaxLeavesPerTree();
106        protected abstract int getBoostedTreeThreads();
107        protected abstract int getMaxBoostingIterations();
108        protected abstract int getMinBoostingIterations();
109        protected abstract int getBoostingIterationsBetweenTests();
110        protected abstract double getRegressionTreeShrinkage();
111        protected abstract double getMinRelativeAccuracyImprovement();
112        protected abstract boolean getReportTrainingProgress();
113        protected abstract boolean getCacheExpandedTrueFeatures();
114    
115        public Configuration(int point_features, int window_radius, int label_count_excluding_start_label) {
116            this.point_features = point_features;
117            this.window_radius = window_radius;
118            this.label_count_excluding_start_label = label_count_excluding_start_label;
119            label_count_including_start_label = label_count_excluding_start_label + 1;
120            valid_window_offset = label_count_including_start_label + point_features * (window_radius * 2 + 1);
121            expanded_features = valid_window_offset + window_radius * 2 + 1;
122        }
123    }