ukui-search/libchinese-segmentation/cppjieba/HMMModel.hpp

141 lines
3.6 KiB
C++
Raw Normal View History

#pragma once
#include "limonp/StringUtil.hpp"
namespace cppjieba {
using namespace limonp;
typedef unordered_map<Rune, double> EmitProbMap;
struct HMMModel {
2021-04-26 15:06:47 +08:00
/*
* STATUS:
* 0: HMMModel::B, 1: HMMModel::E, 2: HMMModel::M, 3:HMMModel::S
* */
enum {B = 0, E = 1, M = 2, S = 3, STATUS_SUM = 4};
2021-04-26 15:06:47 +08:00
HMMModel(const string& modelPath) {
memset(startProb, 0, sizeof(startProb));
memset(transProb, 0, sizeof(transProb));
statMap[0] = 'B';
statMap[1] = 'E';
statMap[2] = 'M';
statMap[3] = 'S';
emitProbVec.push_back(&emitProbB);
emitProbVec.push_back(&emitProbE);
emitProbVec.push_back(&emitProbM);
emitProbVec.push_back(&emitProbS);
LoadModel(modelPath);
}
2021-04-26 15:06:47 +08:00
~HMMModel() {
}
2021-04-26 15:06:47 +08:00
void LoadModel(const string& filePath) {
ifstream ifile(filePath.c_str());
XCHECK(ifile.is_open()) << "open " << filePath << " failed";
string line;
vector<string> tmp;
vector<string> tmp2;
//Load startProb
XCHECK(GetLine(ifile, line));
Split(line, tmp, " ");
XCHECK(tmp.size() == STATUS_SUM);
for (size_t j = 0; j < tmp.size(); j++) {
2021-04-26 15:06:47 +08:00
startProb[j] = atof(tmp[j].c_str());
}
2021-04-26 15:06:47 +08:00
//Load transProb
for (size_t i = 0; i < STATUS_SUM; i++) {
2021-04-26 15:06:47 +08:00
XCHECK(GetLine(ifile, line));
Split(line, tmp, " ");
XCHECK(tmp.size() == STATUS_SUM);
for (size_t j = 0; j < tmp.size(); j++) {
2021-04-26 15:06:47 +08:00
transProb[i][j] = atof(tmp[j].c_str());
}
}
2021-04-26 15:06:47 +08:00
//Load emitProbB
XCHECK(GetLine(ifile, line));
XCHECK(LoadEmitProb(line, emitProbB));
2021-04-26 15:06:47 +08:00
//Load emitProbE
XCHECK(GetLine(ifile, line));
XCHECK(LoadEmitProb(line, emitProbE));
2021-04-26 15:06:47 +08:00
//Load emitProbM
XCHECK(GetLine(ifile, line));
XCHECK(LoadEmitProb(line, emitProbM));
//Load emitProbS
XCHECK(GetLine(ifile, line));
XCHECK(LoadEmitProb(line, emitProbS));
}
2021-04-26 15:06:47 +08:00
double GetEmitProb(const EmitProbMap* ptMp, Rune key,
double defVal)const {
EmitProbMap::const_iterator cit = ptMp->find(key);
if (cit == ptMp->end()) {
2021-04-26 15:06:47 +08:00
return defVal;
}
2021-04-26 15:06:47 +08:00
return cit->second;
}
2021-04-26 15:06:47 +08:00
bool GetLine(ifstream& ifile, string& line) {
while (getline(ifile, line)) {
2021-04-26 15:06:47 +08:00
Trim(line);
if (line.empty()) {
2021-04-26 15:06:47 +08:00
continue;
}
if (StartsWith(line, "#")) {
2021-04-26 15:06:47 +08:00
continue;
}
2021-04-26 15:06:47 +08:00
return true;
}
return false;
}
2021-04-26 15:06:47 +08:00
bool LoadEmitProb(const string& line, EmitProbMap& mp) {
if (line.empty()) {
2021-04-26 15:06:47 +08:00
return false;
}
2021-04-26 15:06:47 +08:00
vector<string> tmp, tmp2;
RuneArray unicode;
2021-04-26 15:06:47 +08:00
Split(line, tmp, ",");
for (size_t i = 0; i < tmp.size(); i++) {
2021-04-26 15:06:47 +08:00
Split(tmp[i], tmp2, ":");
if (2 != tmp2.size()) {
2021-04-26 15:06:47 +08:00
XLOG(ERROR) << "emitProb illegal.";
return false;
}
if (!DecodeRunesInString(tmp2[0], unicode) || unicode.size() != 1) {
2021-04-26 15:06:47 +08:00
XLOG(ERROR) << "TransCode failed.";
return false;
}
2021-04-26 15:06:47 +08:00
mp[unicode[0]] = atof(tmp2[1].c_str());
}
2021-04-26 15:06:47 +08:00
return true;
}
2021-04-26 15:06:47 +08:00
char statMap[STATUS_SUM];
double startProb[STATUS_SUM];
double transProb[STATUS_SUM][STATUS_SUM];
EmitProbMap emitProbB;
EmitProbMap emitProbE;
EmitProbMap emitProbM;
EmitProbMap emitProbS;
vector<EmitProbMap* > emitProbVec;
}; // struct HMMModel
} // namespace cppjieba