001 // Copyright 2012, 2013 Brad Block, Pawjaw, LLC. (an Ohio Limited Liability Company) 002 // 003 // This file is part of JBTCRF. 004 // 005 // JBTCRF is free software: you can redistribute it and/or modify 006 // it under the terms of the GNU General Public License as published by 007 // the Free Software Foundation, either version 3 of the License, or 008 // (at your option) any later version. 009 // 010 // JBTCRF is distributed in the hope that it will be useful, 011 // but WITHOUT ANY WARRANTY; without even the implied warranty of 012 // MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 013 // GNU General Public License for more details. 014 // 015 // You should have received a copy of the GNU General Public License 016 // along with JBTCRF. If not, see <http://www.gnu.org/licenses/>. 017 018 package com.pawjaw.classification.crf.lmcbt.sequences; 019 020 import com.pawjaw.classification.crf.lmcbt.configurations.Configuration; 021 022 import static com.pawjaw.classification.crf.lmcbt.CRFMath.*; 023 024 public class TrainingSequence extends ExpandedPointSequence { 025 private int[] label_t; // T length vector of labels in range (0 to K - 1 inclusive) 026 private double[][] logalpha_k_ta; // T + 1 x K + 1 forward recursion matrix (ta index is augmented with start state) 027 private double[][] logbeta_k_ta; // T + 1 x K + 1 backward recursion matrix (ta index is augmented with start state) 028 private double[][][] target_kp_k_t; // T square matrices length K + 1 for holding regression targets 029 private double[][][] potential_kp_k_t; // T square matrices length K + 1 for holding potential values 030 031 public void allocate(Configuration c) { 032 int T = length(), K = c.label_count_excluding_start_label; 033 logalpha_k_ta = new double[T + 1][K + 1]; 034 logbeta_k_ta = new double[T + 1][K + 1]; 035 target_kp_k_t = new double[T][K + 1][K + 1]; 036 potential_kp_k_t = new double[T][K + 1][K + 1]; 037 initializePotential(T, K); 038 initializeForwardBackward(T, K); 039 } 040 041 public void setLabels(int[] label_t) { 042 this.label_t = label_t; 043 } 044 045 // k = K is start label position 046 private void initializeForwardBackward(int T, int K) { 047 int k; 048 for(k = 0;k < K;k++) 049 logalpha_k_ta[0][k] = LOGZERO; 050 logalpha_k_ta[0][K] = LOGONE; 051 for(k = 0;k <= K;k++) 052 logbeta_k_ta[T][k] = LOGONE; 053 } 054 055 private void initializePotential(int T, int K) { 056 int t, k, kp; 057 for(t = 0;t < T;t++) 058 for(k = 0;k <= K;k++) 059 for(kp = 0;kp <= K;kp++) 060 potential_kp_k_t[t][k][kp] = 0; 061 } 062 063 public void updateForwardBackward() { 064 int t, k, kp, T = length(), K = logalpha_k_ta[0].length - 1; 065 double logsum; 066 for(k = 0;k <= K;k++) 067 logalpha_k_ta[1][k] = potential_kp_k_t[0][k][K]; 068 for(t = 1;t < T;t++) { 069 for(k = 0;k <= K;k++) { 070 logsum = LOGZERO; 071 for(kp = 0;kp <= K;kp++) 072 logsum = eLogSum(logsum, eLogProduct(potential_kp_k_t[t][k][kp], logalpha_k_ta[t][kp])); 073 logalpha_k_ta[t + 1][k] = logsum; 074 } 075 } 076 for(t = T - 1;t >= 0;t--) { 077 for(k = 0;k <= K;k++) { 078 logsum = LOGZERO; 079 for(kp = 0;kp <= K;kp++) 080 logsum = eLogSum(logsum, eLogProduct(potential_kp_k_t[t][k][kp], logbeta_k_ta[t + 1][kp])); 081 logbeta_k_ta[t][k] = logsum; 082 } 083 } 084 } 085 086 public void updateTargets() { 087 int t, k, kp, T = length(), K = logalpha_k_ta[0].length - 1; 088 double lognormalizer = 0, logtarget; 089 for(t = 0;t < T;t++) { 090 lognormalizer = LOGZERO; 091 for(k = 0;k <= K;k++) 092 for(kp = 0;kp <= K;kp++) { 093 logtarget = eLogProduct(logalpha_k_ta[t][kp], 094 eLogProduct(potential_kp_k_t[t][k][kp], logbeta_k_ta[t + 1][k])); 095 target_kp_k_t[t][k][kp] = logtarget; 096 lognormalizer = eLogSum(lognormalizer, logtarget); 097 } 098 for(k = 0;k <= K;k++) 099 for(kp = 0;kp <= K;kp++) 100 target_kp_k_t[t][k][kp] = eExp(eLogProduct(target_kp_k_t[t][k][kp], -lognormalizer)); 101 } 102 for(t = 0;t < T;t++) 103 for(k = 0;k <= K;k++) 104 for(kp = 0;kp <= K;kp++) 105 if(label_t[t] == k) { 106 if(t == 0) { 107 if(kp == K) 108 target_kp_k_t[t][k][kp] = 1.0 - target_kp_k_t[t][k][kp]; 109 else 110 target_kp_k_t[t][k][kp] = 0.0 - target_kp_k_t[t][k][kp]; 111 } else { 112 if(kp == label_t[t - 1]) 113 target_kp_k_t[t][k][kp] = 1.0 - target_kp_k_t[t][k][kp]; 114 else 115 target_kp_k_t[t][k][kp] = 0.0 - target_kp_k_t[t][k][kp]; 116 } 117 } else 118 target_kp_k_t[t][k][kp] = 0.0 - target_kp_k_t[t][k][kp]; 119 } 120 121 public double getTarget(int label, int example_index) { 122 int K = logalpha_k_ta[0].length - 1; 123 int t = example_index / (K + 1); 124 int kp = example_index % (K + 1); 125 return target_kp_k_t[t][label][kp]; 126 } 127 128 public void incrementPotential(int label, int example_index, double potential) { 129 int K = logalpha_k_ta[0].length - 1; 130 int t = example_index / (K + 1); 131 int kp = example_index % (K + 1); 132 potential_kp_k_t[t][label][kp] += potential; 133 } 134 }