001 // Copyright 2012, 2013 Brad Block, Pawjaw, LLC. (an Ohio Limited Liability Company) 002 // 003 // This file is part of JBTCRF. 004 // 005 // JBTCRF is free software: you can redistribute it and/or modify 006 // it under the terms of the GNU General Public License as published by 007 // the Free Software Foundation, either version 3 of the License, or 008 // (at your option) any later version. 009 // 010 // JBTCRF is distributed in the hope that it will be useful, 011 // but WITHOUT ANY WARRANTY; without even the implied warranty of 012 // MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 013 // GNU General Public License for more details. 014 // 015 // You should have received a copy of the GNU General Public License 016 // along with JBTCRF. If not, see <http://www.gnu.org/licenses/>. 017 018 package com.pawjaw.classification.crf.lmcbt.readers; 019 020 import com.pawjaw.classification.crf.lmcbt.points.Point; 021 import com.pawjaw.classification.crf.lmcbt.points.SparsePoint; 022 import java.io.BufferedReader; 023 import java.io.EOFException; 024 import java.io.IOException; 025 import java.util.ArrayList; 026 import java.util.List; 027 import java.util.Map; 028 import java.util.TreeMap; 029 import java.util.regex.Pattern; 030 031 public class CoNLL2000LabeledSequenceDataReader extends LabeledSequenceDataReader { 032 private List<String> tokens = new ArrayList(); 033 private List<String> labels = new ArrayList(); 034 private Map<String, Integer> idx_token = new TreeMap(); 035 private Map<String, Integer> idx_label = new TreeMap(); 036 private Pattern splitter = Pattern.compile("\\s+"); 037 038 public int getPointFeatureCount() { 039 return tokens.size(); 040 } 041 042 public int getLabelCount() { 043 return labels.size(); 044 } 045 046 private class LabeledPoint { 047 public int label_idx; 048 public SparsePoint sp = new SparsePoint(); 049 } 050 051 private int getTokenIdx(String token) { 052 Integer idx = idx_token.get(token); 053 if(idx == null) { 054 idx_token.put(token, idx = tokens.size()); 055 tokens.add(token); 056 } 057 return idx; 058 } 059 060 private int getLabelIdx(String label) { 061 Integer idx = idx_label.get(label); 062 if(idx == null) { 063 idx_label.put(label, idx = labels.size()); 064 labels.add(label); 065 } 066 return idx; 067 } 068 069 private LabeledPoint readNextWord(BufferedReader br) throws IOException { 070 String s = br.readLine(); 071 LabeledPoint lp = new LabeledPoint(); 072 String[] ss; 073 if(s == null) 074 throw new EOFException(); 075 if((s = s.trim()).isEmpty() || (ss = splitter.split(s)).length != 3) 076 return null; 077 lp.sp.setFeatures(new int[] { getTokenIdx(ss[0]), getTokenIdx("POS_" + ss[1]) }); 078 lp.label_idx = getLabelIdx(ss[2]); 079 return lp; 080 } 081 082 private void readNextSentence(BufferedReader br, List<LabeledPoint> lps) throws IOException { 083 LabeledPoint lp; 084 try { 085 while((lp = readNextWord(br)) != null) 086 lps.add(lp); 087 } catch(EOFException e) { 088 if(lps.isEmpty()) 089 throw e; 090 } 091 } 092 093 public void readPoints(List<Point[]> pointss, List<int[]> true_labelss, BufferedReader br) throws IOException { 094 int i, I; 095 int[] is; 096 Point[] ps; 097 List<LabeledPoint> lps = new ArrayList(); 098 try { 099 while(true) { 100 readNextSentence(br, lps); 101 if((I = lps.size()) != 0) { 102 is = new int[I]; 103 ps = new Point[I]; 104 for(i = 0;i < I;i++) { 105 ps[i] = lps.get(i).sp; 106 is[i] = lps.get(i).label_idx; 107 } 108 pointss.add(ps); 109 true_labelss.add(is); 110 } 111 lps.clear(); 112 } 113 } catch(EOFException e) {} 114 } 115 }