# Conflicts: # fastNLP/core/callback.py # fastNLP/io/dataset_loader.pytags/v0.4.10
@@ -249,6 +249,11 @@ class GradientClipCallback(Callback): | |||
self.parameters = parameters | |||
self.clip_value = clip_value | |||
def on_backward_end(self, model): | |||
if self.parameters is None: | |||
self.clip_fun(model.parameters(), self.clip_value) | |||
else: | |||
self.clip_fun(self.parameters, self.clip_value) | |||
def on_backward_end(self): | |||
self.clip_fun(self.model.parameters(), self.clip_value) | |||
@@ -305,7 +310,6 @@ class LRScheduler(Callback): | |||
def on_epoch_begin(self): | |||
self.scheduler.step() | |||
print("scheduler step ", "lr=", self.optimizer.param_groups[0]["lr"]) | |||
class ControlC(Callback): | |||
@@ -16,6 +16,69 @@ from fastNLP.core.vocabulary import Vocabulary | |||
class MetricBase(object): | |||
"""Base class for all metrics. | |||
所有的传入到Trainer, Tester的Metric需要继承自该对象。需要覆盖写入evaluate(), get_metric()方法。 | |||
evaluate(xxx)中传入的是一个batch的数据。 | |||
get_metric(xxx)当所有数据处理完毕,调用该方法得到最终的metric值 | |||
以分类问题中,Accuracy计算为例 | |||
假设model的forward返回dict中包含'pred'这个key, 并且该key需要用于Accuracy | |||
class Model(nn.Module): | |||
def __init__(xxx): | |||
# do something | |||
def forward(self, xxx): | |||
# do something | |||
return {'pred': pred, 'other_keys':xxx} # pred's shape: batch_size x num_classes | |||
假设dataset中'label'这个field是需要预测的值,并且该field被设置为了target | |||
对应的AccMetric可以按如下的定义 | |||
# version1, 只使用这一次 | |||
class AccMetric(MetricBase): | |||
def __init__(self): | |||
super().__init__() | |||
# 根据你的情况自定义指标 | |||
self.corr_num = 0 | |||
self.total = 0 | |||
def evaluate(self, label, pred): # 这里的名称需要和dataset中target field与model返回的key是一样的,不然找不到对应的value | |||
# dev或test时,每个batch结束会调用一次该方法,需要实现如何根据每个batch累加metric | |||
self.total += label.size(0) | |||
self.corr_num += label.eq(pred).sum().item() | |||
def get_metric(self, reset=True): # 在这里定义如何计算metric | |||
acc = self.corr_num/self.total | |||
if reset: # 是否清零以便重新计算 | |||
self.corr_num = 0 | |||
self.total = 0 | |||
return {'acc': acc} # 需要返回一个dict,key为该metric的名称,该名称会显示到Trainer的progress bar中 | |||
# version2,如果需要复用Metric,比如下一次使用AccMetric时,dataset中目标field不叫label而叫y,或者model的输出不是pred | |||
class AccMetric(MetricBase): | |||
def __init__(self, label=None, pred=None): | |||
# 假设在另一场景使用时,目标field叫y,model给出的key为pred_y。则只需要在初始化AccMetric时, | |||
# acc_metric = AccMetric(label='y', pred='pred_y')即可。 | |||
# 当初始化为acc_metric = AccMetric(),即label=None, pred=None, fastNLP会直接使用'label', 'pred'作为key去索取对 | |||
# 应的的值 | |||
super().__init__() | |||
self._init_param_map(label=label, pred=pred) # 该方法会注册label和pred. 仅需要注册evaluate()方法会用到的参数名即可 | |||
# 如果没有注册该则效果与version1就是一样的 | |||
# 根据你的情况自定义指标 | |||
self.corr_num = 0 | |||
self.total = 0 | |||
def evaluate(self, label, pred): # 这里的参数名称需要和self._init_param_map()注册时一致。 | |||
# dev或test时,每个batch结束会调用一次该方法,需要实现如何根据每个batch累加metric | |||
self.total += label.size(0) | |||
self.corr_num += label.eq(pred).sum().item() | |||
def get_metric(self, reset=True): # 在这里定义如何计算metric | |||
acc = self.corr_num/self.total | |||
if reset: # 是否清零以便重新计算 | |||
self.corr_num = 0 | |||
self.total = 0 | |||
return {'acc': acc} # 需要返回一个dict,key为该metric的名称,该名称会显示到Trainer的progress bar中 | |||
``MetricBase`` handles validity check of its input dictionaries - ``pred_dict`` and ``target_dict``. | |||
``pred_dict`` is the output of ``forward()`` or prediction function of a model. | |||
``target_dict`` is the ground truth from DataSet where ``is_target`` is set ``True``. | |||
@@ -24,7 +87,6 @@ class MetricBase(object): | |||
1. whether self.evaluate has varargs, which is not supported. | |||
2. whether params needed by self.evaluate is not included in ``pred_dict``, ``target_dict``. | |||
3. whether params needed by self.evaluate duplicate in ``pred_dict``, ``target_dict``. | |||
4. whether params in ``pred_dict``, ``target_dict`` are not used by evaluate.(Might cause warning) | |||
Besides, before passing params into self.evaluate, this function will filter out params from output_dict and | |||
target_dict which are not used in self.evaluate. (but if **kwargs presented in self.evaluate, no filtering | |||
@@ -296,6 +358,8 @@ class AccuracyMetric(MetricBase): | |||
def bmes_tag_to_spans(tags, ignore_labels=None): | |||
""" | |||
给定一个tags的lis,比如['S', 'B-singer', 'M-singer', 'E-singer', 'S', 'S']。 | |||
返回[('', (0, 1)), ('singer', (1, 4)), ('', (4, 5)), ('', (5, 6))] (左闭右开区间) | |||
:param tags: List[str], | |||
:param ignore_labels: List[str], 在该list中的label将被忽略 | |||
@@ -315,13 +379,45 @@ def bmes_tag_to_spans(tags, ignore_labels=None): | |||
else: | |||
spans.append((label, [idx, idx])) | |||
prev_bmes_tag = bmes_tag | |||
return [(span[0], (span[1][0], span[1][1])) | |||
return [(span[0], (span[1][0], span[1][1]+1)) | |||
for span in spans | |||
if span[0] not in ignore_labels | |||
] | |||
def bmeso_tag_to_spans(tags, ignore_labels=None): | |||
""" | |||
给定一个tags的lis,比如['O', 'B-singer', 'M-singer', 'E-singer', 'O', 'O']。 | |||
返回[('singer', (1, 4))] (左闭右开区间) | |||
:param tags: List[str], | |||
:param ignore_labels: List[str], 在该list中的label将被忽略 | |||
:return: List[Tuple[str, List[int, int]]]. [(label,[start, end])] | |||
""" | |||
ignore_labels = set(ignore_labels) if ignore_labels else set() | |||
spans = [] | |||
prev_bmes_tag = None | |||
for idx, tag in enumerate(tags): | |||
tag = tag.lower() | |||
bmes_tag, label = tag[:1], tag[2:] | |||
if bmes_tag in ('b', 's'): | |||
spans.append((label, [idx, idx])) | |||
elif bmes_tag in ('m', 'e') and prev_bmes_tag in ('b', 'm') and label==spans[-1][0]: | |||
spans[-1][1][1] = idx | |||
elif bmes_tag == 'o': | |||
pass | |||
else: | |||
spans.append((label, [idx, idx])) | |||
prev_bmes_tag = bmes_tag | |||
return [(span[0], (span[1][0], span[1][1]+1)) | |||
for span in spans | |||
if span[0] not in ignore_labels | |||
] | |||
def bio_tag_to_spans(tags, ignore_labels=None): | |||
""" | |||
给定一个tags的lis,比如['O', 'B-singer', 'I-singer', 'I-singer', 'O', 'O']。 | |||
返回[('singer', (1, 4))] (左闭右开区间) | |||
:param tags: List[str], | |||
:param ignore_labels: List[str], 在该list中的label将被忽略 | |||
@@ -343,7 +439,7 @@ def bio_tag_to_spans(tags, ignore_labels=None): | |||
else: | |||
spans.append((label, [idx, idx])) | |||
prev_bio_tag = bio_tag | |||
return [(span[0], (span[1][0], span[1][1])) | |||
return [(span[0], (span[1][0], span[1][1]+1)) | |||
for span in spans | |||
if span[0] not in ignore_labels | |||
] | |||
@@ -352,6 +448,8 @@ def bio_tag_to_spans(tags, ignore_labels=None): | |||
class SpanFPreRecMetric(MetricBase): | |||
""" | |||
在序列标注问题中,以span的方式计算F, pre, rec. | |||
比如中文Part of speech中,会以character的方式进行标注,句子'中国在亚洲'对应的POS可能为(以BMES为例) | |||
['B-NN', 'E-NN', 'S-DET', 'B-NN', 'E-NN']。该metric就是为类似情况下的F1计算。 | |||
最后得到的metric结果为 | |||
{ | |||
'f': xxx, # 这里使用f考虑以后可以计算f_beta值 | |||
@@ -390,8 +488,7 @@ class SpanFPreRecMetric(MetricBase): | |||
则精确率的权重高于召回率;若为1,则两者平等;若为2,则召回率权重高于精确率。 | |||
""" | |||
encoding_type = encoding_type.lower() | |||
if encoding_type not in ('bio', 'bmes'): | |||
raise ValueError("Only support 'bio' or 'bmes' type.") | |||
if not isinstance(tag_vocab, Vocabulary): | |||
raise TypeError("tag_vocab can only be fastNLP.Vocabulary, not {}.".format(type(tag_vocab))) | |||
if f_type not in ('micro', 'macro'): | |||
@@ -402,6 +499,11 @@ class SpanFPreRecMetric(MetricBase): | |||
self.tag_to_span_func = bmes_tag_to_spans | |||
elif self.encoding_type == 'bio': | |||
self.tag_to_span_func = bio_tag_to_spans | |||
elif self.encoding_type == 'bmeso': | |||
self.tag_to_span_func = bmeso_tag_to_spans | |||
else: | |||
raise ValueError("Only support 'bio', 'bmes', 'bmeso' type.") | |||
self.ignore_labels = ignore_labels | |||
self.f_type = f_type | |||
self.beta = beta | |||
@@ -73,6 +73,7 @@ class BucketSampler(BaseSampler): | |||
total_sample_num = len(seq_lens) | |||
bucket_indexes = [] | |||
assert total_sample_num>=self.num_buckets, "The number of samples is smaller than the number of buckets." | |||
num_sample_per_bucket = total_sample_num // self.num_buckets | |||
for i in range(self.num_buckets): | |||
bucket_indexes.append([num_sample_per_bucket * i, num_sample_per_bucket * (i + 1)]) | |||
@@ -205,7 +205,7 @@ class Trainer(object): | |||
except (CallbackException, KeyboardInterrupt) as e: | |||
self.callback_manager.on_exception(e) | |||
if self.dev_data is not None: | |||
if self.dev_data is not None and hasattr(self, 'best_dev_perf'): | |||
print("\nIn Epoch:{}/Step:{}, got best dev performance:".format(self.best_dev_epoch, self.best_dev_step) + | |||
self.tester._format_eval_results(self.best_dev_perf),) | |||
results['best_eval'] = self.best_dev_perf | |||
@@ -373,7 +373,7 @@ class Trainer(object): | |||
else: | |||
model.cpu() | |||
torch.save(model, model_path) | |||
model.cuda() | |||
model.to(self._model_device) | |||
def _load_model(self, model, model_name, only_param=False): | |||
# 返回bool值指示是否成功reload模型 | |||
@@ -44,10 +44,14 @@ class Vocabulary(object): | |||
:param int max_size: set the max number of words in Vocabulary. Default: None | |||
:param int min_freq: set the min occur frequency of words in Vocabulary. Default: None | |||
:param padding: str, padding的字符,默认为<pad>。如果设置为None,则vocabulary中不考虑padding,为None的情况多在为label建立 | |||
Vocabulary的情况。 | |||
:param unknown: str, unknown的字符,默认为<unk>。如果设置为None,则vocabulary中不考虑unknown,为None的情况多在为label建立 | |||
Vocabulary的情况。 | |||
""" | |||
def __init__(self, max_size=None, min_freq=None, unknown='<unk>', padding='<pad>'): | |||
def __init__(self, max_size=None, min_freq=None, padding='<pad>', unknown='<unk>'): | |||
self.max_size = max_size | |||
self.min_freq = min_freq | |||
self.word_count = Counter() | |||
@@ -97,9 +101,9 @@ class Vocabulary(object): | |||
""" | |||
self.word2idx = {} | |||
if self.padding is not None: | |||
self.word2idx[self.padding] = 0 | |||
self.word2idx[self.padding] = len(self.word2idx) | |||
if self.unknown is not None: | |||
self.word2idx[self.unknown] = 1 | |||
self.word2idx[self.unknown] = len(self.word2idx) | |||
max_size = min(self.max_size, len(self.word_count)) if self.max_size else None | |||
words = self.word_count.most_common(max_size) | |||
@@ -839,6 +839,15 @@ class SSTLoader(DataSetLoader): | |||
self.tag_v = tag_v | |||
def load(self, path): | |||
""" | |||
:param path: str,存储数据的路径 | |||
:return: DataSet。内含field有'words', 'pos_tags', 'heads', 'labels'(parser的label) | |||
类似于拥有以下结构, 一行为一个instance(sample) | |||
words pos_tags heads labels | |||
['some', ..] ['NN', ...] [2, 3...] ['nn', 'nn'...] | |||
""" | |||
datalist = [] | |||
with open(path, 'r', encoding='utf-8') as f: | |||
datas = [] | |||
for l in f: | |||
@@ -25,7 +25,7 @@ def allowed_transitions(id2label, encoding_type='bio'): | |||
:param dict id2label: key是label的indices,value是str类型的tag或tag-label。value可以是只有tag的, 比如"B", "M"; 也可以是 | |||
"B-NN", "M-NN", tag和label之间一定要用"-"隔开。一般可以通过Vocabulary.get_id2word()id2label。 | |||
:param encoding_type: str, 支持"bio", "bmes"。 | |||
:param encoding_type: str, 支持"bio", "bmes", "bmeso"。 | |||
:return: List[Tuple(int, int)]], 内部的Tuple是(from_tag_id, to_tag_id)。 返回的结果考虑了start和end,比如"BIO"中,B、O可以 | |||
位于序列的开端,而I不行。所以返回的结果中会包含(start_idx, B_idx), (start_idx, O_idx), 但是不包含(start_idx, I_idx). | |||
start_idx=len(id2label), end_idx=len(id2label)+1。 | |||
@@ -62,7 +62,7 @@ def allowed_transitions(id2label, encoding_type='bio'): | |||
def is_transition_allowed(encoding_type, from_tag, from_label, to_tag, to_label): | |||
""" | |||
:param encoding_type: str, 支持"BIO", "BMES"。 | |||
:param encoding_type: str, 支持"BIO", "BMES", "BEMSO"。 | |||
:param from_tag: str, 比如"B", "M"之类的标注tag. 还包括start, end等两种特殊tag | |||
:param from_label: str, 比如"PER", "LOC"等label | |||
:param to_tag: str, 比如"B", "M"之类的标注tag. 还包括start, end等两种特殊tag | |||
@@ -127,6 +127,18 @@ def is_transition_allowed(encoding_type, from_tag, from_label, to_tag, to_label) | |||
return to_tag in ['b', 's', 'end'] | |||
else: | |||
raise ValueError("Unexpect tag type {}. Expect only 'B', 'M', 'E', 'S'.".format(from_tag)) | |||
elif encoding_type == 'bmeso': | |||
if from_tag == 'start': | |||
return to_tag in ['b', 's', 'o'] | |||
elif from_tag == 'b': | |||
return to_tag in ['m', 'e'] and from_label==to_label | |||
elif from_tag == 'm': | |||
return to_tag in ['m', 'e'] and from_label==to_label | |||
elif from_tag in ['e', 's', 'o']: | |||
return to_tag in ['b', 's', 'end', 'o'] | |||
else: | |||
raise ValueError("Unexpect tag type {}. Expect only 'B', 'M', 'E', 'S', 'O'.".format(from_tag)) | |||
else: | |||
raise ValueError("Only support BIO, BMES encoding type, got {}.".format(encoding_type)) | |||
@@ -7,20 +7,24 @@ from fastNLP.modules.utils import initial_parameter | |||
class MLP(nn.Module): | |||
"""Multilayer Perceptrons as a decoder | |||
:param list size_layer: list of int, define the size of MLP layers. | |||
:param str activation: str or function, the activation function for hidden layers. | |||
:param list size_layer: list of int, define the size of MLP layers. layer的层数为 len(size_layer) - 1 | |||
:param str or list activation: str or function or a list, the activation function for hidden layers. | |||
:param str or function output_activation : str or function, the activation function for output layer | |||
:param str initial_method: the name of initialization method. | |||
:param float dropout: the probability of dropout. | |||
.. note:: | |||
There is no activation function applying on output layer. | |||
隐藏层的激活函数通过activation定义。一个str/function或者一个str/function的list可以被传入activation。 | |||
如果只传入了一个str/function,那么所有隐藏层的激活函数都由这个str/function定义; | |||
如果传入了一个str/function的list,那么每一个隐藏层的激活函数由这个list中对应的元素定义,其中list的长度为隐藏层数。 | |||
输出层的激活函数由output_activation定义,默认值为None,此时输出层没有激活函数。 | |||
""" | |||
def __init__(self, size_layer, activation='relu', initial_method=None, dropout=0.0): | |||
def __init__(self, size_layer, activation='relu', output_activation=None, initial_method=None, dropout=0.0): | |||
super(MLP, self).__init__() | |||
self.hiddens = nn.ModuleList() | |||
self.output = None | |||
self.output_activation = output_activation | |||
for i in range(1, len(size_layer)): | |||
if i + 1 == len(size_layer): | |||
self.output = nn.Linear(size_layer[i-1], size_layer[i]) | |||
@@ -33,25 +37,47 @@ class MLP(nn.Module): | |||
'relu': nn.ReLU(), | |||
'tanh': nn.Tanh(), | |||
} | |||
if activation in actives: | |||
self.hidden_active = actives[activation] | |||
elif callable(activation): | |||
self.hidden_active = activation | |||
if not isinstance(activation, list): | |||
activation = [activation] * (len(size_layer) - 2) | |||
elif len(activation) == len(size_layer) - 2: | |||
pass | |||
else: | |||
raise ValueError("should set activation correctly: {}".format(activation)) | |||
raise ValueError( | |||
f"the length of activation function list except {len(size_layer) - 2} but got {len(activation)}!") | |||
self.hidden_active = [] | |||
for func in activation: | |||
if callable(activation): | |||
self.hidden_active.append(activation) | |||
elif func.lower() in actives: | |||
self.hidden_active.append(actives[func]) | |||
else: | |||
raise ValueError("should set activation correctly: {}".format(activation)) | |||
if self.output_activation is not None: | |||
if callable(self.output_activation): | |||
pass | |||
elif self.output_activation.lower() in actives: | |||
self.output_activation = actives[self.output_activation] | |||
else: | |||
raise ValueError("should set activation correctly: {}".format(activation)) | |||
initial_parameter(self, initial_method) | |||
def forward(self, x): | |||
for layer in self.hiddens: | |||
x = self.dropout(self.hidden_active(layer(x))) | |||
x = self.dropout(self.output(x)) | |||
for layer, func in zip(self.hiddens, self.hidden_active): | |||
x = self.dropout(func(layer(x))) | |||
x = self.output(x) | |||
if self.output_activation is not None: | |||
x = self.output_activation(x) | |||
x = self.dropout(x) | |||
return x | |||
if __name__ == '__main__': | |||
net1 = MLP([5, 10, 5]) | |||
net2 = MLP([5, 10, 5], 'tanh') | |||
for net in [net1, net2]: | |||
net3 = MLP([5, 6, 7, 8, 5], 'tanh') | |||
net4 = MLP([5, 6, 7, 8, 5], 'relu', output_activation='tanh') | |||
net5 = MLP([5, 6, 7, 8, 5], ['tanh', 'relu', 'tanh'], 'tanh') | |||
for net in [net1, net2, net3, net4, net5]: | |||
x = torch.randn(5, 5) | |||
y = net(x) | |||
print(x) | |||