Browse Source

update trainer

tags/v0.2.0
yunfan 6 years ago
parent
commit
3192c9ac66
7 changed files with 157 additions and 72 deletions
  1. +3
    -0
      fastNLP/core/field.py
  2. +3
    -0
      fastNLP/core/instance.py
  3. +1
    -1
      fastNLP/core/tester.py
  4. +20
    -14
      fastNLP/core/trainer.py
  5. +30
    -10
      fastNLP/models/biaffine_parser.py
  6. +6
    -5
      reproduction/Biaffine_parser/cfg.cfg
  7. +94
    -42
      reproduction/Biaffine_parser/run.py

+ 3
- 0
fastNLP/core/field.py View File

@@ -24,6 +24,9 @@ class Field(object):
def __repr__(self): def __repr__(self):
return self.contents().__repr__() return self.contents().__repr__()


def new(self, *args, **kwargs):
return self.__class__(*args, **kwargs, is_target=self.is_target)

class TextField(Field): class TextField(Field):
def __init__(self, text, is_target): def __init__(self, text, is_target):
""" """


+ 3
- 0
fastNLP/core/instance.py View File

@@ -35,6 +35,9 @@ class Instance(object):
else: else:
raise KeyError("{} not found".format(name)) raise KeyError("{} not found".format(name))


def __setitem__(self, name, field):
return self.add_field(name, field)

def get_length(self): def get_length(self):
"""Fetch the length of all fields in the instance. """Fetch the length of all fields in the instance.




+ 1
- 1
fastNLP/core/tester.py View File

@@ -74,7 +74,7 @@ class Tester(object):
output_list = [] output_list = []
truth_list = [] truth_list = []


data_iterator = Batch(dev_data, self.batch_size, sampler=RandomSampler(), use_cuda=self.use_cuda)
data_iterator = Batch(dev_data, self.batch_size, sampler=RandomSampler(), use_cuda=self.use_cuda, sort_in_batch=True, sort_key='word_seq')


with torch.no_grad(): with torch.no_grad():
for batch_x, batch_y in data_iterator: for batch_x, batch_y in data_iterator:


+ 20
- 14
fastNLP/core/trainer.py View File

@@ -1,6 +1,6 @@
import os import os
import time import time
from datetime import timedelta
from datetime import timedelta, datetime


import torch import torch
from tensorboardX import SummaryWriter from tensorboardX import SummaryWriter
@@ -15,7 +15,7 @@ from fastNLP.saver.logger import create_logger
from fastNLP.saver.model_saver import ModelSaver from fastNLP.saver.model_saver import ModelSaver


logger = create_logger(__name__, "./train_test.log") logger = create_logger(__name__, "./train_test.log")
logger.disabled = True


class Trainer(object): class Trainer(object):
"""Operations of training a model, including data loading, gradient descent, and validation. """Operations of training a model, including data loading, gradient descent, and validation.
@@ -42,7 +42,7 @@ class Trainer(object):
""" """
default_args = {"epochs": 1, "batch_size": 2, "validate": False, "use_cuda": False, "pickle_path": "./save/", default_args = {"epochs": 1, "batch_size": 2, "validate": False, "use_cuda": False, "pickle_path": "./save/",
"save_best_dev": False, "model_name": "default_model_name.pkl", "print_every_step": 1, "save_best_dev": False, "model_name": "default_model_name.pkl", "print_every_step": 1,
"valid_step": 500, "eval_sort_key": None,
"valid_step": 500, "eval_sort_key": 'acc',
"loss": Loss(None), # used to pass type check "loss": Loss(None), # used to pass type check
"optimizer": Optimizer("Adam", lr=0.001, weight_decay=0), "optimizer": Optimizer("Adam", lr=0.001, weight_decay=0),
"evaluator": Evaluator() "evaluator": Evaluator()
@@ -111,13 +111,17 @@ class Trainer(object):
else: else:
self._model = network self._model = network


print(self._model)

# define Tester over dev data # define Tester over dev data
self.dev_data = None
if self.validate: if self.validate:
default_valid_args = {"batch_size": self.batch_size, "pickle_path": self.pickle_path, default_valid_args = {"batch_size": self.batch_size, "pickle_path": self.pickle_path,
"use_cuda": self.use_cuda, "evaluator": self._evaluator} "use_cuda": self.use_cuda, "evaluator": self._evaluator}
if self.validator is None: if self.validator is None:
self.validator = self._create_validator(default_valid_args) self.validator = self._create_validator(default_valid_args)
logger.info("validator defined as {}".format(str(self.validator))) logger.info("validator defined as {}".format(str(self.validator)))
self.dev_data = dev_data


# optimizer and loss # optimizer and loss
self.define_optimizer() self.define_optimizer()
@@ -130,7 +134,7 @@ class Trainer(object):


# main training procedure # main training procedure
start = time.time() start = time.time()
self.start_time = str(start)
self.start_time = str(datetime.now().strftime('%Y-%m-%d-%H-%M'))


logger.info("training epochs started " + self.start_time) logger.info("training epochs started " + self.start_time)
epoch, iters = 1, 0 epoch, iters = 1, 0
@@ -141,15 +145,17 @@ class Trainer(object):


# prepare mini-batch iterator # prepare mini-batch iterator
data_iterator = Batch(train_data, batch_size=self.batch_size, sampler=RandomSampler(), data_iterator = Batch(train_data, batch_size=self.batch_size, sampler=RandomSampler(),
use_cuda=self.use_cuda)
use_cuda=self.use_cuda, sort_in_batch=True, sort_key='word_seq')
logger.info("prepared data iterator") logger.info("prepared data iterator")


# one forward and backward pass # one forward and backward pass
iters += self._train_step(data_iterator, network, start=start, n_print=self.print_every_step, epoch=epoch, step=iters, dev_data=dev_data)
iters = self._train_step(data_iterator, network, start=start, n_print=self.print_every_step, epoch=epoch, step=iters, dev_data=dev_data)


# validation # validation
if self.validate: if self.validate:
self.valid_model() self.valid_model()
self.save_model(self._model, 'training_model_'+self.start_time)
epoch += 1


def _train_step(self, data_iterator, network, **kwargs): def _train_step(self, data_iterator, network, **kwargs):
"""Training process in one epoch. """Training process in one epoch.
@@ -160,13 +166,16 @@ class Trainer(object):
- epoch: int, - epoch: int,
""" """
step = kwargs['step'] step = kwargs['step']
dev_data = kwargs['dev_data']
for batch_x, batch_y in data_iterator: for batch_x, batch_y in data_iterator:

prediction = self.data_forward(network, batch_x) prediction = self.data_forward(network, batch_x)


loss = self.get_loss(prediction, batch_y) loss = self.get_loss(prediction, batch_y)
self.grad_backward(loss) self.grad_backward(loss)
if torch.rand(1).item() < 0.001:
print('[grads at epoch: {:>3} step: {:>4}]'.format(kwargs['epoch'], step))
for name, p in self._model.named_parameters():
if p.requires_grad:
print('\t{} {} {}'.format(name, tuple(p.size()), torch.sum(p.grad).item()))
self.update() self.update()
self._summary_writer.add_scalar("loss", loss.item(), global_step=step) self._summary_writer.add_scalar("loss", loss.item(), global_step=step)


@@ -183,13 +192,14 @@ class Trainer(object):
return step return step


def valid_model(self): def valid_model(self):
if dev_data is None:
if self.dev_data is None:
raise RuntimeError( raise RuntimeError(
"self.validate is True in trainer, but dev_data is None. Please provide the validation data.") "self.validate is True in trainer, but dev_data is None. Please provide the validation data.")
logger.info("validation started") logger.info("validation started")
res = self.validator.test(network, dev_data)
res = self.validator.test(self._model, self.dev_data)
if self.save_best_dev and self.best_eval_result(res): if self.save_best_dev and self.best_eval_result(res):
logger.info('save best result! {}'.format(res)) logger.info('save best result! {}'.format(res))
print('save best result! {}'.format(res))
self.save_model(self._model, 'best_model_'+self.start_time) self.save_model(self._model, 'best_model_'+self.start_time)
return res return res


@@ -282,14 +292,10 @@ class Trainer(object):
""" """
if isinstance(metrics, tuple): if isinstance(metrics, tuple):
loss, metrics = metrics loss, metrics = metrics
else:
metrics = validator.metrics


if isinstance(metrics, dict): if isinstance(metrics, dict):
if len(metrics) == 1: if len(metrics) == 1:
accuracy = list(metrics.values())[0] accuracy = list(metrics.values())[0]
elif self.eval_sort_key is None:
raise ValueError('dict format metrics should provide sort key for eval best result')
else: else:
accuracy = metrics[self.eval_sort_key] accuracy = metrics[self.eval_sort_key]
else: else:


+ 30
- 10
fastNLP/models/biaffine_parser.py View File

@@ -199,6 +199,8 @@ class BiaffineParser(GraphParser):
word_emb_dim, word_emb_dim,
pos_vocab_size, pos_vocab_size,
pos_emb_dim, pos_emb_dim,
word_hid_dim,
pos_hid_dim,
rnn_layers, rnn_layers,
rnn_hidden_size, rnn_hidden_size,
arc_mlp_size, arc_mlp_size,
@@ -209,10 +211,15 @@ class BiaffineParser(GraphParser):
use_greedy_infer=False): use_greedy_infer=False):


super(BiaffineParser, self).__init__() super(BiaffineParser, self).__init__()
rnn_out_size = 2 * rnn_hidden_size
self.word_embedding = nn.Embedding(num_embeddings=word_vocab_size, embedding_dim=word_emb_dim) self.word_embedding = nn.Embedding(num_embeddings=word_vocab_size, embedding_dim=word_emb_dim)
self.pos_embedding = nn.Embedding(num_embeddings=pos_vocab_size, embedding_dim=pos_emb_dim) self.pos_embedding = nn.Embedding(num_embeddings=pos_vocab_size, embedding_dim=pos_emb_dim)
self.word_fc = nn.Linear(word_emb_dim, word_hid_dim)
self.pos_fc = nn.Linear(pos_emb_dim, pos_hid_dim)
self.word_norm = nn.LayerNorm(word_hid_dim)
self.pos_norm = nn.LayerNorm(pos_hid_dim)
if use_var_lstm: if use_var_lstm:
self.lstm = VarLSTM(input_size=word_emb_dim + pos_emb_dim,
self.lstm = VarLSTM(input_size=word_hid_dim + pos_hid_dim,
hidden_size=rnn_hidden_size, hidden_size=rnn_hidden_size,
num_layers=rnn_layers, num_layers=rnn_layers,
bias=True, bias=True,
@@ -221,7 +228,7 @@ class BiaffineParser(GraphParser):
hidden_dropout=dropout, hidden_dropout=dropout,
bidirectional=True) bidirectional=True)
else: else:
self.lstm = nn.LSTM(input_size=word_emb_dim + pos_emb_dim,
self.lstm = nn.LSTM(input_size=word_hid_dim + pos_hid_dim,
hidden_size=rnn_hidden_size, hidden_size=rnn_hidden_size,
num_layers=rnn_layers, num_layers=rnn_layers,
bias=True, bias=True,
@@ -229,12 +236,13 @@ class BiaffineParser(GraphParser):
dropout=dropout, dropout=dropout,
bidirectional=True) bidirectional=True)


rnn_out_size = 2 * rnn_hidden_size
self.arc_head_mlp = nn.Sequential(nn.Linear(rnn_out_size, arc_mlp_size), self.arc_head_mlp = nn.Sequential(nn.Linear(rnn_out_size, arc_mlp_size),
nn.LayerNorm(arc_mlp_size),
nn.ELU(), nn.ELU(),
TimestepDropout(p=dropout),) TimestepDropout(p=dropout),)
self.arc_dep_mlp = copy.deepcopy(self.arc_head_mlp) self.arc_dep_mlp = copy.deepcopy(self.arc_head_mlp)
self.label_head_mlp = nn.Sequential(nn.Linear(rnn_out_size, label_mlp_size), self.label_head_mlp = nn.Sequential(nn.Linear(rnn_out_size, label_mlp_size),
nn.LayerNorm(label_mlp_size),
nn.ELU(), nn.ELU(),
TimestepDropout(p=dropout),) TimestepDropout(p=dropout),)
self.label_dep_mlp = copy.deepcopy(self.label_head_mlp) self.label_dep_mlp = copy.deepcopy(self.label_head_mlp)
@@ -242,10 +250,18 @@ class BiaffineParser(GraphParser):
self.label_predictor = LabelBilinear(label_mlp_size, label_mlp_size, num_label, bias=True) self.label_predictor = LabelBilinear(label_mlp_size, label_mlp_size, num_label, bias=True)
self.normal_dropout = nn.Dropout(p=dropout) self.normal_dropout = nn.Dropout(p=dropout)
self.use_greedy_infer = use_greedy_infer self.use_greedy_infer = use_greedy_infer
initial_parameter(self)
self.word_norm = nn.LayerNorm(word_emb_dim)
self.pos_norm = nn.LayerNorm(pos_emb_dim)
self.lstm_norm = nn.LayerNorm(rnn_out_size)
self.reset_parameters()

def reset_parameters(self):
for m in self.modules():
if isinstance(m, nn.Embedding):
continue
elif isinstance(m, nn.LayerNorm):
nn.init.constant_(m.weight, 1)
nn.init.constant_(m.bias, 0)
else:
for p in m.parameters():
nn.init.normal_(p, 0, 0.01)


def forward(self, word_seq, pos_seq, word_seq_origin_len, gold_heads=None, **_): def forward(self, word_seq, pos_seq, word_seq_origin_len, gold_heads=None, **_):
""" """
@@ -262,19 +278,21 @@ class BiaffineParser(GraphParser):
# prepare embeddings # prepare embeddings
batch_size, seq_len = word_seq.shape batch_size, seq_len = word_seq.shape
# print('forward {} {}'.format(batch_size, seq_len)) # print('forward {} {}'.format(batch_size, seq_len))
batch_range = torch.arange(start=0, end=batch_size, dtype=torch.long, device=word_seq.device).unsqueeze(1)


# get sequence mask # get sequence mask
seq_mask = len2masks(word_seq_origin_len, seq_len).long() seq_mask = len2masks(word_seq_origin_len, seq_len).long()


word = self.normal_dropout(self.word_embedding(word_seq)) # [N,L] -> [N,L,C_0] word = self.normal_dropout(self.word_embedding(word_seq)) # [N,L] -> [N,L,C_0]
pos = self.normal_dropout(self.pos_embedding(pos_seq)) # [N,L] -> [N,L,C_1] pos = self.normal_dropout(self.pos_embedding(pos_seq)) # [N,L] -> [N,L,C_1]
word, pos = self.word_fc(word), self.pos_fc(pos)
word, pos = self.word_norm(word), self.pos_norm(pos) word, pos = self.word_norm(word), self.pos_norm(pos)
x = torch.cat([word, pos], dim=2) # -> [N,L,C] x = torch.cat([word, pos], dim=2) # -> [N,L,C]
del word, pos


# lstm, extract features # lstm, extract features
x = nn.utils.rnn.pack_padded_sequence(x, word_seq_origin_len.squeeze(1), batch_first=True)
feat, _ = self.lstm(x) # -> [N,L,C] feat, _ = self.lstm(x) # -> [N,L,C]
feat = self.lstm_norm(feat)
feat, _ = nn.utils.rnn.pad_packed_sequence(feat, batch_first=True)


# for arc biaffine # for arc biaffine
# mlp, reduce dim # mlp, reduce dim
@@ -282,6 +300,7 @@ class BiaffineParser(GraphParser):
arc_head = self.arc_head_mlp(feat) arc_head = self.arc_head_mlp(feat)
label_dep = self.label_dep_mlp(feat) label_dep = self.label_dep_mlp(feat)
label_head = self.label_head_mlp(feat) label_head = self.label_head_mlp(feat)
del feat


# biaffine arc classifier # biaffine arc classifier
arc_pred = self.arc_predictor(arc_head, arc_dep) # [N, L, L] arc_pred = self.arc_predictor(arc_head, arc_dep) # [N, L, L]
@@ -289,7 +308,7 @@ class BiaffineParser(GraphParser):
arc_pred.masked_fill_(flip_mask.unsqueeze(1), -np.inf) arc_pred.masked_fill_(flip_mask.unsqueeze(1), -np.inf)


# use gold or predicted arc to predict label # use gold or predicted arc to predict label
if gold_heads is None:
if gold_heads is None or not self.training:
# use greedy decoding in training # use greedy decoding in training
if self.training or self.use_greedy_infer: if self.training or self.use_greedy_infer:
heads = self._greedy_decoder(arc_pred, seq_mask) heads = self._greedy_decoder(arc_pred, seq_mask)
@@ -301,6 +320,7 @@ class BiaffineParser(GraphParser):
head_pred = None head_pred = None
heads = gold_heads heads = gold_heads


batch_range = torch.arange(start=0, end=batch_size, dtype=torch.long, device=word_seq.device).unsqueeze(1)
label_head = label_head[batch_range, heads].contiguous() label_head = label_head[batch_range, heads].contiguous()
label_pred = self.label_predictor(label_head, label_dep) # [N, L, num_label] label_pred = self.label_predictor(label_head, label_dep) # [N, L, num_label]
res_dict = {'arc_pred': arc_pred, 'label_pred': label_pred, 'seq_mask': seq_mask} res_dict = {'arc_pred': arc_pred, 'label_pred': label_pred, 'seq_mask': seq_mask}


+ 6
- 5
reproduction/Biaffine_parser/cfg.cfg View File

@@ -1,16 +1,14 @@
[train] [train]
epochs = -1 epochs = -1
<<<<<<< HEAD
batch_size = 16
=======
batch_size = 32 batch_size = 32
>>>>>>> update biaffine
pickle_path = "./save/" pickle_path = "./save/"
validate = true validate = true
save_best_dev = true save_best_dev = true
eval_sort_key = "UAS" eval_sort_key = "UAS"
use_cuda = true use_cuda = true
model_saved_path = "./save/" model_saved_path = "./save/"
print_every_step = 20
use_golden_train=true


[test] [test]
save_output = true save_output = true
@@ -26,14 +24,17 @@ word_vocab_size = -1
word_emb_dim = 100 word_emb_dim = 100
pos_vocab_size = -1 pos_vocab_size = -1
pos_emb_dim = 100 pos_emb_dim = 100
word_hid_dim = 100
pos_hid_dim = 100
rnn_layers = 3 rnn_layers = 3
rnn_hidden_size = 400 rnn_hidden_size = 400
arc_mlp_size = 500 arc_mlp_size = 500
label_mlp_size = 100 label_mlp_size = 100
num_label = -1 num_label = -1
dropout = 0.33 dropout = 0.33
use_var_lstm=true
use_var_lstm=false
use_greedy_infer=false use_greedy_infer=false


[optim] [optim]
lr = 2e-3 lr = 2e-3
weight_decay = 0.0

+ 94
- 42
reproduction/Biaffine_parser/run.py View File

@@ -6,6 +6,7 @@ sys.path.append(os.path.join(os.path.dirname(__file__), '../..'))
from collections import defaultdict from collections import defaultdict
import math import math
import torch import torch
import re


from fastNLP.core.trainer import Trainer from fastNLP.core.trainer import Trainer
from fastNLP.core.metrics import Evaluator from fastNLP.core.metrics import Evaluator
@@ -55,10 +56,10 @@ class ConlluDataLoader(object):
return ds return ds


def get_one(self, sample): def get_one(self, sample):
text = ['<root>']
pos_tags = ['<root>']
heads = [0]
head_tags = ['root']
text = []
pos_tags = []
heads = []
head_tags = []
for w in sample: for w in sample:
t1, t2, t3, t4 = w[1], w[3], w[6], w[7] t1, t2, t3, t4 = w[1], w[3], w[6], w[7]
if t3 == '_': if t3 == '_':
@@ -96,12 +97,13 @@ class CTBDataLoader(object):
def convert(self, data): def convert(self, data):
dataset = DataSet() dataset = DataSet()
for sample in data: for sample in data:
word_seq = ["<ROOT>"] + sample[0]
pos_seq = ["<ROOT>"] + sample[1]
heads = [0] + list(map(int, sample[2]))
head_tags = ["ROOT"] + sample[3]
word_seq = ["<s>"] + sample[0] + ['</s>']
pos_seq = ["<s>"] + sample[1] + ['</s>']
heads = [0] + list(map(int, sample[2])) + [0]
head_tags = ["<s>"] + sample[3] + ['</s>']
dataset.append(Instance(word_seq=TextField(word_seq, is_target=False), dataset.append(Instance(word_seq=TextField(word_seq, is_target=False),
pos_seq=TextField(pos_seq, is_target=False), pos_seq=TextField(pos_seq, is_target=False),
gold_heads=SeqLabelField(heads, is_target=False),
head_indices=SeqLabelField(heads, is_target=True), head_indices=SeqLabelField(heads, is_target=True),
head_labels=TextField(head_tags, is_target=True))) head_labels=TextField(head_tags, is_target=True)))
return dataset return dataset
@@ -117,7 +119,8 @@ datadir = '/home/yfshao/workdir/parser-data/'
train_data_name = "train_ctb5.txt" train_data_name = "train_ctb5.txt"
dev_data_name = "dev_ctb5.txt" dev_data_name = "dev_ctb5.txt"
test_data_name = "test_ctb5.txt" test_data_name = "test_ctb5.txt"
emb_file_name = "/home/yfshao/parser-data/word_OOVthr_30_100v.txt"
emb_file_name = "/home/yfshao/workdir/parser-data/word_OOVthr_30_100v.txt"
# emb_file_name = "/home/yfshao/workdir/word_vector/cc.zh.300.vec"
loader = CTBDataLoader() loader = CTBDataLoader()


cfgfile = './cfg.cfg' cfgfile = './cfg.cfg'
@@ -129,6 +132,10 @@ test_args = ConfigSection()
model_args = ConfigSection() model_args = ConfigSection()
optim_args = ConfigSection() optim_args = ConfigSection()
ConfigLoader.load_config(cfgfile, {"train": train_args, "test": test_args, "model": model_args, "optim": optim_args}) ConfigLoader.load_config(cfgfile, {"train": train_args, "test": test_args, "model": model_args, "optim": optim_args})
print('trainre Args:', train_args.data)
print('test Args:', test_args.data)
print('optim Args:', optim_args.data)



# Pickle Loader # Pickle Loader
def save_data(dirpath, **kwargs): def save_data(dirpath, **kwargs):
@@ -151,9 +158,31 @@ def load_data(dirpath):
datas[name] = _pickle.load(f) datas[name] = _pickle.load(f)
return datas return datas


def P2(data, field, length):
ds = [ins for ins in data if ins[field].get_length() >= length]
data.clear()
data.extend(ds)
return ds

def P1(data, field):
def reeng(w):
return w if w == '<s>' or w == '</s>' or re.search(r'^([a-zA-Z]+[\.\-]*)+$', w) is None else 'ENG'
def renum(w):
return w if re.search(r'^[0-9]+\.?[0-9]*$', w) is None else 'NUMBER'
for ins in data:
ori = ins[field].contents()
s = list(map(renum, map(reeng, ori)))
if s != ori:
# print(ori)
# print(s)
# print()
ins[field] = ins[field].new(s)
return data

class ParserEvaluator(Evaluator): class ParserEvaluator(Evaluator):
def __init__(self):
def __init__(self, ignore_label):
super(ParserEvaluator, self).__init__() super(ParserEvaluator, self).__init__()
self.ignore = ignore_label


def __call__(self, predict_list, truth_list): def __call__(self, predict_list, truth_list):
head_all, label_all, total_all = 0, 0, 0 head_all, label_all, total_all = 0, 0, 0
@@ -174,6 +203,7 @@ class ParserEvaluator(Evaluator):
label_pred_correct: number of correct predicted labels. label_pred_correct: number of correct predicted labels.
total_tokens: number of predicted tokens total_tokens: number of predicted tokens
""" """
seq_mask *= (head_labels != self.ignore).long()
head_pred_correct = (head_pred == head_indices).long() * seq_mask head_pred_correct = (head_pred == head_indices).long() * seq_mask
_, label_preds = torch.max(label_pred, dim=2) _, label_preds = torch.max(label_pred, dim=2)
label_pred_correct = (label_preds == head_labels).long() * head_pred_correct label_pred_correct = (label_preds == head_labels).long() * head_pred_correct
@@ -181,72 +211,93 @@ class ParserEvaluator(Evaluator):


try: try:
data_dict = load_data(processed_datadir) data_dict = load_data(processed_datadir)
word_v = data_dict['word_v']
pos_v = data_dict['pos_v'] pos_v = data_dict['pos_v']
tag_v = data_dict['tag_v'] tag_v = data_dict['tag_v']
train_data = data_dict['train_data'] train_data = data_dict['train_data']
dev_data = data_dict['dev_data'] dev_data = data_dict['dev_data']
test_data = data_dict['test_datas']
print('use saved pickles') print('use saved pickles')


except Exception as _: except Exception as _:
print('load raw data and preprocess') print('load raw data and preprocess')
word_v = Vocabulary(need_default=True, min_freq=2)
# use pretrain embedding
pos_v = Vocabulary(need_default=True) pos_v = Vocabulary(need_default=True)
tag_v = Vocabulary(need_default=False) tag_v = Vocabulary(need_default=False)
train_data = loader.load(os.path.join(datadir, train_data_name)) train_data = loader.load(os.path.join(datadir, train_data_name))
dev_data = loader.load(os.path.join(datadir, dev_data_name)) dev_data = loader.load(os.path.join(datadir, dev_data_name))
test_data = loader.load(os.path.join(datadir, test_data_name)) test_data = loader.load(os.path.join(datadir, test_data_name))
train_data.update_vocab(word_seq=word_v, pos_seq=pos_v, head_labels=tag_v)
save_data(processed_datadir, word_v=word_v, pos_v=pos_v, tag_v=tag_v, train_data=train_data, dev_data=dev_data)
train_data.update_vocab(pos_seq=pos_v, head_labels=tag_v)
save_data(processed_datadir, pos_v=pos_v, tag_v=tag_v, train_data=train_data, dev_data=dev_data, test_data=test_data)


train_data.index_field("word_seq", word_v).index_field("pos_seq", pos_v).index_field("head_labels", tag_v)
dev_data.index_field("word_seq", word_v).index_field("pos_seq", pos_v).index_field("head_labels", tag_v)
train_data.set_origin_len("word_seq")
dev_data.set_origin_len("word_seq")
embed, word_v = EmbedLoader.load_embedding(model_args['word_emb_dim'], emb_file_name, 'glove', None, os.path.join(processed_datadir, 'word_emb.pkl'))
word_v.unknown_label = "<OOV>"


print(train_data[:3])
print(len(train_data))
print(len(dev_data))
# Model
model_args['word_vocab_size'] = len(word_v) model_args['word_vocab_size'] = len(word_v)
model_args['pos_vocab_size'] = len(pos_v) model_args['pos_vocab_size'] = len(pos_v)
model_args['num_label'] = len(tag_v) model_args['num_label'] = len(tag_v)


model = BiaffineParser(**model_args.data)
model.reset_parameters()

datasets = (train_data, dev_data, test_data)
for ds in datasets:
# print('====='*30)
P1(ds, 'word_seq')
P2(ds, 'word_seq', 5)
ds.index_field("word_seq", word_v).index_field("pos_seq", pos_v).index_field("head_labels", tag_v)
ds.set_origin_len('word_seq')
if train_args['use_golden_train']:
ds.set_target(gold_heads=False)
else:
ds.set_target(gold_heads=None)
train_args.data.pop('use_golden_train')
ignore_label = pos_v['P']

print(test_data[0])
print(len(train_data))
print(len(dev_data))
print(len(test_data))



def train():

def train(path):
# Trainer # Trainer
trainer = Trainer(**train_args.data) trainer = Trainer(**train_args.data)


def _define_optim(obj): def _define_optim(obj):
obj._optimizer = torch.optim.Adam(obj._model.parameters(), **optim_args.data)
lr = optim_args.data['lr']
embed_params = set(obj._model.word_embedding.parameters())
decay_params = set(obj._model.arc_predictor.parameters()) | set(obj._model.label_predictor.parameters())
params = [p for p in obj._model.parameters() if p not in decay_params and p not in embed_params]
obj._optimizer = torch.optim.Adam([
{'params': list(embed_params), 'lr':lr*0.1},
{'params': list(decay_params), **optim_args.data},
{'params': params}
], lr=lr)
obj._scheduler = torch.optim.lr_scheduler.LambdaLR(obj._optimizer, lambda ep: max(.75 ** (ep / 5e4), 0.05)) obj._scheduler = torch.optim.lr_scheduler.LambdaLR(obj._optimizer, lambda ep: max(.75 ** (ep / 5e4), 0.05))


def _update(obj): def _update(obj):
# torch.nn.utils.clip_grad_norm_(obj._model.parameters(), 5.0)
obj._scheduler.step() obj._scheduler.step()
obj._optimizer.step() obj._optimizer.step()


trainer.define_optimizer = lambda: _define_optim(trainer) trainer.define_optimizer = lambda: _define_optim(trainer)
trainer.update = lambda: _update(trainer) trainer.update = lambda: _update(trainer)
trainer.set_validator(Tester(**test_args.data, evaluator=ParserEvaluator()))
trainer.set_validator(Tester(**test_args.data, evaluator=ParserEvaluator(ignore_label)))


# Model
model = BiaffineParser(**model_args.data)

# use pretrain embedding
word_v.unknown_label = "<OOV>"
embed, _ = EmbedLoader.load_embedding(model_args['word_emb_dim'], emb_file_name, 'glove', word_v, os.path.join(processed_datadir, 'word_emb.pkl'))
model.word_embedding = torch.nn.Embedding.from_pretrained(embed, freeze=False) model.word_embedding = torch.nn.Embedding.from_pretrained(embed, freeze=False)

model.word_embedding.padding_idx = word_v.padding_idx model.word_embedding.padding_idx = word_v.padding_idx
model.word_embedding.weight.data[word_v.padding_idx].fill_(0) model.word_embedding.weight.data[word_v.padding_idx].fill_(0)
model.pos_embedding.padding_idx = pos_v.padding_idx model.pos_embedding.padding_idx = pos_v.padding_idx
model.pos_embedding.weight.data[pos_v.padding_idx].fill_(0) model.pos_embedding.weight.data[pos_v.padding_idx].fill_(0)


try:
ModelLoader.load_pytorch(model, "./save/saved_model.pkl")
print('model parameter loaded!')
except Exception as _:
print("No saved model. Continue.")
pass
# try:
# ModelLoader.load_pytorch(model, "./save/saved_model.pkl")
# print('model parameter loaded!')
# except Exception as _:
# print("No saved model. Continue.")
# pass


# Start training # Start training
trainer.train(model, train_data, dev_data) trainer.train(model, train_data, dev_data)
@@ -258,15 +309,15 @@ def train():
print("Model saved!") print("Model saved!")




def test():
def test(path):
# Tester # Tester
tester = Tester(**test_args.data, evaluator=ParserEvaluator())
tester = Tester(**test_args.data, evaluator=ParserEvaluator(ignore_label))


# Model # Model
model = BiaffineParser(**model_args.data) model = BiaffineParser(**model_args.data)


try: try:
ModelLoader.load_pytorch(model, "./save/saved_model.pkl")
ModelLoader.load_pytorch(model, path)
print('model parameter loaded!') print('model parameter loaded!')
except Exception as _: except Exception as _:
print("No saved model. Abort test.") print("No saved model. Abort test.")
@@ -284,11 +335,12 @@ if __name__ == "__main__":
import argparse import argparse
parser = argparse.ArgumentParser(description='Run a chinese word segmentation model') parser = argparse.ArgumentParser(description='Run a chinese word segmentation model')
parser.add_argument('--mode', help='set the model\'s model', choices=['train', 'test', 'infer']) parser.add_argument('--mode', help='set the model\'s model', choices=['train', 'test', 'infer'])
parser.add_argument('--path', type=str, default='')
args = parser.parse_args() args = parser.parse_args()
if args.mode == 'train': if args.mode == 'train':
train()
train(args.path)
elif args.mode == 'test': elif args.mode == 'test':
test()
test(args.path)
elif args.mode == 'infer': elif args.mode == 'infer':
infer() infer()
else: else:


Loading…
Cancel
Save