Browse Source

training and using separate, add some config class

fetches/sdlf/master
gitclebeg 9 years ago
parent
commit
acb17f1f75
9 changed files with 205 additions and 36 deletions
  1. +5
    -1
      README.md
  2. +102
    -0
      src/main/java/eshore/cn/it/classification/NGramClassierTrainer.java
  3. +47
    -0
      src/main/java/eshore/cn/it/configuration/ClassModelConfiguration.java
  4. +2
    -2
      src/test/java/eshore/cn/it/classification/DfIdfClassierTest.java
  5. +32
    -0
      src/test/java/eshore/cn/it/classification/GovernClassModelTest.java
  6. +12
    -15
      src/test/java/eshore/cn/it/classification/NGramClassierTest.java
  7. +5
    -18
      src/test/java/eshore/cn/it/sentiment/ChinesePolarityBasic.java
  8. +0
    -0
      src/test/java/eshore/cn/it/sentiment/PolarityBasic.java
  9. +0
    -0
      src/test/java/eshore/cn/it/sentiment/Sentiment.java

+ 5
- 1
README.md View File

@@ -13,6 +13,10 @@
>3. 训练一个情感分类器,判断这个样本是该领域的正面信息还是负面信息。


### 新增说明4:将模型训练和模型生成应用分离,提炼一些测试用例。
1. 新增 NGramClassierTrainer 用于基于 NGram 特征的分类器训练
2. 增加模型训练配置类:ClassModelConfiguration

### 新增说明3:增加基于 TF-IDF(词向量) 特征的文本分类程序。
1. 主程序:DfIdfClassifier.java
2. 效果如下:
@@ -40,7 +44,7 @@

### 新增说明1:2015-04-10测试了不用中文分词器,分词之后 LingPipe 情感分类的准确率,同时测试了去除停用词之后的情感分类的准确率。

注意:有时候不用中文分词器效果更好,一定要测试。
1. 发现用HanLP的NLPTokenizer分词器,准确率最高,但是速度有点慢。
2. 如果用HanLP的标准分词器就会准确率低一点点,但是速度快。
3. 分词之后去除停用词效果更加差。


+ 102
- 0
src/main/java/eshore/cn/it/classification/NGramClassierTrainer.java View File

@@ -0,0 +1,102 @@
package eshore.cn.it.classification;

import java.io.File;
import java.io.FileOutputStream;
import java.io.IOException;
import java.io.ObjectOutputStream;
import java.util.logging.Logger;

import com.aliasi.classify.Classification;
import com.aliasi.classify.Classified;
import com.aliasi.classify.DynamicLMClassifier;
import com.aliasi.lm.NGramProcessLM;
import com.aliasi.util.Files;

import eshore.cn.it.configuration.ClassModelConfiguration;

/***
* NGramClassierTrainer 是利用 NGram 特征提取的方式训练分类模型
* 最后保存模型到指定的位置。
* 1、训练模型(利用全新的参数)
* 2、在已有的模型上面继续更新训练
* 模型一定要通过NGramClassierTest测试之后,达到一定的精度(目前还没有写十则交叉验证,只有二择交叉验证)
* 才能利用全部样本训练模型,投入实际生产环境。
* @author clebeg 2015-04-14 10:40
*/
public class NGramClassierTrainer {
private static Logger logger =
Logger.getLogger(NGramClassierTrainer.class.getName());
/**
* ngram特征提取,当前字符与前几个字符相关
* 已模型训练时候的参数设置为准,默认设置为3
* 太大容易出现过拟合。
*/
private int ngramSize = 3;
public int getNgramSize() {
return ngramSize;
}
public void setNgramSize(int ngramSize) {
this.ngramSize = ngramSize;
}
private ClassModelConfiguration modelConfig;
public ClassModelConfiguration getModelConfig() {
return modelConfig;
}
public void setModelConfig(ClassModelConfiguration modelConfig) {
this.modelConfig = modelConfig;
}
//基于NGram提取特征之后的分类器
private DynamicLMClassifier<NGramProcessLM> classifier;
public void trainModel() throws IOException {
classifier
= DynamicLMClassifier.createNGramProcess(this.getModelConfig().getCategories(),
this.getNgramSize());
for(int i = 0; i < this.getModelConfig().getCategories().length; ++i) {
File classDir = new File(this.getModelConfig().getTrainRootDir(),
this.getModelConfig().getCategories()[i]);
if (!classDir.isDirectory()) {
String msg = "无法找到包含训练数据的路径 ="
+ classDir
+ "\n你是否未在训练root目录建立 "
+ this.getModelConfig().getCategories().length
+ "类别?";
logger.severe(msg); // in case exception gets lost in shell
throw new IllegalArgumentException(msg);
}
String[] trainingFiles = classDir.list();
for (int j = 0; j < trainingFiles.length; ++j) {
File file = new File(classDir, trainingFiles[j]);
String text = Files.readFromFile(file, "GBK");
System.out.println("正在训练 -> " + this.getModelConfig().getCategories()[i] + "/" + trainingFiles[j]);
Classification classification
= new Classification(this.getModelConfig().getCategories()[i]);
Classified<CharSequence> classified
= new Classified<CharSequence>(text, classification);
classifier.handle(classified);
}
}
File modelFile = new File(this.getModelConfig().getModelFile());
System.out.println("所有类别均训练完成.开始保存模型到 " + modelFile.getAbsolutePath());
if (modelFile.exists() == false) {
logger.info("指定模型文件不存在,将自动建立上层目录,并建立文件...");
modelFile.getParentFile().mkdirs();
} else {
logger.warning("指定模型文件已经存在,将覆盖...");
}
ObjectOutputStream os = new ObjectOutputStream(new FileOutputStream(
modelFile));
classifier.compileTo(os);
os.close();
System.out.println("模型保持成功!");
}
}

+ 47
- 0
src/main/java/eshore/cn/it/configuration/ClassModelConfiguration.java View File

@@ -0,0 +1,47 @@
package eshore.cn.it.configuration;
/**
* 分类模型配置类
* 配置模型分类的类别个数,配置模型训练样本的root位置
* 目前支持的格式存放样本的格式有:
* 1、在训练样本存放的root目录里面,按照类别名称,分别存放每个类别样本
* @author clebeg
* @time 2015-04-14 10:48
*/
public class ClassModelConfiguration {
/**
* 类别数值,例如初始化为:{"军事", "经济", "政治"}
*/
private String[] categories;
public String[] getCategories() {
return categories;
}
public void setCategories(String[] categories) {
this.categories = categories;
}
/**
* 训练样本的主目录,可以是相对路径,默认是data/training
* 如果有上面的类别,那么data/training目录中必须包含
* 军事/ 经济/ 政治/ 这三个目录,分别存放对应类别的数据
*/
private String trainRootDir = "data/training";
public String getTrainRootDir() {
return trainRootDir;
}
public void setTrainRootDir(String trainRootDir) {
this.trainRootDir = trainRootDir;
}
/**
* 训练好模型之后就需要保持模型,这就模型保持的目录及文件
*/
private String modelFile = "data/model/model.class";
public String getModelFile() {
return modelFile;
}
public void setModelFile(String modelFile) {
this.modelFile = modelFile;
}
}

src/main/java/eshore/cn/it/classification/DfIdfClassier.java → src/test/java/eshore/cn/it/classification/DfIdfClassierTest.java View File

@@ -29,7 +29,7 @@ import com.hankcs.hanlp.seg.common.Term;
* @author clebeg
* @time 2015-04-13
* */
public class DfIdfClassier {
public class DfIdfClassierTest {
private static String[] CATEGORIES = {
"government",
"others"
@@ -119,7 +119,7 @@ public class DfIdfClassier {
+ File.separator + testingFiles[j]);
ScoredClassification classification = compiledClassifier
.classify(segWords.subSequence(0, text.length()));
.classify(segWords.subSequence(0, segWords.length()));
confMatrix.increment(CATEGORIES[i],
classification.bestCategory());
System.out.println("最适合的分类: "

+ 32
- 0
src/test/java/eshore/cn/it/classification/GovernClassModelTest.java View File

@@ -0,0 +1,32 @@
package eshore.cn.it.classification;

import java.io.IOException;

import eshore.cn.it.configuration.ClassModelConfiguration;
import junit.framework.TestCase;
/***
* 测试建立 与政府相关 以及 与政府不相关的 分类器
* @author clebeg 2015-04-14 11:45
*/
public class GovernClassModelTest extends TestCase {
private String[] CATEGORIES = {
"government",
"others"
};
public void trainModel() {
ClassModelConfiguration modelConfig = new ClassModelConfiguration();
modelConfig.setCategories(CATEGORIES);
modelConfig.setTrainRootDir("data/text_classification/training");
modelConfig.setModelFile("C:/model/mymodel.class");
NGramClassierTrainer ngct = new NGramClassierTrainer();
ngct.setNgramSize(3);
ngct.setModelConfig(modelConfig);
try {
ngct.trainModel();
} catch (IOException e) {
e.printStackTrace();
}
}
}

src/main/java/eshore/cn/it/classification/NGramClassier.java → src/test/java/eshore/cn/it/classification/NGramClassierTest.java View File

@@ -2,7 +2,6 @@ package eshore.cn.it.classification;

import java.io.File;
import java.io.IOException;
import java.util.List;

import com.aliasi.classify.Classification;
import com.aliasi.classify.Classified;
@@ -14,8 +13,6 @@ import com.aliasi.classify.JointClassifierEvaluator;
import com.aliasi.lm.NGramProcessLM;
import com.aliasi.util.AbstractExternalizable;
import com.aliasi.util.Files;
import com.hankcs.hanlp.HanLP;
import com.hankcs.hanlp.seg.common.Term;

/**
* 基于LingPipe的文本分类器,主要分类成两类
@@ -25,12 +22,12 @@ import com.hankcs.hanlp.seg.common.Term;
* @author clebeg
* @time 2015-04-13
* */
public class NGramClassier {
public class NGramClassierTest {
private static String[] CATEGORIES = {
"government",
"others"
};
private static int NGRAM_SIZE = 3;
private static int NGRAM_SIZE = 4;
private static String TEXT_CLASSIFICATION_TRAINING = "data/text_classification/training";
private static String TEXT_CLASSIFICATION_TESTING = "data/text_classification/testing";
@@ -62,15 +59,15 @@ public class NGramClassier {
String text = Files.readFromFile(file, "GBK");
System.out.println("Training on " + CATEGORIES[i] + "/" + trainingFiles[j]);
String segWords = "";
List<Term> terms = HanLP.segment(text);
for (Term term : terms)
segWords += term.word + " ";
// String segWords = "";
// List<Term> terms = HanLP.segment(text);
// for (Term term : terms)
// segWords += term.word + " ";
Classification classification
= new Classification(CATEGORIES[i]);
Classified<CharSequence> classified
= new Classified<CharSequence>(segWords, classification);
= new Classified<CharSequence>(text, classification);
classifier.handle(classified);
}
}
@@ -99,13 +96,13 @@ public class NGramClassier {
Classification classification
= new Classification(CATEGORIES[i]);
String segWords = "";
List<Term> terms = HanLP.segment(text);
for (Term term : terms)
segWords += term.word + " ";
// String segWords = "";
// List<Term> terms = HanLP.segment(text);
// for (Term term : terms)
// segWords += term.word + " ";
Classified<CharSequence> classified
= new Classified<CharSequence>(segWords, classification);
= new Classified<CharSequence>(text, classification);
evaluator.handle(classified);
JointClassification jc =
compiledClassifier.classify(text);

src/main/java/eshore/cn/it/sentiment/ChinesePolarityBasic.java → src/test/java/eshore/cn/it/sentiment/ChinesePolarityBasic.java View File

@@ -21,8 +21,6 @@ import com.aliasi.classify.DynamicLMClassifier;
import com.aliasi.lm.NGramProcessLM;
import com.aliasi.util.Files;

import com.hankcs.hanlp.HanLP;
import com.hankcs.hanlp.seg.common.Term;

//下面的分词器准确率最高,去除停用词反而准确率不高了。
//import com.hankcs.hanlp.tokenizer.NLPTokenizer;
@@ -51,6 +49,7 @@ public class ChinesePolarityBasic {
"data/polarity_corpus/hotel_reviews/test2.rlabelclass";
private static final String ENCODING = "GBK";

public static void main(String[] args) {
try {
new ChinesePolarityBasic().run();
@@ -64,7 +63,7 @@ public class ChinesePolarityBasic {

public ChinesePolarityBasic() {
super();
int nGram = 8;
int nGram = 3;
mClassifier
= DynamicLMClassifier
.createNGramProcess(mCategories,nGram);
@@ -114,16 +113,9 @@ public class ChinesePolarityBasic {
throws IOException {
Classification classification = new Classification(category);
String review = Files.readFromFile(trainFile, fileEncoding);
//此处加入中文分词器,得到分词之后的字符串
String segWords = "";
List<Term> terms = HanLP.segment(review);
for (Term term : terms)
segWords += term.word + " ";
Classified<CharSequence> classified
= new Classified<CharSequence>(segWords,classification);
= new Classified<CharSequence>(review,classification);
mClassifier.handle(classified);
}
@@ -136,14 +128,9 @@ public class ChinesePolarityBasic {
String review
= Files.readFromFile(testFile, fileEncoding);
//同理,这里可以加入分词器,这样可以试试效果如何。
String segWords = "";
List<Term> terms = HanLP.segment(review);
for (Term term : terms)
segWords += term.word + " ";
++numTests;
Classification classification
= mClassifier.classify(segWords);
= mClassifier.classify(review);
//得到训练结果
String resultCategory
= classification.bestCategory();

src/main/java/eshore/cn/it/sentiment/PolarityBasic.java → src/test/java/eshore/cn/it/sentiment/PolarityBasic.java View File


src/main/java/eshore/cn/it/sentiment/Sentiment.java → src/test/java/eshore/cn/it/sentiment/Sentiment.java View File


Loading…
Cancel
Save