From 040bd2ab0774842c759b7424f355ae3a23648a64 Mon Sep 17 00:00:00 2001 From: yunfan Date: Wed, 22 May 2019 18:06:52 +0800 Subject: [PATCH] - add star-transformer reproduction --- fastNLP/models/star_transformer.py | 19 +- fastNLP/modules/encoder/star_transformer.py | 2 +- reproduction/Star_transformer/datasets.py | 157 ++++++++++++++ reproduction/Star_transformer/modules.py | 56 +++++ reproduction/Star_transformer/run.sh | 5 + reproduction/Star_transformer/train.py | 214 ++++++++++++++++++++ reproduction/Star_transformer/util.py | 112 ++++++++++ 7 files changed, 555 insertions(+), 10 deletions(-) create mode 100644 reproduction/Star_transformer/datasets.py create mode 100644 reproduction/Star_transformer/modules.py create mode 100644 reproduction/Star_transformer/run.sh create mode 100644 reproduction/Star_transformer/train.py create mode 100644 reproduction/Star_transformer/util.py diff --git a/fastNLP/models/star_transformer.py b/fastNLP/models/star_transformer.py index c67e5938..4c944a54 100644 --- a/fastNLP/models/star_transformer.py +++ b/fastNLP/models/star_transformer.py @@ -26,13 +26,11 @@ class StarTransEnc(nn.Module): :param init_embed: 单词词典, 可以是 tuple, 包括(num_embedings, embedding_dim), 即 embedding的大小和每个词的维度. 也可以传入 nn.Embedding 对象, 此时就以传入的对象作为embedding - :param num_cls: 输出类别个数 :param hidden_size: 模型中特征维度. :param num_layers: 模型层数. :param num_head: 模型中multi-head的head个数. :param head_dim: 模型中multi-head中每个head特征维度. :param max_len: 模型能接受的最大输入长度. - :param cls_hidden_size: 分类器隐层维度. :param emb_dropout: 词嵌入的dropout概率. :param dropout: 模型除词嵌入外的dropout概率. """ @@ -59,7 +57,7 @@ class StarTransEnc(nn.Module): def forward(self, x, mask): """ - :param FloatTensor data: [batch, length, hidden] 输入的序列 + :param FloatTensor x: [batch, length, hidden] 输入的序列 :param ByteTensor mask: [batch, length] 输入序列的padding mask, 在没有内容(padding 部分) 为 0, 否则为 1 :return: [batch, length, hidden] 编码后的输出序列 @@ -110,8 +108,9 @@ class STSeqLabel(nn.Module): 用于序列标注的Star-Transformer模型 - :param vocab_size: 词嵌入的词典大小 - :param emb_dim: 每个词嵌入的特征维度 + :param init_embed: 单词词典, 可以是 tuple, 包括(num_embedings, embedding_dim), 即 + embedding的大小和每个词的维度. 也可以传入 nn.Embedding 对象, + 此时就以传入的对象作为embedding :param num_cls: 输出类别个数 :param hidden_size: 模型中特征维度. Default: 300 :param num_layers: 模型层数. Default: 4 @@ -174,8 +173,9 @@ class STSeqCls(nn.Module): 用于分类任务的Star-Transformer - :param vocab_size: 词嵌入的词典大小 - :param emb_dim: 每个词嵌入的特征维度 + :param init_embed: 单词词典, 可以是 tuple, 包括(num_embedings, embedding_dim), 即 + embedding的大小和每个词的维度. 也可以传入 nn.Embedding 对象, + 此时就以传入的对象作为embedding :param num_cls: 输出类别个数 :param hidden_size: 模型中特征维度. Default: 300 :param num_layers: 模型层数. Default: 4 @@ -238,8 +238,9 @@ class STNLICls(nn.Module): 用于自然语言推断(NLI)的Star-Transformer - :param vocab_size: 词嵌入的词典大小 - :param emb_dim: 每个词嵌入的特征维度 + :param init_embed: 单词词典, 可以是 tuple, 包括(num_embedings, embedding_dim), 即 + embedding的大小和每个词的维度. 也可以传入 nn.Embedding 对象, + 此时就以传入的对象作为embedding :param num_cls: 输出类别个数 :param hidden_size: 模型中特征维度. Default: 300 :param num_layers: 模型层数. Default: 4 diff --git a/fastNLP/modules/encoder/star_transformer.py b/fastNLP/modules/encoder/star_transformer.py index 5a7f3d67..1eec7c13 100644 --- a/fastNLP/modules/encoder/star_transformer.py +++ b/fastNLP/modules/encoder/star_transformer.py @@ -43,7 +43,7 @@ class StarTransformer(nn.Module): for _ in range(self.iters)]) if max_len is not None: - self.pos_emb = self.pos_emb = nn.Embedding(max_len, hidden_size) + self.pos_emb = nn.Embedding(max_len, hidden_size) else: self.pos_emb = None diff --git a/reproduction/Star_transformer/datasets.py b/reproduction/Star_transformer/datasets.py new file mode 100644 index 00000000..a9257fd4 --- /dev/null +++ b/reproduction/Star_transformer/datasets.py @@ -0,0 +1,157 @@ +import torch +import json +import os +from fastNLP import Vocabulary +from fastNLP.io.dataset_loader import ConllLoader, SSTLoader, SNLILoader +from fastNLP.core import Const as C +import numpy as np + +MAX_LEN = 128 + +def update_v(vocab, data, field): + data.apply(lambda x: vocab.add_word_lst(x[field]), new_field_name=None) + + +def to_index(vocab, data, field, name): + def func(x): + try: + return [vocab.to_index(w) for w in x[field]] + except ValueError: + return [vocab.padding_idx for _ in x[field]] + data.apply(func, new_field_name=name) + + +def load_seqtag(path, files, indexs): + word_h, tag_h = 'words', 'tags' + loader = ConllLoader(headers=[word_h, tag_h], indexes=indexs) + ds_list = [] + for fn in files: + ds_list.append(loader.load(os.path.join(path, fn))) + word_v = Vocabulary(min_freq=2) + tag_v = Vocabulary(unknown=None) + update_v(word_v, ds_list[0], word_h) + update_v(tag_v, ds_list[0], tag_h) + + def process_data(ds): + to_index(word_v, ds, word_h, C.INPUT) + to_index(tag_v, ds, tag_h, C.TARGET) + ds.apply(lambda x: x[C.INPUT][:MAX_LEN], new_field_name=C.INPUT) + ds.apply(lambda x: x[C.TARGET][:MAX_LEN], new_field_name=C.TARGET) + ds.apply(lambda x: len(x[word_h]), new_field_name=C.INPUT_LEN) + ds.set_input(C.INPUT, C.INPUT_LEN) + ds.set_target(C.TARGET, C.INPUT_LEN) + for i in range(len(ds_list)): + process_data(ds_list[i]) + return ds_list, word_v, tag_v + + +def load_sst(path, files): + loaders = [SSTLoader(subtree=sub, fine_grained=True) + for sub in [True, False, False]] + ds_list = [loader.load(os.path.join(path, fn)) + for fn, loader in zip(files, loaders)] + word_v = Vocabulary(min_freq=2) + tag_v = Vocabulary(unknown=None, padding=None) + for ds in ds_list: + ds.apply(lambda x: [w.lower() + for w in x['words']], new_field_name='words') + ds_list[0].drop(lambda x: len(x['words']) < 3) + update_v(word_v, ds_list[0], 'words') + ds_list[0].apply(lambda x: tag_v.add_word( + x['target']), new_field_name=None) + + def process_data(ds): + to_index(word_v, ds, 'words', C.INPUT) + ds.apply(lambda x: tag_v.to_index(x['target']), new_field_name=C.TARGET) + ds.apply(lambda x: x[C.INPUT][:MAX_LEN], new_field_name=C.INPUT) + ds.apply(lambda x: len(x['words']), new_field_name=C.INPUT_LEN) + ds.set_input(C.INPUT, C.INPUT_LEN) + ds.set_target(C.TARGET) + for i in range(len(ds_list)): + process_data(ds_list[i]) + return ds_list, word_v, tag_v + + +def load_snli(path, files): + loader = SNLILoader() + ds_list = [loader.load(os.path.join(path, f)) for f in files] + word_v = Vocabulary(min_freq=2) + tag_v = Vocabulary(unknown=None, padding=None) + for ds in ds_list: + ds.apply(lambda x: [w.lower() + for w in x['words1']], new_field_name='words1') + ds.apply(lambda x: [w.lower() + for w in x['words2']], new_field_name='words2') + update_v(word_v, ds_list[0], 'words1') + update_v(word_v, ds_list[0], 'words2') + ds_list[0].apply(lambda x: tag_v.add_word( + x['target']), new_field_name=None) + + def process_data(ds): + to_index(word_v, ds, 'words1', C.INPUTS(0)) + to_index(word_v, ds, 'words2', C.INPUTS(1)) + ds.apply(lambda x: tag_v.to_index(x['target']), new_field_name=C.TARGET) + ds.apply(lambda x: x[C.INPUTS(0)][:MAX_LEN], new_field_name=C.INPUTS(0)) + ds.apply(lambda x: x[C.INPUTS(1)][:MAX_LEN], new_field_name=C.INPUTS(1)) + ds.apply(lambda x: len(x[C.INPUTS(0)]), new_field_name=C.INPUT_LENS(0)) + ds.apply(lambda x: len(x[C.INPUTS(1)]), new_field_name=C.INPUT_LENS(1)) + ds.set_input(C.INPUTS(0), C.INPUTS(1), C.INPUT_LENS(0), C.INPUT_LENS(1)) + ds.set_target(C.TARGET) + for i in range(len(ds_list)): + process_data(ds_list[i]) + return ds_list, word_v, tag_v + + +class EmbedLoader: + @staticmethod + def parse_glove_line(line): + line = line.split() + if len(line) <= 2: + raise RuntimeError( + "something goes wrong in parsing glove embedding") + return line[0], line[1:] + + @staticmethod + def str_list_2_vec(line): + return torch.Tensor(list(map(float, line))) + + @staticmethod + def fast_load_embedding(emb_dim, emb_file, vocab): + """Fast load the pre-trained embedding and combine with the given dictionary. + This loading method uses line-by-line operation. + + :param int emb_dim: the dimension of the embedding. Should be the same as pre-trained embedding. + :param str emb_file: the pre-trained embedding file path. + :param Vocabulary vocab: a mapping from word to index, can be provided by user or built from pre-trained embedding + :return embedding_matrix: numpy.ndarray + + """ + if vocab is None: + raise RuntimeError("You must provide a vocabulary.") + embedding_matrix = np.zeros( + shape=(len(vocab), emb_dim), dtype=np.float32) + hit_flags = np.zeros(shape=(len(vocab),), dtype=int) + with open(emb_file, "r", encoding="utf-8") as f: + startline = f.readline() + if len(startline.split()) > 2: + f.seek(0) + for line in f: + word, vector = EmbedLoader.parse_glove_line(line) + try: + if word in vocab: + vector = EmbedLoader.str_list_2_vec(vector) + if emb_dim != vector.size(0): + continue + embedding_matrix[vocab[word]] = vector + hit_flags[vocab[word]] = 1 + except Exception: + continue + + if np.sum(hit_flags) < len(vocab): + # some words from vocab are missing in pre-trained embedding + # we normally sample each dimension + vocab_embed = embedding_matrix[np.where(hit_flags)] + sampled_vectors = np.random.normal(vocab_embed.mean(axis=0), vocab_embed.std(axis=0), + size=(len(vocab) - np.sum(hit_flags), emb_dim)) + embedding_matrix[np.where(1 - hit_flags)] = sampled_vectors + return embedding_matrix diff --git a/reproduction/Star_transformer/modules.py b/reproduction/Star_transformer/modules.py new file mode 100644 index 00000000..61a61d25 --- /dev/null +++ b/reproduction/Star_transformer/modules.py @@ -0,0 +1,56 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F +import numpy as np +from fastNLP.core.losses import LossBase + + +reduce_func = { + 'none': lambda x, mask: x*mask, + 'sum': lambda x, mask: (x*mask).sum(), + 'mean': lambda x, mask: (x*mask).sum() / mask.sum(), +} + + +class LabelSmoothCrossEntropy(nn.Module): + def __init__(self, smoothing=0.1, ignore_index=-100, reduction='mean'): + global reduce_func + super().__init__() + if smoothing < 0 or smoothing > 1: + raise ValueError('invalid smoothing value: {}'.format(smoothing)) + self.smoothing = smoothing + self.ignore_index = ignore_index + if reduction not in reduce_func: + raise ValueError('invalid reduce type: {}'.format(reduction)) + self.reduce_func = reduce_func[reduction] + + def forward(self, input, target): + input = F.log_softmax(input, dim=1) # [N, C, ...] + smooth_val = self.smoothing / input.size(1) # [N, C, ...] + target_logit = input.new_full(input.size(), fill_value=smooth_val) + target_logit.scatter_(1, target[:, None], 1 - self.smoothing) + result = -(target_logit * input).sum(1) # [N, ...] + mask = (target != self.ignore_index).float() + return self.reduce_func(result, mask) + + +class SmoothCE(LossBase): + def __init__(self, pred=None, target=None, **kwargs): + super().__init__() + self.loss_fn = LabelSmoothCrossEntropy(**kwargs) + self._init_param_map(pred=pred, target=target) + + def get_loss(self, pred, target): + return self.loss_fn(pred, target) + + +if __name__ == '__main__': + loss_fn = nn.CrossEntropyLoss(ignore_index=0) + sm_loss_fn = LabelSmoothCrossEntropy(smoothing=0, ignore_index=0) + predict = torch.tensor([[0, 0.2, 0.7, 0.1, 0], + [0, 0.9, 0.2, 0.1, 0], + [1, 0.2, 0.7, 0.1, 0]]) + target = torch.tensor([2, 1, 0]) + loss = loss_fn(predict, target) + sm_loss = sm_loss_fn(predict, target) + print(loss, sm_loss) diff --git a/reproduction/Star_transformer/run.sh b/reproduction/Star_transformer/run.sh new file mode 100644 index 00000000..0972c662 --- /dev/null +++ b/reproduction/Star_transformer/run.sh @@ -0,0 +1,5 @@ +#python -u train.py --task pos --ds conll --mode train --gpu 1 --lr 3e-4 --w_decay 2e-5 --lr_decay .95 --drop 0.3 --ep 25 --bsz 64 > conll_pos102.log 2>&1 & +#python -u train.py --task pos --ds ctb --mode train --gpu 1 --lr 3e-4 --w_decay 2e-5 --lr_decay .95 --drop 0.3 --ep 25 --bsz 64 > ctb_pos101.log 2>&1 & +#python -u train.py --task cls --ds sst --mode train --gpu 2 --lr 1e-4 --w_decay 1e-5 --lr_decay 0.9 --drop 0.5 --ep 50 --bsz 128 > sst_cls201.log & +#python -u train.py --task nli --ds snli --mode train --gpu 1 --lr 1e-4 --w_decay 1e-5 --lr_decay 0.9 --drop 0.4 --ep 120 --bsz 128 > snli_nli201.log & +python -u train.py --task ner --ds conll --mode train --gpu 0 --lr 1e-4 --w_decay 1e-5 --lr_decay 0.9 --drop 0.4 --ep 120 --bsz 64 > conll_ner201.log & diff --git a/reproduction/Star_transformer/train.py b/reproduction/Star_transformer/train.py new file mode 100644 index 00000000..dee85c38 --- /dev/null +++ b/reproduction/Star_transformer/train.py @@ -0,0 +1,214 @@ +from util import get_argparser, set_gpu, set_rng_seeds, add_model_args +from datasets import load_seqtag, load_sst, load_snli, EmbedLoader, MAX_LEN +import torch.nn as nn +import torch +import numpy as np +import fastNLP as FN +from fastNLP.models.star_transformer import STSeqLabel, STSeqCls, STNLICls +from fastNLP.core.const import Const as C +import sys +sys.path.append('/remote-home/yfshao/workdir/dev_fastnlp/') + + +g_model_select = { + 'pos': STSeqLabel, + 'ner': STSeqLabel, + 'cls': STSeqCls, + 'nli': STNLICls, +} + +g_emb_file_path = {'en': '/remote-home/yfshao/workdir/datasets/word_vector/glove.840B.300d.txt', + 'zh': '/remote-home/yfshao/workdir/datasets/word_vector/cc.zh.300.vec'} + +g_args = None +g_model_cfg = None + + +def get_ptb_pos(): + pos_dir = '/remote-home/yfshao/workdir/datasets/pos' + pos_files = ['train.pos', 'dev.pos', 'test.pos', ] + return load_seqtag(pos_dir, pos_files, [0, 1]) + + +def get_ctb_pos(): + ctb_dir = '/remote-home/yfshao/workdir/datasets/ctb9_hy' + files = ['train.conllx', 'dev.conllx', 'test.conllx'] + return load_seqtag(ctb_dir, files, [1, 4]) + + +def get_conll2012_pos(): + path = '/remote-home/yfshao/workdir/datasets/ontonotes/pos' + files = ['ontonotes-conll.train', + 'ontonotes-conll.dev', + 'ontonotes-conll.conll-2012-test'] + return load_seqtag(path, files, [0, 1]) + + +def get_conll2012_ner(): + path = '/remote-home/yfshao/workdir/datasets/ontonotes/ner' + files = ['bieso-ontonotes-conll-ner.train', + 'bieso-ontonotes-conll-ner.dev', + 'bieso-ontonotes-conll-ner.conll-2012-test'] + return load_seqtag(path, files, [0, 1]) + + +def get_sst(): + path = '/remote-home/yfshao/workdir/datasets/SST' + files = ['train.txt', 'dev.txt', 'test.txt'] + return load_sst(path, files) + + +def get_snli(): + path = '/remote-home/yfshao/workdir/datasets/nli-data/snli_1.0' + files = ['snli_1.0_train.jsonl', + 'snli_1.0_dev.jsonl', 'snli_1.0_test.jsonl'] + return load_snli(path, files) + + +g_datasets = { + 'ptb-pos': get_ptb_pos, + 'ctb-pos': get_ctb_pos, + 'conll-pos': get_conll2012_pos, + 'conll-ner': get_conll2012_ner, + 'sst-cls': get_sst, + 'snli-nli': get_snli, +} + + +def load_pretrain_emb(word_v, lang='en'): + print('loading pre-train embeddings') + emb = EmbedLoader.fast_load_embedding(300, g_emb_file_path[lang], word_v) + emb /= np.linalg.norm(emb, axis=1, keepdims=True) + emb = torch.tensor(emb, dtype=torch.float32) + print('embedding mean: {:.6}, std: {:.6}'.format(emb.mean(), emb.std())) + emb[word_v.padding_idx].fill_(0) + return emb + + +class MyCallback(FN.core.callback.Callback): + def on_train_begin(self): + super(MyCallback, self).on_train_begin() + self.init_lrs = [pg['lr'] for pg in self.optimizer.param_groups] + + def on_backward_end(self): + nn.utils.clip_grad.clip_grad_norm_(self.model.parameters(), 5.0) + + def on_step_end(self): + warm_steps = 6000 + # learning rate warm-up & decay + if self.step <= warm_steps: + for lr, pg in zip(self.init_lrs, self.optimizer.param_groups): + pg['lr'] = lr * (self.step / float(warm_steps)) + + elif self.step % 3000 == 0: + for pg in self.optimizer.param_groups: + cur_lr = pg['lr'] + pg['lr'] = max(1e-5, cur_lr*g_args.lr_decay) + + + +def train(): + seed = set_rng_seeds(1234) + print('RNG SEED {}'.format(seed)) + print('loading data') + ds_list, word_v, tag_v = g_datasets['{}-{}'.format( + g_args.ds, g_args.task)]() + print(ds_list[0][:2]) + embed = load_pretrain_emb(word_v, lang='zh' if g_args.ds == 'ctb' else 'en') + g_model_cfg['num_cls'] = len(tag_v) + print(g_model_cfg) + g_model_cfg['init_embed'] = embed + model = g_model_select[g_args.task.lower()](**g_model_cfg) + + def init_model(model): + for p in model.parameters(): + if p.size(0) != len(word_v): + nn.init.normal_(p, 0.0, 0.05) + init_model(model) + train_data = ds_list[0] + dev_data = ds_list[2] + test_data = ds_list[1] + print(tag_v.word2idx) + + if g_args.task in ['pos', 'ner']: + padding_idx = tag_v.padding_idx + else: + padding_idx = -100 + print('padding_idx ', padding_idx) + loss = FN.CrossEntropyLoss(padding_idx=padding_idx) + metrics = { + 'pos': (None, FN.AccuracyMetric()), + 'ner': ('f', FN.core.metrics.SpanFPreRecMetric( + tag_vocab=tag_v, encoding_type='bmeso', ignore_labels=[''], )), + 'cls': (None, FN.AccuracyMetric()), + 'nli': (None, FN.AccuracyMetric()), + } + metric_key, metric = metrics[g_args.task] + device = 'cuda' if torch.cuda.is_available() else 'cpu' + ex_param = [x for x in model.parameters( + ) if x.requires_grad and x.size(0) != len(word_v)] + optim_cfg = [{'params': model.enc.embedding.parameters(), 'lr': g_args.lr*0.1}, + {'params': ex_param, 'lr': g_args.lr, 'weight_decay': g_args.w_decay}, ] + trainer = FN.Trainer(model=model, train_data=train_data, dev_data=dev_data, + loss=loss, metrics=metric, metric_key=metric_key, + optimizer=torch.optim.Adam(optim_cfg), + n_epochs=g_args.ep, batch_size=g_args.bsz, print_every=10, validate_every=3000, + device=device, + use_tqdm=False, prefetch=False, + save_path=g_args.log, + callbacks=[MyCallback()]) + + trainer.train() + tester = FN.Tester(data=test_data, model=model, metrics=metric, + batch_size=128, device=device) + tester.test() + + +def test(): + pass + + +def infer(): + pass + + +run_select = { + 'train': train, + 'test': test, + 'infer': infer, +} + + +def main(): + global g_args, g_model_cfg + import signal + + def signal_handler(signal, frame): + raise KeyboardInterrupt + signal.signal(signal.SIGINT, signal_handler) + signal.signal(signal.SIGTERM, signal_handler) + parser = get_argparser() + parser.add_argument('--task', choices=['pos', 'ner', 'cls', 'nli']) + parser.add_argument('--mode', choices=['train', 'test', 'infer']) + parser.add_argument('--ds', type=str) + add_model_args(parser) + g_args = parser.parse_args() + print(g_args.__dict__) + set_gpu(g_args.gpu) + g_model_cfg = { + 'init_embed': (None, 300), + 'num_cls': None, + 'hidden_size': g_args.hidden, + 'num_layers': 4, + 'num_head': g_args.nhead, + 'head_dim': g_args.hdim, + 'max_len': MAX_LEN, + 'cls_hidden_size': 600, + 'emb_dropout': 0.3, + 'dropout': g_args.drop, + } + run_select[g_args.mode.lower()]() + + +if __name__ == '__main__': + main() diff --git a/reproduction/Star_transformer/util.py b/reproduction/Star_transformer/util.py new file mode 100644 index 00000000..ecd1e18d --- /dev/null +++ b/reproduction/Star_transformer/util.py @@ -0,0 +1,112 @@ +import fastNLP as FN +import argparse +import os +import random +import numpy +import torch + + +def get_argparser(): + parser = argparse.ArgumentParser() + parser.add_argument('--lr', type=float, required=True) + parser.add_argument('--w_decay', type=float, required=True) + parser.add_argument('--lr_decay', type=float, required=True) + parser.add_argument('--bsz', type=int, required=True) + parser.add_argument('--ep', type=int, required=True) + parser.add_argument('--drop', type=float, required=True) + parser.add_argument('--gpu', type=str, required=True) + parser.add_argument('--log', type=str, default=None) + return parser + + +def add_model_args(parser): + parser.add_argument('--nhead', type=int, default=6) + parser.add_argument('--hdim', type=int, default=50) + parser.add_argument('--hidden', type=int, default=300) + return parser + + +def set_gpu(gpu_str): + os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID" + os.environ['CUDA_VISIBLE_DEVICES'] = gpu_str + + +def set_rng_seeds(seed=None): + if seed is None: + seed = numpy.random.randint(0, 65536) + random.seed(seed) + numpy.random.seed(seed) + torch.random.manual_seed(seed) + torch.cuda.manual_seed_all(seed) + # print('RNG_SEED {}'.format(seed)) + return seed + + +class TensorboardCallback(FN.Callback): + """ + 接受以下一个或多个字符串作为参数: + - "model" + - "loss" + - "metric" + """ + + def __init__(self, *options): + super(TensorboardCallback, self).__init__() + args = {"model", "loss", "metric"} + for opt in options: + if opt not in args: + raise ValueError( + "Unrecognized argument {}. Expect one of {}".format(opt, args)) + self.options = options + self._summary_writer = None + self.graph_added = False + + def on_train_begin(self): + save_dir = self.trainer.save_path + if save_dir is None: + path = os.path.join( + "./", 'tensorboard_logs_{}'.format(self.trainer.start_time)) + else: + path = os.path.join( + save_dir, 'tensorboard_logs_{}'.format(self.trainer.start_time)) + self._summary_writer = SummaryWriter(path) + + def on_batch_begin(self, batch_x, batch_y, indices): + if "model" in self.options and self.graph_added is False: + # tesorboardX 这里有大bug,暂时没法画模型图 + # from fastNLP.core.utils import _build_args + # inputs = _build_args(self.trainer.model, **batch_x) + # args = tuple([value for value in inputs.values()]) + # args = args[0] if len(args) == 1 else args + # self._summary_writer.add_graph(self.trainer.model, torch.zeros(32, 2)) + self.graph_added = True + + def on_backward_begin(self, loss): + if "loss" in self.options: + self._summary_writer.add_scalar( + "loss", loss.item(), global_step=self.trainer.step) + + if "model" in self.options: + for name, param in self.trainer.model.named_parameters(): + if param.requires_grad: + self._summary_writer.add_scalar( + name + "_mean", param.mean(), global_step=self.trainer.step) + # self._summary_writer.add_scalar(name + "_std", param.std(), global_step=self.trainer.step) + self._summary_writer.add_scalar(name + "_grad_mean", param.grad.mean(), + global_step=self.trainer.step) + + def on_valid_end(self, eval_result, metric_key): + if "metric" in self.options: + for name, metric in eval_result.items(): + for metric_key, metric_val in metric.items(): + self._summary_writer.add_scalar("valid_{}_{}".format(name, metric_key), metric_val, + global_step=self.trainer.step) + + def on_train_end(self): + self._summary_writer.close() + del self._summary_writer + + def on_exception(self, exception): + if hasattr(self, "_summary_writer"): + self._summary_writer.close() + del self._summary_writer