001 // Copyright 2012, 2013 Brad Block, Pawjaw, LLC. (an Ohio Limited Liability Company) 002 // 003 // This file is part of JBTCRF. 004 // 005 // JBTCRF is free software: you can redistribute it and/or modify 006 // it under the terms of the GNU General Public License as published by 007 // the Free Software Foundation, either version 3 of the License, or 008 // (at your option) any later version. 009 // 010 // JBTCRF is distributed in the hope that it will be useful, 011 // but WITHOUT ANY WARRANTY; without even the implied warranty of 012 // MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 013 // GNU General Public License for more details. 014 // 015 // You should have received a copy of the GNU General Public License 016 // along with JBTCRF. If not, see <http://www.gnu.org/licenses/>. 017 018 package com.pawjaw.classification.crf.lmcbt; 019 020 import com.pawjaw.classification.crf.lmcbt.configurations.Configuration; 021 import com.pawjaw.classification.crf.lmcbt.points.Point; 022 import java.util.List; 023 import com.pawjaw.classification.crf.lmcbt.trees.BoostedTrees; 024 import com.pawjaw.classification.crf.lmcbt.sequences.ExpandedPointSequence; 025 026 import static com.pawjaw.classification.crf.lmcbt.CRFMath.*; 027 028 /** 029 * Class to take trained and serialized CRF model and apply to new data for 030 * labeling. Create new instance then use set() method to specify model and 031 * configuration. Configuration specified should match that used to train 032 * the specified CRF model. 033 */ 034 public class CRFInference { 035 private Configuration c; 036 private double[][][] sbt_k; 037 038 /** 039 * Set the configuration options and the CRF model. Configuration options 040 * should match those used to train CRF model. 041 * 042 * @param c configuration options 043 * @param sbt_k serialized CRF model 044 */ 045 public void set(Configuration c, double[][][] sbt_k) { 046 this.c = c; 047 this.sbt_k = sbt_k; 048 } 049 050 private double potential(int label, int sequence_position, int previous_label, ExpandedPointSequence eps) { 051 return BoostedTrees.boostedOutput(c, eps, sequence_position, previous_label, sbt_k[label]); 052 } 053 054 /** 055 * Label given sequence and compare to given labels for testing purposes. 056 * Use Viterbi method for labeling (most likely overall label sequence). 057 * 058 * @param eps sequence to label 059 * @param true_labels true labels of the points in the sequence 060 * @return number of labels correctly identified by the CRF model 061 */ 062 public int correctVirterbi(ExpandedPointSequence eps, int[] true_labels) { 063 int i = 0, correct = 0; 064 for(int found_label : labelViterbi(eps)) 065 if(found_label == true_labels[i++]) 066 correct++; 067 return correct; 068 } 069 070 /** 071 * Label given sequence and compare to given labels for testing purposes. 072 * Use Forward-Backward method for labeling (most likely individual 073 * point-wise labels). 074 * 075 * @param eps sequence to label 076 * @param true_labels true labels of the points in the sequence 077 * @return number of labels correctly identified by the CRF model 078 */ 079 public int correctForwardBackward(ExpandedPointSequence eps, int[] true_labels) { 080 int i = 0, correct = 0; 081 for(int found_label : labelForwardBackward(eps)) 082 if(found_label == true_labels[i++]) 083 correct++; 084 return correct; 085 } 086 087 /** 088 * Label given sequence sand compare to given labels for testing purposes. 089 * Use Viterbi method for labeling (most likely overall label sequence). 090 * 091 * @param pointss sequences to label 092 * @param true_labelss true labels of the points in the sequences 093 * @return percentage of correctly labeled points 094 */ 095 public double accuracyViterbi(List<Point[]> pointss, List<int[]> true_labelss) { 096 int i = 0, correct = 0, total = 0; 097 ExpandedPointSequence eps = new ExpandedPointSequence(); 098 for(Point[] points : pointss) { 099 eps.set(points); 100 correct += correctVirterbi(eps, true_labelss.get(i++)); 101 total += points.length; 102 } 103 return (double)correct / (double)total; 104 } 105 106 /** 107 * Label given sequence sand compare to given labels for testing purposes. 108 * Use Forward-Backward method for labeling (most likely individual 109 * point-wise labels). 110 * 111 * @param pointss sequences to label 112 * @param true_labelss true labels of the points in the sequences 113 * @return percentage of correctly labeled points 114 */ 115 public double accuracyForwardBackward(List<Point[]> pointss, List<int[]> true_labelss) { 116 int i = 0, correct = 0, total = 0; 117 ExpandedPointSequence eps = new ExpandedPointSequence(); 118 for(Point[] points : pointss) { 119 eps.set(points); 120 correct += correctForwardBackward(eps, true_labelss.get(i++)); 121 total += points.length; 122 } 123 return (double)correct / (double)total; 124 } 125 126 /** 127 * Infer the labels of the given sequence using the Viterbi method 128 * (most likely overall label sequence). Labels are in the range specified 129 * in the configuration given to the set() method. 130 * 131 * @param eps given sequence to label 132 * @return the labels of each point in the sequence 133 */ 134 public int[] labelViterbi(ExpandedPointSequence eps) { 135 int t, T = eps.length(), k, kp, maxk, K = sbt_k.length - 1, maxarg; 136 double logtransition, logmaxtransition; 137 int[] label_t = new int[T]; 138 int[][] maxarg_k_t = new int[T][K + 1]; 139 double[][] logmaxtransition_k_t = new double[T][K + 1]; 140 for(k = 0;k <= K;k++) { 141 maxarg_k_t[0][k] = -1; 142 logmaxtransition_k_t[0][k] = LOGONE; 143 } 144 for(t = 1;t < T;t++) 145 for(k = 0;k <= K;k++) { 146 maxarg = -1; 147 logmaxtransition = Double.NEGATIVE_INFINITY; 148 for(kp = 0;kp <= K;kp++) 149 if((logtransition = eLogProduct(potential(k, t, kp, eps), logmaxtransition_k_t[t - 1][kp])) > logmaxtransition) { 150 maxarg = kp; 151 logmaxtransition = logtransition; 152 } 153 maxarg_k_t[t][k] = maxarg; 154 logmaxtransition_k_t[t][k] = logmaxtransition; 155 } 156 maxarg = -1; 157 logmaxtransition = Double.NEGATIVE_INFINITY; 158 for(k = 0;k < K;k++) 159 if(logmaxtransition_k_t[T - 1][k] > logmaxtransition) { 160 maxarg = k; 161 logmaxtransition = logmaxtransition_k_t[T - 1][k]; 162 } 163 label_t[T - 1] = maxarg; 164 for(t = T - 2;t >= 0;t--) 165 label_t[t] = maxarg_k_t[t + 1][label_t[t + 1]]; 166 return label_t; 167 } 168 169 /** 170 * Infer the labels of the given sequence using the Forward-Backward method 171 * (most likely individual point-wise labels ). Labels are in the range 172 * specified in the configuration given to the set() method. 173 * 174 * @param eps given sequence to label 175 * @return the labels of each point in the sequence 176 */ 177 public int[] labelForwardBackward(ExpandedPointSequence eps) { 178 int t, k, kp, T = eps.length(), K = sbt_k.length - 1, maxarg; 179 double logsum, logfb, maxlogfb; 180 int[] label_t = new int[T]; 181 double[][] logalpha_k_ta = new double[T + 1][K + 1]; 182 double[][] logbeta_k_ta = new double[T + 1][K + 1]; 183 for(k = 0;k <= K;k++) 184 logalpha_k_ta[1][k] = potential(k, 0, K, eps); 185 for(t = 1;t < T;t++) { 186 for(k = 0;k <= K;k++) { 187 logsum = LOGZERO; 188 for(kp = 0;kp <= K;kp++) 189 logsum = eLogSum(logsum, eLogProduct(potential(k, t, kp, eps), logalpha_k_ta[t][kp])); 190 logalpha_k_ta[t + 1][k] = logsum; 191 } 192 } 193 for(t = T - 1;t >= 0;t--) { 194 for(k = 0;k <= K;k++) { 195 logsum = LOGZERO; 196 for(kp = 0;kp <= K;kp++) 197 logsum = eLogSum(logsum, eLogProduct(potential(k, t, kp, eps), logbeta_k_ta[t + 1][kp])); 198 logbeta_k_ta[t][k] = logsum; 199 } 200 } 201 for(t = 0;t < T;t++) { 202 maxarg = -1; 203 maxlogfb = Double.NEGATIVE_INFINITY; 204 for(k = 0;k < K;k++) 205 if((logfb = eLogProduct(logalpha_k_ta[t + 1][k], logbeta_k_ta[t + 1][k])) > maxlogfb) { 206 maxarg = k; 207 maxlogfb = logfb; 208 } 209 label_t[t] = maxarg; 210 } 211 return label_t; 212 } 213 }