00001
00002
00003
00004
00005
00006
00007
00008
00009
00010
00011
00012
00013
00014
00015
00016
00017
00018
00019
00020
00021
00022
00023
00024 #ifndef Tanl_Classifier_LBFGS_H
00025 #define Tanl_Classifier_LBFGS_H
00026
00027 #include <float.h>
00028
00029 #include "MaxEnt.h"
00030
00031 namespace Tanl {
00032 namespace Classifier {
00033
00059 typedef std::vector<std::vector<std::pair<ClassID, int> > > ContribTable;
00060
00061 struct Options {
00062 int n;
00063 int m;
00064 int niter;
00065 int nfuns;
00066 int iflag;
00067 int iprint[2];
00068 double eps;
00069 double xtol;
00070 int diagco;
00071 double* diag;
00072 double* w;
00073
00074 Options(int n, double eps) :
00075 n(0),
00076 m(5),
00077 niter(0),
00078 nfuns(0),
00079 iflag(0),
00080 eps(eps),
00081 xtol(DBL_EPSILON),
00082 diagco(0),
00083 diag(0),
00084 w(0)
00085 {
00086 iprint[0] = -1;
00087 iprint[1] = 0;
00088 }
00089
00090 ~Options() {
00091 delete[] w;
00092 delete[] diag;
00093 }
00094 };
00095
00099 class LBFGS : public MaxEnt
00100 {
00101 public:
00110 LBFGS(EventStream& es, int iterations, int cutoff = 0, double eps = 1E-05);
00111
00121 LBFGS(int iterations = 50, int cutoff = 0, double eps = 1E-05);
00122
00123 ~LBFGS() {
00124 delete[] lambda;
00125 }
00126
00130 void train();
00131
00135 void train(EventStream& es);
00136
00137 void save(char const* path);
00138
00139 void writeHeader(std::ofstream& ofs);
00140 void writeData(std::ofstream& ofs);
00141
00142 protected:
00143
00144 double* lambda;
00145
00146 private:
00147 ClassID estimate(const std::vector<PID>& preds, double probs[]);
00148
00149 int numEvents;
00150
00151 Options opt;
00152 ContribTable contribTable;
00153
00154 };
00155
00156 }
00157 }
00158
00159 #endif // Tanl_Classifier_LBFGS_H