@@ -24,6 +24,9 @@ class Field(object): | |||
def __repr__(self): | |||
return self.contents().__repr__() | |||
def new(self, *args, **kwargs): | |||
return self.__class__(*args, **kwargs, is_target=self.is_target) | |||
class TextField(Field): | |||
def __init__(self, text, is_target): | |||
""" | |||
@@ -35,6 +35,9 @@ class Instance(object): | |||
else: | |||
raise KeyError("{} not found".format(name)) | |||
def __setitem__(self, name, field): | |||
return self.add_field(name, field) | |||
def get_length(self): | |||
"""Fetch the length of all fields in the instance. | |||
@@ -74,7 +74,7 @@ class Tester(object): | |||
output_list = [] | |||
truth_list = [] | |||
data_iterator = Batch(dev_data, self.batch_size, sampler=RandomSampler(), use_cuda=self.use_cuda) | |||
data_iterator = Batch(dev_data, self.batch_size, sampler=RandomSampler(), use_cuda=self.use_cuda, sort_in_batch=True, sort_key='word_seq') | |||
with torch.no_grad(): | |||
for batch_x, batch_y in data_iterator: | |||
@@ -1,6 +1,6 @@ | |||
import os | |||
import time | |||
from datetime import timedelta | |||
from datetime import timedelta, datetime | |||
import torch | |||
from tensorboardX import SummaryWriter | |||
@@ -15,7 +15,7 @@ from fastNLP.saver.logger import create_logger | |||
from fastNLP.saver.model_saver import ModelSaver | |||
logger = create_logger(__name__, "./train_test.log") | |||
logger.disabled = True | |||
class Trainer(object): | |||
"""Operations of training a model, including data loading, gradient descent, and validation. | |||
@@ -42,7 +42,7 @@ class Trainer(object): | |||
""" | |||
default_args = {"epochs": 1, "batch_size": 2, "validate": False, "use_cuda": False, "pickle_path": "./save/", | |||
"save_best_dev": False, "model_name": "default_model_name.pkl", "print_every_step": 1, | |||
"valid_step": 500, "eval_sort_key": None, | |||
"valid_step": 500, "eval_sort_key": 'acc', | |||
"loss": Loss(None), # used to pass type check | |||
"optimizer": Optimizer("Adam", lr=0.001, weight_decay=0), | |||
"evaluator": Evaluator() | |||
@@ -111,13 +111,17 @@ class Trainer(object): | |||
else: | |||
self._model = network | |||
print(self._model) | |||
# define Tester over dev data | |||
self.dev_data = None | |||
if self.validate: | |||
default_valid_args = {"batch_size": self.batch_size, "pickle_path": self.pickle_path, | |||
"use_cuda": self.use_cuda, "evaluator": self._evaluator} | |||
if self.validator is None: | |||
self.validator = self._create_validator(default_valid_args) | |||
logger.info("validator defined as {}".format(str(self.validator))) | |||
self.dev_data = dev_data | |||
# optimizer and loss | |||
self.define_optimizer() | |||
@@ -130,7 +134,7 @@ class Trainer(object): | |||
# main training procedure | |||
start = time.time() | |||
self.start_time = str(start) | |||
self.start_time = str(datetime.now().strftime('%Y-%m-%d-%H-%M')) | |||
logger.info("training epochs started " + self.start_time) | |||
epoch, iters = 1, 0 | |||
@@ -141,15 +145,17 @@ class Trainer(object): | |||
# prepare mini-batch iterator | |||
data_iterator = Batch(train_data, batch_size=self.batch_size, sampler=RandomSampler(), | |||
use_cuda=self.use_cuda) | |||
use_cuda=self.use_cuda, sort_in_batch=True, sort_key='word_seq') | |||
logger.info("prepared data iterator") | |||
# one forward and backward pass | |||
iters += self._train_step(data_iterator, network, start=start, n_print=self.print_every_step, epoch=epoch, step=iters, dev_data=dev_data) | |||
iters = self._train_step(data_iterator, network, start=start, n_print=self.print_every_step, epoch=epoch, step=iters, dev_data=dev_data) | |||
# validation | |||
if self.validate: | |||
self.valid_model() | |||
self.save_model(self._model, 'training_model_'+self.start_time) | |||
epoch += 1 | |||
def _train_step(self, data_iterator, network, **kwargs): | |||
"""Training process in one epoch. | |||
@@ -160,13 +166,16 @@ class Trainer(object): | |||
- epoch: int, | |||
""" | |||
step = kwargs['step'] | |||
dev_data = kwargs['dev_data'] | |||
for batch_x, batch_y in data_iterator: | |||
prediction = self.data_forward(network, batch_x) | |||
loss = self.get_loss(prediction, batch_y) | |||
self.grad_backward(loss) | |||
if torch.rand(1).item() < 0.001: | |||
print('[grads at epoch: {:>3} step: {:>4}]'.format(kwargs['epoch'], step)) | |||
for name, p in self._model.named_parameters(): | |||
if p.requires_grad: | |||
print('\t{} {} {}'.format(name, tuple(p.size()), torch.sum(p.grad).item())) | |||
self.update() | |||
self._summary_writer.add_scalar("loss", loss.item(), global_step=step) | |||
@@ -183,13 +192,14 @@ class Trainer(object): | |||
return step | |||
def valid_model(self): | |||
if dev_data is None: | |||
if self.dev_data is None: | |||
raise RuntimeError( | |||
"self.validate is True in trainer, but dev_data is None. Please provide the validation data.") | |||
logger.info("validation started") | |||
res = self.validator.test(network, dev_data) | |||
res = self.validator.test(self._model, self.dev_data) | |||
if self.save_best_dev and self.best_eval_result(res): | |||
logger.info('save best result! {}'.format(res)) | |||
print('save best result! {}'.format(res)) | |||
self.save_model(self._model, 'best_model_'+self.start_time) | |||
return res | |||
@@ -282,14 +292,10 @@ class Trainer(object): | |||
""" | |||
if isinstance(metrics, tuple): | |||
loss, metrics = metrics | |||
else: | |||
metrics = validator.metrics | |||
if isinstance(metrics, dict): | |||
if len(metrics) == 1: | |||
accuracy = list(metrics.values())[0] | |||
elif self.eval_sort_key is None: | |||
raise ValueError('dict format metrics should provide sort key for eval best result') | |||
else: | |||
accuracy = metrics[self.eval_sort_key] | |||
else: | |||
@@ -199,6 +199,8 @@ class BiaffineParser(GraphParser): | |||
word_emb_dim, | |||
pos_vocab_size, | |||
pos_emb_dim, | |||
word_hid_dim, | |||
pos_hid_dim, | |||
rnn_layers, | |||
rnn_hidden_size, | |||
arc_mlp_size, | |||
@@ -209,10 +211,15 @@ class BiaffineParser(GraphParser): | |||
use_greedy_infer=False): | |||
super(BiaffineParser, self).__init__() | |||
rnn_out_size = 2 * rnn_hidden_size | |||
self.word_embedding = nn.Embedding(num_embeddings=word_vocab_size, embedding_dim=word_emb_dim) | |||
self.pos_embedding = nn.Embedding(num_embeddings=pos_vocab_size, embedding_dim=pos_emb_dim) | |||
self.word_fc = nn.Linear(word_emb_dim, word_hid_dim) | |||
self.pos_fc = nn.Linear(pos_emb_dim, pos_hid_dim) | |||
self.word_norm = nn.LayerNorm(word_hid_dim) | |||
self.pos_norm = nn.LayerNorm(pos_hid_dim) | |||
if use_var_lstm: | |||
self.lstm = VarLSTM(input_size=word_emb_dim + pos_emb_dim, | |||
self.lstm = VarLSTM(input_size=word_hid_dim + pos_hid_dim, | |||
hidden_size=rnn_hidden_size, | |||
num_layers=rnn_layers, | |||
bias=True, | |||
@@ -221,7 +228,7 @@ class BiaffineParser(GraphParser): | |||
hidden_dropout=dropout, | |||
bidirectional=True) | |||
else: | |||
self.lstm = nn.LSTM(input_size=word_emb_dim + pos_emb_dim, | |||
self.lstm = nn.LSTM(input_size=word_hid_dim + pos_hid_dim, | |||
hidden_size=rnn_hidden_size, | |||
num_layers=rnn_layers, | |||
bias=True, | |||
@@ -229,12 +236,13 @@ class BiaffineParser(GraphParser): | |||
dropout=dropout, | |||
bidirectional=True) | |||
rnn_out_size = 2 * rnn_hidden_size | |||
self.arc_head_mlp = nn.Sequential(nn.Linear(rnn_out_size, arc_mlp_size), | |||
nn.LayerNorm(arc_mlp_size), | |||
nn.ELU(), | |||
TimestepDropout(p=dropout),) | |||
self.arc_dep_mlp = copy.deepcopy(self.arc_head_mlp) | |||
self.label_head_mlp = nn.Sequential(nn.Linear(rnn_out_size, label_mlp_size), | |||
nn.LayerNorm(label_mlp_size), | |||
nn.ELU(), | |||
TimestepDropout(p=dropout),) | |||
self.label_dep_mlp = copy.deepcopy(self.label_head_mlp) | |||
@@ -242,10 +250,18 @@ class BiaffineParser(GraphParser): | |||
self.label_predictor = LabelBilinear(label_mlp_size, label_mlp_size, num_label, bias=True) | |||
self.normal_dropout = nn.Dropout(p=dropout) | |||
self.use_greedy_infer = use_greedy_infer | |||
initial_parameter(self) | |||
self.word_norm = nn.LayerNorm(word_emb_dim) | |||
self.pos_norm = nn.LayerNorm(pos_emb_dim) | |||
self.lstm_norm = nn.LayerNorm(rnn_out_size) | |||
self.reset_parameters() | |||
def reset_parameters(self): | |||
for m in self.modules(): | |||
if isinstance(m, nn.Embedding): | |||
continue | |||
elif isinstance(m, nn.LayerNorm): | |||
nn.init.constant_(m.weight, 1) | |||
nn.init.constant_(m.bias, 0) | |||
else: | |||
for p in m.parameters(): | |||
nn.init.normal_(p, 0, 0.01) | |||
def forward(self, word_seq, pos_seq, word_seq_origin_len, gold_heads=None, **_): | |||
""" | |||
@@ -262,19 +278,21 @@ class BiaffineParser(GraphParser): | |||
# prepare embeddings | |||
batch_size, seq_len = word_seq.shape | |||
# print('forward {} {}'.format(batch_size, seq_len)) | |||
batch_range = torch.arange(start=0, end=batch_size, dtype=torch.long, device=word_seq.device).unsqueeze(1) | |||
# get sequence mask | |||
seq_mask = len2masks(word_seq_origin_len, seq_len).long() | |||
word = self.normal_dropout(self.word_embedding(word_seq)) # [N,L] -> [N,L,C_0] | |||
pos = self.normal_dropout(self.pos_embedding(pos_seq)) # [N,L] -> [N,L,C_1] | |||
word, pos = self.word_fc(word), self.pos_fc(pos) | |||
word, pos = self.word_norm(word), self.pos_norm(pos) | |||
x = torch.cat([word, pos], dim=2) # -> [N,L,C] | |||
del word, pos | |||
# lstm, extract features | |||
x = nn.utils.rnn.pack_padded_sequence(x, word_seq_origin_len.squeeze(1), batch_first=True) | |||
feat, _ = self.lstm(x) # -> [N,L,C] | |||
feat = self.lstm_norm(feat) | |||
feat, _ = nn.utils.rnn.pad_packed_sequence(feat, batch_first=True) | |||
# for arc biaffine | |||
# mlp, reduce dim | |||
@@ -282,6 +300,7 @@ class BiaffineParser(GraphParser): | |||
arc_head = self.arc_head_mlp(feat) | |||
label_dep = self.label_dep_mlp(feat) | |||
label_head = self.label_head_mlp(feat) | |||
del feat | |||
# biaffine arc classifier | |||
arc_pred = self.arc_predictor(arc_head, arc_dep) # [N, L, L] | |||
@@ -289,7 +308,7 @@ class BiaffineParser(GraphParser): | |||
arc_pred.masked_fill_(flip_mask.unsqueeze(1), -np.inf) | |||
# use gold or predicted arc to predict label | |||
if gold_heads is None: | |||
if gold_heads is None or not self.training: | |||
# use greedy decoding in training | |||
if self.training or self.use_greedy_infer: | |||
heads = self._greedy_decoder(arc_pred, seq_mask) | |||
@@ -301,6 +320,7 @@ class BiaffineParser(GraphParser): | |||
head_pred = None | |||
heads = gold_heads | |||
batch_range = torch.arange(start=0, end=batch_size, dtype=torch.long, device=word_seq.device).unsqueeze(1) | |||
label_head = label_head[batch_range, heads].contiguous() | |||
label_pred = self.label_predictor(label_head, label_dep) # [N, L, num_label] | |||
res_dict = {'arc_pred': arc_pred, 'label_pred': label_pred, 'seq_mask': seq_mask} | |||
@@ -1,16 +1,14 @@ | |||
[train] | |||
epochs = -1 | |||
<<<<<<< HEAD | |||
batch_size = 16 | |||
======= | |||
batch_size = 32 | |||
>>>>>>> update biaffine | |||
pickle_path = "./save/" | |||
validate = true | |||
save_best_dev = true | |||
eval_sort_key = "UAS" | |||
use_cuda = true | |||
model_saved_path = "./save/" | |||
print_every_step = 20 | |||
use_golden_train=true | |||
[test] | |||
save_output = true | |||
@@ -26,14 +24,17 @@ word_vocab_size = -1 | |||
word_emb_dim = 100 | |||
pos_vocab_size = -1 | |||
pos_emb_dim = 100 | |||
word_hid_dim = 100 | |||
pos_hid_dim = 100 | |||
rnn_layers = 3 | |||
rnn_hidden_size = 400 | |||
arc_mlp_size = 500 | |||
label_mlp_size = 100 | |||
num_label = -1 | |||
dropout = 0.33 | |||
use_var_lstm=true | |||
use_var_lstm=false | |||
use_greedy_infer=false | |||
[optim] | |||
lr = 2e-3 | |||
weight_decay = 0.0 |
@@ -6,6 +6,7 @@ sys.path.append(os.path.join(os.path.dirname(__file__), '../..')) | |||
from collections import defaultdict | |||
import math | |||
import torch | |||
import re | |||
from fastNLP.core.trainer import Trainer | |||
from fastNLP.core.metrics import Evaluator | |||
@@ -55,10 +56,10 @@ class ConlluDataLoader(object): | |||
return ds | |||
def get_one(self, sample): | |||
text = ['<root>'] | |||
pos_tags = ['<root>'] | |||
heads = [0] | |||
head_tags = ['root'] | |||
text = [] | |||
pos_tags = [] | |||
heads = [] | |||
head_tags = [] | |||
for w in sample: | |||
t1, t2, t3, t4 = w[1], w[3], w[6], w[7] | |||
if t3 == '_': | |||
@@ -96,12 +97,13 @@ class CTBDataLoader(object): | |||
def convert(self, data): | |||
dataset = DataSet() | |||
for sample in data: | |||
word_seq = ["<ROOT>"] + sample[0] | |||
pos_seq = ["<ROOT>"] + sample[1] | |||
heads = [0] + list(map(int, sample[2])) | |||
head_tags = ["ROOT"] + sample[3] | |||
word_seq = ["<s>"] + sample[0] + ['</s>'] | |||
pos_seq = ["<s>"] + sample[1] + ['</s>'] | |||
heads = [0] + list(map(int, sample[2])) + [0] | |||
head_tags = ["<s>"] + sample[3] + ['</s>'] | |||
dataset.append(Instance(word_seq=TextField(word_seq, is_target=False), | |||
pos_seq=TextField(pos_seq, is_target=False), | |||
gold_heads=SeqLabelField(heads, is_target=False), | |||
head_indices=SeqLabelField(heads, is_target=True), | |||
head_labels=TextField(head_tags, is_target=True))) | |||
return dataset | |||
@@ -117,7 +119,8 @@ datadir = '/home/yfshao/workdir/parser-data/' | |||
train_data_name = "train_ctb5.txt" | |||
dev_data_name = "dev_ctb5.txt" | |||
test_data_name = "test_ctb5.txt" | |||
emb_file_name = "/home/yfshao/parser-data/word_OOVthr_30_100v.txt" | |||
emb_file_name = "/home/yfshao/workdir/parser-data/word_OOVthr_30_100v.txt" | |||
# emb_file_name = "/home/yfshao/workdir/word_vector/cc.zh.300.vec" | |||
loader = CTBDataLoader() | |||
cfgfile = './cfg.cfg' | |||
@@ -129,6 +132,10 @@ test_args = ConfigSection() | |||
model_args = ConfigSection() | |||
optim_args = ConfigSection() | |||
ConfigLoader.load_config(cfgfile, {"train": train_args, "test": test_args, "model": model_args, "optim": optim_args}) | |||
print('trainre Args:', train_args.data) | |||
print('test Args:', test_args.data) | |||
print('optim Args:', optim_args.data) | |||
# Pickle Loader | |||
def save_data(dirpath, **kwargs): | |||
@@ -151,9 +158,31 @@ def load_data(dirpath): | |||
datas[name] = _pickle.load(f) | |||
return datas | |||
def P2(data, field, length): | |||
ds = [ins for ins in data if ins[field].get_length() >= length] | |||
data.clear() | |||
data.extend(ds) | |||
return ds | |||
def P1(data, field): | |||
def reeng(w): | |||
return w if w == '<s>' or w == '</s>' or re.search(r'^([a-zA-Z]+[\.\-]*)+$', w) is None else 'ENG' | |||
def renum(w): | |||
return w if re.search(r'^[0-9]+\.?[0-9]*$', w) is None else 'NUMBER' | |||
for ins in data: | |||
ori = ins[field].contents() | |||
s = list(map(renum, map(reeng, ori))) | |||
if s != ori: | |||
# print(ori) | |||
# print(s) | |||
# print() | |||
ins[field] = ins[field].new(s) | |||
return data | |||
class ParserEvaluator(Evaluator): | |||
def __init__(self): | |||
def __init__(self, ignore_label): | |||
super(ParserEvaluator, self).__init__() | |||
self.ignore = ignore_label | |||
def __call__(self, predict_list, truth_list): | |||
head_all, label_all, total_all = 0, 0, 0 | |||
@@ -174,6 +203,7 @@ class ParserEvaluator(Evaluator): | |||
label_pred_correct: number of correct predicted labels. | |||
total_tokens: number of predicted tokens | |||
""" | |||
seq_mask *= (head_labels != self.ignore).long() | |||
head_pred_correct = (head_pred == head_indices).long() * seq_mask | |||
_, label_preds = torch.max(label_pred, dim=2) | |||
label_pred_correct = (label_preds == head_labels).long() * head_pred_correct | |||
@@ -181,72 +211,93 @@ class ParserEvaluator(Evaluator): | |||
try: | |||
data_dict = load_data(processed_datadir) | |||
word_v = data_dict['word_v'] | |||
pos_v = data_dict['pos_v'] | |||
tag_v = data_dict['tag_v'] | |||
train_data = data_dict['train_data'] | |||
dev_data = data_dict['dev_data'] | |||
test_data = data_dict['test_datas'] | |||
print('use saved pickles') | |||
except Exception as _: | |||
print('load raw data and preprocess') | |||
word_v = Vocabulary(need_default=True, min_freq=2) | |||
# use pretrain embedding | |||
pos_v = Vocabulary(need_default=True) | |||
tag_v = Vocabulary(need_default=False) | |||
train_data = loader.load(os.path.join(datadir, train_data_name)) | |||
dev_data = loader.load(os.path.join(datadir, dev_data_name)) | |||
test_data = loader.load(os.path.join(datadir, test_data_name)) | |||
train_data.update_vocab(word_seq=word_v, pos_seq=pos_v, head_labels=tag_v) | |||
save_data(processed_datadir, word_v=word_v, pos_v=pos_v, tag_v=tag_v, train_data=train_data, dev_data=dev_data) | |||
train_data.update_vocab(pos_seq=pos_v, head_labels=tag_v) | |||
save_data(processed_datadir, pos_v=pos_v, tag_v=tag_v, train_data=train_data, dev_data=dev_data, test_data=test_data) | |||
train_data.index_field("word_seq", word_v).index_field("pos_seq", pos_v).index_field("head_labels", tag_v) | |||
dev_data.index_field("word_seq", word_v).index_field("pos_seq", pos_v).index_field("head_labels", tag_v) | |||
train_data.set_origin_len("word_seq") | |||
dev_data.set_origin_len("word_seq") | |||
embed, word_v = EmbedLoader.load_embedding(model_args['word_emb_dim'], emb_file_name, 'glove', None, os.path.join(processed_datadir, 'word_emb.pkl')) | |||
word_v.unknown_label = "<OOV>" | |||
print(train_data[:3]) | |||
print(len(train_data)) | |||
print(len(dev_data)) | |||
# Model | |||
model_args['word_vocab_size'] = len(word_v) | |||
model_args['pos_vocab_size'] = len(pos_v) | |||
model_args['num_label'] = len(tag_v) | |||
model = BiaffineParser(**model_args.data) | |||
model.reset_parameters() | |||
datasets = (train_data, dev_data, test_data) | |||
for ds in datasets: | |||
# print('====='*30) | |||
P1(ds, 'word_seq') | |||
P2(ds, 'word_seq', 5) | |||
ds.index_field("word_seq", word_v).index_field("pos_seq", pos_v).index_field("head_labels", tag_v) | |||
ds.set_origin_len('word_seq') | |||
if train_args['use_golden_train']: | |||
ds.set_target(gold_heads=False) | |||
else: | |||
ds.set_target(gold_heads=None) | |||
train_args.data.pop('use_golden_train') | |||
ignore_label = pos_v['P'] | |||
print(test_data[0]) | |||
print(len(train_data)) | |||
print(len(dev_data)) | |||
print(len(test_data)) | |||
def train(): | |||
def train(path): | |||
# Trainer | |||
trainer = Trainer(**train_args.data) | |||
def _define_optim(obj): | |||
obj._optimizer = torch.optim.Adam(obj._model.parameters(), **optim_args.data) | |||
lr = optim_args.data['lr'] | |||
embed_params = set(obj._model.word_embedding.parameters()) | |||
decay_params = set(obj._model.arc_predictor.parameters()) | set(obj._model.label_predictor.parameters()) | |||
params = [p for p in obj._model.parameters() if p not in decay_params and p not in embed_params] | |||
obj._optimizer = torch.optim.Adam([ | |||
{'params': list(embed_params), 'lr':lr*0.1}, | |||
{'params': list(decay_params), **optim_args.data}, | |||
{'params': params} | |||
], lr=lr) | |||
obj._scheduler = torch.optim.lr_scheduler.LambdaLR(obj._optimizer, lambda ep: max(.75 ** (ep / 5e4), 0.05)) | |||
def _update(obj): | |||
# torch.nn.utils.clip_grad_norm_(obj._model.parameters(), 5.0) | |||
obj._scheduler.step() | |||
obj._optimizer.step() | |||
trainer.define_optimizer = lambda: _define_optim(trainer) | |||
trainer.update = lambda: _update(trainer) | |||
trainer.set_validator(Tester(**test_args.data, evaluator=ParserEvaluator())) | |||
trainer.set_validator(Tester(**test_args.data, evaluator=ParserEvaluator(ignore_label))) | |||
# Model | |||
model = BiaffineParser(**model_args.data) | |||
# use pretrain embedding | |||
word_v.unknown_label = "<OOV>" | |||
embed, _ = EmbedLoader.load_embedding(model_args['word_emb_dim'], emb_file_name, 'glove', word_v, os.path.join(processed_datadir, 'word_emb.pkl')) | |||
model.word_embedding = torch.nn.Embedding.from_pretrained(embed, freeze=False) | |||
model.word_embedding.padding_idx = word_v.padding_idx | |||
model.word_embedding.weight.data[word_v.padding_idx].fill_(0) | |||
model.pos_embedding.padding_idx = pos_v.padding_idx | |||
model.pos_embedding.weight.data[pos_v.padding_idx].fill_(0) | |||
try: | |||
ModelLoader.load_pytorch(model, "./save/saved_model.pkl") | |||
print('model parameter loaded!') | |||
except Exception as _: | |||
print("No saved model. Continue.") | |||
pass | |||
# try: | |||
# ModelLoader.load_pytorch(model, "./save/saved_model.pkl") | |||
# print('model parameter loaded!') | |||
# except Exception as _: | |||
# print("No saved model. Continue.") | |||
# pass | |||
# Start training | |||
trainer.train(model, train_data, dev_data) | |||
@@ -258,15 +309,15 @@ def train(): | |||
print("Model saved!") | |||
def test(): | |||
def test(path): | |||
# Tester | |||
tester = Tester(**test_args.data, evaluator=ParserEvaluator()) | |||
tester = Tester(**test_args.data, evaluator=ParserEvaluator(ignore_label)) | |||
# Model | |||
model = BiaffineParser(**model_args.data) | |||
try: | |||
ModelLoader.load_pytorch(model, "./save/saved_model.pkl") | |||
ModelLoader.load_pytorch(model, path) | |||
print('model parameter loaded!') | |||
except Exception as _: | |||
print("No saved model. Abort test.") | |||
@@ -284,11 +335,12 @@ if __name__ == "__main__": | |||
import argparse | |||
parser = argparse.ArgumentParser(description='Run a chinese word segmentation model') | |||
parser.add_argument('--mode', help='set the model\'s model', choices=['train', 'test', 'infer']) | |||
parser.add_argument('--path', type=str, default='') | |||
args = parser.parse_args() | |||
if args.mode == 'train': | |||
train() | |||
train(args.path) | |||
elif args.mode == 'test': | |||
test() | |||
test(args.path) | |||
elif args.mode == 'infer': | |||
infer() | |||
else: | |||