Browse Source

* fix bugs in DataSet & Instance

* add more code comments
* fix tester
* refresh code styles
tags/v0.2.0
FengZiYjun yunfan 6 years ago
parent
commit
c4103561a8
5 changed files with 113 additions and 129 deletions
  1. +1
    -1
      fastNLP/core/batch.py
  2. +70
    -51
      fastNLP/core/dataset.py
  3. +7
    -49
      fastNLP/core/tester.py
  4. +33
    -26
      fastNLP/core/trainer.py
  5. +2
    -2
      fastNLP/core/utils.py

+ 1
- 1
fastNLP/core/batch.py View File

@@ -10,7 +10,7 @@ class Batch(object):


""" """


def __init__(self, dataset, batch_size, sampler, as_numpy=False,):
def __init__(self, dataset, batch_size, sampler, as_numpy=False):
""" """


:param dataset: a DataSet object :param dataset: a DataSet object


+ 70
- 51
fastNLP/core/dataset.py View File

@@ -1,6 +1,7 @@
import numpy as np import numpy as np


from fastNLP.core.fieldarray import FieldArray from fastNLP.core.fieldarray import FieldArray
from fastNLP.core.instance import Instance


_READERS = {} _READERS = {}


@@ -27,10 +28,10 @@ class DataSet(object):
""" """


class Instance(object): class Instance(object):
def __init__(self, dataset, idx=-1):
def __init__(self, dataset, idx=-1, **fields):
self.dataset = dataset self.dataset = dataset
self.idx = idx self.idx = idx
self.fields = None
self.fields = fields


def __next__(self): def __next__(self):
self.idx += 1 self.idx += 1
@@ -38,6 +39,14 @@ class DataSet(object):
raise StopIteration raise StopIteration
return self return self


def add_field(self, field_name, field):
"""Add a new field to the instance.

:param field_name: str, the name of the field.
:param field:
"""
self.fields[field_name] = field

def __getitem__(self, name): def __getitem__(self, name):
return self.dataset[name][self.idx] return self.dataset[name][self.idx]


@@ -47,13 +56,6 @@ class DataSet(object):
self.dataset.add_field(name, new_fields) self.dataset.add_field(name, new_fields)
self.dataset[name][self.idx] = val self.dataset[name][self.idx] = val


def __getattr__(self, item):
if item == 'fields':
self.fields = {name: field[self.idx] for name, field in self.dataset.get_fields().items()}
return self.fields
else:
raise AttributeError('{} does not exist.'.format(item))

def __repr__(self): def __repr__(self):
return "\n".join(['{}: {}'.format(name, repr(self.dataset[name][self.idx])) for name return "\n".join(['{}: {}'.format(name, repr(self.dataset[name][self.idx])) for name
in self.dataset.get_fields().keys()]) in self.dataset.get_fields().keys()])
@@ -112,14 +114,13 @@ class DataSet(object):
self.field_arrays[name].append(field) self.field_arrays[name].append(field)


def add_field(self, name, fields, padding_val=0, is_input=False, is_target=False): def add_field(self, name, fields, padding_val=0, is_input=False, is_target=False):
"""
"""Add a new field to the DataSet.
:param str name:
:param fields:
:param int padding_val:
:param bool is_input:
:param bool is_target:
:return:
:param str name: the name of the field.
:param fields: a list of int, float, or other objects.
:param int padding_val: integer for padding.
:param bool is_input: whether this field is model input.
:param bool is_target: whether this field is label or target.
""" """
if len(self.field_arrays) != 0: if len(self.field_arrays) != 0:
assert len(self) == len(fields) assert len(self) == len(fields)
@@ -127,28 +128,43 @@ class DataSet(object):
is_input=is_input) is_input=is_input)


def delete_field(self, name): def delete_field(self, name):
"""Delete a field based on the field name.

:param str name: the name of the field to be deleted.
"""
self.field_arrays.pop(name) self.field_arrays.pop(name)


def get_fields(self): def get_fields(self):
"""Return all the fields with their names.

:return dict field_arrays: the internal data structure of DataSet.
"""
return self.field_arrays return self.field_arrays


def __getitem__(self, name):
if isinstance(name, int):
return self.Instance(self, idx=name)
elif isinstance(name, slice):
ds = DataSet()
def __getitem__(self, idx):
"""

:param idx: can be int, slice, or str.
:return: If `idx` is int, return an Instance object.
If `idx` is slice, return a DataSet object.
If `idx` is str, it must be a field name, return the field.

"""
if isinstance(idx, int):
return self.Instance(self, idx, **{name: self.field_arrays[name][idx] for name in self.field_arrays})
elif isinstance(idx, slice):
data_set = DataSet()
for field in self.field_arrays.values(): for field in self.field_arrays.values():
ds.add_field(name=field.name,
fields=field.content[name],
padding_val=field.padding_val,
need_tensor=field.need_tensor,
is_target=field.is_target)
return ds

elif isinstance(name, str):
return self.field_arrays[name]
data_set.add_field(name=field.name,
fields=field.content[idx],
padding_val=field.padding_val,
is_input=field.is_input,
is_target=field.is_target)
return data_set
elif isinstance(idx, str):
return self.field_arrays[idx]
else: else:
raise KeyError
raise KeyError("Unrecognized type {} for idx in __getitem__ method".format(type(idx)))


def __len__(self): def __len__(self):
if len(self.field_arrays) == 0: if len(self.field_arrays) == 0:
@@ -208,6 +224,7 @@ class DataSet(object):
pass pass
try: try:
reader_cls = _READERS[item] reader_cls = _READERS[item]

# add read_*data() support # add read_*data() support
def _read(*args, **kwargs): def _read(*args, **kwargs):
data = reader_cls().load(*args, **kwargs) data = reader_cls().load(*args, **kwargs)
@@ -231,6 +248,12 @@ class DataSet(object):
return wrapper return wrapper


def apply(self, func, new_field_name=None): def apply(self, func, new_field_name=None):
"""Apply a function to every instance of the DataSet.

:param func: a function that takes an instance as input.
:param str new_field_name: If not None, results of the function will be stored as a new field.
:return results: returned values of the function over all instances.
"""
results = [] results = []
for ins in self: for ins in self:
results.append(func(ins)) results.append(func(ins))
@@ -247,28 +270,24 @@ class DataSet(object):
else: else:
return results return results


def split(self, test_ratio):
assert isinstance(test_ratio, float)
def split(self, dev_ratio):
"""Split the dataset into training and development(validation) set.

:param float dev_ratio: the ratio of test set in all data.
:return DataSet train_set: the training set
DataSet dev_set: the development set
"""
assert isinstance(dev_ratio, float)
assert 0 < dev_ratio < 1
all_indices = [_ for _ in range(len(self))] all_indices = [_ for _ in range(len(self))]
np.random.shuffle(all_indices) np.random.shuffle(all_indices)
test_indices = all_indices[:int(test_ratio)]
train_indices = all_indices[int(test_ratio):]
test_set = DataSet()
split = int(dev_ratio * len(self))
dev_indices = all_indices[:split]
train_indices = all_indices[split:]
dev_set = DataSet()
train_set = DataSet() train_set = DataSet()
for idx in test_indices:
test_set.append(self[idx])
for idx in dev_indices:
dev_set.append(self[idx])
for idx in train_indices: for idx in train_indices:
train_set.append(self[idx]) train_set.append(self[idx])
return train_set, test_set


if __name__ == '__main__':
from fastNLP.core.instance import Instance

d = DataSet({'a': list('abc')})
_ = d.a
d.apply(lambda x: x['a'])
print(d[1])
import copy
dd = copy.deepcopy(d)
print(dd.a)
return train_set, dev_set

+ 7
- 49
fastNLP/core/tester.py View File

@@ -3,61 +3,19 @@ from collections import defaultdict
import torch import torch


from fastNLP.core.batch import Batch from fastNLP.core.batch import Batch
from fastNLP.core.metrics import Evaluator
from fastNLP.core.sampler import RandomSampler from fastNLP.core.sampler import RandomSampler




# logger = create_logger(__name__, "./train_test.log")


class Tester(object): class Tester(object):
"""An collection of model inference and evaluation of performance, used over validation/dev set and test set. """ """An collection of model inference and evaluation of performance, used over validation/dev set and test set. """


def __init__(self, **kwargs):
"""
:param kwargs: a dict-like object that has __getitem__ method, can be accessed by "test_args["key_str"]"
"""
def __init__(self, batch_size, evaluator, use_cuda, save_path="./save/", **kwargs):
super(Tester, self).__init__() super(Tester, self).__init__()
"""
"default_args" provides default value for important settings.
The initialization arguments "kwargs" with the same key (name) will override the default value.
"kwargs" must have the same type as "default_args" on corresponding keys.
Otherwise, error will raise.
"""
default_args = {"batch_size": 8,
"use_cuda": False,
"pickle_path": "./save/",
"model_name": "dev_best_model.pkl",
"evaluator": Evaluator()
}
"""
"required_args" is the collection of arguments that users must pass to Trainer explicitly.
This is used to warn users of essential settings in the training.
Specially, "required_args" does not have default value, so they have nothing to do with "default_args".
"""
required_args = {}

for req_key in required_args:
if req_key not in kwargs:
raise ValueError("Tester lacks argument {}".format(req_key))

for key in default_args:
if key in kwargs:
if isinstance(kwargs[key], type(default_args[key])):
default_args[key] = kwargs[key]
else:
msg = "Argument %s type mismatch: expected %s while get %s" % (
key, type(default_args[key]), type(kwargs[key]))
raise ValueError(msg)
else:
# Tester doesn't care about extra arguments
pass
# print(default_args)

self.batch_size = default_args["batch_size"]
self.pickle_path = default_args["pickle_path"]
self.use_cuda = default_args["use_cuda"]
self._evaluator = default_args["evaluator"]

self.batch_size = batch_size
self.pickle_path = save_path
self.use_cuda = use_cuda
self._evaluator = evaluator


self._model = None self._model = None
self.eval_history = [] # evaluation results of all batches self.eval_history = [] # evaluation results of all batches
@@ -72,7 +30,7 @@ class Tester(object):
self.mode(network, is_test=True) self.mode(network, is_test=True)
self.eval_history.clear() self.eval_history.clear()
output, truths = defaultdict(list), defaultdict(list) output, truths = defaultdict(list), defaultdict(list)
data_iterator = Batch(dev_data, self.batch_size, sampler=RandomSampler(), use_cuda=self.use_cuda)
data_iterator = Batch(dev_data, self.batch_size, sampler=RandomSampler(), as_numpy=False)


with torch.no_grad(): with torch.no_grad():
for batch_x, batch_y in data_iterator: for batch_x, batch_y in data_iterator:


+ 33
- 26
fastNLP/core/trainer.py View File

@@ -15,6 +15,8 @@ from fastNLP.core.optimizer import Optimizer
from fastNLP.core.sampler import RandomSampler from fastNLP.core.sampler import RandomSampler
from fastNLP.core.sampler import SequentialSampler from fastNLP.core.sampler import SequentialSampler
from fastNLP.core.tester import Tester from fastNLP.core.tester import Tester
from fastNLP.core.utils import _build_args
from fastNLP.core.utils import _check_arg_dict_list


from fastNLP.core.utils import _check_arg_dict_list from fastNLP.core.utils import _check_arg_dict_list
from fastNLP.core.utils import _build_args from fastNLP.core.utils import _build_args
@@ -78,7 +80,7 @@ class Trainer(object):
epoch = 1 epoch = 1
while epoch <= self.n_epochs: while epoch <= self.n_epochs:


data_iterator = Batch(self.train_data, batch_size=self.batch_size, sampler=RandomSampler())
data_iterator = Batch(self.train_data, batch_size=self.batch_size, sampler=RandomSampler(), as_numpy=False)


self._train_epoch(data_iterator, self.model, epoch, self.dev_data, start) self._train_epoch(data_iterator, self.model, epoch, self.dev_data, start)


@@ -207,9 +209,9 @@ def best_eval_result(self, metrics):
DEFAULT_CHECK_BATCH_SIZE = 2 DEFAULT_CHECK_BATCH_SIZE = 2
DEFAULT_CHECK_NUM_BATCH = 2 DEFAULT_CHECK_NUM_BATCH = 2


IGNORE_CHECK_LEVEL=0
WARNING_CHECK_LEVEL=1
STRICT_CHECK_LEVEL=2
IGNORE_CHECK_LEVEL = 0
WARNING_CHECK_LEVEL = 1
STRICT_CHECK_LEVEL = 2


def _check_code(dataset, model, batch_size=DEFAULT_CHECK_BATCH_SIZE, dev_data=None, check_level=1): def _check_code(dataset, model, batch_size=DEFAULT_CHECK_BATCH_SIZE, dev_data=None, check_level=1):
# check get_loss 方法 # check get_loss 方法
@@ -220,11 +222,20 @@ def _check_code(dataset, model, batch_size=DEFAULT_CHECK_BATCH_SIZE, dev_data=No
batch_size = min(DEFAULT_CHECK_BATCH_SIZE, batch_size) batch_size = min(DEFAULT_CHECK_BATCH_SIZE, batch_size)
batch = Batch(dataset=dataset, batch_size=batch_size, sampler=SequentialSampler()) batch = Batch(dataset=dataset, batch_size=batch_size, sampler=SequentialSampler())
for batch_count, (batch_x, batch_y) in enumerate(batch): for batch_count, (batch_x, batch_y) in enumerate(batch):
_syn_model_data(model, batch_x, batch_y)
# forward check
if batch_count==0:
_check_forward_error(model=model, model_func=model.forward, check_level=check_level,
batch_x=batch_x)
if batch_count == 0:
check_res = _check_arg_dict_list(model.forward, batch_x)
_info_str = ''
if len(check_res.missing) > 0:
if check_level == WARNING_CHECK_LEVEL:
for field_name in check_res.missing:
if hasattr(dataset, field_name):
_info_str += "{} "
_info_str += "Missing argument: [{}] needed by '{}.forward' is not presented in the input.\n"
_info_str += ""
print("")
if len(check_res.unused) > 0:
if check_level == WARNING_CHECK_LEVEL:
_info_str += ""


refined_batch_x = _build_args(model.forward, **batch_x) refined_batch_x = _build_args(model.forward, **batch_x)
output = model(**refined_batch_x) output = model(**refined_batch_x)
@@ -233,10 +244,14 @@ def _check_code(dataset, model, batch_size=DEFAULT_CHECK_BATCH_SIZE, dev_data=No


# loss check # loss check
if batch_count == 0: if batch_count == 0:
_check_loss_evaluate(model=model, model_func=model.get_loss, check_level=check_level,
output=output, batch_y=batch_y)
loss_input = _build_args(model.get_loss, **output, **batch_y)
loss = model.get_loss(**loss_input)
_dict = _check_arg_dict_list(model.loss, [output, batch_y])
if len(_dict) != 0:
pass
loss_input = _build_args(model.loss, **output, **batch_y)
loss = model.loss(**loss_input)
if batch_count == 0:
if isinstance(loss, torch.Tensor):
pass


# check loss output # check loss output
if batch_count == 0: if batch_count == 0:
@@ -248,8 +263,7 @@ def _check_code(dataset, model, batch_size=DEFAULT_CHECK_BATCH_SIZE, dev_data=No
model_name, loss.size() model_name, loss.size()
)) ))
loss.backward() loss.backward()
model.zero_grad()
if batch_count+1>=DEFAULT_CHECK_NUM_BATCH:
if batch_count + 1 >= DEFAULT_CHECK_BATCH_SIZE:
break break
if check_level > IGNORE_CHECK_LEVEL: if check_level > IGNORE_CHECK_LEVEL:
print('Finish checking training process.', flush=True) print('Finish checking training process.', flush=True)
@@ -407,14 +421,7 @@ if __name__ == '__main__':


# trainer = Trainer(dataset, model) # trainer = Trainer(dataset, model)


_check_code(dataset=dataset, model=model, dev_data=dataset, check_level=2)

# _check_forward_error(model=model, model_func=model.forward, check_level=1,
# batch_x=fake_data_dict)

# import inspect
# print(inspect.getfullargspec(model.forward))




if len(_dict) != 0:
pass
refined_batch_x = _build_args(model.forward, **batch_x)
output = model(**refined_batch_x)

+ 2
- 2
fastNLP/core/utils.py View File

@@ -1,8 +1,8 @@
import _pickle import _pickle
import os
import inspect import inspect
from collections import namedtuple
import os
from collections import Counter from collections import Counter
from collections import namedtuple


CheckRes = namedtuple('CheckRes', ['missing', 'unused', 'duplicated', 'required', 'all_needed'], verbose=False) CheckRes = namedtuple('CheckRes', ['missing', 'unused', 'duplicated', 'required', 'all_needed'], verbose=False)




Loading…
Cancel
Save