Browse Source

Merge pull request #213 from SrWYG/dev0.5.0

[verify] char_cnn使用pipe,移除dataloader以及环境变量相关代码
tags/v0.5.0
Yige Xu GitHub 5 years ago
parent
commit
b5b4de745a
No known key found for this signature in database GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 19 additions and 28 deletions
  1. +7
    -0
      reproduction/text_classification/README.md
  2. +12
    -28
      reproduction/text_classification/train_char_cnn.py

+ 7
- 0
reproduction/text_classification/README.md View File

@@ -18,6 +18,13 @@ SST:https://nlp.stanford.edu/sentiment/trainDevTestTrees_PTB.zip
yelp_full:https://drive.google.com/drive/folders/0Bz8a_Dbh9Qhbfll6bVpmNUtUcFdjYmF2SEpmZUZUcVNiMUw1TWN6RDV3a0JHT3kxLVhVR2M yelp_full:https://drive.google.com/drive/folders/0Bz8a_Dbh9Qhbfll6bVpmNUtUcFdjYmF2SEpmZUZUcVNiMUw1TWN6RDV3a0JHT3kxLVhVR2M
yelp_polarity:https://drive.google.com/drive/folders/0Bz8a_Dbh9Qhbfll6bVpmNUtUcFdjYmF2SEpmZUZUcVNiMUw1TWN6RDV3a0JHT3kxLVhVR2M yelp_polarity:https://drive.google.com/drive/folders/0Bz8a_Dbh9Qhbfll6bVpmNUtUcFdjYmF2SEpmZUZUcVNiMUw1TWN6RDV3a0JHT3kxLVhVR2M


dataset |classes | train samples | dev samples | test samples|refer|
:---: | :---: | :---: | :---: | :---: | :---: |
yelp_polarity | 2 |560k | - |38k|[char_cnn](https://arxiv.org/pdf/1509.01626v3.pdf)|
yelp_full | 5|650k | - |50k|[char_cnn](https://arxiv.org/pdf/1509.01626v3.pdf)|
IMDB | 2 |25k | - |25k|[IMDB](https://ai.stanford.edu/~ang/papers/acl11-WordVectorsSentimentAnalysis.pdf)|
sst-2 | 2 |67k | 872 |1.8k|[GLUE](https://arxiv.org/pdf/1804.07461.pdf)|

# 数据集及复现结果汇总 # 数据集及复现结果汇总


使用fastNLP复现的结果vs论文汇报结果(/前为fastNLP实现,后面为论文报道,-表示论文没有在该数据集上列出结果) 使用fastNLP复现的结果vs论文汇报结果(/前为fastNLP实现,后面为论文报道,-表示论文没有在该数据集上列出结果)


+ 12
- 28
reproduction/text_classification/train_char_cnn.py View File

@@ -1,15 +1,8 @@
# 首先需要加入以下的路径到环境变量,因为当前只对内部测试开放,所以需要手动申明一下路径
import os
os.environ['FASTNLP_BASE_URL'] = 'http://10.141.222.118:8888/file/download/'
os.environ['FASTNLP_CACHE_DIR'] = '/remote-home/hyan01/fastnlp_caches'

import sys import sys
sys.path.append('../..') sys.path.append('../..')
from fastNLP.core.const import Const as C from fastNLP.core.const import Const as C
import torch.nn as nn import torch.nn as nn
from fastNLP.io.data_loader import YelpLoader
from fastNLP.io.pipe.classification import YelpFullPipe,YelpPolarityPipe,SST2Pipe,IMDBPipe from fastNLP.io.pipe.classification import YelpFullPipe,YelpPolarityPipe,SST2Pipe,IMDBPipe
#from data.sstLoader import sst2Loader
from model.char_cnn import CharacterLevelCNN from model.char_cnn import CharacterLevelCNN
from fastNLP import CrossEntropyLoss, AccuracyMetric from fastNLP import CrossEntropyLoss, AccuracyMetric
from fastNLP.core.trainer import Trainer from fastNLP.core.trainer import Trainer
@@ -27,19 +20,9 @@ class Config():
model_dir_or_name="en-base-uncased" model_dir_or_name="en-base-uncased"
embedding_grad= False, embedding_grad= False,
bert_embedding_larers= '4,-2,-1' bert_embedding_larers= '4,-2,-1'
train_epoch= 50
train_epoch= 100
num_classes=2 num_classes=2
task= "yelp_p" task= "yelp_p"
#yelp_p
datapath = {"train": "/remote-home/ygwang/yelp_polarity/train.csv",
"test": "/remote-home/ygwang/yelp_polarity/test.csv"}
#IMDB
#datapath = {"train": "/remote-home/ygwang/IMDB_data/train.csv",
# "test": "/remote-home/ygwang/IMDB_data/test.csv"}
# sst
# datapath = {"train": "/remote-home/ygwang/workspace/GLUE/SST-2/train.tsv",
# "dev": "/remote-home/ygwang/workspace/GLUE/SST-2/dev.tsv"}

lr=0.01 lr=0.01
batch_size=128 batch_size=128
model_size="large" model_size="large"
@@ -132,15 +115,17 @@ elif ops.task == 'sst-2':
else: else:
raise RuntimeError(f'NOT support {ops.task} task yet!') raise RuntimeError(f'NOT support {ops.task} task yet!')


print(data_bundle)


def wordtochar(words): def wordtochar(words):
chars = [] chars = []
for word in words:

#for word in words:
#word = word.lower() #word = word.lower()
for char in word:
chars.append(char)
chars.append('')
chars.pop()
for char in words:
chars.append(char)
#chars.append('')
#chars.pop()
return chars return chars


#chartoindex #chartoindex
@@ -162,10 +147,14 @@ def chartoindex(chars):
char_index_list=[zero_index]*max_seq_len char_index_list=[zero_index]*max_seq_len
return char_index_list return char_index_list



for dataset in data_bundle.datasets.values(): for dataset in data_bundle.datasets.values():
dataset.apply_field(wordtochar, field_name="raw_words", new_field_name='chars') dataset.apply_field(wordtochar, field_name="raw_words", new_field_name='chars')
dataset.apply_field(chartoindex,field_name='chars',new_field_name='chars') dataset.apply_field(chartoindex,field_name='chars',new_field_name='chars')


# print(data_bundle.datasets['train'][0]['chars'])
# print(data_bundle.datasets['train'][0]['raw_words'])

data_bundle.datasets['train'].set_input('chars') data_bundle.datasets['train'].set_input('chars')
data_bundle.datasets['test'].set_input('chars') data_bundle.datasets['test'].set_input('chars')
data_bundle.datasets['train'].set_target('target') data_bundle.datasets['train'].set_target('target')
@@ -216,7 +205,6 @@ model=CharacterLevelCNN(ops,embedding)
## 3. 声明loss,metric,optimizer ## 3. 声明loss,metric,optimizer
loss=CrossEntropyLoss loss=CrossEntropyLoss
metric=AccuracyMetric metric=AccuracyMetric
#optimizer= SGD([param for param in model.parameters() if param.requires_grad==True], lr=ops.lr)
optimizer = SGD([param for param in model.parameters() if param.requires_grad == True], optimizer = SGD([param for param in model.parameters() if param.requires_grad == True],
lr=ops.lr, momentum=0.9, weight_decay=ops.weight_decay) lr=ops.lr, momentum=0.9, weight_decay=ops.weight_decay)
callbacks = [] callbacks = []
@@ -236,8 +224,4 @@ def train(model,datainfo,loss,metrics,optimizer,num_epochs=100):




if __name__=="__main__": if __name__=="__main__":
#print(vocab_label)

#print(datainfo.datasets["train"])
train(model,data_bundle,loss,metric,optimizer,num_epochs=ops.train_epoch) train(model,data_bundle,loss,metric,optimizer,num_epochs=ops.train_epoch)

Loading…
Cancel
Save