diff --git a/fastNLP/core/dataset.py b/fastNLP/core/dataset.py index 3dbea8eb..57171e25 100644 --- a/fastNLP/core/dataset.py +++ b/fastNLP/core/dataset.py @@ -69,7 +69,7 @@ class DataSet(object): self.idx = idx def __getitem__(self, item): - assert item in self.dataset.field_arrays, "no such field:{} in instance {}".format(item, self.dataset[self.idx]) + assert item in self.dataset.field_arrays, "no such field:{} in Instance {}".format(item, self.dataset[self.idx]) assert self.idx < len(self.dataset.field_arrays[item]), "index:{} out of range".format(self.idx) return self.dataset.field_arrays[item][self.idx] diff --git a/fastNLP/core/fieldarray.py b/fastNLP/core/fieldarray.py index e1d7a032..5167be35 100644 --- a/fastNLP/core/fieldarray.py +++ b/fastNLP/core/fieldarray.py @@ -83,7 +83,8 @@ class FieldArray(object): elif isinstance(content, list): # content is a 1-D list if len(content) == 0: - raise RuntimeError("Cannot create FieldArray with an empty list.") + # the old error is not informative enough. + raise RuntimeError("Cannot create FieldArray with an empty list. Or one element in the list is empty.") type_set = set([type(item) for item in content]) if len(type_set) == 1 and tuple(type_set)[0] in self.BASIC_TYPES: @@ -164,11 +165,13 @@ class FieldArray(object): # TODO 当这个fieldArray是seq_length这种只有一位的内容时,不需要padding,需要再讨论一下 if not is_iterable(self.content[0]): array = np.array([self.content[i] for i in indices], dtype=self.dtype) - else: + elif self.dtype in (np.int64, np.float64): max_len = max([len(self.content[i]) for i in indices]) array = np.full((batch_size, max_len), self.padding_val, dtype=self.dtype) for i, idx in enumerate(indices): array[i][:len(self.content[idx])] = self.content[idx] + else: # should only be str + array = np.array([self.content[i] for i in indices]) return array def __len__(self): diff --git a/fastNLP/core/losses.py b/fastNLP/core/losses.py index 2a9e89cd..a4976540 100644 --- a/fastNLP/core/losses.py +++ b/fastNLP/core/losses.py @@ -80,7 +80,7 @@ class LossBase(object): fast_param = {} if len(self.param_map) == 2 and len(pred_dict) == 1 and len(target_dict) == 1: fast_param['pred'] = list(pred_dict.values())[0] - fast_param['target'] = list(pred_dict.values())[0] + fast_param['target'] = list(target_dict.values())[0] return fast_param return fast_param @@ -134,10 +134,11 @@ class LossBase(object): # missing if not self._checked: check_res = _check_arg_dict_list(self.get_loss, [mapped_pred_dict, mapped_target_dict]) - # only check missing. + # replace missing. missing = check_res.missing replaced_missing = list(missing) for idx, func_arg in enumerate(missing): + # Don't delete `` in this information, nor add `` replaced_missing[idx] = f"{self.param_map[func_arg]}" + f"(assign to `{func_arg}` " \ f"in `{self.__class__.__name__}`)" @@ -188,7 +189,7 @@ class CrossEntropyLoss(LossBase): class L1Loss(LossBase): def __init__(self, pred=None, target=None): super(L1Loss, self).__init__() - self._init_param_map(input=pred, target=target) + self._init_param_map(pred=pred, target=target) def get_loss(self, pred, target): return F.l1_loss(input=pred, target=target) @@ -197,7 +198,7 @@ class L1Loss(LossBase): class BCELoss(LossBase): def __init__(self, pred=None, target=None): super(BCELoss, self).__init__() - self._init_param_map(input=pred, target=target) + self._init_param_map(pred=pred, target=target) def get_loss(self, pred, target): return F.binary_cross_entropy(input=pred, target=target) @@ -205,7 +206,7 @@ class BCELoss(LossBase): class NLLLoss(LossBase): def __init__(self, pred=None, target=None): super(NLLLoss, self).__init__() - self._init_param_map(input=pred, target=target) + self._init_param_map(pred=pred, target=target) def get_loss(self, pred, target): return F.nll_loss(input=pred, target=target) diff --git a/fastNLP/core/metrics.py b/fastNLP/core/metrics.py index f8279d0a..d97ba699 100644 --- a/fastNLP/core/metrics.py +++ b/fastNLP/core/metrics.py @@ -151,9 +151,11 @@ class MetricBase(object): if not self._checked: check_res = _check_arg_dict_list(self.evaluate, [mapped_pred_dict, mapped_target_dict]) # only check missing. + # replace missing. missing = check_res.missing replaced_missing = list(missing) for idx, func_arg in enumerate(missing): + # Don't delete `` in this information, nor add `` replaced_missing[idx] = f"{self.param_map[func_arg]}" + f"(assign to `{func_arg}` " \ f"in `{self.__class__.__name__}`)" diff --git a/fastNLP/core/trainer.py b/fastNLP/core/trainer.py index 13a3490a..8f676279 100644 --- a/fastNLP/core/trainer.py +++ b/fastNLP/core/trainer.py @@ -2,7 +2,7 @@ import os import time from datetime import datetime from datetime import timedelta -from tqdm import tqdm +from tqdm.autonotebook import tqdm import torch from tensorboardX import SummaryWriter @@ -23,7 +23,6 @@ from fastNLP.core.utils import _check_forward_error from fastNLP.core.utils import _check_loss_evaluate from fastNLP.core.utils import _move_dict_value_to_device from fastNLP.core.utils import get_func_signature -from fastNLP.core.utils import _relocate_pbar class Trainer(object): """Main Training Loop @@ -45,7 +44,7 @@ class Trainer(object): :param int validate_every: step interval to do next validation. Default: -1(validate every epoch). :param DataSet dev_data: the validation data :param use_cuda: - :param str save_path: file path to save models + :param save_path: file path to save models :param Optimizer optimizer: an optimizer object :param int check_code_level: level of FastNLP code checker. -1: don't check, 0: ignore. 1: warning. 2: strict. `ignore` will not check unused field; `warning` when warn if some field are not used; `strict` means @@ -149,7 +148,7 @@ class Trainer(object): self._mode(self.model, is_test=False) self.start_time = str(datetime.now().strftime('%Y-%m-%d %H:%M:%S')) - print("training epochs started " + self.start_time) + print("training epochs started " + self.start_time, flush=True) if self.save_path is None: class psudoSW: def __getattr__(self, item): @@ -172,12 +171,12 @@ class Trainer(object): del self._summary_writer def _tqdm_train(self): + self.step = 0 data_iterator = Batch(self.train_data, batch_size=self.batch_size, sampler=self.sampler, as_numpy=False) total_steps = data_iterator.num_batches*self.n_epochs epoch = 1 - with tqdm(total=total_steps, postfix='loss:{0:<6.5f}', desc="Epoch {}/{}" - .format(epoch, self.n_epochs), leave=False, dynamic_ncols=True) as pbar: + with tqdm(total=total_steps, postfix='loss:{0:<6.5f}', leave=False, dynamic_ncols=True) as pbar: ava_loss = 0 for epoch in range(1, self.n_epochs+1): pbar.set_description_str(desc="Epoch {}/{}".format(epoch, self.n_epochs)) @@ -195,28 +194,26 @@ class Trainer(object): # self._summary_writer.add_scalar(name + "_std", param.std(), global_step=self.step) # self._summary_writer.add_scalar(name + "_grad_sum", param.sum(), global_step=self.step) if (self.step+1) % self.print_every == 0: - pbar.update(self.print_every) - pbar.set_postfix_str("loss:{0:<6.5f}".format(ava_loss/self.print_every)) + pbar.set_postfix_str("loss:{0:<6.5f}".format(ava_loss / self.print_every)) ava_loss = 0 - + pbar.update(1) self.step += 1 if self.validate_every > 0 and self.step % self.validate_every == 0 \ and self.dev_data is not None: eval_res = self._do_validation() eval_str = "Epoch {}/{}. Step:{}/{}. ".format(epoch, self.n_epochs, self.step, total_steps) + \ self.tester._format_eval_results(eval_res) - pbar = _relocate_pbar(pbar, print_str=eval_str) + pbar.write(eval_str) if self.validate_every < 0 and self.dev_data: eval_res = self._do_validation() eval_str = "Epoch {}/{}. Step:{}/{}. ".format(epoch, self.n_epochs, self.step, total_steps) + \ self.tester._format_eval_results(eval_res) - pbar = _relocate_pbar(pbar, print_str=eval_str) + pbar.write(eval_str) if epoch!=self.n_epochs: data_iterator = Batch(self.train_data, batch_size=self.batch_size, sampler=self.sampler, as_numpy=False) pbar.close() - def _print_train(self): """ @@ -264,9 +261,6 @@ class Trainer(object): self._do_validation() epoch += 1 - - - def _do_validation(self): res = self.tester.test() for name, num in res.items(): diff --git a/fastNLP/core/utils.py b/fastNLP/core/utils.py index 0019b022..0e2bba07 100644 --- a/fastNLP/core/utils.py +++ b/fastNLP/core/utils.py @@ -258,29 +258,48 @@ def _check_loss_evaluate(prev_func_signature: str, func_signature: str, check_re if _unused_param: unuseds.append(f"\tunused param: {_unused_param}") # output from predict or forward + module_name = '' if check_res.missing: errs.append(f"\tmissing param: {check_res.missing}") - _miss_in_dataset = [] - _miss_out_dataset = [] + import re + mapped_missing = [] + unmapped_missing = [] + input_func_map = {} for _miss in check_res.missing: + fun_arg, module_name = re.findall("(?<=`)[a-zA-Z0-9]*?(?=`)", _miss) if '(' in _miss: # if they are like 'SomeParam(assign to xxx)' _miss = _miss.split('(')[0] - if _miss in dataset: - _miss_in_dataset.append(_miss) + input_func_map[_miss] = fun_arg + if fun_arg == _miss: + unmapped_missing.append(_miss) else: - _miss_out_dataset.append(_miss) + mapped_missing.append(_miss) - if _miss_in_dataset: - suggestions.append(f"You might need to set {_miss_in_dataset} as target(Right now " - f"target is {list(target_dict.keys())}).") - if _miss_out_dataset: - _tmp = (f"You might need to provide {_miss_out_dataset} in DataSet and set it as target(Right now " - f"target has {list(target_dict.keys())}) or output it " - f"in {prev_func_signature}(Right now output has {list(pred_dict.keys())}).") - # if _unused_field: - # _tmp += f"You can use DataSet.rename_field() to rename the field in `unused field:`. " - suggestions.append(_tmp) + for _miss in mapped_missing: + if _miss in dataset: + suggestions.append(f"Set {_miss} as target.") + else: + _tmp = '' + if check_res.unused: + _tmp = f"Check key assignment for `{input_func_map[_miss]}` when initialize {module_name}." + if _tmp: + _tmp += f' Or provide {_miss} in DataSet or output of {prev_func_signature}.' + else: + _tmp = f'Provide {_miss} in DataSet or output of {prev_func_signature}.' + suggestions.append(_tmp) + for _miss in unmapped_missing: + if _miss in dataset: + suggestions.append(f"Set {_miss} as target.") + else: + _tmp = '' + if check_res.unused: + _tmp = f"Specify your assignment for `{input_func_map[_miss]}` when initialize {module_name}." + if _tmp: + _tmp += f' Or provide {_miss} in DataSet or output of {prev_func_signature}.' + else: + _tmp = f'Provide {_miss} in DataSet or output of {prev_func_signature}.' + suggestions.append(_tmp) if check_res.duplicated: errs.append(f"\tduplicated param: {check_res.duplicated}.") @@ -297,17 +316,23 @@ def _check_loss_evaluate(prev_func_signature: str, func_signature: str, check_re sugg_str = "" if len(suggestions) > 1: for idx, sugg in enumerate(suggestions): - sugg_str += f'({idx+1}). {sugg}' + if idx>0: + sugg_str += '\t\t\t' + sugg_str += f'({idx+1}). {sugg}\n' + sugg_str = sugg_str[:-1] else: sugg_str += suggestions[0] + errs.append(f'\ttarget field: {list(target_dict.keys())}') + errs.append(f'\tparam from {prev_func_signature}: {list(pred_dict.keys())}') err_str = '\n' + '\n'.join(errs) + '\n\tSuggestion: ' + sugg_str raise NameError(err_str) if check_res.unused: if check_level == WARNING_CHECK_LEVEL: - _unused_warn = f'{check_res.unused} is not used by {func_signature}.' + if not module_name: + module_name = func_signature.split('.')[0] + _unused_warn = f'{check_res.unused} is not used by {module_name}.' warnings.warn(message=_unused_warn) - def _check_forward_error(forward_func, batch_x, dataset, check_level): check_res = _check_arg_dict_list(forward_func, batch_x) func_signature = get_func_signature(forward_func) @@ -402,40 +427,3 @@ def seq_mask(seq_len, max_len): seq_len = seq_len.view(-1, 1).long() # [batch_size, 1] seq_range = torch.arange(start=0, end=max_len, dtype=torch.long, device=seq_len.device).view(1, -1) # [1, max_len] return torch.gt(seq_len, seq_range) # [batch_size, max_len] - - -def _relocate_pbar(pbar:tqdm, print_str:str): - """ - - When using tqdm, you cannot print. If you print, the tqdm will duplicate. By using this function, print_str will - show above tqdm. - :param pbar: tqdm - :param print_str: - :return: - """ - - params = ['desc', 'total', 'leave', 'file', 'ncols', 'mininterval', 'maxinterval', 'miniters', 'ascii', 'disable', - 'unit', 'unit_scale', 'dynamic_ncols', 'smoothing', 'bar_format', 'initial', 'position', 'postfix', 'unit_divisor', - 'gui'] - - attr_map = {'file': 'fp', 'initial':'n', 'position':'pos'} - - param_dict = {} - for param in params: - attr_name = param - if param in attr_map: - attr_name = attr_map[param] - value = getattr(pbar, attr_name) - if attr_name == 'pos': - value = abs(value) - param_dict[param] = value - - pbar.close() - avg_time = pbar.avg_time - start_t = pbar.start_t - print(print_str) - pbar = tqdm(**param_dict) - pbar.start_t = start_t - pbar.avg_time = avg_time - pbar.sp(pbar.__repr__()) - return pbar \ No newline at end of file diff --git a/requirements.txt b/requirements.txt index 60ab7849..45c84bc2 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,4 +1,4 @@ numpy>=1.14.2 torch>=0.4.0 tensorboardX -tqdm \ No newline at end of file +tqdm>=4.28.1 \ No newline at end of file diff --git a/test/core/test_dataset.py b/test/core/test_dataset.py index 493a740c..fe58b2f2 100644 --- a/test/core/test_dataset.py +++ b/test/core/test_dataset.py @@ -142,9 +142,16 @@ class TestDataSet(unittest.TestCase): def split_sent(ins): return ins['raw_sentence'].split() dataset = DataSet.read_csv('../../sentence.csv', headers=('raw_sentence', 'label'), sep='\t') - dataset.apply(split_sent, new_field_name='words') + dataset.drop(lambda x:len(x['raw_sentence'].split())==0) + dataset.apply(split_sent, new_field_name='words', is_input=True) # print(dataset) + def test_add_field(self): + ds = DataSet({"x": [3, 4]}) + ds.add_field('y', [['hello', 'world'], ['this', 'is', 'a', 'test']], is_input=True, is_target=True) + # ds.apply(lambda x:[x['x']]*3, is_input=True, is_target=True, new_field_name='y') + print(ds) + def test_save_load(self): ds = DataSet({"x": [[1, 2, 3, 4]] * 10, "y": [[5, 6]] * 10}) ds.save("./my_ds.pkl") diff --git a/test/core/test_tester.py b/test/core/test_tester.py index 68143f7b..99a8000e 100644 --- a/test/core/test_tester.py +++ b/test/core/test_tester.py @@ -4,6 +4,64 @@ data_name = "pku_training.utf8" pickle_path = "data_for_tests" +import numpy as np +import torch.nn.functional as F +from torch import nn +import time +from fastNLP.core.utils import CheckError +from fastNLP.core.dataset import DataSet +from fastNLP.core.instance import Instance +from fastNLP.core.losses import BCELoss +from fastNLP.core.losses import CrossEntropyLoss +from fastNLP.core.metrics import AccuracyMetric +from fastNLP.core.optimizer import SGD +from fastNLP.core.tester import Tester +from fastNLP.models.base_model import NaiveClassifier + +def prepare_fake_dataset(): + mean = np.array([-3, -3]) + cov = np.array([[1, 0], [0, 1]]) + class_A = np.random.multivariate_normal(mean, cov, size=(1000,)) + + mean = np.array([3, 3]) + cov = np.array([[1, 0], [0, 1]]) + class_B = np.random.multivariate_normal(mean, cov, size=(1000,)) + + data_set = DataSet([Instance(x=[float(item[0]), float(item[1])], y=[0.0]) for item in class_A] + + [Instance(x=[float(item[0]), float(item[1])], y=[1.0]) for item in class_B]) + return data_set + + +def prepare_fake_dataset2(*args, size=100): + ys = np.random.randint(4, size=100, dtype=np.int64) + data = {'y': ys} + for arg in args: + data[arg] = np.random.randn(size, 5) + return DataSet(data=data) + class TestTester(unittest.TestCase): def test_case_1(self): - pass + # 检查报错提示能否正确提醒用户 + # 这里传入多余参数,让其duplicate + dataset = prepare_fake_dataset2('x1', 'x_unused') + dataset.rename_field('x_unused', 'x2') + dataset.set_input('x1', 'x2') + dataset.set_target('y', 'x1') + class Model(nn.Module): + def __init__(self): + super().__init__() + self.fc = nn.Linear(5, 4) + def forward(self, x1, x2): + x1 = self.fc(x1) + x2 = self.fc(x2) + x = x1 + x2 + time.sleep(0.1) + # loss = F.cross_entropy(x, y) + return {'preds': x} + + model = Model() + tester = Tester( + data=dataset, + model=model, + metrics=AccuracyMetric()) + tester.test() diff --git a/test/core/test_trainer.py b/test/core/test_trainer.py index 38fb6e0e..a69438ae 100644 --- a/test/core/test_trainer.py +++ b/test/core/test_trainer.py @@ -3,7 +3,7 @@ import unittest import numpy as np import torch.nn.functional as F from torch import nn - +import time from fastNLP.core.utils import CheckError from fastNLP.core.dataset import DataSet from fastNLP.core.instance import Instance @@ -212,8 +212,8 @@ class TrainerTestGround(unittest.TestCase): # 这里传入多余参数,让其duplicate dataset = prepare_fake_dataset2('x1', 'x_unused') dataset.rename_field('x_unused', 'x2') - dataset.set_input('x1', 'x2', 'y') - dataset.set_target('x1', 'x2') + dataset.set_input('x1', 'x2') + dataset.set_target('y', 'x1') class Model(nn.Module): def __init__(self): super().__init__() @@ -222,8 +222,9 @@ class TrainerTestGround(unittest.TestCase): x1 = self.fc(x1) x2 = self.fc(x2) x = x1 + x2 + time.sleep(0.1) # loss = F.cross_entropy(x, y) - return {'pred': x} + return {'preds': x} model = Model() trainer = Trainer( diff --git a/tutorials/fastnlp_tutorial_1204.ipynb b/tutorials/fastnlp_tutorial_1204.ipynb index 1fa1adca..8d896bf2 100644 --- a/tutorials/fastnlp_tutorial_1204.ipynb +++ b/tutorials/fastnlp_tutorial_1204.ipynb @@ -12,7 +12,7 @@ }, { "cell_type": "code", - "execution_count": 3, + "execution_count": null, "metadata": {}, "outputs": [], "source": [ @@ -34,17 +34,9 @@ }, { "cell_type": "code", - "execution_count": 4, - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "8529\n" - ] - } - ], + "execution_count": null, + "metadata": {}, + "outputs": [], "source": [ "from fastNLP import DataSet\n", "from fastNLP import Instance\n", @@ -56,20 +48,9 @@ }, { "cell_type": "code", - "execution_count": 5, - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "{'raw_sentence': A series of escapades demonstrating the adage that what is good for the goose is also good for the gander , some of which occasionally amuses but none of which amounts to much of a story .,\n", - "'label': 1}\n", - "{'raw_sentence': -LRB- Tries -RRB- to parody a genre that 's already a joke in the United States .,\n", - "'label': 1}\n" - ] - } - ], + "execution_count": null, + "metadata": {}, + "outputs": [], "source": [ "# 使用数字索引[k],获取第k个样本\n", "print(dataset[0])\n", @@ -90,21 +71,9 @@ }, { "cell_type": "code", - "execution_count": 6, - "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "{'raw_sentence': fake data,\n", - "'label': 0}" - ] - }, - "execution_count": 6, - "metadata": {}, - "output_type": "execute_result" - } - ], + "execution_count": null, + "metadata": {}, + "outputs": [], "source": [ "# DataSet.append(Instance)加入新数据\n", "dataset.append(Instance(raw_sentence='fake data', label='0'))\n", @@ -121,18 +90,9 @@ }, { "cell_type": "code", - "execution_count": 7, - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "{'raw_sentence': a series of escapades demonstrating the adage that what is good for the goose is also good for the gander , some of which occasionally amuses but none of which amounts to much of a story .,\n", - "'label': 1}\n" - ] - } - ], + "execution_count": null, + "metadata": {}, + "outputs": [], "source": [ "# 将所有数字转为小写\n", "dataset.apply(lambda x: x['raw_sentence'].lower(), new_field_name='raw_sentence')\n", @@ -141,18 +101,9 @@ }, { "cell_type": "code", - "execution_count": 8, - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "{'raw_sentence': a series of escapades demonstrating the adage that what is good for the goose is also good for the gander , some of which occasionally amuses but none of which amounts to much of a story .,\n", - "'label': 1}\n" - ] - } - ], + "execution_count": null, + "metadata": {}, + "outputs": [], "source": [ "# label转int\n", "dataset.apply(lambda x: int(x['label']), new_field_name='label')\n", @@ -161,28 +112,9 @@ }, { "cell_type": "code", - "execution_count": 9, - "metadata": {}, - "outputs": [ - { - "ename": "RuntimeError", - "evalue": "Cannot create FieldArray with an empty list.", - "output_type": "error", - "traceback": [ - "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", - "\u001b[0;31mRuntimeError\u001b[0m Traceback (most recent call last)", - "\u001b[0;32m\u001b[0m in \u001b[0;36m\u001b[0;34m()\u001b[0m\n\u001b[1;32m 2\u001b[0m \u001b[0;32mdef\u001b[0m \u001b[0msplit_sent\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mins\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 3\u001b[0m \u001b[0;32mreturn\u001b[0m \u001b[0mins\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;34m'raw_sentence'\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0msplit\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m----> 4\u001b[0;31m \u001b[0mdataset\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mapply\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0msplit_sent\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mnew_field_name\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;34m'words'\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 5\u001b[0m \u001b[0mprint\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mdataset\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;36m0\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", - "\u001b[0;32m~/Desktop/fastNLP/fastNLP/fastNLP/core/dataset.py\u001b[0m in \u001b[0;36mapply\u001b[0;34m(self, func, new_field_name, **kwargs)\u001b[0m\n\u001b[1;32m 265\u001b[0m **extra_param)\n\u001b[1;32m 266\u001b[0m \u001b[0;32melse\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 267\u001b[0;31m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0madd_field\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mname\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mnew_field_name\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mfields\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mresults\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mextra_param\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 268\u001b[0m \u001b[0;32melse\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 269\u001b[0m \u001b[0;32mreturn\u001b[0m \u001b[0mresults\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", - "\u001b[0;32m~/Desktop/fastNLP/fastNLP/fastNLP/core/dataset.py\u001b[0m in \u001b[0;36madd_field\u001b[0;34m(self, name, fields, padding_val, is_input, is_target)\u001b[0m\n\u001b[1;32m 158\u001b[0m f\"Dataset size {len(self)} != field size {len(fields)}\")\n\u001b[1;32m 159\u001b[0m self.field_arrays[name] = FieldArray(name, fields, padding_val=padding_val, is_target=is_target,\n\u001b[0;32m--> 160\u001b[0;31m is_input=is_input)\n\u001b[0m\u001b[1;32m 161\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 162\u001b[0m \u001b[0;32mdef\u001b[0m \u001b[0mdelete_field\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mname\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", - "\u001b[0;32m~/Desktop/fastNLP/fastNLP/fastNLP/core/fieldarray.py\u001b[0m in \u001b[0;36m__init__\u001b[0;34m(self, name, content, padding_val, is_target, is_input)\u001b[0m\n\u001b[1;32m 36\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 37\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0mis_input\u001b[0m \u001b[0;32mis\u001b[0m \u001b[0;32mnot\u001b[0m \u001b[0;32mNone\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 38\u001b[0;31m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mis_input\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mis_input\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 39\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0mis_target\u001b[0m \u001b[0;32mis\u001b[0m \u001b[0;32mnot\u001b[0m \u001b[0;32mNone\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 40\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mis_target\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mis_target\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", - "\u001b[0;32m~/Desktop/fastNLP/fastNLP/fastNLP/core/fieldarray.py\u001b[0m in \u001b[0;36mis_input\u001b[0;34m(self, value)\u001b[0m\n\u001b[1;32m 46\u001b[0m \u001b[0;34m@\u001b[0m\u001b[0mis_input\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0msetter\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 47\u001b[0m \u001b[0;32mdef\u001b[0m \u001b[0mis_input\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mvalue\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 48\u001b[0;31m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mpytype\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_type_detection\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mcontent\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 49\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mdtype\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_map_to_np_type\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mpytype\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 50\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_is_input\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mvalue\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", - "\u001b[0;32m~/Desktop/fastNLP/fastNLP/fastNLP/core/fieldarray.py\u001b[0m in \u001b[0;36m_type_detection\u001b[0;34m(self, content)\u001b[0m\n\u001b[1;32m 71\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0;32mnot\u001b[0m \u001b[0mall\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0misinstance\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0m_\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mlist\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;32mfor\u001b[0m \u001b[0m_\u001b[0m \u001b[0;32min\u001b[0m \u001b[0mcontent\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m \u001b[0;31m# strict check 2-D list\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 72\u001b[0m \u001b[0;32mraise\u001b[0m \u001b[0mTypeError\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m\"Please provide 2-D list.\"\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 73\u001b[0;31m \u001b[0mtype_set\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mset\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_type_detection\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mx\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;32mfor\u001b[0m \u001b[0mx\u001b[0m \u001b[0;32min\u001b[0m \u001b[0mcontent\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 74\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0mlen\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mtype_set\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;34m==\u001b[0m \u001b[0;36m2\u001b[0m \u001b[0;32mand\u001b[0m \u001b[0mint\u001b[0m \u001b[0;32min\u001b[0m \u001b[0mtype_set\u001b[0m \u001b[0;32mand\u001b[0m \u001b[0mfloat\u001b[0m \u001b[0;32min\u001b[0m \u001b[0mtype_set\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 75\u001b[0m \u001b[0mtype_set\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;34m{\u001b[0m\u001b[0mfloat\u001b[0m\u001b[0;34m}\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", - "\u001b[0;32m~/Desktop/fastNLP/fastNLP/fastNLP/core/fieldarray.py\u001b[0m in \u001b[0;36m\u001b[0;34m(.0)\u001b[0m\n\u001b[1;32m 71\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0;32mnot\u001b[0m \u001b[0mall\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0misinstance\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0m_\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mlist\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;32mfor\u001b[0m \u001b[0m_\u001b[0m \u001b[0;32min\u001b[0m \u001b[0mcontent\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m \u001b[0;31m# strict check 2-D list\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 72\u001b[0m \u001b[0;32mraise\u001b[0m \u001b[0mTypeError\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m\"Please provide 2-D list.\"\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 73\u001b[0;31m \u001b[0mtype_set\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mset\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_type_detection\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mx\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;32mfor\u001b[0m \u001b[0mx\u001b[0m \u001b[0;32min\u001b[0m \u001b[0mcontent\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 74\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0mlen\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mtype_set\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;34m==\u001b[0m \u001b[0;36m2\u001b[0m \u001b[0;32mand\u001b[0m \u001b[0mint\u001b[0m \u001b[0;32min\u001b[0m \u001b[0mtype_set\u001b[0m \u001b[0;32mand\u001b[0m \u001b[0mfloat\u001b[0m \u001b[0;32min\u001b[0m \u001b[0mtype_set\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 75\u001b[0m \u001b[0mtype_set\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;34m{\u001b[0m\u001b[0mfloat\u001b[0m\u001b[0;34m}\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", - "\u001b[0;32m~/Desktop/fastNLP/fastNLP/fastNLP/core/fieldarray.py\u001b[0m in \u001b[0;36m_type_detection\u001b[0;34m(self, content)\u001b[0m\n\u001b[1;32m 82\u001b[0m \u001b[0;31m# content is a 1-D list\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 83\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0mlen\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mcontent\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;34m==\u001b[0m \u001b[0;36m0\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 84\u001b[0;31m \u001b[0;32mraise\u001b[0m \u001b[0mRuntimeError\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m\"Cannot create FieldArray with an empty list.\"\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 85\u001b[0m \u001b[0mtype_set\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mset\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0mtype\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mitem\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;32mfor\u001b[0m \u001b[0mitem\u001b[0m \u001b[0;32min\u001b[0m \u001b[0mcontent\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 86\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n", - "\u001b[0;31mRuntimeError\u001b[0m: Cannot create FieldArray with an empty list." - ] - } - ], + "execution_count": null, + "metadata": {}, + "outputs": [], "source": [ "# 使用空格分割句子\n", "def split_sent(ins):\n", @@ -193,20 +125,9 @@ }, { "cell_type": "code", - "execution_count": 17, - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "{'raw_sentence': a series of escapades demonstrating the adage that what is good for the goose is also good for the gander , some of which occasionally amuses but none of which amounts to much of a story .,\n", - "'label': 1,\n", - "'words': ['a', 'series', 'of', 'escapades', 'demonstrating', 'the', 'adage', 'that', 'what', 'is', 'good', 'for', 'the', 'goose', 'is', 'also', 'good', 'for', 'the', 'gander', ',', 'some', 'of', 'which', 'occasionally', 'amuses', 'but', 'none', 'of', 'which', 'amounts', 'to', 'much', 'of', 'a', 'story', '.'],\n", - "'seq_len': 37}\n" - ] - } - ], + "execution_count": null, + "metadata": {}, + "outputs": [], "source": [ "# 增加长度信息\n", "dataset.apply(lambda x: len(x['words']), new_field_name='seq_len')\n", @@ -223,17 +144,9 @@ }, { "cell_type": "code", - "execution_count": 19, - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "38\n" - ] - } - ], + "execution_count": null, + "metadata": {}, + "outputs": [], "source": [ "dataset.drop(lambda x: x['seq_len'] <= 3)\n", "print(len(dataset))" @@ -250,7 +163,7 @@ }, { "cell_type": "code", - "execution_count": 20, + "execution_count": null, "metadata": {}, "outputs": [], "source": [ @@ -264,18 +177,9 @@ }, { "cell_type": "code", - "execution_count": 21, - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "27\n", - "11" - ] - } - ], + "execution_count": null, + "metadata": {}, + "outputs": [], "source": [ "# 分出测试集、训练集\n", "\n", @@ -296,20 +200,9 @@ }, { "cell_type": "code", - "execution_count": 22, - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "{'raw_sentence': that the chuck norris `` grenade gag '' occurs about 7 times during windtalkers is a good indication of how serious-minded the film is .,\n", - "'label': 2,\n", - "'words': [6, 2, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 10, 8, 24, 1, 5, 1, 1, 2, 15, 10, 3],\n", - "'seq_len': 25}\n" - ] - } - ], + "execution_count": null, + "metadata": {}, + "outputs": [], "source": [ "from fastNLP import Vocabulary\n", "\n", @@ -336,36 +229,9 @@ }, { "cell_type": "code", - "execution_count": 23, - "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "CNNText(\n", - " (embed): Embedding(\n", - " (embed): Embedding(32, 50, padding_idx=0)\n", - " (dropout): Dropout(p=0.0)\n", - " )\n", - " (conv_pool): ConvMaxpool(\n", - " (convs): ModuleList(\n", - " (0): Conv1d(50, 3, kernel_size=(3,), stride=(1,), padding=(2,))\n", - " (1): Conv1d(50, 4, kernel_size=(4,), stride=(1,), padding=(2,))\n", - " (2): Conv1d(50, 5, kernel_size=(5,), stride=(1,), padding=(2,))\n", - " )\n", - " )\n", - " (dropout): Dropout(p=0.1)\n", - " (fc): Linear(\n", - " (linear): Linear(in_features=12, out_features=5, bias=True)\n", - " )\n", - ")" - ] - }, - "execution_count": 23, - "metadata": {}, - "output_type": "execute_result" - } - ], + "execution_count": null, + "metadata": {}, + "outputs": [], "source": [ "from fastNLP.models import CNNText\n", "model = CNNText(embed_num=len(vocab), embed_dim=50, num_classes=5, padding=2, dropout=0.1)\n", @@ -432,7 +298,7 @@ }, { "cell_type": "code", - "execution_count": 25, + "execution_count": null, "metadata": {}, "outputs": [], "source": [ @@ -469,7 +335,7 @@ }, { "cell_type": "code", - "execution_count": 26, + "execution_count": null, "metadata": {}, "outputs": [], "source": [ @@ -492,7 +358,7 @@ }, { "cell_type": "code", - "execution_count": 27, + "execution_count": null, "metadata": {}, "outputs": [], "source": [ @@ -501,94 +367,9 @@ }, { "cell_type": "code", - "execution_count": 30, - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "training epochs started 2018-12-04 22:51:24\n" - ] - }, - { - "name": "stderr", - "output_type": "stream", - "text": [ - " \r" - ] - }, - { - "name": "stdout", - "output_type": "stream", - "text": [ - "Epoch 1/5. Step:1/5. AccuracyMetric: acc=0.296296\n" - ] - }, - { - "name": "stderr", - "output_type": "stream", - "text": [ - " \r" - ] - }, - { - "name": "stdout", - "output_type": "stream", - "text": [ - "Epoch 2/5. Step:2/5. AccuracyMetric: acc=0.407407\n" - ] - }, - { - "name": "stderr", - "output_type": "stream", - "text": [ - " \r" - ] - }, - { - "name": "stdout", - "output_type": "stream", - "text": [ - "Epoch 3/5. Step:3/5. AccuracyMetric: acc=0.518519\n" - ] - }, - { - "name": "stderr", - "output_type": "stream", - "text": [ - " \r" - ] - }, - { - "name": "stdout", - "output_type": "stream", - "text": [ - "Epoch 4/5. Step:4/5. AccuracyMetric: acc=0.481481\n" - ] - }, - { - "name": "stderr", - "output_type": "stream", - "text": [ - " \r" - ] - }, - { - "name": "stdout", - "output_type": "stream", - "text": [ - "Epoch 5/5. Step:5/5. AccuracyMetric: acc=0.592593\n" - ] - }, - { - "name": "stderr", - "output_type": "stream", - "text": [ - " \r" - ] - } - ], + "execution_count": null, + "metadata": {}, + "outputs": [], "source": [ "# 实例化Trainer,传入模型和数据,进行训练\n", "# 先在test_data拟合\n", @@ -604,101 +385,9 @@ }, { "cell_type": "code", - "execution_count": 31, - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "training epochs started 2018-12-04 22:52:01\n" - ] - }, - { - "name": "stderr", - "output_type": "stream", - "text": [ - " \r" - ] - }, - { - "name": "stdout", - "output_type": "stream", - "text": [ - "Epoch 1/5. Step:1/5. AccuracyMetric: acc=0.296296\n" - ] - }, - { - "name": "stderr", - "output_type": "stream", - "text": [ - " \r" - ] - }, - { - "name": "stdout", - "output_type": "stream", - "text": [ - "Epoch 2/5. Step:2/5. AccuracyMetric: acc=0.222222\n" - ] - }, - { - "name": "stderr", - "output_type": "stream", - "text": [ - " \r" - ] - }, - { - "name": "stdout", - "output_type": "stream", - "text": [ - "Epoch 3/5. Step:3/5. AccuracyMetric: acc=0.259259\n" - ] - }, - { - "name": "stderr", - "output_type": "stream", - "text": [ - " \r" - ] - }, - { - "name": "stdout", - "output_type": "stream", - "text": [ - "Epoch 4/5. Step:4/5. AccuracyMetric: acc=0.296296\n" - ] - }, - { - "name": "stderr", - "output_type": "stream", - "text": [ - " \r" - ] - }, - { - "name": "stdout", - "output_type": "stream", - "text": [ - "Epoch 5/5. Step:5/5. AccuracyMetric: acc=0.259259\n" - ] - }, - { - "name": "stderr", - "output_type": "stream", - "text": [ - " \r" - ] - }, - { - "name": "stdout", - "output_type": "stream", - "text": [ - "Train finished!\n" - ] - } - ], + "execution_count": null, + "metadata": {}, + "outputs": [], "source": [ "# 用train_data训练,在test_data验证\n", "trainer = Trainer(model=model, train_data=train_data, dev_data=test_data,\n", @@ -713,19 +402,9 @@ }, { "cell_type": "code", - "execution_count": 33, - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "[tester] \n", - "AccuracyMetric: acc=0.259259\n", - "{'AccuracyMetric': {'acc': 0.259259}}\n" - ] - } - ], + "execution_count": null, + "metadata": {}, + "outputs": [], "source": [ "# 调用Tester在test_data上评价效果\n", "from fastNLP import Tester\n",