00001
00002
00003
00004
00005
00006
00007
00008
00009
00010
00011
00012
00013
00014
00015
00016
00017
00018
00019
00020
00021
00022
00023
00026
00027 #ifdef _WIN32
00028 #include "lib/strtok_r.h"
00029 #endif
00030
00031 #include "conf/conf_int.h"
00032 #include "conf/conf_float.h"
00033 #include "EventStream.h"
00034 #include "State.h"
00035 #include "WordCounts.h"
00036
00037
00038 #include <algorithm>
00039 #include <list>
00040
00041 #ifdef HAVE_TR1_RANDOM
00042 #include <random>
00043 #endif
00044
00045
00046 #include <boost/numeric/ublas/matrix.hpp>
00047 #include <boost/numeric/ublas/matrix_proxy.hpp>
00048
00049 using namespace std;
00050 using namespace boost::numeric::ublas;
00051
00052 #include "io/Format.h"
00053 using IXE::io::Format;
00054
00055 using Tanl::Classifier::Event;
00056 using Tanl::Classifier::Context;
00057 using Tanl::Classifier::PID;
00058 using Tanl::Classifier::ClassID;
00059
00060 namespace Parser {
00061
00063
00064
00066 IXE::conf<int> numHidden("MlpHidden", 100);
00067
00069 IXE::conf<int> numLayers("MlpLayers", 1);
00070
00072 IXE::conf<float> LR("MlpLearningRate", 0.01F);
00073
00075 IXE::conf<float> randomRange("RandomRange", 0.05F);
00076
00078 IXE::conf<int> mlpIterations("MlpIterations", 100);
00080 IXE::conf<int> mlpMinIterations("MlpMinIterations", 5);
00081
00083 IXE::conf<int> mlpVainIterations("MlpVainIterations", 3);
00084
00086
00087
00088 #define BEAM
00089
00090 #define MAX_LINE_LEN 8196
00091
00092 typedef boost::numeric::ublas::vector<double> Vector;
00093 typedef matrix<double> Matrix;
00094
00095 void softsign(Vector& x)
00096 {
00097 for (size_t i = 0; i < x.size(); i++)
00098 x[i] /= 1.0 + fabs(x[i]);
00099 }
00100
00101 #define READLINE(line, file) if (!ifs.getline(line, MAX_LINE_LEN)) \
00102 throw IXE::FileError(string("Wrong file format: ") + file)
00103
00104 #define READ_WEIGHT(w, file) if (!ifs.read((char*)&w, sizeof(w))) \
00105 throw IXE::FileError(string("Wrong file format: ") + file)
00106
00111 struct MlpModel : public Classifier::Classifier
00112 {
00113 typedef std::vector<Tanl::Classifier::PID> X;
00114 typedef Tanl::Classifier::ClassID Y;
00115 typedef std::pair<X, Y> Case;
00116 typedef std::vector<Case> Cases;
00117 typedef std::vector<Case*> ValidationSet;
00118
00119 MlpModel() :
00120 numHidden(::Parser::numHidden),
00121 numLayers(::Parser::numLayers)
00122 { }
00123
00124 MlpModel(int numFeatures, int numOutcomes, int numHidden, int numLayers = 1) :
00125 w1(numFeatures, numHidden),
00126 b1(numHidden),
00127 w2(numHidden, numOutcomes),
00128 b2(numOutcomes),
00129 numHidden(numHidden),
00130 numLayers(numLayers)
00131 {
00132 if (numLayers == 2) {
00133 wh.resize(numHidden, numHidden);
00134 bh.resize(numHidden);
00135 }
00136 }
00137
00138 ~MlpModel() {
00139 FOR_EACH (std::vector<char const*>, outcomeLabels, it)
00140 free((void*)*it);
00141 }
00142
00143 list<Event*>
00144 collectEvents(Enumerator<Sentence*>& sentenceReader, GlobalInfo& info);
00145
00147 void buildCases(list<Event*>& events, Cases& cases);
00148
00149 double train(Case&, int&);
00150
00155 void train(Cases& cases, int epoch, ofstream& ofs);
00156
00157 void validate(ValidationSet& vs, double& avg, double& std);
00158
00159 int crossentropy_softmax(Vector& x, double sm[]);
00160
00161 Vector gradCrossentropy(Vector& x, int y);
00162
00164 int estimate(std::vector<PID>& features, double prob[]);
00165
00166 void load(ifstream& ifs, char const* file = "");
00167
00168 void save(ofstream& ofs);
00169
00170
00171 void writeLabels(ofstream& ofs);
00172
00174 streampos writeData(ofstream& ofs);
00175
00176 void clearLabels();
00177
00178 protected:
00179 Matrix w1, w2, wh;
00180 Vector b1, b2, bh;
00181 int numLayers;
00182 int numHidden;
00183 int numFeatures;
00184 WordIndex outcomeIndex;
00185
00186 };
00187
00208 struct MovingAverage
00209 {
00210 double mean;
00211 double variance;
00212 int count;
00213
00214 MovingAverage() :
00215 mean(0.0),
00216 variance(0.0),
00217 count(0)
00218 { }
00219
00221 void add(double v) {
00222 count++;
00223 mean += (2. / count) * (v - mean);
00224 double this_variance = (v - mean) * (v - mean);
00225 variance += (2. / count) * (this_variance - variance);
00226 }
00227 };
00228
00229 void MlpModel::save(ofstream& ofs)
00230 {
00231 writeLabels(ofs);
00232 writeData(ofs);
00233 }
00234
00235 void MlpModel::writeLabels(ofstream& ofs)
00236 {
00237
00238 ofs << outcomeLabels.size() << endl;
00239 FOR_EACH (std::vector<char const*>, outcomeLabels, pit)
00240 ofs << *pit << endl;
00241
00242 ofs << predLabels.size() << endl;
00243 FOR_EACH (std::vector<string>, predLabels, pit)
00244 ofs << *pit << endl;
00245 }
00246
00247 void MlpModel::clearLabels()
00248 {
00249 predIndex.clear();
00250 WordIndex().swap(predIndex);
00251 predLabels.clear();
00252 std::vector<std::string>().swap(predLabels);
00253
00254 outcomeIndex.clear();
00255 WordIndex().swap(outcomeIndex);
00256 FOR_EACH (std::vector<char const*>, outcomeLabels, it)
00257 free((void*)*it);
00258 outcomeLabels.clear();
00259 std::vector<char const*>().swap(outcomeLabels);
00260 }
00261
00262 streampos MlpModel::writeData(ofstream& ofs)
00263 {
00264 streampos startPos = ofs.tellp();
00265
00266
00267 ofs << numLayers << endl;
00268
00269
00270 ofs << numHidden << endl;
00271
00272
00273
00274
00275 ofs << numFeatures * numHidden << endl;
00276 for (int i = 0; i < numFeatures; i++)
00277 for (int j = 0; j < numHidden; j++)
00278 ofs.write((char*)&w1(i, j), sizeof(double));
00279
00280
00281 ofs << numHidden << endl;
00282 for (int i = 0; i < numHidden; i++)
00283 ofs.write((char*)&b1[i], sizeof(double));
00284
00285
00286 ofs << numHidden * numOutcomes << endl;
00287 for (int i = 0; i < numHidden; i++)
00288 for (unsigned j = 0; j < numOutcomes; j++)
00289 ofs.write((char*)&w2(i, j), sizeof(double));
00290
00291
00292 ofs << numOutcomes << endl;
00293 for (unsigned i = 0; i < numOutcomes; i++)
00294 ofs.write((char*)&b2[i], sizeof(double));
00295
00296
00297 if (numLayers == 2) {
00298
00299 ofs << numHidden * numHidden << endl;
00300 for (int i = 0; i < numHidden; i++)
00301 for (int j = 0; j < numHidden; j++)
00302 ofs.write((char*)&wh(i, j), sizeof(double));
00303
00304
00305 ofs << numHidden << endl;
00306 for (int i = 0; i < numHidden; i++)
00307 ofs.write((char*)&bh[i], sizeof(double));
00308 }
00309 return startPos;
00310 }
00311
00312 void MlpModel::load(ifstream& ifs, char const* file)
00313 {
00314
00315 char line[MAX_LINE_LEN];
00316 READLINE(line, file);
00317 int len = atoi(line);
00318 numOutcomes = len;
00319 outcomeLabels.resize(numOutcomes);
00320 int n = 0;
00321 while (len--) {
00322 READLINE(line, file);
00323 outcomeLabels[n] = strdup(line);
00324 outcomeIndex[(char const*)line] = n++;
00325 }
00326
00327 READLINE(line, file);
00328 numFeatures = len = atoi(line);
00329 predLabels.resize(numFeatures);
00330 n = 0;
00331 while (len--) {
00332 READLINE(line, file);
00333 predLabels[n] = line;
00334 predIndex[(char const*)line] = n++;
00335 }
00336
00337
00338 READLINE(line, file);
00339 numLayers = atoi(line);
00340
00341
00342 if (numLayers > 3) {
00343 numHidden = numLayers;
00344 numLayers = 1;
00345 } else {
00346
00347 READLINE(line, file);
00348 numHidden = atoi(line);
00349 }
00350
00351 w1.resize(numFeatures, numHidden);
00352
00353
00354
00355 READLINE(line, file);
00356 len = atoi(line);
00357 if (len != numFeatures * numHidden)
00358 throw IXE::FileError(string("Wrong w1 size: ") + file);
00359 n = 0;
00360 double w;
00361 while (len--) {
00362 READ_WEIGHT(w, file);
00363 w1(n / numHidden, n % numHidden) = w;
00364 n++;
00365 }
00366
00367
00368 b1.resize(numHidden);
00369 READLINE(line, file);
00370 len = atoi(line);
00371 n = 0;
00372 while (len--) {
00373 READ_WEIGHT(w, file);
00374 b1[n++] = w;
00375 }
00376
00377
00378 w2.resize(numHidden, numOutcomes);
00379 READLINE(line, file);
00380 len = atoi(line);
00381 if (len != numHidden * numOutcomes)
00382 throw IXE::FileError(string("Wrong w2 size: ") + file);
00383 n = 0;
00384 while (len--) {
00385 READ_WEIGHT(w, file);
00386 w2(n / numOutcomes, n % numOutcomes) = w;
00387 n++;
00388 }
00389
00390
00391 b2.resize(numOutcomes);
00392 READLINE(line, file);
00393 len = atoi(line);
00394 n = 0;
00395 while (len--) {
00396 READ_WEIGHT(w, file);
00397 b2[n++] = w;
00398 }
00399
00400
00401 if (numLayers == 2) {
00402
00403 wh.resize(numHidden, numHidden);
00404 READLINE(line, file);
00405 len = atoi(line);
00406 if (len != numHidden * numHidden)
00407 throw IXE::FileError(string("Wrong wh size: ") + file);
00408 n = 0;
00409 while (len--) {
00410 READ_WEIGHT(w, file);
00411 wh(n / numHidden, n % numHidden) = w;
00412 n++;
00413 }
00414
00415
00416 bh.resize(numHidden);
00417 READLINE(line, file);
00418 len = atoi(line);
00419 n = 0;
00420 while (len--) {
00421 READ_WEIGHT(w, file);
00422 bh[n++] = w;
00423 }
00424 }
00425 }
00426
00443 int MlpModel::crossentropy_softmax(Vector& x, double sm[])
00444 {
00445
00446 double m = 0.0;
00447 int am = 0;
00448 for (unsigned i = 0; i < numOutcomes; i++)
00449 if (x[i] > m) {
00450 am = i;
00451 m = x[i];
00452 }
00453
00454 double sum_j = 0.0;
00455 for (unsigned i = 0; i < numOutcomes; i++)
00456 sum_j += sm[i] = exp(x[i] - m);
00457
00458
00459 for (unsigned i = 0; i < numOutcomes; i++)
00460 sm[i] /= sum_j;
00461 return am;
00462 }
00463
00464 Vector MlpModel::gradCrossentropy(Vector& x, int y)
00465 {
00466 Vector sm(numOutcomes);
00467 crossentropy_softmax(x, &sm[0]);
00468 sm[y] -= 1.0;
00469 return sm;
00470 }
00471
00472 int MlpModel::estimate(std::vector<PID>& features, double prob[])
00473 {
00474 Vector h(numHidden);
00475 for (int i = 0; i < numHidden; i++)
00476 h[i] = 0.0;
00477 for (size_t f = 0; f < features.size(); f++)
00478 h += row(w1, features[f]);
00479 h += b1;
00480 softsign(h);
00481 if (numLayers > 1) {
00482 h = prod(wh, h) + bh;
00483 softsign(h);
00484 }
00485 Vector o(numOutcomes);
00486 noalias(o) = prod(h, w2) + b2;
00487 return crossentropy_softmax(o, prob);
00488 }
00489
00490 void show(Vector& x)
00491 {
00492 for (size_t i = 0; i < x.size(); i++) {
00493 cerr << x[i] << ' '; if ((i+1) % 5 == 0) cerr << endl; }
00494 cerr << endl;
00495 }
00496
00517 double MlpModel::train(Case& cas, int& argmax)
00518 {
00519 X& features = cas.first;
00520 Y y = cas.second;
00521 Vector xw1(numHidden);
00522 for (int i = 0; i < numHidden; i++)
00523 xw1[i] = 0.0;
00524 for (size_t f = 0; f < features.size(); f++)
00525 xw1 += row(w1, features[f]);
00526 Vector h(numHidden);
00527 noalias(h) = xw1 + b1;
00528 softsign(h);
00529
00530
00531
00532
00533
00534
00535 Vector o(numOutcomes);
00536 noalias(o) = prod(h, w2) + b2;
00537
00538 std::vector<double> softmax(numOutcomes);
00539 argmax = crossentropy_softmax(o, &softmax[0]);
00540 double kl = -log(softmax[y]);
00541
00542 Vector gb2(numOutcomes);
00543 gb2 = gradCrossentropy(o, y);
00544
00545 Matrix gw2(numHidden, numOutcomes);
00546 noalias(gw2) = outer_prod(h, gb2);
00547
00548 Vector hprimeInv(numHidden);
00549 for (int i = 0; i < numHidden; i++) {
00550 double hi = 1. + abs(xw1[i] + b1[i]);
00551 hprimeInv[i] = hi * hi;
00552 }
00553
00554 Vector gb1(numHidden);
00555 noalias(gb1) = element_div(prod(w2, gb2), hprimeInv);
00556
00557
00558 for (size_t i = 0; i < features.size(); i++)
00559 for (int j = 0; j < numHidden; j++)
00560 w1(features[i], j) -= gb1(j) * LR;
00561
00562 b1 -= gb1 * LR;
00563 if (numLayers == 2) {
00564 Matrix gwh(numHidden, numHidden);
00565 Vector gbh(numHidden);
00566 wh -= gwh * LR;
00567 bh -= gbh * LR;
00568 }
00569 w2 -= gw2 * LR;
00570 b2 -= gb2 * LR;
00571
00572 return kl;
00573 }
00574
00580
00581 static void rand_permutation(std::vector<int>& v) {
00582 size_t N = v.size();
00583 for (size_t i = 0; i < N; i++)
00584 v[i] = i;
00585 for (size_t i = 0; i < N-1; i++) {
00586 int r = i + rand() % (N-i);
00587 std::swap(v[i], v[r]);
00588 }
00589 }
00590
00591 void MlpModel::train(Cases& cases, int epochs, ofstream& ofs)
00592 {
00593 if (verbose)
00594 cerr << "Starting training.." << endl;
00595
00596
00597 streampos dataStart = ofs.tellp();
00598
00599 MovingAverage mvgavg_accuracy;
00600 MovingAverage mvgavg_loss;
00601
00602 double best_validation_accuracy = 0.0;
00603 int best_validation_at = 0;
00604 int cnt = 0;
00605
00606 ValidationSet vs;
00607
00608 int numCases = cases.size();
00609 std::vector<int> perm(numCases);
00610 for (int it = 0; it < epochs; ++it) {
00611 if (verbose)
00612 cerr << " EPOCH #" << it << " (" << numCases*it << " examples)" << endl;
00613 rand_permutation(perm);
00614 vs.clear();
00615 for (int i = 0; i < numCases; i++) {
00616 cnt++;
00617 Case& cas = cases[perm[i]];
00618
00619 if (rand() < RAND_MAX / 100)
00620 vs.push_back(&cas);
00621 else {
00622 int argmax;
00623 double kl = train(cas, argmax);
00624 double accuracy = (argmax == cas.second) ? 100.0 : 0.0;
00625 mvgavg_accuracy.add(accuracy);
00626 mvgavg_loss.add(kl);
00627 }
00628 if (verbose && cnt % 1000 == 0) {
00629 if (cnt % 10000 == 0) {
00630 cerr << '+' << flush;
00631 if (cnt % 80000 == 0) {
00632 cerr << endl
00633 << Format("After %d examples: training accuracy = %f, loss = %f",
00634 cnt, mvgavg_accuracy.mean, mvgavg_loss.mean)
00635 << endl;
00636 }
00637 } else if (cnt % 1000 == 0)
00638 cerr << '.' << flush;
00639 }
00640 }
00641
00642 double valacc, valstd;
00643 validate(vs, valacc, valstd);
00644 if (verbose)
00645 cerr << endl
00646 << Format("After %d examples: validation accuracy = %.2f%%, stddev= %.2f%% (former best=%.2f%% at %d)\n",
00647 cnt, valacc*100, valstd*100, best_validation_accuracy*100, best_validation_at)
00648 << Parser::procStat() << endl;;
00649
00650 if (best_validation_accuracy < valacc) {
00651 best_validation_accuracy = valacc;
00652 best_validation_at = cnt;
00653 if (verbose)
00654 cerr << Format("NEW BEST VALIDATION ACCURACY: %.2f%%.", valacc*100.)
00655 << " Saving model...";
00656 ofs.seekp(dataStart);
00657 writeData(ofs);
00658 if (verbose)
00659 cerr << endl;
00660 } else if (cnt - best_validation_at >= mlpVainIterations * numCases &&
00661 cnt / numCases >= mlpMinIterations) {
00662 if (verbose) {
00663 cerr << "Have not beaten best validation accuracy for "
00664 << *mlpVainIterations << " iterations. Terminating training..."
00665 << endl;;
00666 }
00667 break;
00668 }
00669 }
00670 }
00671
00672 void MlpModel::validate(ValidationSet& vs, double& avg, double& std)
00673 {
00674 avg = 0.0;
00675 std = 0.0;
00676 for (size_t i = 0; i < vs.size(); i++) {
00677 Case& cas = *vs[i];
00678 X& x = cas.first;
00679 Y y = cas.second;
00680
00681 std::vector<double> sm(numOutcomes);
00682 int argmax = estimate(x, &sm[0]);
00683
00684 double acc = double(argmax == y);
00685 avg = (avg * i + acc) / (i+1);
00686
00687 std = (std * i + (acc - avg)*(acc - avg)) / (i+1);
00688 }
00689 }
00690
00696 struct MlpParser : public Parser
00697 {
00698 MlpParser(char const* modelFile, int iter);
00699
00700 void train(SentenceReader* sentenceReader, char const* modelFile);
00701
00702 Sentence* parse(Sentence* sentence);
00703
00704 void revise(SentenceReader* sentenceReader, char const* actionFile = 0);
00705
00706 MlpModel model;
00707
00708 int iter;
00709
00710 };
00711
00717 Parser* MlpParserFactory(char const* modelFile = 0)
00718 {
00719 return new MlpParser(modelFile, mlpIterations);
00720 }
00721
00722 REGISTER_PARSER(MLP, MlpParserFactory);
00723
00724 MlpParser::MlpParser(char const* modelFile, int iter) :
00725 Parser(model.PredIndex()),
00726 iter(iter)
00727 {
00728 model.verbose = verbose;
00729 if (!modelFile)
00730 return;
00731 ifstream ifs(modelFile, ios::binary);
00732 if (!ifs)
00733 throw IXE::FileError(string("Missing model file: ") + modelFile);
00734
00735 readHeader(ifs);
00736 model.load(ifs);
00737
00738 info.load(ifs);
00739 ifs.close();
00740 }
00741
00745 list<Event*>
00746 MlpModel::collectEvents(Enumerator<Sentence*>& sentenceReader, GlobalInfo& info)
00747 {
00748 if (verbose)
00749 cerr << "Collecting events.." << endl;
00750
00751 list<Event*> events;
00752
00753 WordCounts predCount;
00754
00755 int evCount = 0;
00756 PID pID = 0;
00757
00758
00759 EventStream eventStream(&sentenceReader, &info);
00760 ClassID oID = 0;
00761
00762 while (eventStream.hasNext()) {
00763 Event* ev = eventStream.next();
00764 events.push_back(ev);
00765 evCount++;
00766 if (verbose) {
00767 if (evCount % 10000 == 0)
00768 cerr << '+' << flush;
00769 else if (evCount % 1000 == 0)
00770 cerr << '.' << flush;
00771 }
00772 char const* c = ev->className.c_str();
00773 if (outcomeIndex.find(c) == outcomeIndex.end()) {
00774 outcomeIndex[c] = oID++;
00775 outcomeLabels.push_back(strdup(c));
00776 }
00777 std::vector<string>& ec = ev->features;
00778
00779 for (unsigned j = 0; j < ec.size(); j++) {
00780 string& pred = ec[j];
00781
00782 if (predIndex.find(pred.c_str()) == predIndex.end()) {
00783
00784
00785 int count = predCount.add(pred);
00786 if (count >= Parser::featureCutoff) {
00787 predLabels.push_back(pred);
00788 predIndex[pred.c_str()] = pID++;
00789 predCount.erase(pred);
00790 }
00791 }
00792 }
00793 }
00794 if (verbose)
00795 cerr << endl;
00796 numOutcomes = oID;
00797 numFeatures = pID;
00798
00799 predCount.clear();
00800 WordCounts().swap(predCount);
00801
00802 return events;
00803 }
00804
00805 #ifdef HAVE_TR1_RANDOM
00806 #define TWO48 281474976710656
00807 std::tr1::linear_congruential<unsigned long long, 25214903917, 11, TWO48> drand48Gen;
00808 #define srand48(x) drand48Gen.seed(x)
00809 #define drand48() drand48Gen()/double(TWO48)
00810 #endif
00811
00812 #define RAND_WEIGHT (randomRange * (2.0 * drand48() - 1.0))
00813
00814 void MlpModel::buildCases(list<Event*>& events, Cases& cases)
00815 {
00816
00817 srand48(1);
00818
00819 w1.resize(numFeatures, numHidden);
00820
00821 for (int i = 0; i < numFeatures; i++)
00822 for (int j = 0; j < numHidden; j++)
00823 w1(i, j) = RAND_WEIGHT;
00824
00825 b1.resize(numHidden);
00826 for (int i = 0; i < numHidden; i++)
00827 b1(i) = 0.0;
00828 w2.resize(numHidden, numOutcomes);
00829
00830 for (int i = 0; i < numHidden; i++)
00831 for (unsigned j = 0; j < numOutcomes; j++)
00832 w2(i, j) = RAND_WEIGHT;
00833
00834 b2.resize(numOutcomes);
00835 for (unsigned i = 0; i < numOutcomes; i++)
00836 b2(i) = 0.0;
00837 if (numLayers == 2) {
00838 wh.resize(numHidden, numHidden);
00839 for (int i = 0; i < numHidden; i++)
00840 for (int j = 0; j < numHidden; j++)
00841 wh(i, j) = RAND_WEIGHT;
00842 bh.resize(numHidden);
00843 for (int i = 0; i < numHidden; i++)
00844 bh(i) = 0.0;
00845 }
00846
00847 size_t evCount = events.size();
00848 cases.reserve(evCount);
00849 int n = 0;
00850 while (!events.empty()) {
00851 Event* ev = events.front();
00852 events.pop_front();
00853 cases.push_back(Case());
00854 X& x = cases[n].first;
00855
00856 std::vector<string>& ec = ev->features;
00857 char const* c = ev->className.c_str();
00858 for (unsigned j = 0; j < ec.size(); j++) {
00859 string& pred = ec[j];
00860 WordIndex::const_iterator pit = predIndex.find(pred.c_str());
00861 if (pit != predIndex.end()) {
00862 x.push_back(pit->second);
00863 }
00864 }
00865 if (x.size()) {
00866 cases[n].second = outcomeIndex[c];
00867 n++;
00868 if (verbose) {
00869 if (n % 10000 == 0)
00870 cerr << '+' << flush;
00871 else if (n % 1000 == 0)
00872 cerr << '.' << flush;
00873 }
00874
00875 }
00876 delete ev;
00877 }
00878 cases.resize(n);
00879 if (verbose)
00880 cerr << endl;
00881 if (verbose) {
00882 cerr << "\t Number of events: " << evCount << endl;
00883 cerr << "\t Number of Classes: " << outcomeLabels.size() << endl;
00884 cerr << "\tNumber of Predicates: " << predIndex.size() << endl;
00885 }
00886 }
00887
00888 void MlpParser::train(SentenceReader* sentenceReader, char const* modelFile)
00889 {
00890
00891 std::deque<Sentence*> sentences = collectSentences(sentenceReader);
00892 SentenceQueueReader sr(sentences);
00893 list<Event*> events = model.collectEvents(sr, info);
00894
00895 MlpModel::Cases cases;
00896 model.buildCases(events, cases);
00897
00898 ofstream ofs(modelFile, ios::binary | ios::trunc);
00899
00900 writeHeader(ofs);
00901
00902
00903 model.writeLabels(ofs);
00904 streampos dataStart = model.writeData(ofs);
00905
00906
00907 model.clearLabels();
00908
00909
00910 info.clearRareEntities();
00911
00912 info.save(ofs);
00913 info.clear();
00914 ofs.seekp(dataStart);
00915
00916
00917 model.train(cases, iter, ofs);
00918 }
00919
00920 #ifdef BEAM
00921
00927 static double addState(ParseState* s, std::vector<ParseState*>& states)
00928 {
00929 int size = states.size();
00930 assert(size < beam || s->lprob > states[size-1]->lprob);
00931 s->incRef();
00932
00933 if (size == beam) {
00934 ParseState* last = states.back();
00935 last->decRef();
00936 last->prune();
00937 states.pop_back();
00938 }
00939 TO_EACH (std::vector<ParseState*>, states, it)
00940 if (s->lprob > (*it)->lprob) {
00941 states.insert(it, s);
00942 return states.back()->lprob;
00943 }
00944 states.push_back(s);
00945 return s->lprob;
00946 }
00947 #endif
00948
00949 Sentence* MlpParser::parse(Sentence* sentence)
00950 {
00951 int numOutcomes = model.NumOutcomes();
00952 # ifdef _MSC_VER
00953 std::vector<double> paramsV(numOutcomes);
00954 double* params = ¶msV[0];
00955 # else
00956 double params[numOutcomes];
00957 # endif
00958 preprocess(sentence);
00959 ParseState* state = new ParseState(*sentence, &info, predIndex);
00960
00961 # ifdef BEAM
00962 Language const* lang = sentence->language;
00963
00964 std::vector<ParseState*> currStates; currStates.reserve(beam);
00965 std::vector<ParseState*> nextStates; nextStates.reserve(beam);
00966 std::vector<ParseState*>* bestStates = &currStates;
00967 std::vector<ParseState*>* bestNextStates = &nextStates;
00968 addState(state, *bestStates);
00969 int step = 0;
00970
00971 while (true) {
00972 # ifdef DEBUG_MORPH
00973 cerr << "STEP: " << step++ << endl;
00974 # endif
00975 int finished = 0;
00976 int numBest = bestStates->size();
00977 double worstProb = -numeric_limits<double>::infinity();
00978 for (int i = 0; i < numBest; i++) {
00979 state = (*bestStates)[i];
00980 if (state->hasNext()) {
00981 Tanl::Classifier::Context& context = *state->next();
00982 model.estimate(context, params);
00983 for (int o = 0; o < numOutcomes; o++) {
00984 double prob = params[o];
00985 if (prob < 1e-4)
00986 continue;
00987 double lprob = log(prob) + state->lprob;
00988 if (bestNextStates->size() == beam && lprob < worstProb)
00989 continue;
00990 char const* outcome = model.OutcomeName(o);
00991 if (state->stack.size() > 1 && state->input.size()) {
00992 Token* top = state->stack.back()->token;
00993 Token* next = state->input.back()->token;
00994 if ((outcome[0] == 'L' && !lang->morphoRight(*next->pos()) ||
00995 outcome[0] == 'R' && !lang->morphoLeft(*top->pos())) &&
00996 ((top->morpho.number && next->morpho.number &&
00997 !lang->numbAgree(top->morpho.number, next->morpho.number)) ||
00998 (top->morpho.gender && next->morpho.gender &&
00999 !lang->gendAgree(top->morpho.gender, next->morpho.gender)))) {
01000 lprob = log(prob / 10.) + state->lprob;
01001 if (bestNextStates->size() == beam && lprob < worstProb)
01002 continue;
01003 }
01004 }
01005 ParseState* next = state->transition(outcome);
01006 if (!next)
01007 continue;
01008 # ifdef DEBUG_MORPH
01009 cerr << i << " " << outcome << " " << prob << " " << lprob << " ";
01010 for (ParseState* s = next; s; s = (ParseState*)s->previous)
01011 if (s->action)
01012 cerr << s->action << ' ';
01013 cerr << endl;
01014 # endif
01015 next->lprob = lprob;
01016 worstProb = addState(next, *bestNextStates);
01017 }
01018 } else if (bestNextStates->size() < (size_t)beam || state->lprob > worstProb) {
01019
01020 worstProb = addState(state, *bestNextStates);
01021 finished++;
01022 }
01023 }
01024 if (bestNextStates->empty())
01025 break;
01026
01027
01028
01029 FOR_EACH (std::vector<ParseState*>, *bestStates, it) {
01030 ParseState* state = (*it);
01031 state->decRef();
01032 state->prune();
01033 }
01034 bestStates->clear();
01035
01036 std::vector<ParseState*>* tmp = bestStates;
01037 bestStates = bestNextStates;
01038 bestNextStates = tmp;
01039 if (finished == numBest)
01040 break;
01041 }
01042
01043 Sentence* s = (*bestStates)[0]->getSentence();
01044 # ifdef DEBUG_MORPH
01045 for (int i = 0; i < bestStates->size(); i++) {
01046 ParseState* state = (*bestStates)[i];
01047 for (ParseState* s = state; s; s = (ParseState*)s->previous)
01048 if (s->action)
01049 cerr << s->action << ' ';
01050 cerr << endl;
01051 cerr << "lprob: " << state->lprob << endl;
01052 }
01053 # endif
01054 if (showTreelets) {
01055 for (ParseState* s = (*bestStates)[0]; s; s = (ParseState*)s->previous) {
01056 Action action = s->action;
01057 if (action && s->input.size()) {
01058 if (action[0] == 'R' || action[0] == 'r' ||
01059 action[0] == 'L' || action[0] == 'l') {
01060 TreeToken* tok = s->input.back();
01061 cout << '\t';
01062 tok->printLeaves(cout);
01063 cout << endl;
01064 }
01065 }
01066 }
01067 }
01068 # ifdef SHOW_LIKELIHOOD
01069 if (verbose) {
01070 double avg = 0.0;
01071 double min = 1.0;
01072 double n = 0.0;
01073 for (ParseState* s = (*bestStates)[0]; s; s = (ParseState*)s->previous) {
01074 double prob = exp(s->lprob);
01075 avg += prob;
01076 if (prob < min)
01077 min = prob;
01078 n++;
01079 }
01080 avg /= n;
01081 cout << "LogLikelihood: " << (*bestStates)[0]->lprob;
01082 cout << '\t' << avg << '\t' << min << endl;
01083 }
01084 # endif
01085 FOR_EACH (std::vector<ParseState*>, *bestStates, it) {
01086 (*it)->decRef();
01087 (*it)->prune();
01088 }
01089
01090 # ifdef ORACLE
01091
01092 TrainState ts(*sentence, &info);
01093 while (ts.hasNext()) {
01094 Event* ev = ts.next();
01095 string& action = ev->className;
01096 Context context(ev->features, predIndex);
01097 ClassID best = model.estimate(context, params);
01098 if (action == model.OutcomeName(best))
01099 oracleCorrect++;
01100 oracleCount++;
01101 ts.transition(action.c_str());
01102 }
01103 # endif
01104
01105 return s;
01106 # else
01107
01108 while (state->hasNext()) {
01109 Context& context = *state->next();
01110 model.estimate(context, params);
01111 int best = model.BestOutcome(params);
01112 char const* outcome = model.OutcomeName(best);
01113 ParseState* next = state->transition(outcome);
01114 if (!next)
01115 next = state->transition("S");
01116 state = next;
01117 }
01118 Sentence* s = state->getSentence();
01119 state.prune();
01120 return s;
01121 # endif // BEAM
01122 }
01123
01124 void MlpParser::revise(SentenceReader* sentenceReader, char const* actionFile)
01125 {}
01126
01127 }