@@ -13,6 +13,10 @@ | |||
>3. 训练一个情感分类器,判断这个样本是该领域的正面信息还是负面信息。 | |||
### 新增说明4:将模型训练和模型生成应用分离,提炼一些测试用例。 | |||
1. 新增 NGramClassierTrainer 用于基于 NGram 特征的分类器训练 | |||
2. 增加模型训练配置类:ClassModelConfiguration | |||
### 新增说明3:增加基于 TF-IDF(词向量) 特征的文本分类程序。 | |||
1. 主程序:DfIdfClassifier.java | |||
2. 效果如下: | |||
@@ -40,7 +44,7 @@ | |||
### 新增说明1:2015-04-10测试了不用中文分词器,分词之后 LingPipe 情感分类的准确率,同时测试了去除停用词之后的情感分类的准确率。 | |||
注意:有时候不用中文分词器效果更好,一定要测试。 | |||
1. 发现用HanLP的NLPTokenizer分词器,准确率最高,但是速度有点慢。 | |||
2. 如果用HanLP的标准分词器就会准确率低一点点,但是速度快。 | |||
3. 分词之后去除停用词效果更加差。 | |||
@@ -0,0 +1,102 @@ | |||
package eshore.cn.it.classification; | |||
import java.io.File; | |||
import java.io.FileOutputStream; | |||
import java.io.IOException; | |||
import java.io.ObjectOutputStream; | |||
import java.util.logging.Logger; | |||
import com.aliasi.classify.Classification; | |||
import com.aliasi.classify.Classified; | |||
import com.aliasi.classify.DynamicLMClassifier; | |||
import com.aliasi.lm.NGramProcessLM; | |||
import com.aliasi.util.Files; | |||
import eshore.cn.it.configuration.ClassModelConfiguration; | |||
/*** | |||
* NGramClassierTrainer 是利用 NGram 特征提取的方式训练分类模型 | |||
* 最后保存模型到指定的位置。 | |||
* 1、训练模型(利用全新的参数) | |||
* 2、在已有的模型上面继续更新训练 | |||
* 模型一定要通过NGramClassierTest测试之后,达到一定的精度(目前还没有写十则交叉验证,只有二择交叉验证) | |||
* 才能利用全部样本训练模型,投入实际生产环境。 | |||
* @author clebeg 2015-04-14 10:40 | |||
*/ | |||
public class NGramClassierTrainer { | |||
private static Logger logger = | |||
Logger.getLogger(NGramClassierTrainer.class.getName()); | |||
/** | |||
* ngram特征提取,当前字符与前几个字符相关 | |||
* 已模型训练时候的参数设置为准,默认设置为3 | |||
* 太大容易出现过拟合。 | |||
*/ | |||
private int ngramSize = 3; | |||
public int getNgramSize() { | |||
return ngramSize; | |||
} | |||
public void setNgramSize(int ngramSize) { | |||
this.ngramSize = ngramSize; | |||
} | |||
private ClassModelConfiguration modelConfig; | |||
public ClassModelConfiguration getModelConfig() { | |||
return modelConfig; | |||
} | |||
public void setModelConfig(ClassModelConfiguration modelConfig) { | |||
this.modelConfig = modelConfig; | |||
} | |||
//基于NGram提取特征之后的分类器 | |||
private DynamicLMClassifier<NGramProcessLM> classifier; | |||
public void trainModel() throws IOException { | |||
classifier | |||
= DynamicLMClassifier.createNGramProcess(this.getModelConfig().getCategories(), | |||
this.getNgramSize()); | |||
for(int i = 0; i < this.getModelConfig().getCategories().length; ++i) { | |||
File classDir = new File(this.getModelConfig().getTrainRootDir(), | |||
this.getModelConfig().getCategories()[i]); | |||
if (!classDir.isDirectory()) { | |||
String msg = "无法找到包含训练数据的路径 =" | |||
+ classDir | |||
+ "\n你是否未在训练root目录建立 " | |||
+ this.getModelConfig().getCategories().length | |||
+ "类别?"; | |||
logger.severe(msg); // in case exception gets lost in shell | |||
throw new IllegalArgumentException(msg); | |||
} | |||
String[] trainingFiles = classDir.list(); | |||
for (int j = 0; j < trainingFiles.length; ++j) { | |||
File file = new File(classDir, trainingFiles[j]); | |||
String text = Files.readFromFile(file, "GBK"); | |||
System.out.println("正在训练 -> " + this.getModelConfig().getCategories()[i] + "/" + trainingFiles[j]); | |||
Classification classification | |||
= new Classification(this.getModelConfig().getCategories()[i]); | |||
Classified<CharSequence> classified | |||
= new Classified<CharSequence>(text, classification); | |||
classifier.handle(classified); | |||
} | |||
} | |||
File modelFile = new File(this.getModelConfig().getModelFile()); | |||
System.out.println("所有类别均训练完成.开始保存模型到 " + modelFile.getAbsolutePath()); | |||
if (modelFile.exists() == false) { | |||
logger.info("指定模型文件不存在,将自动建立上层目录,并建立文件..."); | |||
modelFile.getParentFile().mkdirs(); | |||
} else { | |||
logger.warning("指定模型文件已经存在,将覆盖..."); | |||
} | |||
ObjectOutputStream os = new ObjectOutputStream(new FileOutputStream( | |||
modelFile)); | |||
classifier.compileTo(os); | |||
os.close(); | |||
System.out.println("模型保持成功!"); | |||
} | |||
} |
@@ -0,0 +1,47 @@ | |||
package eshore.cn.it.configuration; | |||
/** | |||
* 分类模型配置类 | |||
* 配置模型分类的类别个数,配置模型训练样本的root位置 | |||
* 目前支持的格式存放样本的格式有: | |||
* 1、在训练样本存放的root目录里面,按照类别名称,分别存放每个类别样本 | |||
* @author clebeg | |||
* @time 2015-04-14 10:48 | |||
*/ | |||
public class ClassModelConfiguration { | |||
/** | |||
* 类别数值,例如初始化为:{"军事", "经济", "政治"} | |||
*/ | |||
private String[] categories; | |||
public String[] getCategories() { | |||
return categories; | |||
} | |||
public void setCategories(String[] categories) { | |||
this.categories = categories; | |||
} | |||
/** | |||
* 训练样本的主目录,可以是相对路径,默认是data/training | |||
* 如果有上面的类别,那么data/training目录中必须包含 | |||
* 军事/ 经济/ 政治/ 这三个目录,分别存放对应类别的数据 | |||
*/ | |||
private String trainRootDir = "data/training"; | |||
public String getTrainRootDir() { | |||
return trainRootDir; | |||
} | |||
public void setTrainRootDir(String trainRootDir) { | |||
this.trainRootDir = trainRootDir; | |||
} | |||
/** | |||
* 训练好模型之后就需要保持模型,这就模型保持的目录及文件 | |||
*/ | |||
private String modelFile = "data/model/model.class"; | |||
public String getModelFile() { | |||
return modelFile; | |||
} | |||
public void setModelFile(String modelFile) { | |||
this.modelFile = modelFile; | |||
} | |||
} |
@@ -29,7 +29,7 @@ import com.hankcs.hanlp.seg.common.Term; | |||
* @author clebeg | |||
* @time 2015-04-13 | |||
* */ | |||
public class DfIdfClassier { | |||
public class DfIdfClassierTest { | |||
private static String[] CATEGORIES = { | |||
"government", | |||
"others" | |||
@@ -119,7 +119,7 @@ public class DfIdfClassier { | |||
+ File.separator + testingFiles[j]); | |||
ScoredClassification classification = compiledClassifier | |||
.classify(segWords.subSequence(0, text.length())); | |||
.classify(segWords.subSequence(0, segWords.length())); | |||
confMatrix.increment(CATEGORIES[i], | |||
classification.bestCategory()); | |||
System.out.println("最适合的分类: " |
@@ -0,0 +1,32 @@ | |||
package eshore.cn.it.classification; | |||
import java.io.IOException; | |||
import eshore.cn.it.configuration.ClassModelConfiguration; | |||
import junit.framework.TestCase; | |||
/*** | |||
* 测试建立 与政府相关 以及 与政府不相关的 分类器 | |||
* @author clebeg 2015-04-14 11:45 | |||
*/ | |||
public class GovernClassModelTest extends TestCase { | |||
private String[] CATEGORIES = { | |||
"government", | |||
"others" | |||
}; | |||
public void trainModel() { | |||
ClassModelConfiguration modelConfig = new ClassModelConfiguration(); | |||
modelConfig.setCategories(CATEGORIES); | |||
modelConfig.setTrainRootDir("data/text_classification/training"); | |||
modelConfig.setModelFile("C:/model/mymodel.class"); | |||
NGramClassierTrainer ngct = new NGramClassierTrainer(); | |||
ngct.setNgramSize(3); | |||
ngct.setModelConfig(modelConfig); | |||
try { | |||
ngct.trainModel(); | |||
} catch (IOException e) { | |||
e.printStackTrace(); | |||
} | |||
} | |||
} |
@@ -2,7 +2,6 @@ package eshore.cn.it.classification; | |||
import java.io.File; | |||
import java.io.IOException; | |||
import java.util.List; | |||
import com.aliasi.classify.Classification; | |||
import com.aliasi.classify.Classified; | |||
@@ -14,8 +13,6 @@ import com.aliasi.classify.JointClassifierEvaluator; | |||
import com.aliasi.lm.NGramProcessLM; | |||
import com.aliasi.util.AbstractExternalizable; | |||
import com.aliasi.util.Files; | |||
import com.hankcs.hanlp.HanLP; | |||
import com.hankcs.hanlp.seg.common.Term; | |||
/** | |||
* 基于LingPipe的文本分类器,主要分类成两类 | |||
@@ -25,12 +22,12 @@ import com.hankcs.hanlp.seg.common.Term; | |||
* @author clebeg | |||
* @time 2015-04-13 | |||
* */ | |||
public class NGramClassier { | |||
public class NGramClassierTest { | |||
private static String[] CATEGORIES = { | |||
"government", | |||
"others" | |||
}; | |||
private static int NGRAM_SIZE = 3; | |||
private static int NGRAM_SIZE = 4; | |||
private static String TEXT_CLASSIFICATION_TRAINING = "data/text_classification/training"; | |||
private static String TEXT_CLASSIFICATION_TESTING = "data/text_classification/testing"; | |||
@@ -62,15 +59,15 @@ public class NGramClassier { | |||
String text = Files.readFromFile(file, "GBK"); | |||
System.out.println("Training on " + CATEGORIES[i] + "/" + trainingFiles[j]); | |||
String segWords = ""; | |||
List<Term> terms = HanLP.segment(text); | |||
for (Term term : terms) | |||
segWords += term.word + " "; | |||
// String segWords = ""; | |||
// List<Term> terms = HanLP.segment(text); | |||
// for (Term term : terms) | |||
// segWords += term.word + " "; | |||
Classification classification | |||
= new Classification(CATEGORIES[i]); | |||
Classified<CharSequence> classified | |||
= new Classified<CharSequence>(segWords, classification); | |||
= new Classified<CharSequence>(text, classification); | |||
classifier.handle(classified); | |||
} | |||
} | |||
@@ -99,13 +96,13 @@ public class NGramClassier { | |||
Classification classification | |||
= new Classification(CATEGORIES[i]); | |||
String segWords = ""; | |||
List<Term> terms = HanLP.segment(text); | |||
for (Term term : terms) | |||
segWords += term.word + " "; | |||
// String segWords = ""; | |||
// List<Term> terms = HanLP.segment(text); | |||
// for (Term term : terms) | |||
// segWords += term.word + " "; | |||
Classified<CharSequence> classified | |||
= new Classified<CharSequence>(segWords, classification); | |||
= new Classified<CharSequence>(text, classification); | |||
evaluator.handle(classified); | |||
JointClassification jc = | |||
compiledClassifier.classify(text); |
@@ -21,8 +21,6 @@ import com.aliasi.classify.DynamicLMClassifier; | |||
import com.aliasi.lm.NGramProcessLM; | |||
import com.aliasi.util.Files; | |||
import com.hankcs.hanlp.HanLP; | |||
import com.hankcs.hanlp.seg.common.Term; | |||
//下面的分词器准确率最高,去除停用词反而准确率不高了。 | |||
//import com.hankcs.hanlp.tokenizer.NLPTokenizer; | |||
@@ -51,6 +49,7 @@ public class ChinesePolarityBasic { | |||
"data/polarity_corpus/hotel_reviews/test2.rlabelclass"; | |||
private static final String ENCODING = "GBK"; | |||
public static void main(String[] args) { | |||
try { | |||
new ChinesePolarityBasic().run(); | |||
@@ -64,7 +63,7 @@ public class ChinesePolarityBasic { | |||
public ChinesePolarityBasic() { | |||
super(); | |||
int nGram = 8; | |||
int nGram = 3; | |||
mClassifier | |||
= DynamicLMClassifier | |||
.createNGramProcess(mCategories,nGram); | |||
@@ -114,16 +113,9 @@ public class ChinesePolarityBasic { | |||
throws IOException { | |||
Classification classification = new Classification(category); | |||
String review = Files.readFromFile(trainFile, fileEncoding); | |||
//此处加入中文分词器,得到分词之后的字符串 | |||
String segWords = ""; | |||
List<Term> terms = HanLP.segment(review); | |||
for (Term term : terms) | |||
segWords += term.word + " "; | |||
Classified<CharSequence> classified | |||
= new Classified<CharSequence>(segWords,classification); | |||
= new Classified<CharSequence>(review,classification); | |||
mClassifier.handle(classified); | |||
} | |||
@@ -136,14 +128,9 @@ public class ChinesePolarityBasic { | |||
String review | |||
= Files.readFromFile(testFile, fileEncoding); | |||
//同理,这里可以加入分词器,这样可以试试效果如何。 | |||
String segWords = ""; | |||
List<Term> terms = HanLP.segment(review); | |||
for (Term term : terms) | |||
segWords += term.word + " "; | |||
++numTests; | |||
Classification classification | |||
= mClassifier.classify(segWords); | |||
= mClassifier.classify(review); | |||
//得到训练结果 | |||
String resultCategory | |||
= classification.bestCategory(); |