001    // Copyright 2012, 2013 Brad Block, Pawjaw, LLC. (an Ohio Limited Liability Company)
002    // 
003    // This file is part of JBTCRF.
004    // 
005    // JBTCRF is free software: you can redistribute it and/or modify
006    // it under the terms of the GNU General Public License as published by
007    // the Free Software Foundation, either version 3 of the License, or
008    // (at your option) any later version.
009    // 
010    // JBTCRF is distributed in the hope that it will be useful,
011    // but WITHOUT ANY WARRANTY; without even the implied warranty of
012    // MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
013    // GNU General Public License for more details.
014    // 
015    // You should have received a copy of the GNU General Public License
016    // along with JBTCRF.  If not, see <http://www.gnu.org/licenses/>.
017    
018    package com.pawjaw.classification.crf.lmcbt;
019    
020    import com.pawjaw.classification.crf.lmcbt.configurations.Configuration;
021    import com.pawjaw.classification.crf.lmcbt.points.Point;
022    import java.util.List;
023    import com.pawjaw.classification.crf.lmcbt.trees.BoostedTrees;
024    import com.pawjaw.classification.crf.lmcbt.sequences.ExpandedPointSequence;
025    
026    import static com.pawjaw.classification.crf.lmcbt.CRFMath.*;
027    
028    /**
029     * Class to take trained and serialized CRF model and apply to new data for
030     * labeling.  Create new instance then use set() method to specify model and
031     * configuration.  Configuration specified should match that used to train
032     * the specified CRF model.
033     */
034    public class CRFInference {
035        private Configuration c;
036        private double[][][] sbt_k;
037    
038        /**
039         * Set the configuration options and the CRF model.  Configuration options
040         * should match those used to train CRF model.
041         *
042         * @param c configuration options
043         * @param sbt_k serialized CRF model
044         */
045        public void set(Configuration c, double[][][] sbt_k) {
046            this.c = c;
047            this.sbt_k = sbt_k;
048        }
049    
050        private double potential(int label, int sequence_position, int previous_label, ExpandedPointSequence eps) {
051            return BoostedTrees.boostedOutput(c, eps, sequence_position, previous_label, sbt_k[label]);
052        }
053    
054        /**
055         * Label given sequence and compare to given labels for testing purposes.
056         * Use Viterbi method for labeling (most likely overall label sequence).
057         *
058         * @param eps sequence to label
059         * @param true_labels true labels of the points in the sequence
060         * @return number of labels correctly identified by the CRF model
061         */
062        public int correctVirterbi(ExpandedPointSequence eps, int[] true_labels) {
063            int i = 0, correct = 0;
064            for(int found_label : labelViterbi(eps))
065                if(found_label == true_labels[i++])
066                    correct++;
067            return correct;
068        }
069    
070        /**
071         * Label given sequence and compare to given labels for testing purposes.
072         * Use Forward-Backward method for labeling (most likely individual
073         * point-wise labels).
074         *
075         * @param eps sequence to label
076         * @param true_labels true labels of the points in the sequence
077         * @return number of labels correctly identified by the CRF model
078         */
079        public int correctForwardBackward(ExpandedPointSequence eps, int[] true_labels) {
080            int i = 0, correct = 0;
081            for(int found_label : labelForwardBackward(eps))
082                if(found_label == true_labels[i++])
083                    correct++;
084            return correct;
085        }
086    
087        /**
088         * Label given sequence sand compare to given labels for testing purposes.
089         * Use Viterbi method for labeling (most likely overall label sequence).
090         *
091         * @param pointss sequences to label
092         * @param true_labelss true labels of the points in the sequences
093         * @return percentage of correctly labeled points
094         */
095        public double accuracyViterbi(List<Point[]> pointss, List<int[]> true_labelss) {
096            int i = 0, correct = 0, total = 0;
097            ExpandedPointSequence eps = new ExpandedPointSequence();
098            for(Point[] points : pointss) {
099                eps.set(points);
100                correct += correctVirterbi(eps, true_labelss.get(i++));
101                total += points.length;
102            }
103            return (double)correct / (double)total;
104        }
105    
106        /**
107         * Label given sequence sand compare to given labels for testing purposes.
108         * Use Forward-Backward method for labeling (most likely individual
109         * point-wise labels).
110         *
111         * @param pointss sequences to label
112         * @param true_labelss true labels of the points in the sequences
113         * @return percentage of correctly labeled points
114         */
115        public double accuracyForwardBackward(List<Point[]> pointss, List<int[]> true_labelss) {
116            int i = 0, correct = 0, total = 0;
117            ExpandedPointSequence eps = new ExpandedPointSequence();
118            for(Point[] points : pointss) {
119                eps.set(points);
120                correct += correctForwardBackward(eps, true_labelss.get(i++));
121                total += points.length;
122            }
123            return (double)correct / (double)total;
124        }
125    
126        /**
127         * Infer the labels of the given sequence using the Viterbi method
128         * (most likely overall label sequence).  Labels are in the range specified
129         * in the configuration given to the set() method.
130         *
131         * @param eps given sequence to label
132         * @return the labels of each point in the sequence
133         */
134        public int[] labelViterbi(ExpandedPointSequence eps) {
135            int t, T = eps.length(), k, kp, maxk, K = sbt_k.length - 1, maxarg;
136            double logtransition, logmaxtransition;
137            int[] label_t = new int[T];
138            int[][] maxarg_k_t = new int[T][K + 1];
139            double[][] logmaxtransition_k_t = new double[T][K + 1];
140            for(k = 0;k <= K;k++) {
141                maxarg_k_t[0][k] = -1;
142                logmaxtransition_k_t[0][k] = LOGONE;
143            }
144            for(t = 1;t < T;t++)
145                for(k = 0;k <= K;k++) {
146                    maxarg = -1;
147                    logmaxtransition = Double.NEGATIVE_INFINITY;
148                    for(kp = 0;kp <= K;kp++)
149                        if((logtransition = eLogProduct(potential(k, t, kp, eps), logmaxtransition_k_t[t - 1][kp])) > logmaxtransition) {
150                            maxarg = kp;
151                            logmaxtransition = logtransition;
152                        }
153                    maxarg_k_t[t][k] = maxarg;
154                    logmaxtransition_k_t[t][k] = logmaxtransition;
155                }
156            maxarg = -1;
157            logmaxtransition = Double.NEGATIVE_INFINITY;
158            for(k = 0;k < K;k++)
159                if(logmaxtransition_k_t[T - 1][k] > logmaxtransition) {
160                    maxarg = k;
161                    logmaxtransition = logmaxtransition_k_t[T - 1][k];
162                }
163            label_t[T - 1] = maxarg;
164            for(t = T - 2;t >= 0;t--)
165                label_t[t] = maxarg_k_t[t + 1][label_t[t + 1]];
166            return label_t;
167        }
168    
169        /**
170         * Infer the labels of the given sequence using the Forward-Backward method
171         * (most likely individual point-wise labels ).  Labels are in the range
172         * specified in the configuration given to the set() method.
173         *
174         * @param eps given sequence to label
175         * @return the labels of each point in the sequence
176         */
177        public int[] labelForwardBackward(ExpandedPointSequence eps) {
178            int t, k, kp, T = eps.length(), K = sbt_k.length - 1, maxarg;
179            double logsum, logfb, maxlogfb;
180            int[] label_t = new int[T];
181            double[][] logalpha_k_ta = new double[T + 1][K + 1];
182            double[][] logbeta_k_ta = new double[T + 1][K + 1];
183            for(k = 0;k <= K;k++)
184                logalpha_k_ta[1][k] = potential(k, 0, K, eps);
185            for(t = 1;t < T;t++) {
186                for(k = 0;k <= K;k++) {
187                    logsum = LOGZERO;
188                    for(kp = 0;kp <= K;kp++)
189                        logsum = eLogSum(logsum, eLogProduct(potential(k, t, kp, eps), logalpha_k_ta[t][kp]));
190                    logalpha_k_ta[t + 1][k] = logsum;
191                }
192            }
193            for(t = T - 1;t >= 0;t--) {
194                for(k = 0;k <= K;k++) {
195                    logsum = LOGZERO;
196                    for(kp = 0;kp <= K;kp++)
197                        logsum = eLogSum(logsum, eLogProduct(potential(k, t, kp, eps), logbeta_k_ta[t + 1][kp]));
198                    logbeta_k_ta[t][k] = logsum;
199                }
200            }
201            for(t = 0;t < T;t++) {
202                maxarg = -1;
203                maxlogfb = Double.NEGATIVE_INFINITY;
204                for(k = 0;k < K;k++)
205                    if((logfb = eLogProduct(logalpha_k_ta[t + 1][k], logbeta_k_ta[t + 1][k])) > maxlogfb) {
206                        maxarg = k;
207                        maxlogfb = logfb;
208                    }
209                label_t[t] = maxarg;
210            }
211            return label_t;
212        }
213    }