207 lines
6.6 KiB
C++
207 lines
6.6 KiB
C++
/*
|
|
|
|
Copyright (C) 1999,2000,2001 Franz Josef Och (RWTH Aachen - Lehrstuhl fuer Informatik VI)
|
|
|
|
This file is part of GIZA++ ( extension of GIZA ).
|
|
|
|
This program is free software; you can redistribute it and/or
|
|
modify it under the terms of the GNU General Public License
|
|
as published by the Free Software Foundation; either version 2
|
|
of the License, or (at your option) any later version.
|
|
|
|
This program is distributed in the hope that it will be useful,
|
|
but WITHOUT ANY WARRANTY; without even the implied warranty of
|
|
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
|
GNU General Public License for more details.
|
|
|
|
You should have received a copy of the GNU General Public License
|
|
along with this program; if not, write to the Free Software
|
|
Foundation, Inc., 59 Temple Place - Suite 330, Boston, MA 02111-1307,
|
|
USA.
|
|
|
|
*/
|
|
#ifndef HMM_TABLES_H_ASDF_DEFINED
|
|
#define HMM_TABLES_H_ASDF_DEFINED
|
|
#include "FlexArray.h"
|
|
|
|
#if __GNUC__>2
|
|
#include <ext/hash_map>
|
|
using __gnu_cxx::hash_map;
|
|
#else
|
|
#include <hash_map>
|
|
#endif
|
|
#include "Array.h"
|
|
#include <map>
|
|
#include "mymath.h"
|
|
#include "syncObj.h"
|
|
|
|
template<class T>
|
|
T normalize_if_possible(T*a,T*b)
|
|
{
|
|
T sum=0;
|
|
for(T*i=a; i!=b; ++i)
|
|
sum+=*i;
|
|
if( sum )
|
|
for(T*i=a; i!=b; ++i)
|
|
*i/=sum;
|
|
else
|
|
fill(a,b,1.0/(b-a));
|
|
return sum;
|
|
}
|
|
|
|
extern short CompareAlDeps;
|
|
template<class CLS>
|
|
class AlDeps
|
|
{
|
|
public:
|
|
int englishSentenceLength;
|
|
CLS classPrevious;
|
|
int previous;
|
|
int j;
|
|
CLS Cj;
|
|
AlDeps() {};
|
|
AlDeps(int l,int p=0,int _j=0,CLS s1=0,CLS _Cj=0)
|
|
: englishSentenceLength(l),classPrevious(s1),previous(p),j(_j),Cj(_Cj)
|
|
{}
|
|
friend bool operator<(const AlDeps&x,const AlDeps&y) {
|
|
if( (CompareAlDeps&1) && x.englishSentenceLength<y.englishSentenceLength ) return 1;
|
|
if( (CompareAlDeps&1) && y.englishSentenceLength<x.englishSentenceLength ) return 0;
|
|
if( (CompareAlDeps&2) && x.classPrevious<y.classPrevious ) return 1;
|
|
if( (CompareAlDeps&2) && y.classPrevious<x.classPrevious ) return 0;
|
|
if( (CompareAlDeps&4) && x.previous<y.previous ) return 1;
|
|
if( (CompareAlDeps&4) && y.previous<x.previous ) return 0;
|
|
if( (CompareAlDeps&8) && x.j<y.j ) return 1;
|
|
if( (CompareAlDeps&8) && y.j<x.j ) return 0;
|
|
if( (CompareAlDeps&16) && x.Cj<y.Cj ) return 1;
|
|
if( (CompareAlDeps&16) && y.Cj<x.Cj ) return 0;
|
|
return 0;
|
|
}
|
|
friend bool operator==(const AlDeps&x,const AlDeps&y) {
|
|
return !( x<y || y<x );
|
|
}
|
|
};
|
|
|
|
template<class CLS>
|
|
class Hash_AlDeps
|
|
{
|
|
public:
|
|
unsigned
|
|
int
|
|
operator()
|
|
(const AlDeps<CLS>&x)
|
|
const {
|
|
unsigned int hash=0;
|
|
if( (CompareAlDeps&1) ) {
|
|
hash=hash+x.englishSentenceLength;
|
|
hash*=31;
|
|
}
|
|
if( (CompareAlDeps&2) ) {
|
|
hash=hash+x.classPrevious;
|
|
hash*=31;
|
|
}
|
|
if( (CompareAlDeps&4) ) {
|
|
hash=hash+x.previous;
|
|
hash*=31;
|
|
}
|
|
if( (CompareAlDeps&8) ) {
|
|
hash=hash+x.j;
|
|
hash*=31;
|
|
}
|
|
if( (CompareAlDeps&16) ) {
|
|
hash=hash+x.Cj;
|
|
hash*=31;
|
|
}
|
|
return hash;
|
|
|
|
}
|
|
};
|
|
|
|
#ifdef WIN32
|
|
typedef pair<Array<double>,Mutex*> hmmentry_type;
|
|
#else
|
|
typedef pair<Array<double>,Mutex> hmmentry_type;
|
|
#endif
|
|
|
|
template<class CLS,class MAPPERCLASSTOSTRING>
|
|
class HMMTables
|
|
{
|
|
Mutex* lock;
|
|
Mutex* alphalock,*betalock;
|
|
public:
|
|
|
|
double probabilityForEmpty;
|
|
bool updateProbabilityForEmpty;
|
|
hash_map<int, hmmentry_type > init_alpha;
|
|
hash_map<int, hmmentry_type > init_beta;
|
|
map<AlDeps<CLS>,FlexArray<double> > alProb;
|
|
map<AlDeps<CLS>,FlexArray<double> > alProbPredicted;
|
|
int globalCounter;
|
|
double divSum;
|
|
double p0_count,np0_count;
|
|
const MAPPERCLASSTOSTRING*mapper1;
|
|
const MAPPERCLASSTOSTRING*mapper2;
|
|
public:
|
|
bool merge(HMMTables<CLS,MAPPERCLASSTOSTRING> & ht);
|
|
const HMMTables<CLS,MAPPERCLASSTOSTRING>*getThis()const {
|
|
return this;
|
|
}
|
|
HMMTables(double _probForEmpty,const MAPPERCLASSTOSTRING&m1,const MAPPERCLASSTOSTRING&m2);
|
|
HMMTables(const HMMTables& ref);
|
|
void operator=(const HMMTables& ref);
|
|
virtual ~HMMTables();
|
|
virtual double getAlProb(int i,int k,int sentLength,int J,CLS w1,CLS w2,int j,int iter=0) const;
|
|
virtual void writeJumps(ostream&) const;
|
|
/**By Edward Gao, write out all things needed to rebuild the count table*/
|
|
virtual bool writeJumps(const char* alprob, const char* alpredict, const char* alpha, const char* beta )const;
|
|
virtual bool readJumps(const char* alprob, const char* alpredict, const char* alpha, const char* beta );
|
|
void addAlCount(int i,int k,int sentLength,int J,CLS w1,CLS w2,int j,double value,double valuePredicted);
|
|
virtual void readJumps(istream&);
|
|
virtual bool getAlphaInit(int I,Array<double>&x)const;
|
|
virtual bool getBetaInit(int I,Array<double> &x)const;
|
|
hmmentry_type &doGetAlphaInit(int I);
|
|
hmmentry_type &doGetBetaInit(int I);
|
|
virtual double getProbabilityForEmpty()const {
|
|
return probabilityForEmpty;
|
|
}
|
|
void performGISIteration(const HMMTables<CLS,MAPPERCLASSTOSTRING>*old) {
|
|
cout << "OLDSIZE: " << (old?(old->alProb.size()):0) << " NEWSIZE:"<< alProb.size()<< endl;
|
|
for(typename map<AlDeps<CLS>,FlexArray<double> >::iterator i=alProb.begin(); i!=alProb.end(); ++i) {
|
|
if( alProbPredicted.count(i->first)) {
|
|
normalize_if_possible(i->second.begin(),i->second.end());
|
|
normalize_if_possible(alProbPredicted[i->first].begin(),alProbPredicted[i->first].end());
|
|
for(int j=i->second.low(); j<=i->second.high(); ++j) {
|
|
if( i->second[j] )
|
|
if(alProbPredicted[i->first][j]>0.0 ) {
|
|
double op=1.0;
|
|
if( old && old->alProb.count(i->first) )
|
|
op=(old->alProb.find(i->first)->second)[j];
|
|
//cerr << "GIS: " << j << ' ' << " OLD:"
|
|
// << op << "*true:"
|
|
// << i->second[j] << "/pred:" << alProbPredicted[i->first][j] << " -> ";
|
|
|
|
|
|
i->second[j]= op*(i->second[j]/alProbPredicted[i->first][j]);
|
|
//cerr << i->second[j] << endl;
|
|
} else {
|
|
cerr << "ERROR2 in performGISiteration: " << i->second[j] << endl;
|
|
}
|
|
}
|
|
} else
|
|
cerr << "ERROR in performGISIteration: " << alProbPredicted.count(i->first) << endl;
|
|
}
|
|
}
|
|
};
|
|
|
|
template<class CLS,class MAPPERCLASSTOSTRING>
|
|
inline void printAlDeps(ostream&out,const AlDeps<CLS>&x,const MAPPERCLASSTOSTRING&mapper1,const MAPPERCLASSTOSTRING&mapper2)
|
|
{
|
|
if( (CompareAlDeps&1) ) out << "sentenceLength: " << x.englishSentenceLength<< ' ';
|
|
if( (CompareAlDeps&2) ) out << "previousClass: " << mapper1.classString(x.classPrevious) << ' ';
|
|
if( (CompareAlDeps&4) ) out << "previousPosition: " << x.previous << ' ';
|
|
if( (CompareAlDeps&8) ) out << "FrenchPosition: " << x.j << ' ';
|
|
if( (CompareAlDeps&16) ) out << "FrenchClass: " << mapper2.classString(x.Cj) << ' ';
|
|
//out << '\n';
|
|
}
|
|
|
|
#endif
|