|
|
@@ -8,21 +8,18 @@ from fastNLP.core.trainer import Trainer |
|
|
|
from fastNLP import CrossEntropyLoss, AccuracyMetric |
|
|
|
from fastNLP.embeddings import StaticEmbedding |
|
|
|
from reproduction.text_classification.model.dpcnn import DPCNN |
|
|
|
from fastNLP.io.data_loader import YelpLoader |
|
|
|
from fastNLP.core.sampler import BucketSampler |
|
|
|
from fastNLP.core import LRScheduler |
|
|
|
from fastNLP.core.const import Const as C |
|
|
|
from fastNLP.core.vocabulary import VocabularyOption |
|
|
|
from fastNLP.core.dist_trainer import DistTrainer |
|
|
|
from utils.util_init import set_rng_seeds |
|
|
|
from fastNLP import logger |
|
|
|
import os |
|
|
|
# os.environ['FASTNLP_BASE_URL'] = 'http://10.141.222.118:8888/file/download/' |
|
|
|
# os.environ['FASTNLP_CACHE_DIR'] = '/remote-home/hyan01/fastnlp_caches' |
|
|
|
from fastNLP.io import YelpFullPipe, YelpPolarityPipe |
|
|
|
|
|
|
|
os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID" |
|
|
|
# hyper |
|
|
|
logger.add_file('log', 'INFO') |
|
|
|
print(logger.handlers) |
|
|
|
|
|
|
|
class Config(): |
|
|
|
seed = 12345 |
|
|
@@ -50,18 +47,14 @@ class Config(): |
|
|
|
ops = Config() |
|
|
|
|
|
|
|
set_rng_seeds(ops.seed) |
|
|
|
# print('RNG SEED: {}'.format(ops.seed)) |
|
|
|
logger.info('RNG SEED %d'%ops.seed) |
|
|
|
|
|
|
|
# 1.task相关信息:利用dataloader载入dataInfo |
|
|
|
|
|
|
|
#datainfo=SSTLoader(fine_grained=True).process(paths=ops.datapath, train_ds=['train']) |
|
|
|
|
|
|
|
|
|
|
|
@cache_results(ops.model_dir_or_name+'-data-cache') |
|
|
|
def load_data(): |
|
|
|
datainfo = YelpLoader(fine_grained=True, lower=True).process( |
|
|
|
paths=ops.datapath, train_ds=['train'], src_vocab_op=ops.src_vocab_op) |
|
|
|
datainfo = YelpFullPipe(lower=True, tokenizer='raw').process_from_file(ops.datapath) |
|
|
|
for ds in datainfo.datasets.values(): |
|
|
|
ds.apply_field(len, C.INPUT, C.INPUT_LEN) |
|
|
|
ds.set_input(C.INPUT, C.INPUT_LEN) |
|
|
@@ -79,11 +72,8 @@ print(embedding.embedding.weight.data.mean(), embedding.embedding.weight.data.st |
|
|
|
|
|
|
|
# 2.或直接复用fastNLP的模型 |
|
|
|
|
|
|
|
# embedding = StackEmbedding([StaticEmbedding(vocab), CNNCharEmbedding(vocab, 100)]) |
|
|
|
datainfo.datasets['train'] = datainfo.datasets['train'][:1000] |
|
|
|
datainfo.datasets['test'] = datainfo.datasets['test'][:1000] |
|
|
|
# print(datainfo) |
|
|
|
# print(datainfo.datasets['train'][0]) |
|
|
|
# datainfo.datasets['train'] = datainfo.datasets['train'][:1000] # for debug purpose |
|
|
|
# datainfo.datasets['test'] = datainfo.datasets['test'][:1000] |
|
|
|
logger.info(datainfo) |
|
|
|
|
|
|
|
model = DPCNN(init_embed=embedding, num_cls=len(datainfo.vocabs[C.TARGET]), |
|
|
@@ -99,14 +89,7 @@ optimizer = SGD([param for param in model.parameters() if param.requires_grad == |
|
|
|
callbacks = [] |
|
|
|
|
|
|
|
callbacks.append(LRScheduler(CosineAnnealingLR(optimizer, 5))) |
|
|
|
# callbacks.append( |
|
|
|
# LRScheduler(LambdaLR(optimizer, lambda epoch: ops.lr if epoch < |
|
|
|
# ops.train_epoch * 0.8 else ops.lr * 0.1)) |
|
|
|
# ) |
|
|
|
|
|
|
|
# callbacks.append( |
|
|
|
# FitlogCallback(data=datainfo.datasets, verbose=1) |
|
|
|
# ) |
|
|
|
|
|
|
|
device = 'cuda:0' if torch.cuda.is_available() else 'cpu' |
|
|
|
|
|
|
@@ -114,12 +97,15 @@ device = 'cuda:0' if torch.cuda.is_available() else 'cpu' |
|
|
|
logger.info(device) |
|
|
|
|
|
|
|
# 4.定义train方法 |
|
|
|
# normal trainer |
|
|
|
trainer = Trainer(datainfo.datasets['train'], model, optimizer=optimizer, loss=loss, |
|
|
|
sampler=BucketSampler(num_buckets=50, batch_size=ops.batch_size), |
|
|
|
metrics=[metric], use_tqdm=False, save_path='save', |
|
|
|
dev_data=datainfo.datasets['test'], device=device, |
|
|
|
check_code_level=-1, batch_size=ops.batch_size, callbacks=callbacks, |
|
|
|
n_epochs=ops.train_epoch, num_workers=4) |
|
|
|
|
|
|
|
# distributed trainer |
|
|
|
# trainer = DistTrainer(datainfo.datasets['train'], model, optimizer=optimizer, loss=loss, |
|
|
|
# metrics=[metric], |
|
|
|
# dev_data=datainfo.datasets['test'], device='cuda', |
|
|
|