001    // Copyright 2012, 2013 Brad Block, Pawjaw, LLC. (an Ohio Limited Liability Company)
002    // 
003    // This file is part of JBTCRF.
004    // 
005    // JBTCRF is free software: you can redistribute it and/or modify
006    // it under the terms of the GNU General Public License as published by
007    // the Free Software Foundation, either version 3 of the License, or
008    // (at your option) any later version.
009    // 
010    // JBTCRF is distributed in the hope that it will be useful,
011    // but WITHOUT ANY WARRANTY; without even the implied warranty of
012    // MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
013    // GNU General Public License for more details.
014    // 
015    // You should have received a copy of the GNU General Public License
016    // along with JBTCRF.  If not, see <http://www.gnu.org/licenses/>.
017    
018    package com.pawjaw.classification.crf.lmcbt.sequences;
019    
020    import com.pawjaw.classification.crf.lmcbt.configurations.Configuration;
021    
022    import static com.pawjaw.classification.crf.lmcbt.CRFMath.*;
023    
024    public class TrainingSequence extends ExpandedPointSequence {
025        private int[] label_t; // T length vector of labels in range (0 to K - 1 inclusive)
026        private double[][] logalpha_k_ta; // T + 1 x K + 1 forward recursion matrix (ta index is augmented with start state)
027        private double[][] logbeta_k_ta; //  T + 1 x K + 1 backward recursion matrix (ta index is augmented with start state)
028        private double[][][] target_kp_k_t; //  T square matrices length K + 1 for holding regression targets
029        private double[][][] potential_kp_k_t; //  T square matrices length K + 1 for holding potential values
030    
031        public void allocate(Configuration c) {
032            int T = length(), K = c.label_count_excluding_start_label;
033            logalpha_k_ta = new double[T + 1][K + 1];
034            logbeta_k_ta = new double[T + 1][K + 1];
035            target_kp_k_t = new double[T][K + 1][K + 1];
036            potential_kp_k_t = new double[T][K + 1][K + 1];
037            initializePotential(T, K);
038            initializeForwardBackward(T, K);
039        }
040    
041        public void setLabels(int[] label_t) {
042            this.label_t = label_t;
043        }
044    
045        // k = K is start label position
046        private void initializeForwardBackward(int T, int K) {
047            int k;
048            for(k = 0;k < K;k++)
049                logalpha_k_ta[0][k] = LOGZERO;
050            logalpha_k_ta[0][K] = LOGONE;
051            for(k = 0;k <= K;k++)
052                logbeta_k_ta[T][k] = LOGONE;
053        }
054    
055        private void initializePotential(int T, int K) {
056            int t, k, kp;
057            for(t = 0;t < T;t++)
058                for(k = 0;k <= K;k++)
059                    for(kp = 0;kp <= K;kp++)
060                        potential_kp_k_t[t][k][kp] = 0;
061        }
062    
063        public void updateForwardBackward() {
064            int t, k, kp, T = length(), K = logalpha_k_ta[0].length - 1;
065            double logsum;
066            for(k = 0;k <= K;k++)
067                logalpha_k_ta[1][k] = potential_kp_k_t[0][k][K];
068            for(t = 1;t < T;t++) {
069                for(k = 0;k <= K;k++) {
070                    logsum = LOGZERO;
071                    for(kp = 0;kp <= K;kp++)
072                        logsum = eLogSum(logsum, eLogProduct(potential_kp_k_t[t][k][kp], logalpha_k_ta[t][kp]));
073                    logalpha_k_ta[t + 1][k] = logsum;
074                }
075            }
076            for(t = T - 1;t >= 0;t--) {
077                for(k = 0;k <= K;k++) {
078                    logsum = LOGZERO;
079                    for(kp = 0;kp <= K;kp++)
080                        logsum = eLogSum(logsum, eLogProduct(potential_kp_k_t[t][k][kp], logbeta_k_ta[t + 1][kp]));
081                    logbeta_k_ta[t][k] = logsum;
082                }
083            }
084        }
085        
086        public void updateTargets() {
087            int t, k, kp, T = length(), K = logalpha_k_ta[0].length - 1;
088            double lognormalizer = 0, logtarget;
089            for(t = 0;t < T;t++) {
090                lognormalizer = LOGZERO;
091                for(k = 0;k <= K;k++)
092                    for(kp = 0;kp <= K;kp++) {
093                        logtarget = eLogProduct(logalpha_k_ta[t][kp],
094                                eLogProduct(potential_kp_k_t[t][k][kp], logbeta_k_ta[t + 1][k]));
095                        target_kp_k_t[t][k][kp] = logtarget;
096                        lognormalizer = eLogSum(lognormalizer, logtarget);
097                    }
098                for(k = 0;k <= K;k++)
099                    for(kp = 0;kp <= K;kp++)
100                        target_kp_k_t[t][k][kp] = eExp(eLogProduct(target_kp_k_t[t][k][kp], -lognormalizer));
101            }
102            for(t = 0;t < T;t++)
103                for(k = 0;k <= K;k++)
104                    for(kp = 0;kp <= K;kp++)
105                        if(label_t[t] == k) {
106                            if(t == 0) {
107                                if(kp == K)
108                                    target_kp_k_t[t][k][kp] = 1.0 - target_kp_k_t[t][k][kp];
109                                else
110                                    target_kp_k_t[t][k][kp] = 0.0 - target_kp_k_t[t][k][kp];
111                            } else {
112                                if(kp == label_t[t - 1])
113                                    target_kp_k_t[t][k][kp] = 1.0 - target_kp_k_t[t][k][kp];
114                                else
115                                    target_kp_k_t[t][k][kp] = 0.0 - target_kp_k_t[t][k][kp];
116                            }
117                        } else
118                            target_kp_k_t[t][k][kp] = 0.0 - target_kp_k_t[t][k][kp];
119        }
120    
121        public double getTarget(int label, int example_index) {
122            int K = logalpha_k_ta[0].length - 1;
123            int t = example_index / (K + 1);
124            int kp = example_index % (K + 1);
125            return target_kp_k_t[t][label][kp];
126        }
127    
128        public void incrementPotential(int label, int example_index, double potential) {
129            int K = logalpha_k_ta[0].length - 1;
130            int t = example_index / (K + 1);
131            int kp = example_index % (K + 1);
132            potential_kp_k_t[t][label][kp] += potential;
133        }
134    }