001 // Copyright 2012, 2013 Brad Block, Pawjaw, LLC. (an Ohio Limited Liability Company) 002 // 003 // This file is part of JBTCRF. 004 // 005 // JBTCRF is free software: you can redistribute it and/or modify 006 // it under the terms of the GNU General Public License as published by 007 // the Free Software Foundation, either version 3 of the License, or 008 // (at your option) any later version. 009 // 010 // JBTCRF is distributed in the hope that it will be useful, 011 // but WITHOUT ANY WARRANTY; without even the implied warranty of 012 // MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 013 // GNU General Public License for more details. 014 // 015 // You should have received a copy of the GNU General Public License 016 // along with JBTCRF. If not, see <http://www.gnu.org/licenses/>. 017 018 package com.pawjaw.classification.crf.lmcbt.tests; 019 020 import com.pawjaw.classification.crf.lmcbt.CRFValidator; 021 import com.pawjaw.classification.crf.lmcbt.configurations.Configuration; 022 import com.pawjaw.classification.crf.lmcbt.configurations.DefaultConfiguration; 023 import com.pawjaw.classification.crf.lmcbt.points.Point; 024 import com.pawjaw.classification.crf.lmcbt.readers.LabeledSequenceDataReader; 025 import java.io.File; 026 import java.io.IOException; 027 import java.util.ArrayList; 028 import java.util.List; 029 030 public abstract class CRFTest { 031 public abstract int getWindowRadius(); 032 public abstract LabeledSequenceDataReader getLabeledSequenceDataReader(); 033 034 public int getCrossValidationPasses() { 035 return 8; 036 } 037 038 public double getTrainingPercentage() { 039 return 0.8; 040 } 041 042 public List<Double> accuracies; 043 public List<Integer> boosting_iterations ; 044 045 public void test(File sequence_data_file) throws IOException { 046 CRFValidator v; 047 Configuration c; 048 List<Point[]> pointss = new ArrayList(); 049 List<int[]> true_labelss = new ArrayList(); 050 LabeledSequenceDataReader r = getLabeledSequenceDataReader(); 051 r.readFile(pointss, true_labelss, sequence_data_file); 052 c = new DefaultConfiguration(r.getPointFeatureCount(), getWindowRadius(), r.getLabelCount()); 053 v = new CRFValidator(c, pointss, true_labelss); 054 v.run(getCrossValidationPasses(), getTrainingPercentage()); 055 accuracies = v.accuracies; 056 boosting_iterations = v.boosting_iterations; 057 } 058 }