00001
00002
00003
00004
00005
00006
00007
00008
00009
00010
00011
00012
00013
00014
00015
00016
00017
00018
00019
00020
00021
00022
00023
00026
00027 #include "platform.h"
00028
00029 #ifdef _WIN32
00030 #include <io.h>
00031 #include <algorithm>
00032 #endif
00033
00034 #include "Parser.h"
00035 #include "EventStream.h"
00036 #include "svm.h"
00037
00038
00039 #include "conf/conf_string.h"
00040
00041
00042 #include <algorithm>
00043 #include <iostream>
00044 #include <list>
00045
00046 using namespace std;
00047
00048 #define MAX_LINE_LEN 8196
00049
00050 namespace Parser {
00051
00053 extern IXE::conf<string> svmParams;
00054
00061 struct MultiSvmParser : public Parser
00062 {
00063 MultiSvmParser(char const* modelFile);
00064
00065 ~MultiSvmParser() {
00066 for (unsigned i = 0; i < model.size(); i++)
00067 svm_destroy_model(model[i]);
00068 }
00069
00070 void train(SentenceReader* sentenceReader, char const* modelFile);
00071 Sentence* parse(Sentence* sentence);
00072
00073
00074 WordIndex classIndex;
00075 vector<string> classLabels;
00076 vector<struct svm_model*> model;
00077 };
00078
00084 Parser* MultiSvmParserFactory(char const* modelFile = 0)
00085 {
00086 return new MultiSvmParser(modelFile);
00087 }
00088
00089 REGISTER_PARSER(MSVM, MultiSvmParserFactory);
00090
00091 extern void parseParameters(svm_parameter& param, char* parameters);
00092
00093 static int compare_nodes(const void* a, const void* b) {
00094 return ((svm_node const*)a)->index - ((svm_node const*)b)->index;
00095 }
00096
00097 static const char* actType = "AD";
00098
00099 enum ActionType { Shift, Reduce };
00100
00101 MultiSvmParser::MultiSvmParser(char const* modelFile) :
00102 Parser(predIndex)
00103 {
00104 if (!modelFile)
00105 return;
00106 ifstream ifs(modelFile);
00107 if (!ifs)
00108 throw IXE::FileError(string("Missing model file: ") + modelFile);;
00109
00110 readHeader(ifs);
00111
00112 char line[MAX_LINE_LEN];
00113 if (!ifs.getline(line, MAX_LINE_LEN))
00114 throw IXE::FileError(string("Wrong file format: ") + modelFile);
00115 int len = atoi(line);
00116 int n = 0;
00117 while (len--) {
00118 if (!ifs.getline(line, MAX_LINE_LEN))
00119 throw IXE::FileError(string("Wrong file format: ") + modelFile);
00120 classIndex[(char const*)line] = n++;
00121 classLabels.push_back(line);
00122 }
00123
00124 if (!ifs.getline(line, MAX_LINE_LEN))
00125 throw IXE::FileError(string("Wrong file format: ") + modelFile);
00126 len = atoi(line);
00127 n = 0;
00128 while (len--) {
00129 if (!ifs.getline(line, MAX_LINE_LEN))
00130 throw IXE::FileError(string("Wrong file format: ") + modelFile);
00131 predIndex[(char const*)line] = n++;
00132 }
00133
00134 int models = 2;
00135 model.resize(models);
00136 for (int i = 0; i < models; i++) {
00137 string modeliFile = (string(modelFile) + '.') + actType[i];
00138 model[i] = svm_load_model(modeliFile.c_str());
00139 if (!model[i])
00140 throw IXE::FileError(string("can't open model file: ") + modeliFile);
00141 }
00142
00143 ifstream ent(modelFile);
00144 if (!ent)
00145 throw IXE::FileError(string("Missing entities file: ") + modelFile);
00146 info.load(ent);
00147 ent.close();
00148 }
00149
00150 void MultiSvmParser::train(SentenceReader* sentenceReader, char const* modelFile)
00151 {
00152 WordIndex labelIndex;
00153 vector<string> labels;
00154
00155 vector<string> predLabels;
00156
00157
00158 list<Tanl::Classifier::Event*> events;
00159
00160 WordCounts predCount;
00161
00162 int actionCount[2] = {0, 0};
00163 int prevAction = Shift;
00164
00165 int evCount = 0;
00166 Tanl::Classifier::PID pID = 0;
00167
00168
00169 EventStream eventStream(sentenceReader, &info);
00170 while (eventStream.hasNext()) {
00171 Tanl::Classifier::Event* ev = eventStream.next();
00172 events.push_back(ev);
00173 evCount++;
00174 if (verbose) {
00175 if (evCount % 10000 == 0)
00176 cerr << '+' << flush;
00177 else if (evCount % 1000 == 0)
00178 cerr << '.' << flush;
00179 }
00180 vector<string>& ec = ev->features;
00181 for (unsigned j = 0; j < ec.size(); j++) {
00182 string& pred = ec[j];
00183
00184 if (predIndex.find(pred.c_str()) == predIndex.end()) {
00185
00186 WordCounts::iterator wcit = predCount.find(pred);
00187
00188 int count;
00189 if (wcit == predCount.end())
00190 count = predCount[pred] = 1;
00191 else
00192 count = ++wcit->second;
00193 if (count >= featureCutoff) {
00194 predLabels.push_back(pred);
00195 predIndex[pred.c_str()] = pID++;
00196 predCount.erase(pred);
00197 }
00198 }
00199 }
00200 actionCount[prevAction]++;
00201 char a = toupper(ev->className[0]);
00202 prevAction = ActionType(a == 'R' || a == 'L');
00203 }
00204 if (verbose)
00205 cerr << endl;
00206
00207
00208 int models = 2;
00209 vector<svm_problem> problem(models);
00210 for (int i = 0; i < models; i++) {
00211 int size = actionCount[i];
00212 problem[i].y = new double[size];
00213 problem[i].x = new svm_node*[size];
00214 problem[i].l = 0;
00215 }
00216 prevAction = Shift;
00217 int nTot = 0;
00218 Tanl::Classifier::ClassID oID = 0;
00219 while (!events.empty()) {
00220 Tanl::Classifier::Event* ev = events.front();
00221 events.pop_front();
00222 char const* c = ev->className.c_str();
00223
00224 vector<string>& ec = ev->features;
00225 svm_node* preds = new svm_node[ec.size()+1];
00226 unsigned k = 0;
00227 for (unsigned j = 0; j < ec.size(); j++) {
00228 string& pred = ec[j];
00229 WordIndex::const_iterator pit = predIndex.find(pred.c_str());
00230 if (pit != predIndex.end()) {
00231 svm_node& node = preds[k++];
00232 node.index = pit->second + 1;
00233 node.value = 1.0;
00234 }
00235 }
00236 if (k) {
00237
00238 qsort(preds, k, sizeof(svm_node), compare_nodes);
00239
00240 svm_node& node = preds[k++];
00241 node.index = -1;
00242 node.value = 1.0;
00243 if (labelIndex.find(c) == labelIndex.end()) {
00244 labelIndex[c] = oID++;
00245 labels.push_back(c);
00246 }
00247 int i = prevAction;
00248 int& ni = problem[i].l;
00249 problem[i].y[ni] = labelIndex[c];
00250
00251 preds = (svm_node*)realloc(preds, k * sizeof(svm_node));
00252 problem[i].x[ni] = preds;
00253 ni++;
00254 nTot++;
00255 if (verbose) {
00256 if (nTot % 10000 == 0)
00257 cerr << '+' << flush;
00258 else if (nTot % 1000 == 0)
00259 cerr << '.' << flush;
00260 }
00261 } else {
00262 cerr << "Discarded event" << endl;
00263 delete preds;
00264 }
00265 char a = toupper(c[0]);
00266 prevAction = ActionType(a == 'R' || a == 'L');
00267 delete ev;
00268 }
00269
00270 if (verbose)
00271 cerr << endl;
00272
00273 ofstream ofs(modelFile, ios::binary | ios::trunc);
00274
00275 writeHeader(ofs);
00276
00277 ofs << labels.size() << endl;
00278 FOR_EACH (vector<string>, labels, pit)
00279 ofs << *pit << endl;
00280
00281 ofs << predLabels.size() << endl;
00282 FOR_EACH (vector<string>, predLabels, pit)
00283 ofs << *pit << endl;
00284
00285 predIndex.clear();
00286 predIndex = WordIndex();
00287 labelIndex.clear();
00288 labelIndex = WordIndex();
00289
00290 info.clearRareEntities();
00291
00292 svm_parameter param;
00293 parseParameters(param, svmParams);
00294
00295 if (dup2(fileno(stderr), fileno(stdout)) < 0)
00296 cerr << "could not redirect stdout to stderr" << endl;
00297
00298 for (int i = 0; i < models; i++) {
00299 struct svm_model* model = svm_train(&problem[i], ¶m);
00300
00301 string modeliFile = (string(modelFile) + '.') + actType[i];
00302 svm_save_model(modeliFile.c_str(), model);
00303
00304 svm_destroy_model(model);
00305 for (int j = problem[i].l - 1; j >= 0 ; j--)
00306 delete [] problem[i].x[j];
00307 delete [] problem[i].x;
00308 delete [] problem[i].y;
00309 }
00310 svm_destroy_param(¶m);
00311 }
00312
00313 Sentence* MultiSvmParser::parse(Sentence* sentence)
00314 {
00315 int prevAction = Shift;
00316 vector<svm_node> nodes(predIndex.size());
00317 preprocess(sentence);
00318 ParseState state(*sentence, &info, predIndex);
00319 while (state.hasNext()) {
00320 Tanl::Classifier::Context& preds = *state.next();
00321
00322 sort(preds.begin(), preds.end());
00323 nodes.resize(preds.size() + 1);
00324 int j = 0;
00325 FOR_EACH (vector<Tanl::Classifier::PID>, preds, pit) {
00326 nodes[j].index = *pit + 1;
00327 nodes[j++].value = 1.0;
00328 }
00329 nodes[preds.size()].index = -1;
00330 int i = prevAction;
00331 double prediction = svm_predict(model[i], &nodes[0]);
00332 string& outcome = classLabels[(int)prediction];
00333 # ifdef DUMP
00334 cerr << classIndex[rightOutcome];
00335 FOR_EACH (vector<Tanl::Classifier::PID>, preds, pit)
00336 cerr << " " << *pit << ":1";
00337 cerr << endl;
00338 # endif
00339
00340 char a = toupper(outcome[0]);
00341 prevAction = ActionType(a == 'R' || a == 'L');
00342 if (!state.transition(outcome.c_str())) {
00343 state.transition("S");
00344 }
00345 }
00346 return state.getSentence();
00347 }
00348
00349 }