Browse Source

Updates:

* fix losses的_fast_param_map的bug
* Trainer添加sampelr初始化参数,并调整参数顺序
* refine codes
tags/v0.2.0^2
FengZiYjun 5 years ago
parent
commit
513876d5db
6 changed files with 65 additions and 80 deletions
  1. +1
    -2
      fastNLP/core/losses.py
  2. +28
    -29
      fastNLP/core/metrics.py
  3. +5
    -12
      fastNLP/core/trainer.py
  4. +22
    -16
      fastNLP/core/utils.py
  5. +3
    -11
      test/core/test_trainer.py
  6. +6
    -10
      test/test_tutorial.py

+ 1
- 2
fastNLP/core/losses.py View File

@@ -72,10 +72,9 @@ class LossBase(object):


def _fast_param_map(self, pred_dict, target_dict): def _fast_param_map(self, pred_dict, target_dict):
if len(self.param_map) == 2 and len(pred_dict) == 1 and len(target_dict) == 1: if len(self.param_map) == 2 and len(pred_dict) == 1 and len(target_dict) == 1:
return pred_dict.values[0], target_dict.values[0]
return tuple(pred_dict.values())[0], tuple(target_dict.values())[0]
return None return None



def __call__(self, pred_dict, target_dict, check=False): def __call__(self, pred_dict, target_dict, check=False):
""" """
:param pred_dict: A dict from forward function of the network. :param pred_dict: A dict from forward function of the network.


+ 28
- 29
fastNLP/core/metrics.py View File

@@ -1,4 +1,3 @@

import inspect import inspect
import warnings import warnings
from collections import defaultdict from collections import defaultdict
@@ -7,11 +6,12 @@ import numpy as np
import torch import torch


from fastNLP.core.utils import CheckError from fastNLP.core.utils import CheckError
from fastNLP.core.utils import CheckRes
from fastNLP.core.utils import _build_args from fastNLP.core.utils import _build_args
from fastNLP.core.utils import _check_arg_dict_list from fastNLP.core.utils import _check_arg_dict_list
from fastNLP.core.utils import get_func_signature from fastNLP.core.utils import get_func_signature
from fastNLP.core.utils import seq_lens_to_masks from fastNLP.core.utils import seq_lens_to_masks
from fastNLP.core.utils import CheckRes


class MetricBase(object): class MetricBase(object):
def __init__(self): def __init__(self):
@@ -59,9 +59,10 @@ class MetricBase(object):
func_args = [arg for arg in func_spect.args if arg != 'self'] func_args = [arg for arg in func_spect.args if arg != 'self']
for func_param, input_param in self.param_map.items(): for func_param, input_param in self.param_map.items():
if func_param not in func_args: if func_param not in func_args:
raise NameError(f"Parameter `{func_param}` is not in {get_func_signature(self.evaluate)}. Please check the "
f"initialization parameters, or change the signature of"
f" {get_func_signature(self.evaluate)}.")
raise NameError(
f"Parameter `{func_param}` is not in {get_func_signature(self.evaluate)}. Please check the "
f"initialization parameters, or change the signature of"
f" {get_func_signature(self.evaluate)}.")


# evaluate should not have varargs. # evaluate should not have varargs.
if func_spect.varargs: if func_spect.varargs:
@@ -113,7 +114,7 @@ class MetricBase(object):
if not self._checked: if not self._checked:
# 1. check consistence between signature and param_map # 1. check consistence between signature and param_map
func_spect = inspect.getfullargspec(self.evaluate) func_spect = inspect.getfullargspec(self.evaluate)
func_args = set([arg for arg in func_spect.args if arg!='self'])
func_args = set([arg for arg in func_spect.args if arg != 'self'])
for func_arg, input_arg in self.param_map.items(): for func_arg, input_arg in self.param_map.items():
if func_arg not in func_args: if func_arg not in func_args:
raise NameError(f"`{func_arg}` not in {get_func_signature(self.evaluate)}.") raise NameError(f"`{func_arg}` not in {get_func_signature(self.evaluate)}.")
@@ -121,7 +122,7 @@ class MetricBase(object):
# 2. only part of the param_map are passed, left are not # 2. only part of the param_map are passed, left are not
for arg in func_args: for arg in func_args:
if arg not in self.param_map: if arg not in self.param_map:
self.param_map[arg] = arg #This param does not need mapping.
self.param_map[arg] = arg # This param does not need mapping.
self._evaluate_args = func_args self._evaluate_args = func_args
self._reverse_param_map = {input_arg: func_arg for func_arg, input_arg in self.param_map.items()} self._reverse_param_map = {input_arg: func_arg for func_arg, input_arg in self.param_map.items()}


@@ -153,14 +154,14 @@ class MetricBase(object):
replaced_missing = list(missing) replaced_missing = list(missing)
for idx, func_arg in enumerate(missing): for idx, func_arg in enumerate(missing):
replaced_missing[idx] = f"{self.param_map[func_arg]}" + f"(assign to `{func_arg}` " \ replaced_missing[idx] = f"{self.param_map[func_arg]}" + f"(assign to `{func_arg}` " \
f"in `{self.__class__.__name__}`)"
f"in `{self.__class__.__name__}`)"


check_res = CheckRes(missing=replaced_missing, check_res = CheckRes(missing=replaced_missing,
unused=check_res.unused,
duplicated=duplicated,
required=check_res.required,
all_needed=check_res.all_needed,
varargs=check_res.varargs)
unused=check_res.unused,
duplicated=duplicated,
required=check_res.required,
all_needed=check_res.all_needed,
varargs=check_res.varargs)


if check_res.missing or check_res.duplicated or check_res.varargs: if check_res.missing or check_res.duplicated or check_res.varargs:
raise CheckError(check_res=check_res, raise CheckError(check_res=check_res,
@@ -172,6 +173,7 @@ class MetricBase(object):


return return



class AccuracyMetric(MetricBase): class AccuracyMetric(MetricBase):
def __init__(self, pred=None, target=None, masks=None, seq_lens=None): def __init__(self, pred=None, target=None, masks=None, seq_lens=None):
super().__init__() super().__init__()
@@ -191,7 +193,7 @@ class AccuracyMetric(MetricBase):
:param target_dict: :param target_dict:
:return: boolean, whether to go on codes in self.__call__(). When False, don't go on. :return: boolean, whether to go on codes in self.__call__(). When False, don't go on.
""" """
if len(pred_dict)==1 and len(target_dict)==1:
if len(pred_dict) == 1 and len(target_dict) == 1:
pred = list(pred_dict.values())[0] pred = list(pred_dict.values())[0]
target = list(target_dict.values())[0] target = list(target_dict.values())[0]
self.evaluate(pred=pred, target=target) self.evaluate(pred=pred, target=target)
@@ -211,7 +213,7 @@ class AccuracyMetric(MetricBase):
None, None, torch.Size([B], torch.Size([B]). ignored if masks are provided. None, None, torch.Size([B], torch.Size([B]). ignored if masks are provided.
:return: dict({'acc': float}) :return: dict({'acc': float})
""" """
#TODO 这里报错需要更改,因为pred是啥用户并不知道。需要告知用户真实的value
# TODO 这里报错需要更改,因为pred是啥用户并不知道。需要告知用户真实的value
if not isinstance(pred, torch.Tensor): if not isinstance(pred, torch.Tensor):
raise TypeError(f"`pred` in {get_func_signature(self.evaluate)} must be torch.Tensor," raise TypeError(f"`pred` in {get_func_signature(self.evaluate)} must be torch.Tensor,"
f"got {type(pred)}.") f"got {type(pred)}.")
@@ -224,14 +226,14 @@ class AccuracyMetric(MetricBase):
f"got {type(masks)}.") f"got {type(masks)}.")
elif seq_lens is not None and not isinstance(seq_lens, torch.Tensor): elif seq_lens is not None and not isinstance(seq_lens, torch.Tensor):
raise TypeError(f"`seq_lens` in {get_func_signature(self.evaluate)} must be torch.Tensor," raise TypeError(f"`seq_lens` in {get_func_signature(self.evaluate)} must be torch.Tensor,"
f"got {type(seq_lens)}.")
f"got {type(seq_lens)}.")


if masks is None and seq_lens is not None: if masks is None and seq_lens is not None:
masks = seq_lens_to_masks(seq_lens=seq_lens, float=True) masks = seq_lens_to_masks(seq_lens=seq_lens, float=True)


if pred.size()==target.size():
if pred.size() == target.size():
pass pass
elif len(pred.size())==len(target.size())+1:
elif len(pred.size()) == len(target.size()) + 1:
pred = pred.argmax(dim=-1) pred = pred.argmax(dim=-1)
else: else:
raise RuntimeError(f"In {get_func_signature(self.evaluate)}, when pred have " raise RuntimeError(f"In {get_func_signature(self.evaluate)}, when pred have "
@@ -245,18 +247,17 @@ class AccuracyMetric(MetricBase):
self.acc_count += torch.sum(torch.eq(pred, target).float() * masks.float()).item() self.acc_count += torch.sum(torch.eq(pred, target).float() * masks.float()).item()
self.total += torch.sum(masks.float()).item() self.total += torch.sum(masks.float()).item()
else: else:
self.acc_count += torch.sum(torch.eq(pred, target).float()).item()
self.acc_count += torch.sum(torch.eq(pred, target).float()).item()
self.total += np.prod(list(pred.size())) self.total += np.prod(list(pred.size()))


def get_metric(self, reset=True): def get_metric(self, reset=True):
evaluate_result = {'acc': round(self.acc_count/self.total, 6)}
evaluate_result = {'acc': round(self.acc_count / self.total, 6)}
if reset: if reset:
self.acc_count = 0 self.acc_count = 0
self.total = 0 self.total = 0
return evaluate_result return evaluate_result





def _prepare_metrics(metrics): def _prepare_metrics(metrics):
""" """


@@ -278,7 +279,8 @@ def _prepare_metrics(metrics):
raise TypeError(f"{metric_name}.get_metric must be callable, got {type(metric.get_metric)}.") raise TypeError(f"{metric_name}.get_metric must be callable, got {type(metric.get_metric)}.")
_metrics.append(metric) _metrics.append(metric)
else: else:
raise TypeError(f"The type of metric in metrics must be `fastNLP.MetricBase`, not `{type(metric)}`.")
raise TypeError(
f"The type of metric in metrics must be `fastNLP.MetricBase`, not `{type(metric)}`.")
elif isinstance(metrics, MetricBase): elif isinstance(metrics, MetricBase):
_metrics = [metrics] _metrics = [metrics]
else: else:
@@ -300,6 +302,7 @@ class Evaluator(object):
""" """
raise NotImplementedError raise NotImplementedError



class ClassifyEvaluator(Evaluator): class ClassifyEvaluator(Evaluator):
def __init__(self): def __init__(self):
super(ClassifyEvaluator, self).__init__() super(ClassifyEvaluator, self).__init__()
@@ -335,6 +338,7 @@ class SeqLabelEvaluator(Evaluator):
accuracy = total_correct / total_count accuracy = total_correct / total_count
return {"accuracy": float(accuracy)} return {"accuracy": float(accuracy)}



class SeqLabelEvaluator2(Evaluator): class SeqLabelEvaluator2(Evaluator):
# 上面的evaluator应该是错误的 # 上面的evaluator应该是错误的
def __init__(self, seq_lens_field_name='word_seq_origin_len'): def __init__(self, seq_lens_field_name='word_seq_origin_len'):
@@ -367,7 +371,7 @@ class SeqLabelEvaluator2(Evaluator):
if x_i in self.end_tagidx_set: if x_i in self.end_tagidx_set:
truth_count += 1 truth_count += 1
for j in range(start, idx_i + 1): for j in range(start, idx_i + 1):
if y_[j]!=x_[j]:
if y_[j] != x_[j]:
flag = False flag = False
break break
if flag: if flag:
@@ -380,8 +384,7 @@ class SeqLabelEvaluator2(Evaluator):
R = corr_count / (float(truth_count) + 1e-6) R = corr_count / (float(truth_count) + 1e-6)
F = 2 * P * R / (P + R + 1e-6) F = 2 * P * R / (P + R + 1e-6)


return {"P": P, 'R':R, 'F': F}

return {"P": P, 'R': R, 'F': F}




class SNLIEvaluator(Evaluator): class SNLIEvaluator(Evaluator):
@@ -563,10 +566,6 @@ def f1_score(y_true, y_pred, labels=None, pos_label=1, average='binary'):
return 2 * precision * recall / (precision + recall) if (precision + recall) > 0 else 0 return 2 * precision * recall / (precision + recall) if (precision + recall) > 0 else 0




def classification_report(y_true, y_pred, labels=None, target_names=None, digits=2):
raise NotImplementedError


def accuracy_topk(y_true, y_prob, k=1): def accuracy_topk(y_true, y_prob, k=1):
"""Compute accuracy of y_true matching top-k probable """Compute accuracy of y_true matching top-k probable
labels in y_prob. labels in y_prob.


+ 5
- 12
fastNLP/core/trainer.py View File

@@ -28,11 +28,9 @@ class Trainer(object):


""" """


def __init__(self, train_data, model, losser=None, metrics=None, n_epochs=3, batch_size=32, print_every=50,
validate_every=-1,
dev_data=None, use_cuda=False, save_path=None,
optimizer=Adam(lr=0.01, weight_decay=0), check_code_level=0,
metric_key=None):
def __init__(self, train_data, model, losser=None, metrics=None, optimizer=Adam(lr=0.01, weight_decay=0),
sampler=RandomSampler(), n_epochs=3, batch_size=32, print_every=50, validate_every=-1, dev_data=None,
use_cuda=False, metric_key=None, save_path=None, check_code_level=0):
""" """


:param DataSet train_data: the training data :param DataSet train_data: the training data
@@ -54,7 +52,6 @@ class Trainer(object):
:: ::
metric_key="-PPL" # language model gets better as perplexity gets smaller metric_key="-PPL" # language model gets better as perplexity gets smaller


:param kwargs:


""" """
super(Trainer, self).__init__() super(Trainer, self).__init__()
@@ -105,6 +102,7 @@ class Trainer(object):
self.print_every = int(print_every) self.print_every = int(print_every)
self.validate_every = int(validate_every) self.validate_every = int(validate_every)
self.best_metric_indicator = None self.best_metric_indicator = None
self.sampler = sampler


self._model_device = model.parameters().__next__().device self._model_device = model.parameters().__next__().device


@@ -120,14 +118,9 @@ class Trainer(object):
batch_size=self.batch_size, batch_size=self.batch_size,
use_cuda=self.use_cuda) use_cuda=self.use_cuda)


for k, v in kwargs.items():
setattr(self, k, v)

self.step = 0 self.step = 0
self.start_time = None # start timestamp self.start_time = None # start timestamp


# print(self.__dict__)

def train(self): def train(self):
"""Start Training. """Start Training.


@@ -158,7 +151,7 @@ class Trainer(object):
epoch = 1 epoch = 1
while epoch <= self.n_epochs: while epoch <= self.n_epochs:


data_iterator = Batch(self.train_data, batch_size=self.batch_size, sampler=RandomSampler(),
data_iterator = Batch(self.train_data, batch_size=self.batch_size, sampler=self.sampler,
as_numpy=False) as_numpy=False)


self._train_epoch(data_iterator, self.model, epoch, start) self._train_epoch(data_iterator, self.model, epoch, start)


+ 22
- 16
fastNLP/core/utils.py View File

@@ -10,6 +10,8 @@ import torch


CheckRes = namedtuple('CheckRes', ['missing', 'unused', 'duplicated', 'required', 'all_needed', CheckRes = namedtuple('CheckRes', ['missing', 'unused', 'duplicated', 'required', 'all_needed',
'varargs'], verbose=False) 'varargs'], verbose=False)


def save_pickle(obj, pickle_path, file_name): def save_pickle(obj, pickle_path, file_name):
"""Save an object into a pickle file. """Save an object into a pickle file.


@@ -53,6 +55,7 @@ def pickle_exist(pickle_path, pickle_name):
else: else:
return False return False



def _build_args(func, **kwargs): def _build_args(func, **kwargs):
spect = inspect.getfullargspec(func) spect = inspect.getfullargspec(func)
if spect.varkw is not None: if spect.varkw is not None:
@@ -108,7 +111,7 @@ def _check_arg_dict_list(func, args):
assert callable(func) and isinstance(arg_dict_list, (list, tuple)) assert callable(func) and isinstance(arg_dict_list, (list, tuple))
assert len(arg_dict_list) > 0 and isinstance(arg_dict_list[0], dict) assert len(arg_dict_list) > 0 and isinstance(arg_dict_list[0], dict)
spect = inspect.getfullargspec(func) spect = inspect.getfullargspec(func)
all_args = set([arg for arg in spect.args if arg!='self'])
all_args = set([arg for arg in spect.args if arg != 'self'])
defaults = [] defaults = []
if spect.defaults is not None: if spect.defaults is not None:
defaults = [arg for arg in spect.defaults] defaults = [arg for arg in spect.defaults]
@@ -130,6 +133,7 @@ def _check_arg_dict_list(func, args):
all_needed=list(all_args), all_needed=list(all_args),
varargs=varargs) varargs=varargs)



def get_func_signature(func): def get_func_signature(func):
""" """


@@ -153,7 +157,7 @@ def get_func_signature(func):
class_name = func.__self__.__class__.__name__ class_name = func.__self__.__class__.__name__
signature = inspect.signature(func) signature = inspect.signature(func)
signature_str = str(signature) signature_str = str(signature)
if len(signature_str)>2:
if len(signature_str) > 2:
_self = '(self, ' _self = '(self, '
else: else:
_self = '(self' _self = '(self'
@@ -176,12 +180,13 @@ def _is_function_or_method(func):
return False return False
return True return True



def _check_function_or_method(func): def _check_function_or_method(func):
if not _is_function_or_method(func): if not _is_function_or_method(func):
raise TypeError(f"{type(func)} is not a method or function.") raise TypeError(f"{type(func)} is not a method or function.")




def _move_dict_value_to_device(*args, device:torch.device):
def _move_dict_value_to_device(*args, device: torch.device):
""" """


move data to model's device, element in *args should be dict. This is a inplace change. move data to model's device, element in *args should be dict. This is a inplace change.
@@ -206,7 +211,8 @@ class CheckError(Exception):


CheckError. Used in losses.LossBase, metrics.MetricBase. CheckError. Used in losses.LossBase, metrics.MetricBase.
""" """
def __init__(self, check_res:CheckRes, func_signature:str):

def __init__(self, check_res: CheckRes, func_signature: str):
errs = [f'The following problems occurred when calling `{func_signature}`'] errs = [f'The following problems occurred when calling `{func_signature}`']


if check_res.varargs: if check_res.varargs:
@@ -228,8 +234,9 @@ IGNORE_CHECK_LEVEL = 0
WARNING_CHECK_LEVEL = 1 WARNING_CHECK_LEVEL = 1
STRICT_CHECK_LEVEL = 2 STRICT_CHECK_LEVEL = 2


def _check_loss_evaluate(prev_func_signature:str, func_signature:str, check_res:CheckRes,
pred_dict:dict, target_dict:dict, dataset, check_level=0):

def _check_loss_evaluate(prev_func_signature: str, func_signature: str, check_res: CheckRes,
pred_dict: dict, target_dict: dict, dataset, check_level=0):
errs = [] errs = []
unuseds = [] unuseds = []
_unused_field = [] _unused_field = []
@@ -268,8 +275,8 @@ def _check_loss_evaluate(prev_func_signature:str, func_signature:str, check_res:
f"target is {list(target_dict.keys())}).") f"target is {list(target_dict.keys())}).")
if _miss_out_dataset: if _miss_out_dataset:
_tmp = (f"You might need to provide {_miss_out_dataset} in DataSet and set it as target(Right now " _tmp = (f"You might need to provide {_miss_out_dataset} in DataSet and set it as target(Right now "
f"target is {list(target_dict.keys())}) or output it "
f"in {prev_func_signature}(Right now it outputs {list(pred_dict.keys())}).")
f"target is {list(target_dict.keys())}) or output it "
f"in {prev_func_signature}(Right now it outputs {list(pred_dict.keys())}).")
if _unused_field: if _unused_field:
_tmp += f"You can use DataSet.rename_field() to rename the field in `unused field:`. " _tmp += f"You can use DataSet.rename_field() to rename the field in `unused field:`. "
suggestions.append(_tmp) suggestions.append(_tmp)
@@ -277,15 +284,15 @@ def _check_loss_evaluate(prev_func_signature:str, func_signature:str, check_res:
if check_res.duplicated: if check_res.duplicated:
errs.append(f"\tduplicated param: {check_res.duplicated}.") errs.append(f"\tduplicated param: {check_res.duplicated}.")
suggestions.append(f"Delete {check_res.duplicated} in the output of " suggestions.append(f"Delete {check_res.duplicated} in the output of "
f"{prev_func_signature} or do not set {check_res.duplicated} as targets. ")
f"{prev_func_signature} or do not set {check_res.duplicated} as targets. ")


if check_level == STRICT_CHECK_LEVEL: if check_level == STRICT_CHECK_LEVEL:
errs.extend(unuseds) errs.extend(unuseds)


if len(errs)>0:
if len(errs) > 0:
errs.insert(0, f'The following problems occurred when calling {func_signature}') errs.insert(0, f'The following problems occurred when calling {func_signature}')
sugg_str = "" sugg_str = ""
if len(suggestions)>1:
if len(suggestions) > 1:
for idx, sugg in enumerate(suggestions): for idx, sugg in enumerate(suggestions):
sugg_str += f'({idx+1}). {sugg}' sugg_str += f'({idx+1}). {sugg}'
else: else:
@@ -332,10 +339,10 @@ def _check_forward_error(forward_func, batch_x, dataset, check_level):
if check_level == STRICT_CHECK_LEVEL: if check_level == STRICT_CHECK_LEVEL:
errs.extend(_unused) errs.extend(_unused)


if len(errs)>0:
if len(errs) > 0:
errs.insert(0, f'The following problems occurred when calling {func_signature}') errs.insert(0, f'The following problems occurred when calling {func_signature}')
sugg_str = "" sugg_str = ""
if len(suggestions)>1:
if len(suggestions) > 1:
for idx, sugg in enumerate(suggestions): for idx, sugg in enumerate(suggestions):
sugg_str += f'({idx+1}). {sugg}' sugg_str += f'({idx+1}). {sugg}'
else: else:
@@ -357,11 +364,11 @@ def seq_lens_to_masks(seq_lens, float=True):
:return: list, np.ndarray or torch.Tensor, shape will be (B, max_length) :return: list, np.ndarray or torch.Tensor, shape will be (B, max_length)
""" """
if isinstance(seq_lens, np.ndarray): if isinstance(seq_lens, np.ndarray):
assert len(np.shape(seq_lens))==1, f"seq_lens can only have one dimension, got {len(np.shape(seq_lens))}."
assert len(np.shape(seq_lens)) == 1, f"seq_lens can only have one dimension, got {len(np.shape(seq_lens))}."
assert seq_lens.dtype in (int, np.int32, np.int64), f"seq_lens can only be integer, not {seq_lens.dtype}." assert seq_lens.dtype in (int, np.int32, np.int64), f"seq_lens can only be integer, not {seq_lens.dtype}."
raise NotImplemented raise NotImplemented
elif isinstance(seq_lens, torch.LongTensor): elif isinstance(seq_lens, torch.LongTensor):
assert len(seq_lens.size())==1, f"seq_lens can only have one dimension, got {len(seq_lens.size())==1}."
assert len(seq_lens.size()) == 1, f"seq_lens can only have one dimension, got {len(seq_lens.size())==1}."
batch_size = seq_lens.size(0) batch_size = seq_lens.size(0)
max_len = seq_lens.max() max_len = seq_lens.max()
indexes = torch.arange(max_len).view(1, -1).repeat(batch_size, 1).to(seq_lens.device) indexes = torch.arange(max_len).view(1, -1).repeat(batch_size, 1).to(seq_lens.device)
@@ -375,4 +382,3 @@ def seq_lens_to_masks(seq_lens, float=True):
raise NotImplemented raise NotImplemented
else: else:
raise NotImplemented raise NotImplemented


+ 3
- 11
test/core/test_trainer.py View File

@@ -31,15 +31,7 @@ class TrainerTestGround(unittest.TestCase):


model = NaiveClassifier(2, 1) model = NaiveClassifier(2, 1)


trainer = Trainer(train_set, model,
losser=BCELoss(pred="predict", target="y"),
metrics=AccuracyMetric(pred="predict", target="y"),
n_epochs=10,
batch_size=32,
print_every=10,
validate_every=-1,
dev_data=dev_set,
optimizer=SGD(0.1),
check_code_level=2
)
trainer = Trainer(train_set, model, losser=BCELoss(pred="predict", target="y"),
metrics=AccuracyMetric(pred="predict", target="y"), optimizer=SGD(), n_epochs=10,
batch_size=32, print_every=10, validate_every=-1, dev_data=dev_set, check_code_level=2)
trainer.train() trainer.train()

+ 6
- 10
test/test_tutorial.py View File

@@ -71,20 +71,16 @@ class TestTutorial(unittest.TestCase):


# 实例化Trainer,传入模型和数据,进行训练 # 实例化Trainer,传入模型和数据,进行训练
copy_model = deepcopy(model) copy_model = deepcopy(model)
overfit_trainer = Trainer(model=copy_model, train_data=test_data, dev_data=test_data,
overfit_trainer = Trainer(train_data=test_data, model=copy_model,
losser=CrossEntropyLoss(pred="output", target="label_seq"), losser=CrossEntropyLoss(pred="output", target="label_seq"),
metrics=AccuracyMetric(pred="predict", target="label_seq"),
save_path="./save",
batch_size=4,
n_epochs=10)
metrics=AccuracyMetric(pred="predict", target="label_seq"), n_epochs=10, batch_size=4,
dev_data=test_data, save_path="./save")
overfit_trainer.train() overfit_trainer.train()


trainer = Trainer(model=model, train_data=train_data, dev_data=test_data,
trainer = Trainer(train_data=train_data, model=model,
losser=CrossEntropyLoss(pred="output", target="label_seq"), losser=CrossEntropyLoss(pred="output", target="label_seq"),
metrics=AccuracyMetric(pred="predict", target="label_seq"),
save_path="./save",
batch_size=4,
n_epochs=10)
metrics=AccuracyMetric(pred="predict", target="label_seq"), n_epochs=10, batch_size=4,
dev_data=test_data, save_path="./save")
trainer.train() trainer.train()
print('Train finished!') print('Train finished!')




Loading…
Cancel
Save