@@ -13,6 +13,10 @@ | |||||
>3. 训练一个情感分类器,判断这个样本是该领域的正面信息还是负面信息。 | >3. 训练一个情感分类器,判断这个样本是该领域的正面信息还是负面信息。 | ||||
### 新增说明4:将模型训练和模型生成应用分离,提炼一些测试用例。 | |||||
1. 新增 NGramClassierTrainer 用于基于 NGram 特征的分类器训练 | |||||
2. 增加模型训练配置类:ClassModelConfiguration | |||||
### 新增说明3:增加基于 TF-IDF(词向量) 特征的文本分类程序。 | ### 新增说明3:增加基于 TF-IDF(词向量) 特征的文本分类程序。 | ||||
1. 主程序:DfIdfClassifier.java | 1. 主程序:DfIdfClassifier.java | ||||
2. 效果如下: | 2. 效果如下: | ||||
@@ -40,7 +44,7 @@ | |||||
### 新增说明1:2015-04-10测试了不用中文分词器,分词之后 LingPipe 情感分类的准确率,同时测试了去除停用词之后的情感分类的准确率。 | ### 新增说明1:2015-04-10测试了不用中文分词器,分词之后 LingPipe 情感分类的准确率,同时测试了去除停用词之后的情感分类的准确率。 | ||||
注意:有时候不用中文分词器效果更好,一定要测试。 | |||||
1. 发现用HanLP的NLPTokenizer分词器,准确率最高,但是速度有点慢。 | 1. 发现用HanLP的NLPTokenizer分词器,准确率最高,但是速度有点慢。 | ||||
2. 如果用HanLP的标准分词器就会准确率低一点点,但是速度快。 | 2. 如果用HanLP的标准分词器就会准确率低一点点,但是速度快。 | ||||
3. 分词之后去除停用词效果更加差。 | 3. 分词之后去除停用词效果更加差。 | ||||
@@ -0,0 +1,102 @@ | |||||
package eshore.cn.it.classification; | |||||
import java.io.File; | |||||
import java.io.FileOutputStream; | |||||
import java.io.IOException; | |||||
import java.io.ObjectOutputStream; | |||||
import java.util.logging.Logger; | |||||
import com.aliasi.classify.Classification; | |||||
import com.aliasi.classify.Classified; | |||||
import com.aliasi.classify.DynamicLMClassifier; | |||||
import com.aliasi.lm.NGramProcessLM; | |||||
import com.aliasi.util.Files; | |||||
import eshore.cn.it.configuration.ClassModelConfiguration; | |||||
/*** | |||||
* NGramClassierTrainer 是利用 NGram 特征提取的方式训练分类模型 | |||||
* 最后保存模型到指定的位置。 | |||||
* 1、训练模型(利用全新的参数) | |||||
* 2、在已有的模型上面继续更新训练 | |||||
* 模型一定要通过NGramClassierTest测试之后,达到一定的精度(目前还没有写十则交叉验证,只有二择交叉验证) | |||||
* 才能利用全部样本训练模型,投入实际生产环境。 | |||||
* @author clebeg 2015-04-14 10:40 | |||||
*/ | |||||
public class NGramClassierTrainer { | |||||
private static Logger logger = | |||||
Logger.getLogger(NGramClassierTrainer.class.getName()); | |||||
/** | |||||
* ngram特征提取,当前字符与前几个字符相关 | |||||
* 已模型训练时候的参数设置为准,默认设置为3 | |||||
* 太大容易出现过拟合。 | |||||
*/ | |||||
private int ngramSize = 3; | |||||
public int getNgramSize() { | |||||
return ngramSize; | |||||
} | |||||
public void setNgramSize(int ngramSize) { | |||||
this.ngramSize = ngramSize; | |||||
} | |||||
private ClassModelConfiguration modelConfig; | |||||
public ClassModelConfiguration getModelConfig() { | |||||
return modelConfig; | |||||
} | |||||
public void setModelConfig(ClassModelConfiguration modelConfig) { | |||||
this.modelConfig = modelConfig; | |||||
} | |||||
//基于NGram提取特征之后的分类器 | |||||
private DynamicLMClassifier<NGramProcessLM> classifier; | |||||
public void trainModel() throws IOException { | |||||
classifier | |||||
= DynamicLMClassifier.createNGramProcess(this.getModelConfig().getCategories(), | |||||
this.getNgramSize()); | |||||
for(int i = 0; i < this.getModelConfig().getCategories().length; ++i) { | |||||
File classDir = new File(this.getModelConfig().getTrainRootDir(), | |||||
this.getModelConfig().getCategories()[i]); | |||||
if (!classDir.isDirectory()) { | |||||
String msg = "无法找到包含训练数据的路径 =" | |||||
+ classDir | |||||
+ "\n你是否未在训练root目录建立 " | |||||
+ this.getModelConfig().getCategories().length | |||||
+ "类别?"; | |||||
logger.severe(msg); // in case exception gets lost in shell | |||||
throw new IllegalArgumentException(msg); | |||||
} | |||||
String[] trainingFiles = classDir.list(); | |||||
for (int j = 0; j < trainingFiles.length; ++j) { | |||||
File file = new File(classDir, trainingFiles[j]); | |||||
String text = Files.readFromFile(file, "GBK"); | |||||
System.out.println("正在训练 -> " + this.getModelConfig().getCategories()[i] + "/" + trainingFiles[j]); | |||||
Classification classification | |||||
= new Classification(this.getModelConfig().getCategories()[i]); | |||||
Classified<CharSequence> classified | |||||
= new Classified<CharSequence>(text, classification); | |||||
classifier.handle(classified); | |||||
} | |||||
} | |||||
File modelFile = new File(this.getModelConfig().getModelFile()); | |||||
System.out.println("所有类别均训练完成.开始保存模型到 " + modelFile.getAbsolutePath()); | |||||
if (modelFile.exists() == false) { | |||||
logger.info("指定模型文件不存在,将自动建立上层目录,并建立文件..."); | |||||
modelFile.getParentFile().mkdirs(); | |||||
} else { | |||||
logger.warning("指定模型文件已经存在,将覆盖..."); | |||||
} | |||||
ObjectOutputStream os = new ObjectOutputStream(new FileOutputStream( | |||||
modelFile)); | |||||
classifier.compileTo(os); | |||||
os.close(); | |||||
System.out.println("模型保持成功!"); | |||||
} | |||||
} |
@@ -0,0 +1,47 @@ | |||||
package eshore.cn.it.configuration; | |||||
/** | |||||
* 分类模型配置类 | |||||
* 配置模型分类的类别个数,配置模型训练样本的root位置 | |||||
* 目前支持的格式存放样本的格式有: | |||||
* 1、在训练样本存放的root目录里面,按照类别名称,分别存放每个类别样本 | |||||
* @author clebeg | |||||
* @time 2015-04-14 10:48 | |||||
*/ | |||||
public class ClassModelConfiguration { | |||||
/** | |||||
* 类别数值,例如初始化为:{"军事", "经济", "政治"} | |||||
*/ | |||||
private String[] categories; | |||||
public String[] getCategories() { | |||||
return categories; | |||||
} | |||||
public void setCategories(String[] categories) { | |||||
this.categories = categories; | |||||
} | |||||
/** | |||||
* 训练样本的主目录,可以是相对路径,默认是data/training | |||||
* 如果有上面的类别,那么data/training目录中必须包含 | |||||
* 军事/ 经济/ 政治/ 这三个目录,分别存放对应类别的数据 | |||||
*/ | |||||
private String trainRootDir = "data/training"; | |||||
public String getTrainRootDir() { | |||||
return trainRootDir; | |||||
} | |||||
public void setTrainRootDir(String trainRootDir) { | |||||
this.trainRootDir = trainRootDir; | |||||
} | |||||
/** | |||||
* 训练好模型之后就需要保持模型,这就模型保持的目录及文件 | |||||
*/ | |||||
private String modelFile = "data/model/model.class"; | |||||
public String getModelFile() { | |||||
return modelFile; | |||||
} | |||||
public void setModelFile(String modelFile) { | |||||
this.modelFile = modelFile; | |||||
} | |||||
} |
@@ -29,7 +29,7 @@ import com.hankcs.hanlp.seg.common.Term; | |||||
* @author clebeg | * @author clebeg | ||||
* @time 2015-04-13 | * @time 2015-04-13 | ||||
* */ | * */ | ||||
public class DfIdfClassier { | |||||
public class DfIdfClassierTest { | |||||
private static String[] CATEGORIES = { | private static String[] CATEGORIES = { | ||||
"government", | "government", | ||||
"others" | "others" | ||||
@@ -119,7 +119,7 @@ public class DfIdfClassier { | |||||
+ File.separator + testingFiles[j]); | + File.separator + testingFiles[j]); | ||||
ScoredClassification classification = compiledClassifier | ScoredClassification classification = compiledClassifier | ||||
.classify(segWords.subSequence(0, text.length())); | |||||
.classify(segWords.subSequence(0, segWords.length())); | |||||
confMatrix.increment(CATEGORIES[i], | confMatrix.increment(CATEGORIES[i], | ||||
classification.bestCategory()); | classification.bestCategory()); | ||||
System.out.println("最适合的分类: " | System.out.println("最适合的分类: " |
@@ -0,0 +1,32 @@ | |||||
package eshore.cn.it.classification; | |||||
import java.io.IOException; | |||||
import eshore.cn.it.configuration.ClassModelConfiguration; | |||||
import junit.framework.TestCase; | |||||
/*** | |||||
* 测试建立 与政府相关 以及 与政府不相关的 分类器 | |||||
* @author clebeg 2015-04-14 11:45 | |||||
*/ | |||||
public class GovernClassModelTest extends TestCase { | |||||
private String[] CATEGORIES = { | |||||
"government", | |||||
"others" | |||||
}; | |||||
public void trainModel() { | |||||
ClassModelConfiguration modelConfig = new ClassModelConfiguration(); | |||||
modelConfig.setCategories(CATEGORIES); | |||||
modelConfig.setTrainRootDir("data/text_classification/training"); | |||||
modelConfig.setModelFile("C:/model/mymodel.class"); | |||||
NGramClassierTrainer ngct = new NGramClassierTrainer(); | |||||
ngct.setNgramSize(3); | |||||
ngct.setModelConfig(modelConfig); | |||||
try { | |||||
ngct.trainModel(); | |||||
} catch (IOException e) { | |||||
e.printStackTrace(); | |||||
} | |||||
} | |||||
} |
@@ -2,7 +2,6 @@ package eshore.cn.it.classification; | |||||
import java.io.File; | import java.io.File; | ||||
import java.io.IOException; | import java.io.IOException; | ||||
import java.util.List; | |||||
import com.aliasi.classify.Classification; | import com.aliasi.classify.Classification; | ||||
import com.aliasi.classify.Classified; | import com.aliasi.classify.Classified; | ||||
@@ -14,8 +13,6 @@ import com.aliasi.classify.JointClassifierEvaluator; | |||||
import com.aliasi.lm.NGramProcessLM; | import com.aliasi.lm.NGramProcessLM; | ||||
import com.aliasi.util.AbstractExternalizable; | import com.aliasi.util.AbstractExternalizable; | ||||
import com.aliasi.util.Files; | import com.aliasi.util.Files; | ||||
import com.hankcs.hanlp.HanLP; | |||||
import com.hankcs.hanlp.seg.common.Term; | |||||
/** | /** | ||||
* 基于LingPipe的文本分类器,主要分类成两类 | * 基于LingPipe的文本分类器,主要分类成两类 | ||||
@@ -25,12 +22,12 @@ import com.hankcs.hanlp.seg.common.Term; | |||||
* @author clebeg | * @author clebeg | ||||
* @time 2015-04-13 | * @time 2015-04-13 | ||||
* */ | * */ | ||||
public class NGramClassier { | |||||
public class NGramClassierTest { | |||||
private static String[] CATEGORIES = { | private static String[] CATEGORIES = { | ||||
"government", | "government", | ||||
"others" | "others" | ||||
}; | }; | ||||
private static int NGRAM_SIZE = 3; | |||||
private static int NGRAM_SIZE = 4; | |||||
private static String TEXT_CLASSIFICATION_TRAINING = "data/text_classification/training"; | private static String TEXT_CLASSIFICATION_TRAINING = "data/text_classification/training"; | ||||
private static String TEXT_CLASSIFICATION_TESTING = "data/text_classification/testing"; | private static String TEXT_CLASSIFICATION_TESTING = "data/text_classification/testing"; | ||||
@@ -62,15 +59,15 @@ public class NGramClassier { | |||||
String text = Files.readFromFile(file, "GBK"); | String text = Files.readFromFile(file, "GBK"); | ||||
System.out.println("Training on " + CATEGORIES[i] + "/" + trainingFiles[j]); | System.out.println("Training on " + CATEGORIES[i] + "/" + trainingFiles[j]); | ||||
String segWords = ""; | |||||
List<Term> terms = HanLP.segment(text); | |||||
for (Term term : terms) | |||||
segWords += term.word + " "; | |||||
// String segWords = ""; | |||||
// List<Term> terms = HanLP.segment(text); | |||||
// for (Term term : terms) | |||||
// segWords += term.word + " "; | |||||
Classification classification | Classification classification | ||||
= new Classification(CATEGORIES[i]); | = new Classification(CATEGORIES[i]); | ||||
Classified<CharSequence> classified | Classified<CharSequence> classified | ||||
= new Classified<CharSequence>(segWords, classification); | |||||
= new Classified<CharSequence>(text, classification); | |||||
classifier.handle(classified); | classifier.handle(classified); | ||||
} | } | ||||
} | } | ||||
@@ -99,13 +96,13 @@ public class NGramClassier { | |||||
Classification classification | Classification classification | ||||
= new Classification(CATEGORIES[i]); | = new Classification(CATEGORIES[i]); | ||||
String segWords = ""; | |||||
List<Term> terms = HanLP.segment(text); | |||||
for (Term term : terms) | |||||
segWords += term.word + " "; | |||||
// String segWords = ""; | |||||
// List<Term> terms = HanLP.segment(text); | |||||
// for (Term term : terms) | |||||
// segWords += term.word + " "; | |||||
Classified<CharSequence> classified | Classified<CharSequence> classified | ||||
= new Classified<CharSequence>(segWords, classification); | |||||
= new Classified<CharSequence>(text, classification); | |||||
evaluator.handle(classified); | evaluator.handle(classified); | ||||
JointClassification jc = | JointClassification jc = | ||||
compiledClassifier.classify(text); | compiledClassifier.classify(text); |
@@ -21,8 +21,6 @@ import com.aliasi.classify.DynamicLMClassifier; | |||||
import com.aliasi.lm.NGramProcessLM; | import com.aliasi.lm.NGramProcessLM; | ||||
import com.aliasi.util.Files; | import com.aliasi.util.Files; | ||||
import com.hankcs.hanlp.HanLP; | |||||
import com.hankcs.hanlp.seg.common.Term; | |||||
//下面的分词器准确率最高,去除停用词反而准确率不高了。 | //下面的分词器准确率最高,去除停用词反而准确率不高了。 | ||||
//import com.hankcs.hanlp.tokenizer.NLPTokenizer; | //import com.hankcs.hanlp.tokenizer.NLPTokenizer; | ||||
@@ -51,6 +49,7 @@ public class ChinesePolarityBasic { | |||||
"data/polarity_corpus/hotel_reviews/test2.rlabelclass"; | "data/polarity_corpus/hotel_reviews/test2.rlabelclass"; | ||||
private static final String ENCODING = "GBK"; | private static final String ENCODING = "GBK"; | ||||
public static void main(String[] args) { | public static void main(String[] args) { | ||||
try { | try { | ||||
new ChinesePolarityBasic().run(); | new ChinesePolarityBasic().run(); | ||||
@@ -64,7 +63,7 @@ public class ChinesePolarityBasic { | |||||
public ChinesePolarityBasic() { | public ChinesePolarityBasic() { | ||||
super(); | super(); | ||||
int nGram = 8; | |||||
int nGram = 3; | |||||
mClassifier | mClassifier | ||||
= DynamicLMClassifier | = DynamicLMClassifier | ||||
.createNGramProcess(mCategories,nGram); | .createNGramProcess(mCategories,nGram); | ||||
@@ -114,16 +113,9 @@ public class ChinesePolarityBasic { | |||||
throws IOException { | throws IOException { | ||||
Classification classification = new Classification(category); | Classification classification = new Classification(category); | ||||
String review = Files.readFromFile(trainFile, fileEncoding); | String review = Files.readFromFile(trainFile, fileEncoding); | ||||
//此处加入中文分词器,得到分词之后的字符串 | |||||
String segWords = ""; | |||||
List<Term> terms = HanLP.segment(review); | |||||
for (Term term : terms) | |||||
segWords += term.word + " "; | |||||
Classified<CharSequence> classified | Classified<CharSequence> classified | ||||
= new Classified<CharSequence>(segWords,classification); | |||||
= new Classified<CharSequence>(review,classification); | |||||
mClassifier.handle(classified); | mClassifier.handle(classified); | ||||
} | } | ||||
@@ -136,14 +128,9 @@ public class ChinesePolarityBasic { | |||||
String review | String review | ||||
= Files.readFromFile(testFile, fileEncoding); | = Files.readFromFile(testFile, fileEncoding); | ||||
//同理,这里可以加入分词器,这样可以试试效果如何。 | |||||
String segWords = ""; | |||||
List<Term> terms = HanLP.segment(review); | |||||
for (Term term : terms) | |||||
segWords += term.word + " "; | |||||
++numTests; | ++numTests; | ||||
Classification classification | Classification classification | ||||
= mClassifier.classify(segWords); | |||||
= mClassifier.classify(review); | |||||
//得到训练结果 | //得到训练结果 | ||||
String resultCategory | String resultCategory | ||||
= classification.bestCategory(); | = classification.bestCategory(); |