@@ -26,13 +26,11 @@ class StarTransEnc(nn.Module): | |||
:param init_embed: 单词词典, 可以是 tuple, 包括(num_embedings, embedding_dim), 即 | |||
embedding的大小和每个词的维度. 也可以传入 nn.Embedding 对象, | |||
此时就以传入的对象作为embedding | |||
:param num_cls: 输出类别个数 | |||
:param hidden_size: 模型中特征维度. | |||
:param num_layers: 模型层数. | |||
:param num_head: 模型中multi-head的head个数. | |||
:param head_dim: 模型中multi-head中每个head特征维度. | |||
:param max_len: 模型能接受的最大输入长度. | |||
:param cls_hidden_size: 分类器隐层维度. | |||
:param emb_dropout: 词嵌入的dropout概率. | |||
:param dropout: 模型除词嵌入外的dropout概率. | |||
""" | |||
@@ -59,7 +57,7 @@ class StarTransEnc(nn.Module): | |||
def forward(self, x, mask): | |||
""" | |||
:param FloatTensor data: [batch, length, hidden] 输入的序列 | |||
:param FloatTensor x: [batch, length, hidden] 输入的序列 | |||
:param ByteTensor mask: [batch, length] 输入序列的padding mask, 在没有内容(padding 部分) 为 0, | |||
否则为 1 | |||
:return: [batch, length, hidden] 编码后的输出序列 | |||
@@ -110,8 +108,9 @@ class STSeqLabel(nn.Module): | |||
用于序列标注的Star-Transformer模型 | |||
:param vocab_size: 词嵌入的词典大小 | |||
:param emb_dim: 每个词嵌入的特征维度 | |||
:param init_embed: 单词词典, 可以是 tuple, 包括(num_embedings, embedding_dim), 即 | |||
embedding的大小和每个词的维度. 也可以传入 nn.Embedding 对象, | |||
此时就以传入的对象作为embedding | |||
:param num_cls: 输出类别个数 | |||
:param hidden_size: 模型中特征维度. Default: 300 | |||
:param num_layers: 模型层数. Default: 4 | |||
@@ -174,8 +173,9 @@ class STSeqCls(nn.Module): | |||
用于分类任务的Star-Transformer | |||
:param vocab_size: 词嵌入的词典大小 | |||
:param emb_dim: 每个词嵌入的特征维度 | |||
:param init_embed: 单词词典, 可以是 tuple, 包括(num_embedings, embedding_dim), 即 | |||
embedding的大小和每个词的维度. 也可以传入 nn.Embedding 对象, | |||
此时就以传入的对象作为embedding | |||
:param num_cls: 输出类别个数 | |||
:param hidden_size: 模型中特征维度. Default: 300 | |||
:param num_layers: 模型层数. Default: 4 | |||
@@ -238,8 +238,9 @@ class STNLICls(nn.Module): | |||
用于自然语言推断(NLI)的Star-Transformer | |||
:param vocab_size: 词嵌入的词典大小 | |||
:param emb_dim: 每个词嵌入的特征维度 | |||
:param init_embed: 单词词典, 可以是 tuple, 包括(num_embedings, embedding_dim), 即 | |||
embedding的大小和每个词的维度. 也可以传入 nn.Embedding 对象, | |||
此时就以传入的对象作为embedding | |||
:param num_cls: 输出类别个数 | |||
:param hidden_size: 模型中特征维度. Default: 300 | |||
:param num_layers: 模型层数. Default: 4 | |||
@@ -43,7 +43,7 @@ class StarTransformer(nn.Module): | |||
for _ in range(self.iters)]) | |||
if max_len is not None: | |||
self.pos_emb = self.pos_emb = nn.Embedding(max_len, hidden_size) | |||
self.pos_emb = nn.Embedding(max_len, hidden_size) | |||
else: | |||
self.pos_emb = None | |||
@@ -0,0 +1,157 @@ | |||
import torch | |||
import json | |||
import os | |||
from fastNLP import Vocabulary | |||
from fastNLP.io.dataset_loader import ConllLoader, SSTLoader, SNLILoader | |||
from fastNLP.core import Const as C | |||
import numpy as np | |||
MAX_LEN = 128 | |||
def update_v(vocab, data, field): | |||
data.apply(lambda x: vocab.add_word_lst(x[field]), new_field_name=None) | |||
def to_index(vocab, data, field, name): | |||
def func(x): | |||
try: | |||
return [vocab.to_index(w) for w in x[field]] | |||
except ValueError: | |||
return [vocab.padding_idx for _ in x[field]] | |||
data.apply(func, new_field_name=name) | |||
def load_seqtag(path, files, indexs): | |||
word_h, tag_h = 'words', 'tags' | |||
loader = ConllLoader(headers=[word_h, tag_h], indexes=indexs) | |||
ds_list = [] | |||
for fn in files: | |||
ds_list.append(loader.load(os.path.join(path, fn))) | |||
word_v = Vocabulary(min_freq=2) | |||
tag_v = Vocabulary(unknown=None) | |||
update_v(word_v, ds_list[0], word_h) | |||
update_v(tag_v, ds_list[0], tag_h) | |||
def process_data(ds): | |||
to_index(word_v, ds, word_h, C.INPUT) | |||
to_index(tag_v, ds, tag_h, C.TARGET) | |||
ds.apply(lambda x: x[C.INPUT][:MAX_LEN], new_field_name=C.INPUT) | |||
ds.apply(lambda x: x[C.TARGET][:MAX_LEN], new_field_name=C.TARGET) | |||
ds.apply(lambda x: len(x[word_h]), new_field_name=C.INPUT_LEN) | |||
ds.set_input(C.INPUT, C.INPUT_LEN) | |||
ds.set_target(C.TARGET, C.INPUT_LEN) | |||
for i in range(len(ds_list)): | |||
process_data(ds_list[i]) | |||
return ds_list, word_v, tag_v | |||
def load_sst(path, files): | |||
loaders = [SSTLoader(subtree=sub, fine_grained=True) | |||
for sub in [True, False, False]] | |||
ds_list = [loader.load(os.path.join(path, fn)) | |||
for fn, loader in zip(files, loaders)] | |||
word_v = Vocabulary(min_freq=2) | |||
tag_v = Vocabulary(unknown=None, padding=None) | |||
for ds in ds_list: | |||
ds.apply(lambda x: [w.lower() | |||
for w in x['words']], new_field_name='words') | |||
ds_list[0].drop(lambda x: len(x['words']) < 3) | |||
update_v(word_v, ds_list[0], 'words') | |||
ds_list[0].apply(lambda x: tag_v.add_word( | |||
x['target']), new_field_name=None) | |||
def process_data(ds): | |||
to_index(word_v, ds, 'words', C.INPUT) | |||
ds.apply(lambda x: tag_v.to_index(x['target']), new_field_name=C.TARGET) | |||
ds.apply(lambda x: x[C.INPUT][:MAX_LEN], new_field_name=C.INPUT) | |||
ds.apply(lambda x: len(x['words']), new_field_name=C.INPUT_LEN) | |||
ds.set_input(C.INPUT, C.INPUT_LEN) | |||
ds.set_target(C.TARGET) | |||
for i in range(len(ds_list)): | |||
process_data(ds_list[i]) | |||
return ds_list, word_v, tag_v | |||
def load_snli(path, files): | |||
loader = SNLILoader() | |||
ds_list = [loader.load(os.path.join(path, f)) for f in files] | |||
word_v = Vocabulary(min_freq=2) | |||
tag_v = Vocabulary(unknown=None, padding=None) | |||
for ds in ds_list: | |||
ds.apply(lambda x: [w.lower() | |||
for w in x['words1']], new_field_name='words1') | |||
ds.apply(lambda x: [w.lower() | |||
for w in x['words2']], new_field_name='words2') | |||
update_v(word_v, ds_list[0], 'words1') | |||
update_v(word_v, ds_list[0], 'words2') | |||
ds_list[0].apply(lambda x: tag_v.add_word( | |||
x['target']), new_field_name=None) | |||
def process_data(ds): | |||
to_index(word_v, ds, 'words1', C.INPUTS(0)) | |||
to_index(word_v, ds, 'words2', C.INPUTS(1)) | |||
ds.apply(lambda x: tag_v.to_index(x['target']), new_field_name=C.TARGET) | |||
ds.apply(lambda x: x[C.INPUTS(0)][:MAX_LEN], new_field_name=C.INPUTS(0)) | |||
ds.apply(lambda x: x[C.INPUTS(1)][:MAX_LEN], new_field_name=C.INPUTS(1)) | |||
ds.apply(lambda x: len(x[C.INPUTS(0)]), new_field_name=C.INPUT_LENS(0)) | |||
ds.apply(lambda x: len(x[C.INPUTS(1)]), new_field_name=C.INPUT_LENS(1)) | |||
ds.set_input(C.INPUTS(0), C.INPUTS(1), C.INPUT_LENS(0), C.INPUT_LENS(1)) | |||
ds.set_target(C.TARGET) | |||
for i in range(len(ds_list)): | |||
process_data(ds_list[i]) | |||
return ds_list, word_v, tag_v | |||
class EmbedLoader: | |||
@staticmethod | |||
def parse_glove_line(line): | |||
line = line.split() | |||
if len(line) <= 2: | |||
raise RuntimeError( | |||
"something goes wrong in parsing glove embedding") | |||
return line[0], line[1:] | |||
@staticmethod | |||
def str_list_2_vec(line): | |||
return torch.Tensor(list(map(float, line))) | |||
@staticmethod | |||
def fast_load_embedding(emb_dim, emb_file, vocab): | |||
"""Fast load the pre-trained embedding and combine with the given dictionary. | |||
This loading method uses line-by-line operation. | |||
:param int emb_dim: the dimension of the embedding. Should be the same as pre-trained embedding. | |||
:param str emb_file: the pre-trained embedding file path. | |||
:param Vocabulary vocab: a mapping from word to index, can be provided by user or built from pre-trained embedding | |||
:return embedding_matrix: numpy.ndarray | |||
""" | |||
if vocab is None: | |||
raise RuntimeError("You must provide a vocabulary.") | |||
embedding_matrix = np.zeros( | |||
shape=(len(vocab), emb_dim), dtype=np.float32) | |||
hit_flags = np.zeros(shape=(len(vocab),), dtype=int) | |||
with open(emb_file, "r", encoding="utf-8") as f: | |||
startline = f.readline() | |||
if len(startline.split()) > 2: | |||
f.seek(0) | |||
for line in f: | |||
word, vector = EmbedLoader.parse_glove_line(line) | |||
try: | |||
if word in vocab: | |||
vector = EmbedLoader.str_list_2_vec(vector) | |||
if emb_dim != vector.size(0): | |||
continue | |||
embedding_matrix[vocab[word]] = vector | |||
hit_flags[vocab[word]] = 1 | |||
except Exception: | |||
continue | |||
if np.sum(hit_flags) < len(vocab): | |||
# some words from vocab are missing in pre-trained embedding | |||
# we normally sample each dimension | |||
vocab_embed = embedding_matrix[np.where(hit_flags)] | |||
sampled_vectors = np.random.normal(vocab_embed.mean(axis=0), vocab_embed.std(axis=0), | |||
size=(len(vocab) - np.sum(hit_flags), emb_dim)) | |||
embedding_matrix[np.where(1 - hit_flags)] = sampled_vectors | |||
return embedding_matrix |
@@ -0,0 +1,56 @@ | |||
import torch | |||
import torch.nn as nn | |||
import torch.nn.functional as F | |||
import numpy as np | |||
from fastNLP.core.losses import LossBase | |||
reduce_func = { | |||
'none': lambda x, mask: x*mask, | |||
'sum': lambda x, mask: (x*mask).sum(), | |||
'mean': lambda x, mask: (x*mask).sum() / mask.sum(), | |||
} | |||
class LabelSmoothCrossEntropy(nn.Module): | |||
def __init__(self, smoothing=0.1, ignore_index=-100, reduction='mean'): | |||
global reduce_func | |||
super().__init__() | |||
if smoothing < 0 or smoothing > 1: | |||
raise ValueError('invalid smoothing value: {}'.format(smoothing)) | |||
self.smoothing = smoothing | |||
self.ignore_index = ignore_index | |||
if reduction not in reduce_func: | |||
raise ValueError('invalid reduce type: {}'.format(reduction)) | |||
self.reduce_func = reduce_func[reduction] | |||
def forward(self, input, target): | |||
input = F.log_softmax(input, dim=1) # [N, C, ...] | |||
smooth_val = self.smoothing / input.size(1) # [N, C, ...] | |||
target_logit = input.new_full(input.size(), fill_value=smooth_val) | |||
target_logit.scatter_(1, target[:, None], 1 - self.smoothing) | |||
result = -(target_logit * input).sum(1) # [N, ...] | |||
mask = (target != self.ignore_index).float() | |||
return self.reduce_func(result, mask) | |||
class SmoothCE(LossBase): | |||
def __init__(self, pred=None, target=None, **kwargs): | |||
super().__init__() | |||
self.loss_fn = LabelSmoothCrossEntropy(**kwargs) | |||
self._init_param_map(pred=pred, target=target) | |||
def get_loss(self, pred, target): | |||
return self.loss_fn(pred, target) | |||
if __name__ == '__main__': | |||
loss_fn = nn.CrossEntropyLoss(ignore_index=0) | |||
sm_loss_fn = LabelSmoothCrossEntropy(smoothing=0, ignore_index=0) | |||
predict = torch.tensor([[0, 0.2, 0.7, 0.1, 0], | |||
[0, 0.9, 0.2, 0.1, 0], | |||
[1, 0.2, 0.7, 0.1, 0]]) | |||
target = torch.tensor([2, 1, 0]) | |||
loss = loss_fn(predict, target) | |||
sm_loss = sm_loss_fn(predict, target) | |||
print(loss, sm_loss) |
@@ -0,0 +1,5 @@ | |||
#python -u train.py --task pos --ds conll --mode train --gpu 1 --lr 3e-4 --w_decay 2e-5 --lr_decay .95 --drop 0.3 --ep 25 --bsz 64 > conll_pos102.log 2>&1 & | |||
#python -u train.py --task pos --ds ctb --mode train --gpu 1 --lr 3e-4 --w_decay 2e-5 --lr_decay .95 --drop 0.3 --ep 25 --bsz 64 > ctb_pos101.log 2>&1 & | |||
#python -u train.py --task cls --ds sst --mode train --gpu 2 --lr 1e-4 --w_decay 1e-5 --lr_decay 0.9 --drop 0.5 --ep 50 --bsz 128 > sst_cls201.log & | |||
#python -u train.py --task nli --ds snli --mode train --gpu 1 --lr 1e-4 --w_decay 1e-5 --lr_decay 0.9 --drop 0.4 --ep 120 --bsz 128 > snli_nli201.log & | |||
python -u train.py --task ner --ds conll --mode train --gpu 0 --lr 1e-4 --w_decay 1e-5 --lr_decay 0.9 --drop 0.4 --ep 120 --bsz 64 > conll_ner201.log & |
@@ -0,0 +1,214 @@ | |||
from util import get_argparser, set_gpu, set_rng_seeds, add_model_args | |||
from datasets import load_seqtag, load_sst, load_snli, EmbedLoader, MAX_LEN | |||
import torch.nn as nn | |||
import torch | |||
import numpy as np | |||
import fastNLP as FN | |||
from fastNLP.models.star_transformer import STSeqLabel, STSeqCls, STNLICls | |||
from fastNLP.core.const import Const as C | |||
import sys | |||
sys.path.append('/remote-home/yfshao/workdir/dev_fastnlp/') | |||
g_model_select = { | |||
'pos': STSeqLabel, | |||
'ner': STSeqLabel, | |||
'cls': STSeqCls, | |||
'nli': STNLICls, | |||
} | |||
g_emb_file_path = {'en': '/remote-home/yfshao/workdir/datasets/word_vector/glove.840B.300d.txt', | |||
'zh': '/remote-home/yfshao/workdir/datasets/word_vector/cc.zh.300.vec'} | |||
g_args = None | |||
g_model_cfg = None | |||
def get_ptb_pos(): | |||
pos_dir = '/remote-home/yfshao/workdir/datasets/pos' | |||
pos_files = ['train.pos', 'dev.pos', 'test.pos', ] | |||
return load_seqtag(pos_dir, pos_files, [0, 1]) | |||
def get_ctb_pos(): | |||
ctb_dir = '/remote-home/yfshao/workdir/datasets/ctb9_hy' | |||
files = ['train.conllx', 'dev.conllx', 'test.conllx'] | |||
return load_seqtag(ctb_dir, files, [1, 4]) | |||
def get_conll2012_pos(): | |||
path = '/remote-home/yfshao/workdir/datasets/ontonotes/pos' | |||
files = ['ontonotes-conll.train', | |||
'ontonotes-conll.dev', | |||
'ontonotes-conll.conll-2012-test'] | |||
return load_seqtag(path, files, [0, 1]) | |||
def get_conll2012_ner(): | |||
path = '/remote-home/yfshao/workdir/datasets/ontonotes/ner' | |||
files = ['bieso-ontonotes-conll-ner.train', | |||
'bieso-ontonotes-conll-ner.dev', | |||
'bieso-ontonotes-conll-ner.conll-2012-test'] | |||
return load_seqtag(path, files, [0, 1]) | |||
def get_sst(): | |||
path = '/remote-home/yfshao/workdir/datasets/SST' | |||
files = ['train.txt', 'dev.txt', 'test.txt'] | |||
return load_sst(path, files) | |||
def get_snli(): | |||
path = '/remote-home/yfshao/workdir/datasets/nli-data/snli_1.0' | |||
files = ['snli_1.0_train.jsonl', | |||
'snli_1.0_dev.jsonl', 'snli_1.0_test.jsonl'] | |||
return load_snli(path, files) | |||
g_datasets = { | |||
'ptb-pos': get_ptb_pos, | |||
'ctb-pos': get_ctb_pos, | |||
'conll-pos': get_conll2012_pos, | |||
'conll-ner': get_conll2012_ner, | |||
'sst-cls': get_sst, | |||
'snli-nli': get_snli, | |||
} | |||
def load_pretrain_emb(word_v, lang='en'): | |||
print('loading pre-train embeddings') | |||
emb = EmbedLoader.fast_load_embedding(300, g_emb_file_path[lang], word_v) | |||
emb /= np.linalg.norm(emb, axis=1, keepdims=True) | |||
emb = torch.tensor(emb, dtype=torch.float32) | |||
print('embedding mean: {:.6}, std: {:.6}'.format(emb.mean(), emb.std())) | |||
emb[word_v.padding_idx].fill_(0) | |||
return emb | |||
class MyCallback(FN.core.callback.Callback): | |||
def on_train_begin(self): | |||
super(MyCallback, self).on_train_begin() | |||
self.init_lrs = [pg['lr'] for pg in self.optimizer.param_groups] | |||
def on_backward_end(self): | |||
nn.utils.clip_grad.clip_grad_norm_(self.model.parameters(), 5.0) | |||
def on_step_end(self): | |||
warm_steps = 6000 | |||
# learning rate warm-up & decay | |||
if self.step <= warm_steps: | |||
for lr, pg in zip(self.init_lrs, self.optimizer.param_groups): | |||
pg['lr'] = lr * (self.step / float(warm_steps)) | |||
elif self.step % 3000 == 0: | |||
for pg in self.optimizer.param_groups: | |||
cur_lr = pg['lr'] | |||
pg['lr'] = max(1e-5, cur_lr*g_args.lr_decay) | |||
def train(): | |||
seed = set_rng_seeds(1234) | |||
print('RNG SEED {}'.format(seed)) | |||
print('loading data') | |||
ds_list, word_v, tag_v = g_datasets['{}-{}'.format( | |||
g_args.ds, g_args.task)]() | |||
print(ds_list[0][:2]) | |||
embed = load_pretrain_emb(word_v, lang='zh' if g_args.ds == 'ctb' else 'en') | |||
g_model_cfg['num_cls'] = len(tag_v) | |||
print(g_model_cfg) | |||
g_model_cfg['init_embed'] = embed | |||
model = g_model_select[g_args.task.lower()](**g_model_cfg) | |||
def init_model(model): | |||
for p in model.parameters(): | |||
if p.size(0) != len(word_v): | |||
nn.init.normal_(p, 0.0, 0.05) | |||
init_model(model) | |||
train_data = ds_list[0] | |||
dev_data = ds_list[2] | |||
test_data = ds_list[1] | |||
print(tag_v.word2idx) | |||
if g_args.task in ['pos', 'ner']: | |||
padding_idx = tag_v.padding_idx | |||
else: | |||
padding_idx = -100 | |||
print('padding_idx ', padding_idx) | |||
loss = FN.CrossEntropyLoss(padding_idx=padding_idx) | |||
metrics = { | |||
'pos': (None, FN.AccuracyMetric()), | |||
'ner': ('f', FN.core.metrics.SpanFPreRecMetric( | |||
tag_vocab=tag_v, encoding_type='bmeso', ignore_labels=[''], )), | |||
'cls': (None, FN.AccuracyMetric()), | |||
'nli': (None, FN.AccuracyMetric()), | |||
} | |||
metric_key, metric = metrics[g_args.task] | |||
device = 'cuda' if torch.cuda.is_available() else 'cpu' | |||
ex_param = [x for x in model.parameters( | |||
) if x.requires_grad and x.size(0) != len(word_v)] | |||
optim_cfg = [{'params': model.enc.embedding.parameters(), 'lr': g_args.lr*0.1}, | |||
{'params': ex_param, 'lr': g_args.lr, 'weight_decay': g_args.w_decay}, ] | |||
trainer = FN.Trainer(model=model, train_data=train_data, dev_data=dev_data, | |||
loss=loss, metrics=metric, metric_key=metric_key, | |||
optimizer=torch.optim.Adam(optim_cfg), | |||
n_epochs=g_args.ep, batch_size=g_args.bsz, print_every=10, validate_every=3000, | |||
device=device, | |||
use_tqdm=False, prefetch=False, | |||
save_path=g_args.log, | |||
callbacks=[MyCallback()]) | |||
trainer.train() | |||
tester = FN.Tester(data=test_data, model=model, metrics=metric, | |||
batch_size=128, device=device) | |||
tester.test() | |||
def test(): | |||
pass | |||
def infer(): | |||
pass | |||
run_select = { | |||
'train': train, | |||
'test': test, | |||
'infer': infer, | |||
} | |||
def main(): | |||
global g_args, g_model_cfg | |||
import signal | |||
def signal_handler(signal, frame): | |||
raise KeyboardInterrupt | |||
signal.signal(signal.SIGINT, signal_handler) | |||
signal.signal(signal.SIGTERM, signal_handler) | |||
parser = get_argparser() | |||
parser.add_argument('--task', choices=['pos', 'ner', 'cls', 'nli']) | |||
parser.add_argument('--mode', choices=['train', 'test', 'infer']) | |||
parser.add_argument('--ds', type=str) | |||
add_model_args(parser) | |||
g_args = parser.parse_args() | |||
print(g_args.__dict__) | |||
set_gpu(g_args.gpu) | |||
g_model_cfg = { | |||
'init_embed': (None, 300), | |||
'num_cls': None, | |||
'hidden_size': g_args.hidden, | |||
'num_layers': 4, | |||
'num_head': g_args.nhead, | |||
'head_dim': g_args.hdim, | |||
'max_len': MAX_LEN, | |||
'cls_hidden_size': 600, | |||
'emb_dropout': 0.3, | |||
'dropout': g_args.drop, | |||
} | |||
run_select[g_args.mode.lower()]() | |||
if __name__ == '__main__': | |||
main() |
@@ -0,0 +1,112 @@ | |||
import fastNLP as FN | |||
import argparse | |||
import os | |||
import random | |||
import numpy | |||
import torch | |||
def get_argparser(): | |||
parser = argparse.ArgumentParser() | |||
parser.add_argument('--lr', type=float, required=True) | |||
parser.add_argument('--w_decay', type=float, required=True) | |||
parser.add_argument('--lr_decay', type=float, required=True) | |||
parser.add_argument('--bsz', type=int, required=True) | |||
parser.add_argument('--ep', type=int, required=True) | |||
parser.add_argument('--drop', type=float, required=True) | |||
parser.add_argument('--gpu', type=str, required=True) | |||
parser.add_argument('--log', type=str, default=None) | |||
return parser | |||
def add_model_args(parser): | |||
parser.add_argument('--nhead', type=int, default=6) | |||
parser.add_argument('--hdim', type=int, default=50) | |||
parser.add_argument('--hidden', type=int, default=300) | |||
return parser | |||
def set_gpu(gpu_str): | |||
os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID" | |||
os.environ['CUDA_VISIBLE_DEVICES'] = gpu_str | |||
def set_rng_seeds(seed=None): | |||
if seed is None: | |||
seed = numpy.random.randint(0, 65536) | |||
random.seed(seed) | |||
numpy.random.seed(seed) | |||
torch.random.manual_seed(seed) | |||
torch.cuda.manual_seed_all(seed) | |||
# print('RNG_SEED {}'.format(seed)) | |||
return seed | |||
class TensorboardCallback(FN.Callback): | |||
""" | |||
接受以下一个或多个字符串作为参数: | |||
- "model" | |||
- "loss" | |||
- "metric" | |||
""" | |||
def __init__(self, *options): | |||
super(TensorboardCallback, self).__init__() | |||
args = {"model", "loss", "metric"} | |||
for opt in options: | |||
if opt not in args: | |||
raise ValueError( | |||
"Unrecognized argument {}. Expect one of {}".format(opt, args)) | |||
self.options = options | |||
self._summary_writer = None | |||
self.graph_added = False | |||
def on_train_begin(self): | |||
save_dir = self.trainer.save_path | |||
if save_dir is None: | |||
path = os.path.join( | |||
"./", 'tensorboard_logs_{}'.format(self.trainer.start_time)) | |||
else: | |||
path = os.path.join( | |||
save_dir, 'tensorboard_logs_{}'.format(self.trainer.start_time)) | |||
self._summary_writer = SummaryWriter(path) | |||
def on_batch_begin(self, batch_x, batch_y, indices): | |||
if "model" in self.options and self.graph_added is False: | |||
# tesorboardX 这里有大bug,暂时没法画模型图 | |||
# from fastNLP.core.utils import _build_args | |||
# inputs = _build_args(self.trainer.model, **batch_x) | |||
# args = tuple([value for value in inputs.values()]) | |||
# args = args[0] if len(args) == 1 else args | |||
# self._summary_writer.add_graph(self.trainer.model, torch.zeros(32, 2)) | |||
self.graph_added = True | |||
def on_backward_begin(self, loss): | |||
if "loss" in self.options: | |||
self._summary_writer.add_scalar( | |||
"loss", loss.item(), global_step=self.trainer.step) | |||
if "model" in self.options: | |||
for name, param in self.trainer.model.named_parameters(): | |||
if param.requires_grad: | |||
self._summary_writer.add_scalar( | |||
name + "_mean", param.mean(), global_step=self.trainer.step) | |||
# self._summary_writer.add_scalar(name + "_std", param.std(), global_step=self.trainer.step) | |||
self._summary_writer.add_scalar(name + "_grad_mean", param.grad.mean(), | |||
global_step=self.trainer.step) | |||
def on_valid_end(self, eval_result, metric_key): | |||
if "metric" in self.options: | |||
for name, metric in eval_result.items(): | |||
for metric_key, metric_val in metric.items(): | |||
self._summary_writer.add_scalar("valid_{}_{}".format(name, metric_key), metric_val, | |||
global_step=self.trainer.step) | |||
def on_train_end(self): | |||
self._summary_writer.close() | |||
del self._summary_writer | |||
def on_exception(self, exception): | |||
if hasattr(self, "_summary_writer"): | |||
self._summary_writer.close() | |||
del self._summary_writer |