diff --git a/fastNLP/core/field.py b/fastNLP/core/field.py index 48e451f6..8720bf1b 100644 --- a/fastNLP/core/field.py +++ b/fastNLP/core/field.py @@ -28,6 +28,12 @@ class Field(object): """ raise NotImplementedError + def __repr__(self): + return self.contents().__repr__() + + def new(self, *args, **kwargs): + return self.__class__(*args, **kwargs, is_target=self.is_target) + class TextField(Field): def __init__(self, name, text, is_target): """ diff --git a/fastNLP/core/instance.py b/fastNLP/core/instance.py index a4eca1aa..50787fd1 100644 --- a/fastNLP/core/instance.py +++ b/fastNLP/core/instance.py @@ -35,6 +35,9 @@ class Instance(object): else: raise KeyError("{} not found".format(name)) + def __setitem__(self, name, field): + return self.add_field(name, field) + def get_length(self): """Fetch the length of all fields in the instance. @@ -82,3 +85,6 @@ class Instance(object): name, field_name = origin_len tensor_x[name] = torch.LongTensor([self.fields[field_name].get_length()]) return tensor_x, tensor_y + + def __repr__(self): + return self.fields.__repr__() \ No newline at end of file diff --git a/fastNLP/core/tester.py b/fastNLP/core/tester.py index 24aac951..4c0cfb41 100644 --- a/fastNLP/core/tester.py +++ b/fastNLP/core/tester.py @@ -17,9 +17,9 @@ class Tester(object): """ super(Tester, self).__init__() """ - "default_args" provides default value for important settings. - The initialization arguments "kwargs" with the same key (name) will override the default value. - "kwargs" must have the same type as "default_args" on corresponding keys. + "default_args" provides default value for important settings. + The initialization arguments "kwargs" with the same key (name) will override the default value. + "kwargs" must have the same type as "default_args" on corresponding keys. Otherwise, error will raise. """ default_args = {"batch_size": 8, @@ -29,8 +29,8 @@ class Tester(object): "evaluator": Evaluator() } """ - "required_args" is the collection of arguments that users must pass to Trainer explicitly. - This is used to warn users of essential settings in the training. + "required_args" is the collection of arguments that users must pass to Trainer explicitly. + This is used to warn users of essential settings in the training. Specially, "required_args" does not have default value, so they have nothing to do with "default_args". """ required_args = {} @@ -74,16 +74,19 @@ class Tester(object): output_list = [] truth_list = [] - data_iterator = Batch(dev_data, self.batch_size, sampler=RandomSampler(), use_cuda=self.use_cuda) + data_iterator = Batch(dev_data, self.batch_size, sampler=RandomSampler(), use_cuda=self.use_cuda, sort_in_batch=True, sort_key='word_seq') - for batch_x, batch_y in data_iterator: - with torch.no_grad(): + with torch.no_grad(): + for batch_x, batch_y in data_iterator: prediction = self.data_forward(network, batch_x) - output_list.append(prediction) - truth_list.append(batch_y) - eval_results = self.evaluate(output_list, truth_list) + output_list.append(prediction) + truth_list.append(batch_y) + eval_results = self.evaluate(output_list, truth_list) print("[tester] {}".format(self.print_eval_results(eval_results))) logger.info("[tester] {}".format(self.print_eval_results(eval_results))) + self.mode(network, is_test=False) + self.metrics = eval_results + return eval_results def mode(self, model, is_test=False): """Train mode or Test mode. This is for PyTorch currently. diff --git a/fastNLP/core/trainer.py b/fastNLP/core/trainer.py index a180b10d..d1881297 100644 --- a/fastNLP/core/trainer.py +++ b/fastNLP/core/trainer.py @@ -1,6 +1,6 @@ import os import time -from datetime import timedelta +from datetime import timedelta, datetime import torch from tensorboardX import SummaryWriter @@ -15,7 +15,7 @@ from fastNLP.saver.logger import create_logger from fastNLP.saver.model_saver import ModelSaver logger = create_logger(__name__, "./train_test.log") - +logger.disabled = True class Trainer(object): """Operations of training a model, including data loading, gradient descent, and validation. @@ -35,20 +35,21 @@ class Trainer(object): super(Trainer, self).__init__() """ - "default_args" provides default value for important settings. - The initialization arguments "kwargs" with the same key (name) will override the default value. - "kwargs" must have the same type as "default_args" on corresponding keys. + "default_args" provides default value for important settings. + The initialization arguments "kwargs" with the same key (name) will override the default value. + "kwargs" must have the same type as "default_args" on corresponding keys. Otherwise, error will raise. """ default_args = {"epochs": 1, "batch_size": 2, "validate": False, "use_cuda": False, "pickle_path": "./save/", "save_best_dev": False, "model_name": "default_model_name.pkl", "print_every_step": 1, + "valid_step": 500, "eval_sort_key": 'acc', "loss": Loss(None), # used to pass type check "optimizer": Optimizer("Adam", lr=0.001, weight_decay=0), "evaluator": Evaluator() } """ - "required_args" is the collection of arguments that users must pass to Trainer explicitly. - This is used to warn users of essential settings in the training. + "required_args" is the collection of arguments that users must pass to Trainer explicitly. + This is used to warn users of essential settings in the training. Specially, "required_args" does not have default value, so they have nothing to do with "default_args". """ required_args = {} @@ -70,16 +71,20 @@ class Trainer(object): else: # Trainer doesn't care about extra arguments pass - print(default_args) + print("Training Args {}".format(default_args)) + logger.info("Training Args {}".format(default_args)) - self.n_epochs = default_args["epochs"] - self.batch_size = default_args["batch_size"] + self.n_epochs = int(default_args["epochs"]) + self.batch_size = int(default_args["batch_size"]) self.pickle_path = default_args["pickle_path"] self.validate = default_args["validate"] self.save_best_dev = default_args["save_best_dev"] self.use_cuda = default_args["use_cuda"] self.model_name = default_args["model_name"] - self.print_every_step = default_args["print_every_step"] + self.print_every_step = int(default_args["print_every_step"]) + self.valid_step = int(default_args["valid_step"]) + if self.validate is not None: + assert self.valid_step > 0 self._model = None self._loss_func = default_args["loss"].get() # return a pytorch loss function or None @@ -89,6 +94,8 @@ class Trainer(object): self._summary_writer = SummaryWriter(self.pickle_path + 'tensorboard_logs') self._graph_summaried = False self._best_accuracy = 0.0 + self.eval_sort_key = default_args['eval_sort_key'] + self.validator = None def train(self, network, train_data, dev_data=None): """General Training Procedure @@ -104,12 +111,17 @@ class Trainer(object): else: self._model = network + print(self._model) + # define Tester over dev data + self.dev_data = None if self.validate: default_valid_args = {"batch_size": self.batch_size, "pickle_path": self.pickle_path, "use_cuda": self.use_cuda, "evaluator": self._evaluator} - validator = self._create_validator(default_valid_args) - logger.info("validator defined as {}".format(str(validator))) + if self.validator is None: + self.validator = self._create_validator(default_valid_args) + logger.info("validator defined as {}".format(str(self.validator))) + self.dev_data = dev_data # optimizer and loss self.define_optimizer() @@ -117,29 +129,33 @@ class Trainer(object): self.define_loss() logger.info("loss function defined as {}".format(str(self._loss_func))) + # turn on network training mode + self.mode(network, is_test=False) + # main training procedure start = time.time() - logger.info("training epochs started") - for epoch in range(1, self.n_epochs + 1): + self.start_time = str(datetime.now().strftime('%Y-%m-%d-%H-%M-%S')) + print("training epochs started " + self.start_time) + logger.info("training epochs started " + self.start_time) + epoch, iters = 1, 0 + while(1): + if self.n_epochs != -1 and epoch > self.n_epochs: + break logger.info("training epoch {}".format(epoch)) - # turn on network training mode - self.mode(network, is_test=False) # prepare mini-batch iterator data_iterator = Batch(train_data, batch_size=self.batch_size, sampler=RandomSampler(), - use_cuda=self.use_cuda) + use_cuda=self.use_cuda, sort_in_batch=True, sort_key='word_seq') logger.info("prepared data iterator") # one forward and backward pass - self._train_step(data_iterator, network, start=start, n_print=self.print_every_step, epoch=epoch) + iters = self._train_step(data_iterator, network, start=start, n_print=self.print_every_step, epoch=epoch, step=iters, dev_data=dev_data) # validation if self.validate: - if dev_data is None: - raise RuntimeError( - "self.validate is True in trainer, but dev_data is None. Please provide the validation data.") - logger.info("validation started") - validator.test(network, dev_data) + self.valid_model() + self.save_model(self._model, 'training_model_'+self.start_time) + epoch += 1 def _train_step(self, data_iterator, network, **kwargs): """Training process in one epoch. @@ -149,13 +165,17 @@ class Trainer(object): - start: time.time(), the starting time of this step. - epoch: int, """ - step = 0 + step = kwargs['step'] for batch_x, batch_y in data_iterator: - prediction = self.data_forward(network, batch_x) loss = self.get_loss(prediction, batch_y) self.grad_backward(loss) + # if torch.rand(1).item() < 0.001: + # print('[grads at epoch: {:>3} step: {:>4}]'.format(kwargs['epoch'], step)) + # for name, p in self._model.named_parameters(): + # if p.requires_grad: + # print('\t{} {} {}'.format(name, tuple(p.size()), torch.sum(p.grad).item())) self.update() self._summary_writer.add_scalar("loss", loss.item(), global_step=step) @@ -166,7 +186,22 @@ class Trainer(object): kwargs["epoch"], step, loss.data, diff) print(print_output) logger.info(print_output) + if self.validate and self.valid_step > 0 and step > 0 and step % self.valid_step == 0: + self.valid_model() step += 1 + return step + + def valid_model(self): + if self.dev_data is None: + raise RuntimeError( + "self.validate is True in trainer, but dev_data is None. Please provide the validation data.") + logger.info("validation started") + res = self.validator.test(self._model, self.dev_data) + if self.save_best_dev and self.best_eval_result(res): + logger.info('save best result! {}'.format(res)) + print('save best result! {}'.format(res)) + self.save_model(self._model, 'best_model_'+self.start_time) + return res def mode(self, model, is_test=False): """Train mode or Test mode. This is for PyTorch currently. @@ -180,11 +215,17 @@ class Trainer(object): else: model.train() - def define_optimizer(self): + def define_optimizer(self, optim=None): """Define framework-specific optimizer specified by the models. """ - self._optimizer = self._optimizer_proto.construct_from_pytorch(self._model.parameters()) + if optim is not None: + # optimizer constructed by user + self._optimizer = optim + elif self._optimizer is None: + # optimizer constructed by proto + self._optimizer = self._optimizer_proto.construct_from_pytorch(self._model.parameters()) + return self._optimizer def update(self): """Perform weight update on a model. @@ -217,6 +258,8 @@ class Trainer(object): :param truth: ground truth label vector :return: a scalar """ + if isinstance(predict, dict) and isinstance(truth, dict): + return self._loss_func(**predict, **truth) if len(truth) > 1: raise NotImplementedError("Not ready to handle multi-labels.") truth = list(truth.values())[0] if len(truth) > 0 else None @@ -241,13 +284,23 @@ class Trainer(object): raise ValueError("Please specify a loss function.") logger.info("The model didn't define loss, use Trainer's loss.") - def best_eval_result(self, validator): + def best_eval_result(self, metrics): """Check if the current epoch yields better validation results. :param validator: a Tester instance :return: bool, True means current results on dev set is the best. """ - loss, accuracy = validator.metrics + if isinstance(metrics, tuple): + loss, metrics = metrics + + if isinstance(metrics, dict): + if len(metrics) == 1: + accuracy = list(metrics.values())[0] + else: + accuracy = metrics[self.eval_sort_key] + else: + accuracy = metrics + if accuracy > self._best_accuracy: self._best_accuracy = accuracy return True @@ -268,6 +321,8 @@ class Trainer(object): def _create_validator(self, valid_args): raise NotImplementedError + def set_validator(self, validor): + self.validator = validor class SeqLabelTrainer(Trainer): """Trainer for Sequence Labeling diff --git a/fastNLP/core/vocabulary.py b/fastNLP/core/vocabulary.py index 26d2e837..0e8e77cd 100644 --- a/fastNLP/core/vocabulary.py +++ b/fastNLP/core/vocabulary.py @@ -51,6 +51,12 @@ class Vocabulary(object): self.min_freq = min_freq self.word_count = {} self.has_default = need_default + if self.has_default: + self.padding_label = DEFAULT_PADDING_LABEL + self.unknown_label = DEFAULT_UNKNOWN_LABEL + else: + self.padding_label = None + self.unknown_label = None self.word2idx = None self.idx2word = None @@ -77,12 +83,10 @@ class Vocabulary(object): """ if self.has_default: self.word2idx = deepcopy(DEFAULT_WORD_TO_INDEX) - self.padding_label = DEFAULT_PADDING_LABEL - self.unknown_label = DEFAULT_UNKNOWN_LABEL + self.word2idx[self.unknown_label] = self.word2idx.pop(DEFAULT_UNKNOWN_LABEL) + self.word2idx[self.padding_label] = self.word2idx.pop(DEFAULT_PADDING_LABEL) else: self.word2idx = {} - self.padding_label = None - self.unknown_label = None words = sorted(self.word_count.items(), key=lambda kv: kv[1], reverse=True) if self.min_freq is not None: @@ -114,7 +118,7 @@ class Vocabulary(object): if w in self.word2idx: return self.word2idx[w] elif self.has_default: - return self.word2idx[DEFAULT_UNKNOWN_LABEL] + return self.word2idx[self.unknown_label] else: raise ValueError("word {} not in vocabulary".format(w)) @@ -134,6 +138,11 @@ class Vocabulary(object): return None return self.word2idx[self.unknown_label] + def __setattr__(self, name, val): + self.__dict__[name] = val + if name in self.__dict__ and name in ["unknown_label", "padding_label"]: + self.word2idx = None + @property @check_build_vocab def padding_idx(self): diff --git a/fastNLP/loader/dataset_loader.py b/fastNLP/loader/dataset_loader.py index 91be0215..4ba121dd 100644 --- a/fastNLP/loader/dataset_loader.py +++ b/fastNLP/loader/dataset_loader.py @@ -87,7 +87,6 @@ class DataSetLoader(BaseLoader): """ raise NotImplementedError - @DataSet.set_reader('read_raw') class RawDataSetLoader(DataSetLoader): def __init__(self): @@ -103,7 +102,6 @@ class RawDataSetLoader(DataSetLoader): def convert(self, data): return convert_seq_dataset(data) - @DataSet.set_reader('read_pos') class POSDataSetLoader(DataSetLoader): """Dataset Loader for POS Tag datasets. @@ -173,7 +171,6 @@ class POSDataSetLoader(DataSetLoader): """ return convert_seq2seq_dataset(data) - @DataSet.set_reader('read_tokenize') class TokenizeDataSetLoader(DataSetLoader): """ @@ -233,7 +230,6 @@ class TokenizeDataSetLoader(DataSetLoader): def convert(self, data): return convert_seq2seq_dataset(data) - @DataSet.set_reader('read_class') class ClassDataSetLoader(DataSetLoader): """Loader for classification data sets""" @@ -272,7 +268,6 @@ class ClassDataSetLoader(DataSetLoader): def convert(self, data): return convert_seq2tag_dataset(data) - @DataSet.set_reader('read_conll') class ConllLoader(DataSetLoader): """loader for conll format files""" @@ -314,7 +309,6 @@ class ConllLoader(DataSetLoader): def convert(self, data): pass - @DataSet.set_reader('read_lm') class LMDataSetLoader(DataSetLoader): """Language Model Dataset Loader @@ -351,7 +345,6 @@ class LMDataSetLoader(DataSetLoader): def convert(self, data): pass - @DataSet.set_reader('read_people_daily') class PeopleDailyCorpusLoader(DataSetLoader): """ diff --git a/fastNLP/loader/embed_loader.py b/fastNLP/loader/embed_loader.py index 2f61830f..415cb1b9 100644 --- a/fastNLP/loader/embed_loader.py +++ b/fastNLP/loader/embed_loader.py @@ -17,8 +17,8 @@ class EmbedLoader(BaseLoader): def _load_glove(emb_file): """Read file as a glove embedding - file format: - embeddings are split by line, + file format: + embeddings are split by line, for one embedding, word and numbers split by space Example:: @@ -33,7 +33,7 @@ class EmbedLoader(BaseLoader): if len(line) > 0: emb[line[0]] = torch.Tensor(list(map(float, line[1:]))) return emb - + @staticmethod def _load_pretrain(emb_file, emb_type): """Read txt data from embedding file and convert to np.array as pre-trained embedding diff --git a/fastNLP/models/biaffine_parser.py b/fastNLP/models/biaffine_parser.py index a2a00a29..7e0a9cec 100644 --- a/fastNLP/models/biaffine_parser.py +++ b/fastNLP/models/biaffine_parser.py @@ -16,10 +16,9 @@ def mst(scores): https://github.com/tdozat/Parser/blob/0739216129cd39d69997d28cbc4133b360ea3934/lib/models/nn.py#L692 """ length = scores.shape[0] - min_score = -np.inf - mask = np.zeros((length, length)) - np.fill_diagonal(mask, -np.inf) - scores = scores + mask + min_score = scores.min() - 1 + eye = np.eye(length) + scores = scores * (1 - eye) + min_score * eye heads = np.argmax(scores, axis=1) heads[0] = 0 tokens = np.arange(1, length) @@ -126,6 +125,8 @@ class GraphParser(nn.Module): def _greedy_decoder(self, arc_matrix, seq_mask=None): _, seq_len, _ = arc_matrix.shape matrix = arc_matrix + torch.diag(arc_matrix.new(seq_len).fill_(-np.inf)) + flip_mask = (seq_mask == 0).byte() + matrix.masked_fill_(flip_mask.unsqueeze(1), -np.inf) _, heads = torch.max(matrix, dim=2) if seq_mask is not None: heads *= seq_mask.long() @@ -135,8 +136,15 @@ class GraphParser(nn.Module): batch_size, seq_len, _ = arc_matrix.shape matrix = torch.zeros_like(arc_matrix).copy_(arc_matrix) ans = matrix.new_zeros(batch_size, seq_len).long() + lens = (seq_mask.long()).sum(1) if seq_mask is not None else torch.zeros(batch_size) + seq_len + batch_idx = torch.arange(batch_size, dtype=torch.long, device=lens.device) + seq_mask[batch_idx, lens-1] = 0 for i, graph in enumerate(matrix): - ans[i] = torch.as_tensor(mst(graph.cpu().numpy()), device=ans.device) + len_i = lens[i] + if len_i == seq_len: + ans[i] = torch.as_tensor(mst(graph.cpu().numpy()), device=ans.device) + else: + ans[i, :len_i] = torch.as_tensor(mst(graph[:len_i, :len_i].cpu().numpy()), device=ans.device) if seq_mask is not None: ans *= seq_mask.long() return ans @@ -175,14 +183,19 @@ class LabelBilinear(nn.Module): def __init__(self, in1_features, in2_features, num_label, bias=True): super(LabelBilinear, self).__init__() self.bilinear = nn.Bilinear(in1_features, in2_features, num_label, bias=bias) - self.lin1 = nn.Linear(in1_features, num_label, bias=False) - self.lin2 = nn.Linear(in2_features, num_label, bias=False) + self.lin = nn.Linear(in1_features + in2_features, num_label, bias=False) def forward(self, x1, x2): output = self.bilinear(x1, x2) - output += self.lin1(x1) + self.lin2(x2) + output += self.lin(torch.cat([x1, x2], dim=2)) return output +def len2masks(origin_len, max_len): + if origin_len.dim() <= 1: + origin_len = origin_len.unsqueeze(1) # [batch_size, 1] + seq_range = torch.arange(start=0, end=max_len, dtype=torch.long, device=origin_len.device) # [max_len,] + seq_mask = torch.gt(origin_len, seq_range.unsqueeze(0)) # [batch_size, max_len] + return seq_mask class BiaffineParser(GraphParser): """Biaffine Dependency Parser implemantation. @@ -194,6 +207,8 @@ class BiaffineParser(GraphParser): word_emb_dim, pos_vocab_size, pos_emb_dim, + word_hid_dim, + pos_hid_dim, rnn_layers, rnn_hidden_size, arc_mlp_size, @@ -204,10 +219,15 @@ class BiaffineParser(GraphParser): use_greedy_infer=False): super(BiaffineParser, self).__init__() + rnn_out_size = 2 * rnn_hidden_size self.word_embedding = nn.Embedding(num_embeddings=word_vocab_size, embedding_dim=word_emb_dim) self.pos_embedding = nn.Embedding(num_embeddings=pos_vocab_size, embedding_dim=pos_emb_dim) + self.word_fc = nn.Linear(word_emb_dim, word_hid_dim) + self.pos_fc = nn.Linear(pos_emb_dim, pos_hid_dim) + self.word_norm = nn.LayerNorm(word_hid_dim) + self.pos_norm = nn.LayerNorm(pos_hid_dim) if use_var_lstm: - self.lstm = VarLSTM(input_size=word_emb_dim + pos_emb_dim, + self.lstm = VarLSTM(input_size=word_hid_dim + pos_hid_dim, hidden_size=rnn_hidden_size, num_layers=rnn_layers, bias=True, @@ -216,7 +236,7 @@ class BiaffineParser(GraphParser): hidden_dropout=dropout, bidirectional=True) else: - self.lstm = nn.LSTM(input_size=word_emb_dim + pos_emb_dim, + self.lstm = nn.LSTM(input_size=word_hid_dim + pos_hid_dim, hidden_size=rnn_hidden_size, num_layers=rnn_layers, bias=True, @@ -224,21 +244,35 @@ class BiaffineParser(GraphParser): dropout=dropout, bidirectional=True) - rnn_out_size = 2 * rnn_hidden_size self.arc_head_mlp = nn.Sequential(nn.Linear(rnn_out_size, arc_mlp_size), - nn.ELU()) + nn.LayerNorm(arc_mlp_size), + nn.ELU(), + TimestepDropout(p=dropout),) self.arc_dep_mlp = copy.deepcopy(self.arc_head_mlp) self.label_head_mlp = nn.Sequential(nn.Linear(rnn_out_size, label_mlp_size), - nn.ELU()) + nn.LayerNorm(label_mlp_size), + nn.ELU(), + TimestepDropout(p=dropout),) self.label_dep_mlp = copy.deepcopy(self.label_head_mlp) self.arc_predictor = ArcBiaffine(arc_mlp_size, bias=True) self.label_predictor = LabelBilinear(label_mlp_size, label_mlp_size, num_label, bias=True) self.normal_dropout = nn.Dropout(p=dropout) - self.timestep_dropout = TimestepDropout(p=dropout) self.use_greedy_infer = use_greedy_infer - initial_parameter(self) + self.reset_parameters() + self.explore_p = 0.2 + + def reset_parameters(self): + for m in self.modules(): + if isinstance(m, nn.Embedding): + continue + elif isinstance(m, nn.LayerNorm): + nn.init.constant_(m.weight, 0.1) + nn.init.constant_(m.bias, 0) + else: + for p in m.parameters(): + nn.init.normal_(p, 0, 0.1) - def forward(self, word_seq, pos_seq, seq_mask, gold_heads=None, **_): + def forward(self, word_seq, pos_seq, word_seq_origin_len, gold_heads=None, **_): """ :param word_seq: [batch_size, seq_len] sequence of word's indices :param pos_seq: [batch_size, seq_len] sequence of word's indices @@ -253,32 +287,35 @@ class BiaffineParser(GraphParser): # prepare embeddings batch_size, seq_len = word_seq.shape # print('forward {} {}'.format(batch_size, seq_len)) - batch_range = torch.arange(start=0, end=batch_size, dtype=torch.long, device=word_seq.device).unsqueeze(1) # get sequence mask - seq_mask = seq_mask.long() + seq_mask = len2masks(word_seq_origin_len, seq_len).long() word = self.normal_dropout(self.word_embedding(word_seq)) # [N,L] -> [N,L,C_0] pos = self.normal_dropout(self.pos_embedding(pos_seq)) # [N,L] -> [N,L,C_1] + word, pos = self.word_fc(word), self.pos_fc(pos) + word, pos = self.word_norm(word), self.pos_norm(pos) x = torch.cat([word, pos], dim=2) # -> [N,L,C] + del word, pos # lstm, extract features + x = nn.utils.rnn.pack_padded_sequence(x, word_seq_origin_len.squeeze(1), batch_first=True) feat, _ = self.lstm(x) # -> [N,L,C] + feat, _ = nn.utils.rnn.pad_packed_sequence(feat, batch_first=True) # for arc biaffine # mlp, reduce dim - arc_dep = self.timestep_dropout(self.arc_dep_mlp(feat)) - arc_head = self.timestep_dropout(self.arc_head_mlp(feat)) - label_dep = self.timestep_dropout(self.label_dep_mlp(feat)) - label_head = self.timestep_dropout(self.label_head_mlp(feat)) + arc_dep = self.arc_dep_mlp(feat) + arc_head = self.arc_head_mlp(feat) + label_dep = self.label_dep_mlp(feat) + label_head = self.label_head_mlp(feat) + del feat # biaffine arc classifier arc_pred = self.arc_predictor(arc_head, arc_dep) # [N, L, L] - flip_mask = (seq_mask == 0) - arc_pred.masked_fill_(flip_mask.unsqueeze(1), -np.inf) # use gold or predicted arc to predict label - if gold_heads is None: + if gold_heads is None or not self.training: # use greedy decoding in training if self.training or self.use_greedy_infer: heads = self._greedy_decoder(arc_pred, seq_mask) @@ -286,9 +323,15 @@ class BiaffineParser(GraphParser): heads = self._mst_decoder(arc_pred, seq_mask) head_pred = heads else: - head_pred = None - heads = gold_heads + assert self.training # must be training mode + if torch.rand(1).item() < self.explore_p: + heads = self._greedy_decoder(arc_pred, seq_mask) + head_pred = heads + else: + head_pred = None + heads = gold_heads + batch_range = torch.arange(start=0, end=batch_size, dtype=torch.long, device=word_seq.device).unsqueeze(1) label_head = label_head[batch_range, heads].contiguous() label_pred = self.label_predictor(label_head, label_dep) # [N, L, num_label] res_dict = {'arc_pred': arc_pred, 'label_pred': label_pred, 'seq_mask': seq_mask} @@ -301,7 +344,7 @@ class BiaffineParser(GraphParser): Compute loss. :param arc_pred: [batch_size, seq_len, seq_len] - :param label_pred: [batch_size, seq_len, seq_len] + :param label_pred: [batch_size, seq_len, n_tags] :param head_indices: [batch_size, seq_len] :param head_labels: [batch_size, seq_len] :param seq_mask: [batch_size, seq_len] @@ -309,10 +352,13 @@ class BiaffineParser(GraphParser): """ batch_size, seq_len, _ = arc_pred.shape - arc_logits = F.log_softmax(arc_pred, dim=2) + flip_mask = (seq_mask == 0) + _arc_pred = arc_pred.new_empty((batch_size, seq_len, seq_len)).copy_(arc_pred) + _arc_pred.masked_fill_(flip_mask.unsqueeze(1), -np.inf) + arc_logits = F.log_softmax(_arc_pred, dim=2) label_logits = F.log_softmax(label_pred, dim=2) - batch_index = torch.arange(start=0, end=batch_size, device=arc_logits.device).long().unsqueeze(1) - child_index = torch.arange(start=0, end=seq_len, device=arc_logits.device).long().unsqueeze(0) + batch_index = torch.arange(batch_size, device=arc_logits.device, dtype=torch.long).unsqueeze(1) + child_index = torch.arange(seq_len, device=arc_logits.device, dtype=torch.long).unsqueeze(0) arc_loss = arc_logits[batch_index, child_index, head_indices] label_loss = label_logits[batch_index, child_index, head_labels] @@ -320,45 +366,8 @@ class BiaffineParser(GraphParser): label_loss = label_loss[:, 1:] float_mask = seq_mask[:, 1:].float() - length = (seq_mask.sum() - batch_size).float() - arc_nll = -(arc_loss*float_mask).sum() / length - label_nll = -(label_loss*float_mask).sum() / length + arc_nll = -(arc_loss*float_mask).mean() + label_nll = -(label_loss*float_mask).mean() return arc_nll + label_nll - def evaluate(self, arc_pred, label_pred, head_indices, head_labels, seq_mask, **kwargs): - """ - Evaluate the performance of prediction. - - :return dict: performance results. - head_pred_corrct: number of correct predicted heads. - label_pred_correct: number of correct predicted labels. - total_tokens: number of predicted tokens - """ - if 'head_pred' in kwargs: - head_pred = kwargs['head_pred'] - elif self.use_greedy_infer: - head_pred = self._greedy_decoder(arc_pred, seq_mask) - else: - head_pred = self._mst_decoder(arc_pred, seq_mask) - - head_pred_correct = (head_pred == head_indices).long() * seq_mask - _, label_preds = torch.max(label_pred, dim=2) - label_pred_correct = (label_preds == head_labels).long() * head_pred_correct - return {"head_pred_correct": head_pred_correct.sum(dim=1), - "label_pred_correct": label_pred_correct.sum(dim=1), - "total_tokens": seq_mask.sum(dim=1)} - - def metrics(self, head_pred_correct, label_pred_correct, total_tokens, **_): - """ - Compute the metrics of model - - :param head_pred_corrct: number of correct predicted heads. - :param label_pred_correct: number of correct predicted labels. - :param total_tokens: number of predicted tokens - :return dict: the metrics results - UAS: the head predicted accuracy - LAS: the label predicted accuracy - """ - return {"UAS": head_pred_correct.sum().float() / total_tokens.sum().float() * 100, - "LAS": label_pred_correct.sum().float() / total_tokens.sum().float() * 100} diff --git a/fastNLP/modules/aggregator/attention.py b/fastNLP/modules/aggregator/attention.py index 5cdc77c9..69c5fdf6 100644 --- a/fastNLP/modules/aggregator/attention.py +++ b/fastNLP/modules/aggregator/attention.py @@ -1,5 +1,6 @@ import torch - +from torch import nn +import math from fastNLP.modules.utils import mask_softmax @@ -17,3 +18,44 @@ class Attention(torch.nn.Module): def _atten_forward(self, query, memory): raise NotImplementedError + +class DotAtte(nn.Module): + def __init__(self, key_size, value_size): + super(DotAtte, self).__init__() + self.key_size = key_size + self.value_size = value_size + self.scale = math.sqrt(key_size) + + def forward(self, Q, K, V, seq_mask=None): + """ + + :param Q: [batch, seq_len, key_size] + :param K: [batch, seq_len, key_size] + :param V: [batch, seq_len, value_size] + :param seq_mask: [batch, seq_len] + """ + output = torch.matmul(Q, K.transpose(1, 2)) / self.scale + if seq_mask is not None: + output.masked_fill_(seq_mask.lt(1), -float('inf')) + output = nn.functional.softmax(output, dim=2) + return torch.matmul(output, V) + +class MultiHeadAtte(nn.Module): + def __init__(self, input_size, output_size, key_size, value_size, num_atte): + super(MultiHeadAtte, self).__init__() + self.in_linear = nn.ModuleList() + for i in range(num_atte * 3): + out_feat = key_size if (i % 3) != 2 else value_size + self.in_linear.append(nn.Linear(input_size, out_feat)) + self.attes = nn.ModuleList([DotAtte(key_size, value_size) for _ in range(num_atte)]) + self.out_linear = nn.Linear(value_size * num_atte, output_size) + + def forward(self, Q, K, V, seq_mask=None): + heads = [] + for i in range(len(self.attes)): + j = i * 3 + qi, ki, vi = self.in_linear[j](Q), self.in_linear[j+1](K), self.in_linear[j+2](V) + headi = self.attes[i](qi, ki, vi, seq_mask) + heads.append(headi) + output = torch.cat(heads, dim=2) + return self.out_linear(output) diff --git a/fastNLP/modules/encoder/transformer.py b/fastNLP/modules/encoder/transformer.py new file mode 100644 index 00000000..46badcfe --- /dev/null +++ b/fastNLP/modules/encoder/transformer.py @@ -0,0 +1,32 @@ +import torch +from torch import nn +import torch.nn.functional as F + +from ..aggregator.attention import MultiHeadAtte +from ..other_modules import LayerNormalization + +class TransformerEncoder(nn.Module): + class SubLayer(nn.Module): + def __init__(self, input_size, output_size, key_size, value_size, num_atte): + super(TransformerEncoder.SubLayer, self).__init__() + self.atte = MultiHeadAtte(input_size, output_size, key_size, value_size, num_atte) + self.norm1 = LayerNormalization(output_size) + self.ffn = nn.Sequential(nn.Linear(output_size, output_size), + nn.ReLU(), + nn.Linear(output_size, output_size)) + self.norm2 = LayerNormalization(output_size) + + def forward(self, input, seq_mask): + attention = self.atte(input) + norm_atte = self.norm1(attention + input) + output = self.ffn(norm_atte) + return self.norm2(output + norm_atte) + + def __init__(self, num_layers, **kargs): + super(TransformerEncoder, self).__init__() + self.layers = nn.Sequential(*[self.SubLayer(**kargs) for _ in range(num_layers)]) + + def forward(self, x, seq_mask=None): + return self.layers(x, seq_mask) + + diff --git a/fastNLP/modules/encoder/variational_rnn.py b/fastNLP/modules/encoder/variational_rnn.py index 16bd4172..f4a37cf4 100644 --- a/fastNLP/modules/encoder/variational_rnn.py +++ b/fastNLP/modules/encoder/variational_rnn.py @@ -101,14 +101,14 @@ class VarRNNBase(nn.Module): mask_x = input.new_ones((batch_size, self.input_size)) mask_out = input.new_ones((batch_size, self.hidden_size * self.num_directions)) - mask_h = input.new_ones((batch_size, self.hidden_size)) + mask_h_ones = input.new_ones((batch_size, self.hidden_size)) nn.functional.dropout(mask_x, p=self.input_dropout, training=self.training, inplace=True) nn.functional.dropout(mask_out, p=self.hidden_dropout, training=self.training, inplace=True) - nn.functional.dropout(mask_h, p=self.hidden_dropout, training=self.training, inplace=True) hidden_list = [] for layer in range(self.num_layers): output_list = [] + mask_h = nn.functional.dropout(mask_h_ones, p=self.hidden_dropout, training=self.training, inplace=False) for direction in range(self.num_directions): input_x = input if direction == 0 else flip(input, [0]) idx = self.num_directions * layer + direction diff --git a/fastNLP/modules/other_modules.py b/fastNLP/modules/other_modules.py index ea1423be..5cd10e7e 100644 --- a/fastNLP/modules/other_modules.py +++ b/fastNLP/modules/other_modules.py @@ -31,12 +31,12 @@ class GroupNorm(nn.Module): class LayerNormalization(nn.Module): """ Layer normalization module """ - def __init__(self, d_hid, eps=1e-3): + def __init__(self, layer_size, eps=1e-3): super(LayerNormalization, self).__init__() self.eps = eps - self.a_2 = nn.Parameter(torch.ones(d_hid), requires_grad=True) - self.b_2 = nn.Parameter(torch.zeros(d_hid), requires_grad=True) + self.a_2 = nn.Parameter(torch.ones(1, layer_size, requires_grad=True)) + self.b_2 = nn.Parameter(torch.zeros(1, layer_size, requires_grad=True)) def forward(self, z): if z.size(1) == 1: @@ -44,9 +44,8 @@ class LayerNormalization(nn.Module): mu = torch.mean(z, keepdim=True, dim=-1) sigma = torch.std(z, keepdim=True, dim=-1) - ln_out = (z - mu.expand_as(z)) / (sigma.expand_as(z) + self.eps) - ln_out = ln_out * self.a_2.expand_as(ln_out) + self.b_2.expand_as(ln_out) - + ln_out = (z - mu) / (sigma + self.eps) + ln_out = ln_out * self.a_2 + self.b_2 return ln_out diff --git a/reproduction/Biaffine_parser/cfg.cfg b/reproduction/Biaffine_parser/cfg.cfg index 946e4c51..8ee6f5fe 100644 --- a/reproduction/Biaffine_parser/cfg.cfg +++ b/reproduction/Biaffine_parser/cfg.cfg @@ -1,37 +1,40 @@ [train] -epochs = 50 +epochs = -1 batch_size = 16 pickle_path = "./save/" validate = true -save_best_dev = false +save_best_dev = true +eval_sort_key = "UAS" use_cuda = true model_saved_path = "./save/" -task = "parse" - +print_every_step = 20 +use_golden_train=true [test] save_output = true validate_in_training = true save_dev_input = false save_loss = true -batch_size = 16 +batch_size = 64 pickle_path = "./save/" use_cuda = true -task = "parse" [model] word_vocab_size = -1 word_emb_dim = 100 pos_vocab_size = -1 pos_emb_dim = 100 +word_hid_dim = 100 +pos_hid_dim = 100 rnn_layers = 3 rnn_hidden_size = 400 arc_mlp_size = 500 label_mlp_size = 100 num_label = -1 dropout = 0.33 -use_var_lstm=true +use_var_lstm=false use_greedy_infer=false [optim] lr = 2e-3 +weight_decay = 5e-5 diff --git a/reproduction/Biaffine_parser/run.py b/reproduction/Biaffine_parser/run.py index cc8e54ad..45668066 100644 --- a/reproduction/Biaffine_parser/run.py +++ b/reproduction/Biaffine_parser/run.py @@ -6,15 +6,17 @@ sys.path.append(os.path.join(os.path.dirname(__file__), '../..')) from collections import defaultdict import math import torch +import re from fastNLP.core.trainer import Trainer +from fastNLP.core.metrics import Evaluator from fastNLP.core.instance import Instance from fastNLP.core.vocabulary import Vocabulary from fastNLP.core.dataset import DataSet from fastNLP.core.batch import Batch from fastNLP.core.sampler import SequentialSampler from fastNLP.core.field import TextField, SeqLabelField -from fastNLP.core.preprocess import SeqLabelPreprocess, load_pickle +from fastNLP.core.preprocess import load_pickle from fastNLP.core.tester import Tester from fastNLP.loader.config_loader import ConfigLoader, ConfigSection from fastNLP.loader.model_loader import ModelLoader @@ -22,15 +24,18 @@ from fastNLP.loader.embed_loader import EmbedLoader from fastNLP.models.biaffine_parser import BiaffineParser from fastNLP.saver.model_saver import ModelSaver +BOS = '' +EOS = '' +UNK = '' +NUM = '' +ENG = '' + # not in the file's dir if len(os.path.dirname(__file__)) != 0: os.chdir(os.path.dirname(__file__)) -class MyDataLoader(object): - def __init__(self, pickle_path): - self.pickle_path = pickle_path - - def load(self, path, word_v=None, pos_v=None, headtag_v=None): +class ConlluDataLoader(object): + def load(self, path): datalist = [] with open(path, 'r', encoding='utf-8') as f: sample = [] @@ -49,23 +54,18 @@ class MyDataLoader(object): for sample in datalist: # print(sample) res = self.get_one(sample) - if word_v is not None: - word_v.update(res[0]) - pos_v.update(res[1]) - headtag_v.update(res[3]) ds.append(Instance(word_seq=TextField(res[0], is_target=False), pos_seq=TextField(res[1], is_target=False), head_indices=SeqLabelField(res[2], is_target=True), - head_labels=TextField(res[3], is_target=True), - seq_mask=SeqLabelField([1 for _ in range(len(res[0]))], is_target=False))) + head_labels=TextField(res[3], is_target=True))) return ds def get_one(self, sample): - text = [''] - pos_tags = [''] - heads = [0] - head_tags = ['root'] + text = [] + pos_tags = [] + heads = [] + head_tags = [] for w in sample: t1, t2, t3, t4 = w[1], w[3], w[6], w[7] if t3 == '_': @@ -76,17 +76,60 @@ class MyDataLoader(object): head_tags.append(t4) return (text, pos_tags, heads, head_tags) - def index_data(self, dataset, word_v, pos_v, tag_v): - dataset.index_field('word_seq', word_v) - dataset.index_field('pos_seq', pos_v) - dataset.index_field('head_labels', tag_v) +class CTBDataLoader(object): + def load(self, data_path): + with open(data_path, "r", encoding="utf-8") as f: + lines = f.readlines() + data = self.parse(lines) + return self.convert(data) + + def parse(self, lines): + """ + [ + [word], [pos], [head_index], [head_tag] + ] + """ + sample = [] + data = [] + for i, line in enumerate(lines): + line = line.strip() + if len(line) == 0 or i+1 == len(lines): + data.append(list(map(list, zip(*sample)))) + sample = [] + else: + sample.append(line.split()) + return data + + def convert(self, data): + dataset = DataSet() + for sample in data: + word_seq = [BOS] + sample[0] + [EOS] + pos_seq = [BOS] + sample[1] + [EOS] + heads = [0] + list(map(int, sample[2])) + [0] + head_tags = [BOS] + sample[3] + [EOS] + dataset.append(Instance(word_seq=TextField(word_seq, is_target=False), + pos_seq=TextField(pos_seq, is_target=False), + gold_heads=SeqLabelField(heads, is_target=False), + head_indices=SeqLabelField(heads, is_target=True), + head_labels=TextField(head_tags, is_target=True))) + return dataset # datadir = "/mnt/c/Me/Dev/release-2.2-st-train-dev-data/ud-treebanks-v2.2/UD_English-EWT" -datadir = "/home/yfshao/UD_English-EWT" +# datadir = "/home/yfshao/UD_English-EWT" +# train_data_name = "en_ewt-ud-train.conllu" +# dev_data_name = "en_ewt-ud-dev.conllu" +# emb_file_name = '/home/yfshao/glove.6B.100d.txt' +# loader = ConlluDataLoader() + +datadir = '/home/yfshao/workdir/parser-data/' +train_data_name = "train_ctb5.txt" +dev_data_name = "dev_ctb5.txt" +test_data_name = "test_ctb5.txt" +emb_file_name = "/home/yfshao/workdir/parser-data/word_OOVthr_30_100v.txt" +# emb_file_name = "/home/yfshao/workdir/word_vector/cc.zh.300.vec" +loader = CTBDataLoader() + cfgfile = './cfg.cfg' -train_data_name = "en_ewt-ud-train.conllu" -dev_data_name = "en_ewt-ud-dev.conllu" -emb_file_name = '/home/yfshao/glove.6B.100d.txt' processed_datadir = './save' # Config Loader @@ -95,8 +138,12 @@ test_args = ConfigSection() model_args = ConfigSection() optim_args = ConfigSection() ConfigLoader.load_config(cfgfile, {"train": train_args, "test": test_args, "model": model_args, "optim": optim_args}) +print('trainre Args:', train_args.data) +print('test Args:', test_args.data) +print('optim Args:', optim_args.data) -# Data Loader + +# Pickle Loader def save_data(dirpath, **kwargs): import _pickle if not os.path.exists(dirpath): @@ -117,38 +164,57 @@ def load_data(dirpath): datas[name] = _pickle.load(f) return datas -class MyTester(object): - def __init__(self, batch_size, use_cuda=False, **kwagrs): - self.batch_size = batch_size - self.use_cuda = use_cuda - - def test(self, model, dataset): - self.model = model.cuda() if self.use_cuda else model - self.model.eval() - batchiter = Batch(dataset, self.batch_size, SequentialSampler(), self.use_cuda) - eval_res = defaultdict(list) - i = 0 - for batch_x, batch_y in batchiter: - with torch.no_grad(): - pred_y = self.model(**batch_x) - eval_one = self.model.evaluate(**pred_y, **batch_y) - i += self.batch_size - for eval_name, tensor in eval_one.items(): - eval_res[eval_name].append(tensor) - tmp = {} - for eval_name, tensorlist in eval_res.items(): - tmp[eval_name] = torch.cat(tensorlist, dim=0) - - self.res = self.model.metrics(**tmp) - - def show_metrics(self): - s = "" - for name, val in self.res.items(): - s += '{}: {:.2f}\t'.format(name, val) - return s - - -loader = MyDataLoader('') +def P2(data, field, length): + ds = [ins for ins in data if ins[field].get_length() >= length] + data.clear() + data.extend(ds) + return ds + +def P1(data, field): + def reeng(w): + return w if w == BOS or w == EOS or re.search(r'^([a-zA-Z]+[\.\-]*)+$', w) is None else ENG + def renum(w): + return w if re.search(r'^[0-9]+\.?[0-9]*$', w) is None else NUM + for ins in data: + ori = ins[field].contents() + s = list(map(renum, map(reeng, ori))) + if s != ori: + # print(ori) + # print(s) + # print() + ins[field] = ins[field].new(s) + return data + +class ParserEvaluator(Evaluator): + def __init__(self, ignore_label): + super(ParserEvaluator, self).__init__() + self.ignore = ignore_label + + def __call__(self, predict_list, truth_list): + head_all, label_all, total_all = 0, 0, 0 + for pred, truth in zip(predict_list, truth_list): + head, label, total = self.evaluate(**pred, **truth) + head_all += head + label_all += label + total_all += total + + return {'UAS': head_all*1.0 / total_all, 'LAS': label_all*1.0 / total_all} + + def evaluate(self, head_pred, label_pred, head_indices, head_labels, seq_mask, **_): + """ + Evaluate the performance of prediction. + + :return : performance results. + head_pred_corrct: number of correct predicted heads. + label_pred_correct: number of correct predicted labels. + total_tokens: number of predicted tokens + """ + seq_mask *= (head_labels != self.ignore).long() + head_pred_correct = (head_pred == head_indices).long() * seq_mask + _, label_preds = torch.max(label_pred, dim=2) + label_pred_correct = (label_preds == head_labels).long() * head_pred_correct + return head_pred_correct.sum().item(), label_pred_correct.sum().item(), seq_mask.sum().item() + try: data_dict = load_data(processed_datadir) word_v = data_dict['word_v'] @@ -156,62 +222,90 @@ try: tag_v = data_dict['tag_v'] train_data = data_dict['train_data'] dev_data = data_dict['dev_data'] + test_data = data_dict['test_data'] print('use saved pickles') except Exception as _: print('load raw data and preprocess') + # use pretrain embedding word_v = Vocabulary(need_default=True, min_freq=2) + word_v.unknown_label = UNK pos_v = Vocabulary(need_default=True) tag_v = Vocabulary(need_default=False) - train_data = loader.load(os.path.join(datadir, train_data_name), word_v, pos_v, tag_v) + train_data = loader.load(os.path.join(datadir, train_data_name)) dev_data = loader.load(os.path.join(datadir, dev_data_name)) - save_data(processed_datadir, word_v=word_v, pos_v=pos_v, tag_v=tag_v, train_data=train_data, dev_data=dev_data) + test_data = loader.load(os.path.join(datadir, test_data_name)) + train_data.update_vocab(word_seq=word_v, pos_seq=pos_v, head_labels=tag_v) + datasets = (train_data, dev_data, test_data) + save_data(processed_datadir, word_v=word_v, pos_v=pos_v, tag_v=tag_v, train_data=train_data, dev_data=dev_data, test_data=test_data) -loader.index_data(train_data, word_v, pos_v, tag_v) -loader.index_data(dev_data, word_v, pos_v, tag_v) -print(len(train_data)) -print(len(dev_data)) -ep = train_args['epochs'] -train_args['epochs'] = math.ceil(50000.0 / len(train_data) * train_args['batch_size']) if ep <= 0 else ep +embed, _ = EmbedLoader.load_embedding(model_args['word_emb_dim'], emb_file_name, 'glove', word_v, os.path.join(processed_datadir, 'word_emb.pkl')) + +print(len(word_v)) +print(embed.size()) + +# Model model_args['word_vocab_size'] = len(word_v) model_args['pos_vocab_size'] = len(pos_v) model_args['num_label'] = len(tag_v) +model = BiaffineParser(**model_args.data) +model.reset_parameters() +datasets = (train_data, dev_data, test_data) +for ds in datasets: + ds.index_field("word_seq", word_v).index_field("pos_seq", pos_v).index_field("head_labels", tag_v) + ds.set_origin_len('word_seq') +if train_args['use_golden_train']: + train_data.set_target(gold_heads=False) +else: + train_data.set_target(gold_heads=None) +train_args.data.pop('use_golden_train') +ignore_label = pos_v['P'] + +print(test_data[0]) +print(len(train_data)) +print(len(dev_data)) +print(len(test_data)) + + -def train(): +def train(path): # Trainer trainer = Trainer(**train_args.data) def _define_optim(obj): - obj._optimizer = torch.optim.Adam(obj._model.parameters(), **optim_args.data) - obj._scheduler = torch.optim.lr_scheduler.LambdaLR(obj._optimizer, lambda ep: .75 ** (ep / 5e4)) + lr = optim_args.data['lr'] + embed_params = set(obj._model.word_embedding.parameters()) + decay_params = set(obj._model.arc_predictor.parameters()) | set(obj._model.label_predictor.parameters()) + params = [p for p in obj._model.parameters() if p not in decay_params and p not in embed_params] + obj._optimizer = torch.optim.Adam([ + {'params': list(embed_params), 'lr':lr*0.1}, + {'params': list(decay_params), **optim_args.data}, + {'params': params} + ], lr=lr, betas=(0.9, 0.9)) + obj._scheduler = torch.optim.lr_scheduler.LambdaLR(obj._optimizer, lambda ep: max(.75 ** (ep / 5e4), 0.05)) def _update(obj): + # torch.nn.utils.clip_grad_norm_(obj._model.parameters(), 5.0) obj._scheduler.step() obj._optimizer.step() trainer.define_optimizer = lambda: _define_optim(trainer) trainer.update = lambda: _update(trainer) - trainer.get_loss = lambda predict, truth: trainer._loss_func(**predict, **truth) - trainer._create_validator = lambda x: MyTester(**test_args.data) - - # Model - model = BiaffineParser(**model_args.data) + trainer.set_validator(Tester(**test_args.data, evaluator=ParserEvaluator(ignore_label))) - # use pretrain embedding - embed, _ = EmbedLoader.load_embedding(model_args['word_emb_dim'], emb_file_name, 'glove', word_v, os.path.join(processed_datadir, 'word_emb.pkl')) model.word_embedding = torch.nn.Embedding.from_pretrained(embed, freeze=False) model.word_embedding.padding_idx = word_v.padding_idx model.word_embedding.weight.data[word_v.padding_idx].fill_(0) model.pos_embedding.padding_idx = pos_v.padding_idx model.pos_embedding.weight.data[pos_v.padding_idx].fill_(0) - try: - ModelLoader.load_pytorch(model, "./save/saved_model.pkl") - print('model parameter loaded!') - except Exception as _: - print("No saved model. Continue.") - pass + # try: + # ModelLoader.load_pytorch(model, "./save/saved_model.pkl") + # print('model parameter loaded!') + # except Exception as _: + # print("No saved model. Continue.") + # pass # Start training trainer.train(model, train_data, dev_data) @@ -223,24 +317,27 @@ def train(): print("Model saved!") -def test(): +def test(path): # Tester - tester = MyTester(**test_args.data) + tester = Tester(**test_args.data, evaluator=ParserEvaluator(ignore_label)) # Model model = BiaffineParser(**model_args.data) - + model.eval() try: - ModelLoader.load_pytorch(model, "./save/saved_model.pkl") + ModelLoader.load_pytorch(model, path) print('model parameter loaded!') except Exception as _: print("No saved model. Abort test.") raise # Start training + print("Testing Train data") + tester.test(model, train_data) + print("Testing Dev data") tester.test(model, dev_data) - print(tester.show_metrics()) - print("Testing finished!") + print("Testing Test data") + tester.test(model, test_data) @@ -248,11 +345,12 @@ if __name__ == "__main__": import argparse parser = argparse.ArgumentParser(description='Run a chinese word segmentation model') parser.add_argument('--mode', help='set the model\'s model', choices=['train', 'test', 'infer']) + parser.add_argument('--path', type=str, default='') args = parser.parse_args() if args.mode == 'train': - train() + train(args.path) elif args.mode == 'test': - test() + test(args.path) elif args.mode == 'infer': infer() else: