From 29eab18b78a813eed76515325eaf0f3bffca1eb7 Mon Sep 17 00:00:00 2001 From: yh Date: Wed, 6 Feb 2019 22:26:10 +0800 Subject: [PATCH 1/5] =?UTF-8?q?1.=20CRF=E5=A2=9E=E5=8A=A0=E6=94=AF?= =?UTF-8?q?=E6=8C=81bmeso=E7=B1=BB=E5=9E=8B=E7=9A=84tag=202.=20vocabulary?= =?UTF-8?q?=E4=B8=AD=E5=A2=9E=E5=8A=A0=E6=B3=A8=E9=87=8A?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- fastNLP/core/metrics.py | 46 +++++++++++++++++++++++++++++++--- fastNLP/core/vocabulary.py | 10 +++++--- fastNLP/io/dataset_loader.py | 8 ++++++ fastNLP/modules/decoder/CRF.py | 16 ++++++++++-- 4 files changed, 71 insertions(+), 9 deletions(-) diff --git a/fastNLP/core/metrics.py b/fastNLP/core/metrics.py index dfb20480..8b51e23c 100644 --- a/fastNLP/core/metrics.py +++ b/fastNLP/core/metrics.py @@ -296,6 +296,8 @@ class AccuracyMetric(MetricBase): def bmes_tag_to_spans(tags, ignore_labels=None): """ + 给定一个tags的lis,比如['S', 'B-singer', 'M-singer', 'E-singer', 'S', 'S']。 + 返回[('', (0, 1)), ('singer', (1, 2)), ('singer', (2, 3)), ('singer', (3, 4)), ('', (4, 5)), ('', (5, 6))] :param tags: List[str], :param ignore_labels: List[str], 在该list中的label将被忽略 @@ -315,13 +317,45 @@ def bmes_tag_to_spans(tags, ignore_labels=None): else: spans.append((label, [idx, idx])) prev_bmes_tag = bmes_tag - return [(span[0], (span[1][0], span[1][1])) + return [(span[0], (span[1][0], span[1][1]+1)) + for span in spans + if span[0] not in ignore_labels + ] + +def bmeso_tag_to_spans(tags, ignore_labels=None): + """ + 给定一个tags的lis,比如['O', 'B-singer', 'M-singer', 'E-singer', 'O', 'O']。 + 返回[('singer', (1, 2)), ('singer', (2, 3)), ('singer', (3, 4))] + + :param tags: List[str], + :param ignore_labels: List[str], 在该list中的label将被忽略 + :return: List[Tuple[str, List[int, int]]]. [(label,[start, end])] + """ + ignore_labels = set(ignore_labels) if ignore_labels else set() + + spans = [] + prev_bmes_tag = None + for idx, tag in enumerate(tags): + tag = tag.lower() + bmes_tag, label = tag[:1], tag[2:] + if bmes_tag in ('b', 's'): + spans.append((label, [idx, idx])) + elif bmes_tag in ('m', 'e') and prev_bmes_tag in ('b', 'm') and label==spans[-1][0]: + spans[-1][1][1] = idx + elif bmes_tag == 'o': + pass + else: + spans.append((label, [idx, idx])) + prev_bmes_tag = bmes_tag + return [(span[0], (span[1][0], span[1][1]+1)) for span in spans if span[0] not in ignore_labels ] def bio_tag_to_spans(tags, ignore_labels=None): """ + 给定一个tags的lis,比如['O', 'B-singer', 'I-singer', 'I-singer', 'O', 'O']。 + 返回[('singer', (1, 4))] (特别注意这是左闭右开区间) :param tags: List[str], :param ignore_labels: List[str], 在该list中的label将被忽略 @@ -343,7 +377,7 @@ def bio_tag_to_spans(tags, ignore_labels=None): else: spans.append((label, [idx, idx])) prev_bio_tag = bio_tag - return [(span[0], (span[1][0], span[1][1])) + return [(span[0], (span[1][0], span[1][1]+1)) for span in spans if span[0] not in ignore_labels ] @@ -390,8 +424,7 @@ class SpanFPreRecMetric(MetricBase): 则精确率的权重高于召回率;若为1,则两者平等;若为2,则召回率权重高于精确率。 """ encoding_type = encoding_type.lower() - if encoding_type not in ('bio', 'bmes'): - raise ValueError("Only support 'bio' or 'bmes' type.") + if not isinstance(tag_vocab, Vocabulary): raise TypeError("tag_vocab can only be fastNLP.Vocabulary, not {}.".format(type(tag_vocab))) if f_type not in ('micro', 'macro'): @@ -402,6 +435,11 @@ class SpanFPreRecMetric(MetricBase): self.tag_to_span_func = bmes_tag_to_spans elif self.encoding_type == 'bio': self.tag_to_span_func = bio_tag_to_spans + elif self.encoding_type == 'bmeso': + self.tag_to_span_func = bmeso_tag_to_spans + else: + raise ValueError("Only support 'bio', 'bmes', 'bmeso' type.") + self.ignore_labels = ignore_labels self.f_type = f_type self.beta = beta diff --git a/fastNLP/core/vocabulary.py b/fastNLP/core/vocabulary.py index 50a79d24..987a3527 100644 --- a/fastNLP/core/vocabulary.py +++ b/fastNLP/core/vocabulary.py @@ -44,10 +44,14 @@ class Vocabulary(object): :param int max_size: set the max number of words in Vocabulary. Default: None :param int min_freq: set the min occur frequency of words in Vocabulary. Default: None + :param padding: str, padding的字符,默认为。如果设置为None,则vocabulary中不考虑padding,为None的情况多在为label建立 + Vocabulary的情况。 + :param unknown: str, unknown的字符,默认为。如果设置为None,则vocabulary中不考虑unknown,为None的情况多在为label建立 + Vocabulary的情况。 """ - def __init__(self, max_size=None, min_freq=None, unknown='', padding=''): + def __init__(self, max_size=None, min_freq=None, padding='', unknown=''): self.max_size = max_size self.min_freq = min_freq self.word_count = Counter() @@ -97,9 +101,9 @@ class Vocabulary(object): """ self.word2idx = {} if self.padding is not None: - self.word2idx[self.padding] = 0 + self.word2idx[self.padding] = len(self.word2idx) if self.unknown is not None: - self.word2idx[self.unknown] = 1 + self.word2idx[self.unknown] = len(self.word2idx) max_size = min(self.max_size, len(self.word_count)) if self.max_size else None words = self.word_count.most_common(max_size) diff --git a/fastNLP/io/dataset_loader.py b/fastNLP/io/dataset_loader.py index 1fcdb7d9..09fce24f 100644 --- a/fastNLP/io/dataset_loader.py +++ b/fastNLP/io/dataset_loader.py @@ -877,6 +877,14 @@ class ConllPOSReader(object): class ConllxDataLoader(object): def load(self, path): + """ + + :param path: str,存储数据的路径 + :return: DataSet。内含field有'words', 'pos_tags', 'heads', 'labels'(parser的label) + 类似于拥有以下结构, 一行为一个instance(sample) + words pos_tags heads labels + ['some', ..] ['NN', ...] [2, 3...] ['nn', 'nn'...] + """ datalist = [] with open(path, 'r', encoding='utf-8') as f: sample = [] diff --git a/fastNLP/modules/decoder/CRF.py b/fastNLP/modules/decoder/CRF.py index d7db3bf9..e1b68e7a 100644 --- a/fastNLP/modules/decoder/CRF.py +++ b/fastNLP/modules/decoder/CRF.py @@ -25,7 +25,7 @@ def allowed_transitions(id2label, encoding_type='bio'): :param dict id2label: key是label的indices,value是str类型的tag或tag-label。value可以是只有tag的, 比如"B", "M"; 也可以是 "B-NN", "M-NN", tag和label之间一定要用"-"隔开。一般可以通过Vocabulary.get_id2word()id2label。 - :param encoding_type: str, 支持"bio", "bmes"。 + :param encoding_type: str, 支持"bio", "bmes", "bmeso"。 :return: List[Tuple(int, int)]], 内部的Tuple是(from_tag_id, to_tag_id)。 返回的结果考虑了start和end,比如"BIO"中,B、O可以 位于序列的开端,而I不行。所以返回的结果中会包含(start_idx, B_idx), (start_idx, O_idx), 但是不包含(start_idx, I_idx). start_idx=len(id2label), end_idx=len(id2label)+1。 @@ -62,7 +62,7 @@ def allowed_transitions(id2label, encoding_type='bio'): def is_transition_allowed(encoding_type, from_tag, from_label, to_tag, to_label): """ - :param encoding_type: str, 支持"BIO", "BMES"。 + :param encoding_type: str, 支持"BIO", "BMES", "BEMSO"。 :param from_tag: str, 比如"B", "M"之类的标注tag. 还包括start, end等两种特殊tag :param from_label: str, 比如"PER", "LOC"等label :param to_tag: str, 比如"B", "M"之类的标注tag. 还包括start, end等两种特殊tag @@ -127,6 +127,18 @@ def is_transition_allowed(encoding_type, from_tag, from_label, to_tag, to_label) return to_tag in ['b', 's', 'end'] else: raise ValueError("Unexpect tag type {}. Expect only 'B', 'M', 'E', 'S'.".format(from_tag)) + elif encoding_type == 'bmeso': + if from_tag == 'start': + return to_tag in ['b', 's', 'o'] + elif from_tag == 'b': + return to_tag in ['m', 'e'] and from_label==to_label + elif from_tag == 'm': + return to_tag in ['m', 'e'] and from_label==to_label + elif from_tag in ['e', 's', 'o']: + return to_tag in ['b', 's', 'end', 'o'] + else: + raise ValueError("Unexpect tag type {}. Expect only 'B', 'M', 'E', 'S', 'O'.".format(from_tag)) + else: raise ValueError("Only support BIO, BMES encoding type, got {}.".format(encoding_type)) From 5eb126dbcd300650bd4effccc9061fa67abe2c9c Mon Sep 17 00:00:00 2001 From: yh Date: Sat, 9 Feb 2019 13:47:13 +0800 Subject: [PATCH 2/5] =?UTF-8?q?BucketSampler=E5=A2=9E=E5=8A=A0=E4=B8=80?= =?UTF-8?q?=E6=9D=A1=E9=94=99=E8=AF=AF=E6=A3=80=E6=B5=8B?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- fastNLP/core/sampler.py | 1 + 1 file changed, 1 insertion(+) diff --git a/fastNLP/core/sampler.py b/fastNLP/core/sampler.py index 67ec2a8d..4a523f10 100644 --- a/fastNLP/core/sampler.py +++ b/fastNLP/core/sampler.py @@ -73,6 +73,7 @@ class BucketSampler(BaseSampler): total_sample_num = len(seq_lens) bucket_indexes = [] + assert total_sample_num>=self.num_buckets, "The number of samples is smaller than the number of buckets." num_sample_per_bucket = total_sample_num // self.num_buckets for i in range(self.num_buckets): bucket_indexes.append([num_sample_per_bucket * i, num_sample_per_bucket * (i + 1)]) From 3ea7de16732c14ddeed4655669a4be89241c9c99 Mon Sep 17 00:00:00 2001 From: yh Date: Thu, 14 Feb 2019 13:18:50 +0800 Subject: [PATCH 3/5] =?UTF-8?q?1.=E4=BF=AE=E6=94=B9ClipGradientCallback?= =?UTF-8?q?=E7=9A=84bug=EF=BC=9B=E5=88=A0=E9=99=A4LRSchedulerCallback?= =?UTF-8?q?=E4=B8=AD=E7=9A=84print=EF=BC=8C=E4=B9=8B=E5=90=8E=E5=BA=94?= =?UTF-8?q?=E8=AF=A5=E4=BC=A0=E5=85=A5pbar=E8=BF=9B=E8=A1=8C=E6=89=93?= =?UTF-8?q?=E5=8D=B0;2.=E5=A2=9E=E5=8A=A0MLP=E6=B3=A8=E9=87=8A?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- fastNLP/core/callback.py | 6 ++++-- fastNLP/modules/decoder/MLP.py | 2 +- 2 files changed, 5 insertions(+), 3 deletions(-) diff --git a/fastNLP/core/callback.py b/fastNLP/core/callback.py index b1a480cc..d941c235 100644 --- a/fastNLP/core/callback.py +++ b/fastNLP/core/callback.py @@ -248,7 +248,10 @@ class GradientClipCallback(Callback): self.clip_value = clip_value def on_backward_end(self, model): - self.clip_fun(model.parameters(), self.clip_value) + if self.parameters is None: + self.clip_fun(model.parameters(), self.clip_value) + else: + self.clip_fun(self.parameters, self.clip_value) class CallbackException(BaseException): @@ -306,7 +309,6 @@ class LRScheduler(Callback): def on_epoch_begin(self, cur_epoch, total_epoch): self.scheduler.step() - print("scheduler step ", "lr=", self.trainer.optimizer.param_groups[0]["lr"]) class ControlC(Callback): diff --git a/fastNLP/modules/decoder/MLP.py b/fastNLP/modules/decoder/MLP.py index c9198859..b76fdab7 100644 --- a/fastNLP/modules/decoder/MLP.py +++ b/fastNLP/modules/decoder/MLP.py @@ -7,7 +7,7 @@ from fastNLP.modules.utils import initial_parameter class MLP(nn.Module): """Multilayer Perceptrons as a decoder - :param list size_layer: list of int, define the size of MLP layers. + :param list size_layer: list of int, define the size of MLP layers. layer的层数为(len(size_layer)-1)//2 + 1 :param str activation: str or function, the activation function for hidden layers. :param str initial_method: the name of initialization method. :param float dropout: the probability of dropout. From ee677d5d550a0b947dbd127094c6d5aa02a23e6a Mon Sep 17 00:00:00 2001 From: xuyige Date: Sun, 17 Feb 2019 02:12:33 +0800 Subject: [PATCH 4/5] update MLP module --- fastNLP/modules/decoder/MLP.py | 54 +++++++++++++++++++++++++--------- 1 file changed, 40 insertions(+), 14 deletions(-) diff --git a/fastNLP/modules/decoder/MLP.py b/fastNLP/modules/decoder/MLP.py index b76fdab7..d75f6b48 100644 --- a/fastNLP/modules/decoder/MLP.py +++ b/fastNLP/modules/decoder/MLP.py @@ -7,20 +7,24 @@ from fastNLP.modules.utils import initial_parameter class MLP(nn.Module): """Multilayer Perceptrons as a decoder - :param list size_layer: list of int, define the size of MLP layers. layer的层数为(len(size_layer)-1)//2 + 1 - :param str activation: str or function, the activation function for hidden layers. + :param list size_layer: list of int, define the size of MLP layers. layer的层数为 len(size_layer) - 1 + :param str or list activation: str or function or a list, the activation function for hidden layers. + :param str or function output_activation : str or function, the activation function for output layer :param str initial_method: the name of initialization method. :param float dropout: the probability of dropout. .. note:: - There is no activation function applying on output layer. - + 隐藏层的激活函数通过activation定义。一个str/function或者一个str/function的list可以被传入activation。 + 如果只传入了一个str/function,那么所有隐藏层的激活函数都由这个str/function定义; + 如果传入了一个str/function的list,那么每一个隐藏层的激活函数由这个list中对应的元素定义,其中list的长度为隐藏层数。 + 输出层的激活函数由output_activation定义,默认值为None,此时输出层没有激活函数。 """ - def __init__(self, size_layer, activation='relu', initial_method=None, dropout=0.0): + def __init__(self, size_layer, activation='relu', output_activation=None, initial_method=None, dropout=0.0): super(MLP, self).__init__() self.hiddens = nn.ModuleList() self.output = None + self.output_activation = output_activation for i in range(1, len(size_layer)): if i + 1 == len(size_layer): self.output = nn.Linear(size_layer[i-1], size_layer[i]) @@ -33,25 +37,47 @@ class MLP(nn.Module): 'relu': nn.ReLU(), 'tanh': nn.Tanh(), } - if activation in actives: - self.hidden_active = actives[activation] - elif callable(activation): - self.hidden_active = activation + if not isinstance(activation, list): + activation = [activation] * (len(size_layer) - 2) + elif len(activation) == len(size_layer) - 2: + pass else: - raise ValueError("should set activation correctly: {}".format(activation)) + raise ValueError( + f"the length of activation function list except {len(size_layer) - 2} but got {len(activation)}!") + self.hidden_active = [] + for func in activation: + if callable(activation): + self.hidden_active.append(activation) + elif func.lower() in actives: + self.hidden_active.append(actives[func]) + else: + raise ValueError("should set activation correctly: {}".format(activation)) + if self.output_activation is not None: + if callable(self.output_activation): + pass + elif self.output_activation.lower() in actives: + self.output_activation = actives[self.output_activation] + else: + raise ValueError("should set activation correctly: {}".format(activation)) initial_parameter(self, initial_method) def forward(self, x): - for layer in self.hiddens: - x = self.dropout(self.hidden_active(layer(x))) - x = self.dropout(self.output(x)) + for layer, func in zip(self.hiddens, self.hidden_active): + x = self.dropout(func(layer(x))) + x = self.output(x) + if self.output_activation is not None: + x = self.output_activation(x) + x = self.dropout(x) return x if __name__ == '__main__': net1 = MLP([5, 10, 5]) net2 = MLP([5, 10, 5], 'tanh') - for net in [net1, net2]: + net3 = MLP([5, 6, 7, 8, 5], 'tanh') + net4 = MLP([5, 6, 7, 8, 5], 'relu', output_activation='tanh') + net5 = MLP([5, 6, 7, 8, 5], ['tanh', 'relu', 'tanh'], 'tanh') + for net in [net1, net2, net3, net4, net5]: x = torch.randn(5, 5) y = net(x) print(x) From 8d4f26bbd9cc6a43c1c98cf4ae79c44d59749f4e Mon Sep 17 00:00:00 2001 From: yh Date: Sun, 17 Feb 2019 14:16:19 +0800 Subject: [PATCH 5/5] =?UTF-8?q?=E5=A2=9E=E5=8A=A0metric=E6=B3=A8=E9=87=8A?= =?UTF-8?q?=EF=BC=9B=E4=BF=AE=E6=94=B9trainer=20save=E8=BF=87=E7=A8=8B?= =?UTF-8?q?=E4=B8=AD=E7=9A=84bug?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- fastNLP/core/metrics.py | 72 ++++++++++++++++++++++++++++++++++++++--- fastNLP/core/trainer.py | 4 +-- 2 files changed, 70 insertions(+), 6 deletions(-) diff --git a/fastNLP/core/metrics.py b/fastNLP/core/metrics.py index 8b51e23c..54fde815 100644 --- a/fastNLP/core/metrics.py +++ b/fastNLP/core/metrics.py @@ -16,6 +16,69 @@ from fastNLP.core.vocabulary import Vocabulary class MetricBase(object): """Base class for all metrics. + 所有的传入到Trainer, Tester的Metric需要继承自该对象。需要覆盖写入evaluate(), get_metric()方法。 + evaluate(xxx)中传入的是一个batch的数据。 + get_metric(xxx)当所有数据处理完毕,调用该方法得到最终的metric值 + 以分类问题中,Accuracy计算为例 + 假设model的forward返回dict中包含'pred'这个key, 并且该key需要用于Accuracy + class Model(nn.Module): + def __init__(xxx): + # do something + def forward(self, xxx): + # do something + return {'pred': pred, 'other_keys':xxx} # pred's shape: batch_size x num_classes + 假设dataset中'label'这个field是需要预测的值,并且该field被设置为了target + 对应的AccMetric可以按如下的定义 + # version1, 只使用这一次 + class AccMetric(MetricBase): + def __init__(self): + super().__init__() + + # 根据你的情况自定义指标 + self.corr_num = 0 + self.total = 0 + + def evaluate(self, label, pred): # 这里的名称需要和dataset中target field与model返回的key是一样的,不然找不到对应的value + # dev或test时,每个batch结束会调用一次该方法,需要实现如何根据每个batch累加metric + self.total += label.size(0) + self.corr_num += label.eq(pred).sum().item() + + def get_metric(self, reset=True): # 在这里定义如何计算metric + acc = self.corr_num/self.total + if reset: # 是否清零以便重新计算 + self.corr_num = 0 + self.total = 0 + return {'acc': acc} # 需要返回一个dict,key为该metric的名称,该名称会显示到Trainer的progress bar中 + + + # version2,如果需要复用Metric,比如下一次使用AccMetric时,dataset中目标field不叫label而叫y,或者model的输出不是pred + class AccMetric(MetricBase): + def __init__(self, label=None, pred=None): + # 假设在另一场景使用时,目标field叫y,model给出的key为pred_y。则只需要在初始化AccMetric时, + # acc_metric = AccMetric(label='y', pred='pred_y')即可。 + # 当初始化为acc_metric = AccMetric(),即label=None, pred=None, fastNLP会直接使用'label', 'pred'作为key去索取对 + # 应的的值 + super().__init__() + self._init_param_map(label=label, pred=pred) # 该方法会注册label和pred. 仅需要注册evaluate()方法会用到的参数名即可 + # 如果没有注册该则效果与version1就是一样的 + + # 根据你的情况自定义指标 + self.corr_num = 0 + self.total = 0 + + def evaluate(self, label, pred): # 这里的参数名称需要和self._init_param_map()注册时一致。 + # dev或test时,每个batch结束会调用一次该方法,需要实现如何根据每个batch累加metric + self.total += label.size(0) + self.corr_num += label.eq(pred).sum().item() + + def get_metric(self, reset=True): # 在这里定义如何计算metric + acc = self.corr_num/self.total + if reset: # 是否清零以便重新计算 + self.corr_num = 0 + self.total = 0 + return {'acc': acc} # 需要返回一个dict,key为该metric的名称,该名称会显示到Trainer的progress bar中 + + ``MetricBase`` handles validity check of its input dictionaries - ``pred_dict`` and ``target_dict``. ``pred_dict`` is the output of ``forward()`` or prediction function of a model. ``target_dict`` is the ground truth from DataSet where ``is_target`` is set ``True``. @@ -24,7 +87,6 @@ class MetricBase(object): 1. whether self.evaluate has varargs, which is not supported. 2. whether params needed by self.evaluate is not included in ``pred_dict``, ``target_dict``. 3. whether params needed by self.evaluate duplicate in ``pred_dict``, ``target_dict``. - 4. whether params in ``pred_dict``, ``target_dict`` are not used by evaluate.(Might cause warning) Besides, before passing params into self.evaluate, this function will filter out params from output_dict and target_dict which are not used in self.evaluate. (but if **kwargs presented in self.evaluate, no filtering @@ -297,7 +359,7 @@ class AccuracyMetric(MetricBase): def bmes_tag_to_spans(tags, ignore_labels=None): """ 给定一个tags的lis,比如['S', 'B-singer', 'M-singer', 'E-singer', 'S', 'S']。 - 返回[('', (0, 1)), ('singer', (1, 2)), ('singer', (2, 3)), ('singer', (3, 4)), ('', (4, 5)), ('', (5, 6))] + 返回[('', (0, 1)), ('singer', (1, 4)), ('', (4, 5)), ('', (5, 6))] (左闭右开区间) :param tags: List[str], :param ignore_labels: List[str], 在该list中的label将被忽略 @@ -325,7 +387,7 @@ def bmes_tag_to_spans(tags, ignore_labels=None): def bmeso_tag_to_spans(tags, ignore_labels=None): """ 给定一个tags的lis,比如['O', 'B-singer', 'M-singer', 'E-singer', 'O', 'O']。 - 返回[('singer', (1, 2)), ('singer', (2, 3)), ('singer', (3, 4))] + 返回[('singer', (1, 4))] (左闭右开区间) :param tags: List[str], :param ignore_labels: List[str], 在该list中的label将被忽略 @@ -355,7 +417,7 @@ def bmeso_tag_to_spans(tags, ignore_labels=None): def bio_tag_to_spans(tags, ignore_labels=None): """ 给定一个tags的lis,比如['O', 'B-singer', 'I-singer', 'I-singer', 'O', 'O']。 - 返回[('singer', (1, 4))] (特别注意这是左闭右开区间) + 返回[('singer', (1, 4))] (左闭右开区间) :param tags: List[str], :param ignore_labels: List[str], 在该list中的label将被忽略 @@ -386,6 +448,8 @@ def bio_tag_to_spans(tags, ignore_labels=None): class SpanFPreRecMetric(MetricBase): """ 在序列标注问题中,以span的方式计算F, pre, rec. + 比如中文Part of speech中,会以character的方式进行标注,句子'中国在亚洲'对应的POS可能为(以BMES为例) + ['B-NN', 'E-NN', 'S-DET', 'B-NN', 'E-NN']。该metric就是为类似情况下的F1计算。 最后得到的metric结果为 { 'f': xxx, # 这里使用f考虑以后可以计算f_beta值 diff --git a/fastNLP/core/trainer.py b/fastNLP/core/trainer.py index ddd35b28..5381fc5d 100644 --- a/fastNLP/core/trainer.py +++ b/fastNLP/core/trainer.py @@ -202,7 +202,7 @@ class Trainer(object): except (CallbackException, KeyboardInterrupt) as e: self.callback_manager.on_exception(e, self.model) - if self.dev_data is not None: + if self.dev_data is not None and hasattr(self, 'best_dev_perf'): print("\nIn Epoch:{}/Step:{}, got best dev performance:".format(self.best_dev_epoch, self.best_dev_step) + self.tester._format_eval_results(self.best_dev_perf),) results['best_eval'] = self.best_dev_perf @@ -367,7 +367,7 @@ class Trainer(object): else: model.cpu() torch.save(model, model_path) - model.cuda() + model.to(self._model_device) def _load_model(self, model, model_name, only_param=False): # 返回bool值指示是否成功reload模型