Browse Source

Update

* fix bug in DataSet.split
* fix bugs in FieldArray, to allow content as a list
* fix bug in losses check
* ...
tags/v0.2.0^2
FengZiYjun 5 years ago
parent
commit
125c2718e4
10 changed files with 129 additions and 53 deletions
  1. +6
    -0
      fastNLP/core/dataset.py
  2. +19
    -4
      fastNLP/core/fieldarray.py
  3. +6
    -5
      fastNLP/core/losses.py
  4. +6
    -5
      fastNLP/core/metrics.py
  5. +15
    -16
      fastNLP/core/tester.py
  6. +6
    -3
      fastNLP/core/trainer.py
  7. +3
    -3
      fastNLP/core/utils.py
  8. +14
    -4
      fastNLP/models/base_model.py
  9. +11
    -10
      test/core/test_loss.py
  10. +43
    -3
      test/core/test_trainer.py

+ 6
- 0
fastNLP/core/dataset.py View File

@@ -260,6 +260,12 @@ class DataSet(object):
dev_set.append(self[idx]) dev_set.append(self[idx])
for idx in train_indices: for idx in train_indices:
train_set.append(self[idx]) train_set.append(self[idx])
for field_name in self.field_arrays:
train_set.field_arrays[field_name].is_input = self.field_arrays[field_name].is_input
train_set.field_arrays[field_name].is_target = self.field_arrays[field_name].is_target
dev_set.field_arrays[field_name].is_input = self.field_arrays[field_name].is_input
dev_set.field_arrays[field_name].is_target = self.field_arrays[field_name].is_target

return train_set, dev_set return train_set, dev_set


@classmethod @classmethod


+ 19
- 4
fastNLP/core/fieldarray.py View File

@@ -11,7 +11,7 @@ class FieldArray(object):
""" """


:param str name: the name of the FieldArray :param str name: the name of the FieldArray
:param list content: a list of int, float, or other objects.
:param list content: a list of int, float, or a list of list.
:param int padding_val: the integer for padding. Default: 0. :param int padding_val: the integer for padding. Default: 0.
:param bool is_target: If True, this FieldArray is used to compute loss. :param bool is_target: If True, this FieldArray is used to compute loss.
:param bool is_input: If True, this FieldArray is used to the model input. :param bool is_input: If True, this FieldArray is used to the model input.
@@ -26,7 +26,14 @@ class FieldArray(object):


@staticmethod @staticmethod
def _type_detection(content): def _type_detection(content):
type_set = set([type(item) for item in content])

if isinstance(content, list) and len(content) > 0 and isinstance(content[0], list):
# 2-D list
# TODO: refactor
type_set = set([type(item) for item in content[0]])
else:
# 1-D list
type_set = set([type(item) for item in content])
if len(type_set) == 1 and any(basic_type in type_set for basic_type in (str, int, float)): if len(type_set) == 1 and any(basic_type in type_set for basic_type in (str, int, float)):
return type_set.pop() return type_set.pop()
elif len(type_set) == 2 and float in type_set and int in type_set: elif len(type_set) == 2 and float in type_set and int in type_set:
@@ -48,7 +55,7 @@ class FieldArray(object):
def append(self, val): def append(self, val):
"""Add a new item to the tail of FieldArray. """Add a new item to the tail of FieldArray.


:param val: int, float, or str.
:param val: int, float, str, or a list of them.
""" """
val_type = type(val) val_type = type(val)
if val_type is int and self.pytype is float: if val_type is int and self.pytype is float:
@@ -60,9 +67,17 @@ class FieldArray(object):
self.content[idx] = float(self.content[idx]) self.content[idx] = float(self.content[idx])
self.pytype = float self.pytype = float
self.dtype = self._map_to_np_type(self.pytype) self.dtype = self._map_to_np_type(self.pytype)

elif val_type is list:
if len(val) == 0:
raise ValueError("Cannot append an empty list.")
else:
if type(val[0]) != self.pytype:
raise ValueError(
"Cannot append a list of {}-type value into a {}-tpye FieldArray.".
format(type(val[0]), self.pytype))
elif val_type != self.pytype: elif val_type != self.pytype:
raise ValueError("Cannot append a {}-type value into a {}-tpye FieldArray.".format(val_type, self.pytype)) raise ValueError("Cannot append a {}-type value into a {}-tpye FieldArray.".format(val_type, self.pytype))

self.content.append(val) self.content.append(val)


def __getitem__(self, indices): def __getitem__(self, indices):


+ 6
- 5
fastNLP/core/losses.py View File

@@ -3,11 +3,11 @@ import torch.nn.functional as F


from fastNLP.core.utils import CheckError from fastNLP.core.utils import CheckError
from fastNLP.core.utils import CheckRes from fastNLP.core.utils import CheckRes
from fastNLP.core.utils import _build_args
from fastNLP.core.utils import _check_function_or_method
from fastNLP.core.utils import _get_arg_list from fastNLP.core.utils import _get_arg_list
from fastNLP.core.utils import _map_args from fastNLP.core.utils import _map_args
from fastNLP.core.utils import get_func_signature from fastNLP.core.utils import get_func_signature
from fastNLP.core.utils import _build_args
from fastNLP.core.utils import _check_function_or_method




class LossBase(object): class LossBase(object):
@@ -71,7 +71,8 @@ class LossBase(object):


if len(duplicated) > 0 or len(missing) > 0: if len(duplicated) > 0 or len(missing) > 0:
raise CheckError( raise CheckError(
CheckRes(missing=missing, unused=[], duplicated=duplicated, required=[], all_needed=[]),
CheckRes(missing=missing, unused=[], duplicated=duplicated, required=[], all_needed=[],
varargs=varargs),
func_signature=get_func_signature(self.get_loss) func_signature=get_func_signature(self.get_loss)
) )


@@ -90,9 +91,9 @@ class LossBase(object):
return loss return loss




class NewLoss(LossBase):
class LossFunc(LossBase):
def __init__(self, func, key_map=None, **kwargs): def __init__(self, func, key_map=None, **kwargs):
super(NewLoss, self).__init__()
super(LossFunc, self).__init__()
_check_function_or_method(func) _check_function_or_method(func)
if key_map is not None: if key_map is not None:
if not isinstance(key_map, dict): if not isinstance(key_map, dict):


+ 6
- 5
fastNLP/core/metrics.py View File

@@ -1,17 +1,18 @@


import warnings
import inspect import inspect
import warnings
from collections import defaultdict from collections import defaultdict


import numpy as np import numpy as np
import torch import torch


from fastNLP.core.utils import get_func_signature
from fastNLP.core.utils import _check_arg_dict_list
from fastNLP.core.utils import _build_args
from fastNLP.core.utils import CheckError from fastNLP.core.utils import CheckError
from fastNLP.core.utils import _build_args
from fastNLP.core.utils import _check_arg_dict_list
from fastNLP.core.utils import get_func_signature
from fastNLP.core.utils import seq_lens_to_masks from fastNLP.core.utils import seq_lens_to_masks



class MetricBase(object): class MetricBase(object):
def __init__(self): def __init__(self):
self.param_map = {} # key is param in function, value is input param. self.param_map = {} # key is param in function, value is input param.
@@ -46,7 +47,7 @@ class MetricBase(object):
if value is None: if value is None:
self.param_map[key] = key self.param_map[key] = key
continue continue
if isinstance(value, str):
if not isinstance(value, str):
raise TypeError(f"in {key}={value}, value must be `str`, not `{type(value)}`.") raise TypeError(f"in {key}={value}, value must be `str`, not `{type(value)}`.")
self.param_map[key] = value self.param_map[key] = value
value_counter[value].add(key) value_counter[value].add(key)


+ 15
- 16
fastNLP/core/tester.py View File

@@ -1,18 +1,18 @@
import itertools
from collections import defaultdict from collections import defaultdict


import torch import torch
from torch import nn from torch import nn


from fastNLP.core.batch import Batch from fastNLP.core.batch import Batch
from fastNLP.core.sampler import SequentialSampler
from fastNLP.core.dataset import DataSet from fastNLP.core.dataset import DataSet
from fastNLP.core.metrics import _prepare_metrics
from fastNLP.core.sampler import SequentialSampler
from fastNLP.core.utils import CheckError from fastNLP.core.utils import CheckError
from fastNLP.core.utils import _build_args from fastNLP.core.utils import _build_args
from fastNLP.core.utils import get_func_signature
from fastNLP.core.utils import _move_dict_value_to_device
from fastNLP.core.metrics import _prepare_metrics
from fastNLP.core.utils import _check_loss_evaluate from fastNLP.core.utils import _check_loss_evaluate
from fastNLP.core.utils import _move_dict_value_to_device
from fastNLP.core.utils import get_func_signature



class Tester(object): class Tester(object):
"""An collection of model inference and evaluation of performance, used over validation/dev set and test set. """ """An collection of model inference and evaluation of performance, used over validation/dev set and test set. """
@@ -27,16 +27,6 @@ class Tester(object):


self.metrics = _prepare_metrics(metrics) self.metrics = _prepare_metrics(metrics)


# check predict
if hasattr(self._model, 'predict'):
self._predict_func = self._model.predict
if not callable(self._predict_func):
_model_name = model.__class__.__name__
raise TypeError(f"`{_model_name}.predict` must be callable to be used "
f"for evaluation, not `{type(self._predict_func)}`.")
else:
self._predict_func = self._model.forward

self.data = data self.data = data
if torch.cuda.is_available() and self.use_cuda: if torch.cuda.is_available() and self.use_cuda:
self._model = model.cuda() self._model = model.cuda()
@@ -45,9 +35,18 @@ class Tester(object):
self.use_cuda = use_cuda self.use_cuda = use_cuda
self.batch_size = batch_size self.batch_size = batch_size
self.verbose = verbose self.verbose = verbose

self._model_device = model.parameters().__next__().device self._model_device = model.parameters().__next__().device


# check predict
if hasattr(self._model, 'predict'):
self._predict_func = self._model.predict
if not callable(self._predict_func):
_model_name = model.__class__.__name__
raise TypeError(f"`{_model_name}.predict` must be callable to be used "
f"for evaluation, not `{type(self._predict_func)}`.")
else:
self._predict_func = self._model.forward

def test(self): def test(self):
# turn on the testing mode; clean up the history # turn on the testing mode; clean up the history
network = self._model network = self._model


+ 6
- 3
fastNLP/core/trainer.py View File

@@ -80,8 +80,9 @@ class Trainer(object):
# parse metric_key # parse metric_key
# increase_better is True. It means the exp result gets better if the indicator increases. # increase_better is True. It means the exp result gets better if the indicator increases.
# It is true by default. # It is true by default.
self.increase_better = False if metric_key[0] == "-" else True
self.increase_better = True
if metric_key is not None: if metric_key is not None:
self.increase_better = False if metric_key[0] == "-" else True
self.metric_key = metric_key[1:] if metric_key[0] == "+" or metric_key[0] == "-" else metric_key self.metric_key = metric_key[1:] if metric_key[0] == "+" or metric_key[0] == "-" else metric_key
else: else:
self.metric_key = None self.metric_key = None
@@ -208,10 +209,12 @@ class Trainer(object):
def _do_validation(self): def _do_validation(self):
res = self.tester.test() res = self.tester.test()
for name, num in res.items(): for name, num in res.items():
self._summary_writer.add_scalar("valid_{}".format(name), num, global_step=self.step)
pass
# self._summary_writer.add_scalar("valid_{}".format(name), num, global_step=self.step)
if self.save_path is not None and self._better_eval_result(res): if self.save_path is not None and self._better_eval_result(res):
metric_key = self.metric_key if self.metric_key is not None else "None"
self._save_model(self.model, self._save_model(self.model,
"best_" + "_".join([self.model.__class__.__name__, self.metric_key, self.start_time]))
"best_" + "_".join([self.model.__class__.__name__, metric_key, self.start_time]))


def _mode(self, model, is_test=False): def _mode(self, model, is_test=False):
"""Train mode or Test mode. This is for PyTorch currently. """Train mode or Test mode. This is for PyTorch currently.


+ 3
- 3
fastNLP/core/utils.py View File

@@ -5,9 +5,8 @@ import warnings
from collections import Counter from collections import Counter
from collections import namedtuple from collections import namedtuple


import torch
import numpy as np import numpy as np
import torch


CheckRes = namedtuple('CheckRes', ['missing', 'unused', 'duplicated', 'required', 'all_needed', CheckRes = namedtuple('CheckRes', ['missing', 'unused', 'duplicated', 'required', 'all_needed',
'varargs'], verbose=False) 'varargs'], verbose=False)
@@ -266,7 +265,8 @@ def _check_forward_error(forward_func, batch_x, check_level):
if check_res.varargs: if check_res.varargs:
errs.append(f"\tvarargs: {check_res.varargs}(Does not support pass positional arguments, please delete it)") errs.append(f"\tvarargs: {check_res.varargs}(Does not support pass positional arguments, please delete it)")
if check_res.missing: if check_res.missing:
errs.append(f"\tmissing param: {check_res.missing}, provided with {list(batch_x.keys())}.")
errs.append(f"\tmissing param: {check_res.missing}, provided with {list(batch_x.keys())}. "
f"Please set {check_res.missing} as input.")
if check_res.unused: if check_res.unused:
_unused = [f"\tunused param: {check_res.unused}"] _unused = [f"\tunused param: {check_res.unused}"]
if check_level == STRICT_CHECK_LEVEL: if check_level == STRICT_CHECK_LEVEL:


+ 14
- 4
fastNLP/models/base_model.py View File

@@ -1,7 +1,5 @@
import torch import torch


from fastNLP.core.trainer import Trainer



class BaseModel(torch.nn.Module): class BaseModel(torch.nn.Module):
"""Base PyTorch model for all models. """Base PyTorch model for all models.
@@ -11,8 +9,20 @@ class BaseModel(torch.nn.Module):
super(BaseModel, self).__init__() super(BaseModel, self).__init__()


def fit(self, train_data, dev_data=None, **train_args): def fit(self, train_data, dev_data=None, **train_args):
trainer = Trainer(**train_args)
trainer.train(self, train_data, dev_data)
raise NotImplementedError


def predict(self, *args, **kwargs): def predict(self, *args, **kwargs):
raise NotImplementedError raise NotImplementedError


class LinearClassifier(BaseModel):
def __init__(self, in_feature_dim, out_feature_dim):
super(LinearClassifier, self).__init__()
self.linear = torch.nn.Linear(in_feature_dim, out_feature_dim)
self.softmax = torch.nn.Softmax()

def forward(self, x):
return {"predict": self.softmax(self.linear(x))}

def predict(self, x):
return {"predict": self.softmax(self.linear(x))}

+ 11
- 10
test/core/test_loss.py View File

@@ -16,7 +16,8 @@ class TestLoss(unittest.TestCase):


# loss_func = loss.Loss("nll") # loss_func = loss.Loss("nll")
print(callable(tc.nn.NLLLoss)) print(callable(tc.nn.NLLLoss))
loss_func = loss.NewLoss(F.nll_loss)

loss_func = loss.LossFunc(F.nll_loss)


nll_loss = loss.NLLLoss() nll_loss = loss.NLLLoss()


@@ -330,36 +331,36 @@ class TestLoss(unittest.TestCase):
c = kwargs['c'] c = kwargs['c']
return (a + b) * c return (a + b) * c


import torch
from fastNLP.core.losses import LossBase, NewLoss


get_loss = NewLoss(func, {'a': 'predict', 'b': 'truth'})
from fastNLP.core.losses import LossFunc

get_loss = LossFunc(func, {'a': 'predict', 'b': 'truth'})
predict = torch.randn(5, 3) predict = torch.randn(5, 3)
truth = torch.LongTensor([1, 0, 1, 2, 1]) truth = torch.LongTensor([1, 0, 1, 2, 1])
loss1 = get_loss({'predict': predict}, {'truth': truth}) loss1 = get_loss({'predict': predict}, {'truth': truth})
get_loss_2 = NewLoss(func2, {'a': 'predict'})
get_loss_2 = LossFunc(func2, {'a': 'predict'})
loss2 = get_loss_2({'predict': predict}, {'truth': truth}) loss2 = get_loss_2({'predict': predict}, {'truth': truth})
get_loss_3 = NewLoss(func3)
get_loss_3 = LossFunc(func3)
loss3 = get_loss_3({'predict': predict}, {'truth': truth}) loss3 = get_loss_3({'predict': predict}, {'truth': truth})
print(loss1, loss2, loss3) print(loss1, loss2, loss3)
assert loss1 == loss2 and loss1 == loss3 assert loss1 == loss2 and loss1 == loss3


get_loss_4 = NewLoss(func4)
get_loss_4 = LossFunc(func4)
loss4 = get_loss_4({'a': 1, 'b': 3}, {}) loss4 = get_loss_4({'a': 1, 'b': 3}, {})
print(loss4) print(loss4)
assert loss4 == (1 + 3) * 2 assert loss4 == (1 + 3) * 2


get_loss_5 = NewLoss(func4)
get_loss_5 = LossFunc(func4)
loss5 = get_loss_5({'a': 1, 'b': 3}, {'c': 4}) loss5 = get_loss_5({'a': 1, 'b': 3}, {'c': 4})
print(loss5) print(loss5)
assert loss5 == (1 + 3) * 4 assert loss5 == (1 + 3) * 4


get_loss_6 = NewLoss(func6)
get_loss_6 = LossFunc(func6)
loss6 = get_loss_6({'a': 1, 'b': 3}, {'c': 4}) loss6 = get_loss_6({'a': 1, 'b': 3}, {'c': 4})
print(loss6) print(loss6)
assert loss6 == (1 + 3) * 4 assert loss6 == (1 + 3) * 4


get_loss_7 = NewLoss(func6, c='cc')
get_loss_7 = LossFunc(func6, c='cc')
loss7 = get_loss_7({'a': 1, 'b': 3}, {'cc': 4}) loss7 = get_loss_7({'a': 1, 'b': 3}, {'cc': 4})
print(loss7) print(loss7)
assert loss7 == (1 + 3) * 4 assert loss7 == (1 + 3) * 4


+ 43
- 3
test/core/test_trainer.py View File

@@ -1,7 +1,47 @@
import unittest import unittest


import numpy as np
import torch


class TestTrainer(unittest.TestCase):
def test_case_1(self):
pass
from fastNLP.core.dataset import DataSet
from fastNLP.core.instance import Instance
from fastNLP.core.losses import LossFunc
from fastNLP.core.metrics import AccuracyMetric
from fastNLP.core.optimizer import SGD
from fastNLP.core.trainer import Trainer
from fastNLP.models.base_model import LinearClassifier



class TrainerTestGround(unittest.TestCase):
def test_case(self):
mean = np.array([-3, -3])
cov = np.array([[1, 0], [0, 1]])
class_A = np.random.multivariate_normal(mean, cov, size=(1000,))

mean = np.array([3, 3])
cov = np.array([[1, 0], [0, 1]])
class_B = np.random.multivariate_normal(mean, cov, size=(1000,))

data_set = DataSet([Instance(x=[float(item[0]), float(item[1])], y=[0.0]) for item in class_A] +
[Instance(x=[float(item[0]), float(item[1])], y=[1.0]) for item in class_B])

data_set.set_input("x", flag=True)
data_set.set_target("y", flag=True)

train_set, dev_set = data_set.split(0.3)

model = LinearClassifier(2, 1)

trainer = Trainer(train_set, model,
losser=LossFunc(torch.nn.functional.binary_cross_entropy,
key_map={"target": "y", "input": "predict"}),
metrics=AccuracyMetric(pred="predict", target="y"),
n_epochs=10,
batch_size=32,
print_every=10,
validate_every=-1,
dev_data=dev_set,
optimizer=SGD(0.001),
check_code_level=2
)
trainer.train()

Loading…
Cancel
Save