diff --git a/reproduction/text_classification/README.md b/reproduction/text_classification/README.md index 96ea7a10..5767d9e8 100644 --- a/reproduction/text_classification/README.md +++ b/reproduction/text_classification/README.md @@ -18,6 +18,13 @@ SST:https://nlp.stanford.edu/sentiment/trainDevTestTrees_PTB.zip yelp_full:https://drive.google.com/drive/folders/0Bz8a_Dbh9Qhbfll6bVpmNUtUcFdjYmF2SEpmZUZUcVNiMUw1TWN6RDV3a0JHT3kxLVhVR2M yelp_polarity:https://drive.google.com/drive/folders/0Bz8a_Dbh9Qhbfll6bVpmNUtUcFdjYmF2SEpmZUZUcVNiMUw1TWN6RDV3a0JHT3kxLVhVR2M +dataset |classes | train samples | dev samples | test samples|refer| +:---: | :---: | :---: | :---: | :---: | :---: | +yelp_polarity | 2 |560k | - |38k|[char_cnn](https://arxiv.org/pdf/1509.01626v3.pdf)| +yelp_full | 5|650k | - |50k|[char_cnn](https://arxiv.org/pdf/1509.01626v3.pdf)| +IMDB | 2 |25k | - |25k|[IMDB](https://ai.stanford.edu/~ang/papers/acl11-WordVectorsSentimentAnalysis.pdf)| +sst-2 | 2 |67k | 872 |1.8k|[GLUE](https://arxiv.org/pdf/1804.07461.pdf)| + # 数据集及复现结果汇总 使用fastNLP复现的结果vs论文汇报结果(/前为fastNLP实现,后面为论文报道,-表示论文没有在该数据集上列出结果) diff --git a/reproduction/text_classification/train_char_cnn.py b/reproduction/text_classification/train_char_cnn.py index 6b56608a..93a15add 100644 --- a/reproduction/text_classification/train_char_cnn.py +++ b/reproduction/text_classification/train_char_cnn.py @@ -1,15 +1,8 @@ -# 首先需要加入以下的路径到环境变量,因为当前只对内部测试开放,所以需要手动申明一下路径 -import os -os.environ['FASTNLP_BASE_URL'] = 'http://10.141.222.118:8888/file/download/' -os.environ['FASTNLP_CACHE_DIR'] = '/remote-home/hyan01/fastnlp_caches' - import sys sys.path.append('../..') from fastNLP.core.const import Const as C import torch.nn as nn -from fastNLP.io.data_loader import YelpLoader from fastNLP.io.pipe.classification import YelpFullPipe,YelpPolarityPipe,SST2Pipe,IMDBPipe -#from data.sstLoader import sst2Loader from model.char_cnn import CharacterLevelCNN from fastNLP import CrossEntropyLoss, AccuracyMetric from fastNLP.core.trainer import Trainer @@ -27,9 +20,9 @@ class Config(): model_dir_or_name="en-base-uncased" embedding_grad= False, bert_embedding_larers= '4,-2,-1' - train_epoch= 50 + train_epoch= 100 num_classes=2 - task= "yelp_p" + task= "sst-2" #yelp_p datapath = {"train": "/remote-home/ygwang/yelp_polarity/train.csv", "test": "/remote-home/ygwang/yelp_polarity/test.csv"} @@ -132,15 +125,17 @@ elif ops.task == 'sst-2': else: raise RuntimeError(f'NOT support {ops.task} task yet!') +print(data_bundle) def wordtochar(words): chars = [] - for word in words: + + #for word in words: #word = word.lower() - for char in word: - chars.append(char) - chars.append('') - chars.pop() + for char in words: + chars.append(char) + #chars.append('') + #chars.pop() return chars #chartoindex @@ -162,10 +157,14 @@ def chartoindex(chars): char_index_list=[zero_index]*max_seq_len return char_index_list + for dataset in data_bundle.datasets.values(): dataset.apply_field(wordtochar, field_name="raw_words", new_field_name='chars') dataset.apply_field(chartoindex,field_name='chars',new_field_name='chars') +# print(data_bundle.datasets['train'][0]['chars']) +# print(data_bundle.datasets['train'][0]['raw_words']) + data_bundle.datasets['train'].set_input('chars') data_bundle.datasets['test'].set_input('chars') data_bundle.datasets['train'].set_target('target') @@ -216,7 +215,6 @@ model=CharacterLevelCNN(ops,embedding) ## 3. 声明loss,metric,optimizer loss=CrossEntropyLoss metric=AccuracyMetric -#optimizer= SGD([param for param in model.parameters() if param.requires_grad==True], lr=ops.lr) optimizer = SGD([param for param in model.parameters() if param.requires_grad == True], lr=ops.lr, momentum=0.9, weight_decay=ops.weight_decay) callbacks = [] @@ -236,8 +234,4 @@ def train(model,datainfo,loss,metrics,optimizer,num_epochs=100): if __name__=="__main__": - #print(vocab_label) - - #print(datainfo.datasets["train"]) train(model,data_bundle,loss,metric,optimizer,num_epochs=ops.train_epoch) - \ No newline at end of file