Browse Source

修改部分bug;调整callback

tags/v0.4.10
yh 5 years ago
parent
commit
4d1721ffe3
10 changed files with 87 additions and 136 deletions
  1. +4
    -2
      fastNLP/core/batch.py
  2. +19
    -32
      fastNLP/core/callback.py
  3. +14
    -18
      fastNLP/core/dataset.py
  4. +17
    -0
      fastNLP/core/fieldarray.py
  5. +6
    -66
      fastNLP/core/metrics.py
  6. +20
    -11
      fastNLP/core/trainer.py
  7. +2
    -2
      setup.py
  8. +1
    -1
      test/automl/test_enas.py
  9. +2
    -2
      test/core/test_dataset.py
  10. +2
    -2
      test/test_tutorials.py

+ 4
- 2
fastNLP/core/batch.py View File

@@ -14,15 +14,17 @@ class Batch(object):


:param DataSet dataset: a DataSet object :param DataSet dataset: a DataSet object
:param int batch_size: the size of the batch :param int batch_size: the size of the batch
:param Sampler sampler: a Sampler object
:param Sampler sampler: a Sampler object. If None, use fastNLP.sampler.RandomSampler
:param bool as_numpy: If True, return Numpy array. Otherwise, return torch tensors. :param bool as_numpy: If True, return Numpy array. Otherwise, return torch tensors.
:param bool prefetch: If True, use multiprocessing to fetch next batch when training. :param bool prefetch: If True, use multiprocessing to fetch next batch when training.
:param str or torch.device device: the batch's device, if as_numpy is True, device is ignored. :param str or torch.device device: the batch's device, if as_numpy is True, device is ignored.
""" """


def __init__(self, dataset, batch_size, sampler=RandomSampler(), as_numpy=False, prefetch=False):
def __init__(self, dataset, batch_size, sampler=None, as_numpy=False, prefetch=False):
self.dataset = dataset self.dataset = dataset
self.batch_size = batch_size self.batch_size = batch_size
if sampler is None:
sampler = RandomSampler()
self.sampler = sampler self.sampler = sampler
self.as_numpy = as_numpy self.as_numpy = as_numpy
self.idx_list = None self.idx_list = None


+ 19
- 32
fastNLP/core/callback.py View File

@@ -17,37 +17,37 @@ class Callback(object):
super(Callback, self).__init__() super(Callback, self).__init__()
self.trainer = None # 在Trainer内部被重新赋值 self.trainer = None # 在Trainer内部被重新赋值


# callback只读属性
self._n_epochs = None
self._n_steps = None
self._batch_size = None
self._model = None
self._pbar = None
self._optimizer = None

@property @property
def n_epochs(self): def n_epochs(self):
return self._n_epochs
return self.trainer.n_epochs

@property
def epoch(self):
return self.trainer.epoch


@property @property
def n_steps(self): def n_steps(self):
return self._n_steps
return self.trainer.n_steps

@property
def step(self):
return self.trainer.step


@property @property
def batch_size(self): def batch_size(self):
return self._batch_size
return self.trainer.batch_size


@property @property
def model(self): def model(self):
return self._model
return self.trainer.model


@property @property
def pbar(self): def pbar(self):
return self._pbar
return self.trainer.pbar


@property @property
def optimizer(self): def optimizer(self):
return self._optimizer
return self.trainer.optimizer


def on_train_begin(self): def on_train_begin(self):
# before the main training loop # before the main training loop
@@ -82,13 +82,14 @@ class Callback(object):
def on_valid_begin(self): def on_valid_begin(self):
pass pass


def on_valid_end(self, eval_result, metric_key, optimizer):
def on_valid_end(self, eval_result, metric_key, optimizer, is_better_eval):
""" """
每次执行验证机的evaluation后会调用。传入eval_result 每次执行验证机的evaluation后会调用。传入eval_result


:param eval_result: Dict[str: Dict[str: float]], evaluation的结果 :param eval_result: Dict[str: Dict[str: float]], evaluation的结果
:param metric_key: str :param metric_key: str
:param optimizer:
:param optimizer: optimizer passed to trainer
:param is_better_eval: bool, 当前dev结果是否比之前的好
:return: :return:
""" """
pass pass
@@ -145,11 +146,10 @@ class CallbackManager(Callback):


""" """


def __init__(self, env, attr, callbacks=None):
def __init__(self, env, callbacks=None):
""" """


:param dict env: The key is the name of the Trainer attribute(str). The value is the attribute itself. :param dict env: The key is the name of the Trainer attribute(str). The value is the attribute itself.
:param dict attr: read-only attributes for all callbacks
:param Callback callbacks: :param Callback callbacks:
""" """
super(CallbackManager, self).__init__() super(CallbackManager, self).__init__()
@@ -170,19 +170,6 @@ class CallbackManager(Callback):
for callback in self.callbacks: for callback in self.callbacks:
setattr(callback, env_name, env_val) # Callback.trainer setattr(callback, env_name, env_val) # Callback.trainer


self.set_property(**attr)

def set_property(self, **kwargs):
"""设置所有callback的只读属性

:param kwargs:
:return:
"""
for callback in self.callbacks:
for k, v in kwargs.items():
setattr(callback, "_" + k, v)


@transfer @transfer
def on_train_begin(self): def on_train_begin(self):
pass pass
@@ -220,7 +207,7 @@ class CallbackManager(Callback):
pass pass


@transfer @transfer
def on_valid_end(self, eval_result, metric_key, optimizer):
def on_valid_end(self, eval_result, metric_key, optimizer, is_better_eval):
pass pass


@transfer @transfer


+ 14
- 18
fastNLP/core/dataset.py View File

@@ -90,7 +90,7 @@ class DataSet(object):
data_set = DataSet() data_set = DataSet()
for field in self.field_arrays.values(): for field in self.field_arrays.values():
data_set.add_field(name=field.name, fields=field.content[idx], padder=field.padder, data_set.add_field(name=field.name, fields=field.content[idx], padder=field.padder,
is_input=field.is_input, is_target=field.is_target)
is_input=field.is_input, is_target=field.is_target, ignore_type=field.ignore_type)
return data_set return data_set
elif isinstance(idx, str): elif isinstance(idx, str):
if idx not in self: if idx not in self:
@@ -313,16 +313,23 @@ class DataSet(object):
else: else:
return results return results


def drop(self, func):
def drop(self, func, inplace=True):
"""Drop instances if a condition holds. """Drop instances if a condition holds.


:param func: a function that takes an Instance object as input, and returns bool. :param func: a function that takes an Instance object as input, and returns bool.
The instance will be dropped if the function returns True. The instance will be dropped if the function returns True.
:param inplace: bool, whether to drop inpalce. Otherwise a new dataset will be returned.


""" """
results = [ins for ins in self._inner_iter() if not func(ins)]
for name, old_field in self.field_arrays.items():
self.field_arrays[name].content = [ins[name] for ins in results]
if inplace:
results = [ins for ins in self._inner_iter() if not func(ins)]
for name, old_field in self.field_arrays.items():
self.field_arrays[name].content = [ins[name] for ins in results]
else:
results = [ins for ins in self if not func(ins)]
data = DataSet(results)
for field_name, field in self.field_arrays.items():
data.field_arrays[field_name].to(field)


def split(self, dev_ratio): def split(self, dev_ratio):
"""Split the dataset into training and development(validation) set. """Split the dataset into training and development(validation) set.
@@ -346,19 +353,8 @@ class DataSet(object):
for idx in train_indices: for idx in train_indices:
train_set.append(self[idx]) train_set.append(self[idx])
for field_name in self.field_arrays: for field_name in self.field_arrays:
train_set.field_arrays[field_name].is_input = self.field_arrays[field_name].is_input
train_set.field_arrays[field_name].is_target = self.field_arrays[field_name].is_target
train_set.field_arrays[field_name].padder = self.field_arrays[field_name].padder
train_set.field_arrays[field_name].dtype = self.field_arrays[field_name].dtype
train_set.field_arrays[field_name].pytype = self.field_arrays[field_name].pytype
train_set.field_arrays[field_name].content_dim = self.field_arrays[field_name].content_dim

dev_set.field_arrays[field_name].is_input = self.field_arrays[field_name].is_input
dev_set.field_arrays[field_name].is_target = self.field_arrays[field_name].is_target
dev_set.field_arrays[field_name].padder = self.field_arrays[field_name].padder
dev_set.field_arrays[field_name].dtype = self.field_arrays[field_name].dtype
dev_set.field_arrays[field_name].pytype = self.field_arrays[field_name].pytype
dev_set.field_arrays[field_name].content_dim = self.field_arrays[field_name].content_dim
train_set.field_arrays[field_name].to(self.field_arrays[field_name])
dev_set.field_arrays[field_name].to(self.field_arrays[field_name])


return train_set, dev_set return train_set, dev_set




+ 17
- 0
fastNLP/core/fieldarray.py View File

@@ -383,6 +383,23 @@ class FieldArray(object):
""" """
return len(self.content) return len(self.content)


def to(self, other):
"""
将other的属性复制给本fieldarray(必须通过fieldarray类型). 包含 is_input, is_target, padder, dtype, pytype, content_dim
ignore_type

:param other: FieldArray
:return:
"""
assert isinstance(other, FieldArray), "Only support FieldArray type, not {}.".format(type(other))

self.is_input = other.is_input
self.is_target = other.is_target
self.padder = other.padder
self.dtype = other.dtype
self.pytype = other.pytype
self.content_dim = other.content_dim
self.ignore_type = other.ignore_type


def is_iterable(content): def is_iterable(content):
try: try:


+ 6
- 66
fastNLP/core/metrics.py View File

@@ -91,7 +91,6 @@ class MetricBase(object):
Besides, before passing params into self.evaluate, this function will filter out params from output_dict and Besides, before passing params into self.evaluate, this function will filter out params from output_dict and
target_dict which are not used in self.evaluate. (but if **kwargs presented in self.evaluate, no filtering target_dict which are not used in self.evaluate. (but if **kwargs presented in self.evaluate, no filtering
will be conducted.) will be conducted.)
However, in some cases where type check is not necessary, ``_fast_param_map`` will be used.


""" """
def __init__(self): def __init__(self):
@@ -146,21 +145,6 @@ class MetricBase(object):
def get_metric(self, reset=True): def get_metric(self, reset=True):
raise NotImplemented raise NotImplemented


def _fast_param_map(self, pred_dict, target_dict):
"""Only used as inner function. When the pred_dict, target is unequivocal. Don't need users to pass key_map.
such as pred_dict has one element, target_dict has one element

:param pred_dict:
:param target_dict:
:return: dict, if dict is not {}, pass it to self.evaluate. Otherwise do mapping.
"""
fast_param = {}
if len(self.param_map) == 2 and len(pred_dict) == 1 and len(target_dict) == 1:
fast_param['pred'] = list(pred_dict.values())[0]
fast_param['target'] = list(target_dict.values())[0]
return fast_param
return fast_param

def __call__(self, pred_dict, target_dict): def __call__(self, pred_dict, target_dict):
""" """


@@ -172,7 +156,6 @@ class MetricBase(object):
Besides, before passing params into self.evaluate, this function will filter out params from output_dict and Besides, before passing params into self.evaluate, this function will filter out params from output_dict and
target_dict which are not used in self.evaluate. (but if **kwargs presented in self.evaluate, no filtering target_dict which are not used in self.evaluate. (but if **kwargs presented in self.evaluate, no filtering
will be conducted.) will be conducted.)
This function also support _fast_param_map.
:param pred_dict: usually the output of forward or prediction function :param pred_dict: usually the output of forward or prediction function
:param target_dict: usually features set as target.. :param target_dict: usually features set as target..
:return: :return:
@@ -180,11 +163,6 @@ class MetricBase(object):
if not callable(self.evaluate): if not callable(self.evaluate):
raise TypeError(f"{self.__class__.__name__}.evaluate has to be callable, not {type(self.evaluate)}.") raise TypeError(f"{self.__class__.__name__}.evaluate has to be callable, not {type(self.evaluate)}.")


fast_param = self._fast_param_map(pred_dict=pred_dict, target_dict=target_dict)
if fast_param:
self.evaluate(**fast_param)
return

if not self._checked: if not self._checked:
# 1. check consistence between signature and param_map # 1. check consistence between signature and param_map
func_spect = inspect.getfullargspec(self.evaluate) func_spect = inspect.getfullargspec(self.evaluate)
@@ -262,41 +240,6 @@ class AccuracyMetric(MetricBase):
self.total = 0 self.total = 0
self.acc_count = 0 self.acc_count = 0


def _fast_param_map(self, pred_dict, target_dict):
"""Only used as inner function. When the pred_dict, target is unequivocal. Don't need users to pass key_map.
such as pred_dict has one element, target_dict has one element

:param pred_dict:
:param target_dict:
:return: dict, if dict is not None, pass it to self.evaluate. Otherwise do mapping.
"""
fast_param = {}
targets = list(target_dict.values())
if len(targets) == 1 and isinstance(targets[0], torch.Tensor):
if len(pred_dict) == 1:
pred = list(pred_dict.values())[0]
fast_param['pred'] = pred
elif len(pred_dict) == 2:
pred1 = list(pred_dict.values())[0]
pred2 = list(pred_dict.values())[1]
if not (isinstance(pred1, torch.Tensor) and isinstance(pred2, torch.Tensor)):
return fast_param
if len(pred1.size()) < len(pred2.size()) and len(pred1.size()) == 1:
seq_lens = pred1
pred = pred2
elif len(pred1.size()) > len(pred2.size()) and len(pred2.size()) == 1:
seq_lens = pred2
pred = pred1
else:
return fast_param
fast_param['pred'] = pred
fast_param['seq_lens'] = seq_lens
else:
return fast_param
fast_param['target'] = targets[0]
# TODO need to make sure they all have same batch_size
return fast_param

def evaluate(self, pred, target, seq_lens=None): def evaluate(self, pred, target, seq_lens=None):
""" """


@@ -321,7 +264,7 @@ class AccuracyMetric(MetricBase):
f"got {type(seq_lens)}.") f"got {type(seq_lens)}.")


if seq_lens is not None: if seq_lens is not None:
masks = seq_lens_to_masks(seq_lens=seq_lens, float=True)
masks = seq_lens_to_masks(seq_lens=seq_lens).long()
else: else:
masks = None masks = None


@@ -334,14 +277,12 @@ class AccuracyMetric(MetricBase):
f"size:{pred.size()}, target should have size: {pred.size()} or " f"size:{pred.size()}, target should have size: {pred.size()} or "
f"{pred.size()[:-1]}, got {target.size()}.") f"{pred.size()[:-1]}, got {target.size()}.")


pred = pred.float()
target = target.float()


if masks is not None: if masks is not None:
self.acc_count += torch.sum(torch.eq(pred, target).float() * masks.float()).item()
self.total += torch.sum(masks.float()).item()
self.acc_count += torch.sum(torch.eq(pred, target) * masks).item()
self.total += torch.sum(masks).item()
else: else:
self.acc_count += torch.sum(torch.eq(pred, target).float()).item()
self.acc_count += torch.sum(torch.eq(pred, target)).item()
self.total += np.prod(list(pred.size())) self.total += np.prod(list(pred.size()))


def get_metric(self, reset=True): def get_metric(self, reset=True):
@@ -350,7 +291,7 @@ class AccuracyMetric(MetricBase):
:param bool reset: whether to recount next time. :param bool reset: whether to recount next time.
:return evaluate_result: {"acc": float} :return evaluate_result: {"acc": float}
""" """
evaluate_result = {'acc': round(self.acc_count / self.total, 6)}
evaluate_result = {'acc': round(float(self.acc_count) / (self.total + 1e-12), 6)}
if reset: if reset:
self.acc_count = 0 self.acc_count = 0
self.total = 0 self.total = 0
@@ -441,8 +382,7 @@ def bio_tag_to_spans(tags, ignore_labels=None):
prev_bio_tag = bio_tag prev_bio_tag = bio_tag
return [(span[0], (span[1][0], span[1][1]+1)) return [(span[0], (span[1][0], span[1][1]+1))
for span in spans for span in spans
if span[0] not in ignore_labels
]
if span[0] not in ignore_labels]




class SpanFPreRecMetric(MetricBase): class SpanFPreRecMetric(MetricBase):


+ 20
- 11
fastNLP/core/trainer.py View File

@@ -34,7 +34,7 @@ class Trainer(object):
def __init__(self, train_data, model, loss=None, metrics=None, n_epochs=3, batch_size=32, print_every=50, def __init__(self, train_data, model, loss=None, metrics=None, n_epochs=3, batch_size=32, print_every=50,
validate_every=-1, dev_data=None, save_path=None, optimizer=None, validate_every=-1, dev_data=None, save_path=None, optimizer=None,
check_code_level=0, metric_key=None, sampler=None, prefetch=False, use_tqdm=True, check_code_level=0, metric_key=None, sampler=None, prefetch=False, use_tqdm=True,
use_cuda=False, callbacks=None):
use_cuda=False, callbacks=None, update_every=1):
""" """
:param DataSet train_data: the training data :param DataSet train_data: the training data
:param torch.nn.modules.module model: a PyTorch model :param torch.nn.modules.module model: a PyTorch model
@@ -62,6 +62,8 @@ class Trainer(object):
:param bool use_tqdm: whether to use tqdm to show train progress. :param bool use_tqdm: whether to use tqdm to show train progress.
:param callbacks: List[Callback]. 用于在train过程中起调节作用的回调函数。比如early stop,negative sampling等可以 :param callbacks: List[Callback]. 用于在train过程中起调节作用的回调函数。比如early stop,negative sampling等可以
通过callback机制实现。 通过callback机制实现。
:param update_every: int, 多少步更新一次梯度。用于希望累计梯度的场景,比如需要128的batch_size, 但是直接设为128会导致内存
不足,通过设置batch_size=32, update_every=4达到目的
""" """
super(Trainer, self).__init__() super(Trainer, self).__init__()


@@ -76,6 +78,10 @@ class Trainer(object):
if metrics and (dev_data is None): if metrics and (dev_data is None):
raise ValueError("No dev_data for evaluations, pass dev_data or set metrics to None. ") raise ValueError("No dev_data for evaluations, pass dev_data or set metrics to None. ")


# check update every
assert update_every>=1, "update_every must be no less than 1."
self.update_every = int(update_every)

# check save_path # check save_path
if not (save_path is None or isinstance(save_path, str)): if not (save_path is None or isinstance(save_path, str)):
raise ValueError("save_path can only be None or `str`.") raise ValueError("save_path can only be None or `str`.")
@@ -144,11 +150,9 @@ class Trainer(object):
self.start_time = None # start timestamp self.start_time = None # start timestamp


self.callback_manager = CallbackManager(env={"trainer": self}, self.callback_manager = CallbackManager(env={"trainer": self},
attr={"n_epochs": self.n_epochs, "n_steps": self.step,
"batch_size": self.batch_size, "model": self.model,
"optimizer": self.optimizer},
callbacks=callbacks) callbacks=callbacks)



def train(self, load_best_model=True): def train(self, load_best_model=True):
""" """


@@ -241,7 +245,6 @@ class Trainer(object):
avg_loss = 0 avg_loss = 0
data_iterator = Batch(self.train_data, batch_size=self.batch_size, sampler=self.sampler, as_numpy=False, data_iterator = Batch(self.train_data, batch_size=self.batch_size, sampler=self.sampler, as_numpy=False,
prefetch=self.prefetch) prefetch=self.prefetch)
self.callback_manager.set_property(pbar=pbar)
for epoch in range(1, self.n_epochs+1): for epoch in range(1, self.n_epochs+1):
pbar.set_description_str(desc="Epoch {}/{}".format(epoch, self.n_epochs)) pbar.set_description_str(desc="Epoch {}/{}".format(epoch, self.n_epochs))
# early stopping # early stopping
@@ -257,6 +260,7 @@ class Trainer(object):
self.callback_manager.on_loss_begin(batch_y, prediction) self.callback_manager.on_loss_begin(batch_y, prediction)
loss = self._compute_loss(prediction, batch_y) loss = self._compute_loss(prediction, batch_y)
avg_loss += loss.item() avg_loss += loss.item()
loss = loss/self.update_every


# Is loss NaN or inf? requires_grad = False # Is loss NaN or inf? requires_grad = False
self.callback_manager.on_backward_begin(loss, self.model) self.callback_manager.on_backward_begin(loss, self.model)
@@ -267,8 +271,9 @@ class Trainer(object):
self.callback_manager.on_step_end(self.optimizer) self.callback_manager.on_step_end(self.optimizer)


if (self.step+1) % self.print_every == 0: if (self.step+1) % self.print_every == 0:
avg_loss = avg_loss / self.print_every
if self.use_tqdm: if self.use_tqdm:
print_output = "loss:{0:<6.5f}".format(avg_loss / self.print_every)
print_output = "loss:{0:<6.5f}".format(avg_loss)
pbar.update(self.print_every) pbar.update(self.print_every)
else: else:
end = time.time() end = time.time()
@@ -286,8 +291,8 @@ class Trainer(object):
eval_res = self._do_validation(epoch=epoch, step=self.step) eval_res = self._do_validation(epoch=epoch, step=self.step)
eval_str = "Evaluation at Epoch {}/{}. Step:{}/{}. ".format(epoch, self.n_epochs, self.step, eval_str = "Evaluation at Epoch {}/{}. Step:{}/{}. ".format(epoch, self.n_epochs, self.step,
total_steps) + \ total_steps) + \
self.tester._format_eval_results(eval_res)
pbar.write(eval_str)
self.tester._format_eval_results(eval_res)
pbar.write(eval_str + '\n')


# ================= mini-batch end ==================== # # ================= mini-batch end ==================== #


@@ -301,6 +306,7 @@ class Trainer(object):
self.callback_manager.on_valid_begin() self.callback_manager.on_valid_begin()
res = self.tester.test() res = self.tester.test()


is_better_eval = False
if self._better_eval_result(res): if self._better_eval_result(res):
if self.save_path is not None: if self.save_path is not None:
self._save_model(self.model, self._save_model(self.model,
@@ -310,8 +316,9 @@ class Trainer(object):
self.best_dev_perf = res self.best_dev_perf = res
self.best_dev_epoch = epoch self.best_dev_epoch = epoch
self.best_dev_step = step self.best_dev_step = step
is_better_eval = True
# get validation results; adjust optimizer # get validation results; adjust optimizer
self.callback_manager.on_valid_end(res, self.metric_key, self.optimizer)
self.callback_manager.on_valid_end(res, self.metric_key, self.optimizer, is_better_eval)
return res return res


def _mode(self, model, is_test=False): def _mode(self, model, is_test=False):
@@ -330,7 +337,8 @@ class Trainer(object):
"""Perform weight update on a model. """Perform weight update on a model.


""" """
self.optimizer.step()
if (self.step+1)%self.update_every==0:
self.optimizer.step()


def _data_forward(self, network, x): def _data_forward(self, network, x):
x = _build_args(network.forward, **x) x = _build_args(network.forward, **x)
@@ -346,7 +354,8 @@ class Trainer(object):


For PyTorch, just do "loss.backward()" For PyTorch, just do "loss.backward()"
""" """
self.model.zero_grad()
if self.step%self.update_every==0:
self.model.zero_grad()
loss.backward() loss.backward()


def _compute_loss(self, predict, truth): def _compute_loss(self, predict, truth):


+ 2
- 2
setup.py View File

@@ -13,12 +13,12 @@ with open('requirements.txt', encoding='utf-8') as f:


setup( setup(
name='FastNLP', name='FastNLP',
version='0.1.1',
version='0.4.0',
description='fastNLP: Deep Learning Toolkit for NLP, developed by Fudan FastNLP Team', description='fastNLP: Deep Learning Toolkit for NLP, developed by Fudan FastNLP Team',
long_description=readme, long_description=readme,
license=license, license=license,
author='FudanNLP', author='FudanNLP',
python_requires='>=3.5',
python_requires='>=3.6',
packages=find_packages(), packages=find_packages(),
install_requires=reqs.strip().split('\n'), install_requires=reqs.strip().split('\n'),
) )

+ 1
- 1
test/automl/test_enas.py View File

@@ -35,7 +35,7 @@ class TestENAS(unittest.TestCase):
print(dataset[0]) print(dataset[0])


# DataSet.drop(func)筛除数据 # DataSet.drop(func)筛除数据
dataset.drop(lambda x: x['seq_len'] <= 3)
dataset.drop(lambda x: x['seq_len'] <= 3, inplace=True)
print(len(dataset)) print(len(dataset))


# 设置DataSet中,哪些field要转为tensor # 设置DataSet中,哪些field要转为tensor


+ 2
- 2
test/core/test_dataset.py View File

@@ -125,7 +125,7 @@ class TestDataSetMethods(unittest.TestCase):


def test_drop(self): def test_drop(self):
ds = DataSet({"x": [[1, 2, 3, 4]] * 40, "y": [[5, 6], [7, 8, 9, 0]] * 20}) ds = DataSet({"x": [[1, 2, 3, 4]] * 40, "y": [[5, 6], [7, 8, 9, 0]] * 20})
ds.drop(lambda ins: len(ins["y"]) < 3)
ds.drop(lambda ins: len(ins["y"]) < 3, inplace=True)
self.assertEqual(len(ds), 20) self.assertEqual(len(ds), 20)


def test_contains(self): def test_contains(self):
@@ -169,7 +169,7 @@ class TestDataSetMethods(unittest.TestCase):


dataset = DataSet.read_csv('test/data_for_tests/tutorial_sample_dataset.csv', headers=('raw_sentence', 'label'), dataset = DataSet.read_csv('test/data_for_tests/tutorial_sample_dataset.csv', headers=('raw_sentence', 'label'),
sep='\t') sep='\t')
dataset.drop(lambda x: len(x['raw_sentence'].split()) == 0)
dataset.drop(lambda x: len(x['raw_sentence'].split()) == 0, inplace=True)
dataset.apply(split_sent, new_field_name='words', is_input=True) dataset.apply(split_sent, new_field_name='words', is_input=True)
# print(dataset) # print(dataset)




+ 2
- 2
test/test_tutorials.py View File

@@ -35,7 +35,7 @@ class TestTutorial(unittest.TestCase):
print(dataset[0]) print(dataset[0])


# DataSet.drop(func)筛除数据 # DataSet.drop(func)筛除数据
dataset.drop(lambda x: x['seq_len'] <= 3)
dataset.drop(lambda x: x['seq_len'] <= 3, inplace=True)
print(len(dataset)) print(len(dataset))


# 设置DataSet中,哪些field要转为tensor # 设置DataSet中,哪些field要转为tensor
@@ -296,7 +296,7 @@ class TestTutorial(unittest.TestCase):


# 筛选数据 # 筛选数据
origin_data_set_len = len(data_set) origin_data_set_len = len(data_set)
data_set.drop(lambda x: len(x['premise']) <= 6)
data_set.drop(lambda x: len(x['premise']) <= 6, inplace=True)
origin_data_set_len, len(data_set) origin_data_set_len, len(data_set)


# In[17]: # In[17]:


Loading…
Cancel
Save