# Conflicts: # fastNLP/core/callback.py # fastNLP/io/dataset_loader.pytags/v0.4.10
@@ -249,6 +249,11 @@ class GradientClipCallback(Callback): | |||||
self.parameters = parameters | self.parameters = parameters | ||||
self.clip_value = clip_value | self.clip_value = clip_value | ||||
def on_backward_end(self, model): | |||||
if self.parameters is None: | |||||
self.clip_fun(model.parameters(), self.clip_value) | |||||
else: | |||||
self.clip_fun(self.parameters, self.clip_value) | |||||
def on_backward_end(self): | def on_backward_end(self): | ||||
self.clip_fun(self.model.parameters(), self.clip_value) | self.clip_fun(self.model.parameters(), self.clip_value) | ||||
@@ -305,7 +310,6 @@ class LRScheduler(Callback): | |||||
def on_epoch_begin(self): | def on_epoch_begin(self): | ||||
self.scheduler.step() | self.scheduler.step() | ||||
print("scheduler step ", "lr=", self.optimizer.param_groups[0]["lr"]) | |||||
class ControlC(Callback): | class ControlC(Callback): | ||||
@@ -16,6 +16,69 @@ from fastNLP.core.vocabulary import Vocabulary | |||||
class MetricBase(object): | class MetricBase(object): | ||||
"""Base class for all metrics. | """Base class for all metrics. | ||||
所有的传入到Trainer, Tester的Metric需要继承自该对象。需要覆盖写入evaluate(), get_metric()方法。 | |||||
evaluate(xxx)中传入的是一个batch的数据。 | |||||
get_metric(xxx)当所有数据处理完毕,调用该方法得到最终的metric值 | |||||
以分类问题中,Accuracy计算为例 | |||||
假设model的forward返回dict中包含'pred'这个key, 并且该key需要用于Accuracy | |||||
class Model(nn.Module): | |||||
def __init__(xxx): | |||||
# do something | |||||
def forward(self, xxx): | |||||
# do something | |||||
return {'pred': pred, 'other_keys':xxx} # pred's shape: batch_size x num_classes | |||||
假设dataset中'label'这个field是需要预测的值,并且该field被设置为了target | |||||
对应的AccMetric可以按如下的定义 | |||||
# version1, 只使用这一次 | |||||
class AccMetric(MetricBase): | |||||
def __init__(self): | |||||
super().__init__() | |||||
# 根据你的情况自定义指标 | |||||
self.corr_num = 0 | |||||
self.total = 0 | |||||
def evaluate(self, label, pred): # 这里的名称需要和dataset中target field与model返回的key是一样的,不然找不到对应的value | |||||
# dev或test时,每个batch结束会调用一次该方法,需要实现如何根据每个batch累加metric | |||||
self.total += label.size(0) | |||||
self.corr_num += label.eq(pred).sum().item() | |||||
def get_metric(self, reset=True): # 在这里定义如何计算metric | |||||
acc = self.corr_num/self.total | |||||
if reset: # 是否清零以便重新计算 | |||||
self.corr_num = 0 | |||||
self.total = 0 | |||||
return {'acc': acc} # 需要返回一个dict,key为该metric的名称,该名称会显示到Trainer的progress bar中 | |||||
# version2,如果需要复用Metric,比如下一次使用AccMetric时,dataset中目标field不叫label而叫y,或者model的输出不是pred | |||||
class AccMetric(MetricBase): | |||||
def __init__(self, label=None, pred=None): | |||||
# 假设在另一场景使用时,目标field叫y,model给出的key为pred_y。则只需要在初始化AccMetric时, | |||||
# acc_metric = AccMetric(label='y', pred='pred_y')即可。 | |||||
# 当初始化为acc_metric = AccMetric(),即label=None, pred=None, fastNLP会直接使用'label', 'pred'作为key去索取对 | |||||
# 应的的值 | |||||
super().__init__() | |||||
self._init_param_map(label=label, pred=pred) # 该方法会注册label和pred. 仅需要注册evaluate()方法会用到的参数名即可 | |||||
# 如果没有注册该则效果与version1就是一样的 | |||||
# 根据你的情况自定义指标 | |||||
self.corr_num = 0 | |||||
self.total = 0 | |||||
def evaluate(self, label, pred): # 这里的参数名称需要和self._init_param_map()注册时一致。 | |||||
# dev或test时,每个batch结束会调用一次该方法,需要实现如何根据每个batch累加metric | |||||
self.total += label.size(0) | |||||
self.corr_num += label.eq(pred).sum().item() | |||||
def get_metric(self, reset=True): # 在这里定义如何计算metric | |||||
acc = self.corr_num/self.total | |||||
if reset: # 是否清零以便重新计算 | |||||
self.corr_num = 0 | |||||
self.total = 0 | |||||
return {'acc': acc} # 需要返回一个dict,key为该metric的名称,该名称会显示到Trainer的progress bar中 | |||||
``MetricBase`` handles validity check of its input dictionaries - ``pred_dict`` and ``target_dict``. | ``MetricBase`` handles validity check of its input dictionaries - ``pred_dict`` and ``target_dict``. | ||||
``pred_dict`` is the output of ``forward()`` or prediction function of a model. | ``pred_dict`` is the output of ``forward()`` or prediction function of a model. | ||||
``target_dict`` is the ground truth from DataSet where ``is_target`` is set ``True``. | ``target_dict`` is the ground truth from DataSet where ``is_target`` is set ``True``. | ||||
@@ -24,7 +87,6 @@ class MetricBase(object): | |||||
1. whether self.evaluate has varargs, which is not supported. | 1. whether self.evaluate has varargs, which is not supported. | ||||
2. whether params needed by self.evaluate is not included in ``pred_dict``, ``target_dict``. | 2. whether params needed by self.evaluate is not included in ``pred_dict``, ``target_dict``. | ||||
3. whether params needed by self.evaluate duplicate in ``pred_dict``, ``target_dict``. | 3. whether params needed by self.evaluate duplicate in ``pred_dict``, ``target_dict``. | ||||
4. whether params in ``pred_dict``, ``target_dict`` are not used by evaluate.(Might cause warning) | |||||
Besides, before passing params into self.evaluate, this function will filter out params from output_dict and | Besides, before passing params into self.evaluate, this function will filter out params from output_dict and | ||||
target_dict which are not used in self.evaluate. (but if **kwargs presented in self.evaluate, no filtering | target_dict which are not used in self.evaluate. (but if **kwargs presented in self.evaluate, no filtering | ||||
@@ -296,6 +358,8 @@ class AccuracyMetric(MetricBase): | |||||
def bmes_tag_to_spans(tags, ignore_labels=None): | def bmes_tag_to_spans(tags, ignore_labels=None): | ||||
""" | """ | ||||
给定一个tags的lis,比如['S', 'B-singer', 'M-singer', 'E-singer', 'S', 'S']。 | |||||
返回[('', (0, 1)), ('singer', (1, 4)), ('', (4, 5)), ('', (5, 6))] (左闭右开区间) | |||||
:param tags: List[str], | :param tags: List[str], | ||||
:param ignore_labels: List[str], 在该list中的label将被忽略 | :param ignore_labels: List[str], 在该list中的label将被忽略 | ||||
@@ -315,13 +379,45 @@ def bmes_tag_to_spans(tags, ignore_labels=None): | |||||
else: | else: | ||||
spans.append((label, [idx, idx])) | spans.append((label, [idx, idx])) | ||||
prev_bmes_tag = bmes_tag | prev_bmes_tag = bmes_tag | ||||
return [(span[0], (span[1][0], span[1][1])) | |||||
return [(span[0], (span[1][0], span[1][1]+1)) | |||||
for span in spans | |||||
if span[0] not in ignore_labels | |||||
] | |||||
def bmeso_tag_to_spans(tags, ignore_labels=None): | |||||
""" | |||||
给定一个tags的lis,比如['O', 'B-singer', 'M-singer', 'E-singer', 'O', 'O']。 | |||||
返回[('singer', (1, 4))] (左闭右开区间) | |||||
:param tags: List[str], | |||||
:param ignore_labels: List[str], 在该list中的label将被忽略 | |||||
:return: List[Tuple[str, List[int, int]]]. [(label,[start, end])] | |||||
""" | |||||
ignore_labels = set(ignore_labels) if ignore_labels else set() | |||||
spans = [] | |||||
prev_bmes_tag = None | |||||
for idx, tag in enumerate(tags): | |||||
tag = tag.lower() | |||||
bmes_tag, label = tag[:1], tag[2:] | |||||
if bmes_tag in ('b', 's'): | |||||
spans.append((label, [idx, idx])) | |||||
elif bmes_tag in ('m', 'e') and prev_bmes_tag in ('b', 'm') and label==spans[-1][0]: | |||||
spans[-1][1][1] = idx | |||||
elif bmes_tag == 'o': | |||||
pass | |||||
else: | |||||
spans.append((label, [idx, idx])) | |||||
prev_bmes_tag = bmes_tag | |||||
return [(span[0], (span[1][0], span[1][1]+1)) | |||||
for span in spans | for span in spans | ||||
if span[0] not in ignore_labels | if span[0] not in ignore_labels | ||||
] | ] | ||||
def bio_tag_to_spans(tags, ignore_labels=None): | def bio_tag_to_spans(tags, ignore_labels=None): | ||||
""" | """ | ||||
给定一个tags的lis,比如['O', 'B-singer', 'I-singer', 'I-singer', 'O', 'O']。 | |||||
返回[('singer', (1, 4))] (左闭右开区间) | |||||
:param tags: List[str], | :param tags: List[str], | ||||
:param ignore_labels: List[str], 在该list中的label将被忽略 | :param ignore_labels: List[str], 在该list中的label将被忽略 | ||||
@@ -343,7 +439,7 @@ def bio_tag_to_spans(tags, ignore_labels=None): | |||||
else: | else: | ||||
spans.append((label, [idx, idx])) | spans.append((label, [idx, idx])) | ||||
prev_bio_tag = bio_tag | prev_bio_tag = bio_tag | ||||
return [(span[0], (span[1][0], span[1][1])) | |||||
return [(span[0], (span[1][0], span[1][1]+1)) | |||||
for span in spans | for span in spans | ||||
if span[0] not in ignore_labels | if span[0] not in ignore_labels | ||||
] | ] | ||||
@@ -352,6 +448,8 @@ def bio_tag_to_spans(tags, ignore_labels=None): | |||||
class SpanFPreRecMetric(MetricBase): | class SpanFPreRecMetric(MetricBase): | ||||
""" | """ | ||||
在序列标注问题中,以span的方式计算F, pre, rec. | 在序列标注问题中,以span的方式计算F, pre, rec. | ||||
比如中文Part of speech中,会以character的方式进行标注,句子'中国在亚洲'对应的POS可能为(以BMES为例) | |||||
['B-NN', 'E-NN', 'S-DET', 'B-NN', 'E-NN']。该metric就是为类似情况下的F1计算。 | |||||
最后得到的metric结果为 | 最后得到的metric结果为 | ||||
{ | { | ||||
'f': xxx, # 这里使用f考虑以后可以计算f_beta值 | 'f': xxx, # 这里使用f考虑以后可以计算f_beta值 | ||||
@@ -390,8 +488,7 @@ class SpanFPreRecMetric(MetricBase): | |||||
则精确率的权重高于召回率;若为1,则两者平等;若为2,则召回率权重高于精确率。 | 则精确率的权重高于召回率;若为1,则两者平等;若为2,则召回率权重高于精确率。 | ||||
""" | """ | ||||
encoding_type = encoding_type.lower() | encoding_type = encoding_type.lower() | ||||
if encoding_type not in ('bio', 'bmes'): | |||||
raise ValueError("Only support 'bio' or 'bmes' type.") | |||||
if not isinstance(tag_vocab, Vocabulary): | if not isinstance(tag_vocab, Vocabulary): | ||||
raise TypeError("tag_vocab can only be fastNLP.Vocabulary, not {}.".format(type(tag_vocab))) | raise TypeError("tag_vocab can only be fastNLP.Vocabulary, not {}.".format(type(tag_vocab))) | ||||
if f_type not in ('micro', 'macro'): | if f_type not in ('micro', 'macro'): | ||||
@@ -402,6 +499,11 @@ class SpanFPreRecMetric(MetricBase): | |||||
self.tag_to_span_func = bmes_tag_to_spans | self.tag_to_span_func = bmes_tag_to_spans | ||||
elif self.encoding_type == 'bio': | elif self.encoding_type == 'bio': | ||||
self.tag_to_span_func = bio_tag_to_spans | self.tag_to_span_func = bio_tag_to_spans | ||||
elif self.encoding_type == 'bmeso': | |||||
self.tag_to_span_func = bmeso_tag_to_spans | |||||
else: | |||||
raise ValueError("Only support 'bio', 'bmes', 'bmeso' type.") | |||||
self.ignore_labels = ignore_labels | self.ignore_labels = ignore_labels | ||||
self.f_type = f_type | self.f_type = f_type | ||||
self.beta = beta | self.beta = beta | ||||
@@ -73,6 +73,7 @@ class BucketSampler(BaseSampler): | |||||
total_sample_num = len(seq_lens) | total_sample_num = len(seq_lens) | ||||
bucket_indexes = [] | bucket_indexes = [] | ||||
assert total_sample_num>=self.num_buckets, "The number of samples is smaller than the number of buckets." | |||||
num_sample_per_bucket = total_sample_num // self.num_buckets | num_sample_per_bucket = total_sample_num // self.num_buckets | ||||
for i in range(self.num_buckets): | for i in range(self.num_buckets): | ||||
bucket_indexes.append([num_sample_per_bucket * i, num_sample_per_bucket * (i + 1)]) | bucket_indexes.append([num_sample_per_bucket * i, num_sample_per_bucket * (i + 1)]) | ||||
@@ -205,7 +205,7 @@ class Trainer(object): | |||||
except (CallbackException, KeyboardInterrupt) as e: | except (CallbackException, KeyboardInterrupt) as e: | ||||
self.callback_manager.on_exception(e) | self.callback_manager.on_exception(e) | ||||
if self.dev_data is not None: | |||||
if self.dev_data is not None and hasattr(self, 'best_dev_perf'): | |||||
print("\nIn Epoch:{}/Step:{}, got best dev performance:".format(self.best_dev_epoch, self.best_dev_step) + | print("\nIn Epoch:{}/Step:{}, got best dev performance:".format(self.best_dev_epoch, self.best_dev_step) + | ||||
self.tester._format_eval_results(self.best_dev_perf),) | self.tester._format_eval_results(self.best_dev_perf),) | ||||
results['best_eval'] = self.best_dev_perf | results['best_eval'] = self.best_dev_perf | ||||
@@ -373,7 +373,7 @@ class Trainer(object): | |||||
else: | else: | ||||
model.cpu() | model.cpu() | ||||
torch.save(model, model_path) | torch.save(model, model_path) | ||||
model.cuda() | |||||
model.to(self._model_device) | |||||
def _load_model(self, model, model_name, only_param=False): | def _load_model(self, model, model_name, only_param=False): | ||||
# 返回bool值指示是否成功reload模型 | # 返回bool值指示是否成功reload模型 | ||||
@@ -44,10 +44,14 @@ class Vocabulary(object): | |||||
:param int max_size: set the max number of words in Vocabulary. Default: None | :param int max_size: set the max number of words in Vocabulary. Default: None | ||||
:param int min_freq: set the min occur frequency of words in Vocabulary. Default: None | :param int min_freq: set the min occur frequency of words in Vocabulary. Default: None | ||||
:param padding: str, padding的字符,默认为<pad>。如果设置为None,则vocabulary中不考虑padding,为None的情况多在为label建立 | |||||
Vocabulary的情况。 | |||||
:param unknown: str, unknown的字符,默认为<unk>。如果设置为None,则vocabulary中不考虑unknown,为None的情况多在为label建立 | |||||
Vocabulary的情况。 | |||||
""" | """ | ||||
def __init__(self, max_size=None, min_freq=None, unknown='<unk>', padding='<pad>'): | |||||
def __init__(self, max_size=None, min_freq=None, padding='<pad>', unknown='<unk>'): | |||||
self.max_size = max_size | self.max_size = max_size | ||||
self.min_freq = min_freq | self.min_freq = min_freq | ||||
self.word_count = Counter() | self.word_count = Counter() | ||||
@@ -97,9 +101,9 @@ class Vocabulary(object): | |||||
""" | """ | ||||
self.word2idx = {} | self.word2idx = {} | ||||
if self.padding is not None: | if self.padding is not None: | ||||
self.word2idx[self.padding] = 0 | |||||
self.word2idx[self.padding] = len(self.word2idx) | |||||
if self.unknown is not None: | if self.unknown is not None: | ||||
self.word2idx[self.unknown] = 1 | |||||
self.word2idx[self.unknown] = len(self.word2idx) | |||||
max_size = min(self.max_size, len(self.word_count)) if self.max_size else None | max_size = min(self.max_size, len(self.word_count)) if self.max_size else None | ||||
words = self.word_count.most_common(max_size) | words = self.word_count.most_common(max_size) | ||||
@@ -839,6 +839,15 @@ class SSTLoader(DataSetLoader): | |||||
self.tag_v = tag_v | self.tag_v = tag_v | ||||
def load(self, path): | def load(self, path): | ||||
""" | |||||
:param path: str,存储数据的路径 | |||||
:return: DataSet。内含field有'words', 'pos_tags', 'heads', 'labels'(parser的label) | |||||
类似于拥有以下结构, 一行为一个instance(sample) | |||||
words pos_tags heads labels | |||||
['some', ..] ['NN', ...] [2, 3...] ['nn', 'nn'...] | |||||
""" | |||||
datalist = [] | |||||
with open(path, 'r', encoding='utf-8') as f: | with open(path, 'r', encoding='utf-8') as f: | ||||
datas = [] | datas = [] | ||||
for l in f: | for l in f: | ||||
@@ -25,7 +25,7 @@ def allowed_transitions(id2label, encoding_type='bio'): | |||||
:param dict id2label: key是label的indices,value是str类型的tag或tag-label。value可以是只有tag的, 比如"B", "M"; 也可以是 | :param dict id2label: key是label的indices,value是str类型的tag或tag-label。value可以是只有tag的, 比如"B", "M"; 也可以是 | ||||
"B-NN", "M-NN", tag和label之间一定要用"-"隔开。一般可以通过Vocabulary.get_id2word()id2label。 | "B-NN", "M-NN", tag和label之间一定要用"-"隔开。一般可以通过Vocabulary.get_id2word()id2label。 | ||||
:param encoding_type: str, 支持"bio", "bmes"。 | |||||
:param encoding_type: str, 支持"bio", "bmes", "bmeso"。 | |||||
:return: List[Tuple(int, int)]], 内部的Tuple是(from_tag_id, to_tag_id)。 返回的结果考虑了start和end,比如"BIO"中,B、O可以 | :return: List[Tuple(int, int)]], 内部的Tuple是(from_tag_id, to_tag_id)。 返回的结果考虑了start和end,比如"BIO"中,B、O可以 | ||||
位于序列的开端,而I不行。所以返回的结果中会包含(start_idx, B_idx), (start_idx, O_idx), 但是不包含(start_idx, I_idx). | 位于序列的开端,而I不行。所以返回的结果中会包含(start_idx, B_idx), (start_idx, O_idx), 但是不包含(start_idx, I_idx). | ||||
start_idx=len(id2label), end_idx=len(id2label)+1。 | start_idx=len(id2label), end_idx=len(id2label)+1。 | ||||
@@ -62,7 +62,7 @@ def allowed_transitions(id2label, encoding_type='bio'): | |||||
def is_transition_allowed(encoding_type, from_tag, from_label, to_tag, to_label): | def is_transition_allowed(encoding_type, from_tag, from_label, to_tag, to_label): | ||||
""" | """ | ||||
:param encoding_type: str, 支持"BIO", "BMES"。 | |||||
:param encoding_type: str, 支持"BIO", "BMES", "BEMSO"。 | |||||
:param from_tag: str, 比如"B", "M"之类的标注tag. 还包括start, end等两种特殊tag | :param from_tag: str, 比如"B", "M"之类的标注tag. 还包括start, end等两种特殊tag | ||||
:param from_label: str, 比如"PER", "LOC"等label | :param from_label: str, 比如"PER", "LOC"等label | ||||
:param to_tag: str, 比如"B", "M"之类的标注tag. 还包括start, end等两种特殊tag | :param to_tag: str, 比如"B", "M"之类的标注tag. 还包括start, end等两种特殊tag | ||||
@@ -127,6 +127,18 @@ def is_transition_allowed(encoding_type, from_tag, from_label, to_tag, to_label) | |||||
return to_tag in ['b', 's', 'end'] | return to_tag in ['b', 's', 'end'] | ||||
else: | else: | ||||
raise ValueError("Unexpect tag type {}. Expect only 'B', 'M', 'E', 'S'.".format(from_tag)) | raise ValueError("Unexpect tag type {}. Expect only 'B', 'M', 'E', 'S'.".format(from_tag)) | ||||
elif encoding_type == 'bmeso': | |||||
if from_tag == 'start': | |||||
return to_tag in ['b', 's', 'o'] | |||||
elif from_tag == 'b': | |||||
return to_tag in ['m', 'e'] and from_label==to_label | |||||
elif from_tag == 'm': | |||||
return to_tag in ['m', 'e'] and from_label==to_label | |||||
elif from_tag in ['e', 's', 'o']: | |||||
return to_tag in ['b', 's', 'end', 'o'] | |||||
else: | |||||
raise ValueError("Unexpect tag type {}. Expect only 'B', 'M', 'E', 'S', 'O'.".format(from_tag)) | |||||
else: | else: | ||||
raise ValueError("Only support BIO, BMES encoding type, got {}.".format(encoding_type)) | raise ValueError("Only support BIO, BMES encoding type, got {}.".format(encoding_type)) | ||||
@@ -7,20 +7,24 @@ from fastNLP.modules.utils import initial_parameter | |||||
class MLP(nn.Module): | class MLP(nn.Module): | ||||
"""Multilayer Perceptrons as a decoder | """Multilayer Perceptrons as a decoder | ||||
:param list size_layer: list of int, define the size of MLP layers. | |||||
:param str activation: str or function, the activation function for hidden layers. | |||||
:param list size_layer: list of int, define the size of MLP layers. layer的层数为 len(size_layer) - 1 | |||||
:param str or list activation: str or function or a list, the activation function for hidden layers. | |||||
:param str or function output_activation : str or function, the activation function for output layer | |||||
:param str initial_method: the name of initialization method. | :param str initial_method: the name of initialization method. | ||||
:param float dropout: the probability of dropout. | :param float dropout: the probability of dropout. | ||||
.. note:: | .. note:: | ||||
There is no activation function applying on output layer. | |||||
隐藏层的激活函数通过activation定义。一个str/function或者一个str/function的list可以被传入activation。 | |||||
如果只传入了一个str/function,那么所有隐藏层的激活函数都由这个str/function定义; | |||||
如果传入了一个str/function的list,那么每一个隐藏层的激活函数由这个list中对应的元素定义,其中list的长度为隐藏层数。 | |||||
输出层的激活函数由output_activation定义,默认值为None,此时输出层没有激活函数。 | |||||
""" | """ | ||||
def __init__(self, size_layer, activation='relu', initial_method=None, dropout=0.0): | |||||
def __init__(self, size_layer, activation='relu', output_activation=None, initial_method=None, dropout=0.0): | |||||
super(MLP, self).__init__() | super(MLP, self).__init__() | ||||
self.hiddens = nn.ModuleList() | self.hiddens = nn.ModuleList() | ||||
self.output = None | self.output = None | ||||
self.output_activation = output_activation | |||||
for i in range(1, len(size_layer)): | for i in range(1, len(size_layer)): | ||||
if i + 1 == len(size_layer): | if i + 1 == len(size_layer): | ||||
self.output = nn.Linear(size_layer[i-1], size_layer[i]) | self.output = nn.Linear(size_layer[i-1], size_layer[i]) | ||||
@@ -33,25 +37,47 @@ class MLP(nn.Module): | |||||
'relu': nn.ReLU(), | 'relu': nn.ReLU(), | ||||
'tanh': nn.Tanh(), | 'tanh': nn.Tanh(), | ||||
} | } | ||||
if activation in actives: | |||||
self.hidden_active = actives[activation] | |||||
elif callable(activation): | |||||
self.hidden_active = activation | |||||
if not isinstance(activation, list): | |||||
activation = [activation] * (len(size_layer) - 2) | |||||
elif len(activation) == len(size_layer) - 2: | |||||
pass | |||||
else: | else: | ||||
raise ValueError("should set activation correctly: {}".format(activation)) | |||||
raise ValueError( | |||||
f"the length of activation function list except {len(size_layer) - 2} but got {len(activation)}!") | |||||
self.hidden_active = [] | |||||
for func in activation: | |||||
if callable(activation): | |||||
self.hidden_active.append(activation) | |||||
elif func.lower() in actives: | |||||
self.hidden_active.append(actives[func]) | |||||
else: | |||||
raise ValueError("should set activation correctly: {}".format(activation)) | |||||
if self.output_activation is not None: | |||||
if callable(self.output_activation): | |||||
pass | |||||
elif self.output_activation.lower() in actives: | |||||
self.output_activation = actives[self.output_activation] | |||||
else: | |||||
raise ValueError("should set activation correctly: {}".format(activation)) | |||||
initial_parameter(self, initial_method) | initial_parameter(self, initial_method) | ||||
def forward(self, x): | def forward(self, x): | ||||
for layer in self.hiddens: | |||||
x = self.dropout(self.hidden_active(layer(x))) | |||||
x = self.dropout(self.output(x)) | |||||
for layer, func in zip(self.hiddens, self.hidden_active): | |||||
x = self.dropout(func(layer(x))) | |||||
x = self.output(x) | |||||
if self.output_activation is not None: | |||||
x = self.output_activation(x) | |||||
x = self.dropout(x) | |||||
return x | return x | ||||
if __name__ == '__main__': | if __name__ == '__main__': | ||||
net1 = MLP([5, 10, 5]) | net1 = MLP([5, 10, 5]) | ||||
net2 = MLP([5, 10, 5], 'tanh') | net2 = MLP([5, 10, 5], 'tanh') | ||||
for net in [net1, net2]: | |||||
net3 = MLP([5, 6, 7, 8, 5], 'tanh') | |||||
net4 = MLP([5, 6, 7, 8, 5], 'relu', output_activation='tanh') | |||||
net5 = MLP([5, 6, 7, 8, 5], ['tanh', 'relu', 'tanh'], 'tanh') | |||||
for net in [net1, net2, net3, net4, net5]: | |||||
x = torch.randn(5, 5) | x = torch.randn(5, 5) | ||||
y = net(x) | y = net(x) | ||||
print(x) | print(x) | ||||