Browse Source

check code init

tags/v0.2.0
yh yunfan 6 years ago
parent
commit
ce3b002263
6 changed files with 89 additions and 28 deletions
  1. +3
    -8
      fastNLP/core/batch.py
  2. +1
    -12
      fastNLP/core/dataset.py
  3. +3
    -3
      fastNLP/core/fieldarray.py
  4. +75
    -0
      fastNLP/core/trainer.py
  5. +7
    -4
      fastNLP/core/utils.py
  6. +0
    -1
      fastNLP/core/vocabulary.py

+ 3
- 8
fastNLP/core/batch.py View File

@@ -9,20 +9,17 @@ class Batch(object):


""" """


def __init__(self, dataset, batch_size, sampler, as_numpy=False, use_cuda=False):
def __init__(self, dataset, batch_size, sampler, as_numpy=False,):
""" """


:param dataset: a DataSet object :param dataset: a DataSet object
:param batch_size: int, the size of the batch :param batch_size: int, the size of the batch
:param sampler: a Sampler object :param sampler: a Sampler object
:param use_cuda: bool, whether to use GPU

""" """
self.dataset = dataset self.dataset = dataset
self.batch_size = batch_size self.batch_size = batch_size
self.sampler = sampler self.sampler = sampler
self.as_numpy = as_numpy self.as_numpy = as_numpy
self.use_cuda = use_cuda
self.idx_list = None self.idx_list = None
self.curidx = 0 self.curidx = 0


@@ -53,15 +50,13 @@ class Batch(object):
indices = self.idx_list[self.curidx:endidx] indices = self.idx_list[self.curidx:endidx]


for field_name, field in self.dataset.get_fields().items(): for field_name, field in self.dataset.get_fields().items():
if field.need_tensor:
if field.is_target or field.is_input:
batch = field.get(indices) batch = field.get(indices)
if not self.as_numpy: if not self.as_numpy:
batch = torch.from_numpy(batch) batch = torch.from_numpy(batch)
if self.use_cuda:
batch = batch.cuda()
if field.is_target: if field.is_target:
batch_y[field_name] = batch batch_y[field_name] = batch
else:
if field.is_input:
batch_x[field_name] = batch batch_x[field_name] = batch


self.curidx = endidx self.curidx = endidx


+ 1
- 12
fastNLP/core/dataset.py View File

@@ -189,26 +189,15 @@ class DataSet(object):
self.field_arrays[name].is_target = val self.field_arrays[name].is_target = val
else: else:
raise KeyError("{} is not a valid field name.".format(name)) raise KeyError("{} is not a valid field name.".format(name))
self._set_need_tensor(**fields)
return self return self


def set_input(self, **fields): def set_input(self, **fields):
for name, val in fields.items(): for name, val in fields.items():
if name in self.field_arrays: if name in self.field_arrays:
assert isinstance(val, bool) assert isinstance(val, bool)
self.field_arrays[name].is_target = not val
self.field_arrays[name].is_input = val
else: else:
raise KeyError("{} is not a valid field name.".format(name)) raise KeyError("{} is not a valid field name.".format(name))
self._set_need_tensor(**fields)
return self

def _set_need_tensor(self, **kwargs):
for name, val in kwargs.items():
if name in self.field_arrays:
assert isinstance(val, bool)
self.field_arrays[name].need_tensor = val
else:
raise KeyError
return self return self


def __getattr__(self, item): def __getattr__(self, item):


+ 3
- 3
fastNLP/core/fieldarray.py View File

@@ -2,12 +2,12 @@ import numpy as np




class FieldArray(object): class FieldArray(object):
def __init__(self, name, content, padding_val=0, is_target=False, need_tensor=False):
def __init__(self, name, content, padding_val=0, is_target=False, is_input=False):
self.name = name self.name = name
self.content = content self.content = content
self.padding_val = padding_val self.padding_val = padding_val
self.is_target = is_target self.is_target = is_target
self.need_tensor = need_tensor
self.is_input = is_input
self.dtype = None self.dtype = None


def __repr__(self): def __repr__(self):
@@ -27,7 +27,7 @@ class FieldArray(object):
def get(self, idxes): def get(self, idxes):
if isinstance(idxes, int): if isinstance(idxes, int):
return self.content[idxes] return self.content[idxes]
assert self.need_tensor is True
assert self.is_input is True or self.is_target is True
batch_size = len(idxes) batch_size = len(idxes)
# TODO 当这个fieldArray是seq_length这种只有一位的内容时,不需要padding,需要再讨论一下 # TODO 当这个fieldArray是seq_length这种只有一位的内容时,不需要padding,需要再讨论一下
if isinstance(self.content[0], int) or isinstance(self.content[0], float): if isinstance(self.content[0], int) or isinstance(self.content[0], float):


+ 75
- 0
fastNLP/core/trainer.py View File

@@ -9,6 +9,7 @@ from fastNLP.core.loss import Loss
from fastNLP.core.metrics import Evaluator from fastNLP.core.metrics import Evaluator
from fastNLP.core.optimizer import Optimizer from fastNLP.core.optimizer import Optimizer
from fastNLP.core.sampler import RandomSampler from fastNLP.core.sampler import RandomSampler
from fastNLP.core.sampler import SequentialSampler
from fastNLP.core.tester import Tester from fastNLP.core.tester import Tester




@@ -194,3 +195,77 @@ def best_eval_result(self, metrics):
return True return True
else: else:
return False return False


from fastNLP.core.utils import _check_arg_dict_list
from fastNLP.core.utils import _build_args

DEFAULT_CHECK_BATCH_SIZE = 2
DEFAULT_CHECK_NUM_BATCH = 2

IGNORE_CHECK_LEVEL=0
WARNING_CHECK_LEVEL=1
STRICT_CHECK_LEVEL=2


def _check_code(dataset, model, batch_size=DEFAULT_CHECK_BATCH_SIZE, dev_data=None, check_level=WARNING_CHECK_LEVEL):
# check loss 方法
if not hasattr(model, 'get_loss'):
raise AttributeError("{} has to have a 'get_loss' function.".format(type(model)))

batch_size = min(DEFAULT_CHECK_BATCH_SIZE, batch_size)
batch = Batch(dataset=dataset, batch_size=batch_size, sampler=SequentialSampler())
for batch_count, (batch_x, batch_y) in enumerate(batch):
if batch_count==0:
check_res = _check_arg_dict_list(model.forward, batch_x)
_info_str = ''
if len(check_res.missing)>0:
if check_level == WARNING_CHECK_LEVEL:
for field_name in check_res.missing:
if hasattr(dataset, field_name):
_info_str += "{} "
_info_str += "Missing argument: [{}] needed by '{}.forward' is not presented in the input.\n"
_info_str += ""
print("")
if len(check_res.unused)>0:
if check_level == WARNING_CHECK_LEVEL:
_info_str += ""

refined_batch_x = _build_args(model.forward, **batch_x)
output = model(**refined_batch_x)
if batch_count == 0:
_dict = _check_arg_dict_list(model.loss, [output, batch_y])
if len(_dict)!=0:
pass
loss_input = _build_args(model.loss, **output, **batch_y)
loss = model.loss(**loss_input)
if batch_count == 0:
if isinstance(loss, torch.Tensor):
pass

loss.backward()

if batch_count+1>=DEFAULT_CHECK_BATCH_SIZE:
break

dev_batch = Batch(dataset=dataset, batch_size=batch_size, sampler=SequentialSampler())
if dev_data is not None:
if not hasattr(model, 'evaluate'):
raise AttributeError("If {} wants to do evaluation, {} has to have a 'evaluate' function. Or you can set"
"dev_data to 'None'."
.format(type(model), type(model)))

for batch_count, (batch_x, batch_y) in enumerate(dev_batch):
if batch_count == 0:
_dict = _check_arg_dict_list(model.evaluate, [output, batch_y])

if len(_dict)!=0:
pass
refined_batch_x = _build_args(model.forward, **batch_x)
output = model(**refined_batch_x)







+ 7
- 4
fastNLP/core/utils.py View File

@@ -1,6 +1,11 @@
import _pickle import _pickle
import os import os
import inspect import inspect
from collections import namedtuple
from collections import Counter

CheckRes = namedtuple('CheckRes', ['missing', 'unused', 'duplicated'], verbose=True)



def save_pickle(obj, pickle_path, file_name): def save_pickle(obj, pickle_path, file_name):
"""Save an object into a pickle file. """Save an object into a pickle file.
@@ -45,7 +50,7 @@ def pickle_exist(pickle_path, pickle_name):
else: else:
return False return False


def build_args(func, **kwargs):
def _build_args(func, **kwargs):
spect = inspect.getfullargspec(func) spect = inspect.getfullargspec(func)
if spect.varkw is not None: if spect.varkw is not None:
return kwargs return kwargs
@@ -55,11 +60,9 @@ def build_args(func, **kwargs):
output.update({name: val for name, val in kwargs.items() if name in needed_args}) output.update({name: val for name, val in kwargs.items() if name in needed_args})
return output return output


from collections import namedtuple, Counter
CheckRes = namedtuple('CheckRes', ['missing', 'unused', 'duplicated'], verbose=True)


# check args # check args
def check_arg_dict_list(func, args):
def _check_arg_dict_list(func, args):
if isinstance(args, dict): if isinstance(args, dict):
arg_dict_list = [args] arg_dict_list = [args]
else: else:


+ 0
- 1
fastNLP/core/vocabulary.py View File

@@ -60,7 +60,6 @@ class Vocabulary(object):
""" """
self.word_count.update(word_lst) self.word_count.update(word_lst)



def add(self, word): def add(self, word):
self.word_count[word] += 1 self.word_count[word] += 1




Loading…
Cancel
Save