* fix losses的_fast_param_map的bug * Trainer添加sampelr初始化参数,并调整参数顺序 * refine codestags/v0.2.0^2
@@ -72,10 +72,9 @@ class LossBase(object): | |||||
def _fast_param_map(self, pred_dict, target_dict): | def _fast_param_map(self, pred_dict, target_dict): | ||||
if len(self.param_map) == 2 and len(pred_dict) == 1 and len(target_dict) == 1: | if len(self.param_map) == 2 and len(pred_dict) == 1 and len(target_dict) == 1: | ||||
return pred_dict.values[0], target_dict.values[0] | |||||
return tuple(pred_dict.values())[0], tuple(target_dict.values())[0] | |||||
return None | return None | ||||
def __call__(self, pred_dict, target_dict, check=False): | def __call__(self, pred_dict, target_dict, check=False): | ||||
""" | """ | ||||
:param pred_dict: A dict from forward function of the network. | :param pred_dict: A dict from forward function of the network. | ||||
@@ -1,4 +1,3 @@ | |||||
import inspect | import inspect | ||||
import warnings | import warnings | ||||
from collections import defaultdict | from collections import defaultdict | ||||
@@ -7,11 +6,12 @@ import numpy as np | |||||
import torch | import torch | ||||
from fastNLP.core.utils import CheckError | from fastNLP.core.utils import CheckError | ||||
from fastNLP.core.utils import CheckRes | |||||
from fastNLP.core.utils import _build_args | from fastNLP.core.utils import _build_args | ||||
from fastNLP.core.utils import _check_arg_dict_list | from fastNLP.core.utils import _check_arg_dict_list | ||||
from fastNLP.core.utils import get_func_signature | from fastNLP.core.utils import get_func_signature | ||||
from fastNLP.core.utils import seq_lens_to_masks | from fastNLP.core.utils import seq_lens_to_masks | ||||
from fastNLP.core.utils import CheckRes | |||||
class MetricBase(object): | class MetricBase(object): | ||||
def __init__(self): | def __init__(self): | ||||
@@ -59,9 +59,10 @@ class MetricBase(object): | |||||
func_args = [arg for arg in func_spect.args if arg != 'self'] | func_args = [arg for arg in func_spect.args if arg != 'self'] | ||||
for func_param, input_param in self.param_map.items(): | for func_param, input_param in self.param_map.items(): | ||||
if func_param not in func_args: | if func_param not in func_args: | ||||
raise NameError(f"Parameter `{func_param}` is not in {get_func_signature(self.evaluate)}. Please check the " | |||||
f"initialization parameters, or change the signature of" | |||||
f" {get_func_signature(self.evaluate)}.") | |||||
raise NameError( | |||||
f"Parameter `{func_param}` is not in {get_func_signature(self.evaluate)}. Please check the " | |||||
f"initialization parameters, or change the signature of" | |||||
f" {get_func_signature(self.evaluate)}.") | |||||
# evaluate should not have varargs. | # evaluate should not have varargs. | ||||
if func_spect.varargs: | if func_spect.varargs: | ||||
@@ -113,7 +114,7 @@ class MetricBase(object): | |||||
if not self._checked: | if not self._checked: | ||||
# 1. check consistence between signature and param_map | # 1. check consistence between signature and param_map | ||||
func_spect = inspect.getfullargspec(self.evaluate) | func_spect = inspect.getfullargspec(self.evaluate) | ||||
func_args = set([arg for arg in func_spect.args if arg!='self']) | |||||
func_args = set([arg for arg in func_spect.args if arg != 'self']) | |||||
for func_arg, input_arg in self.param_map.items(): | for func_arg, input_arg in self.param_map.items(): | ||||
if func_arg not in func_args: | if func_arg not in func_args: | ||||
raise NameError(f"`{func_arg}` not in {get_func_signature(self.evaluate)}.") | raise NameError(f"`{func_arg}` not in {get_func_signature(self.evaluate)}.") | ||||
@@ -121,7 +122,7 @@ class MetricBase(object): | |||||
# 2. only part of the param_map are passed, left are not | # 2. only part of the param_map are passed, left are not | ||||
for arg in func_args: | for arg in func_args: | ||||
if arg not in self.param_map: | if arg not in self.param_map: | ||||
self.param_map[arg] = arg #This param does not need mapping. | |||||
self.param_map[arg] = arg # This param does not need mapping. | |||||
self._evaluate_args = func_args | self._evaluate_args = func_args | ||||
self._reverse_param_map = {input_arg: func_arg for func_arg, input_arg in self.param_map.items()} | self._reverse_param_map = {input_arg: func_arg for func_arg, input_arg in self.param_map.items()} | ||||
@@ -153,14 +154,14 @@ class MetricBase(object): | |||||
replaced_missing = list(missing) | replaced_missing = list(missing) | ||||
for idx, func_arg in enumerate(missing): | for idx, func_arg in enumerate(missing): | ||||
replaced_missing[idx] = f"{self.param_map[func_arg]}" + f"(assign to `{func_arg}` " \ | replaced_missing[idx] = f"{self.param_map[func_arg]}" + f"(assign to `{func_arg}` " \ | ||||
f"in `{self.__class__.__name__}`)" | |||||
f"in `{self.__class__.__name__}`)" | |||||
check_res = CheckRes(missing=replaced_missing, | check_res = CheckRes(missing=replaced_missing, | ||||
unused=check_res.unused, | |||||
duplicated=duplicated, | |||||
required=check_res.required, | |||||
all_needed=check_res.all_needed, | |||||
varargs=check_res.varargs) | |||||
unused=check_res.unused, | |||||
duplicated=duplicated, | |||||
required=check_res.required, | |||||
all_needed=check_res.all_needed, | |||||
varargs=check_res.varargs) | |||||
if check_res.missing or check_res.duplicated or check_res.varargs: | if check_res.missing or check_res.duplicated or check_res.varargs: | ||||
raise CheckError(check_res=check_res, | raise CheckError(check_res=check_res, | ||||
@@ -172,6 +173,7 @@ class MetricBase(object): | |||||
return | return | ||||
class AccuracyMetric(MetricBase): | class AccuracyMetric(MetricBase): | ||||
def __init__(self, pred=None, target=None, masks=None, seq_lens=None): | def __init__(self, pred=None, target=None, masks=None, seq_lens=None): | ||||
super().__init__() | super().__init__() | ||||
@@ -191,7 +193,7 @@ class AccuracyMetric(MetricBase): | |||||
:param target_dict: | :param target_dict: | ||||
:return: boolean, whether to go on codes in self.__call__(). When False, don't go on. | :return: boolean, whether to go on codes in self.__call__(). When False, don't go on. | ||||
""" | """ | ||||
if len(pred_dict)==1 and len(target_dict)==1: | |||||
if len(pred_dict) == 1 and len(target_dict) == 1: | |||||
pred = list(pred_dict.values())[0] | pred = list(pred_dict.values())[0] | ||||
target = list(target_dict.values())[0] | target = list(target_dict.values())[0] | ||||
self.evaluate(pred=pred, target=target) | self.evaluate(pred=pred, target=target) | ||||
@@ -211,7 +213,7 @@ class AccuracyMetric(MetricBase): | |||||
None, None, torch.Size([B], torch.Size([B]). ignored if masks are provided. | None, None, torch.Size([B], torch.Size([B]). ignored if masks are provided. | ||||
:return: dict({'acc': float}) | :return: dict({'acc': float}) | ||||
""" | """ | ||||
#TODO 这里报错需要更改,因为pred是啥用户并不知道。需要告知用户真实的value | |||||
# TODO 这里报错需要更改,因为pred是啥用户并不知道。需要告知用户真实的value | |||||
if not isinstance(pred, torch.Tensor): | if not isinstance(pred, torch.Tensor): | ||||
raise TypeError(f"`pred` in {get_func_signature(self.evaluate)} must be torch.Tensor," | raise TypeError(f"`pred` in {get_func_signature(self.evaluate)} must be torch.Tensor," | ||||
f"got {type(pred)}.") | f"got {type(pred)}.") | ||||
@@ -224,14 +226,14 @@ class AccuracyMetric(MetricBase): | |||||
f"got {type(masks)}.") | f"got {type(masks)}.") | ||||
elif seq_lens is not None and not isinstance(seq_lens, torch.Tensor): | elif seq_lens is not None and not isinstance(seq_lens, torch.Tensor): | ||||
raise TypeError(f"`seq_lens` in {get_func_signature(self.evaluate)} must be torch.Tensor," | raise TypeError(f"`seq_lens` in {get_func_signature(self.evaluate)} must be torch.Tensor," | ||||
f"got {type(seq_lens)}.") | |||||
f"got {type(seq_lens)}.") | |||||
if masks is None and seq_lens is not None: | if masks is None and seq_lens is not None: | ||||
masks = seq_lens_to_masks(seq_lens=seq_lens, float=True) | masks = seq_lens_to_masks(seq_lens=seq_lens, float=True) | ||||
if pred.size()==target.size(): | |||||
if pred.size() == target.size(): | |||||
pass | pass | ||||
elif len(pred.size())==len(target.size())+1: | |||||
elif len(pred.size()) == len(target.size()) + 1: | |||||
pred = pred.argmax(dim=-1) | pred = pred.argmax(dim=-1) | ||||
else: | else: | ||||
raise RuntimeError(f"In {get_func_signature(self.evaluate)}, when pred have " | raise RuntimeError(f"In {get_func_signature(self.evaluate)}, when pred have " | ||||
@@ -245,18 +247,17 @@ class AccuracyMetric(MetricBase): | |||||
self.acc_count += torch.sum(torch.eq(pred, target).float() * masks.float()).item() | self.acc_count += torch.sum(torch.eq(pred, target).float() * masks.float()).item() | ||||
self.total += torch.sum(masks.float()).item() | self.total += torch.sum(masks.float()).item() | ||||
else: | else: | ||||
self.acc_count += torch.sum(torch.eq(pred, target).float()).item() | |||||
self.acc_count += torch.sum(torch.eq(pred, target).float()).item() | |||||
self.total += np.prod(list(pred.size())) | self.total += np.prod(list(pred.size())) | ||||
def get_metric(self, reset=True): | def get_metric(self, reset=True): | ||||
evaluate_result = {'acc': round(self.acc_count/self.total, 6)} | |||||
evaluate_result = {'acc': round(self.acc_count / self.total, 6)} | |||||
if reset: | if reset: | ||||
self.acc_count = 0 | self.acc_count = 0 | ||||
self.total = 0 | self.total = 0 | ||||
return evaluate_result | return evaluate_result | ||||
def _prepare_metrics(metrics): | def _prepare_metrics(metrics): | ||||
""" | """ | ||||
@@ -278,7 +279,8 @@ def _prepare_metrics(metrics): | |||||
raise TypeError(f"{metric_name}.get_metric must be callable, got {type(metric.get_metric)}.") | raise TypeError(f"{metric_name}.get_metric must be callable, got {type(metric.get_metric)}.") | ||||
_metrics.append(metric) | _metrics.append(metric) | ||||
else: | else: | ||||
raise TypeError(f"The type of metric in metrics must be `fastNLP.MetricBase`, not `{type(metric)}`.") | |||||
raise TypeError( | |||||
f"The type of metric in metrics must be `fastNLP.MetricBase`, not `{type(metric)}`.") | |||||
elif isinstance(metrics, MetricBase): | elif isinstance(metrics, MetricBase): | ||||
_metrics = [metrics] | _metrics = [metrics] | ||||
else: | else: | ||||
@@ -300,6 +302,7 @@ class Evaluator(object): | |||||
""" | """ | ||||
raise NotImplementedError | raise NotImplementedError | ||||
class ClassifyEvaluator(Evaluator): | class ClassifyEvaluator(Evaluator): | ||||
def __init__(self): | def __init__(self): | ||||
super(ClassifyEvaluator, self).__init__() | super(ClassifyEvaluator, self).__init__() | ||||
@@ -335,6 +338,7 @@ class SeqLabelEvaluator(Evaluator): | |||||
accuracy = total_correct / total_count | accuracy = total_correct / total_count | ||||
return {"accuracy": float(accuracy)} | return {"accuracy": float(accuracy)} | ||||
class SeqLabelEvaluator2(Evaluator): | class SeqLabelEvaluator2(Evaluator): | ||||
# 上面的evaluator应该是错误的 | # 上面的evaluator应该是错误的 | ||||
def __init__(self, seq_lens_field_name='word_seq_origin_len'): | def __init__(self, seq_lens_field_name='word_seq_origin_len'): | ||||
@@ -367,7 +371,7 @@ class SeqLabelEvaluator2(Evaluator): | |||||
if x_i in self.end_tagidx_set: | if x_i in self.end_tagidx_set: | ||||
truth_count += 1 | truth_count += 1 | ||||
for j in range(start, idx_i + 1): | for j in range(start, idx_i + 1): | ||||
if y_[j]!=x_[j]: | |||||
if y_[j] != x_[j]: | |||||
flag = False | flag = False | ||||
break | break | ||||
if flag: | if flag: | ||||
@@ -380,8 +384,7 @@ class SeqLabelEvaluator2(Evaluator): | |||||
R = corr_count / (float(truth_count) + 1e-6) | R = corr_count / (float(truth_count) + 1e-6) | ||||
F = 2 * P * R / (P + R + 1e-6) | F = 2 * P * R / (P + R + 1e-6) | ||||
return {"P": P, 'R':R, 'F': F} | |||||
return {"P": P, 'R': R, 'F': F} | |||||
class SNLIEvaluator(Evaluator): | class SNLIEvaluator(Evaluator): | ||||
@@ -563,10 +566,6 @@ def f1_score(y_true, y_pred, labels=None, pos_label=1, average='binary'): | |||||
return 2 * precision * recall / (precision + recall) if (precision + recall) > 0 else 0 | return 2 * precision * recall / (precision + recall) if (precision + recall) > 0 else 0 | ||||
def classification_report(y_true, y_pred, labels=None, target_names=None, digits=2): | |||||
raise NotImplementedError | |||||
def accuracy_topk(y_true, y_prob, k=1): | def accuracy_topk(y_true, y_prob, k=1): | ||||
"""Compute accuracy of y_true matching top-k probable | """Compute accuracy of y_true matching top-k probable | ||||
labels in y_prob. | labels in y_prob. | ||||
@@ -28,11 +28,9 @@ class Trainer(object): | |||||
""" | """ | ||||
def __init__(self, train_data, model, losser=None, metrics=None, n_epochs=3, batch_size=32, print_every=50, | |||||
validate_every=-1, | |||||
dev_data=None, use_cuda=False, save_path=None, | |||||
optimizer=Adam(lr=0.01, weight_decay=0), check_code_level=0, | |||||
metric_key=None): | |||||
def __init__(self, train_data, model, losser=None, metrics=None, optimizer=Adam(lr=0.01, weight_decay=0), | |||||
sampler=RandomSampler(), n_epochs=3, batch_size=32, print_every=50, validate_every=-1, dev_data=None, | |||||
use_cuda=False, metric_key=None, save_path=None, check_code_level=0): | |||||
""" | """ | ||||
:param DataSet train_data: the training data | :param DataSet train_data: the training data | ||||
@@ -54,7 +52,6 @@ class Trainer(object): | |||||
:: | :: | ||||
metric_key="-PPL" # language model gets better as perplexity gets smaller | metric_key="-PPL" # language model gets better as perplexity gets smaller | ||||
:param kwargs: | |||||
""" | """ | ||||
super(Trainer, self).__init__() | super(Trainer, self).__init__() | ||||
@@ -105,6 +102,7 @@ class Trainer(object): | |||||
self.print_every = int(print_every) | self.print_every = int(print_every) | ||||
self.validate_every = int(validate_every) | self.validate_every = int(validate_every) | ||||
self.best_metric_indicator = None | self.best_metric_indicator = None | ||||
self.sampler = sampler | |||||
self._model_device = model.parameters().__next__().device | self._model_device = model.parameters().__next__().device | ||||
@@ -120,14 +118,9 @@ class Trainer(object): | |||||
batch_size=self.batch_size, | batch_size=self.batch_size, | ||||
use_cuda=self.use_cuda) | use_cuda=self.use_cuda) | ||||
for k, v in kwargs.items(): | |||||
setattr(self, k, v) | |||||
self.step = 0 | self.step = 0 | ||||
self.start_time = None # start timestamp | self.start_time = None # start timestamp | ||||
# print(self.__dict__) | |||||
def train(self): | def train(self): | ||||
"""Start Training. | """Start Training. | ||||
@@ -158,7 +151,7 @@ class Trainer(object): | |||||
epoch = 1 | epoch = 1 | ||||
while epoch <= self.n_epochs: | while epoch <= self.n_epochs: | ||||
data_iterator = Batch(self.train_data, batch_size=self.batch_size, sampler=RandomSampler(), | |||||
data_iterator = Batch(self.train_data, batch_size=self.batch_size, sampler=self.sampler, | |||||
as_numpy=False) | as_numpy=False) | ||||
self._train_epoch(data_iterator, self.model, epoch, start) | self._train_epoch(data_iterator, self.model, epoch, start) | ||||
@@ -10,6 +10,8 @@ import torch | |||||
CheckRes = namedtuple('CheckRes', ['missing', 'unused', 'duplicated', 'required', 'all_needed', | CheckRes = namedtuple('CheckRes', ['missing', 'unused', 'duplicated', 'required', 'all_needed', | ||||
'varargs'], verbose=False) | 'varargs'], verbose=False) | ||||
def save_pickle(obj, pickle_path, file_name): | def save_pickle(obj, pickle_path, file_name): | ||||
"""Save an object into a pickle file. | """Save an object into a pickle file. | ||||
@@ -53,6 +55,7 @@ def pickle_exist(pickle_path, pickle_name): | |||||
else: | else: | ||||
return False | return False | ||||
def _build_args(func, **kwargs): | def _build_args(func, **kwargs): | ||||
spect = inspect.getfullargspec(func) | spect = inspect.getfullargspec(func) | ||||
if spect.varkw is not None: | if spect.varkw is not None: | ||||
@@ -108,7 +111,7 @@ def _check_arg_dict_list(func, args): | |||||
assert callable(func) and isinstance(arg_dict_list, (list, tuple)) | assert callable(func) and isinstance(arg_dict_list, (list, tuple)) | ||||
assert len(arg_dict_list) > 0 and isinstance(arg_dict_list[0], dict) | assert len(arg_dict_list) > 0 and isinstance(arg_dict_list[0], dict) | ||||
spect = inspect.getfullargspec(func) | spect = inspect.getfullargspec(func) | ||||
all_args = set([arg for arg in spect.args if arg!='self']) | |||||
all_args = set([arg for arg in spect.args if arg != 'self']) | |||||
defaults = [] | defaults = [] | ||||
if spect.defaults is not None: | if spect.defaults is not None: | ||||
defaults = [arg for arg in spect.defaults] | defaults = [arg for arg in spect.defaults] | ||||
@@ -130,6 +133,7 @@ def _check_arg_dict_list(func, args): | |||||
all_needed=list(all_args), | all_needed=list(all_args), | ||||
varargs=varargs) | varargs=varargs) | ||||
def get_func_signature(func): | def get_func_signature(func): | ||||
""" | """ | ||||
@@ -153,7 +157,7 @@ def get_func_signature(func): | |||||
class_name = func.__self__.__class__.__name__ | class_name = func.__self__.__class__.__name__ | ||||
signature = inspect.signature(func) | signature = inspect.signature(func) | ||||
signature_str = str(signature) | signature_str = str(signature) | ||||
if len(signature_str)>2: | |||||
if len(signature_str) > 2: | |||||
_self = '(self, ' | _self = '(self, ' | ||||
else: | else: | ||||
_self = '(self' | _self = '(self' | ||||
@@ -176,12 +180,13 @@ def _is_function_or_method(func): | |||||
return False | return False | ||||
return True | return True | ||||
def _check_function_or_method(func): | def _check_function_or_method(func): | ||||
if not _is_function_or_method(func): | if not _is_function_or_method(func): | ||||
raise TypeError(f"{type(func)} is not a method or function.") | raise TypeError(f"{type(func)} is not a method or function.") | ||||
def _move_dict_value_to_device(*args, device:torch.device): | |||||
def _move_dict_value_to_device(*args, device: torch.device): | |||||
""" | """ | ||||
move data to model's device, element in *args should be dict. This is a inplace change. | move data to model's device, element in *args should be dict. This is a inplace change. | ||||
@@ -206,7 +211,8 @@ class CheckError(Exception): | |||||
CheckError. Used in losses.LossBase, metrics.MetricBase. | CheckError. Used in losses.LossBase, metrics.MetricBase. | ||||
""" | """ | ||||
def __init__(self, check_res:CheckRes, func_signature:str): | |||||
def __init__(self, check_res: CheckRes, func_signature: str): | |||||
errs = [f'The following problems occurred when calling `{func_signature}`'] | errs = [f'The following problems occurred when calling `{func_signature}`'] | ||||
if check_res.varargs: | if check_res.varargs: | ||||
@@ -228,8 +234,9 @@ IGNORE_CHECK_LEVEL = 0 | |||||
WARNING_CHECK_LEVEL = 1 | WARNING_CHECK_LEVEL = 1 | ||||
STRICT_CHECK_LEVEL = 2 | STRICT_CHECK_LEVEL = 2 | ||||
def _check_loss_evaluate(prev_func_signature:str, func_signature:str, check_res:CheckRes, | |||||
pred_dict:dict, target_dict:dict, dataset, check_level=0): | |||||
def _check_loss_evaluate(prev_func_signature: str, func_signature: str, check_res: CheckRes, | |||||
pred_dict: dict, target_dict: dict, dataset, check_level=0): | |||||
errs = [] | errs = [] | ||||
unuseds = [] | unuseds = [] | ||||
_unused_field = [] | _unused_field = [] | ||||
@@ -268,8 +275,8 @@ def _check_loss_evaluate(prev_func_signature:str, func_signature:str, check_res: | |||||
f"target is {list(target_dict.keys())}).") | f"target is {list(target_dict.keys())}).") | ||||
if _miss_out_dataset: | if _miss_out_dataset: | ||||
_tmp = (f"You might need to provide {_miss_out_dataset} in DataSet and set it as target(Right now " | _tmp = (f"You might need to provide {_miss_out_dataset} in DataSet and set it as target(Right now " | ||||
f"target is {list(target_dict.keys())}) or output it " | |||||
f"in {prev_func_signature}(Right now it outputs {list(pred_dict.keys())}).") | |||||
f"target is {list(target_dict.keys())}) or output it " | |||||
f"in {prev_func_signature}(Right now it outputs {list(pred_dict.keys())}).") | |||||
if _unused_field: | if _unused_field: | ||||
_tmp += f"You can use DataSet.rename_field() to rename the field in `unused field:`. " | _tmp += f"You can use DataSet.rename_field() to rename the field in `unused field:`. " | ||||
suggestions.append(_tmp) | suggestions.append(_tmp) | ||||
@@ -277,15 +284,15 @@ def _check_loss_evaluate(prev_func_signature:str, func_signature:str, check_res: | |||||
if check_res.duplicated: | if check_res.duplicated: | ||||
errs.append(f"\tduplicated param: {check_res.duplicated}.") | errs.append(f"\tduplicated param: {check_res.duplicated}.") | ||||
suggestions.append(f"Delete {check_res.duplicated} in the output of " | suggestions.append(f"Delete {check_res.duplicated} in the output of " | ||||
f"{prev_func_signature} or do not set {check_res.duplicated} as targets. ") | |||||
f"{prev_func_signature} or do not set {check_res.duplicated} as targets. ") | |||||
if check_level == STRICT_CHECK_LEVEL: | if check_level == STRICT_CHECK_LEVEL: | ||||
errs.extend(unuseds) | errs.extend(unuseds) | ||||
if len(errs)>0: | |||||
if len(errs) > 0: | |||||
errs.insert(0, f'The following problems occurred when calling {func_signature}') | errs.insert(0, f'The following problems occurred when calling {func_signature}') | ||||
sugg_str = "" | sugg_str = "" | ||||
if len(suggestions)>1: | |||||
if len(suggestions) > 1: | |||||
for idx, sugg in enumerate(suggestions): | for idx, sugg in enumerate(suggestions): | ||||
sugg_str += f'({idx+1}). {sugg}' | sugg_str += f'({idx+1}). {sugg}' | ||||
else: | else: | ||||
@@ -332,10 +339,10 @@ def _check_forward_error(forward_func, batch_x, dataset, check_level): | |||||
if check_level == STRICT_CHECK_LEVEL: | if check_level == STRICT_CHECK_LEVEL: | ||||
errs.extend(_unused) | errs.extend(_unused) | ||||
if len(errs)>0: | |||||
if len(errs) > 0: | |||||
errs.insert(0, f'The following problems occurred when calling {func_signature}') | errs.insert(0, f'The following problems occurred when calling {func_signature}') | ||||
sugg_str = "" | sugg_str = "" | ||||
if len(suggestions)>1: | |||||
if len(suggestions) > 1: | |||||
for idx, sugg in enumerate(suggestions): | for idx, sugg in enumerate(suggestions): | ||||
sugg_str += f'({idx+1}). {sugg}' | sugg_str += f'({idx+1}). {sugg}' | ||||
else: | else: | ||||
@@ -357,11 +364,11 @@ def seq_lens_to_masks(seq_lens, float=True): | |||||
:return: list, np.ndarray or torch.Tensor, shape will be (B, max_length) | :return: list, np.ndarray or torch.Tensor, shape will be (B, max_length) | ||||
""" | """ | ||||
if isinstance(seq_lens, np.ndarray): | if isinstance(seq_lens, np.ndarray): | ||||
assert len(np.shape(seq_lens))==1, f"seq_lens can only have one dimension, got {len(np.shape(seq_lens))}." | |||||
assert len(np.shape(seq_lens)) == 1, f"seq_lens can only have one dimension, got {len(np.shape(seq_lens))}." | |||||
assert seq_lens.dtype in (int, np.int32, np.int64), f"seq_lens can only be integer, not {seq_lens.dtype}." | assert seq_lens.dtype in (int, np.int32, np.int64), f"seq_lens can only be integer, not {seq_lens.dtype}." | ||||
raise NotImplemented | raise NotImplemented | ||||
elif isinstance(seq_lens, torch.LongTensor): | elif isinstance(seq_lens, torch.LongTensor): | ||||
assert len(seq_lens.size())==1, f"seq_lens can only have one dimension, got {len(seq_lens.size())==1}." | |||||
assert len(seq_lens.size()) == 1, f"seq_lens can only have one dimension, got {len(seq_lens.size())==1}." | |||||
batch_size = seq_lens.size(0) | batch_size = seq_lens.size(0) | ||||
max_len = seq_lens.max() | max_len = seq_lens.max() | ||||
indexes = torch.arange(max_len).view(1, -1).repeat(batch_size, 1).to(seq_lens.device) | indexes = torch.arange(max_len).view(1, -1).repeat(batch_size, 1).to(seq_lens.device) | ||||
@@ -375,4 +382,3 @@ def seq_lens_to_masks(seq_lens, float=True): | |||||
raise NotImplemented | raise NotImplemented | ||||
else: | else: | ||||
raise NotImplemented | raise NotImplemented | ||||
@@ -31,15 +31,7 @@ class TrainerTestGround(unittest.TestCase): | |||||
model = NaiveClassifier(2, 1) | model = NaiveClassifier(2, 1) | ||||
trainer = Trainer(train_set, model, | |||||
losser=BCELoss(pred="predict", target="y"), | |||||
metrics=AccuracyMetric(pred="predict", target="y"), | |||||
n_epochs=10, | |||||
batch_size=32, | |||||
print_every=10, | |||||
validate_every=-1, | |||||
dev_data=dev_set, | |||||
optimizer=SGD(0.1), | |||||
check_code_level=2 | |||||
) | |||||
trainer = Trainer(train_set, model, losser=BCELoss(pred="predict", target="y"), | |||||
metrics=AccuracyMetric(pred="predict", target="y"), optimizer=SGD(), n_epochs=10, | |||||
batch_size=32, print_every=10, validate_every=-1, dev_data=dev_set, check_code_level=2) | |||||
trainer.train() | trainer.train() |
@@ -71,20 +71,16 @@ class TestTutorial(unittest.TestCase): | |||||
# 实例化Trainer,传入模型和数据,进行训练 | # 实例化Trainer,传入模型和数据,进行训练 | ||||
copy_model = deepcopy(model) | copy_model = deepcopy(model) | ||||
overfit_trainer = Trainer(model=copy_model, train_data=test_data, dev_data=test_data, | |||||
overfit_trainer = Trainer(train_data=test_data, model=copy_model, | |||||
losser=CrossEntropyLoss(pred="output", target="label_seq"), | losser=CrossEntropyLoss(pred="output", target="label_seq"), | ||||
metrics=AccuracyMetric(pred="predict", target="label_seq"), | |||||
save_path="./save", | |||||
batch_size=4, | |||||
n_epochs=10) | |||||
metrics=AccuracyMetric(pred="predict", target="label_seq"), n_epochs=10, batch_size=4, | |||||
dev_data=test_data, save_path="./save") | |||||
overfit_trainer.train() | overfit_trainer.train() | ||||
trainer = Trainer(model=model, train_data=train_data, dev_data=test_data, | |||||
trainer = Trainer(train_data=train_data, model=model, | |||||
losser=CrossEntropyLoss(pred="output", target="label_seq"), | losser=CrossEntropyLoss(pred="output", target="label_seq"), | ||||
metrics=AccuracyMetric(pred="predict", target="label_seq"), | |||||
save_path="./save", | |||||
batch_size=4, | |||||
n_epochs=10) | |||||
metrics=AccuracyMetric(pred="predict", target="label_seq"), n_epochs=10, batch_size=4, | |||||
dev_data=test_data, save_path="./save") | |||||
trainer.train() | trainer.train() | ||||
print('Train finished!') | print('Train finished!') | ||||