00001
00002
00003
00004
00005
00006
00007
00008
00009
00010
00011
00012
00013
00014
00015
00016
00017
00018
00019
00020
00021
00022
00023
00026
00027 #include "Parser.h"
00028 #include "EventStream.h"
00029 #include "svm.h"
00030 #include "conf_feature.h"
00031
00032
00033 #include "conf/conf_string.h"
00034 #include "include/unordered_map.h"
00035
00036
00037 #include <algorithm>
00038 #ifdef _WIN32
00039 #include <functional>
00040 #include <algorithm>
00041 #include <stdlib.h>
00042 #include <stdio.h>
00043 #include <io.h>
00044 #include "lib/strtok_r.h"
00045 #else
00046 #include <ext/functional>
00047 #endif
00048
00049 #include <iostream>
00050 #include <list>
00051
00052 using namespace std;
00053
00054 #define MAX_LINE_LEN 8196
00055
00056 namespace Parser {
00057
00059 IXE::conf<int> svmSkip("SvmSkip", 0);
00060
00062 IXE::conf<string> svmParams("SvmParams", "-t 1 -d 2 -g 0.2 -c 0.4 -e 0.1");
00063
00065 extern IXE::conf<bool> CompositeActions;
00066
00067 extern conf_feature SplitFeature;
00068
00074 struct SvmParser : public Parser
00075 {
00076 SvmParser(char const* modelFile);
00077
00078 ~SvmParser() {
00079 for (unsigned i = 0; i < model.size(); i++)
00080 svm_destroy_model(model[i]);
00081 }
00082
00083 void train(SentenceReader* sentenceReader, char const* modelFile);
00084 Sentence* parse(Sentence* sentence);
00085
00086
00087 WordIndex splits;
00088 vector<string> splitNames;
00089 unordered_map<char, char> splitGroup;
00090
00091 WordIndex predIndex;
00092 WordIndex classIndex;
00093 vector<string> classLabels;
00094 vector<svm_model*> model;
00095
00096 private:
00097 void collectEvents(Enumerator<Sentence*>& sentenceReader,
00098 char const* modelFile,
00099 vector<svm_problem>& problem);
00100
00102 bool splitModel() { return !SplitFeature->empty(); }
00103 };
00104
00105 static char* mkext(char* ext, int i)
00106 {
00107 ext[0] = '.'; ext[1] = 'a' + i/26; ext[2] = 'a' + (i%26); ext[3] = '\0';
00108 return ext;
00109 }
00110
00111 SvmParser::SvmParser(char const* modelFile) :
00112 Parser(predIndex)
00113 {
00114 if (!modelFile)
00115 return;
00116 ifstream ifs(modelFile);
00117 if (!ifs)
00118 throw IXE::FileError(string("Missing symbols file: ") + modelFile);
00119
00120 readHeader(ifs);
00121
00122 char line[MAX_LINE_LEN];
00123 if (!ifs.getline(line, MAX_LINE_LEN))
00124 throw IXE::FileError(string("Wrong file format: ") + modelFile);
00125 int len = atoi(line);
00126 int n = 0;
00127 while (len--) {
00128 if (!ifs.getline(line, MAX_LINE_LEN))
00129 throw IXE::FileError(string("Wrong file format: ") + modelFile);
00130 classIndex[(char const*)line] = n++;
00131 classLabels.push_back(line);
00132 }
00133
00134 if (!ifs.getline(line, MAX_LINE_LEN))
00135 throw IXE::FileError(string("Wrong file format: ") + modelFile);
00136 len = atoi(line);
00137 n = 0;
00138 while (len--) {
00139 if (!ifs.getline(line, MAX_LINE_LEN))
00140 throw IXE::FileError(string("Wrong file format: ") + modelFile);
00141 predIndex[(char const*)line] = n++;
00142 }
00143
00144 if (!ifs.getline(line, MAX_LINE_LEN))
00145 throw IXE::FileError(string("Wrong file format: ") + modelFile);
00146 len = atoi(line);
00147 n = 0;
00148 int models = 0;
00149 int skipGroup = (len != 0);
00150 while (len--) {
00151 ifs.getline(line, MAX_LINE_LEN);
00152 char* next = line;
00153
00154 char* code = (line[0] == ' ' || line[0] == '\t') ? (char*)"" : strtok_r(0, " \t", &next);
00155 splits.insert(code);
00156 splitNames.push_back(code);
00157 code = strtok_r(0, " \t", &next);
00158 int group = atoi(code);
00159 splitGroup[n] = group;
00160 models = max(models, group);
00161 if (group == 0) skipGroup = 0;
00162 n++;
00163 }
00164 models++;
00165
00166 info.load(ifs);
00167
00168 model.resize(models);
00169 for (int i = skipGroup; i < models; i++) {
00170 char ext[3];
00171 string modeliFile = string(modelFile) + mkext(ext, i);
00172 model[i] = svm_load_model(modeliFile.c_str());
00173 if (!model[i])
00174 throw IXE::FileError(string("can't open model file ") + modeliFile);
00175 }
00176 }
00177
00183 Parser* SvmParserFactory(char const* modelFile = 0)
00184 {
00185 SvmParser* p = new SvmParser(modelFile);
00186 if (modelFile && p->model.empty()) {
00187 delete p;
00188 return 0;
00189 }
00190 return p;
00191 }
00192
00193 REGISTER_PARSER(SVM, SvmParserFactory);
00194
00195 void parseParameters(svm_parameter& param, char* parameters)
00196 {
00197
00198 param.svm_type = C_SVC;
00199 param.kernel_type = RBF;
00200 param.degree = 3;
00201 param.gamma = 0;
00202 param.coef0 = 0;
00203 param.nu = 0.5;
00204 param.cache_size = 100;
00205 param.C = 1;
00206 param.eps = 1e-3;
00207 param.p = 0.1;
00208 param.shrinking = 1;
00209 param.probability = 0;
00210 param.nr_weight = 0;
00211 param.weight_label = NULL;
00212 param.weight = NULL;
00213
00214 char const* opt = "";
00215 char* next = parameters;
00216
00217 while (opt = strtok_r(0, " \t", &next)) {
00218 if (opt[0] != '-') {
00219 cerr << "Missing option: " << opt << endl;
00220 return;
00221 }
00222 char* tok = strtok_r(0, " \t", &next);
00223 if (!tok) {
00224 cerr << "Missing option value: " << opt << endl;
00225 return;
00226 }
00227 switch (opt[1]) {
00228 case 's':
00229 param.svm_type = atoi(tok);
00230 break;
00231 case 't':
00232 param.kernel_type = atoi(tok);
00233 break;
00234 case 'd':
00235 param.degree = atoi(tok);
00236 break;
00237 case 'g':
00238 param.gamma = atof(tok);
00239 break;
00240 case 'r':
00241 param.coef0 = atof(tok);
00242 break;
00243 case 'n':
00244 param.nu = atof(tok);
00245 break;
00246 case 'm':
00247 param.cache_size = atof(tok);
00248 break;
00249 case 'c':
00250 param.C = atof(tok);
00251 break;
00252 case 'e':
00253 param.eps = atof(tok);
00254 break;
00255 case 'p':
00256 param.p = atof(tok);
00257 break;
00258 case 'h':
00259 param.shrinking = atoi(tok);
00260 break;
00261 case 'b':
00262 param.probability = atoi(tok);
00263 break;
00264 case 'w':
00265 ++param.nr_weight;
00266 param.weight_label = (int *)realloc(param.weight_label,
00267 sizeof(int)*param.nr_weight);
00268 param.weight = (double *)realloc(param.weight,
00269 sizeof(double)*param.nr_weight);
00270 param.weight_label[param.nr_weight-1] = atoi(opt+2);
00271 param.weight[param.nr_weight-1] = atof(tok);
00272 break;
00273 default:
00274 cerr << "unknown option: " << opt << endl;
00275 return;
00276 }
00277 }
00278 }
00279
00280 int compare_nodes(const void* a, const void* b) {
00281 return ((svm_node const*)a)->index - ((svm_node const*)b)->index;
00282 }
00283
00284 int MinimumSvmSize = 5000;
00285
00286 void SvmParser::collectEvents(Enumerator<Sentence*>& sentenceReader,
00287 char const* modelFile,
00288 vector<svm_problem>& problem)
00289 {
00290 WordIndex labelIndex;
00291 vector<string> labels;
00292
00293 vector<string> predLabels;
00294
00295
00296 list<Tanl::Classifier::Event*> events;
00297
00298 WordCounts predCount;
00299
00300 WordCounts splitCount;
00301 vector<char> splitEvents;
00302
00303 int evCount = 0;
00304 Tanl::Classifier::PID pID = 0;
00305
00306
00307 EventStream eventStream(&sentenceReader, &info);
00308
00309 bool doSplit = splitModel();
00310
00311 while (eventStream.hasNext()) {
00312 Tanl::Classifier::Event* ev = eventStream.next();
00313 events.push_back(ev);
00314 evCount++;
00315 if (verbose) {
00316 if (evCount % 10000 == 0)
00317 cerr << '+' << flush;
00318 else if (evCount % 1000 == 0)
00319 cerr << '.' << flush;
00320 }
00321 vector<string>& ec = ev->features;
00322
00323 for (unsigned j = 0; j < ec.size(); j++) {
00324 string& pred = ec[j];
00325
00326 if (predIndex.find(pred.c_str()) == predIndex.end()) {
00327
00328
00329 int count = predCount.add(pred);
00330 if (count >= featureCutoff) {
00331 predLabels.push_back(pred);
00332 predIndex[pred.c_str()] = pID++;
00333 predCount.erase(pred);
00334 }
00335 }
00336 }
00337 if (doSplit) {
00338 string code = eventStream.splitFeature();
00339
00340 char const* ccode = code.c_str();
00341 if (splits.index(ccode) == -1) {
00342 splits.insert(ccode);
00343 splitNames.push_back(ccode);
00344 }
00345
00346 splitEvents.push_back((char)splits[ccode]);
00347 splitCount.add(code);
00348 }
00349 }
00350 if (verbose)
00351 cerr << endl;
00352
00353 predCount.clear();
00354 predCount = WordCounts();
00355
00356
00357 if (doSplit) {
00358
00359 vector<int> splitNewSize(max(1, (int)splits.size()));
00360 splitNewSize[0] = 0;
00361 int models = 1;
00362 FOR_EACH (WordCounts, splitCount, sit) {
00363 char splitCode = (char)splits[sit->first.c_str()];
00364 if (sit->second < MinimumSvmSize || splitCount.size() == 1) {
00365 splitGroup[splitCode] = 0;
00366 splitNewSize[0] += sit->second;
00367 } else {
00368 splitGroup[splitCode] = models;
00369 splitNewSize[models] = sit->second;
00370 models++;
00371 }
00372 }
00373
00374 problem.resize(models);
00375 int skipGroup = splitNewSize[0] == 0;
00376 problem[0].l = 0;
00377 for (int i = skipGroup; i < models; i++) {
00378 int size = splitNewSize[i];
00379 problem[i].y = new double[size];
00380 problem[i].x = new svm_node*[size];
00381 problem[i].l = 0;
00382 }
00383 } else {
00384 problem.resize(1);
00385 problem[0].y = new double[evCount];
00386 problem[0].x = new svm_node*[evCount];
00387 problem[0].l = 0;
00388 }
00389 int nTot = 0;
00390 Tanl::Classifier::ClassID oID = 0;
00391 evCount = 0;
00392 while (!events.empty()) {
00393 Tanl::Classifier::Event* ev = events.front();
00394 events.pop_front();
00395 char const* c = ev->className.c_str();
00396
00397 vector<string>& ec = ev->features;
00398 svm_node* preds = new svm_node[ec.size()+1];
00399 unsigned k = 0;
00400 for (unsigned j = 0; j < ec.size(); j++) {
00401 string& pred = ec[j];
00402 WordIndex::const_iterator pit = predIndex.find(pred.c_str());
00403 if (pit != predIndex.end()) {
00404 svm_node& node = preds[k++];
00405 node.index = pit->second + 1;
00406 node.value = 1.0;
00407 }
00408 }
00409 if (k) {
00410
00411 qsort(preds, k, sizeof(svm_node), compare_nodes);
00412
00413 svm_node& node = preds[k++];
00414 node.index = -1;
00415 node.value = 1.0;
00416 if (labelIndex.find(c) == labelIndex.end()) {
00417 labelIndex[c] = oID++;
00418 labels.push_back(c);
00419 }
00420 int i = 0;
00421 if (!splitEvents.empty())
00422 i = splitGroup[splitEvents[evCount]];
00423 int& ni = problem[i].l;
00424 problem[i].y[ni] = labelIndex[c];
00425
00426 preds = (svm_node*)realloc(preds, k * sizeof(svm_node));
00427 problem[i].x[ni] = preds;
00428 ni++;
00429 nTot++;
00430 if (verbose) {
00431 if (nTot % 10000 == 0)
00432 cerr << '+' << flush;
00433 else if (nTot % 1000 == 0)
00434 cerr << '.' << flush;
00435 }
00436 } else {
00437 cerr << "Discarded event" << endl;
00438 delete preds;
00439 }
00440 evCount++;
00441 delete ev;
00442 }
00443
00444 if (verbose)
00445 cerr << endl;
00446
00447
00448 ofstream ofs(modelFile, ios::binary | ios::trunc);
00449
00450 writeHeader(ofs);
00451
00452 ofs << labels.size() << endl;
00453 FOR_EACH (vector<string>, labels, pit)
00454 ofs << *pit << endl;
00455
00456 ofs << predLabels.size() << endl;
00457 FOR_EACH (vector<string>, predLabels, pit)
00458 ofs << *pit << endl;
00459
00460 ofs << splitNames.size() << endl;
00461 FOR_EACH (vector<string>, splitNames, pit)
00462 ofs << *pit << "\t" << (int)splitGroup[(char)splits[pit->c_str()]] << endl;
00463 info.save(ofs);
00464 ofs.close();
00465
00466 if (verbose) {
00467 cerr << "\t Number of events: " << evCount << endl;
00468 cerr << "\t Number of Classes: " << labels.size() << endl;
00469 cerr << "\tNumber of Predicates: " << predIndex.size() << endl;
00470 }
00471
00472
00473 labels.clear();
00474 predLabels.clear();
00475 predIndex.clear();
00476 predIndex = WordIndex();
00477 labelIndex.clear();
00478 labelIndex = WordIndex();
00479
00480 info.clearRareEntities();
00481 }
00482
00483 void SvmParser::train(SentenceReader* sentenceReader, char const* modelFile)
00484 {
00485
00486 std::deque<Sentence*> sentences = collectSentences(sentenceReader);
00487 SentenceQueueReader sr(sentences);
00488 vector<svm_problem> problem;
00489 collectEvents(sr, modelFile, problem);
00490
00491
00492 svm_parameter param;
00493 parseParameters(param, svmParams);
00494
00495 if (dup2(fileno(stderr), fileno(stdout)) < 0)
00496 cerr << "could not redirect stdout to stderr" << endl;
00497
00498 int models = problem.size();
00499 int skipGroup = problem[0].l == 0;
00500 for (int i = skipGroup; i < models; i++) {
00501 if (i >= svmSkip) {
00502 struct svm_model* model = svm_train(&problem[i], ¶m);
00503
00504 char ext[4];
00505 string modeliFile = string(modelFile) + mkext(ext, i);
00506 svm_save_model(modeliFile.c_str(), model);
00507
00508 svm_destroy_model(model);
00509 }
00510 for (int j = problem[i].l - 1; j >= 0 ; j--)
00511 delete [] problem[i].x[j];
00512 delete [] problem[i].x;
00513 delete [] problem[i].y;
00514 }
00515 svm_destroy_param(¶m);
00516 }
00517
00518 Sentence* SvmParser::parse(Sentence* sentence)
00519 {
00520 vector<svm_node> nodes(predIndex.size());
00521 preprocess(sentence);
00522 ParseState* state = new ParseState(*sentence, &info, predIndex);
00523 while (state->hasNext()) {
00524 Tanl::Classifier::Context& preds = *state->next();
00525
00526 sort(preds.begin(), preds.end());
00527 nodes.resize(preds.size() + 1);
00528 int j = 0;
00529 FOR_EACH (vector<Tanl::Classifier::PID>, preds, pit) {
00530 nodes[j].index = *pit + 1;
00531 nodes[j++].value = 1.0;
00532 }
00533 nodes[preds.size()].index = -1;
00534 string code = state->splitFeature;
00535 int i = splitGroup[splits[code.c_str()]];
00536 double prediction = svm_predict(model[i], &nodes[0]);
00537 string& outcome = classLabels[(int)prediction];
00538 # ifdef DUMP
00539 cerr << classIndex[rightOutcome];
00540 FOR_EACH (vector<Tanl::Classifier::PID>, preds, pit)
00541 cerr << " " << *pit << ":1";
00542 cerr << endl;
00543 # endif
00544 ParseState* next = state->transition(outcome.c_str());
00545 if (!next)
00546 next = state->transition("S");
00547 state = next;
00548 }
00549 Sentence* s = state->getSentence();
00550 state->prune();
00551 return s;
00552 }
00553
00554
00555
00556
00557
00558
00559
00560
00561
00562
00563
00564
00565
00566
00567
00568
00569
00570
00571
00572
00573
00574
00575
00576
00577
00578
00579
00580
00581
00582
00583
00584
00585
00586
00587
00588
00589 }