Browse Source

check code init

tags/v0.2.0
yh yunfan 6 years ago
parent
commit
ce3b002263
6 changed files with 89 additions and 28 deletions
  1. +3
    -8
      fastNLP/core/batch.py
  2. +1
    -12
      fastNLP/core/dataset.py
  3. +3
    -3
      fastNLP/core/fieldarray.py
  4. +75
    -0
      fastNLP/core/trainer.py
  5. +7
    -4
      fastNLP/core/utils.py
  6. +0
    -1
      fastNLP/core/vocabulary.py

+ 3
- 8
fastNLP/core/batch.py View File

@@ -9,20 +9,17 @@ class Batch(object):

"""

def __init__(self, dataset, batch_size, sampler, as_numpy=False, use_cuda=False):
def __init__(self, dataset, batch_size, sampler, as_numpy=False,):
"""

:param dataset: a DataSet object
:param batch_size: int, the size of the batch
:param sampler: a Sampler object
:param use_cuda: bool, whether to use GPU

"""
self.dataset = dataset
self.batch_size = batch_size
self.sampler = sampler
self.as_numpy = as_numpy
self.use_cuda = use_cuda
self.idx_list = None
self.curidx = 0

@@ -53,15 +50,13 @@ class Batch(object):
indices = self.idx_list[self.curidx:endidx]

for field_name, field in self.dataset.get_fields().items():
if field.need_tensor:
if field.is_target or field.is_input:
batch = field.get(indices)
if not self.as_numpy:
batch = torch.from_numpy(batch)
if self.use_cuda:
batch = batch.cuda()
if field.is_target:
batch_y[field_name] = batch
else:
if field.is_input:
batch_x[field_name] = batch

self.curidx = endidx


+ 1
- 12
fastNLP/core/dataset.py View File

@@ -189,26 +189,15 @@ class DataSet(object):
self.field_arrays[name].is_target = val
else:
raise KeyError("{} is not a valid field name.".format(name))
self._set_need_tensor(**fields)
return self

def set_input(self, **fields):
for name, val in fields.items():
if name in self.field_arrays:
assert isinstance(val, bool)
self.field_arrays[name].is_target = not val
self.field_arrays[name].is_input = val
else:
raise KeyError("{} is not a valid field name.".format(name))
self._set_need_tensor(**fields)
return self

def _set_need_tensor(self, **kwargs):
for name, val in kwargs.items():
if name in self.field_arrays:
assert isinstance(val, bool)
self.field_arrays[name].need_tensor = val
else:
raise KeyError
return self

def __getattr__(self, item):


+ 3
- 3
fastNLP/core/fieldarray.py View File

@@ -2,12 +2,12 @@ import numpy as np


class FieldArray(object):
def __init__(self, name, content, padding_val=0, is_target=False, need_tensor=False):
def __init__(self, name, content, padding_val=0, is_target=False, is_input=False):
self.name = name
self.content = content
self.padding_val = padding_val
self.is_target = is_target
self.need_tensor = need_tensor
self.is_input = is_input
self.dtype = None

def __repr__(self):
@@ -27,7 +27,7 @@ class FieldArray(object):
def get(self, idxes):
if isinstance(idxes, int):
return self.content[idxes]
assert self.need_tensor is True
assert self.is_input is True or self.is_target is True
batch_size = len(idxes)
# TODO 当这个fieldArray是seq_length这种只有一位的内容时,不需要padding,需要再讨论一下
if isinstance(self.content[0], int) or isinstance(self.content[0], float):


+ 75
- 0
fastNLP/core/trainer.py View File

@@ -9,6 +9,7 @@ from fastNLP.core.loss import Loss
from fastNLP.core.metrics import Evaluator
from fastNLP.core.optimizer import Optimizer
from fastNLP.core.sampler import RandomSampler
from fastNLP.core.sampler import SequentialSampler
from fastNLP.core.tester import Tester


@@ -194,3 +195,77 @@ def best_eval_result(self, metrics):
return True
else:
return False


from fastNLP.core.utils import _check_arg_dict_list
from fastNLP.core.utils import _build_args

DEFAULT_CHECK_BATCH_SIZE = 2
DEFAULT_CHECK_NUM_BATCH = 2

IGNORE_CHECK_LEVEL=0
WARNING_CHECK_LEVEL=1
STRICT_CHECK_LEVEL=2


def _check_code(dataset, model, batch_size=DEFAULT_CHECK_BATCH_SIZE, dev_data=None, check_level=WARNING_CHECK_LEVEL):
# check loss 方法
if not hasattr(model, 'get_loss'):
raise AttributeError("{} has to have a 'get_loss' function.".format(type(model)))

batch_size = min(DEFAULT_CHECK_BATCH_SIZE, batch_size)
batch = Batch(dataset=dataset, batch_size=batch_size, sampler=SequentialSampler())
for batch_count, (batch_x, batch_y) in enumerate(batch):
if batch_count==0:
check_res = _check_arg_dict_list(model.forward, batch_x)
_info_str = ''
if len(check_res.missing)>0:
if check_level == WARNING_CHECK_LEVEL:
for field_name in check_res.missing:
if hasattr(dataset, field_name):
_info_str += "{} "
_info_str += "Missing argument: [{}] needed by '{}.forward' is not presented in the input.\n"
_info_str += ""
print("")
if len(check_res.unused)>0:
if check_level == WARNING_CHECK_LEVEL:
_info_str += ""

refined_batch_x = _build_args(model.forward, **batch_x)
output = model(**refined_batch_x)
if batch_count == 0:
_dict = _check_arg_dict_list(model.loss, [output, batch_y])
if len(_dict)!=0:
pass
loss_input = _build_args(model.loss, **output, **batch_y)
loss = model.loss(**loss_input)
if batch_count == 0:
if isinstance(loss, torch.Tensor):
pass

loss.backward()

if batch_count+1>=DEFAULT_CHECK_BATCH_SIZE:
break

dev_batch = Batch(dataset=dataset, batch_size=batch_size, sampler=SequentialSampler())
if dev_data is not None:
if not hasattr(model, 'evaluate'):
raise AttributeError("If {} wants to do evaluation, {} has to have a 'evaluate' function. Or you can set"
"dev_data to 'None'."
.format(type(model), type(model)))

for batch_count, (batch_x, batch_y) in enumerate(dev_batch):
if batch_count == 0:
_dict = _check_arg_dict_list(model.evaluate, [output, batch_y])

if len(_dict)!=0:
pass
refined_batch_x = _build_args(model.forward, **batch_x)
output = model(**refined_batch_x)







+ 7
- 4
fastNLP/core/utils.py View File

@@ -1,6 +1,11 @@
import _pickle
import os
import inspect
from collections import namedtuple
from collections import Counter

CheckRes = namedtuple('CheckRes', ['missing', 'unused', 'duplicated'], verbose=True)


def save_pickle(obj, pickle_path, file_name):
"""Save an object into a pickle file.
@@ -45,7 +50,7 @@ def pickle_exist(pickle_path, pickle_name):
else:
return False

def build_args(func, **kwargs):
def _build_args(func, **kwargs):
spect = inspect.getfullargspec(func)
if spect.varkw is not None:
return kwargs
@@ -55,11 +60,9 @@ def build_args(func, **kwargs):
output.update({name: val for name, val in kwargs.items() if name in needed_args})
return output

from collections import namedtuple, Counter
CheckRes = namedtuple('CheckRes', ['missing', 'unused', 'duplicated'], verbose=True)

# check args
def check_arg_dict_list(func, args):
def _check_arg_dict_list(func, args):
if isinstance(args, dict):
arg_dict_list = [args]
else:


+ 0
- 1
fastNLP/core/vocabulary.py View File

@@ -60,7 +60,6 @@ class Vocabulary(object):
"""
self.word_count.update(word_lst)


def add(self, word):
self.word_count[word] += 1



Loading…
Cancel
Save