001    // Copyright 2012, 2013 Brad Block, Pawjaw, LLC. (an Ohio Limited Liability Company)
002    // 
003    // This file is part of JBTCRF.
004    // 
005    // JBTCRF is free software: you can redistribute it and/or modify
006    // it under the terms of the GNU General Public License as published by
007    // the Free Software Foundation, either version 3 of the License, or
008    // (at your option) any later version.
009    // 
010    // JBTCRF is distributed in the hope that it will be useful,
011    // but WITHOUT ANY WARRANTY; without even the implied warranty of
012    // MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
013    // GNU General Public License for more details.
014    // 
015    // You should have received a copy of the GNU General Public License
016    // along with JBTCRF.  If not, see <http://www.gnu.org/licenses/>.
017    
018    package com.pawjaw.classification.crf.lmcbt.readers;
019    
020    import com.pawjaw.classification.crf.lmcbt.points.Point;
021    import com.pawjaw.classification.crf.lmcbt.points.SparsePoint;
022    import java.io.BufferedReader;
023    import java.io.EOFException;
024    import java.io.IOException;
025    import java.util.ArrayList;
026    import java.util.List;
027    import java.util.regex.Pattern;
028    
029    public class TreeCRFLabeledSequenceDataReader extends LabeledSequenceDataReader {
030        private int max_point = -1, max_label = -1;
031        private int previous_sequence = -1;
032        private Pattern splitter = Pattern.compile("\\s+");
033    
034        private void readPoint(List<List<Point>> pointss, List<List<Integer>> true_labelss, BufferedReader br) throws IOException {
035            int k, K, sequence, label;
036            String s = br.readLine();
037            List<Point> points;
038            List<Integer> labels;
039            SparsePoint sp = new SparsePoint();
040            String[] ss;
041            int[] feature_indexes;
042            if(s == null)
043                throw new EOFException();
044            if((ss = splitter.split(s)).length >= 5) {
045                feature_indexes = new int[K = Integer.parseInt(ss[2])];
046                for(k = 0;k < K;k++)
047                    if((feature_indexes[k] = Integer.parseInt(ss[k + 3])) > max_point)
048                        max_point = feature_indexes[k];
049                if((sequence = Integer.parseInt(ss[0])) != previous_sequence) {
050                    pointss.add(points = new ArrayList());
051                    true_labelss.add(labels = new ArrayList());
052                    previous_sequence = sequence;
053                } else {
054                    points = pointss.get(pointss.size() - 1);
055                    labels = true_labelss.get(true_labelss.size() - 1);
056                }
057                sp.setFeatures(feature_indexes);
058                points.add(sp);
059                labels.add(label = Integer.parseInt(ss[ss.length - 1]));
060                if(label > max_label)
061                    max_label = label;
062            }
063        }
064    
065        public void readPoints(List<Point[]> pointss, List<int[]> true_labelss, BufferedReader br) throws IOException {
066            int i, I, k, K;
067            List<List<Point>> _pointss = new ArrayList();
068            List<List<Integer>> _true_labelss = new ArrayList();
069            Point[] points;
070            int[] labels;
071            br.readLine();
072            try {
073                while(true)
074                    readPoint(_pointss, _true_labelss, br);
075            } catch(EOFException e) {}
076            K = _pointss.size();
077            for(k = 0;k < K;k++) {
078                points = new Point[I = _pointss.get(k).size()];
079                labels = new int[I];
080                for(i = 0;i < I;i++) {
081                    points[i] = _pointss.get(k).get(i);
082                    labels[i] = _true_labelss.get(k).get(i);
083                }
084                pointss.add(points);
085                true_labelss.add(labels);
086            }
087        }
088    
089        public int getPointFeatureCount() {
090            return max_point + 1;
091        }
092    
093        public int getLabelCount() {
094            return max_label + 1;
095        }
096    }