001 // Copyright 2012, 2013 Brad Block, Pawjaw, LLC. (an Ohio Limited Liability Company) 002 // 003 // This file is part of JBTCRF. 004 // 005 // JBTCRF is free software: you can redistribute it and/or modify 006 // it under the terms of the GNU General Public License as published by 007 // the Free Software Foundation, either version 3 of the License, or 008 // (at your option) any later version. 009 // 010 // JBTCRF is distributed in the hope that it will be useful, 011 // but WITHOUT ANY WARRANTY; without even the implied warranty of 012 // MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 013 // GNU General Public License for more details. 014 // 015 // You should have received a copy of the GNU General Public License 016 // along with JBTCRF. If not, see <http://www.gnu.org/licenses/>. 017 018 package com.pawjaw.classification.crf.lmcbt.readers; 019 020 import com.pawjaw.classification.crf.lmcbt.points.Point; 021 import com.pawjaw.classification.crf.lmcbt.points.SparsePoint; 022 import java.io.BufferedReader; 023 import java.io.EOFException; 024 import java.io.IOException; 025 import java.util.ArrayList; 026 import java.util.List; 027 import java.util.regex.Pattern; 028 029 public class TreeCRFLabeledSequenceDataReader extends LabeledSequenceDataReader { 030 private int max_point = -1, max_label = -1; 031 private int previous_sequence = -1; 032 private Pattern splitter = Pattern.compile("\\s+"); 033 034 private void readPoint(List<List<Point>> pointss, List<List<Integer>> true_labelss, BufferedReader br) throws IOException { 035 int k, K, sequence, label; 036 String s = br.readLine(); 037 List<Point> points; 038 List<Integer> labels; 039 SparsePoint sp = new SparsePoint(); 040 String[] ss; 041 int[] feature_indexes; 042 if(s == null) 043 throw new EOFException(); 044 if((ss = splitter.split(s)).length >= 5) { 045 feature_indexes = new int[K = Integer.parseInt(ss[2])]; 046 for(k = 0;k < K;k++) 047 if((feature_indexes[k] = Integer.parseInt(ss[k + 3])) > max_point) 048 max_point = feature_indexes[k]; 049 if((sequence = Integer.parseInt(ss[0])) != previous_sequence) { 050 pointss.add(points = new ArrayList()); 051 true_labelss.add(labels = new ArrayList()); 052 previous_sequence = sequence; 053 } else { 054 points = pointss.get(pointss.size() - 1); 055 labels = true_labelss.get(true_labelss.size() - 1); 056 } 057 sp.setFeatures(feature_indexes); 058 points.add(sp); 059 labels.add(label = Integer.parseInt(ss[ss.length - 1])); 060 if(label > max_label) 061 max_label = label; 062 } 063 } 064 065 public void readPoints(List<Point[]> pointss, List<int[]> true_labelss, BufferedReader br) throws IOException { 066 int i, I, k, K; 067 List<List<Point>> _pointss = new ArrayList(); 068 List<List<Integer>> _true_labelss = new ArrayList(); 069 Point[] points; 070 int[] labels; 071 br.readLine(); 072 try { 073 while(true) 074 readPoint(_pointss, _true_labelss, br); 075 } catch(EOFException e) {} 076 K = _pointss.size(); 077 for(k = 0;k < K;k++) { 078 points = new Point[I = _pointss.get(k).size()]; 079 labels = new int[I]; 080 for(i = 0;i < I;i++) { 081 points[i] = _pointss.get(k).get(i); 082 labels[i] = _true_labelss.get(k).get(i); 083 } 084 pointss.add(points); 085 true_labelss.add(labels); 086 } 087 } 088 089 public int getPointFeatureCount() { 090 return max_point + 1; 091 } 092 093 public int getLabelCount() { 094 return max_label + 1; 095 } 096 }