Browse Source

[update] change data-loader to pipe

tags/v0.4.10
yunfan 5 years ago
parent
commit
44af647839
1 changed files with 8 additions and 22 deletions
  1. +8
    -22
      reproduction/text_classification/train_dpcnn.py

+ 8
- 22
reproduction/text_classification/train_dpcnn.py View File

@@ -8,21 +8,18 @@ from fastNLP.core.trainer import Trainer
from fastNLP import CrossEntropyLoss, AccuracyMetric
from fastNLP.embeddings import StaticEmbedding
from reproduction.text_classification.model.dpcnn import DPCNN
from fastNLP.io.data_loader import YelpLoader
from fastNLP.core.sampler import BucketSampler
from fastNLP.core import LRScheduler
from fastNLP.core.const import Const as C
from fastNLP.core.vocabulary import VocabularyOption
from fastNLP.core.dist_trainer import DistTrainer
from utils.util_init import set_rng_seeds
from fastNLP import logger
import os
# os.environ['FASTNLP_BASE_URL'] = 'http://10.141.222.118:8888/file/download/'
# os.environ['FASTNLP_CACHE_DIR'] = '/remote-home/hyan01/fastnlp_caches'
from fastNLP.io import YelpFullPipe, YelpPolarityPipe
os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
# hyper
logger.add_file('log', 'INFO')
print(logger.handlers)

class Config():
seed = 12345
@@ -50,18 +47,14 @@ class Config():
ops = Config()

set_rng_seeds(ops.seed)
# print('RNG SEED: {}'.format(ops.seed))
logger.info('RNG SEED %d'%ops.seed)

# 1.task相关信息:利用dataloader载入dataInfo

#datainfo=SSTLoader(fine_grained=True).process(paths=ops.datapath, train_ds=['train'])


@cache_results(ops.model_dir_or_name+'-data-cache')
def load_data():
datainfo = YelpLoader(fine_grained=True, lower=True).process(
paths=ops.datapath, train_ds=['train'], src_vocab_op=ops.src_vocab_op)
datainfo = YelpFullPipe(lower=True, tokenizer='raw').process_from_file(ops.datapath)
for ds in datainfo.datasets.values():
ds.apply_field(len, C.INPUT, C.INPUT_LEN)
ds.set_input(C.INPUT, C.INPUT_LEN)
@@ -79,11 +72,8 @@ print(embedding.embedding.weight.data.mean(), embedding.embedding.weight.data.st

# 2.或直接复用fastNLP的模型

# embedding = StackEmbedding([StaticEmbedding(vocab), CNNCharEmbedding(vocab, 100)])
datainfo.datasets['train'] = datainfo.datasets['train'][:1000]
datainfo.datasets['test'] = datainfo.datasets['test'][:1000]
# print(datainfo)
# print(datainfo.datasets['train'][0])
# datainfo.datasets['train'] = datainfo.datasets['train'][:1000] # for debug purpose
# datainfo.datasets['test'] = datainfo.datasets['test'][:1000]
logger.info(datainfo)

model = DPCNN(init_embed=embedding, num_cls=len(datainfo.vocabs[C.TARGET]),
@@ -99,14 +89,7 @@ optimizer = SGD([param for param in model.parameters() if param.requires_grad ==
callbacks = []

callbacks.append(LRScheduler(CosineAnnealingLR(optimizer, 5)))
# callbacks.append(
# LRScheduler(LambdaLR(optimizer, lambda epoch: ops.lr if epoch <
# ops.train_epoch * 0.8 else ops.lr * 0.1))
# )

# callbacks.append(
# FitlogCallback(data=datainfo.datasets, verbose=1)
# )

device = 'cuda:0' if torch.cuda.is_available() else 'cpu'

@@ -114,12 +97,15 @@ device = 'cuda:0' if torch.cuda.is_available() else 'cpu'
logger.info(device)

# 4.定义train方法
# normal trainer
trainer = Trainer(datainfo.datasets['train'], model, optimizer=optimizer, loss=loss,
sampler=BucketSampler(num_buckets=50, batch_size=ops.batch_size),
metrics=[metric], use_tqdm=False, save_path='save',
dev_data=datainfo.datasets['test'], device=device,
check_code_level=-1, batch_size=ops.batch_size, callbacks=callbacks,
n_epochs=ops.train_epoch, num_workers=4)

# distributed trainer
# trainer = DistTrainer(datainfo.datasets['train'], model, optimizer=optimizer, loss=loss,
# metrics=[metric],
# dev_data=datainfo.datasets['test'], device='cuda',


Loading…
Cancel
Save