001 // Copyright 2012, 2013 Brad Block, Pawjaw, LLC. (an Ohio Limited Liability Company) 002 // 003 // This file is part of JBTCRF. 004 // 005 // JBTCRF is free software: you can redistribute it and/or modify 006 // it under the terms of the GNU General Public License as published by 007 // the Free Software Foundation, either version 3 of the License, or 008 // (at your option) any later version. 009 // 010 // JBTCRF is distributed in the hope that it will be useful, 011 // but WITHOUT ANY WARRANTY; without even the implied warranty of 012 // MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 013 // GNU General Public License for more details. 014 // 015 // You should have received a copy of the GNU General Public License 016 // along with JBTCRF. If not, see <http://www.gnu.org/licenses/>. 017 018 package com.pawjaw.classification.crf.lmcbt.configurations; 019 020 /** 021 * Single point for configuring all options for training and inference. 022 */ 023 public abstract class Configuration { 024 /** 025 * One side of the sliding window (e.g. window_radius = 3 means 026 * use features from previous 3 and next 3 elements in the 027 * sequence). 028 */ 029 public final int window_radius; 030 /** 031 * Number of features in each individual element (aka point). 032 */ 033 public final int point_features; 034 /** 035 * Number of threads used in multithreaded sufficient 036 * statistics collection portion of node splitting during 037 * regression tree construction. 038 */ 039 public final int splitter_threads = getSplitterThreads(); 040 /** 041 * Number of features in each individual element (aka point) 042 * including features incorporated through sliding window and 043 * previous point label. 044 */ 045 public final int expanded_features; 046 /** 047 * Offset into feature index where valid window features start. 048 */ 049 public final int valid_window_offset; 050 /** 051 * Maximum leaves per boosted regression tree. 052 */ 053 public final int max_leaves_per_tree = getMaxLeavesPerTree(); 054 /** 055 * Number of threads used in multithreaded tree boosting. 056 * Can utilize as many simultaneous threads as the value in 057 * {@link label_count_including_start_label}. 058 */ 059 public final int boosted_tree_threads = getBoostedTreeThreads(); 060 /** 061 * Maximum number of boosted tree iterations. May be less if 062 * training includes validation data and 063 * {@link min_relative_accuracy_improvement} is greater than zero. 064 */ 065 public final int max_boosting_iterations = getMaxBoostingIterations(); 066 /** 067 * Minimum number of boosting iterations regardless of validation 068 * data and {@link min_relative_accuracy_improvement}. 069 */ 070 public final int min_boosting_iterations = getMinBoostingIterations(); 071 /** 072 * Boosting itereations to run between evaluating stopping criteria when 073 * validation data is provided for training. 074 */ 075 public final int boosting_iterations_between_tests = getBoostingIterationsBetweenTests(); 076 /** 077 * Number of labels for elements in the data not including the start 078 * pseudo-label. 079 */ 080 public final int label_count_excluding_start_label; 081 /** 082 * Number of labels for elements in the data including the start 083 * pseudo-label. 084 */ 085 public final int label_count_including_start_label; 086 /** 087 * Shrinkage for regression tree smoothing. Typical range is 0 to 100. 088 */ 089 public final double regression_tree_shrinkage = getRegressionTreeShrinkage(); 090 /** 091 * Minimum relative accuracy improvement required to continue boosting 092 * when training includes validation data. Typical range is 1e-3 to 1e-4. 093 */ 094 public final double min_relative_accuracy_improvement = getMinRelativeAccuracyImprovement(); 095 /* 096 * Print to standard error some basic reporting of training progress. 097 */ 098 public final boolean report_training_progress = getReportTrainingProgress(); 099 /* 100 * Cache expanded true features for training examples. 101 */ 102 public final boolean cache_expanded_true_features = getCacheExpandedTrueFeatures(); 103 104 protected abstract int getSplitterThreads(); 105 protected abstract int getMaxLeavesPerTree(); 106 protected abstract int getBoostedTreeThreads(); 107 protected abstract int getMaxBoostingIterations(); 108 protected abstract int getMinBoostingIterations(); 109 protected abstract int getBoostingIterationsBetweenTests(); 110 protected abstract double getRegressionTreeShrinkage(); 111 protected abstract double getMinRelativeAccuracyImprovement(); 112 protected abstract boolean getReportTrainingProgress(); 113 protected abstract boolean getCacheExpandedTrueFeatures(); 114 115 public Configuration(int point_features, int window_radius, int label_count_excluding_start_label) { 116 this.point_features = point_features; 117 this.window_radius = window_radius; 118 this.label_count_excluding_start_label = label_count_excluding_start_label; 119 label_count_including_start_label = label_count_excluding_start_label + 1; 120 valid_window_offset = label_count_including_start_label + point_features * (window_radius * 2 + 1); 121 expanded_features = valid_window_offset + window_radius * 2 + 1; 122 } 123 }