Browse Source

[verify]charcnn use pipe,remove dataloader

tags/v0.5.0
wyg 5 years ago
parent
commit
c29aca77ba
2 changed files with 20 additions and 19 deletions
  1. +7
    -0
      reproduction/text_classification/README.md
  2. +13
    -19
      reproduction/text_classification/train_char_cnn.py

+ 7
- 0
reproduction/text_classification/README.md View File

@@ -18,6 +18,13 @@ SST:https://nlp.stanford.edu/sentiment/trainDevTestTrees_PTB.zip
yelp_full:https://drive.google.com/drive/folders/0Bz8a_Dbh9Qhbfll6bVpmNUtUcFdjYmF2SEpmZUZUcVNiMUw1TWN6RDV3a0JHT3kxLVhVR2M
yelp_polarity:https://drive.google.com/drive/folders/0Bz8a_Dbh9Qhbfll6bVpmNUtUcFdjYmF2SEpmZUZUcVNiMUw1TWN6RDV3a0JHT3kxLVhVR2M

dataset |classes | train samples | dev samples | test samples|refer|
:---: | :---: | :---: | :---: | :---: | :---: |
yelp_polarity | 2 |560k | - |38k|[char_cnn](https://arxiv.org/pdf/1509.01626v3.pdf)|
yelp_full | 5|650k | - |50k|[char_cnn](https://arxiv.org/pdf/1509.01626v3.pdf)|
IMDB | 2 |25k | - |25k|[IMDB](https://ai.stanford.edu/~ang/papers/acl11-WordVectorsSentimentAnalysis.pdf)|
sst-2 | 2 |67k | 872 |1.8k|[GLUE](https://arxiv.org/pdf/1804.07461.pdf)|

# 数据集及复现结果汇总

使用fastNLP复现的结果vs论文汇报结果(/前为fastNLP实现,后面为论文报道,-表示论文没有在该数据集上列出结果)


+ 13
- 19
reproduction/text_classification/train_char_cnn.py View File

@@ -1,15 +1,8 @@
# 首先需要加入以下的路径到环境变量,因为当前只对内部测试开放,所以需要手动申明一下路径
import os
os.environ['FASTNLP_BASE_URL'] = 'http://10.141.222.118:8888/file/download/'
os.environ['FASTNLP_CACHE_DIR'] = '/remote-home/hyan01/fastnlp_caches'

import sys
sys.path.append('../..')
from fastNLP.core.const import Const as C
import torch.nn as nn
from fastNLP.io.data_loader import YelpLoader
from fastNLP.io.pipe.classification import YelpFullPipe,YelpPolarityPipe,SST2Pipe,IMDBPipe
#from data.sstLoader import sst2Loader
from model.char_cnn import CharacterLevelCNN
from fastNLP import CrossEntropyLoss, AccuracyMetric
from fastNLP.core.trainer import Trainer
@@ -27,9 +20,9 @@ class Config():
model_dir_or_name="en-base-uncased"
embedding_grad= False,
bert_embedding_larers= '4,-2,-1'
train_epoch= 50
train_epoch= 100
num_classes=2
task= "yelp_p"
task= "sst-2"
#yelp_p
datapath = {"train": "/remote-home/ygwang/yelp_polarity/train.csv",
"test": "/remote-home/ygwang/yelp_polarity/test.csv"}
@@ -132,15 +125,17 @@ elif ops.task == 'sst-2':
else:
raise RuntimeError(f'NOT support {ops.task} task yet!')

print(data_bundle)

def wordtochar(words):
chars = []
for word in words:

#for word in words:
#word = word.lower()
for char in word:
chars.append(char)
chars.append('')
chars.pop()
for char in words:
chars.append(char)
#chars.append('')
#chars.pop()
return chars

#chartoindex
@@ -162,10 +157,14 @@ def chartoindex(chars):
char_index_list=[zero_index]*max_seq_len
return char_index_list


for dataset in data_bundle.datasets.values():
dataset.apply_field(wordtochar, field_name="raw_words", new_field_name='chars')
dataset.apply_field(chartoindex,field_name='chars',new_field_name='chars')

# print(data_bundle.datasets['train'][0]['chars'])
# print(data_bundle.datasets['train'][0]['raw_words'])

data_bundle.datasets['train'].set_input('chars')
data_bundle.datasets['test'].set_input('chars')
data_bundle.datasets['train'].set_target('target')
@@ -216,7 +215,6 @@ model=CharacterLevelCNN(ops,embedding)
## 3. 声明loss,metric,optimizer
loss=CrossEntropyLoss
metric=AccuracyMetric
#optimizer= SGD([param for param in model.parameters() if param.requires_grad==True], lr=ops.lr)
optimizer = SGD([param for param in model.parameters() if param.requires_grad == True],
lr=ops.lr, momentum=0.9, weight_decay=ops.weight_decay)
callbacks = []
@@ -236,8 +234,4 @@ def train(model,datainfo,loss,metrics,optimizer,num_epochs=100):


if __name__=="__main__":
#print(vocab_label)

#print(datainfo.datasets["train"])
train(model,data_bundle,loss,metric,optimizer,num_epochs=ops.train_epoch)

Loading…
Cancel
Save