001    // Copyright 2012, 2013 Brad Block, Pawjaw, LLC. (an Ohio Limited Liability Company)
002    // 
003    // This file is part of JBTCRF.
004    // 
005    // JBTCRF is free software: you can redistribute it and/or modify
006    // it under the terms of the GNU General Public License as published by
007    // the Free Software Foundation, either version 3 of the License, or
008    // (at your option) any later version.
009    // 
010    // JBTCRF is distributed in the hope that it will be useful,
011    // but WITHOUT ANY WARRANTY; without even the implied warranty of
012    // MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
013    // GNU General Public License for more details.
014    // 
015    // You should have received a copy of the GNU General Public License
016    // along with JBTCRF.  If not, see <http://www.gnu.org/licenses/>.
017    
018    package com.pawjaw.classification.crf.lmcbt.readers;
019    
020    import com.pawjaw.classification.crf.lmcbt.points.Point;
021    import com.pawjaw.classification.crf.lmcbt.points.SparsePoint;
022    import java.io.BufferedReader;
023    import java.io.EOFException;
024    import java.io.IOException;
025    import java.util.ArrayList;
026    import java.util.List;
027    import java.util.Map;
028    import java.util.TreeMap;
029    import java.util.regex.Pattern;
030    
031    public class CoNLL2000LabeledSequenceDataReader extends LabeledSequenceDataReader {
032        private List<String> tokens = new ArrayList();
033        private List<String> labels = new ArrayList();
034        private Map<String, Integer> idx_token = new TreeMap();
035        private Map<String, Integer> idx_label = new TreeMap();
036        private Pattern splitter = Pattern.compile("\\s+");
037    
038        public int getPointFeatureCount() {
039            return tokens.size();
040        }
041    
042        public int getLabelCount() {
043            return labels.size();
044        }
045    
046        private class LabeledPoint {
047            public int label_idx;
048            public SparsePoint sp = new SparsePoint();
049        }
050    
051        private int getTokenIdx(String token) {
052            Integer idx = idx_token.get(token);
053            if(idx == null) {
054                idx_token.put(token, idx = tokens.size());
055                tokens.add(token);
056            }
057            return idx;
058        }
059    
060        private int getLabelIdx(String label) {
061            Integer idx = idx_label.get(label);
062            if(idx == null) {
063                idx_label.put(label, idx = labels.size());
064                labels.add(label);
065            }
066            return idx;
067        }
068    
069        private LabeledPoint readNextWord(BufferedReader br) throws IOException {
070            String s = br.readLine();
071            LabeledPoint lp = new LabeledPoint();
072            String[] ss;
073            if(s == null)
074                throw new EOFException();
075            if((s = s.trim()).isEmpty() || (ss = splitter.split(s)).length != 3)
076                return null;
077            lp.sp.setFeatures(new int[] { getTokenIdx(ss[0]), getTokenIdx("POS_" + ss[1]) });
078            lp.label_idx = getLabelIdx(ss[2]);
079            return lp;
080        }
081    
082        private void readNextSentence(BufferedReader br, List<LabeledPoint> lps) throws IOException {
083            LabeledPoint lp;
084            try {
085                while((lp = readNextWord(br)) != null)
086                    lps.add(lp);
087            } catch(EOFException e) {
088                if(lps.isEmpty())
089                    throw e;
090            }
091        }
092    
093        public void readPoints(List<Point[]> pointss, List<int[]> true_labelss, BufferedReader br) throws IOException {
094            int i, I;
095            int[] is;
096            Point[] ps;
097            List<LabeledPoint> lps = new ArrayList();
098            try {
099                while(true) {
100                    readNextSentence(br, lps);
101                    if((I = lps.size()) != 0) {
102                        is = new int[I];
103                        ps = new Point[I];
104                        for(i = 0;i < I;i++) {
105                            ps[i] = lps.get(i).sp;
106                            is[i] = lps.get(i).label_idx;
107                        }
108                        pointss.add(ps);
109                        true_labelss.add(is);
110                    }
111                    lps.clear();
112                }
113            } catch(EOFException e) {}
114        }
115    }