@@ -69,7 +69,7 @@ class DataSet(object): | |||||
self.idx = idx | self.idx = idx | ||||
def __getitem__(self, item): | def __getitem__(self, item): | ||||
assert item in self.dataset.field_arrays, "no such field:{} in instance {}".format(item, self.dataset[self.idx]) | |||||
assert item in self.dataset.field_arrays, "no such field:{} in Instance {}".format(item, self.dataset[self.idx]) | |||||
assert self.idx < len(self.dataset.field_arrays[item]), "index:{} out of range".format(self.idx) | assert self.idx < len(self.dataset.field_arrays[item]), "index:{} out of range".format(self.idx) | ||||
return self.dataset.field_arrays[item][self.idx] | return self.dataset.field_arrays[item][self.idx] | ||||
@@ -83,7 +83,8 @@ class FieldArray(object): | |||||
elif isinstance(content, list): | elif isinstance(content, list): | ||||
# content is a 1-D list | # content is a 1-D list | ||||
if len(content) == 0: | if len(content) == 0: | ||||
raise RuntimeError("Cannot create FieldArray with an empty list.") | |||||
# the old error is not informative enough. | |||||
raise RuntimeError("Cannot create FieldArray with an empty list. Or one element in the list is empty.") | |||||
type_set = set([type(item) for item in content]) | type_set = set([type(item) for item in content]) | ||||
if len(type_set) == 1 and tuple(type_set)[0] in self.BASIC_TYPES: | if len(type_set) == 1 and tuple(type_set)[0] in self.BASIC_TYPES: | ||||
@@ -164,11 +165,13 @@ class FieldArray(object): | |||||
# TODO 当这个fieldArray是seq_length这种只有一位的内容时,不需要padding,需要再讨论一下 | # TODO 当这个fieldArray是seq_length这种只有一位的内容时,不需要padding,需要再讨论一下 | ||||
if not is_iterable(self.content[0]): | if not is_iterable(self.content[0]): | ||||
array = np.array([self.content[i] for i in indices], dtype=self.dtype) | array = np.array([self.content[i] for i in indices], dtype=self.dtype) | ||||
else: | |||||
elif self.dtype in (np.int64, np.float64): | |||||
max_len = max([len(self.content[i]) for i in indices]) | max_len = max([len(self.content[i]) for i in indices]) | ||||
array = np.full((batch_size, max_len), self.padding_val, dtype=self.dtype) | array = np.full((batch_size, max_len), self.padding_val, dtype=self.dtype) | ||||
for i, idx in enumerate(indices): | for i, idx in enumerate(indices): | ||||
array[i][:len(self.content[idx])] = self.content[idx] | array[i][:len(self.content[idx])] = self.content[idx] | ||||
else: # should only be str | |||||
array = np.array([self.content[i] for i in indices]) | |||||
return array | return array | ||||
def __len__(self): | def __len__(self): | ||||
@@ -80,7 +80,7 @@ class LossBase(object): | |||||
fast_param = {} | fast_param = {} | ||||
if len(self.param_map) == 2 and len(pred_dict) == 1 and len(target_dict) == 1: | if len(self.param_map) == 2 and len(pred_dict) == 1 and len(target_dict) == 1: | ||||
fast_param['pred'] = list(pred_dict.values())[0] | fast_param['pred'] = list(pred_dict.values())[0] | ||||
fast_param['target'] = list(pred_dict.values())[0] | |||||
fast_param['target'] = list(target_dict.values())[0] | |||||
return fast_param | return fast_param | ||||
return fast_param | return fast_param | ||||
@@ -134,10 +134,11 @@ class LossBase(object): | |||||
# missing | # missing | ||||
if not self._checked: | if not self._checked: | ||||
check_res = _check_arg_dict_list(self.get_loss, [mapped_pred_dict, mapped_target_dict]) | check_res = _check_arg_dict_list(self.get_loss, [mapped_pred_dict, mapped_target_dict]) | ||||
# only check missing. | |||||
# replace missing. | |||||
missing = check_res.missing | missing = check_res.missing | ||||
replaced_missing = list(missing) | replaced_missing = list(missing) | ||||
for idx, func_arg in enumerate(missing): | for idx, func_arg in enumerate(missing): | ||||
# Don't delete `` in this information, nor add `` | |||||
replaced_missing[idx] = f"{self.param_map[func_arg]}" + f"(assign to `{func_arg}` " \ | replaced_missing[idx] = f"{self.param_map[func_arg]}" + f"(assign to `{func_arg}` " \ | ||||
f"in `{self.__class__.__name__}`)" | f"in `{self.__class__.__name__}`)" | ||||
@@ -188,7 +189,7 @@ class CrossEntropyLoss(LossBase): | |||||
class L1Loss(LossBase): | class L1Loss(LossBase): | ||||
def __init__(self, pred=None, target=None): | def __init__(self, pred=None, target=None): | ||||
super(L1Loss, self).__init__() | super(L1Loss, self).__init__() | ||||
self._init_param_map(input=pred, target=target) | |||||
self._init_param_map(pred=pred, target=target) | |||||
def get_loss(self, pred, target): | def get_loss(self, pred, target): | ||||
return F.l1_loss(input=pred, target=target) | return F.l1_loss(input=pred, target=target) | ||||
@@ -197,7 +198,7 @@ class L1Loss(LossBase): | |||||
class BCELoss(LossBase): | class BCELoss(LossBase): | ||||
def __init__(self, pred=None, target=None): | def __init__(self, pred=None, target=None): | ||||
super(BCELoss, self).__init__() | super(BCELoss, self).__init__() | ||||
self._init_param_map(input=pred, target=target) | |||||
self._init_param_map(pred=pred, target=target) | |||||
def get_loss(self, pred, target): | def get_loss(self, pred, target): | ||||
return F.binary_cross_entropy(input=pred, target=target) | return F.binary_cross_entropy(input=pred, target=target) | ||||
@@ -205,7 +206,7 @@ class BCELoss(LossBase): | |||||
class NLLLoss(LossBase): | class NLLLoss(LossBase): | ||||
def __init__(self, pred=None, target=None): | def __init__(self, pred=None, target=None): | ||||
super(NLLLoss, self).__init__() | super(NLLLoss, self).__init__() | ||||
self._init_param_map(input=pred, target=target) | |||||
self._init_param_map(pred=pred, target=target) | |||||
def get_loss(self, pred, target): | def get_loss(self, pred, target): | ||||
return F.nll_loss(input=pred, target=target) | return F.nll_loss(input=pred, target=target) | ||||
@@ -151,9 +151,11 @@ class MetricBase(object): | |||||
if not self._checked: | if not self._checked: | ||||
check_res = _check_arg_dict_list(self.evaluate, [mapped_pred_dict, mapped_target_dict]) | check_res = _check_arg_dict_list(self.evaluate, [mapped_pred_dict, mapped_target_dict]) | ||||
# only check missing. | # only check missing. | ||||
# replace missing. | |||||
missing = check_res.missing | missing = check_res.missing | ||||
replaced_missing = list(missing) | replaced_missing = list(missing) | ||||
for idx, func_arg in enumerate(missing): | for idx, func_arg in enumerate(missing): | ||||
# Don't delete `` in this information, nor add `` | |||||
replaced_missing[idx] = f"{self.param_map[func_arg]}" + f"(assign to `{func_arg}` " \ | replaced_missing[idx] = f"{self.param_map[func_arg]}" + f"(assign to `{func_arg}` " \ | ||||
f"in `{self.__class__.__name__}`)" | f"in `{self.__class__.__name__}`)" | ||||
@@ -2,7 +2,7 @@ import os | |||||
import time | import time | ||||
from datetime import datetime | from datetime import datetime | ||||
from datetime import timedelta | from datetime import timedelta | ||||
from tqdm import tqdm | |||||
from tqdm.autonotebook import tqdm | |||||
import torch | import torch | ||||
from tensorboardX import SummaryWriter | from tensorboardX import SummaryWriter | ||||
@@ -23,7 +23,6 @@ from fastNLP.core.utils import _check_forward_error | |||||
from fastNLP.core.utils import _check_loss_evaluate | from fastNLP.core.utils import _check_loss_evaluate | ||||
from fastNLP.core.utils import _move_dict_value_to_device | from fastNLP.core.utils import _move_dict_value_to_device | ||||
from fastNLP.core.utils import get_func_signature | from fastNLP.core.utils import get_func_signature | ||||
from fastNLP.core.utils import _relocate_pbar | |||||
class Trainer(object): | class Trainer(object): | ||||
"""Main Training Loop | """Main Training Loop | ||||
@@ -45,7 +44,7 @@ class Trainer(object): | |||||
:param int validate_every: step interval to do next validation. Default: -1(validate every epoch). | :param int validate_every: step interval to do next validation. Default: -1(validate every epoch). | ||||
:param DataSet dev_data: the validation data | :param DataSet dev_data: the validation data | ||||
:param use_cuda: | :param use_cuda: | ||||
:param str save_path: file path to save models | |||||
:param save_path: file path to save models | |||||
:param Optimizer optimizer: an optimizer object | :param Optimizer optimizer: an optimizer object | ||||
:param int check_code_level: level of FastNLP code checker. -1: don't check, 0: ignore. 1: warning. 2: strict. | :param int check_code_level: level of FastNLP code checker. -1: don't check, 0: ignore. 1: warning. 2: strict. | ||||
`ignore` will not check unused field; `warning` when warn if some field are not used; `strict` means | `ignore` will not check unused field; `warning` when warn if some field are not used; `strict` means | ||||
@@ -149,7 +148,7 @@ class Trainer(object): | |||||
self._mode(self.model, is_test=False) | self._mode(self.model, is_test=False) | ||||
self.start_time = str(datetime.now().strftime('%Y-%m-%d %H:%M:%S')) | self.start_time = str(datetime.now().strftime('%Y-%m-%d %H:%M:%S')) | ||||
print("training epochs started " + self.start_time) | |||||
print("training epochs started " + self.start_time, flush=True) | |||||
if self.save_path is None: | if self.save_path is None: | ||||
class psudoSW: | class psudoSW: | ||||
def __getattr__(self, item): | def __getattr__(self, item): | ||||
@@ -172,12 +171,12 @@ class Trainer(object): | |||||
del self._summary_writer | del self._summary_writer | ||||
def _tqdm_train(self): | def _tqdm_train(self): | ||||
self.step = 0 | |||||
data_iterator = Batch(self.train_data, batch_size=self.batch_size, sampler=self.sampler, | data_iterator = Batch(self.train_data, batch_size=self.batch_size, sampler=self.sampler, | ||||
as_numpy=False) | as_numpy=False) | ||||
total_steps = data_iterator.num_batches*self.n_epochs | total_steps = data_iterator.num_batches*self.n_epochs | ||||
epoch = 1 | epoch = 1 | ||||
with tqdm(total=total_steps, postfix='loss:{0:<6.5f}', desc="Epoch {}/{}" | |||||
.format(epoch, self.n_epochs), leave=False, dynamic_ncols=True) as pbar: | |||||
with tqdm(total=total_steps, postfix='loss:{0:<6.5f}', leave=False, dynamic_ncols=True) as pbar: | |||||
ava_loss = 0 | ava_loss = 0 | ||||
for epoch in range(1, self.n_epochs+1): | for epoch in range(1, self.n_epochs+1): | ||||
pbar.set_description_str(desc="Epoch {}/{}".format(epoch, self.n_epochs)) | pbar.set_description_str(desc="Epoch {}/{}".format(epoch, self.n_epochs)) | ||||
@@ -195,28 +194,26 @@ class Trainer(object): | |||||
# self._summary_writer.add_scalar(name + "_std", param.std(), global_step=self.step) | # self._summary_writer.add_scalar(name + "_std", param.std(), global_step=self.step) | ||||
# self._summary_writer.add_scalar(name + "_grad_sum", param.sum(), global_step=self.step) | # self._summary_writer.add_scalar(name + "_grad_sum", param.sum(), global_step=self.step) | ||||
if (self.step+1) % self.print_every == 0: | if (self.step+1) % self.print_every == 0: | ||||
pbar.update(self.print_every) | |||||
pbar.set_postfix_str("loss:{0:<6.5f}".format(ava_loss/self.print_every)) | |||||
pbar.set_postfix_str("loss:{0:<6.5f}".format(ava_loss / self.print_every)) | |||||
ava_loss = 0 | ava_loss = 0 | ||||
pbar.update(1) | |||||
self.step += 1 | self.step += 1 | ||||
if self.validate_every > 0 and self.step % self.validate_every == 0 \ | if self.validate_every > 0 and self.step % self.validate_every == 0 \ | ||||
and self.dev_data is not None: | and self.dev_data is not None: | ||||
eval_res = self._do_validation() | eval_res = self._do_validation() | ||||
eval_str = "Epoch {}/{}. Step:{}/{}. ".format(epoch, self.n_epochs, self.step, total_steps) + \ | eval_str = "Epoch {}/{}. Step:{}/{}. ".format(epoch, self.n_epochs, self.step, total_steps) + \ | ||||
self.tester._format_eval_results(eval_res) | self.tester._format_eval_results(eval_res) | ||||
pbar = _relocate_pbar(pbar, print_str=eval_str) | |||||
pbar.write(eval_str) | |||||
if self.validate_every < 0 and self.dev_data: | if self.validate_every < 0 and self.dev_data: | ||||
eval_res = self._do_validation() | eval_res = self._do_validation() | ||||
eval_str = "Epoch {}/{}. Step:{}/{}. ".format(epoch, self.n_epochs, self.step, total_steps) + \ | eval_str = "Epoch {}/{}. Step:{}/{}. ".format(epoch, self.n_epochs, self.step, total_steps) + \ | ||||
self.tester._format_eval_results(eval_res) | self.tester._format_eval_results(eval_res) | ||||
pbar = _relocate_pbar(pbar, print_str=eval_str) | |||||
pbar.write(eval_str) | |||||
if epoch!=self.n_epochs: | if epoch!=self.n_epochs: | ||||
data_iterator = Batch(self.train_data, batch_size=self.batch_size, sampler=self.sampler, | data_iterator = Batch(self.train_data, batch_size=self.batch_size, sampler=self.sampler, | ||||
as_numpy=False) | as_numpy=False) | ||||
pbar.close() | pbar.close() | ||||
def _print_train(self): | def _print_train(self): | ||||
""" | """ | ||||
@@ -264,9 +261,6 @@ class Trainer(object): | |||||
self._do_validation() | self._do_validation() | ||||
epoch += 1 | epoch += 1 | ||||
def _do_validation(self): | def _do_validation(self): | ||||
res = self.tester.test() | res = self.tester.test() | ||||
for name, num in res.items(): | for name, num in res.items(): | ||||
@@ -258,29 +258,48 @@ def _check_loss_evaluate(prev_func_signature: str, func_signature: str, check_re | |||||
if _unused_param: | if _unused_param: | ||||
unuseds.append(f"\tunused param: {_unused_param}") # output from predict or forward | unuseds.append(f"\tunused param: {_unused_param}") # output from predict or forward | ||||
module_name = '' | |||||
if check_res.missing: | if check_res.missing: | ||||
errs.append(f"\tmissing param: {check_res.missing}") | errs.append(f"\tmissing param: {check_res.missing}") | ||||
_miss_in_dataset = [] | |||||
_miss_out_dataset = [] | |||||
import re | |||||
mapped_missing = [] | |||||
unmapped_missing = [] | |||||
input_func_map = {} | |||||
for _miss in check_res.missing: | for _miss in check_res.missing: | ||||
fun_arg, module_name = re.findall("(?<=`)[a-zA-Z0-9]*?(?=`)", _miss) | |||||
if '(' in _miss: | if '(' in _miss: | ||||
# if they are like 'SomeParam(assign to xxx)' | # if they are like 'SomeParam(assign to xxx)' | ||||
_miss = _miss.split('(')[0] | _miss = _miss.split('(')[0] | ||||
if _miss in dataset: | |||||
_miss_in_dataset.append(_miss) | |||||
input_func_map[_miss] = fun_arg | |||||
if fun_arg == _miss: | |||||
unmapped_missing.append(_miss) | |||||
else: | else: | ||||
_miss_out_dataset.append(_miss) | |||||
mapped_missing.append(_miss) | |||||
if _miss_in_dataset: | |||||
suggestions.append(f"You might need to set {_miss_in_dataset} as target(Right now " | |||||
f"target is {list(target_dict.keys())}).") | |||||
if _miss_out_dataset: | |||||
_tmp = (f"You might need to provide {_miss_out_dataset} in DataSet and set it as target(Right now " | |||||
f"target has {list(target_dict.keys())}) or output it " | |||||
f"in {prev_func_signature}(Right now output has {list(pred_dict.keys())}).") | |||||
# if _unused_field: | |||||
# _tmp += f"You can use DataSet.rename_field() to rename the field in `unused field:`. " | |||||
suggestions.append(_tmp) | |||||
for _miss in mapped_missing: | |||||
if _miss in dataset: | |||||
suggestions.append(f"Set {_miss} as target.") | |||||
else: | |||||
_tmp = '' | |||||
if check_res.unused: | |||||
_tmp = f"Check key assignment for `{input_func_map[_miss]}` when initialize {module_name}." | |||||
if _tmp: | |||||
_tmp += f' Or provide {_miss} in DataSet or output of {prev_func_signature}.' | |||||
else: | |||||
_tmp = f'Provide {_miss} in DataSet or output of {prev_func_signature}.' | |||||
suggestions.append(_tmp) | |||||
for _miss in unmapped_missing: | |||||
if _miss in dataset: | |||||
suggestions.append(f"Set {_miss} as target.") | |||||
else: | |||||
_tmp = '' | |||||
if check_res.unused: | |||||
_tmp = f"Specify your assignment for `{input_func_map[_miss]}` when initialize {module_name}." | |||||
if _tmp: | |||||
_tmp += f' Or provide {_miss} in DataSet or output of {prev_func_signature}.' | |||||
else: | |||||
_tmp = f'Provide {_miss} in DataSet or output of {prev_func_signature}.' | |||||
suggestions.append(_tmp) | |||||
if check_res.duplicated: | if check_res.duplicated: | ||||
errs.append(f"\tduplicated param: {check_res.duplicated}.") | errs.append(f"\tduplicated param: {check_res.duplicated}.") | ||||
@@ -297,17 +316,23 @@ def _check_loss_evaluate(prev_func_signature: str, func_signature: str, check_re | |||||
sugg_str = "" | sugg_str = "" | ||||
if len(suggestions) > 1: | if len(suggestions) > 1: | ||||
for idx, sugg in enumerate(suggestions): | for idx, sugg in enumerate(suggestions): | ||||
sugg_str += f'({idx+1}). {sugg}' | |||||
if idx>0: | |||||
sugg_str += '\t\t\t' | |||||
sugg_str += f'({idx+1}). {sugg}\n' | |||||
sugg_str = sugg_str[:-1] | |||||
else: | else: | ||||
sugg_str += suggestions[0] | sugg_str += suggestions[0] | ||||
errs.append(f'\ttarget field: {list(target_dict.keys())}') | |||||
errs.append(f'\tparam from {prev_func_signature}: {list(pred_dict.keys())}') | |||||
err_str = '\n' + '\n'.join(errs) + '\n\tSuggestion: ' + sugg_str | err_str = '\n' + '\n'.join(errs) + '\n\tSuggestion: ' + sugg_str | ||||
raise NameError(err_str) | raise NameError(err_str) | ||||
if check_res.unused: | if check_res.unused: | ||||
if check_level == WARNING_CHECK_LEVEL: | if check_level == WARNING_CHECK_LEVEL: | ||||
_unused_warn = f'{check_res.unused} is not used by {func_signature}.' | |||||
if not module_name: | |||||
module_name = func_signature.split('.')[0] | |||||
_unused_warn = f'{check_res.unused} is not used by {module_name}.' | |||||
warnings.warn(message=_unused_warn) | warnings.warn(message=_unused_warn) | ||||
def _check_forward_error(forward_func, batch_x, dataset, check_level): | def _check_forward_error(forward_func, batch_x, dataset, check_level): | ||||
check_res = _check_arg_dict_list(forward_func, batch_x) | check_res = _check_arg_dict_list(forward_func, batch_x) | ||||
func_signature = get_func_signature(forward_func) | func_signature = get_func_signature(forward_func) | ||||
@@ -402,40 +427,3 @@ def seq_mask(seq_len, max_len): | |||||
seq_len = seq_len.view(-1, 1).long() # [batch_size, 1] | seq_len = seq_len.view(-1, 1).long() # [batch_size, 1] | ||||
seq_range = torch.arange(start=0, end=max_len, dtype=torch.long, device=seq_len.device).view(1, -1) # [1, max_len] | seq_range = torch.arange(start=0, end=max_len, dtype=torch.long, device=seq_len.device).view(1, -1) # [1, max_len] | ||||
return torch.gt(seq_len, seq_range) # [batch_size, max_len] | return torch.gt(seq_len, seq_range) # [batch_size, max_len] | ||||
def _relocate_pbar(pbar:tqdm, print_str:str): | |||||
""" | |||||
When using tqdm, you cannot print. If you print, the tqdm will duplicate. By using this function, print_str will | |||||
show above tqdm. | |||||
:param pbar: tqdm | |||||
:param print_str: | |||||
:return: | |||||
""" | |||||
params = ['desc', 'total', 'leave', 'file', 'ncols', 'mininterval', 'maxinterval', 'miniters', 'ascii', 'disable', | |||||
'unit', 'unit_scale', 'dynamic_ncols', 'smoothing', 'bar_format', 'initial', 'position', 'postfix', 'unit_divisor', | |||||
'gui'] | |||||
attr_map = {'file': 'fp', 'initial':'n', 'position':'pos'} | |||||
param_dict = {} | |||||
for param in params: | |||||
attr_name = param | |||||
if param in attr_map: | |||||
attr_name = attr_map[param] | |||||
value = getattr(pbar, attr_name) | |||||
if attr_name == 'pos': | |||||
value = abs(value) | |||||
param_dict[param] = value | |||||
pbar.close() | |||||
avg_time = pbar.avg_time | |||||
start_t = pbar.start_t | |||||
print(print_str) | |||||
pbar = tqdm(**param_dict) | |||||
pbar.start_t = start_t | |||||
pbar.avg_time = avg_time | |||||
pbar.sp(pbar.__repr__()) | |||||
return pbar |
@@ -1,4 +1,4 @@ | |||||
numpy>=1.14.2 | numpy>=1.14.2 | ||||
torch>=0.4.0 | torch>=0.4.0 | ||||
tensorboardX | tensorboardX | ||||
tqdm | |||||
tqdm>=4.28.1 |
@@ -142,9 +142,16 @@ class TestDataSet(unittest.TestCase): | |||||
def split_sent(ins): | def split_sent(ins): | ||||
return ins['raw_sentence'].split() | return ins['raw_sentence'].split() | ||||
dataset = DataSet.read_csv('../../sentence.csv', headers=('raw_sentence', 'label'), sep='\t') | dataset = DataSet.read_csv('../../sentence.csv', headers=('raw_sentence', 'label'), sep='\t') | ||||
dataset.apply(split_sent, new_field_name='words') | |||||
dataset.drop(lambda x:len(x['raw_sentence'].split())==0) | |||||
dataset.apply(split_sent, new_field_name='words', is_input=True) | |||||
# print(dataset) | # print(dataset) | ||||
def test_add_field(self): | |||||
ds = DataSet({"x": [3, 4]}) | |||||
ds.add_field('y', [['hello', 'world'], ['this', 'is', 'a', 'test']], is_input=True, is_target=True) | |||||
# ds.apply(lambda x:[x['x']]*3, is_input=True, is_target=True, new_field_name='y') | |||||
print(ds) | |||||
def test_save_load(self): | def test_save_load(self): | ||||
ds = DataSet({"x": [[1, 2, 3, 4]] * 10, "y": [[5, 6]] * 10}) | ds = DataSet({"x": [[1, 2, 3, 4]] * 10, "y": [[5, 6]] * 10}) | ||||
ds.save("./my_ds.pkl") | ds.save("./my_ds.pkl") | ||||
@@ -4,6 +4,64 @@ data_name = "pku_training.utf8" | |||||
pickle_path = "data_for_tests" | pickle_path = "data_for_tests" | ||||
import numpy as np | |||||
import torch.nn.functional as F | |||||
from torch import nn | |||||
import time | |||||
from fastNLP.core.utils import CheckError | |||||
from fastNLP.core.dataset import DataSet | |||||
from fastNLP.core.instance import Instance | |||||
from fastNLP.core.losses import BCELoss | |||||
from fastNLP.core.losses import CrossEntropyLoss | |||||
from fastNLP.core.metrics import AccuracyMetric | |||||
from fastNLP.core.optimizer import SGD | |||||
from fastNLP.core.tester import Tester | |||||
from fastNLP.models.base_model import NaiveClassifier | |||||
def prepare_fake_dataset(): | |||||
mean = np.array([-3, -3]) | |||||
cov = np.array([[1, 0], [0, 1]]) | |||||
class_A = np.random.multivariate_normal(mean, cov, size=(1000,)) | |||||
mean = np.array([3, 3]) | |||||
cov = np.array([[1, 0], [0, 1]]) | |||||
class_B = np.random.multivariate_normal(mean, cov, size=(1000,)) | |||||
data_set = DataSet([Instance(x=[float(item[0]), float(item[1])], y=[0.0]) for item in class_A] + | |||||
[Instance(x=[float(item[0]), float(item[1])], y=[1.0]) for item in class_B]) | |||||
return data_set | |||||
def prepare_fake_dataset2(*args, size=100): | |||||
ys = np.random.randint(4, size=100, dtype=np.int64) | |||||
data = {'y': ys} | |||||
for arg in args: | |||||
data[arg] = np.random.randn(size, 5) | |||||
return DataSet(data=data) | |||||
class TestTester(unittest.TestCase): | class TestTester(unittest.TestCase): | ||||
def test_case_1(self): | def test_case_1(self): | ||||
pass | |||||
# 检查报错提示能否正确提醒用户 | |||||
# 这里传入多余参数,让其duplicate | |||||
dataset = prepare_fake_dataset2('x1', 'x_unused') | |||||
dataset.rename_field('x_unused', 'x2') | |||||
dataset.set_input('x1', 'x2') | |||||
dataset.set_target('y', 'x1') | |||||
class Model(nn.Module): | |||||
def __init__(self): | |||||
super().__init__() | |||||
self.fc = nn.Linear(5, 4) | |||||
def forward(self, x1, x2): | |||||
x1 = self.fc(x1) | |||||
x2 = self.fc(x2) | |||||
x = x1 + x2 | |||||
time.sleep(0.1) | |||||
# loss = F.cross_entropy(x, y) | |||||
return {'preds': x} | |||||
model = Model() | |||||
tester = Tester( | |||||
data=dataset, | |||||
model=model, | |||||
metrics=AccuracyMetric()) | |||||
tester.test() |
@@ -3,7 +3,7 @@ import unittest | |||||
import numpy as np | import numpy as np | ||||
import torch.nn.functional as F | import torch.nn.functional as F | ||||
from torch import nn | from torch import nn | ||||
import time | |||||
from fastNLP.core.utils import CheckError | from fastNLP.core.utils import CheckError | ||||
from fastNLP.core.dataset import DataSet | from fastNLP.core.dataset import DataSet | ||||
from fastNLP.core.instance import Instance | from fastNLP.core.instance import Instance | ||||
@@ -212,8 +212,8 @@ class TrainerTestGround(unittest.TestCase): | |||||
# 这里传入多余参数,让其duplicate | # 这里传入多余参数,让其duplicate | ||||
dataset = prepare_fake_dataset2('x1', 'x_unused') | dataset = prepare_fake_dataset2('x1', 'x_unused') | ||||
dataset.rename_field('x_unused', 'x2') | dataset.rename_field('x_unused', 'x2') | ||||
dataset.set_input('x1', 'x2', 'y') | |||||
dataset.set_target('x1', 'x2') | |||||
dataset.set_input('x1', 'x2') | |||||
dataset.set_target('y', 'x1') | |||||
class Model(nn.Module): | class Model(nn.Module): | ||||
def __init__(self): | def __init__(self): | ||||
super().__init__() | super().__init__() | ||||
@@ -222,8 +222,9 @@ class TrainerTestGround(unittest.TestCase): | |||||
x1 = self.fc(x1) | x1 = self.fc(x1) | ||||
x2 = self.fc(x2) | x2 = self.fc(x2) | ||||
x = x1 + x2 | x = x1 + x2 | ||||
time.sleep(0.1) | |||||
# loss = F.cross_entropy(x, y) | # loss = F.cross_entropy(x, y) | ||||
return {'pred': x} | |||||
return {'preds': x} | |||||
model = Model() | model = Model() | ||||
trainer = Trainer( | trainer = Trainer( | ||||
@@ -12,7 +12,7 @@ | |||||
}, | }, | ||||
{ | { | ||||
"cell_type": "code", | "cell_type": "code", | ||||
"execution_count": 3, | |||||
"execution_count": null, | |||||
"metadata": {}, | "metadata": {}, | ||||
"outputs": [], | "outputs": [], | ||||
"source": [ | "source": [ | ||||
@@ -34,17 +34,9 @@ | |||||
}, | }, | ||||
{ | { | ||||
"cell_type": "code", | "cell_type": "code", | ||||
"execution_count": 4, | |||||
"metadata": {}, | |||||
"outputs": [ | |||||
{ | |||||
"name": "stdout", | |||||
"output_type": "stream", | |||||
"text": [ | |||||
"8529\n" | |||||
] | |||||
} | |||||
], | |||||
"execution_count": null, | |||||
"metadata": {}, | |||||
"outputs": [], | |||||
"source": [ | "source": [ | ||||
"from fastNLP import DataSet\n", | "from fastNLP import DataSet\n", | ||||
"from fastNLP import Instance\n", | "from fastNLP import Instance\n", | ||||
@@ -56,20 +48,9 @@ | |||||
}, | }, | ||||
{ | { | ||||
"cell_type": "code", | "cell_type": "code", | ||||
"execution_count": 5, | |||||
"metadata": {}, | |||||
"outputs": [ | |||||
{ | |||||
"name": "stdout", | |||||
"output_type": "stream", | |||||
"text": [ | |||||
"{'raw_sentence': A series of escapades demonstrating the adage that what is good for the goose is also good for the gander , some of which occasionally amuses but none of which amounts to much of a story .,\n", | |||||
"'label': 1}\n", | |||||
"{'raw_sentence': -LRB- Tries -RRB- to parody a genre that 's already a joke in the United States .,\n", | |||||
"'label': 1}\n" | |||||
] | |||||
} | |||||
], | |||||
"execution_count": null, | |||||
"metadata": {}, | |||||
"outputs": [], | |||||
"source": [ | "source": [ | ||||
"# 使用数字索引[k],获取第k个样本\n", | "# 使用数字索引[k],获取第k个样本\n", | ||||
"print(dataset[0])\n", | "print(dataset[0])\n", | ||||
@@ -90,21 +71,9 @@ | |||||
}, | }, | ||||
{ | { | ||||
"cell_type": "code", | "cell_type": "code", | ||||
"execution_count": 6, | |||||
"metadata": {}, | |||||
"outputs": [ | |||||
{ | |||||
"data": { | |||||
"text/plain": [ | |||||
"{'raw_sentence': fake data,\n", | |||||
"'label': 0}" | |||||
] | |||||
}, | |||||
"execution_count": 6, | |||||
"metadata": {}, | |||||
"output_type": "execute_result" | |||||
} | |||||
], | |||||
"execution_count": null, | |||||
"metadata": {}, | |||||
"outputs": [], | |||||
"source": [ | "source": [ | ||||
"# DataSet.append(Instance)加入新数据\n", | "# DataSet.append(Instance)加入新数据\n", | ||||
"dataset.append(Instance(raw_sentence='fake data', label='0'))\n", | "dataset.append(Instance(raw_sentence='fake data', label='0'))\n", | ||||
@@ -121,18 +90,9 @@ | |||||
}, | }, | ||||
{ | { | ||||
"cell_type": "code", | "cell_type": "code", | ||||
"execution_count": 7, | |||||
"metadata": {}, | |||||
"outputs": [ | |||||
{ | |||||
"name": "stdout", | |||||
"output_type": "stream", | |||||
"text": [ | |||||
"{'raw_sentence': a series of escapades demonstrating the adage that what is good for the goose is also good for the gander , some of which occasionally amuses but none of which amounts to much of a story .,\n", | |||||
"'label': 1}\n" | |||||
] | |||||
} | |||||
], | |||||
"execution_count": null, | |||||
"metadata": {}, | |||||
"outputs": [], | |||||
"source": [ | "source": [ | ||||
"# 将所有数字转为小写\n", | "# 将所有数字转为小写\n", | ||||
"dataset.apply(lambda x: x['raw_sentence'].lower(), new_field_name='raw_sentence')\n", | "dataset.apply(lambda x: x['raw_sentence'].lower(), new_field_name='raw_sentence')\n", | ||||
@@ -141,18 +101,9 @@ | |||||
}, | }, | ||||
{ | { | ||||
"cell_type": "code", | "cell_type": "code", | ||||
"execution_count": 8, | |||||
"metadata": {}, | |||||
"outputs": [ | |||||
{ | |||||
"name": "stdout", | |||||
"output_type": "stream", | |||||
"text": [ | |||||
"{'raw_sentence': a series of escapades demonstrating the adage that what is good for the goose is also good for the gander , some of which occasionally amuses but none of which amounts to much of a story .,\n", | |||||
"'label': 1}\n" | |||||
] | |||||
} | |||||
], | |||||
"execution_count": null, | |||||
"metadata": {}, | |||||
"outputs": [], | |||||
"source": [ | "source": [ | ||||
"# label转int\n", | "# label转int\n", | ||||
"dataset.apply(lambda x: int(x['label']), new_field_name='label')\n", | "dataset.apply(lambda x: int(x['label']), new_field_name='label')\n", | ||||
@@ -161,28 +112,9 @@ | |||||
}, | }, | ||||
{ | { | ||||
"cell_type": "code", | "cell_type": "code", | ||||
"execution_count": 9, | |||||
"metadata": {}, | |||||
"outputs": [ | |||||
{ | |||||
"ename": "RuntimeError", | |||||
"evalue": "Cannot create FieldArray with an empty list.", | |||||
"output_type": "error", | |||||
"traceback": [ | |||||
"\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", | |||||
"\u001b[0;31mRuntimeError\u001b[0m Traceback (most recent call last)", | |||||
"\u001b[0;32m<ipython-input-9-d70cf5545af4>\u001b[0m in \u001b[0;36m<module>\u001b[0;34m()\u001b[0m\n\u001b[1;32m 2\u001b[0m \u001b[0;32mdef\u001b[0m \u001b[0msplit_sent\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mins\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 3\u001b[0m \u001b[0;32mreturn\u001b[0m \u001b[0mins\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;34m'raw_sentence'\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0msplit\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m----> 4\u001b[0;31m \u001b[0mdataset\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mapply\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0msplit_sent\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mnew_field_name\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;34m'words'\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 5\u001b[0m \u001b[0mprint\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mdataset\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;36m0\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", | |||||
"\u001b[0;32m~/Desktop/fastNLP/fastNLP/fastNLP/core/dataset.py\u001b[0m in \u001b[0;36mapply\u001b[0;34m(self, func, new_field_name, **kwargs)\u001b[0m\n\u001b[1;32m 265\u001b[0m **extra_param)\n\u001b[1;32m 266\u001b[0m \u001b[0;32melse\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 267\u001b[0;31m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0madd_field\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mname\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mnew_field_name\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mfields\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mresults\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mextra_param\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 268\u001b[0m \u001b[0;32melse\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 269\u001b[0m \u001b[0;32mreturn\u001b[0m \u001b[0mresults\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", | |||||
"\u001b[0;32m~/Desktop/fastNLP/fastNLP/fastNLP/core/dataset.py\u001b[0m in \u001b[0;36madd_field\u001b[0;34m(self, name, fields, padding_val, is_input, is_target)\u001b[0m\n\u001b[1;32m 158\u001b[0m f\"Dataset size {len(self)} != field size {len(fields)}\")\n\u001b[1;32m 159\u001b[0m self.field_arrays[name] = FieldArray(name, fields, padding_val=padding_val, is_target=is_target,\n\u001b[0;32m--> 160\u001b[0;31m is_input=is_input)\n\u001b[0m\u001b[1;32m 161\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 162\u001b[0m \u001b[0;32mdef\u001b[0m \u001b[0mdelete_field\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mname\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", | |||||
"\u001b[0;32m~/Desktop/fastNLP/fastNLP/fastNLP/core/fieldarray.py\u001b[0m in \u001b[0;36m__init__\u001b[0;34m(self, name, content, padding_val, is_target, is_input)\u001b[0m\n\u001b[1;32m 36\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 37\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0mis_input\u001b[0m \u001b[0;32mis\u001b[0m \u001b[0;32mnot\u001b[0m \u001b[0;32mNone\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 38\u001b[0;31m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mis_input\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mis_input\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 39\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0mis_target\u001b[0m \u001b[0;32mis\u001b[0m \u001b[0;32mnot\u001b[0m \u001b[0;32mNone\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 40\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mis_target\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mis_target\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", | |||||
"\u001b[0;32m~/Desktop/fastNLP/fastNLP/fastNLP/core/fieldarray.py\u001b[0m in \u001b[0;36mis_input\u001b[0;34m(self, value)\u001b[0m\n\u001b[1;32m 46\u001b[0m \u001b[0;34m@\u001b[0m\u001b[0mis_input\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0msetter\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 47\u001b[0m \u001b[0;32mdef\u001b[0m \u001b[0mis_input\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mvalue\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 48\u001b[0;31m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mpytype\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_type_detection\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mcontent\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 49\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mdtype\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_map_to_np_type\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mpytype\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 50\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_is_input\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mvalue\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", | |||||
"\u001b[0;32m~/Desktop/fastNLP/fastNLP/fastNLP/core/fieldarray.py\u001b[0m in \u001b[0;36m_type_detection\u001b[0;34m(self, content)\u001b[0m\n\u001b[1;32m 71\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0;32mnot\u001b[0m \u001b[0mall\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0misinstance\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0m_\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mlist\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;32mfor\u001b[0m \u001b[0m_\u001b[0m \u001b[0;32min\u001b[0m \u001b[0mcontent\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m \u001b[0;31m# strict check 2-D list\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 72\u001b[0m \u001b[0;32mraise\u001b[0m \u001b[0mTypeError\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m\"Please provide 2-D list.\"\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 73\u001b[0;31m \u001b[0mtype_set\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mset\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_type_detection\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mx\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;32mfor\u001b[0m \u001b[0mx\u001b[0m \u001b[0;32min\u001b[0m \u001b[0mcontent\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 74\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0mlen\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mtype_set\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;34m==\u001b[0m \u001b[0;36m2\u001b[0m \u001b[0;32mand\u001b[0m \u001b[0mint\u001b[0m \u001b[0;32min\u001b[0m \u001b[0mtype_set\u001b[0m \u001b[0;32mand\u001b[0m \u001b[0mfloat\u001b[0m \u001b[0;32min\u001b[0m \u001b[0mtype_set\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 75\u001b[0m \u001b[0mtype_set\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;34m{\u001b[0m\u001b[0mfloat\u001b[0m\u001b[0;34m}\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", | |||||
"\u001b[0;32m~/Desktop/fastNLP/fastNLP/fastNLP/core/fieldarray.py\u001b[0m in \u001b[0;36m<listcomp>\u001b[0;34m(.0)\u001b[0m\n\u001b[1;32m 71\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0;32mnot\u001b[0m \u001b[0mall\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0misinstance\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0m_\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mlist\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;32mfor\u001b[0m \u001b[0m_\u001b[0m \u001b[0;32min\u001b[0m \u001b[0mcontent\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m \u001b[0;31m# strict check 2-D list\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 72\u001b[0m \u001b[0;32mraise\u001b[0m \u001b[0mTypeError\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m\"Please provide 2-D list.\"\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 73\u001b[0;31m \u001b[0mtype_set\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mset\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_type_detection\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mx\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;32mfor\u001b[0m \u001b[0mx\u001b[0m \u001b[0;32min\u001b[0m \u001b[0mcontent\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 74\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0mlen\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mtype_set\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;34m==\u001b[0m \u001b[0;36m2\u001b[0m \u001b[0;32mand\u001b[0m \u001b[0mint\u001b[0m \u001b[0;32min\u001b[0m \u001b[0mtype_set\u001b[0m \u001b[0;32mand\u001b[0m \u001b[0mfloat\u001b[0m \u001b[0;32min\u001b[0m \u001b[0mtype_set\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 75\u001b[0m \u001b[0mtype_set\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;34m{\u001b[0m\u001b[0mfloat\u001b[0m\u001b[0;34m}\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", | |||||
"\u001b[0;32m~/Desktop/fastNLP/fastNLP/fastNLP/core/fieldarray.py\u001b[0m in \u001b[0;36m_type_detection\u001b[0;34m(self, content)\u001b[0m\n\u001b[1;32m 82\u001b[0m \u001b[0;31m# content is a 1-D list\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 83\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0mlen\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mcontent\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;34m==\u001b[0m \u001b[0;36m0\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 84\u001b[0;31m \u001b[0;32mraise\u001b[0m \u001b[0mRuntimeError\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m\"Cannot create FieldArray with an empty list.\"\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 85\u001b[0m \u001b[0mtype_set\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mset\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0mtype\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mitem\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;32mfor\u001b[0m \u001b[0mitem\u001b[0m \u001b[0;32min\u001b[0m \u001b[0mcontent\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 86\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n", | |||||
"\u001b[0;31mRuntimeError\u001b[0m: Cannot create FieldArray with an empty list." | |||||
] | |||||
} | |||||
], | |||||
"execution_count": null, | |||||
"metadata": {}, | |||||
"outputs": [], | |||||
"source": [ | "source": [ | ||||
"# 使用空格分割句子\n", | "# 使用空格分割句子\n", | ||||
"def split_sent(ins):\n", | "def split_sent(ins):\n", | ||||
@@ -193,20 +125,9 @@ | |||||
}, | }, | ||||
{ | { | ||||
"cell_type": "code", | "cell_type": "code", | ||||
"execution_count": 17, | |||||
"metadata": {}, | |||||
"outputs": [ | |||||
{ | |||||
"name": "stdout", | |||||
"output_type": "stream", | |||||
"text": [ | |||||
"{'raw_sentence': a series of escapades demonstrating the adage that what is good for the goose is also good for the gander , some of which occasionally amuses but none of which amounts to much of a story .,\n", | |||||
"'label': 1,\n", | |||||
"'words': ['a', 'series', 'of', 'escapades', 'demonstrating', 'the', 'adage', 'that', 'what', 'is', 'good', 'for', 'the', 'goose', 'is', 'also', 'good', 'for', 'the', 'gander', ',', 'some', 'of', 'which', 'occasionally', 'amuses', 'but', 'none', 'of', 'which', 'amounts', 'to', 'much', 'of', 'a', 'story', '.'],\n", | |||||
"'seq_len': 37}\n" | |||||
] | |||||
} | |||||
], | |||||
"execution_count": null, | |||||
"metadata": {}, | |||||
"outputs": [], | |||||
"source": [ | "source": [ | ||||
"# 增加长度信息\n", | "# 增加长度信息\n", | ||||
"dataset.apply(lambda x: len(x['words']), new_field_name='seq_len')\n", | "dataset.apply(lambda x: len(x['words']), new_field_name='seq_len')\n", | ||||
@@ -223,17 +144,9 @@ | |||||
}, | }, | ||||
{ | { | ||||
"cell_type": "code", | "cell_type": "code", | ||||
"execution_count": 19, | |||||
"metadata": {}, | |||||
"outputs": [ | |||||
{ | |||||
"name": "stdout", | |||||
"output_type": "stream", | |||||
"text": [ | |||||
"38\n" | |||||
] | |||||
} | |||||
], | |||||
"execution_count": null, | |||||
"metadata": {}, | |||||
"outputs": [], | |||||
"source": [ | "source": [ | ||||
"dataset.drop(lambda x: x['seq_len'] <= 3)\n", | "dataset.drop(lambda x: x['seq_len'] <= 3)\n", | ||||
"print(len(dataset))" | "print(len(dataset))" | ||||
@@ -250,7 +163,7 @@ | |||||
}, | }, | ||||
{ | { | ||||
"cell_type": "code", | "cell_type": "code", | ||||
"execution_count": 20, | |||||
"execution_count": null, | |||||
"metadata": {}, | "metadata": {}, | ||||
"outputs": [], | "outputs": [], | ||||
"source": [ | "source": [ | ||||
@@ -264,18 +177,9 @@ | |||||
}, | }, | ||||
{ | { | ||||
"cell_type": "code", | "cell_type": "code", | ||||
"execution_count": 21, | |||||
"metadata": {}, | |||||
"outputs": [ | |||||
{ | |||||
"name": "stdout", | |||||
"output_type": "stream", | |||||
"text": [ | |||||
"27\n", | |||||
"11" | |||||
] | |||||
} | |||||
], | |||||
"execution_count": null, | |||||
"metadata": {}, | |||||
"outputs": [], | |||||
"source": [ | "source": [ | ||||
"# 分出测试集、训练集\n", | "# 分出测试集、训练集\n", | ||||
"\n", | "\n", | ||||
@@ -296,20 +200,9 @@ | |||||
}, | }, | ||||
{ | { | ||||
"cell_type": "code", | "cell_type": "code", | ||||
"execution_count": 22, | |||||
"metadata": {}, | |||||
"outputs": [ | |||||
{ | |||||
"name": "stdout", | |||||
"output_type": "stream", | |||||
"text": [ | |||||
"{'raw_sentence': that the chuck norris `` grenade gag '' occurs about 7 times during windtalkers is a good indication of how serious-minded the film is .,\n", | |||||
"'label': 2,\n", | |||||
"'words': [6, 2, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 10, 8, 24, 1, 5, 1, 1, 2, 15, 10, 3],\n", | |||||
"'seq_len': 25}\n" | |||||
] | |||||
} | |||||
], | |||||
"execution_count": null, | |||||
"metadata": {}, | |||||
"outputs": [], | |||||
"source": [ | "source": [ | ||||
"from fastNLP import Vocabulary\n", | "from fastNLP import Vocabulary\n", | ||||
"\n", | "\n", | ||||
@@ -336,36 +229,9 @@ | |||||
}, | }, | ||||
{ | { | ||||
"cell_type": "code", | "cell_type": "code", | ||||
"execution_count": 23, | |||||
"metadata": {}, | |||||
"outputs": [ | |||||
{ | |||||
"data": { | |||||
"text/plain": [ | |||||
"CNNText(\n", | |||||
" (embed): Embedding(\n", | |||||
" (embed): Embedding(32, 50, padding_idx=0)\n", | |||||
" (dropout): Dropout(p=0.0)\n", | |||||
" )\n", | |||||
" (conv_pool): ConvMaxpool(\n", | |||||
" (convs): ModuleList(\n", | |||||
" (0): Conv1d(50, 3, kernel_size=(3,), stride=(1,), padding=(2,))\n", | |||||
" (1): Conv1d(50, 4, kernel_size=(4,), stride=(1,), padding=(2,))\n", | |||||
" (2): Conv1d(50, 5, kernel_size=(5,), stride=(1,), padding=(2,))\n", | |||||
" )\n", | |||||
" )\n", | |||||
" (dropout): Dropout(p=0.1)\n", | |||||
" (fc): Linear(\n", | |||||
" (linear): Linear(in_features=12, out_features=5, bias=True)\n", | |||||
" )\n", | |||||
")" | |||||
] | |||||
}, | |||||
"execution_count": 23, | |||||
"metadata": {}, | |||||
"output_type": "execute_result" | |||||
} | |||||
], | |||||
"execution_count": null, | |||||
"metadata": {}, | |||||
"outputs": [], | |||||
"source": [ | "source": [ | ||||
"from fastNLP.models import CNNText\n", | "from fastNLP.models import CNNText\n", | ||||
"model = CNNText(embed_num=len(vocab), embed_dim=50, num_classes=5, padding=2, dropout=0.1)\n", | "model = CNNText(embed_num=len(vocab), embed_dim=50, num_classes=5, padding=2, dropout=0.1)\n", | ||||
@@ -432,7 +298,7 @@ | |||||
}, | }, | ||||
{ | { | ||||
"cell_type": "code", | "cell_type": "code", | ||||
"execution_count": 25, | |||||
"execution_count": null, | |||||
"metadata": {}, | "metadata": {}, | ||||
"outputs": [], | "outputs": [], | ||||
"source": [ | "source": [ | ||||
@@ -469,7 +335,7 @@ | |||||
}, | }, | ||||
{ | { | ||||
"cell_type": "code", | "cell_type": "code", | ||||
"execution_count": 26, | |||||
"execution_count": null, | |||||
"metadata": {}, | "metadata": {}, | ||||
"outputs": [], | "outputs": [], | ||||
"source": [ | "source": [ | ||||
@@ -492,7 +358,7 @@ | |||||
}, | }, | ||||
{ | { | ||||
"cell_type": "code", | "cell_type": "code", | ||||
"execution_count": 27, | |||||
"execution_count": null, | |||||
"metadata": {}, | "metadata": {}, | ||||
"outputs": [], | "outputs": [], | ||||
"source": [ | "source": [ | ||||
@@ -501,94 +367,9 @@ | |||||
}, | }, | ||||
{ | { | ||||
"cell_type": "code", | "cell_type": "code", | ||||
"execution_count": 30, | |||||
"metadata": {}, | |||||
"outputs": [ | |||||
{ | |||||
"name": "stdout", | |||||
"output_type": "stream", | |||||
"text": [ | |||||
"training epochs started 2018-12-04 22:51:24\n" | |||||
] | |||||
}, | |||||
{ | |||||
"name": "stderr", | |||||
"output_type": "stream", | |||||
"text": [ | |||||
" \r" | |||||
] | |||||
}, | |||||
{ | |||||
"name": "stdout", | |||||
"output_type": "stream", | |||||
"text": [ | |||||
"Epoch 1/5. Step:1/5. AccuracyMetric: acc=0.296296\n" | |||||
] | |||||
}, | |||||
{ | |||||
"name": "stderr", | |||||
"output_type": "stream", | |||||
"text": [ | |||||
" \r" | |||||
] | |||||
}, | |||||
{ | |||||
"name": "stdout", | |||||
"output_type": "stream", | |||||
"text": [ | |||||
"Epoch 2/5. Step:2/5. AccuracyMetric: acc=0.407407\n" | |||||
] | |||||
}, | |||||
{ | |||||
"name": "stderr", | |||||
"output_type": "stream", | |||||
"text": [ | |||||
" \r" | |||||
] | |||||
}, | |||||
{ | |||||
"name": "stdout", | |||||
"output_type": "stream", | |||||
"text": [ | |||||
"Epoch 3/5. Step:3/5. AccuracyMetric: acc=0.518519\n" | |||||
] | |||||
}, | |||||
{ | |||||
"name": "stderr", | |||||
"output_type": "stream", | |||||
"text": [ | |||||
" \r" | |||||
] | |||||
}, | |||||
{ | |||||
"name": "stdout", | |||||
"output_type": "stream", | |||||
"text": [ | |||||
"Epoch 4/5. Step:4/5. AccuracyMetric: acc=0.481481\n" | |||||
] | |||||
}, | |||||
{ | |||||
"name": "stderr", | |||||
"output_type": "stream", | |||||
"text": [ | |||||
" \r" | |||||
] | |||||
}, | |||||
{ | |||||
"name": "stdout", | |||||
"output_type": "stream", | |||||
"text": [ | |||||
"Epoch 5/5. Step:5/5. AccuracyMetric: acc=0.592593\n" | |||||
] | |||||
}, | |||||
{ | |||||
"name": "stderr", | |||||
"output_type": "stream", | |||||
"text": [ | |||||
" \r" | |||||
] | |||||
} | |||||
], | |||||
"execution_count": null, | |||||
"metadata": {}, | |||||
"outputs": [], | |||||
"source": [ | "source": [ | ||||
"# 实例化Trainer,传入模型和数据,进行训练\n", | "# 实例化Trainer,传入模型和数据,进行训练\n", | ||||
"# 先在test_data拟合\n", | "# 先在test_data拟合\n", | ||||
@@ -604,101 +385,9 @@ | |||||
}, | }, | ||||
{ | { | ||||
"cell_type": "code", | "cell_type": "code", | ||||
"execution_count": 31, | |||||
"metadata": {}, | |||||
"outputs": [ | |||||
{ | |||||
"name": "stdout", | |||||
"output_type": "stream", | |||||
"text": [ | |||||
"training epochs started 2018-12-04 22:52:01\n" | |||||
] | |||||
}, | |||||
{ | |||||
"name": "stderr", | |||||
"output_type": "stream", | |||||
"text": [ | |||||
" \r" | |||||
] | |||||
}, | |||||
{ | |||||
"name": "stdout", | |||||
"output_type": "stream", | |||||
"text": [ | |||||
"Epoch 1/5. Step:1/5. AccuracyMetric: acc=0.296296\n" | |||||
] | |||||
}, | |||||
{ | |||||
"name": "stderr", | |||||
"output_type": "stream", | |||||
"text": [ | |||||
" \r" | |||||
] | |||||
}, | |||||
{ | |||||
"name": "stdout", | |||||
"output_type": "stream", | |||||
"text": [ | |||||
"Epoch 2/5. Step:2/5. AccuracyMetric: acc=0.222222\n" | |||||
] | |||||
}, | |||||
{ | |||||
"name": "stderr", | |||||
"output_type": "stream", | |||||
"text": [ | |||||
" \r" | |||||
] | |||||
}, | |||||
{ | |||||
"name": "stdout", | |||||
"output_type": "stream", | |||||
"text": [ | |||||
"Epoch 3/5. Step:3/5. AccuracyMetric: acc=0.259259\n" | |||||
] | |||||
}, | |||||
{ | |||||
"name": "stderr", | |||||
"output_type": "stream", | |||||
"text": [ | |||||
" \r" | |||||
] | |||||
}, | |||||
{ | |||||
"name": "stdout", | |||||
"output_type": "stream", | |||||
"text": [ | |||||
"Epoch 4/5. Step:4/5. AccuracyMetric: acc=0.296296\n" | |||||
] | |||||
}, | |||||
{ | |||||
"name": "stderr", | |||||
"output_type": "stream", | |||||
"text": [ | |||||
" \r" | |||||
] | |||||
}, | |||||
{ | |||||
"name": "stdout", | |||||
"output_type": "stream", | |||||
"text": [ | |||||
"Epoch 5/5. Step:5/5. AccuracyMetric: acc=0.259259\n" | |||||
] | |||||
}, | |||||
{ | |||||
"name": "stderr", | |||||
"output_type": "stream", | |||||
"text": [ | |||||
" \r" | |||||
] | |||||
}, | |||||
{ | |||||
"name": "stdout", | |||||
"output_type": "stream", | |||||
"text": [ | |||||
"Train finished!\n" | |||||
] | |||||
} | |||||
], | |||||
"execution_count": null, | |||||
"metadata": {}, | |||||
"outputs": [], | |||||
"source": [ | "source": [ | ||||
"# 用train_data训练,在test_data验证\n", | "# 用train_data训练,在test_data验证\n", | ||||
"trainer = Trainer(model=model, train_data=train_data, dev_data=test_data,\n", | "trainer = Trainer(model=model, train_data=train_data, dev_data=test_data,\n", | ||||
@@ -713,19 +402,9 @@ | |||||
}, | }, | ||||
{ | { | ||||
"cell_type": "code", | "cell_type": "code", | ||||
"execution_count": 33, | |||||
"metadata": {}, | |||||
"outputs": [ | |||||
{ | |||||
"name": "stdout", | |||||
"output_type": "stream", | |||||
"text": [ | |||||
"[tester] \n", | |||||
"AccuracyMetric: acc=0.259259\n", | |||||
"{'AccuracyMetric': {'acc': 0.259259}}\n" | |||||
] | |||||
} | |||||
], | |||||
"execution_count": null, | |||||
"metadata": {}, | |||||
"outputs": [], | |||||
"source": [ | "source": [ | ||||
"# 调用Tester在test_data上评价效果\n", | "# 调用Tester在test_data上评价效果\n", | ||||
"from fastNLP import Tester\n", | "from fastNLP import Tester\n", | ||||