00001
00002
00003
00004
00005
00006
00007
00008
00009
00010
00011
00012
00013
00014
00015
00016
00017
00018
00019
00020
00021
00022
00023
00026
00027 #ifdef _WIN32
00028 #include "lib/strtok_r.h"
00029 #endif
00030
00031 #include "conf/conf_int.h"
00032 #include "Parser.h"
00033 #include "EventStream.h"
00034 #include "State.h"
00035
00036
00037 #include "GIS.h"
00038 #include "LBFGS.h"
00039
00040
00041 #include <math.h>
00042 #include <assert.h>
00043
00044 using namespace std;
00045
00046 namespace Parser {
00047
00048 IXE::conf<int> iterations("MEiter", 60);
00049
00050
00051 #define BEAM
00052
00058 struct MeParser : public Parser
00059 {
00060 MeParser(char const* modelFile);
00061
00062 void train(SentenceReader* sentenceReader, char const* modelFile);
00063
00064 Sentence* parse(Sentence* sentence);
00065
00066 void revise(SentenceReader* sentenceReader, char const* actionFile = 0);
00067
00068 Tanl::Classifier::MaxEnt model;
00069 };
00070
00076 Parser* MeParserFactory(char const* modelFile = 0)
00077 {
00078 return new MeParser(modelFile);
00079 }
00080
00081 REGISTER_PARSER(ME, MeParserFactory);
00082
00083 MeParser::MeParser(char const* modelFile) :
00084 Parser(model.PredIndex())
00085 {
00086 if (!modelFile)
00087 return;
00088 ifstream ifs(modelFile);
00089 if (!ifs)
00090 throw IXE::FileError(string("Missing model file: ") + modelFile);
00091
00092 readHeader(ifs);
00093 model.load(ifs);
00094
00095 info.load(ifs);
00096 ifs.close();
00097 }
00098
00099 void MeParser::train(SentenceReader* sentenceReader, char const* modelFile)
00100 {
00101 EventStream eventStream(sentenceReader, &info);
00102 Tanl::Classifier::LBFGS model(iterations, featureCutoff);
00103 model.verbose = verbose;
00104 if (verbose)
00105 cerr << "Collecting events.." << endl;
00106 model.read(eventStream);
00107 ofstream ofs(modelFile, ios::binary | ios::trunc);
00108
00109 writeHeader(ofs);
00110 model.train();
00111
00112 model.writeHeader(ofs);
00113 model.writeData(ofs);
00114
00115 info.save(ofs);
00116 }
00117
00118 #ifdef BEAM
00119
00125 static double addState(ParseState* s, vector<ParseState*>& states)
00126 {
00127 int size = states.size();
00128 assert(size < beam || s->lprob < states[size-1]->lprob);
00129 s->incRef();
00130 if (size == 0) {
00131 states.push_back(s);
00132 return s->lprob;
00133 }
00134
00135
00136
00137
00138
00139
00140 TO_EACH (std::vector<ParseState*>, states, it)
00141 if (s->lprob > (*it)->lprob) {
00142 if (size == beam) {
00143 states.back()->decRef();
00144 states.back()->prune();
00145 states.back() = s;
00146 } else
00147 states.insert(it, s);
00148 return states.back()->lprob;
00149 }
00150 if (size < beam)
00151 states.push_back(s);
00152 return states.back()->lprob;
00153 }
00154 #endif
00155
00156 Sentence* MeParser::parse(Sentence* sentence)
00157 {
00158 int numOutcomes = model.NumOutcomes();
00159 double* params = new double[numOutcomes];
00160 preprocess(sentence);
00161 ParseState* state = new ParseState(*sentence, &info, predIndex);
00162
00163 # ifdef BEAM
00164 vector<ParseState*> currStates; currStates.reserve(beam);
00165 vector<ParseState*> nextStates; nextStates.reserve(beam);
00166 vector<ParseState*>* bestStates = &currStates;
00167 vector<ParseState*>* bestNextStates = &nextStates;
00168 addState(state, *bestStates);
00169
00170 while (true) {
00171 int finished = 0;
00172 int numBest = bestStates->size();
00173 double worstProb = -numeric_limits<double>::infinity();
00174 for (int i = 0; i < numBest; i++) {
00175 state = (*bestStates)[i];
00176 if (state->hasNext()) {
00177 Tanl::Classifier::Context& context = *state->next();
00178 model.estimate(context, params);
00179 for (int o = 0; o < numOutcomes; o++) {
00180 double prob = params[o];
00181 if (prob < 1e-4)
00182 continue;
00183 double lprob = log(prob) + state->lprob;
00184 if (bestNextStates->size() == beam && lprob < worstProb)
00185 continue;
00186 char const* outcome = model.OutcomeName(o);
00187 ParseState* next = state->transition(outcome);
00188 if (!next)
00189 continue;
00190 next->lprob = lprob;
00191 worstProb = addState(next, *bestNextStates);
00192 }
00193 } else if (bestNextStates->size() < (size_t)beam || state->lprob > worstProb) {
00194
00195 worstProb = addState(state, *bestNextStates);
00196 finished++;
00197 }
00198 }
00199 if (bestNextStates->empty())
00200 break;
00201
00202
00203
00204 FOR_EACH (std::vector<ParseState*>, *bestStates, it) {
00205 ParseState* state = (*it);
00206 state->decRef();
00207 state->prune();
00208 }
00209 bestStates->clear();
00210
00211 vector<ParseState*>* tmp = bestStates;
00212 bestStates = bestNextStates;
00213 bestNextStates = tmp;
00214 if (finished == numBest)
00215 break;
00216 }
00217
00218 Sentence* s = (*bestStates)[0]->getSentence();
00219 FOR_EACH (std::vector<ParseState*>, *bestStates, it) {
00220 (*it)->decRef();
00221 (*it)->prune();
00222 }
00223 # else
00224
00225 while (state->hasNext()) {
00226 Tanl::Classifier::Context& context = *state->next();
00227 model.estimate(context, params);
00228 int best = model.BestOutcome(params);
00229 char const* outcome = model.OutcomeName(best);
00230 ParseState* next = state->transition(outcome);
00231 if (!next)
00232 next = state->transition("S");
00233 state = next;
00234 }
00235 Sentence* s = state->getSentence();
00236 state->prune();
00237 # endif // BEAM
00238
00239
00240 delete[] params;
00241
00242 return s;
00243 }
00244
00245 void MeParser::revise(SentenceReader* sentenceReader, char const* actionFile)
00246 {
00247 if (actionFile) {
00248
00249 ifstream ifs(actionFile);
00250 WordIndex predIndex;
00251
00252 ReviseContextStream contextStream(sentenceReader, predIndex);
00253
00254 char line[4000];
00255 while (contextStream.hasNext()) {
00256 ++contextStream.cur;
00257 ifs.getline(line, sizeof(line));
00258 char* next = line;
00259 char const* outcome = strtok_r(0, " \t", &next);
00260 contextStream.actions.push_back(outcome);
00261 }
00262 } else {
00263 int numOutcomes = model.NumOutcomes();
00264 double* params = new double[numOutcomes];
00265 int correct = 0;
00266 int wrong = 0;
00267
00268 ReviseContextStream contextStream(sentenceReader, model.PredIndex());
00269
00270 while (contextStream.hasNext()) {
00271 Tanl::Classifier::Context& context = *contextStream.next();
00272 model.estimate(context, params);
00273 char const* outcome = model.OutcomeName(model.BestOutcome(params));
00274 contextStream.actions.push_back(outcome);
00275 }
00276 }
00277 }
00278
00279 }