ukui-search/libchinese-segmentation/cppjieba/HMMModel.hpp

177 lines
4.9 KiB
C++
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

/*
*
* Copyright (C) 2023, KylinSoft Co., Ltd.
*
* This program is free software: you can redistribute it and/or modify
* it under the terms of the GNU General Public License as published by
* the Free Software Foundation, either version 3 of the License, or
* (at your option) any later version.
*
* This program is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
* GNU General Public License for more details.
*
* You should have received a copy of the GNU General Public License
* along with this program. If not, see <https://www.gnu.org/licenses/>.
*
*/
#pragma once
#include "limonp/StringUtil.hpp"
//#define USE_CEDAR_SEGMENT //使用cedar初步测试性能损失3%-5%左右内存占用降低近1M
#ifdef USE_CEDAR_SEGMENT
#include "cedar/cedar.h"
#endif
namespace cppjieba {
using namespace limonp;
#ifdef USE_CEDAR_SEGMENT
typedef cedar::da<float, -1, -2, false> EmitProbMap;
#else
typedef unordered_map<Rune, double> EmitProbMap;
#endif
struct HMMModel {
/*
* STATUS:
* 0: HMMModel::B, 1: HMMModel::E, 2: HMMModel::M, 3:HMMModel::S
* */
enum {B = 0, E = 1, M = 2, S = 3, STATUS_SUM = 4};
HMMModel(const string& modelPath) {
memset(startProb, 0, sizeof(startProb));
memset(transProb, 0, sizeof(transProb));
statMap[0] = 'B';
statMap[1] = 'E';
statMap[2] = 'M';
statMap[3] = 'S';
emitProbVec.push_back(&emitProbB);
emitProbVec.push_back(&emitProbE);
emitProbVec.push_back(&emitProbM);
emitProbVec.push_back(&emitProbS);
LoadModel(modelPath);
}
~HMMModel() {
}
void LoadModel(const string& filePath) {
ifstream ifile(filePath.c_str());
XCHECK(ifile.is_open()) << "open " << filePath << " failed";
string line;
vector<string> tmp;
vector<string> tmp2;
//Load startProb
XCHECK(GetLine(ifile, line));
Split(line, tmp, " ");
XCHECK(tmp.size() == STATUS_SUM);
for (size_t j = 0; j < tmp.size(); j++) {
startProb[j] = atof(tmp[j].c_str());
}
//Load transProb
for (size_t i = 0; i < STATUS_SUM; i++) {
XCHECK(GetLine(ifile, line));
Split(line, tmp, " ");
XCHECK(tmp.size() == STATUS_SUM);
for (size_t j = 0; j < tmp.size(); j++) {
transProb[i][j] = atof(tmp[j].c_str());
}
}
//Load emitProbB
XCHECK(GetLine(ifile, line));
XCHECK(LoadEmitProb(line, emitProbB));
//Load emitProbE
XCHECK(GetLine(ifile, line));
XCHECK(LoadEmitProb(line, emitProbE));
//Load emitProbM
XCHECK(GetLine(ifile, line));
XCHECK(LoadEmitProb(line, emitProbM));
//Load emitProbS
XCHECK(GetLine(ifile, line));
XCHECK(LoadEmitProb(line, emitProbS));
}
double GetEmitProb(const EmitProbMap* ptMp, Rune key,
double defVal)const {
#ifdef USE_CEDAR_SEGMENT
char str_key[8];
snprintf(str_key, sizeof(str_key), "%d", key);
float result = ptMp->exactMatchSearch<float>(str_key);
return result < 0 ? defVal : result;
#else
EmitProbMap::const_iterator cit = ptMp->find(key);
if (cit == ptMp->end()) {
return defVal;
}
return cit->second;
#endif
}
bool GetLine(ifstream& ifile, string& line) {
while (getline(ifile, line)) {
Trim(line);
if (line.empty()) {
continue;
}
if (StartsWith(line, "#")) {
continue;
}
return true;
}
return false;
}
bool LoadEmitProb(const string& line, EmitProbMap& mp) {
if (line.empty()) {
return false;
}
vector<string> tmp, tmp2;
RuneArray unicode;
Split(line, tmp, ",");
for (size_t i = 0; i < tmp.size(); i++) {
Split(tmp[i], tmp2, ":");
if (2 != tmp2.size()) {
XLOG(ERROR) << "emitProb illegal.";
return false;
}
if (!DecodeRunesInString(tmp2[0], unicode) || unicode.size() != 1) {
XLOG(ERROR) << "TransCode failed.";
return false;
}
#ifdef USE_CEDAR_SEGMENT
char str_key[8];
snprintf(str_key, sizeof(str_key), "%d", unicode[0]);
mp.update(str_key, std::strlen(str_key), atof(tmp2[1].c_str()));
#else
mp[unicode[0]] = atof(tmp2[1].c_str());
#endif
}
return true;
}
char statMap[STATUS_SUM];
double startProb[STATUS_SUM];
double transProb[STATUS_SUM][STATUS_SUM];
EmitProbMap emitProbB;
EmitProbMap emitProbE;
EmitProbMap emitProbM;
EmitProbMap emitProbS;
vector<EmitProbMap* > emitProbVec;
}; // struct HMMModel
} // namespace cppjieba