diff --git a/README.md b/README.md index 4968a75..57b9e64 100644 --- a/README.md +++ b/README.md @@ -13,6 +13,10 @@ >3. 训练一个情感分类器,判断这个样本是该领域的正面信息还是负面信息。 +### 新增说明4:将模型训练和模型生成应用分离,提炼一些测试用例。 +1. 新增 NGramClassierTrainer 用于基于 NGram 特征的分类器训练 +2. 增加模型训练配置类:ClassModelConfiguration + ### 新增说明3:增加基于 TF-IDF(词向量) 特征的文本分类程序。 1. 主程序:DfIdfClassifier.java 2. 效果如下: @@ -40,7 +44,7 @@ ### 新增说明1:2015-04-10测试了不用中文分词器,分词之后 LingPipe 情感分类的准确率,同时测试了去除停用词之后的情感分类的准确率。 - +注意:有时候不用中文分词器效果更好,一定要测试。 1. 发现用HanLP的NLPTokenizer分词器,准确率最高,但是速度有点慢。 2. 如果用HanLP的标准分词器就会准确率低一点点,但是速度快。 3. 分词之后去除停用词效果更加差。 diff --git a/src/main/java/eshore/cn/it/classification/NGramClassierTrainer.java b/src/main/java/eshore/cn/it/classification/NGramClassierTrainer.java new file mode 100644 index 0000000..3a46ea9 --- /dev/null +++ b/src/main/java/eshore/cn/it/classification/NGramClassierTrainer.java @@ -0,0 +1,102 @@ +package eshore.cn.it.classification; + +import java.io.File; +import java.io.FileOutputStream; +import java.io.IOException; +import java.io.ObjectOutputStream; +import java.util.logging.Logger; + +import com.aliasi.classify.Classification; +import com.aliasi.classify.Classified; +import com.aliasi.classify.DynamicLMClassifier; +import com.aliasi.lm.NGramProcessLM; +import com.aliasi.util.Files; + +import eshore.cn.it.configuration.ClassModelConfiguration; + +/*** + * NGramClassierTrainer 是利用 NGram 特征提取的方式训练分类模型 + * 最后保存模型到指定的位置。 + * 1、训练模型(利用全新的参数) + * 2、在已有的模型上面继续更新训练 + * 模型一定要通过NGramClassierTest测试之后,达到一定的精度(目前还没有写十则交叉验证,只有二择交叉验证) + * 才能利用全部样本训练模型,投入实际生产环境。 + * @author clebeg 2015-04-14 10:40 + */ +public class NGramClassierTrainer { + private static Logger logger = + Logger.getLogger(NGramClassierTrainer.class.getName()); + /** + * ngram特征提取,当前字符与前几个字符相关 + * 已模型训练时候的参数设置为准,默认设置为3 + * 太大容易出现过拟合。 + */ + private int ngramSize = 3; + public int getNgramSize() { + return ngramSize; + } + public void setNgramSize(int ngramSize) { + this.ngramSize = ngramSize; + } + + private ClassModelConfiguration modelConfig; + public ClassModelConfiguration getModelConfig() { + return modelConfig; + } + public void setModelConfig(ClassModelConfiguration modelConfig) { + this.modelConfig = modelConfig; + } + + //基于NGram提取特征之后的分类器 + private DynamicLMClassifier classifier; + + + public void trainModel() throws IOException { + classifier + = DynamicLMClassifier.createNGramProcess(this.getModelConfig().getCategories(), + this.getNgramSize()); + + for(int i = 0; i < this.getModelConfig().getCategories().length; ++i) { + File classDir = new File(this.getModelConfig().getTrainRootDir(), + this.getModelConfig().getCategories()[i]); + + if (!classDir.isDirectory()) { + String msg = "无法找到包含训练数据的路径 =" + + classDir + + "\n你是否未在训练root目录建立 " + + this.getModelConfig().getCategories().length + + "类别?"; + logger.severe(msg); // in case exception gets lost in shell + throw new IllegalArgumentException(msg); + } + + String[] trainingFiles = classDir.list(); + for (int j = 0; j < trainingFiles.length; ++j) { + File file = new File(classDir, trainingFiles[j]); + String text = Files.readFromFile(file, "GBK"); + System.out.println("正在训练 -> " + this.getModelConfig().getCategories()[i] + "/" + trainingFiles[j]); + Classification classification + = new Classification(this.getModelConfig().getCategories()[i]); + Classified classified + = new Classified(text, classification); + classifier.handle(classified); + } + } + File modelFile = new File(this.getModelConfig().getModelFile()); + System.out.println("所有类别均训练完成.开始保存模型到 " + modelFile.getAbsolutePath()); + if (modelFile.exists() == false) { + logger.info("指定模型文件不存在,将自动建立上层目录,并建立文件..."); + modelFile.getParentFile().mkdirs(); + } else { + logger.warning("指定模型文件已经存在,将覆盖..."); + } + + ObjectOutputStream os = new ObjectOutputStream(new FileOutputStream( + modelFile)); + classifier.compileTo(os); + os.close(); + + System.out.println("模型保持成功!"); + } + +} diff --git a/src/main/java/eshore/cn/it/configuration/ClassModelConfiguration.java b/src/main/java/eshore/cn/it/configuration/ClassModelConfiguration.java new file mode 100644 index 0000000..442ef41 --- /dev/null +++ b/src/main/java/eshore/cn/it/configuration/ClassModelConfiguration.java @@ -0,0 +1,47 @@ +package eshore.cn.it.configuration; +/** + * 分类模型配置类 + * 配置模型分类的类别个数,配置模型训练样本的root位置 + * 目前支持的格式存放样本的格式有: + * 1、在训练样本存放的root目录里面,按照类别名称,分别存放每个类别样本 + * @author clebeg + * @time 2015-04-14 10:48 + */ +public class ClassModelConfiguration { + /** + * 类别数值,例如初始化为:{"军事", "经济", "政治"} + */ + private String[] categories; + public String[] getCategories() { + return categories; + } + public void setCategories(String[] categories) { + this.categories = categories; + } + + /** + * 训练样本的主目录,可以是相对路径,默认是data/training + * 如果有上面的类别,那么data/training目录中必须包含 + * 军事/ 经济/ 政治/ 这三个目录,分别存放对应类别的数据 + */ + private String trainRootDir = "data/training"; + public String getTrainRootDir() { + return trainRootDir; + } + public void setTrainRootDir(String trainRootDir) { + this.trainRootDir = trainRootDir; + } + + /** + * 训练好模型之后就需要保持模型,这就模型保持的目录及文件 + */ + private String modelFile = "data/model/model.class"; + public String getModelFile() { + return modelFile; + } + public void setModelFile(String modelFile) { + this.modelFile = modelFile; + } + + +} diff --git a/src/main/java/eshore/cn/it/classification/DfIdfClassier.java b/src/test/java/eshore/cn/it/classification/DfIdfClassierTest.java similarity index 98% rename from src/main/java/eshore/cn/it/classification/DfIdfClassier.java rename to src/test/java/eshore/cn/it/classification/DfIdfClassierTest.java index 125d6b6..39e8485 100644 --- a/src/main/java/eshore/cn/it/classification/DfIdfClassier.java +++ b/src/test/java/eshore/cn/it/classification/DfIdfClassierTest.java @@ -29,7 +29,7 @@ import com.hankcs.hanlp.seg.common.Term; * @author clebeg * @time 2015-04-13 * */ -public class DfIdfClassier { +public class DfIdfClassierTest { private static String[] CATEGORIES = { "government", "others" @@ -119,7 +119,7 @@ public class DfIdfClassier { + File.separator + testingFiles[j]); ScoredClassification classification = compiledClassifier - .classify(segWords.subSequence(0, text.length())); + .classify(segWords.subSequence(0, segWords.length())); confMatrix.increment(CATEGORIES[i], classification.bestCategory()); System.out.println("最适合的分类: " diff --git a/src/test/java/eshore/cn/it/classification/GovernClassModelTest.java b/src/test/java/eshore/cn/it/classification/GovernClassModelTest.java new file mode 100644 index 0000000..c609c91 --- /dev/null +++ b/src/test/java/eshore/cn/it/classification/GovernClassModelTest.java @@ -0,0 +1,32 @@ +package eshore.cn.it.classification; + +import java.io.IOException; + +import eshore.cn.it.configuration.ClassModelConfiguration; +import junit.framework.TestCase; +/*** + * 测试建立 与政府相关 以及 与政府不相关的 分类器 + * @author clebeg 2015-04-14 11:45 + */ +public class GovernClassModelTest extends TestCase { + private String[] CATEGORIES = { + "government", + "others" + }; + + public void trainModel() { + ClassModelConfiguration modelConfig = new ClassModelConfiguration(); + modelConfig.setCategories(CATEGORIES); + modelConfig.setTrainRootDir("data/text_classification/training"); + modelConfig.setModelFile("C:/model/mymodel.class"); + + NGramClassierTrainer ngct = new NGramClassierTrainer(); + ngct.setNgramSize(3); + ngct.setModelConfig(modelConfig); + try { + ngct.trainModel(); + } catch (IOException e) { + e.printStackTrace(); + } + } +} diff --git a/src/main/java/eshore/cn/it/classification/NGramClassier.java b/src/test/java/eshore/cn/it/classification/NGramClassierTest.java similarity index 87% rename from src/main/java/eshore/cn/it/classification/NGramClassier.java rename to src/test/java/eshore/cn/it/classification/NGramClassierTest.java index 44e54e6..e71a048 100644 --- a/src/main/java/eshore/cn/it/classification/NGramClassier.java +++ b/src/test/java/eshore/cn/it/classification/NGramClassierTest.java @@ -2,7 +2,6 @@ package eshore.cn.it.classification; import java.io.File; import java.io.IOException; -import java.util.List; import com.aliasi.classify.Classification; import com.aliasi.classify.Classified; @@ -14,8 +13,6 @@ import com.aliasi.classify.JointClassifierEvaluator; import com.aliasi.lm.NGramProcessLM; import com.aliasi.util.AbstractExternalizable; import com.aliasi.util.Files; -import com.hankcs.hanlp.HanLP; -import com.hankcs.hanlp.seg.common.Term; /** * 基于LingPipe的文本分类器,主要分类成两类 @@ -25,12 +22,12 @@ import com.hankcs.hanlp.seg.common.Term; * @author clebeg * @time 2015-04-13 * */ -public class NGramClassier { +public class NGramClassierTest { private static String[] CATEGORIES = { "government", "others" }; - private static int NGRAM_SIZE = 3; + private static int NGRAM_SIZE = 4; private static String TEXT_CLASSIFICATION_TRAINING = "data/text_classification/training"; private static String TEXT_CLASSIFICATION_TESTING = "data/text_classification/testing"; @@ -62,15 +59,15 @@ public class NGramClassier { String text = Files.readFromFile(file, "GBK"); System.out.println("Training on " + CATEGORIES[i] + "/" + trainingFiles[j]); - String segWords = ""; - List terms = HanLP.segment(text); - for (Term term : terms) - segWords += term.word + " "; +// String segWords = ""; +// List terms = HanLP.segment(text); +// for (Term term : terms) +// segWords += term.word + " "; Classification classification = new Classification(CATEGORIES[i]); Classified classified - = new Classified(segWords, classification); + = new Classified(text, classification); classifier.handle(classified); } } @@ -99,13 +96,13 @@ public class NGramClassier { Classification classification = new Classification(CATEGORIES[i]); - String segWords = ""; - List terms = HanLP.segment(text); - for (Term term : terms) - segWords += term.word + " "; +// String segWords = ""; +// List terms = HanLP.segment(text); +// for (Term term : terms) +// segWords += term.word + " "; Classified classified - = new Classified(segWords, classification); + = new Classified(text, classification); evaluator.handle(classified); JointClassification jc = compiledClassifier.classify(text); diff --git a/src/main/java/eshore/cn/it/sentiment/ChinesePolarityBasic.java b/src/test/java/eshore/cn/it/sentiment/ChinesePolarityBasic.java similarity index 86% rename from src/main/java/eshore/cn/it/sentiment/ChinesePolarityBasic.java rename to src/test/java/eshore/cn/it/sentiment/ChinesePolarityBasic.java index e3429cc..cb8ddb4 100644 --- a/src/main/java/eshore/cn/it/sentiment/ChinesePolarityBasic.java +++ b/src/test/java/eshore/cn/it/sentiment/ChinesePolarityBasic.java @@ -21,8 +21,6 @@ import com.aliasi.classify.DynamicLMClassifier; import com.aliasi.lm.NGramProcessLM; import com.aliasi.util.Files; -import com.hankcs.hanlp.HanLP; -import com.hankcs.hanlp.seg.common.Term; //下面的分词器准确率最高,去除停用词反而准确率不高了。 //import com.hankcs.hanlp.tokenizer.NLPTokenizer; @@ -51,6 +49,7 @@ public class ChinesePolarityBasic { "data/polarity_corpus/hotel_reviews/test2.rlabelclass"; private static final String ENCODING = "GBK"; + public static void main(String[] args) { try { new ChinesePolarityBasic().run(); @@ -64,7 +63,7 @@ public class ChinesePolarityBasic { public ChinesePolarityBasic() { super(); - int nGram = 8; + int nGram = 3; mClassifier = DynamicLMClassifier .createNGramProcess(mCategories,nGram); @@ -114,16 +113,9 @@ public class ChinesePolarityBasic { throws IOException { Classification classification = new Classification(category); String review = Files.readFromFile(trainFile, fileEncoding); - - //此处加入中文分词器,得到分词之后的字符串 - - String segWords = ""; - List terms = HanLP.segment(review); - for (Term term : terms) - segWords += term.word + " "; - + Classified classified - = new Classified(segWords,classification); + = new Classified(review,classification); mClassifier.handle(classified); } @@ -136,14 +128,9 @@ public class ChinesePolarityBasic { String review = Files.readFromFile(testFile, fileEncoding); - //同理,这里可以加入分词器,这样可以试试效果如何。 - String segWords = ""; - List terms = HanLP.segment(review); - for (Term term : terms) - segWords += term.word + " "; ++numTests; Classification classification - = mClassifier.classify(segWords); + = mClassifier.classify(review); //得到训练结果 String resultCategory = classification.bestCategory(); diff --git a/src/main/java/eshore/cn/it/sentiment/PolarityBasic.java b/src/test/java/eshore/cn/it/sentiment/PolarityBasic.java similarity index 100% rename from src/main/java/eshore/cn/it/sentiment/PolarityBasic.java rename to src/test/java/eshore/cn/it/sentiment/PolarityBasic.java diff --git a/src/main/java/eshore/cn/it/sentiment/Sentiment.java b/src/test/java/eshore/cn/it/sentiment/Sentiment.java similarity index 100% rename from src/main/java/eshore/cn/it/sentiment/Sentiment.java rename to src/test/java/eshore/cn/it/sentiment/Sentiment.java