Browse Source

* fix bugs in DataSet & Instance

* add more code comments
* fix tester
* refresh code styles
tags/v0.2.0
FengZiYjun yunfan 6 years ago
parent
commit
c4103561a8
5 changed files with 113 additions and 129 deletions
  1. +1
    -1
      fastNLP/core/batch.py
  2. +70
    -51
      fastNLP/core/dataset.py
  3. +7
    -49
      fastNLP/core/tester.py
  4. +33
    -26
      fastNLP/core/trainer.py
  5. +2
    -2
      fastNLP/core/utils.py

+ 1
- 1
fastNLP/core/batch.py View File

@@ -10,7 +10,7 @@ class Batch(object):

"""

def __init__(self, dataset, batch_size, sampler, as_numpy=False,):
def __init__(self, dataset, batch_size, sampler, as_numpy=False):
"""

:param dataset: a DataSet object


+ 70
- 51
fastNLP/core/dataset.py View File

@@ -1,6 +1,7 @@
import numpy as np

from fastNLP.core.fieldarray import FieldArray
from fastNLP.core.instance import Instance

_READERS = {}

@@ -27,10 +28,10 @@ class DataSet(object):
"""

class Instance(object):
def __init__(self, dataset, idx=-1):
def __init__(self, dataset, idx=-1, **fields):
self.dataset = dataset
self.idx = idx
self.fields = None
self.fields = fields

def __next__(self):
self.idx += 1
@@ -38,6 +39,14 @@ class DataSet(object):
raise StopIteration
return self

def add_field(self, field_name, field):
"""Add a new field to the instance.

:param field_name: str, the name of the field.
:param field:
"""
self.fields[field_name] = field

def __getitem__(self, name):
return self.dataset[name][self.idx]

@@ -47,13 +56,6 @@ class DataSet(object):
self.dataset.add_field(name, new_fields)
self.dataset[name][self.idx] = val

def __getattr__(self, item):
if item == 'fields':
self.fields = {name: field[self.idx] for name, field in self.dataset.get_fields().items()}
return self.fields
else:
raise AttributeError('{} does not exist.'.format(item))

def __repr__(self):
return "\n".join(['{}: {}'.format(name, repr(self.dataset[name][self.idx])) for name
in self.dataset.get_fields().keys()])
@@ -112,14 +114,13 @@ class DataSet(object):
self.field_arrays[name].append(field)

def add_field(self, name, fields, padding_val=0, is_input=False, is_target=False):
"""
"""Add a new field to the DataSet.
:param str name:
:param fields:
:param int padding_val:
:param bool is_input:
:param bool is_target:
:return:
:param str name: the name of the field.
:param fields: a list of int, float, or other objects.
:param int padding_val: integer for padding.
:param bool is_input: whether this field is model input.
:param bool is_target: whether this field is label or target.
"""
if len(self.field_arrays) != 0:
assert len(self) == len(fields)
@@ -127,28 +128,43 @@ class DataSet(object):
is_input=is_input)

def delete_field(self, name):
"""Delete a field based on the field name.

:param str name: the name of the field to be deleted.
"""
self.field_arrays.pop(name)

def get_fields(self):
"""Return all the fields with their names.

:return dict field_arrays: the internal data structure of DataSet.
"""
return self.field_arrays

def __getitem__(self, name):
if isinstance(name, int):
return self.Instance(self, idx=name)
elif isinstance(name, slice):
ds = DataSet()
def __getitem__(self, idx):
"""

:param idx: can be int, slice, or str.
:return: If `idx` is int, return an Instance object.
If `idx` is slice, return a DataSet object.
If `idx` is str, it must be a field name, return the field.

"""
if isinstance(idx, int):
return self.Instance(self, idx, **{name: self.field_arrays[name][idx] for name in self.field_arrays})
elif isinstance(idx, slice):
data_set = DataSet()
for field in self.field_arrays.values():
ds.add_field(name=field.name,
fields=field.content[name],
padding_val=field.padding_val,
need_tensor=field.need_tensor,
is_target=field.is_target)
return ds

elif isinstance(name, str):
return self.field_arrays[name]
data_set.add_field(name=field.name,
fields=field.content[idx],
padding_val=field.padding_val,
is_input=field.is_input,
is_target=field.is_target)
return data_set
elif isinstance(idx, str):
return self.field_arrays[idx]
else:
raise KeyError
raise KeyError("Unrecognized type {} for idx in __getitem__ method".format(type(idx)))

def __len__(self):
if len(self.field_arrays) == 0:
@@ -208,6 +224,7 @@ class DataSet(object):
pass
try:
reader_cls = _READERS[item]

# add read_*data() support
def _read(*args, **kwargs):
data = reader_cls().load(*args, **kwargs)
@@ -231,6 +248,12 @@ class DataSet(object):
return wrapper

def apply(self, func, new_field_name=None):
"""Apply a function to every instance of the DataSet.

:param func: a function that takes an instance as input.
:param str new_field_name: If not None, results of the function will be stored as a new field.
:return results: returned values of the function over all instances.
"""
results = []
for ins in self:
results.append(func(ins))
@@ -247,28 +270,24 @@ class DataSet(object):
else:
return results

def split(self, test_ratio):
assert isinstance(test_ratio, float)
def split(self, dev_ratio):
"""Split the dataset into training and development(validation) set.

:param float dev_ratio: the ratio of test set in all data.
:return DataSet train_set: the training set
DataSet dev_set: the development set
"""
assert isinstance(dev_ratio, float)
assert 0 < dev_ratio < 1
all_indices = [_ for _ in range(len(self))]
np.random.shuffle(all_indices)
test_indices = all_indices[:int(test_ratio)]
train_indices = all_indices[int(test_ratio):]
test_set = DataSet()
split = int(dev_ratio * len(self))
dev_indices = all_indices[:split]
train_indices = all_indices[split:]
dev_set = DataSet()
train_set = DataSet()
for idx in test_indices:
test_set.append(self[idx])
for idx in dev_indices:
dev_set.append(self[idx])
for idx in train_indices:
train_set.append(self[idx])
return train_set, test_set


if __name__ == '__main__':
from fastNLP.core.instance import Instance

d = DataSet({'a': list('abc')})
_ = d.a
d.apply(lambda x: x['a'])
print(d[1])
import copy
dd = copy.deepcopy(d)
print(dd.a)
return train_set, dev_set

+ 7
- 49
fastNLP/core/tester.py View File

@@ -3,61 +3,19 @@ from collections import defaultdict
import torch

from fastNLP.core.batch import Batch
from fastNLP.core.metrics import Evaluator
from fastNLP.core.sampler import RandomSampler


# logger = create_logger(__name__, "./train_test.log")


class Tester(object):
"""An collection of model inference and evaluation of performance, used over validation/dev set and test set. """

def __init__(self, **kwargs):
"""
:param kwargs: a dict-like object that has __getitem__ method, can be accessed by "test_args["key_str"]"
"""
def __init__(self, batch_size, evaluator, use_cuda, save_path="./save/", **kwargs):
super(Tester, self).__init__()
"""
"default_args" provides default value for important settings.
The initialization arguments "kwargs" with the same key (name) will override the default value.
"kwargs" must have the same type as "default_args" on corresponding keys.
Otherwise, error will raise.
"""
default_args = {"batch_size": 8,
"use_cuda": False,
"pickle_path": "./save/",
"model_name": "dev_best_model.pkl",
"evaluator": Evaluator()
}
"""
"required_args" is the collection of arguments that users must pass to Trainer explicitly.
This is used to warn users of essential settings in the training.
Specially, "required_args" does not have default value, so they have nothing to do with "default_args".
"""
required_args = {}

for req_key in required_args:
if req_key not in kwargs:
raise ValueError("Tester lacks argument {}".format(req_key))

for key in default_args:
if key in kwargs:
if isinstance(kwargs[key], type(default_args[key])):
default_args[key] = kwargs[key]
else:
msg = "Argument %s type mismatch: expected %s while get %s" % (
key, type(default_args[key]), type(kwargs[key]))
raise ValueError(msg)
else:
# Tester doesn't care about extra arguments
pass
# print(default_args)

self.batch_size = default_args["batch_size"]
self.pickle_path = default_args["pickle_path"]
self.use_cuda = default_args["use_cuda"]
self._evaluator = default_args["evaluator"]

self.batch_size = batch_size
self.pickle_path = save_path
self.use_cuda = use_cuda
self._evaluator = evaluator

self._model = None
self.eval_history = [] # evaluation results of all batches
@@ -72,7 +30,7 @@ class Tester(object):
self.mode(network, is_test=True)
self.eval_history.clear()
output, truths = defaultdict(list), defaultdict(list)
data_iterator = Batch(dev_data, self.batch_size, sampler=RandomSampler(), use_cuda=self.use_cuda)
data_iterator = Batch(dev_data, self.batch_size, sampler=RandomSampler(), as_numpy=False)

with torch.no_grad():
for batch_x, batch_y in data_iterator:


+ 33
- 26
fastNLP/core/trainer.py View File

@@ -15,6 +15,8 @@ from fastNLP.core.optimizer import Optimizer
from fastNLP.core.sampler import RandomSampler
from fastNLP.core.sampler import SequentialSampler
from fastNLP.core.tester import Tester
from fastNLP.core.utils import _build_args
from fastNLP.core.utils import _check_arg_dict_list

from fastNLP.core.utils import _check_arg_dict_list
from fastNLP.core.utils import _build_args
@@ -78,7 +80,7 @@ class Trainer(object):
epoch = 1
while epoch <= self.n_epochs:

data_iterator = Batch(self.train_data, batch_size=self.batch_size, sampler=RandomSampler())
data_iterator = Batch(self.train_data, batch_size=self.batch_size, sampler=RandomSampler(), as_numpy=False)

self._train_epoch(data_iterator, self.model, epoch, self.dev_data, start)

@@ -207,9 +209,9 @@ def best_eval_result(self, metrics):
DEFAULT_CHECK_BATCH_SIZE = 2
DEFAULT_CHECK_NUM_BATCH = 2

IGNORE_CHECK_LEVEL=0
WARNING_CHECK_LEVEL=1
STRICT_CHECK_LEVEL=2
IGNORE_CHECK_LEVEL = 0
WARNING_CHECK_LEVEL = 1
STRICT_CHECK_LEVEL = 2

def _check_code(dataset, model, batch_size=DEFAULT_CHECK_BATCH_SIZE, dev_data=None, check_level=1):
# check get_loss 方法
@@ -220,11 +222,20 @@ def _check_code(dataset, model, batch_size=DEFAULT_CHECK_BATCH_SIZE, dev_data=No
batch_size = min(DEFAULT_CHECK_BATCH_SIZE, batch_size)
batch = Batch(dataset=dataset, batch_size=batch_size, sampler=SequentialSampler())
for batch_count, (batch_x, batch_y) in enumerate(batch):
_syn_model_data(model, batch_x, batch_y)
# forward check
if batch_count==0:
_check_forward_error(model=model, model_func=model.forward, check_level=check_level,
batch_x=batch_x)
if batch_count == 0:
check_res = _check_arg_dict_list(model.forward, batch_x)
_info_str = ''
if len(check_res.missing) > 0:
if check_level == WARNING_CHECK_LEVEL:
for field_name in check_res.missing:
if hasattr(dataset, field_name):
_info_str += "{} "
_info_str += "Missing argument: [{}] needed by '{}.forward' is not presented in the input.\n"
_info_str += ""
print("")
if len(check_res.unused) > 0:
if check_level == WARNING_CHECK_LEVEL:
_info_str += ""

refined_batch_x = _build_args(model.forward, **batch_x)
output = model(**refined_batch_x)
@@ -233,10 +244,14 @@ def _check_code(dataset, model, batch_size=DEFAULT_CHECK_BATCH_SIZE, dev_data=No

# loss check
if batch_count == 0:
_check_loss_evaluate(model=model, model_func=model.get_loss, check_level=check_level,
output=output, batch_y=batch_y)
loss_input = _build_args(model.get_loss, **output, **batch_y)
loss = model.get_loss(**loss_input)
_dict = _check_arg_dict_list(model.loss, [output, batch_y])
if len(_dict) != 0:
pass
loss_input = _build_args(model.loss, **output, **batch_y)
loss = model.loss(**loss_input)
if batch_count == 0:
if isinstance(loss, torch.Tensor):
pass

# check loss output
if batch_count == 0:
@@ -248,8 +263,7 @@ def _check_code(dataset, model, batch_size=DEFAULT_CHECK_BATCH_SIZE, dev_data=No
model_name, loss.size()
))
loss.backward()
model.zero_grad()
if batch_count+1>=DEFAULT_CHECK_NUM_BATCH:
if batch_count + 1 >= DEFAULT_CHECK_BATCH_SIZE:
break
if check_level > IGNORE_CHECK_LEVEL:
print('Finish checking training process.', flush=True)
@@ -407,14 +421,7 @@ if __name__ == '__main__':

# trainer = Trainer(dataset, model)

_check_code(dataset=dataset, model=model, dev_data=dataset, check_level=2)

# _check_forward_error(model=model, model_func=model.forward, check_level=1,
# batch_x=fake_data_dict)

# import inspect
# print(inspect.getfullargspec(model.forward))




if len(_dict) != 0:
pass
refined_batch_x = _build_args(model.forward, **batch_x)
output = model(**refined_batch_x)

+ 2
- 2
fastNLP/core/utils.py View File

@@ -1,8 +1,8 @@
import _pickle
import os
import inspect
from collections import namedtuple
import os
from collections import Counter
from collections import namedtuple

CheckRes = namedtuple('CheckRes', ['missing', 'unused', 'duplicated', 'required', 'all_needed'], verbose=False)



Loading…
Cancel
Save