@@ -6,16 +6,39 @@ | |||||
 |  | ||||
[](http://fastnlp.readthedocs.io/?badge=latest) | [](http://fastnlp.readthedocs.io/?badge=latest) | ||||
fastNLP is a modular Natural Language Processing system based on PyTorch, for fast development of NLP tools. It divides the NLP model based on deep learning into different modules. These modules fall into 4 categories: encoder, interaction, aggregation and decoder, while each category contains different implemented modules. Encoder modules encode the input into some abstract representation, interaction modules make the information in the representation interact with each other, aggregation modules aggregate and reduce information, and decoder modules decode the representation into the output. Most current NLP models could be built on these modules, which vastly simplifies the process of developing NLP models. The architecture of fastNLP is as the figure below: | |||||
FastNLP is a modular Natural Language Processing system based on PyTorch, built for fast development of NLP models. | |||||
 | |||||
 | |||||
A deep learning NLP model is the composition of three types of modules: | |||||
<table> | |||||
<tr> | |||||
<td><b> module type </b></td> | |||||
<td><b> functionality </b></td> | |||||
<td><b> example </b></td> | |||||
</tr> | |||||
<tr> | |||||
<td> encoder </td> | |||||
<td> encode the input into some abstract representation </td> | |||||
<td> embedding, RNN, CNN, transformer | |||||
</tr> | |||||
<tr> | |||||
<td> aggregator </td> | |||||
<td> aggregate and reduce information </td> | |||||
<td> self-attention, max-pooling </td> | |||||
</tr> | |||||
<tr> | |||||
<td> decoder </td> | |||||
<td> decode the representation into the output </td> | |||||
<td> MLP, CRF </td> | |||||
</tr> | |||||
For example: | |||||
 | |||||
## Requirements | ## Requirements | ||||
- numpy>=1.14.2 | - numpy>=1.14.2 | ||||
- torch>=0.4.0 | - torch>=0.4.0 | ||||
- torchvision>=0.1.8 | |||||
- tensorboardX | - tensorboardX | ||||
@@ -39,12 +62,12 @@ pip install fastNLP | |||||
<td> an open-source NLP library </td> | <td> an open-source NLP library </td> | ||||
</tr> | </tr> | ||||
<tr> | <tr> | ||||
<td><b> fastNLP.core </b></td> | |||||
<td> trainer, tester, predictor </td> | |||||
<td><b> fastNLP.api </b></td> | |||||
<td> APIs for end-to-end prediction </td> | |||||
</tr> | </tr> | ||||
<tr> | <tr> | ||||
<td><b> fastNLP.loader </b></td> | |||||
<td> all kinds of loaders/readers </td> | |||||
<td><b> fastNLP.core </b></td> | |||||
<td> data representation & train/test presedure </td> | |||||
</tr> | </tr> | ||||
<tr> | <tr> | ||||
<td><b> fastNLP.models </b></td> | <td><b> fastNLP.models </b></td> | ||||
@@ -55,11 +78,7 @@ pip install fastNLP | |||||
<td> a collection of PyTorch sub-models/components/wheels </td> | <td> a collection of PyTorch sub-models/components/wheels </td> | ||||
</tr> | </tr> | ||||
<tr> | <tr> | ||||
<td><b> fastNLP.saver </b></td> | |||||
<td> all kinds of savers/writers </td> | |||||
</tr> | |||||
<tr> | |||||
<td><b> fastNLP.fastnlp </b></td> | |||||
<td> a high-level interface for prediction </td> | |||||
<td><b> fastNLP.io </b></td> | |||||
<td> readers & savers </td> | |||||
</tr> | </tr> | ||||
</table> | </table> |
@@ -1 +1,2 @@ | |||||
# FastNLP Quick Tutorial | |||||
# FastNLP Quick Tutorial | |||||
@@ -1,5 +1,3 @@ | |||||
import torch | |||||
import hashlib | import hashlib | ||||
import os | import os | ||||
import re | import re | ||||
@@ -7,6 +5,8 @@ import shutil | |||||
import sys | import sys | ||||
import tempfile | import tempfile | ||||
import torch | |||||
try: | try: | ||||
from requests.utils import urlparse | from requests.utils import urlparse | ||||
from requests import get as urlopen | from requests import get as urlopen | ||||
@@ -132,7 +132,3 @@ if tqdm is None: | |||||
sys.stderr.write('\n') | sys.stderr.write('\n') | ||||
if __name__ == '__main__': | |||||
pipeline = load_url('http://10.141.208.102:5000/file/download/infer_context-4e86fd93.pkl', model_dir='.') | |||||
print(type(pipeline)) |
@@ -1,14 +1,15 @@ | |||||
import torch | |||||
from collections import defaultdict | |||||
import re | import re | ||||
from collections import defaultdict | |||||
import torch | |||||
from fastNLP.core.dataset import DataSet | |||||
from fastNLP.core.vocabulary import Vocabulary | |||||
from fastNLP.core.batch import Batch | from fastNLP.core.batch import Batch | ||||
from fastNLP.core.dataset import DataSet | |||||
from fastNLP.core.sampler import SequentialSampler | from fastNLP.core.sampler import SequentialSampler | ||||
from fastNLP.core.vocabulary import Vocabulary | |||||
class Processor: | |||||
class Processor(object): | |||||
def __init__(self, field_name, new_added_field_name): | def __init__(self, field_name, new_added_field_name): | ||||
self.field_name = field_name | self.field_name = field_name | ||||
if new_added_field_name is None: | if new_added_field_name is None: | ||||
@@ -17,7 +18,7 @@ class Processor: | |||||
self.new_added_field_name = new_added_field_name | self.new_added_field_name = new_added_field_name | ||||
def process(self, *args, **kwargs): | def process(self, *args, **kwargs): | ||||
pass | |||||
raise NotImplementedError | |||||
def __call__(self, *args, **kwargs): | def __call__(self, *args, **kwargs): | ||||
return self.process(*args, **kwargs) | return self.process(*args, **kwargs) | ||||
@@ -132,13 +133,14 @@ class Num2TagProcessor(Processor): | |||||
class IndexerProcessor(Processor): | class IndexerProcessor(Processor): | ||||
def __init__(self, vocab, field_name, new_added_field_name, delete_old_field=False): | |||||
def __init__(self, vocab, field_name, new_added_field_name, delete_old_field=False, is_input=True): | |||||
assert isinstance(vocab, Vocabulary), "Only Vocabulary class is allowed, not {}.".format(type(vocab)) | assert isinstance(vocab, Vocabulary), "Only Vocabulary class is allowed, not {}.".format(type(vocab)) | ||||
super(IndexerProcessor, self).__init__(field_name, new_added_field_name) | super(IndexerProcessor, self).__init__(field_name, new_added_field_name) | ||||
self.vocab = vocab | self.vocab = vocab | ||||
self.delete_old_field = delete_old_field | self.delete_old_field = delete_old_field | ||||
self.is_input = is_input | |||||
def set_vocab(self, vocab): | def set_vocab(self, vocab): | ||||
assert isinstance(vocab, Vocabulary), "Only Vocabulary class is allowed, not {}.".format(type(vocab)) | assert isinstance(vocab, Vocabulary), "Only Vocabulary class is allowed, not {}.".format(type(vocab)) | ||||
@@ -146,13 +148,14 @@ class IndexerProcessor(Processor): | |||||
self.vocab = vocab | self.vocab = vocab | ||||
def process(self, dataset): | def process(self, dataset): | ||||
assert isinstance(dataset, DataSet), "Only Dataset class is allowed, not {}.".format(type(dataset)) | |||||
assert isinstance(dataset, DataSet), "Only DataSet class is allowed, not {}.".format(type(dataset)) | |||||
for ins in dataset: | for ins in dataset: | ||||
tokens = ins[self.field_name] | tokens = ins[self.field_name] | ||||
index = [self.vocab.to_index(token) for token in tokens] | index = [self.vocab.to_index(token) for token in tokens] | ||||
ins[self.new_added_field_name] = index | ins[self.new_added_field_name] = index | ||||
dataset._set_need_tensor(**{self.new_added_field_name: True}) | |||||
if self.is_input: | |||||
dataset.set_input(self.new_added_field_name) | |||||
if self.delete_old_field: | if self.delete_old_field: | ||||
dataset.delete_field(self.field_name) | dataset.delete_field(self.field_name) | ||||
@@ -161,6 +164,9 @@ class IndexerProcessor(Processor): | |||||
class VocabProcessor(Processor): | class VocabProcessor(Processor): | ||||
"""Build vocabulary with a field in the data set. | |||||
""" | |||||
def __init__(self, field_name): | def __init__(self, field_name): | ||||
super(VocabProcessor, self).__init__(field_name, None) | super(VocabProcessor, self).__init__(field_name, None) | ||||
self.vocab = Vocabulary() | self.vocab = Vocabulary() | ||||
@@ -178,17 +184,20 @@ class VocabProcessor(Processor): | |||||
class SeqLenProcessor(Processor): | class SeqLenProcessor(Processor): | ||||
def __init__(self, field_name, new_added_field_name='seq_lens'): | |||||
def __init__(self, field_name, new_added_field_name='seq_lens', is_input=True): | |||||
super(SeqLenProcessor, self).__init__(field_name, new_added_field_name) | super(SeqLenProcessor, self).__init__(field_name, new_added_field_name) | ||||
self.is_input = is_input | |||||
def process(self, dataset): | def process(self, dataset): | ||||
assert isinstance(dataset, DataSet), "Only Dataset class is allowed, not {}.".format(type(dataset)) | assert isinstance(dataset, DataSet), "Only Dataset class is allowed, not {}.".format(type(dataset)) | ||||
for ins in dataset: | for ins in dataset: | ||||
length = len(ins[self.field_name]) | length = len(ins[self.field_name]) | ||||
ins[self.new_added_field_name] = length | ins[self.new_added_field_name] = length | ||||
dataset._set_need_tensor(**{self.new_added_field_name: True}) | |||||
if self.is_input: | |||||
dataset.set_input(self.new_added_field_name) | |||||
return dataset | return dataset | ||||
class ModelProcessor(Processor): | class ModelProcessor(Processor): | ||||
def __init__(self, model, seq_len_field_name='seq_lens', batch_size=32): | def __init__(self, model, seq_len_field_name='seq_lens', batch_size=32): | ||||
""" | """ | ||||
@@ -238,6 +247,7 @@ class ModelProcessor(Processor): | |||||
device = torch.device(device) | device = torch.device(device) | ||||
self.model.to(device) | self.model.to(device) | ||||
class Index2WordProcessor(Processor): | class Index2WordProcessor(Processor): | ||||
def __init__(self, vocab, field_name, new_added_field_name): | def __init__(self, vocab, field_name, new_added_field_name): | ||||
super(Index2WordProcessor, self).__init__(field_name, new_added_field_name) | super(Index2WordProcessor, self).__init__(field_name, new_added_field_name) | ||||
@@ -251,26 +261,28 @@ class Index2WordProcessor(Processor): | |||||
class SetTensorProcessor(Processor): | class SetTensorProcessor(Processor): | ||||
# TODO: remove it. It is strange. | |||||
def __init__(self, field_dict, default=False): | def __init__(self, field_dict, default=False): | ||||
super(SetTensorProcessor, self).__init__(None, None) | super(SetTensorProcessor, self).__init__(None, None) | ||||
self.field_dict = field_dict | self.field_dict = field_dict | ||||
self.default = default | self.default = default | ||||
def process(self, dataset): | def process(self, dataset): | ||||
set_dict = {name: self.default for name in dataset.get_fields().keys()} | |||||
set_dict = {name: self.default for name in dataset.get_all_fields().keys()} | |||||
set_dict.update(self.field_dict) | set_dict.update(self.field_dict) | ||||
dataset._set_need_tensor(**set_dict) | dataset._set_need_tensor(**set_dict) | ||||
return dataset | return dataset | ||||
class SetIsTargetProcessor(Processor): | class SetIsTargetProcessor(Processor): | ||||
# TODO; remove it. | |||||
def __init__(self, field_dict, default=False): | def __init__(self, field_dict, default=False): | ||||
super(SetIsTargetProcessor, self).__init__(None, None) | super(SetIsTargetProcessor, self).__init__(None, None) | ||||
self.field_dict = field_dict | self.field_dict = field_dict | ||||
self.default = default | self.default = default | ||||
def process(self, dataset): | def process(self, dataset): | ||||
set_dict = {name: self.default for name in dataset.get_fields().keys()} | |||||
set_dict = {name: self.default for name in dataset.get_all_fields().keys()} | |||||
set_dict.update(self.field_dict) | set_dict.update(self.field_dict) | ||||
dataset.set_target(**set_dict) | dataset.set_target(**set_dict) | ||||
return dataset | return dataset |
@@ -1,11 +1,13 @@ | |||||
from .batch import Batch | from .batch import Batch | ||||
from .dataset import DataSet | |||||
# from .dataset import DataSet | |||||
from .fieldarray import FieldArray | from .fieldarray import FieldArray | ||||
from .instance import Instance | from .instance import Instance | ||||
from .metrics import Evaluator, ClassifyEvaluator, SNLIEvaluator, SeqLabelEvaluator | |||||
from .losses import LossFunc, CrossEntropyLoss, L1Loss, BCELoss, NLLLoss, LossInForward | |||||
from .metrics import AccuracyMetric | |||||
from .optimizer import Optimizer, SGD, Adam | |||||
from .sampler import SequentialSampler, BucketSampler, RandomSampler, BaseSampler | from .sampler import SequentialSampler, BucketSampler, RandomSampler, BaseSampler | ||||
from .tester import Tester | from .tester import Tester | ||||
from .trainer import Trainer | from .trainer import Trainer | ||||
from .vocabulary import Vocabulary | from .vocabulary import Vocabulary | ||||
from .optimizer import Optimizer | |||||
from .loss import Loss | |||||
from ..io.dataset_loader import DataSet | |||||
@@ -1,3 +1,4 @@ | |||||
import numpy as np | |||||
import torch | import torch | ||||
@@ -25,6 +26,7 @@ class Batch(object): | |||||
self.as_numpy = as_numpy | self.as_numpy = as_numpy | ||||
self.idx_list = None | self.idx_list = None | ||||
self.curidx = 0 | self.curidx = 0 | ||||
self.num_batches = len(dataset)//batch_size + int(len(dataset)%batch_size!=0) | |||||
def __iter__(self): | def __iter__(self): | ||||
self.idx_list = self.sampler(self.dataset) | self.idx_list = self.sampler(self.dataset) | ||||
@@ -41,11 +43,11 @@ class Batch(object): | |||||
indices = self.idx_list[self.curidx:endidx] | indices = self.idx_list[self.curidx:endidx] | ||||
for field_name, field in self.dataset.get_fields().items(): | |||||
for field_name, field in self.dataset.get_all_fields().items(): | |||||
if field.is_target or field.is_input: | if field.is_target or field.is_input: | ||||
batch = field.get(indices) | batch = field.get(indices) | ||||
if not self.as_numpy: | if not self.as_numpy: | ||||
batch = torch.from_numpy(batch) | |||||
batch = to_tensor(batch, field.dtype) | |||||
if field.is_target: | if field.is_target: | ||||
batch_y[field_name] = batch | batch_y[field_name] = batch | ||||
if field.is_input: | if field.is_input: | ||||
@@ -54,3 +56,14 @@ class Batch(object): | |||||
self.curidx = endidx | self.curidx = endidx | ||||
return batch_x, batch_y | return batch_x, batch_y | ||||
def __len__(self): | |||||
return self.num_batches | |||||
def to_tensor(batch, dtype): | |||||
if dtype in (int, np.int8, np.int16, np.int32, np.int64): | |||||
batch = torch.LongTensor(batch) | |||||
if dtype in (float, np.float32, np.float64): | |||||
batch = torch.FloatTensor(batch) | |||||
return batch |
@@ -1,24 +1,11 @@ | |||||
import _pickle as pickle | |||||
import numpy as np | import numpy as np | ||||
from copy import copy | |||||
from fastNLP.core.fieldarray import FieldArray | from fastNLP.core.fieldarray import FieldArray | ||||
from fastNLP.core.instance import Instance | from fastNLP.core.instance import Instance | ||||
_READERS = {} | |||||
def construct_dataset(sentences): | |||||
"""Construct a data set from a list of sentences. | |||||
:param sentences: list of list of str | |||||
:return dataset: a DataSet object | |||||
""" | |||||
dataset = DataSet() | |||||
for sentence in sentences: | |||||
instance = Instance() | |||||
instance['raw_sentence'] = sentence | |||||
dataset.append(instance) | |||||
return dataset | |||||
from fastNLP.core.utils import get_func_signature | |||||
from fastNLP.io.base_loader import DataLoaderRegister | |||||
class DataSet(object): | class DataSet(object): | ||||
@@ -28,45 +15,13 @@ class DataSet(object): | |||||
""" | """ | ||||
class Instance(object): | |||||
def __init__(self, dataset, idx=-1, **fields): | |||||
self.dataset = dataset | |||||
self.idx = idx | |||||
self.fields = fields | |||||
def __next__(self): | |||||
self.idx += 1 | |||||
if self.idx >= len(self.dataset): | |||||
raise StopIteration | |||||
return copy(self) | |||||
def add_field(self, field_name, field): | |||||
"""Add a new field to the instance. | |||||
:param field_name: str, the name of the field. | |||||
:param field: | |||||
""" | |||||
self.fields[field_name] = field | |||||
def __getitem__(self, name): | |||||
return self.dataset[name][self.idx] | |||||
def __setitem__(self, name, val): | |||||
if name not in self.dataset: | |||||
new_fields = [None] * len(self.dataset) | |||||
self.dataset.add_field(name, new_fields) | |||||
self.dataset[name][self.idx] = val | |||||
def __repr__(self): | |||||
return "\n".join(['{}: {}'.format(name, repr(self.dataset[name][self.idx])) for name | |||||
in self.dataset.get_fields().keys()]) | |||||
def __init__(self, data=None): | def __init__(self, data=None): | ||||
""" | """ | ||||
:param data: a dict or a list. If it is a dict, the key is the name of a field and the value is the field. | |||||
All values must be of the same length. | |||||
If it is a list, it must be a list of Instance objects. | |||||
:param data: a dict or a list. | |||||
If `data` is a dict, the key is the name of a FieldArray and the value is the FieldArray. All values | |||||
must be of the same length. | |||||
If `data` is a list, it must be a list of Instance objects. | |||||
""" | """ | ||||
self.field_arrays = {} | self.field_arrays = {} | ||||
if data is not None: | if data is not None: | ||||
@@ -89,14 +44,95 @@ class DataSet(object): | |||||
return item in self.field_arrays | return item in self.field_arrays | ||||
def __iter__(self): | def __iter__(self): | ||||
return self.Instance(self) | |||||
def iter_func(): | |||||
for idx in range(len(self)): | |||||
yield self[idx] | |||||
return iter_func() | |||||
def _inner_iter(self): | |||||
class Iter_ptr: | |||||
def __init__(self, dataset, idx): | |||||
self.dataset = dataset | |||||
self.idx = idx | |||||
def __getitem__(self, item): | |||||
assert item in self.dataset.field_arrays, "no such field:{} in Instance {}".format(item, self.dataset[ | |||||
self.idx]) | |||||
assert self.idx < len(self.dataset.field_arrays[item]), "index:{} out of range".format(self.idx) | |||||
return self.dataset.field_arrays[item][self.idx] | |||||
def __repr__(self): | |||||
return self.dataset[self.idx].__repr__() | |||||
def inner_iter_func(): | |||||
for idx in range(len(self)): | |||||
yield Iter_ptr(self, idx) | |||||
return inner_iter_func() | |||||
def __getitem__(self, idx): | |||||
"""Fetch Instance(s) at the `idx` position(s) in the dataset. | |||||
Notice: This method returns a copy of the actual instance(s). Any change to the returned value would not modify | |||||
the origin instance(s) of the DataSet. | |||||
If you want to make in-place changes to all Instances, use `apply` method. | |||||
:param idx: can be int or slice. | |||||
:return: If `idx` is int, return an Instance object. | |||||
If `idx` is slice, return a DataSet object. | |||||
""" | |||||
if isinstance(idx, int): | |||||
return Instance(**{name: self.field_arrays[name][idx] for name in self.field_arrays}) | |||||
elif isinstance(idx, slice): | |||||
if idx.start is not None and (idx.start >= len(self) or idx.start <= -len(self)): | |||||
raise RuntimeError(f"Start index {idx.start} out of range 0-{len(self)-1}") | |||||
data_set = DataSet() | |||||
for field in self.field_arrays.values(): | |||||
data_set.add_field(name=field.name, | |||||
fields=field.content[idx], | |||||
padding_val=field.padding_val, | |||||
is_input=field.is_input, | |||||
is_target=field.is_target) | |||||
return data_set | |||||
else: | |||||
raise KeyError("Unrecognized type {} for idx in __getitem__ method".format(type(idx))) | |||||
def __getattr__(self, item): | |||||
# Not tested. Don't use !! | |||||
if item == "field_arrays": | |||||
raise AttributeError | |||||
if isinstance(item, str) and item in self.field_arrays: | |||||
return self.field_arrays[item] | |||||
try: | |||||
reader = DataLoaderRegister.get_reader(item) | |||||
return reader | |||||
except AttributeError: | |||||
raise | |||||
def _convert_ins(self, ins_list): | |||||
if isinstance(ins_list, list): | |||||
for ins in ins_list: | |||||
self.append(ins) | |||||
def __setstate__(self, state): | |||||
self.__dict__ = state | |||||
def __getstate__(self): | |||||
return self.__dict__ | |||||
def __len__(self): | |||||
"""Fetch the length of the dataset. | |||||
:return int length: | |||||
""" | |||||
if len(self.field_arrays) == 0: | |||||
return 0 | |||||
field = iter(self.field_arrays.values()).__next__() | |||||
return len(field) | |||||
def __inner_repr__(self): | |||||
if len(self) < 20: | |||||
return ",\n".join([ins.__repr__() for ins in self]) | |||||
else: | else: | ||||
self.append(ins_list) | |||||
return self[:5].__inner_repr__() + "\n...\n" + self[-5:].__inner_repr__() | |||||
def __repr__(self): | |||||
return "DataSet(" + self.__inner_repr__() + ")" | |||||
def append(self, ins): | def append(self, ins): | ||||
"""Add an instance to the DataSet. | """Add an instance to the DataSet. | ||||
@@ -125,7 +161,9 @@ class DataSet(object): | |||||
:param bool is_target: whether this field is label or target. | :param bool is_target: whether this field is label or target. | ||||
""" | """ | ||||
if len(self.field_arrays) != 0: | if len(self.field_arrays) != 0: | ||||
assert len(self) == len(fields) | |||||
if len(self) != len(fields): | |||||
raise RuntimeError(f"The field to append must have the same size as dataset. " | |||||
f"Dataset size {len(self)} != field size {len(fields)}") | |||||
self.field_arrays[name] = FieldArray(name, fields, padding_val=padding_val, is_target=is_target, | self.field_arrays[name] = FieldArray(name, fields, padding_val=padding_val, is_target=is_target, | ||||
is_input=is_input) | is_input=is_input) | ||||
@@ -136,146 +174,121 @@ class DataSet(object): | |||||
""" | """ | ||||
self.field_arrays.pop(name) | self.field_arrays.pop(name) | ||||
def get_fields(self): | |||||
def get_field(self, field_name): | |||||
if field_name not in self.field_arrays: | |||||
raise KeyError("Field name {} not found in DataSet".format(field_name)) | |||||
return self.field_arrays[field_name] | |||||
def get_all_fields(self): | |||||
"""Return all the fields with their names. | """Return all the fields with their names. | ||||
:return dict field_arrays: the internal data structure of DataSet. | :return dict field_arrays: the internal data structure of DataSet. | ||||
""" | """ | ||||
return self.field_arrays | return self.field_arrays | ||||
def __getitem__(self, idx): | |||||
""" | |||||
:param idx: can be int, slice, or str. | |||||
:return: If `idx` is int, return an Instance object. | |||||
If `idx` is slice, return a DataSet object. | |||||
If `idx` is str, it must be a field name, return the field. | |||||
""" | |||||
if isinstance(idx, int): | |||||
return self.Instance(self, idx, **{name: self.field_arrays[name][idx] for name in self.field_arrays}) | |||||
elif isinstance(idx, slice): | |||||
data_set = DataSet() | |||||
for field in self.field_arrays.values(): | |||||
data_set.add_field(name=field.name, | |||||
fields=field.content[idx], | |||||
padding_val=field.padding_val, | |||||
is_input=field.is_input, | |||||
is_target=field.is_target) | |||||
return data_set | |||||
elif isinstance(idx, str): | |||||
return self.field_arrays[idx] | |||||
else: | |||||
raise KeyError("Unrecognized type {} for idx in __getitem__ method".format(type(idx))) | |||||
def __len__(self): | |||||
if len(self.field_arrays) == 0: | |||||
return 0 | |||||
field = iter(self.field_arrays.values()).__next__() | |||||
return len(field) | |||||
def get_length(self): | def get_length(self): | ||||
"""The same as __len__ | |||||
"""Fetch the length of the dataset. | |||||
:return int length: | |||||
""" | """ | ||||
return len(self) | return len(self) | ||||
def rename_field(self, old_name, new_name): | def rename_field(self, old_name, new_name): | ||||
"""rename a field | |||||
"""Rename a field. | |||||
:param str old_name: | |||||
:param str new_name: | |||||
""" | """ | ||||
if old_name in self.field_arrays: | if old_name in self.field_arrays: | ||||
self.field_arrays[new_name] = self.field_arrays.pop(old_name) | self.field_arrays[new_name] = self.field_arrays.pop(old_name) | ||||
self.field_arrays[new_name].name = new_name | |||||
else: | else: | ||||
raise KeyError("{} is not a valid name. ".format(old_name)) | |||||
raise KeyError("DataSet has no field named {}.".format(old_name)) | |||||
def set_target(self, **fields): | |||||
"""Change the flag of `is_target` for all instance. For fields not set here, leave their `is_target` unchanged. | |||||
def set_target(self, *field_names, flag=True): | |||||
"""Change the target flag of these fields. | |||||
:param key-value pairs for field-name and `is_target` value(True, False). | |||||
:param field_names: a sequence of str, indicating field names | |||||
:param bool flag: Set these fields as target if True. Unset them if False. | |||||
""" | """ | ||||
for name, val in fields.items(): | |||||
for name in field_names: | |||||
if name in self.field_arrays: | if name in self.field_arrays: | ||||
assert isinstance(val, bool) | |||||
self.field_arrays[name].is_target = val | |||||
self.field_arrays[name].is_target = flag | |||||
else: | else: | ||||
raise KeyError("{} is not a valid field name.".format(name)) | raise KeyError("{} is not a valid field name.".format(name)) | ||||
return self | |||||
def set_input(self, **fields): | |||||
for name, val in fields.items(): | |||||
def set_input(self, *field_name, flag=True): | |||||
"""Set the input flag of these fields. | |||||
:param field_name: a sequence of str, indicating field names. | |||||
:param bool flag: Set these fields as input if True. Unset them if False. | |||||
""" | |||||
for name in field_name: | |||||
if name in self.field_arrays: | if name in self.field_arrays: | ||||
assert isinstance(val, bool) | |||||
self.field_arrays[name].is_input = val | |||||
self.field_arrays[name].is_input = flag | |||||
else: | else: | ||||
raise KeyError("{} is not a valid field name.".format(name)) | raise KeyError("{} is not a valid field name.".format(name)) | ||||
return self | |||||
def get_input_name(self): | def get_input_name(self): | ||||
"""Get all field names with `is_input` as True. | |||||
:return list field_names: a list of str | |||||
""" | |||||
return [name for name, field in self.field_arrays.items() if field.is_input] | return [name for name, field in self.field_arrays.items() if field.is_input] | ||||
def get_target_name(self): | def get_target_name(self): | ||||
return [name for name, field in self.field_arrays.items() if field.is_target] | |||||
def __getattr__(self, item): | |||||
# block infinite recursion for copy, pickle | |||||
if item == '__setstate__': | |||||
raise AttributeError(item) | |||||
try: | |||||
return self.field_arrays.__getitem__(item) | |||||
except KeyError: | |||||
pass | |||||
try: | |||||
reader_cls = _READERS[item] | |||||
# add read_*data() support | |||||
def _read(*args, **kwargs): | |||||
data = reader_cls().load(*args, **kwargs) | |||||
self.extend(data) | |||||
return self | |||||
"""Get all field names with `is_target` as True. | |||||
return _read | |||||
except KeyError: | |||||
raise AttributeError('{} does not exist.'.format(item)) | |||||
@classmethod | |||||
def set_reader(cls, method_name): | |||||
"""decorator to add dataloader support | |||||
:return list field_names: a list of str | |||||
""" | """ | ||||
assert isinstance(method_name, str) | |||||
def wrapper(read_cls): | |||||
_READERS[method_name] = read_cls | |||||
return read_cls | |||||
return wrapper | |||||
return [name for name, field in self.field_arrays.items() if field.is_target] | |||||
def apply(self, func, new_field_name=None): | |||||
def apply(self, func, new_field_name=None, **kwargs): | |||||
"""Apply a function to every instance of the DataSet. | """Apply a function to every instance of the DataSet. | ||||
:param func: a function that takes an instance as input. | :param func: a function that takes an instance as input. | ||||
:param str new_field_name: If not None, results of the function will be stored as a new field. | :param str new_field_name: If not None, results of the function will be stored as a new field. | ||||
:return results: returned values of the function over all instances. | |||||
:param **kwargs: Accept parameters will be | |||||
(1) is_input: boolean, will be ignored if new_field is None. If True, the new field will be as input. | |||||
(2) is_target: boolean, will be ignored if new_field is None. If True, the new field will be as target. | |||||
:return results: if new_field_name is not passed, returned values of the function over all instances. | |||||
""" | """ | ||||
results = [func(ins) for ins in self] | |||||
results = [func(ins) for ins in self._inner_iter()] | |||||
if len(list(filter(lambda x: x is not None, results))) == 0: # all None | |||||
raise ValueError("{} always return None.".format(get_func_signature(func=func))) | |||||
extra_param = {} | |||||
if 'is_input' in kwargs: | |||||
extra_param['is_input'] = kwargs['is_input'] | |||||
if 'is_target' in kwargs: | |||||
extra_param['is_target'] = kwargs['is_target'] | |||||
if new_field_name is not None: | if new_field_name is not None: | ||||
if new_field_name in self.field_arrays: | if new_field_name in self.field_arrays: | ||||
# overwrite the field, keep same attributes | # overwrite the field, keep same attributes | ||||
old_field = self.field_arrays[new_field_name] | old_field = self.field_arrays[new_field_name] | ||||
if 'is_input' not in extra_param: | |||||
extra_param['is_input'] = old_field.is_input | |||||
if 'is_target' not in extra_param: | |||||
extra_param['is_target'] = old_field.is_target | |||||
self.add_field(name=new_field_name, | self.add_field(name=new_field_name, | ||||
fields=results, | fields=results, | ||||
padding_val=old_field.padding_val, | padding_val=old_field.padding_val, | ||||
is_input=old_field.is_input, | |||||
is_target=old_field.is_target) | |||||
**extra_param) | |||||
else: | else: | ||||
self.add_field(name=new_field_name, fields=results) | |||||
self.add_field(name=new_field_name, fields=results, **extra_param) | |||||
else: | else: | ||||
return results | return results | ||||
def drop(self, func): | def drop(self, func): | ||||
results = [ins for ins in self if not func(ins)] | |||||
"""Drop instances if a condition holds. | |||||
:param func: a function that takes an Instance object as input, and returns bool. | |||||
The instance will be dropped if the function returns True. | |||||
""" | |||||
results = [ins for ins in self._inner_iter() if not func(ins)] | |||||
for name, old_field in self.field_arrays.items(): | for name, old_field in self.field_arrays.items(): | ||||
self.field_arrays[name].content = [ins[name] for ins in results] | self.field_arrays[name].content = [ins[name] for ins in results] | ||||
# print(self.field_arrays[name]) | |||||
def split(self, dev_ratio): | def split(self, dev_ratio): | ||||
"""Split the dataset into training and development(validation) set. | """Split the dataset into training and development(validation) set. | ||||
@@ -297,30 +310,81 @@ class DataSet(object): | |||||
dev_set.append(self[idx]) | dev_set.append(self[idx]) | ||||
for idx in train_indices: | for idx in train_indices: | ||||
train_set.append(self[idx]) | train_set.append(self[idx]) | ||||
for field_name in self.field_arrays: | |||||
train_set.field_arrays[field_name].is_input = self.field_arrays[field_name].is_input | |||||
train_set.field_arrays[field_name].is_target = self.field_arrays[field_name].is_target | |||||
dev_set.field_arrays[field_name].is_input = self.field_arrays[field_name].is_input | |||||
dev_set.field_arrays[field_name].is_target = self.field_arrays[field_name].is_target | |||||
return train_set, dev_set | return train_set, dev_set | ||||
@classmethod | @classmethod | ||||
def read_csv(cls, csv_path, headers=None, sep='\t', dropna=True): | |||||
with open(csv_path, 'r') as f: | |||||
def read_csv(cls, csv_path, headers=None, sep=",", dropna=True): | |||||
"""Load data from a CSV file and return a DataSet object. | |||||
:param str csv_path: path to the CSV file | |||||
:param List[str] or Tuple[str] headers: headers of the CSV file | |||||
:param str sep: delimiter in CSV file. Default: "," | |||||
:param bool dropna: If True, drop rows that have less entries than headers. | |||||
:return DataSet dataset: | |||||
""" | |||||
with open(csv_path, "r") as f: | |||||
start_idx = 0 | start_idx = 0 | ||||
if headers is None: | if headers is None: | ||||
headers = f.readline().rstrip('\r\n') | headers = f.readline().rstrip('\r\n') | ||||
headers = headers.split(sep) | headers = headers.split(sep) | ||||
start_idx += 1 | start_idx += 1 | ||||
else: | else: | ||||
assert isinstance(headers, (list, tuple)), "headers should be list or tuple, not {}.".format(type(headers)) | |||||
assert isinstance(headers, (list, tuple)), "headers should be list or tuple, not {}.".format( | |||||
type(headers)) | |||||
_dict = {} | _dict = {} | ||||
for col in headers: | for col in headers: | ||||
_dict[col] = [] | _dict[col] = [] | ||||
for line_idx, line in enumerate(f, start_idx): | for line_idx, line in enumerate(f, start_idx): | ||||
contents = line.split(sep) | |||||
if len(contents)!=len(headers): | |||||
contents = line.rstrip('\r\n').split(sep) | |||||
if len(contents) != len(headers): | |||||
if dropna: | if dropna: | ||||
continue | continue | ||||
else: | else: | ||||
#TODO change error type | |||||
raise ValueError("Line {} has {} parts, while header has {} parts."\ | |||||
.format(line_idx, len(contents), len(headers))) | |||||
# TODO change error type | |||||
raise ValueError("Line {} has {} parts, while header has {} parts." \ | |||||
.format(line_idx, len(contents), len(headers))) | |||||
for header, content in zip(headers, contents): | for header, content in zip(headers, contents): | ||||
_dict[header].append(content) | _dict[header].append(content) | ||||
return cls(_dict) | return cls(_dict) | ||||
# def read_pos(self): | |||||
# return DataLoaderRegister.get_reader('read_pos') | |||||
def save(self, path): | |||||
"""Save the DataSet object as pickle. | |||||
:param str path: the path to the pickle | |||||
""" | |||||
with open(path, 'wb') as f: | |||||
pickle.dump(self, f) | |||||
@staticmethod | |||||
def load(path): | |||||
"""Load a DataSet object from pickle. | |||||
:param str path: the path to the pickle | |||||
:return DataSet data_set: | |||||
""" | |||||
with open(path, 'rb') as f: | |||||
return pickle.load(f) | |||||
def construct_dataset(sentences): | |||||
"""Construct a data set from a list of sentences. | |||||
:param sentences: list of list of str | |||||
:return dataset: a DataSet object | |||||
""" | |||||
dataset = DataSet() | |||||
for sentence in sentences: | |||||
instance = Instance() | |||||
instance['raw_sentence'] = sentence | |||||
dataset.append(instance) | |||||
return dataset |
@@ -6,35 +6,150 @@ class FieldArray(object): | |||||
It is the basic element of DataSet class. | It is the basic element of DataSet class. | ||||
""" | """ | ||||
def __init__(self, name, content, padding_val=0, is_target=False, is_input=False): | |||||
def __init__(self, name, content, padding_val=0, is_target=None, is_input=None): | |||||
""" | """ | ||||
:param str name: the name of the FieldArray | :param str name: the name of the FieldArray | ||||
:param list content: a list of int, float, or other objects. | |||||
:param list content: a list of int, float, str or np.ndarray, or a list of list of one, or a np.ndarray. | |||||
:param int padding_val: the integer for padding. Default: 0. | :param int padding_val: the integer for padding. Default: 0. | ||||
:param bool is_target: If True, this FieldArray is used to compute loss. | :param bool is_target: If True, this FieldArray is used to compute loss. | ||||
:param bool is_input: If True, this FieldArray is used to the model input. | :param bool is_input: If True, this FieldArray is used to the model input. | ||||
""" | """ | ||||
self.name = name | self.name = name | ||||
if isinstance(content, list): | |||||
content = content | |||||
elif isinstance(content, np.ndarray): | |||||
content = content.tolist() # convert np.ndarray into 2-D list | |||||
else: | |||||
raise TypeError("content in FieldArray can only be list or numpy.ndarray, got {}.".format(type(content))) | |||||
self.content = content | self.content = content | ||||
self.padding_val = padding_val | self.padding_val = padding_val | ||||
self.is_target = is_target | |||||
self.is_input = is_input | |||||
# TODO: auto detect dtype | |||||
self.dtype = None | |||||
self._is_target = None | |||||
self._is_input = None | |||||
self.BASIC_TYPES = (int, float, str, np.ndarray) | |||||
self.is_2d_list = False | |||||
self.pytype = None # int, float, str, or np.ndarray | |||||
self.dtype = None # np.int64, np.float64, np.str | |||||
if is_input is not None: | |||||
self.is_input = is_input | |||||
if is_target is not None: | |||||
self.is_target = is_target | |||||
@property | |||||
def is_input(self): | |||||
return self._is_input | |||||
@is_input.setter | |||||
def is_input(self, value): | |||||
if value is True: | |||||
self.pytype = self._type_detection(self.content) | |||||
self.dtype = self._map_to_np_type(self.pytype) | |||||
self._is_input = value | |||||
@property | |||||
def is_target(self): | |||||
return self._is_target | |||||
@is_target.setter | |||||
def is_target(self, value): | |||||
if value is True: | |||||
self.pytype = self._type_detection(self.content) | |||||
self.dtype = self._map_to_np_type(self.pytype) | |||||
self._is_target = value | |||||
def _type_detection(self, content): | |||||
""" | |||||
:param content: a list of int, float, str or np.ndarray, or a list of list of one. | |||||
:return type: one of int, float, str, np.ndarray | |||||
""" | |||||
if isinstance(content, list) and len(content) > 0 and isinstance(content[0], list): | |||||
# content is a 2-D list | |||||
if not all(isinstance(_, list) for _ in content): # strict check 2-D list | |||||
raise TypeError("Please provide 2-D list.") | |||||
type_set = set([self._type_detection(x) for x in content]) | |||||
if len(type_set) == 2 and int in type_set and float in type_set: | |||||
type_set = {float} | |||||
elif len(type_set) > 1: | |||||
raise TypeError("Cannot create FieldArray with more than one type. Provided {}".format(type_set)) | |||||
self.is_2d_list = True | |||||
return type_set.pop() | |||||
elif isinstance(content, list): | |||||
# content is a 1-D list | |||||
if len(content) == 0: | |||||
# the old error is not informative enough. | |||||
raise RuntimeError("Cannot create FieldArray with an empty list. Or one element in the list is empty.") | |||||
type_set = set([type(item) for item in content]) | |||||
if len(type_set) == 1 and tuple(type_set)[0] in self.BASIC_TYPES: | |||||
return type_set.pop() | |||||
elif len(type_set) == 2 and float in type_set and int in type_set: | |||||
# up-cast int to float | |||||
return float | |||||
else: | |||||
raise TypeError("Cannot create FieldArray with type {}".format(*type_set)) | |||||
else: | |||||
raise TypeError("Cannot create FieldArray with type {}".format(type(content))) | |||||
@staticmethod | |||||
def _map_to_np_type(basic_type): | |||||
type_mapping = {int: np.int64, float: np.float64, str: np.str, np.ndarray: np.ndarray} | |||||
return type_mapping[basic_type] | |||||
def __repr__(self): | def __repr__(self): | ||||
return "FieldArray {}: {}".format(self.name, self.content.__repr__()) | return "FieldArray {}: {}".format(self.name, self.content.__repr__()) | ||||
def append(self, val): | def append(self, val): | ||||
"""Add a new item to the tail of FieldArray. | |||||
:param val: int, float, str, or a list of one. | |||||
""" | |||||
if self.is_target is True or self.is_input is True: | |||||
# only check type when used as target or input | |||||
val_type = type(val) | |||||
if val_type == list: # shape check | |||||
if self.is_2d_list is False: | |||||
raise RuntimeError("Cannot append a list into a 1-D FieldArray. Please provide an element.") | |||||
if len(val) == 0: | |||||
raise RuntimeError("Cannot append an empty list.") | |||||
val_list_type = set([type(_) for _ in val]) # type check | |||||
if len(val_list_type) == 2 and int in val_list_type and float in val_list_type: | |||||
# up-cast int to float | |||||
val_type = float | |||||
elif len(val_list_type) == 1: | |||||
val_type = val_list_type.pop() | |||||
else: | |||||
raise TypeError("Cannot append a list of {}".format(val_list_type)) | |||||
else: | |||||
if self.is_2d_list is True: | |||||
raise RuntimeError("Cannot append a non-list into a 2-D list. Please provide a list.") | |||||
if val_type == float and self.pytype == int: | |||||
# up-cast | |||||
self.pytype = float | |||||
self.dtype = self._map_to_np_type(self.pytype) | |||||
elif val_type == int and self.pytype == float: | |||||
pass | |||||
elif val_type == self.pytype: | |||||
pass | |||||
else: | |||||
raise TypeError("Cannot append type {} into type {}".format(val_type, self.pytype)) | |||||
self.content.append(val) | self.content.append(val) | ||||
def __getitem__(self, name): | |||||
return self.get(name) | |||||
def __getitem__(self, indices): | |||||
return self.get(indices) | |||||
def __setitem__(self, name, val): | |||||
assert isinstance(name, int) | |||||
self.content[name] = val | |||||
def __setitem__(self, idx, val): | |||||
assert isinstance(idx, int) | |||||
self.content[idx] = val | |||||
def get(self, indices): | def get(self, indices): | ||||
"""Fetch instances based on indices. | """Fetch instances based on indices. | ||||
@@ -44,29 +159,32 @@ class FieldArray(object): | |||||
""" | """ | ||||
if isinstance(indices, int): | if isinstance(indices, int): | ||||
return self.content[indices] | return self.content[indices] | ||||
assert self.is_input is True or self.is_target is True | |||||
if self.is_input is False and self.is_target is False: | |||||
raise RuntimeError("Please specify either is_input or is_target is True for {}".format(self.name)) | |||||
batch_size = len(indices) | batch_size = len(indices) | ||||
# TODO 当这个fieldArray是seq_length这种只有一位的内容时,不需要padding,需要再讨论一下 | |||||
if not isiterable(self.content[0]): | |||||
if self.dtype is None: | |||||
self.dtype = np.int64 if isinstance(self.content[0], int) else np.double | |||||
if not is_iterable(self.content[0]): | |||||
array = np.array([self.content[i] for i in indices], dtype=self.dtype) | array = np.array([self.content[i] for i in indices], dtype=self.dtype) | ||||
else: | |||||
if self.dtype is None: | |||||
self.dtype = np.int64 | |||||
elif self.dtype in (np.int64, np.float64): | |||||
max_len = max([len(self.content[i]) for i in indices]) | max_len = max([len(self.content[i]) for i in indices]) | ||||
array = np.full((batch_size, max_len), self.padding_val, dtype=self.dtype) | array = np.full((batch_size, max_len), self.padding_val, dtype=self.dtype) | ||||
for i, idx in enumerate(indices): | for i, idx in enumerate(indices): | ||||
array[i][:len(self.content[idx])] = self.content[idx] | array[i][:len(self.content[idx])] = self.content[idx] | ||||
else: # should only be str | |||||
array = np.array([self.content[i] for i in indices]) | |||||
return array | return array | ||||
def __len__(self): | def __len__(self): | ||||
"""Returns the size of FieldArray. | |||||
:return int length: | |||||
""" | |||||
return len(self.content) | return len(self.content) | ||||
def isiterable(content): | |||||
def is_iterable(content): | |||||
try: | try: | ||||
_ = (e for e in content) | _ = (e for e in content) | ||||
except TypeError: | except TypeError: | ||||
return False | return False | ||||
return True | |||||
return True |
@@ -1,5 +1,3 @@ | |||||
class Instance(object): | class Instance(object): | ||||
"""An Instance is an example of data. It is the collection of Fields. | """An Instance is an example of data. It is the collection of Fields. | ||||
@@ -33,4 +31,5 @@ class Instance(object): | |||||
return self.add_field(name, field) | return self.add_field(name, field) | ||||
def __repr__(self): | def __repr__(self): | ||||
return self.fields.__repr__() | |||||
return "{" + ",\n".join( | |||||
"\'" + field_name + "\': " + str(self.fields[field_name]) for field_name in self.fields) + "}" |
@@ -1,196 +0,0 @@ | |||||
import torch | |||||
def squash(predict , truth , **kwargs): | |||||
'''To reshape tensors in order to fit Loss functions in pytorch | |||||
:param predict : Tensor, model output | |||||
:param truth : Tensor, truth from dataset | |||||
:param **kwargs : extra arguments | |||||
:return predict , truth: predict & truth after processing | |||||
''' | |||||
return predict.view(-1 , predict.size()[-1]) , truth.view(-1,) | |||||
def unpad(predict , truth , **kwargs): | |||||
'''To process padded sequence output to get true loss | |||||
Using pack_padded_sequence() method | |||||
This method contains squash() | |||||
:param predict : Tensor, [batch_size , max_len , tag_size] | |||||
:param truth : Tensor, [batch_size , max_len] | |||||
:param **kwargs : extra arguments, kwargs["lens"] is expected to be exsist | |||||
kwargs["lens"] : list or LongTensor, [batch_size] | |||||
the i-th element is true lengths of i-th sequence | |||||
:return predict , truth: predict & truth after processing | |||||
''' | |||||
if kwargs.get("lens") is None: | |||||
return predict , truth | |||||
lens = torch.LongTensor(kwargs["lens"]) | |||||
lens , idx = torch.sort(lens , descending = True) | |||||
predict = torch.nn.utils.rnn.pack_padded_sequence(predict[idx] , lens , batch_first = True).data | |||||
truth = torch.nn.utils.rnn.pack_padded_sequence(truth[idx] , lens , batch_first = True).data | |||||
return predict , truth | |||||
def unpad_mask(predict , truth , **kwargs): | |||||
'''To process padded sequence output to get true loss | |||||
Using mask() method | |||||
This method contains squash() | |||||
:param predict : Tensor, [batch_size , max_len , tag_size] | |||||
:param truth : Tensor, [batch_size , max_len] | |||||
:param **kwargs : extra arguments, kwargs["lens"] is expected to be exsist | |||||
kwargs["lens"] : list or LongTensor, [batch_size] | |||||
the i-th element is true lengths of i-th sequence | |||||
:return predict , truth: predict & truth after processing | |||||
''' | |||||
if kwargs.get("lens") is None: | |||||
return predict , truth | |||||
mas = make_mask(kwargs["lens"] , truth.size()[1]) | |||||
return mask(predict , truth , mask = mas) | |||||
def mask(predict , truth , **kwargs): | |||||
'''To select specific elements from Tensor | |||||
This method contains squash() | |||||
:param predict : Tensor, [batch_size , max_len , tag_size] | |||||
:param truth : Tensor, [batch_size , max_len] | |||||
:param **kwargs : extra arguments, kwargs["mask"] is expected to be exsist | |||||
kwargs["mask"] : ByteTensor, [batch_size , max_len] | |||||
the mask Tensor , the position that is 1 will be selected | |||||
:return predict , truth: predict & truth after processing | |||||
''' | |||||
if kwargs.get("mask") is None: | |||||
return predict , truth | |||||
mask = kwargs["mask"] | |||||
predict , truth = squash(predict , truth) | |||||
mask = mask.view(-1,) | |||||
predict = torch.masked_select(predict.permute(1,0) , mask).view(predict.size()[-1] , -1).permute(1,0) | |||||
truth = torch.masked_select(truth , mask) | |||||
return predict , truth | |||||
def make_mask(lens , tar_len): | |||||
'''to generate a mask that select [:lens[i]] for i-th element | |||||
embezzle from fastNLP.models.sequence_modeling.seq_mask | |||||
:param lens : list or LongTensor, [batch_size] | |||||
:param tar_len : int | |||||
:return mask : ByteTensor | |||||
''' | |||||
lens = torch.LongTensor(lens) | |||||
mask = [torch.ge(lens, i + 1) for i in range(tar_len)] | |||||
mask = torch.stack(mask, 1) | |||||
return mask | |||||
#map string to function. Just for more elegant using | |||||
method_dict = { | |||||
"squash" : squash, | |||||
"unpad" : unpad, | |||||
"unpad_mask" : unpad_mask, | |||||
"mask" : mask, | |||||
} | |||||
loss_function_name = { | |||||
"L1Loss".lower() : torch.nn.L1Loss, | |||||
"BCELoss".lower() : torch.nn.BCELoss, | |||||
"MSELoss".lower() : torch.nn.MSELoss, | |||||
"NLLLoss".lower() : torch.nn.NLLLoss, | |||||
"KLDivLoss".lower() : torch.nn.KLDivLoss, | |||||
"NLLLoss2dLoss".lower() : torch.nn.NLLLoss2d, #every name should end with "loss" | |||||
"SmoothL1Loss".lower() : torch.nn.SmoothL1Loss, | |||||
"SoftMarginLoss".lower() : torch.nn.SoftMarginLoss, | |||||
"PoissonNLLLoss".lower() : torch.nn.PoissonNLLLoss, | |||||
"MultiMarginLoss".lower() : torch.nn.MultiMarginLoss, | |||||
"CrossEntropyLoss".lower() : torch.nn.CrossEntropyLoss, | |||||
"BCEWithLogitsLoss".lower() : torch.nn.BCEWithLogitsLoss, | |||||
"MarginRankingLoss".lower() : torch.nn.MarginRankingLoss, | |||||
"TripletMarginLoss".lower() : torch.nn.TripletMarginLoss, | |||||
"HingeEmbeddingLoss".lower() : torch.nn.HingeEmbeddingLoss, | |||||
"CosineEmbeddingLoss".lower() : torch.nn.CosineEmbeddingLoss, | |||||
"MultiLabelMarginLoss".lower() : torch.nn.MultiLabelMarginLoss, | |||||
"MultiLabelSoftMarginLoss".lower() : torch.nn.MultiLabelSoftMarginLoss, | |||||
} | |||||
class Loss(object): | |||||
'''a Loss object is a callable object represents loss functions | |||||
''' | |||||
def __init__(self , loss_name , pre_pro = [squash], **kwargs): | |||||
''' | |||||
:param loss_name: str or None , the name of loss function | |||||
:param pre_pro : list of function or str, methods to reform parameters before calculating loss | |||||
the strings will be auto translated to pre-defined functions | |||||
:param **kwargs: kwargs for torch loss function | |||||
pre_pro funcsions should have three arguments: predict, truth, **arg | |||||
predict and truth is the necessary parameters in loss function | |||||
kwargs is the extra parameters passed-in when calling loss function | |||||
pre_pro functions should return two objects, respectively predict and truth that after processed | |||||
''' | |||||
if loss_name is None: | |||||
# this is useful when Trainer.__init__ performs type check | |||||
self._loss = None | |||||
else: | |||||
if not isinstance(loss_name, str): | |||||
raise NotImplementedError | |||||
else: | |||||
self._loss = self._get_loss(loss_name , **kwargs) | |||||
self.pre_pro = [f if callable(f) else method_dict.get(f) for f in pre_pro] | |||||
def add_pre_pro(self , func): | |||||
'''add a pre_pro function | |||||
:param func: a function or str, methods to reform parameters before calculating loss | |||||
the strings will be auto translated to pre-defined functions | |||||
''' | |||||
if not callable(func): | |||||
func = method_dict.get(func) | |||||
if func is None: | |||||
return | |||||
self.pre_pro.append(func) | |||||
@staticmethod | |||||
def _get_loss(loss_name , **kwargs): | |||||
'''Get loss function from torch | |||||
:param loss_name: str, the name of loss function | |||||
:param **kwargs: kwargs for torch loss function | |||||
:return: A callable loss function object | |||||
''' | |||||
loss_name = loss_name.strip().lower() | |||||
loss_name = "".join(loss_name.split("_")) | |||||
if len(loss_name) < 4 or loss_name[-4 : ] != "loss": | |||||
loss_name += "loss" | |||||
return loss_function_name[loss_name](**kwargs) | |||||
def get(self): | |||||
'''This method exists just for make some existing codes run error-freely | |||||
''' | |||||
return self | |||||
def __call__(self , predict , truth , **kwargs): | |||||
'''call a loss function | |||||
predict and truth will be processed by pre_pro methods in order of addition | |||||
:param predict : Tensor, model output | |||||
:param truth : Tensor, truth from dataset | |||||
:param **kwargs : extra arguments, pass to pre_pro functions | |||||
for example, if used unpad_mask() in pre_pro, there should be a kwarg named lens | |||||
''' | |||||
for f in self.pre_pro: | |||||
if f is None: | |||||
continue | |||||
predict , truth = f(predict , truth , **kwargs) | |||||
return self._loss(predict , truth) |
@@ -0,0 +1,358 @@ | |||||
import inspect | |||||
from collections import defaultdict | |||||
import torch | |||||
import torch.nn.functional as F | |||||
from fastNLP.core.utils import CheckError | |||||
from fastNLP.core.utils import CheckRes | |||||
from fastNLP.core.utils import _build_args | |||||
from fastNLP.core.utils import _check_arg_dict_list | |||||
from fastNLP.core.utils import _check_function_or_method | |||||
from fastNLP.core.utils import get_func_signature | |||||
class LossBase(object): | |||||
def __init__(self): | |||||
self.param_map = {} | |||||
self._checked = False | |||||
def get_loss(self, *args, **kwargs): | |||||
raise NotImplementedError | |||||
def _init_param_map(self, key_map=None, **kwargs): | |||||
"""Check the validity of key_map and other param map. Add these into self.param_map | |||||
:param key_map: dict | |||||
:param kwargs: | |||||
:return: None | |||||
""" | |||||
value_counter = defaultdict(set) | |||||
if key_map is not None: | |||||
if not isinstance(key_map, dict): | |||||
raise TypeError("key_map must be `dict`, got {}.".format(type(key_map))) | |||||
for key, value in key_map.items(): | |||||
if value is None: | |||||
self.param_map[key] = key | |||||
continue | |||||
if not isinstance(key, str): | |||||
raise TypeError(f"key in key_map must be `str`, not `{type(key)}`.") | |||||
if not isinstance(value, str): | |||||
raise TypeError(f"value in key_map must be `str`, not `{type(value)}`.") | |||||
self.param_map[key] = value | |||||
value_counter[value].add(key) | |||||
for key, value in kwargs.items(): | |||||
if value is None: | |||||
self.param_map[key] = key | |||||
continue | |||||
if not isinstance(value, str): | |||||
raise TypeError(f"in {key}={value}, value must be `str`, not `{type(value)}`.") | |||||
self.param_map[key] = value | |||||
value_counter[value].add(key) | |||||
for value, key_set in value_counter.items(): | |||||
if len(key_set) > 1: | |||||
raise ValueError(f"Several parameters:{key_set} are provided with one output {value}.") | |||||
# check consistence between signature and param_map | |||||
func_spect = inspect.getfullargspec(self.get_loss) | |||||
func_args = [arg for arg in func_spect.args if arg != 'self'] | |||||
for func_param, input_param in self.param_map.items(): | |||||
if func_param not in func_args: | |||||
raise NameError( | |||||
f"Parameter `{func_param}` is not in {get_func_signature(self.get_loss)}. Please check the " | |||||
f"initialization parameters, or change its signature.") | |||||
# evaluate should not have varargs. | |||||
if func_spect.varargs: | |||||
raise NameError(f"Delete `*{func_spect.varargs}` in {get_func_signature(self.get_loss)}(Do not use " | |||||
f"positional argument.).") | |||||
def _fast_param_map(self, pred_dict, target_dict): | |||||
""" | |||||
Only used as inner function. When the pred_dict, target is unequivocal. Don't need users to pass key_map. | |||||
such as pred_dict has one element, target_dict has one element | |||||
:param pred_dict: | |||||
:param target_dict: | |||||
:return: dict, if dict is not {}, pass it to self.evaluate. Otherwise do mapping. | |||||
""" | |||||
fast_param = {} | |||||
if len(self.param_map) == 2 and len(pred_dict) == 1 and len(target_dict) == 1: | |||||
fast_param['pred'] = list(pred_dict.values())[0] | |||||
fast_param['target'] = list(target_dict.values())[0] | |||||
return fast_param | |||||
return fast_param | |||||
def __call__(self, pred_dict, target_dict, check=False): | |||||
""" | |||||
:param pred_dict: A dict from forward function of the network. | |||||
:param target_dict: A dict from DataSet.batch_y. | |||||
:param check: Boolean. Force to check the mapping functions when it is running. | |||||
:return: | |||||
""" | |||||
fast_param = self._fast_param_map(pred_dict, target_dict) | |||||
if fast_param: | |||||
loss = self.get_loss(**fast_param) | |||||
return loss | |||||
if not self._checked: | |||||
# 1. check consistence between signature and param_map | |||||
func_spect = inspect.getfullargspec(self.get_loss) | |||||
func_args = set([arg for arg in func_spect.args if arg != 'self']) | |||||
for func_arg, input_arg in self.param_map.items(): | |||||
if func_arg not in func_args: | |||||
raise NameError(f"`{func_arg}` not in {get_func_signature(self.get_loss)}.") | |||||
# 2. only part of the param_map are passed, left are not | |||||
for arg in func_args: | |||||
if arg not in self.param_map: | |||||
self.param_map[arg] = arg # This param does not need mapping. | |||||
self._evaluate_args = func_args | |||||
self._reverse_param_map = {input_arg: func_arg for func_arg, input_arg in self.param_map.items()} | |||||
# need to wrap inputs in dict. | |||||
mapped_pred_dict = {} | |||||
mapped_target_dict = {} | |||||
duplicated = [] | |||||
for input_arg in set(list(pred_dict.keys()) + list(target_dict.keys())): | |||||
not_duplicate_flag = 0 | |||||
if input_arg in self._reverse_param_map: | |||||
mapped_arg = self._reverse_param_map[input_arg] | |||||
not_duplicate_flag += 1 | |||||
else: | |||||
mapped_arg = input_arg | |||||
if input_arg in pred_dict: | |||||
mapped_pred_dict[mapped_arg] = pred_dict[input_arg] | |||||
not_duplicate_flag += 1 | |||||
if input_arg in target_dict: | |||||
mapped_target_dict[mapped_arg] = target_dict[input_arg] | |||||
not_duplicate_flag += 1 | |||||
if not_duplicate_flag == 3: | |||||
duplicated.append(input_arg) | |||||
# missing | |||||
if not self._checked: | |||||
check_res = _check_arg_dict_list(self.get_loss, [mapped_pred_dict, mapped_target_dict]) | |||||
# replace missing. | |||||
missing = check_res.missing | |||||
replaced_missing = list(missing) | |||||
for idx, func_arg in enumerate(missing): | |||||
# Don't delete `` in this information, nor add `` | |||||
replaced_missing[idx] = f"{self.param_map[func_arg]}" + f"(assign to `{func_arg}` " \ | |||||
f"in `{self.__class__.__name__}`)" | |||||
check_res = CheckRes(missing=replaced_missing, | |||||
unused=check_res.unused, | |||||
duplicated=duplicated, | |||||
required=check_res.required, | |||||
all_needed=check_res.all_needed, | |||||
varargs=check_res.varargs) | |||||
if check_res.missing or check_res.duplicated or check_res.varargs: | |||||
raise CheckError(check_res=check_res, | |||||
func_signature=get_func_signature(self.get_loss)) | |||||
refined_args = _build_args(self.get_loss, **mapped_pred_dict, **mapped_target_dict) | |||||
loss = self.get_loss(**refined_args) | |||||
self._checked = True | |||||
return loss | |||||
class LossFunc(LossBase): | |||||
"""A wrapper of user-provided loss function. | |||||
""" | |||||
def __init__(self, func, key_map=None, **kwargs): | |||||
""" | |||||
:param func: a callable object, such as a function. | |||||
:param dict key_map: | |||||
:param kwargs: | |||||
""" | |||||
super(LossFunc, self).__init__() | |||||
_check_function_or_method(func) | |||||
if key_map is not None: | |||||
if not isinstance(key_map, dict): | |||||
raise RuntimeError(f"Loss error: key_map except a {type({})} but got a {type(key_map)}") | |||||
self.param_map = key_map | |||||
if len(kwargs) > 0: | |||||
for key, val in kwargs.items(): | |||||
self.param_map.update({key: val}) | |||||
self.get_loss = func | |||||
class CrossEntropyLoss(LossBase): | |||||
def __init__(self, pred=None, target=None, padding_idx=-100): | |||||
# TODO 需要做一些检查,F.cross_entropy在计算时,如果pred是(16, 10 ,4), target的形状按道理应该是(16, 10), 但实际却需要 | |||||
# TODO (16, 4) | |||||
super(CrossEntropyLoss, self).__init__() | |||||
self._init_param_map(pred=pred, target=target) | |||||
self.padding_idx = padding_idx | |||||
def get_loss(self, pred, target): | |||||
return F.cross_entropy(input=pred, target=target, | |||||
ignore_index=self.padding_idx) | |||||
class L1Loss(LossBase): | |||||
def __init__(self, pred=None, target=None): | |||||
super(L1Loss, self).__init__() | |||||
self._init_param_map(pred=pred, target=target) | |||||
def get_loss(self, pred, target): | |||||
return F.l1_loss(input=pred, target=target) | |||||
class BCELoss(LossBase): | |||||
def __init__(self, pred=None, target=None): | |||||
super(BCELoss, self).__init__() | |||||
self._init_param_map(pred=pred, target=target) | |||||
def get_loss(self, pred, target): | |||||
return F.binary_cross_entropy(input=pred, target=target) | |||||
class NLLLoss(LossBase): | |||||
def __init__(self, pred=None, target=None): | |||||
super(NLLLoss, self).__init__() | |||||
self._init_param_map(pred=pred, target=target) | |||||
def get_loss(self, pred, target): | |||||
return F.nll_loss(input=pred, target=target) | |||||
class LossInForward(LossBase): | |||||
def __init__(self, loss_key='loss'): | |||||
super().__init__() | |||||
if not isinstance(loss_key, str): | |||||
raise TypeError(f"Only str allowed for loss_key, got {type(loss_key)}.") | |||||
self.loss_key = loss_key | |||||
def get_loss(self, **kwargs): | |||||
if self.loss_key not in kwargs: | |||||
check_res = CheckRes(missing=[self.loss_key + f"(assign to `{self.loss_key}` " \ | |||||
f"in `{self.__class__.__name__}`"], | |||||
unused=[], | |||||
duplicated=[], | |||||
required=[], | |||||
all_needed=[], | |||||
varargs=[]) | |||||
raise CheckError(check_res=check_res, func_signature=get_func_signature(self.get_loss)) | |||||
return kwargs[self.loss_key] | |||||
def __call__(self, pred_dict, target_dict, check=False): | |||||
loss = self.get_loss(**pred_dict) | |||||
if not (isinstance(loss, torch.Tensor) and len(loss.size()) == 0): | |||||
if not isinstance(loss, torch.Tensor): | |||||
raise TypeError(f"loss excepts to be a torch.Tensor, got {type(loss)}") | |||||
raise RuntimeError(f"The size of loss excepts to be torch.Size([]), got {loss.size()}") | |||||
return loss | |||||
def _prepare_losser(losser): | |||||
if losser is None: | |||||
losser = LossInForward() | |||||
return losser | |||||
elif isinstance(losser, LossBase): | |||||
return losser | |||||
else: | |||||
raise TypeError(f"Type of loss should be `fastNLP.LossBase`, got {type(losser)}") | |||||
def squash(predict, truth, **kwargs): | |||||
"""To reshape tensors in order to fit loss functions in pytorch | |||||
:param predict : Tensor, model output | |||||
:param truth : Tensor, truth from dataset | |||||
:param **kwargs : extra arguments | |||||
:return predict , truth: predict & truth after processing | |||||
""" | |||||
return predict.view(-1, predict.size()[-1]), truth.view(-1, ) | |||||
def unpad(predict, truth, **kwargs): | |||||
"""To process padded sequence output to get true loss | |||||
Using pack_padded_sequence() method | |||||
This method contains squash() | |||||
:param predict : Tensor, [batch_size , max_len , tag_size] | |||||
:param truth : Tensor, [batch_size , max_len] | |||||
:param **kwargs : extra arguments, kwargs["lens"] is expected to be exsist | |||||
kwargs["lens"] : list or LongTensor, [batch_size] | |||||
the i-th element is true lengths of i-th sequence | |||||
:return predict , truth: predict & truth after processing | |||||
""" | |||||
if kwargs.get("lens") is None: | |||||
return predict, truth | |||||
lens = torch.LongTensor(kwargs["lens"]) | |||||
lens, idx = torch.sort(lens, descending=True) | |||||
predict = torch.nn.utils.rnn.pack_padded_sequence(predict[idx], lens, batch_first=True).data | |||||
truth = torch.nn.utils.rnn.pack_padded_sequence(truth[idx], lens, batch_first=True).data | |||||
return predict, truth | |||||
def unpad_mask(predict, truth, **kwargs): | |||||
"""To process padded sequence output to get true loss | |||||
Using mask() method | |||||
This method contains squash() | |||||
:param predict : Tensor, [batch_size , max_len , tag_size] | |||||
:param truth : Tensor, [batch_size , max_len] | |||||
:param **kwargs : extra arguments, kwargs["lens"] is expected to be exsist | |||||
kwargs["lens"] : list or LongTensor, [batch_size] | |||||
the i-th element is true lengths of i-th sequence | |||||
:return predict , truth: predict & truth after processing | |||||
""" | |||||
if kwargs.get("lens") is None: | |||||
return predict, truth | |||||
mas = make_mask(kwargs["lens"], truth.size()[1]) | |||||
return mask(predict, truth, mask=mas) | |||||
def mask(predict, truth, **kwargs): | |||||
"""To select specific elements from Tensor | |||||
This method contains squash() | |||||
:param predict : Tensor, [batch_size , max_len , tag_size] | |||||
:param truth : Tensor, [batch_size , max_len] | |||||
:param **kwargs : extra arguments, kwargs["mask"] is expected to be exsist | |||||
kwargs["mask"] : ByteTensor, [batch_size , max_len] | |||||
the mask Tensor , the position that is 1 will be selected | |||||
:return predict , truth: predict & truth after processing | |||||
""" | |||||
if kwargs.get("mask") is None: | |||||
return predict, truth | |||||
mask = kwargs["mask"] | |||||
predict, truth = squash(predict, truth) | |||||
mask = mask.view(-1, ) | |||||
predict = torch.masked_select(predict.permute(1, 0), mask).view(predict.size()[-1], -1).permute(1, 0) | |||||
truth = torch.masked_select(truth, mask) | |||||
return predict, truth | |||||
def make_mask(lens, tar_len): | |||||
"""to generate a mask that select [:lens[i]] for i-th element | |||||
embezzle from fastNLP.models.sequence_modeling.seq_mask | |||||
:param lens : list or LongTensor, [batch_size] | |||||
:param tar_len : int | |||||
:return mask : ByteTensor | |||||
""" | |||||
lens = torch.LongTensor(lens) | |||||
mask = [torch.ge(lens, i + 1) for i in range(tar_len)] | |||||
mask = torch.stack(mask, 1) | |||||
return mask | |||||
@@ -1,288 +1,310 @@ | |||||
import warnings | |||||
import inspect | |||||
from collections import defaultdict | |||||
import numpy as np | import numpy as np | ||||
import torch | import torch | ||||
class Evaluator(object): | |||||
def __init__(self): | |||||
pass | |||||
def __call__(self, predict, truth): | |||||
""" | |||||
:param predict: list of tensors, the network outputs from all batches. | |||||
:param truth: list of dict, the ground truths from all batch_y. | |||||
:return: | |||||
""" | |||||
raise NotImplementedError | |||||
from fastNLP.core.utils import CheckError | |||||
from fastNLP.core.utils import CheckRes | |||||
from fastNLP.core.utils import _build_args | |||||
from fastNLP.core.utils import _check_arg_dict_list | |||||
from fastNLP.core.utils import get_func_signature | |||||
from fastNLP.core.utils import seq_lens_to_masks | |||||
class ClassifyEvaluator(Evaluator): | |||||
class MetricBase(object): | |||||
def __init__(self): | def __init__(self): | ||||
super(ClassifyEvaluator, self).__init__() | |||||
self.param_map = {} # key is param in function, value is input param. | |||||
self._checked = False | |||||
def __call__(self, predict, truth): | |||||
y_prob = [torch.nn.functional.softmax(y_logit, dim=-1) for y_logit in predict] | |||||
y_prob = torch.cat(y_prob, dim=0) | |||||
y_pred = torch.argmax(y_prob, dim=-1) | |||||
y_true = torch.cat(truth, dim=0) | |||||
acc = float(torch.sum(y_pred == y_true)) / len(y_true) | |||||
return {"accuracy": acc} | |||||
def evaluate(self, *args, **kwargs): | |||||
raise NotImplementedError | |||||
def _init_param_map(self, key_map=None, **kwargs): | |||||
"""Check the validity of key_map and other param map. Add these into self.param_map | |||||
class SeqLabelEvaluator(Evaluator): | |||||
def __init__(self): | |||||
super(SeqLabelEvaluator, self).__init__() | |||||
def __call__(self, predict, truth, **_): | |||||
:param key_map: dict | |||||
:param kwargs: | |||||
:return: None | |||||
""" | |||||
value_counter = defaultdict(set) | |||||
if key_map is not None: | |||||
if not isinstance(key_map, dict): | |||||
raise TypeError("key_map must be `dict`, got {}.".format(type(key_map))) | |||||
for key, value in key_map.items(): | |||||
if value is None: | |||||
self.param_map[key] = key | |||||
continue | |||||
if not isinstance(key, str): | |||||
raise TypeError(f"key in key_map must be `str`, not `{type(key)}`.") | |||||
if not isinstance(value, str): | |||||
raise TypeError(f"value in key_map must be `str`, not `{type(value)}`.") | |||||
self.param_map[key] = value | |||||
value_counter[value].add(key) | |||||
for key, value in kwargs.items(): | |||||
if value is None: | |||||
self.param_map[key] = key | |||||
continue | |||||
if not isinstance(value, str): | |||||
raise TypeError(f"in {key}={value}, value must be `str`, not `{type(value)}`.") | |||||
self.param_map[key] = value | |||||
value_counter[value].add(key) | |||||
for value, key_set in value_counter.items(): | |||||
if len(key_set) > 1: | |||||
raise ValueError(f"Several parameters:{key_set} are provided with one output {value}.") | |||||
# check consistence between signature and param_map | |||||
func_spect = inspect.getfullargspec(self.evaluate) | |||||
func_args = [arg for arg in func_spect.args if arg != 'self'] | |||||
for func_param, input_param in self.param_map.items(): | |||||
if func_param not in func_args: | |||||
raise NameError( | |||||
f"Parameter `{func_param}` is not in {get_func_signature(self.evaluate)}. Please check the " | |||||
f"initialization parameters, or change its signature.") | |||||
# evaluate should not have varargs. | |||||
if func_spect.varargs: | |||||
raise NameError(f"Delete `*{func_spect.varargs}` in {get_func_signature(self.evaluate)}(Do not use " | |||||
f"positional argument.).") | |||||
def get_metric(self, reset=True): | |||||
raise NotImplemented | |||||
def _fast_param_map(self, pred_dict, target_dict): | |||||
""" | """ | ||||
:param predict: list of List, the network outputs from all batches. | |||||
:param truth: list of dict, the ground truths from all batch_y. | |||||
:return accuracy: | |||||
Only used as inner function. When the pred_dict, target is unequivocal. Don't need users to pass key_map. | |||||
such as pred_dict has one element, target_dict has one element | |||||
:param pred_dict: | |||||
:param target_dict: | |||||
:return: dict, if dict is not {}, pass it to self.evaluate. Otherwise do mapping. | |||||
""" | """ | ||||
total_correct, total_count = 0., 0. | |||||
for x, y in zip(predict, truth): | |||||
x = torch.tensor(x) | |||||
y = y.to(x) # make sure they are in the same device | |||||
mask = (y > 0) | |||||
correct = torch.sum(((x == y) * mask).long()) | |||||
total_correct += float(correct) | |||||
total_count += float(torch.sum(mask.long())) | |||||
accuracy = total_correct / total_count | |||||
return {"accuracy": float(accuracy)} | |||||
class SeqLabelEvaluator2(Evaluator): | |||||
# 上面的evaluator应该是错误的 | |||||
def __init__(self, seq_lens_field_name='word_seq_origin_len'): | |||||
super(SeqLabelEvaluator2, self).__init__() | |||||
self.end_tagidx_set = set() | |||||
self.seq_lens_field_name = seq_lens_field_name | |||||
def __call__(self, predict, truth, **_): | |||||
fast_param = {} | |||||
if len(self.param_map) == 2 and len(pred_dict) == 1 and len(target_dict) == 1: | |||||
fast_param['pred'] = list(pred_dict.values())[0] | |||||
fast_param['target'] = list(pred_dict.values())[0] | |||||
return fast_param | |||||
return fast_param | |||||
def __call__(self, pred_dict, target_dict): | |||||
""" | """ | ||||
:param predict: list of batch, the network outputs from all batches. | |||||
:param truth: list of dict, the ground truths from all batch_y. | |||||
:return accuracy: | |||||
This method will call self.evaluate method. | |||||
Before calling self.evaluate, it will first check the validity of output_dict, target_dict | |||||
(1) whether self.evaluate has varargs, which is not supported. | |||||
(2) whether params needed by self.evaluate is not included in output_dict,target_dict. | |||||
(3) whether params needed by self.evaluate duplicate in pred_dict, target_dict | |||||
(4) whether params in output_dict, target_dict are not used by evaluate.(Might cause warning) | |||||
Besides, before passing params into self.evaluate, this function will filter out params from output_dict and | |||||
target_dict which are not used in self.evaluate. (but if **kwargs presented in self.evaluate, no filtering | |||||
will be conducted.) | |||||
This function also support _fast_param_map. | |||||
:param pred_dict: usually the output of forward or prediction function | |||||
:param target_dict: usually features set as target.. | |||||
:return: | |||||
""" | |||||
if not callable(self.evaluate): | |||||
raise TypeError(f"{self.__class__.__name__}.evaluate has to be callable, not {type(self.evaluate)}.") | |||||
fast_param = self._fast_param_map(pred_dict=pred_dict, target_dict=target_dict) | |||||
if fast_param: | |||||
self.evaluate(**fast_param) | |||||
return | |||||
if not self._checked: | |||||
# 1. check consistence between signature and param_map | |||||
func_spect = inspect.getfullargspec(self.evaluate) | |||||
func_args = set([arg for arg in func_spect.args if arg != 'self']) | |||||
for func_arg, input_arg in self.param_map.items(): | |||||
if func_arg not in func_args: | |||||
raise NameError(f"`{func_arg}` not in {get_func_signature(self.evaluate)}.") | |||||
# 2. only part of the param_map are passed, left are not | |||||
for arg in func_args: | |||||
if arg not in self.param_map: | |||||
self.param_map[arg] = arg # This param does not need mapping. | |||||
self._evaluate_args = func_args | |||||
self._reverse_param_map = {input_arg: func_arg for func_arg, input_arg in self.param_map.items()} | |||||
# need to wrap inputs in dict. | |||||
mapped_pred_dict = {} | |||||
mapped_target_dict = {} | |||||
duplicated = [] | |||||
for input_arg in set(list(pred_dict.keys()) + list(target_dict.keys())): | |||||
not_duplicate_flag = 0 | |||||
if input_arg in self._reverse_param_map: | |||||
mapped_arg = self._reverse_param_map[input_arg] | |||||
not_duplicate_flag += 1 | |||||
else: | |||||
mapped_arg = input_arg | |||||
if input_arg in pred_dict: | |||||
mapped_pred_dict[mapped_arg] = pred_dict[input_arg] | |||||
not_duplicate_flag += 1 | |||||
if input_arg in target_dict: | |||||
mapped_target_dict[mapped_arg] = target_dict[input_arg] | |||||
not_duplicate_flag += 1 | |||||
if not_duplicate_flag == 3: | |||||
duplicated.append(input_arg) | |||||
# missing | |||||
if not self._checked: | |||||
check_res = _check_arg_dict_list(self.evaluate, [mapped_pred_dict, mapped_target_dict]) | |||||
# only check missing. | |||||
# replace missing. | |||||
missing = check_res.missing | |||||
replaced_missing = list(missing) | |||||
for idx, func_arg in enumerate(missing): | |||||
# Don't delete `` in this information, nor add `` | |||||
replaced_missing[idx] = f"{self.param_map[func_arg]}" + f"(assign to `{func_arg}` " \ | |||||
f"in `{self.__class__.__name__}`)" | |||||
check_res = CheckRes(missing=replaced_missing, | |||||
unused=check_res.unused, | |||||
duplicated=duplicated, | |||||
required=check_res.required, | |||||
all_needed=check_res.all_needed, | |||||
varargs=check_res.varargs) | |||||
if check_res.missing or check_res.duplicated or check_res.varargs: | |||||
raise CheckError(check_res=check_res, | |||||
func_signature=get_func_signature(self.evaluate)) | |||||
refined_args = _build_args(self.evaluate, **mapped_pred_dict, **mapped_target_dict) | |||||
self.evaluate(**refined_args) | |||||
self._checked = True | |||||
return | |||||
class AccuracyMetric(MetricBase): | |||||
def __init__(self, pred=None, target=None, seq_lens=None): | |||||
super().__init__() | |||||
self._init_param_map(pred=pred, target=target, seq_lens=seq_lens) | |||||
self.total = 0 | |||||
self.acc_count = 0 | |||||
def _fast_param_map(self, pred_dict, target_dict): | |||||
""" | """ | ||||
seq_lens = _[self.seq_lens_field_name] | |||||
corr_count = 0 | |||||
pred_count = 0 | |||||
truth_count = 0 | |||||
for x, y, seq_len in zip(predict, truth, seq_lens): | |||||
x = x.cpu().numpy() | |||||
y = y.cpu().numpy() | |||||
for idx, s_l in enumerate(seq_len): | |||||
x_ = x[idx] | |||||
y_ = y[idx] | |||||
x_ = x_[:s_l] | |||||
y_ = y_[:s_l] | |||||
flag = True | |||||
start = 0 | |||||
for idx_i, (x_i, y_i) in enumerate(zip(x_, y_)): | |||||
if x_i in self.end_tagidx_set: | |||||
truth_count += 1 | |||||
for j in range(start, idx_i + 1): | |||||
if y_[j]!=x_[j]: | |||||
flag = False | |||||
break | |||||
if flag: | |||||
corr_count += 1 | |||||
flag = True | |||||
start = idx_i + 1 | |||||
if y_i in self.end_tagidx_set: | |||||
pred_count += 1 | |||||
P = corr_count / (float(pred_count) + 1e-6) | |||||
R = corr_count / (float(truth_count) + 1e-6) | |||||
F = 2 * P * R / (P + R + 1e-6) | |||||
return {"P": P, 'R':R, 'F': F} | |||||
class SNLIEvaluator(Evaluator): | |||||
def __init__(self): | |||||
super(SNLIEvaluator, self).__init__() | |||||
def __call__(self, predict, truth): | |||||
y_prob = [torch.nn.functional.softmax(y_logit, dim=-1) for y_logit in predict] | |||||
y_prob = torch.cat(y_prob, dim=0) | |||||
y_pred = torch.argmax(y_prob, dim=-1) | |||||
truth = [t['truth'] for t in truth] | |||||
y_true = torch.cat(truth, dim=0).view(-1) | |||||
acc = float(torch.sum(y_pred == y_true)) / y_true.size(0) | |||||
return {"accuracy": acc} | |||||
Only used as inner function. When the pred_dict, target is unequivocal. Don't need users to pass key_map. | |||||
such as pred_dict has one element, target_dict has one element | |||||
:param pred_dict: | |||||
:param target_dict: | |||||
:return: dict, if dict is not None, pass it to self.evaluate. Otherwise do mapping. | |||||
""" | |||||
fast_param = {} | |||||
targets = list(target_dict.values()) | |||||
if len(targets) == 1 and isinstance(targets[0], torch.Tensor): | |||||
if len(pred_dict) == 1: | |||||
pred = list(pred_dict.values())[0] | |||||
fast_param['pred'] = pred | |||||
elif len(pred_dict) == 2: | |||||
pred1 = list(pred_dict.values())[0] | |||||
pred2 = list(pred_dict.values())[1] | |||||
if not (isinstance(pred1, torch.Tensor) and isinstance(pred2, torch.Tensor)): | |||||
return fast_param | |||||
if len(pred1.size()) < len(pred2.size()) and len(pred1.size()) == 1: | |||||
seq_lens = pred1 | |||||
pred = pred2 | |||||
elif len(pred1.size()) > len(pred2.size()) and len(pred2.size()) == 1: | |||||
seq_lens = pred2 | |||||
pred = pred1 | |||||
else: | |||||
return fast_param | |||||
fast_param['pred'] = pred | |||||
fast_param['seq_lens'] = seq_lens | |||||
else: | |||||
return fast_param | |||||
fast_param['target'] = targets[0] | |||||
# TODO need to make sure they all have same batch_size | |||||
return fast_param | |||||
def evaluate(self, pred, target, seq_lens=None): | |||||
""" | |||||
def _conver_numpy(x): | |||||
"""convert input data to numpy array | |||||
:param pred: List of (torch.Tensor, or numpy.ndarray). Element's shape can be: | |||||
torch.Size([B,]), torch.Size([B, n_classes]), torch.Size([B, max_len]), torch.Size([B, max_len, n_classes]) | |||||
:param target: List of (torch.Tensor, or numpy.ndarray). Element's can be: | |||||
torch.Size([B,]), torch.Size([B,]), torch.Size([B, max_len]), torch.Size([B, max_len]) | |||||
:param seq_lens: List of (torch.Tensor, or numpy.ndarray). Element's can be: | |||||
None, None, torch.Size([B], torch.Size([B]). ignored if masks are provided. | |||||
:return: dict({'acc': float}) | |||||
""" | |||||
# TODO 这里报错需要更改,因为pred是啥用户并不知道。需要告知用户真实的value | |||||
if not isinstance(pred, torch.Tensor): | |||||
raise TypeError(f"`pred` in {get_func_signature(self.evaluate)} must be torch.Tensor," | |||||
f"got {type(pred)}.") | |||||
if not isinstance(target, torch.Tensor): | |||||
raise TypeError(f"`target` in {get_func_signature(self.evaluate)} must be torch.Tensor," | |||||
f"got {type(target)}.") | |||||
if seq_lens is not None and not isinstance(seq_lens, torch.Tensor): | |||||
raise TypeError(f"`seq_lens` in {get_func_signature(self.evaluate)} must be torch.Tensor," | |||||
f"got {type(seq_lens)}.") | |||||
if seq_lens is not None: | |||||
masks = seq_lens_to_masks(seq_lens=seq_lens, float=True) | |||||
else: | |||||
masks = None | |||||
""" | |||||
if isinstance(x, np.ndarray): | |||||
return x | |||||
elif isinstance(x, torch.Tensor): | |||||
return x.numpy() | |||||
elif isinstance(x, list): | |||||
return np.array(x) | |||||
raise TypeError('cannot accept object: {}'.format(x)) | |||||
if pred.size() == target.size(): | |||||
pass | |||||
elif len(pred.size()) == len(target.size()) + 1: | |||||
pred = pred.argmax(dim=-1) | |||||
else: | |||||
raise RuntimeError(f"In {get_func_signature(self.evaluate)}, when pred have " | |||||
f"size:{pred.size()}, target should have size: {pred.size()} or " | |||||
f"{pred.size()[:-1]}, got {target.size()}.") | |||||
pred = pred.float() | |||||
target = target.float() | |||||
def _check_same_len(*arrays, axis=0): | |||||
"""check if input array list has same length for one dimension | |||||
if masks is not None: | |||||
self.acc_count += torch.sum(torch.eq(pred, target).float() * masks.float()).item() | |||||
self.total += torch.sum(masks.float()).item() | |||||
else: | |||||
self.acc_count += torch.sum(torch.eq(pred, target).float()).item() | |||||
self.total += np.prod(list(pred.size())) | |||||
""" | |||||
lens = set([x.shape[axis] for x in arrays if x is not None]) | |||||
return len(lens) == 1 | |||||
def get_metric(self, reset=True): | |||||
evaluate_result = {'acc': round(self.acc_count / self.total, 6)} | |||||
if reset: | |||||
self.acc_count = 0 | |||||
self.total = 0 | |||||
return evaluate_result | |||||
def _label_types(y): | |||||
"""Determine the type | |||||
- "binary" | |||||
- "multiclass" | |||||
- "multiclass-multioutput" | |||||
- "multilabel" | |||||
- "unknown" | |||||
def _prepare_metrics(metrics): | |||||
""" | """ | ||||
# never squeeze the first dimension | |||||
y = y.squeeze() if y.shape[0] > 1 else y.resize(1, -1) | |||||
shape = y.shape | |||||
if len(shape) < 1: | |||||
raise ValueError('cannot accept data: {}'.format(y)) | |||||
if len(shape) == 1: | |||||
return 'multiclass' if np.unique(y).shape[0] > 2 else 'binary', y | |||||
if len(shape) == 2: | |||||
return 'multiclass-multioutput' if np.unique(y).shape[0] > 2 else 'multilabel', y | |||||
return 'unknown', y | |||||
def _check_data(y_true, y_pred): | |||||
"""Check if y_true and y_pred is same type of data e.g both binary or multiclass | |||||
Prepare list of Metric based on input | |||||
:param metrics: | |||||
:return: List[fastNLP.MetricBase] | |||||
""" | """ | ||||
y_true, y_pred = _conver_numpy(y_true), _conver_numpy(y_pred) | |||||
if not _check_same_len(y_true, y_pred): | |||||
raise ValueError('cannot accept data with different shape {0}, {1}'.format(y_true, y_pred)) | |||||
type_true, y_true = _label_types(y_true) | |||||
type_pred, y_pred = _label_types(y_pred) | |||||
type_set = set(['binary', 'multiclass']) | |||||
if type_true in type_set and type_pred in type_set: | |||||
return type_true if type_true == type_pred else 'multiclass', y_true, y_pred | |||||
type_set = set(['multiclass-multioutput', 'multilabel']) | |||||
if type_true in type_set and type_pred in type_set: | |||||
return type_true if type_true == type_pred else 'multiclass-multioutput', y_true, y_pred | |||||
raise ValueError('cannot accept data mixed of {0} and {1} target'.format(type_true, type_pred)) | |||||
def _weight_sum(y, normalize=True, sample_weight=None): | |||||
if normalize: | |||||
return np.average(y, weights=sample_weight) | |||||
if sample_weight is None: | |||||
return y.sum() | |||||
else: | |||||
return np.dot(y, sample_weight) | |||||
def accuracy_score(y_true, y_pred, normalize=True, sample_weight=None): | |||||
y_type, y_true, y_pred = _check_data(y_true, y_pred) | |||||
if y_type == 'multiclass-multioutput': | |||||
raise ValueError('cannot accept data type {0}'.format(y_type)) | |||||
if y_type == 'multilabel': | |||||
equel = (y_true == y_pred).sum(1) | |||||
count = equel == y_true.shape[1] | |||||
else: | |||||
count = y_true == y_pred | |||||
return _weight_sum(count, normalize=normalize, sample_weight=sample_weight) | |||||
def recall_score(y_true, y_pred, labels=None, pos_label=1, average='binary'): | |||||
y_type, y_true, y_pred = _check_data(y_true, y_pred) | |||||
if average == 'binary': | |||||
if y_type != 'binary': | |||||
raise ValueError("data type is {} but use average type {}".format(y_type, average)) | |||||
else: | |||||
pos = (y_true == pos_label) | |||||
tp = np.logical_and((y_true == y_pred), pos).sum() | |||||
pos_sum = pos.sum() | |||||
return tp / pos_sum if pos_sum > 0 else 0 | |||||
elif average == None: | |||||
y_labels = set(list(np.unique(y_true))) | |||||
if labels is None: | |||||
labels = list(y_labels) | |||||
else: | |||||
for i in labels: | |||||
if (i not in y_labels and y_type != 'multilabel') or (y_type == 'multilabel' and i >= y_true.shape[1]): | |||||
warnings.warn('label {} is not contained in data'.format(i), UserWarning) | |||||
if y_type in ['binary', 'multiclass']: | |||||
y_pred_right = y_true == y_pred | |||||
pos_list = [y_true == i for i in labels] | |||||
pos_sum_list = [pos_i.sum() for pos_i in pos_list] | |||||
return np.array([np.logical_and(y_pred_right, pos_i).sum() / sum_i if sum_i > 0 else 0 \ | |||||
for pos_i, sum_i in zip(pos_list, pos_sum_list)]) | |||||
elif y_type == 'multilabel': | |||||
y_pred_right = y_true == y_pred | |||||
pos = (y_true == pos_label) | |||||
tp = np.logical_and(y_pred_right, pos).sum(0) | |||||
pos_sum = pos.sum(0) | |||||
return np.array([tp[i] / pos_sum[i] if pos_sum[i] > 0 else 0 for i in labels]) | |||||
_metrics = [] | |||||
if metrics: | |||||
if isinstance(metrics, list): | |||||
for metric in metrics: | |||||
if isinstance(metric, type): | |||||
metric = metric() | |||||
if isinstance(metric, MetricBase): | |||||
metric_name = metric.__class__.__name__ | |||||
if not callable(metric.evaluate): | |||||
raise TypeError(f"{metric_name}.evaluate must be callable, got {type(metric.evaluate)}.") | |||||
if not callable(metric.get_metric): | |||||
raise TypeError(f"{metric_name}.get_metric must be callable, got {type(metric.get_metric)}.") | |||||
_metrics.append(metric) | |||||
else: | |||||
raise TypeError( | |||||
f"The type of metric in metrics must be `fastNLP.MetricBase`, not `{type(metric)}`.") | |||||
elif isinstance(metrics, MetricBase): | |||||
_metrics = [metrics] | |||||
else: | else: | ||||
raise ValueError('not support targets type {}'.format(y_type)) | |||||
raise ValueError('not support for average type {}'.format(average)) | |||||
def precision_score(y_true, y_pred, labels=None, pos_label=1, average='binary'): | |||||
y_type, y_true, y_pred = _check_data(y_true, y_pred) | |||||
if average == 'binary': | |||||
if y_type != 'binary': | |||||
raise ValueError("data type is {} but use average type {}".format(y_type, average)) | |||||
else: | |||||
pos = (y_true == pos_label) | |||||
tp = np.logical_and((y_true == y_pred), pos).sum() | |||||
pos_pred = (y_pred == pos_label).sum() | |||||
return tp / pos_pred if pos_pred > 0 else 0 | |||||
elif average == None: | |||||
y_labels = set(list(np.unique(y_true))) | |||||
if labels is None: | |||||
labels = list(y_labels) | |||||
else: | |||||
for i in labels: | |||||
if (i not in y_labels and y_type != 'multilabel') or (y_type == 'multilabel' and i >= y_true.shape[1]): | |||||
warnings.warn('label {} is not contained in data'.format(i), UserWarning) | |||||
if y_type in ['binary', 'multiclass']: | |||||
y_pred_right = y_true == y_pred | |||||
pos_list = [y_true == i for i in labels] | |||||
pos_sum_list = [(y_pred == i).sum() for i in labels] | |||||
return np.array([np.logical_and(y_pred_right, pos_i).sum() / sum_i if sum_i > 0 else 0 \ | |||||
for pos_i, sum_i in zip(pos_list, pos_sum_list)]) | |||||
elif y_type == 'multilabel': | |||||
y_pred_right = y_true == y_pred | |||||
pos = (y_true == pos_label) | |||||
tp = np.logical_and(y_pred_right, pos).sum(0) | |||||
pos_sum = (y_pred == pos_label).sum(0) | |||||
return np.array([tp[i] / pos_sum[i] if pos_sum[i] > 0 else 0 for i in labels]) | |||||
else: | |||||
raise ValueError('not support targets type {}'.format(y_type)) | |||||
raise ValueError('not support for average type {}'.format(average)) | |||||
def f1_score(y_true, y_pred, labels=None, pos_label=1, average='binary'): | |||||
precision = precision_score(y_true, y_pred, labels=labels, pos_label=pos_label, average=average) | |||||
recall = recall_score(y_true, y_pred, labels=labels, pos_label=pos_label, average=average) | |||||
if isinstance(precision, np.ndarray): | |||||
res = 2 * precision * recall / (precision + recall) | |||||
res[(precision + recall) <= 0] = 0 | |||||
return res | |||||
return 2 * precision * recall / (precision + recall) if (precision + recall) > 0 else 0 | |||||
def classification_report(y_true, y_pred, labels=None, target_names=None, digits=2): | |||||
raise NotImplementedError | |||||
raise TypeError(f"The type of metrics should be `list[fastNLP.MetricBase]` or `fastNLP.MetricBase`, " | |||||
f"got {type(metrics)}.") | |||||
return _metrics | |||||
def accuracy_topk(y_true, y_prob, k=1): | def accuracy_topk(y_true, y_prob, k=1): | ||||
@@ -2,61 +2,48 @@ import torch | |||||
class Optimizer(object): | class Optimizer(object): | ||||
"""Wrapper of optimizer from framework | |||||
def __init__(self, model_params, **kwargs): | |||||
if model_params is not None and not hasattr(model_params, "__next__"): | |||||
raise RuntimeError("model parameters should be a generator, rather than {}.".format(type(model_params))) | |||||
self.model_params = model_params | |||||
self.settings = kwargs | |||||
1. Adam: lr (float), weight_decay (float) | |||||
2. AdaGrad | |||||
3. RMSProp | |||||
4. SGD: lr (float), momentum (float) | |||||
""" | |||||
def __init__(self, optimizer_name, **kwargs): | |||||
class SGD(Optimizer): | |||||
def __init__(self, lr=0.01, momentum=0, model_params=None): | |||||
""" | """ | ||||
:param optimizer_name: str, the name of the optimizer | |||||
:param kwargs: the arguments | |||||
:param float lr: learning rate. Default: 0.01 | |||||
:param float momentum: momentum. Default: 0 | |||||
:param model_params: a generator. E.g. model.parameters() for PyTorch models. | |||||
""" | """ | ||||
self.optim_name = optimizer_name | |||||
self.kwargs = kwargs | |||||
if not isinstance(lr, float): | |||||
raise TypeError("learning rate has to be float.") | |||||
super(SGD, self).__init__(model_params, lr=lr, momentum=momentum) | |||||
@property | |||||
def name(self): | |||||
"""The name of the optimizer. | |||||
def construct_from_pytorch(self, model_params): | |||||
if self.model_params is None: | |||||
# careful! generator cannot be assigned. | |||||
return torch.optim.SGD(model_params, **self.settings) | |||||
else: | |||||
return torch.optim.SGD(self.model_params, **self.settings) | |||||
:return: str | |||||
""" | |||||
return self.optim_name | |||||
@property | |||||
def params(self): | |||||
"""The arguments used to create the optimizer. | |||||
class Adam(Optimizer): | |||||
def __init__(self, lr=0.01, weight_decay=0, model_params=None): | |||||
""" | |||||
:return: dict of (str, *) | |||||
:param float lr: learning rate | |||||
:param float weight_decay: | |||||
:param model_params: a generator. E.g. model.parameters() for PyTorch models. | |||||
""" | """ | ||||
return self.kwargs | |||||
if not isinstance(lr, float): | |||||
raise TypeError("learning rate has to be float.") | |||||
super(Adam, self).__init__(model_params, lr=lr, weight_decay=weight_decay) | |||||
def construct_from_pytorch(self, model_params): | def construct_from_pytorch(self, model_params): | ||||
"""Construct a optimizer from framework over given model parameters.""" | |||||
if self.optim_name in ["SGD", "sgd"]: | |||||
if "lr" in self.kwargs: | |||||
if "momentum" not in self.kwargs: | |||||
self.kwargs["momentum"] = 0 | |||||
optimizer = torch.optim.SGD(model_params, lr=self.kwargs["lr"], momentum=self.kwargs["momentum"]) | |||||
else: | |||||
raise ValueError("requires learning rate for SGD optimizer") | |||||
elif self.optim_name in ["adam", "Adam"]: | |||||
if "lr" in self.kwargs: | |||||
if "weight_decay" not in self.kwargs: | |||||
self.kwargs["weight_decay"] = 0 | |||||
optimizer = torch.optim.Adam(model_params, lr=self.kwargs["lr"], | |||||
weight_decay=self.kwargs["weight_decay"]) | |||||
else: | |||||
raise ValueError("requires learning rate for Adam optimizer") | |||||
if self.model_params is None: | |||||
# careful! generator cannot be assigned. | |||||
return torch.optim.Adam(model_params, **self.settings) | |||||
else: | else: | ||||
raise NotImplementedError | |||||
return optimizer | |||||
return torch.optim.Adam(self.model_params, **self.settings) |
@@ -1,4 +1,3 @@ | |||||
import numpy as np | |||||
import torch | import torch | ||||
from fastNLP.core.batch import Batch | from fastNLP.core.batch import Batch | ||||
@@ -23,13 +22,13 @@ class Predictor(object): | |||||
:param network: a PyTorch model (cpu) | :param network: a PyTorch model (cpu) | ||||
:param data: a DataSet object. | :param data: a DataSet object. | ||||
:return: list of list of strings, [num_examples, tag_seq_length] | |||||
:return: list of batch outputs | |||||
""" | """ | ||||
# turn on the testing mode; clean up the history | # turn on the testing mode; clean up the history | ||||
self.mode(network, test=True) | self.mode(network, test=True) | ||||
batch_output = [] | batch_output = [] | ||||
data_iterator = Batch(data, batch_size=self.batch_size, sampler=SequentialSampler(), use_cuda=False) | |||||
data_iterator = Batch(data, batch_size=self.batch_size, sampler=SequentialSampler(), as_numpy=False) | |||||
for batch_x, _ in data_iterator: | for batch_x, _ in data_iterator: | ||||
with torch.no_grad(): | with torch.no_grad(): | ||||
@@ -48,19 +47,3 @@ class Predictor(object): | |||||
"""Forward through network.""" | """Forward through network.""" | ||||
y = network(**x) | y = network(**x) | ||||
return y | return y | ||||
def seq_label_post_processor(batch_outputs, label_vocab): | |||||
results = [] | |||||
for batch in batch_outputs: | |||||
for example in np.array(batch): | |||||
results.append([label_vocab.to_word(int(x)) for x in example]) | |||||
return results | |||||
def text_classify_post_processor(batch_outputs, label_vocab): | |||||
results = [] | |||||
for batch_out in batch_outputs: | |||||
idx = np.argmax(batch_out.detach().numpy(), axis=-1) | |||||
results.extend([label_vocab.to_word(i) for i in idx]) | |||||
return results |
@@ -55,7 +55,7 @@ class BucketSampler(BaseSampler): | |||||
def __call__(self, data_set): | def __call__(self, data_set): | ||||
seq_lens = data_set[self.seq_lens_field_name].content | |||||
seq_lens = data_set.get_all_fields()[self.seq_lens_field_name].content | |||||
total_sample_num = len(seq_lens) | total_sample_num = len(seq_lens) | ||||
bucket_indexes = [] | bucket_indexes = [] | ||||
@@ -1,60 +1,88 @@ | |||||
import itertools | |||||
from collections import defaultdict | from collections import defaultdict | ||||
import torch | import torch | ||||
from torch import nn | |||||
from fastNLP.core.batch import Batch | from fastNLP.core.batch import Batch | ||||
from fastNLP.core.sampler import RandomSampler | |||||
from fastNLP.core.dataset import DataSet | |||||
from fastNLP.core.metrics import _prepare_metrics | |||||
from fastNLP.core.sampler import SequentialSampler | |||||
from fastNLP.core.utils import CheckError | |||||
from fastNLP.core.utils import _build_args | from fastNLP.core.utils import _build_args | ||||
from fastNLP.core.utils import _check_loss_evaluate | |||||
from fastNLP.core.utils import _move_dict_value_to_device | |||||
from fastNLP.core.utils import get_func_signature | |||||
class Tester(object): | class Tester(object): | ||||
"""An collection of model inference and evaluation of performance, used over validation/dev set and test set. """ | """An collection of model inference and evaluation of performance, used over validation/dev set and test set. """ | ||||
def __init__(self, data, model, batch_size=16, use_cuda=False): | |||||
def __init__(self, data, model, metrics, batch_size=16, use_cuda=False, verbose=1): | |||||
super(Tester, self).__init__() | super(Tester, self).__init__() | ||||
self.use_cuda = use_cuda | |||||
if not isinstance(data, DataSet): | |||||
raise TypeError(f"The type of data must be `fastNLP.DataSet`, got `{type(data)}`.") | |||||
if not isinstance(model, nn.Module): | |||||
raise TypeError(f"The type of model must be `torch.nn.Module`, got `{type(model)}`.") | |||||
self.metrics = _prepare_metrics(metrics) | |||||
self.data = data | self.data = data | ||||
self.use_cuda = use_cuda | |||||
self.batch_size = batch_size | self.batch_size = batch_size | ||||
self.verbose = verbose | |||||
self._model_device = model.parameters().__next__().device | |||||
if torch.cuda.is_available() and self.use_cuda: | if torch.cuda.is_available() and self.use_cuda: | ||||
self._model = model.cuda() | self._model = model.cuda() | ||||
else: | else: | ||||
self._model = model | self._model = model | ||||
# check predict | |||||
if hasattr(self._model, 'predict'): | if hasattr(self._model, 'predict'): | ||||
assert callable(self._model.predict) | |||||
self._predict_func = self._model.predict | self._predict_func = self._model.predict | ||||
if not callable(self._predict_func): | |||||
_model_name = model.__class__.__name__ | |||||
raise TypeError(f"`{_model_name}.predict` must be callable to be used " | |||||
f"for evaluation, not `{type(self._predict_func)}`.") | |||||
else: | else: | ||||
self._predict_func = self._model | |||||
assert hasattr(model, 'evaluate') | |||||
self._evaluator = model.evaluate | |||||
self.eval_history = [] # evaluation results of all batches | |||||
self._predict_func = self._model.forward | |||||
def test(self): | def test(self): | ||||
# turn on the testing mode; clean up the history | # turn on the testing mode; clean up the history | ||||
network = self._model | network = self._model | ||||
self.mode(network, is_test=True) | |||||
self.eval_history.clear() | |||||
output, truths = defaultdict(list), defaultdict(list) | |||||
data_iterator = Batch(self.data, self.batch_size, sampler=RandomSampler(), as_numpy=False) | |||||
with torch.no_grad(): | |||||
for batch_x, batch_y in data_iterator: | |||||
prediction = self.data_forward(network, batch_x) | |||||
assert isinstance(prediction, dict) | |||||
for k, v in prediction.items(): | |||||
output[k].append(v) | |||||
for k, v in batch_y.items(): | |||||
truths[k].append(v) | |||||
for k, v in output.items(): | |||||
output[k] = itertools.chain(*v) | |||||
for k, v in truths.items(): | |||||
truths[k] = itertools.chain(*v) | |||||
args = _build_args(self._evaluator, **output, **truths) | |||||
eval_results = self._evaluator(**args) | |||||
print("[tester] {}".format(self.print_eval_results(eval_results))) | |||||
self.mode(network, is_test=False) | |||||
self._mode(network, is_test=True) | |||||
data_iterator = Batch(self.data, self.batch_size, sampler=SequentialSampler(), as_numpy=False) | |||||
eval_results = {} | |||||
try: | |||||
with torch.no_grad(): | |||||
for batch_x, batch_y in data_iterator: | |||||
_move_dict_value_to_device(batch_x, batch_y, device=self._model_device) | |||||
pred_dict = self._data_forward(self._predict_func, batch_x) | |||||
if not isinstance(pred_dict, dict): | |||||
raise TypeError(f"The return value of {get_func_signature(self._predict_func)} " | |||||
f"must be `dict`, got {type(pred_dict)}.") | |||||
for metric in self.metrics: | |||||
metric(pred_dict, batch_y) | |||||
for metric in self.metrics: | |||||
eval_result = metric.get_metric() | |||||
if not isinstance(eval_result, dict): | |||||
raise TypeError(f"The return value of {get_func_signature(metric.get_metric)} must be " | |||||
f"`dict`, got {type(eval_result)}") | |||||
metric_name = metric.__class__.__name__ | |||||
eval_results[metric_name] = eval_result | |||||
except CheckError as e: | |||||
prev_func_signature = get_func_signature(self._predict_func) | |||||
_check_loss_evaluate(prev_func_signature=prev_func_signature, func_signature=e.func_signature, | |||||
check_res=e.check_res, pred_dict=pred_dict, target_dict=batch_y, | |||||
dataset=self.data, check_level=0) | |||||
if self.verbose >= 1: | |||||
print("[tester] \n{}".format(self._format_eval_results(eval_results))) | |||||
self._mode(network, is_test=False) | |||||
return eval_results | return eval_results | ||||
def mode(self, model, is_test=False): | |||||
def _mode(self, model, is_test=False): | |||||
"""Train mode or Test mode. This is for PyTorch currently. | """Train mode or Test mode. This is for PyTorch currently. | ||||
:param model: a PyTorch model | :param model: a PyTorch model | ||||
@@ -66,16 +94,21 @@ class Tester(object): | |||||
else: | else: | ||||
model.train() | model.train() | ||||
def data_forward(self, network, x): | |||||
def _data_forward(self, func, x): | |||||
"""A forward pass of the model. """ | """A forward pass of the model. """ | ||||
x = _build_args(network.forward, **x) | |||||
y = self._predict_func(**x) | |||||
x = _build_args(func, **x) | |||||
y = func(**x) | |||||
return y | return y | ||||
def print_eval_results(self, results): | |||||
def _format_eval_results(self, results): | |||||
"""Override this method to support more print formats. | """Override this method to support more print formats. | ||||
:param results: dict, (str: float) is (metrics name: value) | :param results: dict, (str: float) is (metrics name: value) | ||||
""" | """ | ||||
return ", ".join([str(key) + "=" + str(value) for key, value in results.items()]) | |||||
_str = '' | |||||
for metric_name, metric_result in results.items(): | |||||
_str += metric_name + ': ' | |||||
_str += ", ".join([str(key) + "=" + str(value) for key, value in metric_result.items()]) | |||||
_str += '\n' | |||||
return _str[:-1] |
@@ -1,160 +1,275 @@ | |||||
import os | |||||
import time | import time | ||||
from datetime import timedelta | |||||
from datetime import datetime | from datetime import datetime | ||||
import warnings | |||||
from collections import defaultdict | |||||
import os | |||||
import itertools | |||||
import shutil | |||||
from datetime import timedelta | |||||
from tensorboardX import SummaryWriter | |||||
import torch | import torch | ||||
from tensorboardX import SummaryWriter | |||||
from torch import nn | |||||
from tqdm.autonotebook import tqdm | |||||
from fastNLP.core.batch import Batch | from fastNLP.core.batch import Batch | ||||
from fastNLP.core.loss import Loss | |||||
from fastNLP.core.metrics import Evaluator | |||||
from fastNLP.core.optimizer import Optimizer | |||||
from fastNLP.core.dataset import DataSet | |||||
from fastNLP.core.losses import _prepare_losser | |||||
from fastNLP.core.metrics import _prepare_metrics | |||||
from fastNLP.core.optimizer import Adam | |||||
from fastNLP.core.sampler import BaseSampler | |||||
from fastNLP.core.sampler import RandomSampler | from fastNLP.core.sampler import RandomSampler | ||||
from fastNLP.core.sampler import SequentialSampler | from fastNLP.core.sampler import SequentialSampler | ||||
from fastNLP.core.tester import Tester | from fastNLP.core.tester import Tester | ||||
from fastNLP.core.utils import _check_arg_dict_list | |||||
from fastNLP.core.utils import CheckError | |||||
from fastNLP.core.utils import _build_args | from fastNLP.core.utils import _build_args | ||||
from fastNLP.core.utils import _syn_model_data | |||||
from fastNLP.core.utils import _check_forward_error | |||||
from fastNLP.core.utils import _check_loss_evaluate | |||||
from fastNLP.core.utils import _move_dict_value_to_device | |||||
from fastNLP.core.utils import get_func_signature | from fastNLP.core.utils import get_func_signature | ||||
class Trainer(object): | class Trainer(object): | ||||
"""Main Training Loop | """Main Training Loop | ||||
""" | """ | ||||
def __init__(self, train_data, model, n_epochs=3, batch_size=32, print_every=-1, validate_every=-1, | |||||
dev_data=None, use_cuda=False, save_path="./save", | |||||
optimizer=Optimizer("Adam", lr=0.01, weight_decay=0), need_check_code=True, | |||||
**kwargs): | |||||
def __init__(self, train_data, model, loss=None, metrics=None, n_epochs=3, batch_size=32, print_every=50, | |||||
validate_every=-1, dev_data=None, use_cuda=False, save_path=None, | |||||
optimizer=Adam(lr=0.01, weight_decay=0), check_code_level=0, | |||||
metric_key=None, sampler=RandomSampler(), use_tqdm=True): | |||||
""" | |||||
:param DataSet train_data: the training data | |||||
:param torch.nn.modules.module model: a PyTorch model | |||||
:param LossBase loss: a loss object | |||||
:param MetricBase or List[MetricBase] metrics: a metric object or a list of metrics | |||||
:param int n_epochs: the number of training epochs | |||||
:param int batch_size: batch size for training and validation | |||||
:param int print_every: step interval to print next training information. Default: -1(no print). | |||||
:param int validate_every: step interval to do next validation. Default: -1(validate every epoch). | |||||
:param DataSet dev_data: the validation data | |||||
:param use_cuda: | |||||
:param save_path: file path to save models | |||||
:param Optimizer optimizer: an optimizer object | |||||
:param int check_code_level: level of FastNLP code checker. -1: don't check, 0: ignore. 1: warning. 2: strict. | |||||
`ignore` will not check unused field; `warning` when warn if some field are not used; `strict` means | |||||
it will raise error if some field are not used. | |||||
:param str metric_key: a single indicator used to decide the best model based on metric results. It must be one | |||||
of the keys returned by the FIRST metric in `metrics`. If the overall result gets better if the indicator gets | |||||
smaller, add a `-` character in front of the string. For example | |||||
:: | |||||
metric_key="-PPL" # language model gets better as perplexity gets smaller | |||||
:param sampler: method used to generate batch data. | |||||
:param use_tqdm: boolean, use tqdm to show train progress. | |||||
""" | |||||
super(Trainer, self).__init__() | super(Trainer, self).__init__() | ||||
if not isinstance(train_data, DataSet): | |||||
raise TypeError(f"The type of train_data must be fastNLP.DataSet, got {type(train_data)}.") | |||||
if not isinstance(model, nn.Module): | |||||
raise TypeError(f"The type of model must be torch.nn.Module, got {type(model)}.") | |||||
# check metrics and dev_data | |||||
if (not metrics) and dev_data is not None: | |||||
raise ValueError("No metric for dev_data evaluation.") | |||||
if metrics and (dev_data is None): | |||||
raise ValueError("No dev_data for evaluations, pass dev_data or set metrics to None. ") | |||||
# check save_path | |||||
if not (save_path is None or isinstance(save_path, str)): | |||||
raise ValueError("save_path can only be None or `str`.") | |||||
# prepare evaluate | |||||
metrics = _prepare_metrics(metrics) | |||||
# parse metric_key | |||||
# increase_better is True. It means the exp result gets better if the indicator increases. | |||||
# It is true by default. | |||||
self.increase_better = True | |||||
if metric_key is not None: | |||||
self.increase_better = False if metric_key[0] == "-" else True | |||||
self.metric_key = metric_key[1:] if metric_key[0] == "+" or metric_key[0] == "-" else metric_key | |||||
elif len(metrics) > 0: | |||||
self.metric_key = metrics[0].__class__.__name__.lower().strip('metric') | |||||
# prepare loss | |||||
losser = _prepare_losser(loss) | |||||
# sampler check | |||||
if not isinstance(sampler, BaseSampler): | |||||
raise ValueError("The type of sampler should be fastNLP.BaseSampler, got {}.".format(type(sampler))) | |||||
if check_code_level > -1: | |||||
_check_code(dataset=train_data, model=model, losser=losser, metrics=metrics, dev_data=dev_data, | |||||
metric_key=metric_key, check_level=check_code_level) | |||||
self.train_data = train_data | self.train_data = train_data | ||||
self.dev_data = dev_data # If None, No validation. | self.dev_data = dev_data # If None, No validation. | ||||
self.model = model | self.model = model | ||||
self.losser = losser | |||||
self.metrics = metrics | |||||
self.n_epochs = int(n_epochs) | self.n_epochs = int(n_epochs) | ||||
self.batch_size = int(batch_size) | self.batch_size = int(batch_size) | ||||
self.use_cuda = bool(use_cuda) | self.use_cuda = bool(use_cuda) | ||||
self.save_path = save_path | self.save_path = save_path | ||||
self.print_every = int(print_every) | self.print_every = int(print_every) | ||||
self.validate_every = int(validate_every) | self.validate_every = int(validate_every) | ||||
self._best_accuracy = 0 | |||||
self.best_metric_indicator = None | |||||
self.sampler = sampler | |||||
if need_check_code: | |||||
_check_code(dataset=train_data, model=model, dev_data=dev_data) | |||||
self._model_device = model.parameters().__next__().device | |||||
model_name = model.__class__.__name__ | |||||
assert hasattr(self.model, 'get_loss'), "model {} has to have a 'get_loss' function.".format(model_name) | |||||
self.loss_func = self.model.get_loss | |||||
if isinstance(optimizer, torch.optim.Optimizer): | if isinstance(optimizer, torch.optim.Optimizer): | ||||
self.optimizer = optimizer | self.optimizer = optimizer | ||||
else: | else: | ||||
self.optimizer = optimizer.construct_from_pytorch(self.model.parameters()) | self.optimizer = optimizer.construct_from_pytorch(self.model.parameters()) | ||||
assert hasattr(self.model, 'evaluate'), "model {} has to have a 'evaluate' function.".format(model_name) | |||||
self.evaluator = self.model.evaluate | |||||
self.use_tqdm = use_tqdm | |||||
if self.use_tqdm: | |||||
tester_verbose = 0 | |||||
else: | |||||
tester_verbose = 1 | |||||
if self.dev_data is not None: | if self.dev_data is not None: | ||||
self.tester = Tester(model=self.model, | self.tester = Tester(model=self.model, | ||||
data=self.dev_data, | data=self.dev_data, | ||||
metrics=self.metrics, | |||||
batch_size=self.batch_size, | batch_size=self.batch_size, | ||||
use_cuda=self.use_cuda) | |||||
for k, v in kwargs.items(): | |||||
setattr(self, k, v) | |||||
use_cuda=self.use_cuda, | |||||
verbose=tester_verbose) | |||||
self.step = 0 | self.step = 0 | ||||
self.start_time = None # start timestamp | self.start_time = None # start timestamp | ||||
# print(self.__dict__) | |||||
def train(self): | def train(self): | ||||
"""Start Training. | """Start Training. | ||||
:return: | |||||
""" | """ | ||||
try: | try: | ||||
if torch.cuda.is_available() and self.use_cuda: | if torch.cuda.is_available() and self.use_cuda: | ||||
self.model = self.model.cuda() | self.model = self.model.cuda() | ||||
self.mode(self.model, is_test=False) | |||||
self._mode(self.model, is_test=False) | |||||
start = time.time() | |||||
self.start_time = str(datetime.now().strftime('%Y-%m-%d-%H-%M-%S')) | |||||
print("training epochs started " + self.start_time) | |||||
self.start_time = str(datetime.now().strftime('%Y-%m-%d %H-%M-%S')) | |||||
print("training epochs started " + self.start_time, flush=True) | |||||
if self.save_path is None: | if self.save_path is None: | ||||
class psudoSW: | class psudoSW: | ||||
def __getattr__(self, item): | def __getattr__(self, item): | ||||
def pass_func(*args, **kwargs): | def pass_func(*args, **kwargs): | ||||
pass | pass | ||||
return pass_func | return pass_func | ||||
self._summary_writer = psudoSW() | self._summary_writer = psudoSW() | ||||
else: | else: | ||||
path = os.path.join(self.save_path, 'tensorboard_logs_{}'.format(self.start_time)) | path = os.path.join(self.save_path, 'tensorboard_logs_{}'.format(self.start_time)) | ||||
self._summary_writer = SummaryWriter(path) | self._summary_writer = SummaryWriter(path) | ||||
if self.use_tqdm: | |||||
self._tqdm_train() | |||||
else: | |||||
self._print_train() | |||||
epoch = 1 | |||||
while epoch <= self.n_epochs: | |||||
data_iterator = Batch(self.train_data, batch_size=self.batch_size, sampler=RandomSampler(), as_numpy=False) | |||||
self._train_epoch(data_iterator, self.model, epoch, self.dev_data, start) | |||||
# validate_every override validation at end of epochs | |||||
if self.dev_data and self.validate_every <= 0: | |||||
self.do_validation() | |||||
epoch += 1 | |||||
finally: | finally: | ||||
self._summary_writer.close() | self._summary_writer.close() | ||||
del self._summary_writer | del self._summary_writer | ||||
def _train_epoch(self, data_iterator, model, epoch, dev_data, start, **kwargs): | |||||
"""Training process in one epoch. | |||||
kwargs should contain: | |||||
- n_print: int, print training information every n steps. | |||||
- start: time.time(), the starting time of this step. | |||||
- epoch: int, | |||||
""" | |||||
for batch_x, batch_y in data_iterator: | |||||
prediction = self.data_forward(model, batch_x) | |||||
loss = self.get_loss(prediction, batch_y) | |||||
self.grad_backward(loss) | |||||
self.update() | |||||
self._summary_writer.add_scalar("loss", loss.item(), global_step=self.step) | |||||
for name, param in self.model.named_parameters(): | |||||
if param.requires_grad: | |||||
self._summary_writer.add_scalar(name + "_mean", param.mean(), global_step=self.step) | |||||
# self._summary_writer.add_scalar(name + "_std", param.std(), global_step=self.step) | |||||
# self._summary_writer.add_scalar(name + "_grad_sum", param.sum(), global_step=self.step) | |||||
if self.print_every > 0 and self.step % self.print_every == 0: | |||||
end = time.time() | |||||
diff = timedelta(seconds=round(end - start)) | |||||
print_output = "[epoch: {:>3} step: {:>4}] train loss: {:>4.6} time: {}".format( | |||||
epoch, self.step, loss.data, diff) | |||||
print(print_output) | |||||
if self.validate_every > 0 and self.step % self.validate_every == 0: | |||||
self.do_validation() | |||||
self.step += 1 | |||||
def do_validation(self): | |||||
def _tqdm_train(self): | |||||
self.step = 0 | |||||
data_iterator = Batch(self.train_data, batch_size=self.batch_size, sampler=self.sampler, | |||||
as_numpy=False) | |||||
total_steps = data_iterator.num_batches*self.n_epochs | |||||
epoch = 1 | |||||
with tqdm(total=total_steps, postfix='loss:{0:<6.5f}', leave=False, dynamic_ncols=True) as pbar: | |||||
ava_loss = 0 | |||||
for epoch in range(1, self.n_epochs+1): | |||||
pbar.set_description_str(desc="Epoch {}/{}".format(epoch, self.n_epochs)) | |||||
for batch_x, batch_y in data_iterator: | |||||
_move_dict_value_to_device(batch_x, batch_y, device=self._model_device) | |||||
prediction = self._data_forward(self.model, batch_x) | |||||
loss = self._compute_loss(prediction, batch_y) | |||||
ava_loss += loss.item() | |||||
self._grad_backward(loss) | |||||
self._update() | |||||
self._summary_writer.add_scalar("loss", loss.item(), global_step=self.step) | |||||
for name, param in self.model.named_parameters(): | |||||
if param.requires_grad: | |||||
self._summary_writer.add_scalar(name + "_mean", param.mean(), global_step=self.step) | |||||
# self._summary_writer.add_scalar(name + "_std", param.std(), global_step=self.step) | |||||
# self._summary_writer.add_scalar(name + "_grad_sum", param.sum(), global_step=self.step) | |||||
if (self.step+1) % self.print_every == 0: | |||||
pbar.set_postfix_str("loss:{0:<6.5f}".format(ava_loss / self.print_every)) | |||||
ava_loss = 0 | |||||
pbar.update(1) | |||||
self.step += 1 | |||||
if self.validate_every > 0 and self.step % self.validate_every == 0 \ | |||||
and self.dev_data is not None: | |||||
eval_res = self._do_validation() | |||||
eval_str = "Epoch {}/{}. Step:{}/{}. ".format(epoch, self.n_epochs, self.step, total_steps) + \ | |||||
self.tester._format_eval_results(eval_res) | |||||
pbar.write(eval_str) | |||||
if self.validate_every < 0 and self.dev_data: | |||||
eval_res = self._do_validation() | |||||
eval_str = "Epoch {}/{}. Step:{}/{}. ".format(epoch, self.n_epochs, self.step, total_steps) + \ | |||||
self.tester._format_eval_results(eval_res) | |||||
pbar.write(eval_str) | |||||
if epoch!=self.n_epochs: | |||||
data_iterator = Batch(self.train_data, batch_size=self.batch_size, sampler=self.sampler, | |||||
as_numpy=False) | |||||
pbar.close() | |||||
def _print_train(self): | |||||
epoch = 1 | |||||
start = time.time() | |||||
while epoch <= self.n_epochs: | |||||
data_iterator = Batch(self.train_data, batch_size=self.batch_size, sampler=self.sampler, | |||||
as_numpy=False) | |||||
for batch_x, batch_y in data_iterator: | |||||
# TODO 这里可能会遇到问题,万一用户在model内部修改了prediction的device就会有问题 | |||||
_move_dict_value_to_device(batch_x, batch_y, device=self._model_device) | |||||
prediction = self._data_forward(self.model, batch_x) | |||||
loss = self._compute_loss(prediction, batch_y) | |||||
self._grad_backward(loss) | |||||
self._update() | |||||
self._summary_writer.add_scalar("loss", loss.item(), global_step=self.step) | |||||
for name, param in self.model.named_parameters(): | |||||
if param.requires_grad: | |||||
self._summary_writer.add_scalar(name + "_mean", param.mean(), global_step=self.step) | |||||
# self._summary_writer.add_scalar(name + "_std", param.std(), global_step=self.step) | |||||
# self._summary_writer.add_scalar(name + "_grad_sum", param.sum(), global_step=self.step) | |||||
if self.print_every > 0 and self.step % self.print_every == 0: | |||||
end = time.time() | |||||
diff = timedelta(seconds=round(end - start)) | |||||
print_output = "[epoch: {:>3} step: {:>4}] train loss: {:>4.6} time: {}".format( | |||||
epoch, self.step, loss.data, diff) | |||||
print(print_output) | |||||
if (self.validate_every > 0 and self.step % self.validate_every == 0 and | |||||
self.dev_data is not None): | |||||
self._do_validation() | |||||
self.step += 1 | |||||
# validate_every override validation at end of epochs | |||||
if self.dev_data and self.validate_every <= 0: | |||||
self._do_validation() | |||||
epoch += 1 | |||||
def _do_validation(self): | |||||
res = self.tester.test() | res = self.tester.test() | ||||
for name, num in res.items(): | |||||
self._summary_writer.add_scalar("valid_{}".format(name), num, global_step=self.step) | |||||
if self.save_path is not None and self.best_eval_result(res): | |||||
self.save_model(self.model, 'best_model_' + self.start_time) | |||||
def mode(self, model, is_test=False): | |||||
for name, metric in res.items(): | |||||
for metric_key, metric_val in metric.items(): | |||||
self._summary_writer.add_scalar("valid_{}_{}".format(name, metric_key), metric_val, | |||||
global_step=self.step) | |||||
if self.save_path is not None and self._better_eval_result(res): | |||||
metric_key = self.metric_key if self.metric_key is not None else "" | |||||
self._save_model(self.model, | |||||
"best_" + "_".join([self.model.__class__.__name__, metric_key, self.start_time])) | |||||
return res | |||||
def _mode(self, model, is_test=False): | |||||
"""Train mode or Test mode. This is for PyTorch currently. | """Train mode or Test mode. This is for PyTorch currently. | ||||
:param model: a PyTorch model | :param model: a PyTorch model | ||||
:param is_test: bool, whether in test mode or not. | |||||
:param bool is_test: whether in test mode or not. | |||||
""" | """ | ||||
if is_test: | if is_test: | ||||
@@ -162,18 +277,20 @@ class Trainer(object): | |||||
else: | else: | ||||
model.train() | model.train() | ||||
def update(self): | |||||
def _update(self): | |||||
"""Perform weight update on a model. | """Perform weight update on a model. | ||||
""" | """ | ||||
self.optimizer.step() | self.optimizer.step() | ||||
def data_forward(self, network, x): | |||||
def _data_forward(self, network, x): | |||||
x = _build_args(network.forward, **x) | x = _build_args(network.forward, **x) | ||||
y = network(**x) | y = network(**x) | ||||
if not isinstance(y, dict): | |||||
raise TypeError(f"The return value of {get_func_signature(network.forward)} should be dict, got {type(y)}.") | |||||
return y | return y | ||||
def grad_backward(self, loss): | |||||
def _grad_backward(self, loss): | |||||
"""Compute gradient with link rules. | """Compute gradient with link rules. | ||||
:param loss: a scalar where back-prop starts | :param loss: a scalar where back-prop starts | ||||
@@ -183,223 +300,130 @@ class Trainer(object): | |||||
self.model.zero_grad() | self.model.zero_grad() | ||||
loss.backward() | loss.backward() | ||||
def get_loss(self, predict, truth): | |||||
def _compute_loss(self, predict, truth): | |||||
"""Compute loss given prediction and ground truth. | """Compute loss given prediction and ground truth. | ||||
:param predict: prediction label vector | |||||
:param truth: ground truth label vector | |||||
:param predict: prediction dict, produced by model.forward | |||||
:param truth: ground truth dict, produced by batch_y | |||||
:return: a scalar | :return: a scalar | ||||
""" | """ | ||||
assert isinstance(predict, dict) and isinstance(truth, dict) | |||||
args = _build_args(self.loss_func, **predict, **truth) | |||||
return self.loss_func(**args) | |||||
def save_model(self, model, model_name, only_param=False): | |||||
model_name = os.path.join(self.save_path, model_name) | |||||
if only_param: | |||||
torch.save(model.state_dict(), model_name) | |||||
else: | |||||
torch.save(model, model_name) | |||||
return self.losser(predict, truth) | |||||
def _save_model(self, model, model_name, only_param=False): | |||||
if self.save_path is not None: | |||||
model_name = os.path.join(self.save_path, model_name) | |||||
if only_param: | |||||
torch.save(model.state_dict(), model_name) | |||||
else: | |||||
torch.save(model, model_name) | |||||
def best_eval_result(self, metrics): | |||||
def _better_eval_result(self, metrics): | |||||
"""Check if the current epoch yields better validation results. | """Check if the current epoch yields better validation results. | ||||
:return: bool, True means current results on dev set is the best. | |||||
:return bool value: True means current results on dev set is the best. | |||||
""" | """ | ||||
if isinstance(metrics, tuple): | |||||
loss, metrics = metrics | |||||
if isinstance(metrics, dict): | |||||
if len(metrics) == 1: | |||||
accuracy = list(metrics.values())[0] | |||||
else: | |||||
accuracy = metrics[self.eval_sort_key] | |||||
indicator_val = _check_eval_results(metrics, self.metric_key, self.metrics) | |||||
is_better = True | |||||
if self.best_metric_indicator is None: | |||||
# first-time validation | |||||
self.best_metric_indicator = indicator_val | |||||
else: | else: | ||||
accuracy = metrics | |||||
if accuracy > self._best_accuracy: | |||||
self._best_accuracy = accuracy | |||||
return True | |||||
else: | |||||
return False | |||||
if self.increase_better is True: | |||||
if indicator_val > self.best_metric_indicator: | |||||
self.best_metric_indicator = indicator_val | |||||
else: | |||||
is_better = False | |||||
else: | |||||
if indicator_val < self.best_metric_indicator: | |||||
self.best_metric_indicator = indicator_val | |||||
else: | |||||
is_better = False | |||||
return is_better | |||||
DEFAULT_CHECK_BATCH_SIZE = 2 | DEFAULT_CHECK_BATCH_SIZE = 2 | ||||
DEFAULT_CHECK_NUM_BATCH = 2 | DEFAULT_CHECK_NUM_BATCH = 2 | ||||
IGNORE_CHECK_LEVEL = 0 | |||||
WARNING_CHECK_LEVEL = 1 | |||||
STRICT_CHECK_LEVEL = 2 | |||||
def _check_code(dataset, model, batch_size=DEFAULT_CHECK_BATCH_SIZE, dev_data=None, check_level=WARNING_CHECK_LEVEL): | |||||
def _check_code(dataset, model, losser, metrics, batch_size=DEFAULT_CHECK_BATCH_SIZE, | |||||
dev_data=None, metric_key=None, | |||||
check_level=0): | |||||
# check get_loss 方法 | # check get_loss 方法 | ||||
model_name = model.__class__.__name__ | |||||
if not hasattr(model, 'get_loss'): | |||||
raise AttributeError("{} has to have a 'get_loss' function.".format(model_name)) | |||||
model_devcie = model.parameters().__next__().device | |||||
batch = Batch(dataset=dataset, batch_size=batch_size, sampler=SequentialSampler()) | batch = Batch(dataset=dataset, batch_size=batch_size, sampler=SequentialSampler()) | ||||
for batch_count, (batch_x, batch_y) in enumerate(batch): | for batch_count, (batch_x, batch_y) in enumerate(batch): | ||||
_syn_model_data(model, batch_x, batch_y) | |||||
_move_dict_value_to_device(batch_x, batch_y, device=model_devcie) | |||||
# forward check | # forward check | ||||
if batch_count==0: | if batch_count==0: | ||||
_check_forward_error(model_func=model.forward, check_level=check_level, | |||||
batch_x=batch_x) | |||||
_check_forward_error(forward_func=model.forward, dataset=dataset, | |||||
batch_x=batch_x, check_level=check_level) | |||||
refined_batch_x = _build_args(model.forward, **batch_x) | refined_batch_x = _build_args(model.forward, **batch_x) | ||||
output = model(**refined_batch_x) | |||||
pred_dict = model(**refined_batch_x) | |||||
func_signature = get_func_signature(model.forward) | func_signature = get_func_signature(model.forward) | ||||
assert isinstance(output, dict), "The return value of {} should be dict.".format(func_signature) | |||||
if not isinstance(pred_dict, dict): | |||||
raise TypeError(f"The return value of {func_signature} should be `dict`, not `{type(pred_dict)}`.") | |||||
# loss check | # loss check | ||||
if batch_count == 0: | |||||
_check_loss_evaluate(prev_func=model.forward, func=model.get_loss, check_level=check_level, | |||||
output=output, batch_y=batch_y) | |||||
loss_input = _build_args(model.get_loss, **output, **batch_y) | |||||
loss = model.get_loss(**loss_input) | |||||
# check loss output | |||||
if batch_count == 0: | |||||
if not isinstance(loss, torch.Tensor): | |||||
raise ValueError("The return value of {}.get_loss() should be torch.Tensor, but {} got.". | |||||
format(model_name, type(loss))) | |||||
if len(loss.size())!=0: | |||||
raise ValueError("The size of return value of {}.get_loss() is {}, should be torch.size([])".format( | |||||
model_name, loss.size() | |||||
)) | |||||
loss.backward() | |||||
try: | |||||
loss = losser(pred_dict, batch_y) | |||||
# check loss output | |||||
if batch_count == 0: | |||||
if not isinstance(loss, torch.Tensor): | |||||
raise TypeError( | |||||
f"The return value of {get_func_signature(losser.get_loss)} should be `torch.Tensor`, " | |||||
f"but got `{type(loss)}`.") | |||||
if len(loss.size()) != 0: | |||||
raise ValueError( | |||||
f"The size of return value of {get_func_signature(losser.get_loss)} is {loss.size()}, " | |||||
f"should be torch.size([])") | |||||
loss.backward() | |||||
except CheckError as e: | |||||
# TODO: another error raised if CheckError caught | |||||
pre_func_signature = get_func_signature(model.forward) | |||||
_check_loss_evaluate(prev_func_signature=pre_func_signature, func_signature=e.func_signature, | |||||
check_res=e.check_res, pred_dict=pred_dict, target_dict=batch_y, | |||||
dataset=dataset, check_level=check_level) | |||||
model.zero_grad() | model.zero_grad() | ||||
if batch_count+1>=DEFAULT_CHECK_NUM_BATCH: | |||||
if batch_count + 1 >= DEFAULT_CHECK_NUM_BATCH: | |||||
break | break | ||||
if dev_data is not None: | if dev_data is not None: | ||||
if not hasattr(model, 'evaluate'): | |||||
raise AttributeError("{} has to have a 'evaluate' function to do evaluation. Or set" | |||||
"dev_data to 'None'." | |||||
.format(model_name)) | |||||
outputs, truths = defaultdict(list), defaultdict(list) | |||||
dev_batch = Batch(dataset=dataset, batch_size=batch_size, sampler=SequentialSampler()) | |||||
with torch.no_grad(): | |||||
for batch_count, (batch_x, batch_y) in enumerate(dev_batch): | |||||
_syn_model_data(model, batch_x, batch_y) | |||||
if hasattr(model, 'predict'): | |||||
refined_batch_x = _build_args(model.predict, **batch_x) | |||||
prev_func = model.predict | |||||
output = prev_func(**refined_batch_x) | |||||
func_signature = get_func_signature(model.predict) | |||||
assert isinstance(output, dict), "The return value of {} should be dict.".format(func_signature) | |||||
else: | |||||
refined_batch_x = _build_args(model.forward, **batch_x) | |||||
prev_func = model.forward | |||||
output = prev_func(**refined_batch_x) | |||||
for k, v in output.items(): | |||||
outputs[k].append(v) | |||||
for k, v in batch_y.items(): | |||||
truths[k].append(v) | |||||
if batch_count+1>DEFAULT_CHECK_NUM_BATCH: | |||||
break | |||||
for k, v in outputs.items(): | |||||
outputs[k] = itertools.chain(*v) | |||||
for k, v in truths.items(): | |||||
truths[k] = itertools.chain(*v) | |||||
_check_loss_evaluate(prev_func=prev_func, func=model.evaluate, check_level=check_level, | |||||
output=outputs, batch_y=truths) | |||||
refined_input = _build_args(model.evaluate, **outputs, **truths) | |||||
metrics = model.evaluate(**refined_input) | |||||
func_signature = get_func_signature(model.evaluate) | |||||
assert isinstance(metrics, dict), "The return value of {} should be dict.". \ | |||||
format(func_signature) | |||||
def _check_forward_error(model_func, check_level, batch_x): | |||||
check_res = _check_arg_dict_list(model_func, batch_x) | |||||
_missing = '' | |||||
_unused = '' | |||||
func_signature = get_func_signature(model_func) | |||||
if len(check_res.missing)!=0: | |||||
_missing = "Function {} misses {}, only provided with {}, " \ | |||||
".\n".format(func_signature, check_res.missing, | |||||
list(batch_x.keys())) | |||||
if len(check_res.unused)!=0: | |||||
if len(check_res.unused) > 1: | |||||
_unused = "{} are not used ".format(check_res.unused) | |||||
tester = Tester(data=dataset[:batch_size * DEFAULT_CHECK_NUM_BATCH], model=model, metrics=metrics, | |||||
batch_size=batch_size, verbose=-1) | |||||
evaluate_results = tester.test() | |||||
_check_eval_results(metrics=evaluate_results, metric_key=metric_key, metric_list=metrics) | |||||
def _check_eval_results(metrics, metric_key, metric_list): | |||||
# metrics: tester返回的结果 | |||||
# metric_key: 一个用来做筛选的指标,来自Trainer的初始化 | |||||
# metric_list: 多个用来做评价的指标,来自Trainer的初始化 | |||||
if isinstance(metrics, tuple): | |||||
loss, metrics = metrics | |||||
if isinstance(metrics, dict): | |||||
if len(metrics) == 1: | |||||
# only single metric, just use it | |||||
metric_dict = list(metrics.values())[0] | |||||
metrics_name = list(metrics.keys())[0] | |||||
else: | else: | ||||
_unused = "{} is not used ".format(check_res.unused) | |||||
_unused += "in function {}.\n".format(func_signature) | |||||
if _missing: | |||||
if len(_unused)>0 and STRICT_CHECK_LEVEL: | |||||
_error_str = "(1).{}\n(2).{}".format(_missing, _unused) | |||||
metrics_name = metric_list[0].__class__.__name__ | |||||
if metrics_name not in metrics: | |||||
raise RuntimeError(f"{metrics_name} is chosen to do validation, but got {metrics}") | |||||
metric_dict = metrics[metrics_name] | |||||
if len(metric_dict) == 1: | |||||
indicator_val, indicator = list(metric_dict.values())[0], list(metric_dict.keys())[0] | |||||
elif len(metric_dict) > 1 and metric_key is None: | |||||
raise RuntimeError( | |||||
f"Got multiple metric keys: {metric_dict}, but metric_key is not set. Which one to use?") | |||||
else: | else: | ||||
_error_str = _missing | |||||
# TODO 这里可能需要自定义一些Error类型 | |||||
raise TypeError(_error_str) | |||||
if _unused: | |||||
if check_level == STRICT_CHECK_LEVEL: | |||||
# TODO 这里可能需要自定义一些Error类型 | |||||
raise ValueError(_unused) | |||||
elif check_level == WARNING_CHECK_LEVEL: | |||||
warnings.warn(message=_unused) | |||||
def _check_loss_evaluate(prev_func, func, check_level, output, batch_y): | |||||
check_res = _check_arg_dict_list(func, [output, batch_y]) | |||||
_missing = '' | |||||
_unused = '' | |||||
_duplicated = '' | |||||
func_signature = get_func_signature(func) | |||||
prev_func_signature = get_func_signature(prev_func) | |||||
if len(check_res.missing)>0: | |||||
_missing = "function {} misses argument {}, \n\t only provided with {}(from {}) and " \ | |||||
"{}(from target in Dataset)." \ | |||||
.format(func_signature, check_res.missing, | |||||
list(output.keys()), prev_func_signature, | |||||
list(batch_y.keys())) | |||||
if len(check_res.unused)>0: | |||||
if len(check_res.unused) > 1: | |||||
_unused = "{} are not used ".format(check_res.unused) | |||||
else: | |||||
_unused = "{} is not used ".format(check_res.unused) | |||||
_unused += "in function {}.\n".format(func_signature) | |||||
if len(check_res.duplicated)>0: | |||||
if len(check_res.duplicated) > 1: | |||||
_duplicated = "duplicated keys {} are detected when calling function {}. \n\tDon't set {} as target and output " \ | |||||
"them in {} at the same time.".format(check_res.duplicated, | |||||
func_signature, | |||||
check_res.duplicated, | |||||
prev_func_signature) | |||||
else: | |||||
_duplicated = "duplicated key {} is detected when calling function {}. \n\tDon't set {} as target and output " \ | |||||
"it in {} at the same time.".format(check_res.duplicated, | |||||
func_signature, | |||||
check_res.duplicated, | |||||
prev_func_signature) | |||||
_number_errs = int(len(_missing)!=0) + int(len(_duplicated)!=0) + int(len(_unused)!=0) | |||||
if _number_errs > 0: | |||||
_error_strs = [] | |||||
if _number_errs > 1: | |||||
count = 0 | |||||
order_words = ['Firstly', 'Secondly', 'Thirdly'] | |||||
if _missing: | |||||
_error_strs.append('{}, {}'.format(order_words[count], _missing)) | |||||
count += 1 | |||||
if _duplicated: | |||||
_error_strs.append('{}, {}'.format(order_words[count], _duplicated)) | |||||
count += 1 | |||||
if _unused and check_level == STRICT_CHECK_LEVEL: | |||||
_error_strs.append('{}, {}'.format(order_words[count], _unused)) | |||||
else: | |||||
if _unused: | |||||
if check_level == STRICT_CHECK_LEVEL: | |||||
# TODO 这里可能需要自定义一些Error类型 | |||||
_error_strs.append(_unused) | |||||
elif check_level == WARNING_CHECK_LEVEL: | |||||
_unused = _unused.strip() | |||||
warnings.warn(_unused) | |||||
else: | |||||
if _missing: | |||||
_error_strs.append(_missing) | |||||
if _duplicated: | |||||
_error_strs.append(_duplicated) | |||||
if _error_strs: | |||||
raise ValueError('\n' + '\n'.join(_error_strs)) | |||||
# metric_key is set | |||||
if metric_key not in metric_dict: | |||||
raise RuntimeError(f"metric key {metric_key} not found in {metric_dict}") | |||||
indicator_val = metric_dict[metric_key] | |||||
else: | |||||
raise RuntimeError("Invalid metrics type. Expect {}, got {}".format((tuple, dict), type(metrics))) | |||||
return indicator_val |
@@ -1,10 +1,15 @@ | |||||
import _pickle | import _pickle | ||||
import inspect | import inspect | ||||
import os | import os | ||||
import warnings | |||||
from collections import Counter | from collections import Counter | ||||
from collections import namedtuple | from collections import namedtuple | ||||
CheckRes = namedtuple('CheckRes', ['missing', 'unused', 'duplicated', 'required', 'all_needed'], verbose=False) | |||||
import numpy as np | |||||
import torch | |||||
CheckRes = namedtuple('CheckRes', ['missing', 'unused', 'duplicated', 'required', 'all_needed', | |||||
'varargs'], verbose=False) | |||||
def save_pickle(obj, pickle_path, file_name): | def save_pickle(obj, pickle_path, file_name): | ||||
@@ -50,6 +55,7 @@ def pickle_exist(pickle_path, pickle_name): | |||||
else: | else: | ||||
return False | return False | ||||
def _build_args(func, **kwargs): | def _build_args(func, **kwargs): | ||||
spect = inspect.getfullargspec(func) | spect = inspect.getfullargspec(func) | ||||
if spect.varkw is not None: | if spect.varkw is not None: | ||||
@@ -64,6 +70,38 @@ def _build_args(func, **kwargs): | |||||
return output | return output | ||||
def _map_args(maps: dict, **kwargs): | |||||
# maps: key=old name, value= new name | |||||
output = {} | |||||
for name, val in kwargs.items(): | |||||
if name in maps: | |||||
assert isinstance(maps[name], str) | |||||
output.update({maps[name]: val}) | |||||
else: | |||||
output.update({name: val}) | |||||
for keys in maps.keys(): | |||||
if keys not in output.keys(): | |||||
# TODO: add UNUSED warning. | |||||
pass | |||||
return output | |||||
def _get_arg_list(func): | |||||
assert callable(func) | |||||
spect = inspect.getfullargspec(func) | |||||
if spect.defaults is not None: | |||||
args = spect.args[: -len(spect.defaults)] | |||||
defaults = spect.args[-len(spect.defaults):] | |||||
defaults_val = spect.defaults | |||||
else: | |||||
args = spect.args | |||||
defaults = None | |||||
defaults_val = None | |||||
varargs = spect.varargs | |||||
kwargs = spect.varkw | |||||
return args, defaults, defaults_val, varargs, kwargs | |||||
# check args | # check args | ||||
def _check_arg_dict_list(func, args): | def _check_arg_dict_list(func, args): | ||||
if isinstance(args, dict): | if isinstance(args, dict): | ||||
@@ -73,8 +111,7 @@ def _check_arg_dict_list(func, args): | |||||
assert callable(func) and isinstance(arg_dict_list, (list, tuple)) | assert callable(func) and isinstance(arg_dict_list, (list, tuple)) | ||||
assert len(arg_dict_list) > 0 and isinstance(arg_dict_list[0], dict) | assert len(arg_dict_list) > 0 and isinstance(arg_dict_list[0], dict) | ||||
spect = inspect.getfullargspec(func) | spect = inspect.getfullargspec(func) | ||||
assert spect.varargs is None, 'Positional Arguments({}) are not supported.'.format(spect.varargs) | |||||
all_args = set([arg for arg in spect.args if arg!='self']) | |||||
all_args = set([arg for arg in spect.args if arg != 'self']) | |||||
defaults = [] | defaults = [] | ||||
if spect.defaults is not None: | if spect.defaults is not None: | ||||
defaults = [arg for arg in spect.defaults] | defaults = [arg for arg in spect.defaults] | ||||
@@ -88,19 +125,39 @@ def _check_arg_dict_list(func, args): | |||||
input_args = set(input_arg_count.keys()) | input_args = set(input_arg_count.keys()) | ||||
missing = list(require_args - input_args) | missing = list(require_args - input_args) | ||||
unused = list(input_args - all_args) | unused = list(input_args - all_args) | ||||
varargs = [] if not spect.varargs else [arg for arg in spect.varargs] | |||||
return CheckRes(missing=missing, | return CheckRes(missing=missing, | ||||
unused=unused, | unused=unused, | ||||
duplicated=duplicated, | duplicated=duplicated, | ||||
required=list(require_args), | required=list(require_args), | ||||
all_needed=list(all_args)) | |||||
all_needed=list(all_args), | |||||
varargs=varargs) | |||||
def get_func_signature(func): | def get_func_signature(func): | ||||
# can only be used in function or class method | |||||
""" | |||||
Given a function or method, return its signature. | |||||
For example: | |||||
(1) function | |||||
def func(a, b='a', *args): | |||||
xxxx | |||||
get_func_signature(func) # 'func(a, b='a', *args)' | |||||
(2) method | |||||
class Demo: | |||||
def __init__(self): | |||||
xxx | |||||
def forward(self, a, b='a', **args) | |||||
demo = Demo() | |||||
get_func_signature(demo.forward) # 'Demo.forward(self, a, b='a', **args)' | |||||
:param func: a function or a method | |||||
:return: str or None | |||||
""" | |||||
if inspect.ismethod(func): | if inspect.ismethod(func): | ||||
class_name = func.__self__.__class__.__name__ | class_name = func.__self__.__class__.__name__ | ||||
signature = inspect.signature(func) | signature = inspect.signature(func) | ||||
signature_str = str(signature) | signature_str = str(signature) | ||||
if len(signature_str)>2: | |||||
if len(signature_str) > 2: | |||||
_self = '(self, ' | _self = '(self, ' | ||||
else: | else: | ||||
_self = '(self' | _self = '(self' | ||||
@@ -113,15 +170,263 @@ def get_func_signature(func): | |||||
return signature_str | return signature_str | ||||
# move data to model's device | |||||
import torch | |||||
def _syn_model_data(model, *args): | |||||
assert len(model.state_dict())!=0, "This model has no parameter." | |||||
device = model.parameters().__next__().device | |||||
def _is_function_or_method(func): | |||||
""" | |||||
:param func: | |||||
:return: | |||||
""" | |||||
if not inspect.ismethod(func) and not inspect.isfunction(func): | |||||
return False | |||||
return True | |||||
def _check_function_or_method(func): | |||||
if not _is_function_or_method(func): | |||||
raise TypeError(f"{type(func)} is not a method or function.") | |||||
def _move_dict_value_to_device(*args, device: torch.device): | |||||
""" | |||||
move data to model's device, element in *args should be dict. This is a inplace change. | |||||
:param device: torch.device | |||||
:param args: | |||||
:return: | |||||
""" | |||||
if not isinstance(device, torch.device): | |||||
raise TypeError(f"device must be `torch.device`, got `{type(device)}`") | |||||
for arg in args: | for arg in args: | ||||
if isinstance(arg, dict): | if isinstance(arg, dict): | ||||
for key, value in arg.items(): | for key, value in arg.items(): | ||||
if isinstance(value, torch.Tensor): | if isinstance(value, torch.Tensor): | ||||
arg[key] = value.to(device) | arg[key] = value.to(device) | ||||
else: | else: | ||||
raise ValueError("Only support dict type right now.") | |||||
raise TypeError("Only support `dict` type right now.") | |||||
class CheckError(Exception): | |||||
""" | |||||
CheckError. Used in losses.LossBase, metrics.MetricBase. | |||||
""" | |||||
def __init__(self, check_res: CheckRes, func_signature: str): | |||||
errs = [f'Problems occurred when calling `{func_signature}`'] | |||||
if check_res.varargs: | |||||
errs.append(f"\tvarargs: {check_res.varargs}(Does not support pass positional arguments, please delete it)") | |||||
if check_res.missing: | |||||
errs.append(f"\tmissing param: {check_res.missing}") | |||||
if check_res.duplicated: | |||||
errs.append(f"\tduplicated param: {check_res.duplicated}") | |||||
if check_res.unused: | |||||
errs.append(f"\tunused param: {check_res.unused}") | |||||
Exception.__init__(self, '\n'.join(errs)) | |||||
self.check_res = check_res | |||||
self.func_signature = func_signature | |||||
IGNORE_CHECK_LEVEL = 0 | |||||
WARNING_CHECK_LEVEL = 1 | |||||
STRICT_CHECK_LEVEL = 2 | |||||
def _check_loss_evaluate(prev_func_signature: str, func_signature: str, check_res: CheckRes, | |||||
pred_dict: dict, target_dict: dict, dataset, check_level=0): | |||||
errs = [] | |||||
unuseds = [] | |||||
_unused_field = [] | |||||
_unused_param = [] | |||||
suggestions = [] | |||||
if check_res.varargs: | |||||
errs.append(f"\tvarargs: *{check_res.varargs}") | |||||
suggestions.append(f"Does not support pass positional arguments, please delete *{check_res.varargs}.") | |||||
if check_res.unused: | |||||
for _unused in check_res.unused: | |||||
if _unused in target_dict: | |||||
_unused_field.append(_unused) | |||||
else: | |||||
_unused_param.append(_unused) | |||||
if _unused_field: | |||||
unuseds.append(f"\tunused field: {_unused_field}") | |||||
if _unused_param: | |||||
unuseds.append(f"\tunused param: {_unused_param}") # output from predict or forward | |||||
module_name = func_signature.split('.')[0] | |||||
if check_res.missing: | |||||
errs.append(f"\tmissing param: {check_res.missing}") | |||||
import re | |||||
mapped_missing = [] | |||||
unmapped_missing = [] | |||||
input_func_map = {} | |||||
for _miss in check_res.missing: | |||||
if '(' in _miss: | |||||
# if they are like 'SomeParam(assign to xxx)' | |||||
_miss = _miss.split('(')[0] | |||||
matches = re.findall("(?<=`)[a-zA-Z0-9]*?(?=`)", _miss) | |||||
if len(matches) == 2: | |||||
fun_arg, module_name = matches | |||||
input_func_map[_miss] = fun_arg | |||||
if fun_arg == _miss: | |||||
unmapped_missing.append(_miss) | |||||
else: | |||||
mapped_missing.append(_miss) | |||||
else: | |||||
unmapped_missing.append(_miss) | |||||
for _miss in mapped_missing: | |||||
if _miss in dataset: | |||||
suggestions.append(f"Set {_miss} as target.") | |||||
else: | |||||
_tmp = '' | |||||
if check_res.unused: | |||||
_tmp = f"Check key assignment for `{input_func_map.get(_miss, _miss)}` when initialize {module_name}." | |||||
if _tmp: | |||||
_tmp += f' Or provide {_miss} in DataSet or output of {prev_func_signature}.' | |||||
else: | |||||
_tmp = f'Provide {_miss} in DataSet or output of {prev_func_signature}.' | |||||
suggestions.append(_tmp) | |||||
for _miss in unmapped_missing: | |||||
if _miss in dataset: | |||||
suggestions.append(f"Set {_miss} as target.") | |||||
else: | |||||
_tmp = '' | |||||
if check_res.unused: | |||||
_tmp = f"Specify your assignment for `{input_func_map.get(_miss, _miss)}` when initialize {module_name}." | |||||
if _tmp: | |||||
_tmp += f' Or provide {_miss} in DataSet or output of {prev_func_signature}.' | |||||
else: | |||||
_tmp = f'Provide {_miss} in output of {prev_func_signature} or DataSet.' | |||||
suggestions.append(_tmp) | |||||
if check_res.duplicated: | |||||
errs.append(f"\tduplicated param: {check_res.duplicated}.") | |||||
suggestions.append(f"Delete {check_res.duplicated} in the output of " | |||||
f"{prev_func_signature} or do not set {check_res.duplicated} as targets. ") | |||||
if len(errs)>0: | |||||
errs.extend(unuseds) | |||||
elif check_level == STRICT_CHECK_LEVEL: | |||||
errs.extend(unuseds) | |||||
if len(errs) > 0: | |||||
errs.insert(0, f'Problems occurred when calling {func_signature}') | |||||
sugg_str = "" | |||||
if len(suggestions) > 1: | |||||
for idx, sugg in enumerate(suggestions): | |||||
if idx>0: | |||||
sugg_str += '\t\t\t' | |||||
sugg_str += f'({idx+1}). {sugg}\n' | |||||
sugg_str = sugg_str[:-1] | |||||
else: | |||||
sugg_str += suggestions[0] | |||||
errs.append(f'\ttarget field: {list(target_dict.keys())}') | |||||
errs.append(f'\tparam from {prev_func_signature}: {list(pred_dict.keys())}') | |||||
err_str = '\n' + '\n'.join(errs) + '\n\tSuggestion: ' + sugg_str | |||||
raise NameError(err_str) | |||||
if check_res.unused: | |||||
if check_level == WARNING_CHECK_LEVEL: | |||||
if not module_name: | |||||
module_name = func_signature.split('.')[0] | |||||
_unused_warn = f'{check_res.unused} is not used by {module_name}.' | |||||
warnings.warn(message=_unused_warn) | |||||
def _check_forward_error(forward_func, batch_x, dataset, check_level): | |||||
check_res = _check_arg_dict_list(forward_func, batch_x) | |||||
func_signature = get_func_signature(forward_func) | |||||
errs = [] | |||||
suggestions = [] | |||||
_unused = [] | |||||
if check_res.varargs: | |||||
errs.append(f"\tvarargs: {check_res.varargs}") | |||||
suggestions.append(f"Does not support pass positional arguments, please delete *{check_res.varargs}.") | |||||
if check_res.missing: | |||||
errs.append(f"\tmissing param: {check_res.missing}") | |||||
_miss_in_dataset = [] | |||||
_miss_out_dataset = [] | |||||
for _miss in check_res.missing: | |||||
if _miss in dataset: | |||||
_miss_in_dataset.append(_miss) | |||||
else: | |||||
_miss_out_dataset.append(_miss) | |||||
if _miss_in_dataset: | |||||
suggestions.append(f"You might need to set {_miss_in_dataset} as input. ") | |||||
if _miss_out_dataset: | |||||
_tmp = f"You need to provide {_miss_out_dataset} in DataSet and set it as input. " | |||||
# if check_res.unused: | |||||
# _tmp += f"Or you might find it in `unused field:`, you can use DataSet.rename_field() to " \ | |||||
# f"rename the field in `unused field:`." | |||||
suggestions.append(_tmp) | |||||
if check_res.unused: | |||||
_unused = [f"\tunused field: {check_res.unused}"] | |||||
if len(errs)>0: | |||||
errs.extend(_unused) | |||||
elif check_level == STRICT_CHECK_LEVEL: | |||||
errs.extend(_unused) | |||||
if len(errs) > 0: | |||||
errs.insert(0, f'Problems occurred when calling {func_signature}') | |||||
sugg_str = "" | |||||
if len(suggestions) > 1: | |||||
for idx, sugg in enumerate(suggestions): | |||||
sugg_str += f'({idx+1}). {sugg}' | |||||
else: | |||||
sugg_str += suggestions[0] | |||||
err_str = '\n' + '\n'.join(errs) + '\n\tSuggestion: ' + sugg_str | |||||
raise NameError(err_str) | |||||
if _unused: | |||||
if check_level == WARNING_CHECK_LEVEL: | |||||
_unused_warn = _unused[0] + f' in {func_signature}.' | |||||
warnings.warn(message=_unused_warn) | |||||
def seq_lens_to_masks(seq_lens, float=False): | |||||
""" | |||||
Convert seq_lens to masks. | |||||
:param seq_lens: list, np.ndarray, or torch.LongTensor, shape should all be (B,) | |||||
:param float: if True, the return masks is in float type, otherwise it is byte. | |||||
:return: list, np.ndarray or torch.Tensor, shape will be (B, max_length) | |||||
""" | |||||
if isinstance(seq_lens, np.ndarray): | |||||
assert len(np.shape(seq_lens)) == 1, f"seq_lens can only have one dimension, got {len(np.shape(seq_lens))}." | |||||
assert seq_lens.dtype in (int, np.int32, np.int64), f"seq_lens can only be integer, not {seq_lens.dtype}." | |||||
raise NotImplemented | |||||
elif isinstance(seq_lens, torch.LongTensor): | |||||
assert len(seq_lens.size()) == 1, f"seq_lens can only have one dimension, got {len(seq_lens.size())==1}." | |||||
batch_size = seq_lens.size(0) | |||||
max_len = seq_lens.max() | |||||
indexes = torch.arange(max_len).view(1, -1).repeat(batch_size, 1).to(seq_lens.device) | |||||
masks = indexes.lt(seq_lens.unsqueeze(1)) | |||||
if float: | |||||
masks = masks.float() | |||||
return masks | |||||
elif isinstance(seq_lens, list): | |||||
raise NotImplemented | |||||
else: | |||||
raise NotImplemented | |||||
def seq_mask(seq_len, max_len): | |||||
"""Create sequence mask. | |||||
:param seq_len: list or torch.Tensor, the lengths of sequences in a batch. | |||||
:param max_len: int, the maximum sequence length in a batch. | |||||
:return mask: torch.LongTensor, [batch_size, max_len] | |||||
""" | |||||
if not isinstance(seq_len, torch.Tensor): | |||||
seq_len = torch.LongTensor(seq_len) | |||||
seq_len = seq_len.view(-1, 1).long() # [batch_size, 1] | |||||
seq_range = torch.arange(start=0, end=max_len, dtype=torch.long, device=seq_len.device).view(1, -1) # [1, max_len] | |||||
return torch.gt(seq_len, seq_range) # [batch_size, max_len] |
@@ -1,24 +1,31 @@ | |||||
from collections import Counter | from collections import Counter | ||||
from copy import deepcopy | |||||
DEFAULT_PADDING_LABEL = '<pad>' # dict index = 0 | |||||
DEFAULT_UNKNOWN_LABEL = '<unk>' # dict index = 1 | |||||
DEFAULT_WORD_TO_INDEX = {DEFAULT_PADDING_LABEL: 0, DEFAULT_UNKNOWN_LABEL: 1} | |||||
def check_build_vocab(func): | |||||
"""A decorator to make sure the indexing is built before used. | |||||
""" | |||||
def _wrapper(self, *args, **kwargs): | |||||
if self.word2idx is None or self.rebuild is True: | |||||
self.build_vocab() | |||||
return func(self, *args, **kwargs) | |||||
def isiterable(p_object): | |||||
try: | |||||
_ = iter(p_object) | |||||
except TypeError: | |||||
return False | |||||
return True | |||||
return _wrapper | |||||
def check_build_vocab(func): | |||||
def check_build_status(func): | |||||
"""A decorator to check whether the vocabulary updates after the last build. | |||||
""" | |||||
def _wrapper(self, *args, **kwargs): | def _wrapper(self, *args, **kwargs): | ||||
if self.word2idx is None: | |||||
self.build_vocab() | |||||
if self.rebuild is False: | |||||
self.rebuild = True | |||||
if self.max_size is not None and len(self.word_count) >= self.max_size: | |||||
print("[Warning] Vocabulary has reached the max size {} when calling {} method. " | |||||
"Adding more words may cause unexpected behaviour of Vocabulary. ".format( | |||||
self.max_size, func.__name__)) | |||||
return func(self, *args, **kwargs) | return func(self, *args, **kwargs) | ||||
return _wrapper | return _wrapper | ||||
@@ -36,25 +43,21 @@ class Vocabulary(object): | |||||
vocab.to_word(5) | vocab.to_word(5) | ||||
""" | """ | ||||
def __init__(self, need_default=True, max_size=None, min_freq=None): | |||||
def __init__(self, max_size=None, min_freq=None, unknown='<unk>', padding='<pad>'): | |||||
""" | """ | ||||
:param bool need_default: set if the Vocabulary has default labels reserved for sequences. Default: True. | |||||
:param int max_size: set the max number of words in Vocabulary. Default: None | :param int max_size: set the max number of words in Vocabulary. Default: None | ||||
:param int min_freq: set the min occur frequency of words in Vocabulary. Default: None | :param int min_freq: set the min occur frequency of words in Vocabulary. Default: None | ||||
""" | """ | ||||
self.max_size = max_size | self.max_size = max_size | ||||
self.min_freq = min_freq | self.min_freq = min_freq | ||||
self.word_count = Counter() | self.word_count = Counter() | ||||
self.has_default = need_default | |||||
if self.has_default: | |||||
self.padding_label = DEFAULT_PADDING_LABEL | |||||
self.unknown_label = DEFAULT_UNKNOWN_LABEL | |||||
else: | |||||
self.padding_label = None | |||||
self.unknown_label = None | |||||
self.unknown = unknown | |||||
self.padding = padding | |||||
self.word2idx = None | self.word2idx = None | ||||
self.idx2word = None | self.idx2word = None | ||||
self.rebuild = True | |||||
@check_build_status | |||||
def update(self, word_lst): | def update(self, word_lst): | ||||
"""Add a list of words into the vocabulary. | """Add a list of words into the vocabulary. | ||||
@@ -62,6 +65,7 @@ class Vocabulary(object): | |||||
""" | """ | ||||
self.word_count.update(word_lst) | self.word_count.update(word_lst) | ||||
@check_build_status | |||||
def add(self, word): | def add(self, word): | ||||
"""Add a single word into the vocabulary. | """Add a single word into the vocabulary. | ||||
@@ -69,6 +73,7 @@ class Vocabulary(object): | |||||
""" | """ | ||||
self.word_count[word] += 1 | self.word_count[word] += 1 | ||||
@check_build_status | |||||
def add_word(self, word): | def add_word(self, word): | ||||
"""Add a single word into the vocabulary. | """Add a single word into the vocabulary. | ||||
@@ -76,6 +81,7 @@ class Vocabulary(object): | |||||
""" | """ | ||||
self.add(word) | self.add(word) | ||||
@check_build_status | |||||
def add_word_lst(self, word_lst): | def add_word_lst(self, word_lst): | ||||
"""Add a list of words into the vocabulary. | """Add a list of words into the vocabulary. | ||||
@@ -87,20 +93,22 @@ class Vocabulary(object): | |||||
"""Build 'word to index' dict, and filter the word using `max_size` and `min_freq`. | """Build 'word to index' dict, and filter the word using `max_size` and `min_freq`. | ||||
""" | """ | ||||
if self.has_default: | |||||
self.word2idx = deepcopy(DEFAULT_WORD_TO_INDEX) | |||||
self.word2idx[self.unknown_label] = self.word2idx.pop(DEFAULT_UNKNOWN_LABEL) | |||||
self.word2idx[self.padding_label] = self.word2idx.pop(DEFAULT_PADDING_LABEL) | |||||
else: | |||||
self.word2idx = {} | |||||
self.word2idx = {} | |||||
if self.padding is not None: | |||||
self.word2idx[self.padding] = 0 | |||||
if self.unknown is not None: | |||||
self.word2idx[self.unknown] = 1 | |||||
max_size = min(self.max_size, len(self.word_count)) if self.max_size else None | max_size = min(self.max_size, len(self.word_count)) if self.max_size else None | ||||
words = self.word_count.most_common(max_size) | words = self.word_count.most_common(max_size) | ||||
if self.min_freq is not None: | if self.min_freq is not None: | ||||
words = filter(lambda kv: kv[1] >= self.min_freq, words) | words = filter(lambda kv: kv[1] >= self.min_freq, words) | ||||
if self.word2idx is not None: | |||||
words = filter(lambda kv: kv[0] not in self.word2idx, words) | |||||
start_idx = len(self.word2idx) | start_idx = len(self.word2idx) | ||||
self.word2idx.update({w: i + start_idx for i, (w, _) in enumerate(words)}) | self.word2idx.update({w: i + start_idx for i, (w, _) in enumerate(words)}) | ||||
self.build_reverse_vocab() | self.build_reverse_vocab() | ||||
self.rebuild = False | |||||
def build_reverse_vocab(self): | def build_reverse_vocab(self): | ||||
"""Build 'index to word' dict based on 'word to index' dict. | """Build 'index to word' dict based on 'word to index' dict. | ||||
@@ -132,8 +140,8 @@ class Vocabulary(object): | |||||
""" | """ | ||||
if w in self.word2idx: | if w in self.word2idx: | ||||
return self.word2idx[w] | return self.word2idx[w] | ||||
elif self.has_default: | |||||
return self.word2idx[self.unknown_label] | |||||
if self.unknown is not None: | |||||
return self.word2idx[self.unknown] | |||||
else: | else: | ||||
raise ValueError("word {} not in vocabulary".format(w)) | raise ValueError("word {} not in vocabulary".format(w)) | ||||
@@ -148,21 +156,16 @@ class Vocabulary(object): | |||||
@property | @property | ||||
@check_build_vocab | @check_build_vocab | ||||
def unknown_idx(self): | def unknown_idx(self): | ||||
if self.unknown_label is None: | |||||
if self.unknown is None: | |||||
return None | return None | ||||
return self.word2idx[self.unknown_label] | |||||
def __setattr__(self, name, val): | |||||
self.__dict__[name] = val | |||||
if name in ["unknown_label", "padding_label"]: | |||||
self.word2idx = None | |||||
return self.word2idx[self.unknown] | |||||
@property | @property | ||||
@check_build_vocab | @check_build_vocab | ||||
def padding_idx(self): | def padding_idx(self): | ||||
if self.padding_label is None: | |||||
if self.padding is None: | |||||
return None | return None | ||||
return self.word2idx[self.padding_label] | |||||
return self.word2idx[self.padding] | |||||
@check_build_vocab | @check_build_vocab | ||||
def to_word(self, idx): | def to_word(self, idx): | ||||
@@ -188,4 +191,3 @@ class Vocabulary(object): | |||||
""" | """ | ||||
self.__dict__.update(state) | self.__dict__.update(state) | ||||
self.build_reverse_vocab() | self.build_reverse_vocab() | ||||
@@ -31,17 +31,21 @@ class BaseLoader(object): | |||||
return obj | return obj | ||||
class ToyLoader0(BaseLoader): | |||||
""" | |||||
For CharLM | |||||
""" | |||||
def __init__(self, data_path): | |||||
super(ToyLoader0, self).__init__(data_path) | |||||
def load(self): | |||||
with open(self.data_path, 'r') as f: | |||||
corpus = f.read().lower() | |||||
import re | |||||
corpus = re.sub(r"<unk>", "unk", corpus) | |||||
return corpus.split() | |||||
class DataLoaderRegister: | |||||
""""register for data sets""" | |||||
_readers = {} | |||||
@classmethod | |||||
def set_reader(cls, reader_cls, read_fn_name): | |||||
# def wrapper(reader_cls): | |||||
if read_fn_name in cls._readers: | |||||
raise KeyError('duplicate reader: {} and {} for read_func: {}'.format(cls._readers[read_fn_name], reader_cls, read_fn_name)) | |||||
if hasattr(reader_cls, 'load'): | |||||
cls._readers[read_fn_name] = reader_cls().load | |||||
return reader_cls | |||||
@classmethod | |||||
def get_reader(cls, read_fn_name): | |||||
if read_fn_name in cls._readers: | |||||
return cls._readers[read_fn_name] | |||||
raise AttributeError('no read function: {}'.format(read_fn_name)) |
@@ -1,6 +1,152 @@ | |||||
import configparser | |||||
import json | |||||
import os | import os | ||||
from fastNLP.io.config_loader import ConfigSection, ConfigLoader | |||||
from fastNLP.io.base_loader import BaseLoader | |||||
class ConfigLoader(BaseLoader): | |||||
"""loader for configuration files""" | |||||
def __init__(self, data_path=None): | |||||
super(ConfigLoader, self).__init__() | |||||
if data_path is not None: | |||||
self.config = self.parse(super(ConfigLoader, self).load(data_path)) | |||||
@staticmethod | |||||
def parse(string): | |||||
raise NotImplementedError | |||||
@staticmethod | |||||
def load_config(file_path, sections): | |||||
""" | |||||
:param file_path: the path of config file | |||||
:param sections: the dict of {section_name(string): Section instance} | |||||
Example: | |||||
test_args = ConfigSection() | |||||
ConfigLoader("config.cfg", "").load_config("./data_for_tests/config", {"POS_test": test_args}) | |||||
:return: return nothing, but the value of attributes are saved in sessions | |||||
""" | |||||
assert isinstance(sections, dict) | |||||
cfg = configparser.ConfigParser() | |||||
if not os.path.exists(file_path): | |||||
raise FileNotFoundError("config file {} not found. ".format(file_path)) | |||||
cfg.read(file_path) | |||||
for s in sections: | |||||
attr_list = [i for i in sections[s].__dict__.keys() if | |||||
not callable(getattr(sections[s], i)) and not i.startswith("__")] | |||||
if s not in cfg: | |||||
print('section %s not found in config file' % (s)) | |||||
continue | |||||
gen_sec = cfg[s] | |||||
for attr in gen_sec.keys(): | |||||
try: | |||||
val = json.loads(gen_sec[attr]) | |||||
# print(s, attr, val, type(val)) | |||||
if attr in attr_list: | |||||
assert type(val) == type(getattr(sections[s], attr)), \ | |||||
'type not match, except %s but got %s' % \ | |||||
(type(getattr(sections[s], attr)), type(val)) | |||||
""" | |||||
if attr in attr_list then check its type and | |||||
update its value. | |||||
else add a new attr in sections[s] | |||||
""" | |||||
setattr(sections[s], attr, val) | |||||
except Exception as e: | |||||
print("cannot load attribute %s in section %s" | |||||
% (attr, s)) | |||||
pass | |||||
class ConfigSection(object): | |||||
def __init__(self): | |||||
pass | |||||
def __getitem__(self, key): | |||||
""" | |||||
:param key: str, the name of the attribute | |||||
:return attr: the value of this attribute | |||||
if key not in self.__dict__.keys(): | |||||
return self[key] | |||||
else: | |||||
raise AttributeError | |||||
""" | |||||
if key in self.__dict__.keys(): | |||||
return getattr(self, key) | |||||
raise AttributeError("do NOT have attribute %s" % key) | |||||
def __setitem__(self, key, value): | |||||
""" | |||||
:param key: str, the name of the attribute | |||||
:param value: the value of this attribute | |||||
if key not in self.__dict__.keys(): | |||||
self[key] will be added | |||||
else: | |||||
self[key] will be updated | |||||
""" | |||||
if key in self.__dict__.keys(): | |||||
if not isinstance(value, type(getattr(self, key))): | |||||
raise AttributeError("attr %s except %s but got %s" % | |||||
(key, str(type(getattr(self, key))), str(type(value)))) | |||||
setattr(self, key, value) | |||||
def __contains__(self, item): | |||||
""" | |||||
:param item: The key of item. | |||||
:return: True if the key in self.__dict__.keys() else False. | |||||
""" | |||||
return item in self.__dict__.keys() | |||||
def __eq__(self, other): | |||||
"""Overwrite the == operator | |||||
:param other: Another ConfigSection() object which to be compared. | |||||
:return: True if value of each key in each ConfigSection() object are equal to the other, else False. | |||||
""" | |||||
for k in self.__dict__.keys(): | |||||
if k not in other.__dict__.keys(): | |||||
return False | |||||
if getattr(self, k) != getattr(self, k): | |||||
return False | |||||
for k in other.__dict__.keys(): | |||||
if k not in self.__dict__.keys(): | |||||
return False | |||||
if getattr(self, k) != getattr(self, k): | |||||
return False | |||||
return True | |||||
def __ne__(self, other): | |||||
"""Overwrite the != operator | |||||
:param other: | |||||
:return: | |||||
""" | |||||
return not self.__eq__(other) | |||||
@property | |||||
def data(self): | |||||
return self.__dict__ | |||||
if __name__ == "__main__": | |||||
config = ConfigLoader('there is no data') | |||||
section = {'General': ConfigSection(), 'My': ConfigSection(), 'A': ConfigSection()} | |||||
""" | |||||
General and My can be found in config file, so the attr and | |||||
value will be updated | |||||
A cannot be found in config file, so nothing will be done | |||||
""" | |||||
config.load_config("../../test/data_for_tests/config", section) | |||||
for s in section: | |||||
print(s) | |||||
for attr in section[s].__dict__.keys(): | |||||
print(s, attr, getattr(section[s], attr), type(getattr(section[s], attr))) | |||||
class ConfigSaver(object): | class ConfigSaver(object): | ||||
@@ -125,7 +271,7 @@ class ConfigSaver(object): | |||||
# logger = create_logger(__name__, "./config_loader.log") | # logger = create_logger(__name__, "./config_loader.log") | ||||
# logger.warning("section [%s] in config file [%s] has been changed" % ( | # logger.warning("section [%s] in config file [%s] has been changed" % ( | ||||
# section_name, self.file_path | # section_name, self.file_path | ||||
#)) | |||||
# )) | |||||
change_file = True | change_file = True | ||||
break | break | ||||
if not change_file: | if not change_file: |
@@ -1,149 +0,0 @@ | |||||
import configparser | |||||
import json | |||||
import os | |||||
from fastNLP.io.base_loader import BaseLoader | |||||
class ConfigLoader(BaseLoader): | |||||
"""loader for configuration files""" | |||||
def __init__(self, data_path=None): | |||||
super(ConfigLoader, self).__init__() | |||||
if data_path is not None: | |||||
self.config = self.parse(super(ConfigLoader, self).load(data_path)) | |||||
@staticmethod | |||||
def parse(string): | |||||
raise NotImplementedError | |||||
@staticmethod | |||||
def load_config(file_path, sections): | |||||
""" | |||||
:param file_path: the path of config file | |||||
:param sections: the dict of {section_name(string): Section instance} | |||||
Example: | |||||
test_args = ConfigSection() | |||||
ConfigLoader("config.cfg", "").load_config("./data_for_tests/config", {"POS_test": test_args}) | |||||
:return: return nothing, but the value of attributes are saved in sessions | |||||
""" | |||||
assert isinstance(sections, dict) | |||||
cfg = configparser.ConfigParser() | |||||
if not os.path.exists(file_path): | |||||
raise FileNotFoundError("config file {} not found. ".format(file_path)) | |||||
cfg.read(file_path) | |||||
for s in sections: | |||||
attr_list = [i for i in sections[s].__dict__.keys() if | |||||
not callable(getattr(sections[s], i)) and not i.startswith("__")] | |||||
if s not in cfg: | |||||
print('section %s not found in config file' % (s)) | |||||
continue | |||||
gen_sec = cfg[s] | |||||
for attr in gen_sec.keys(): | |||||
try: | |||||
val = json.loads(gen_sec[attr]) | |||||
# print(s, attr, val, type(val)) | |||||
if attr in attr_list: | |||||
assert type(val) == type(getattr(sections[s], attr)), \ | |||||
'type not match, except %s but got %s' % \ | |||||
(type(getattr(sections[s], attr)), type(val)) | |||||
""" | |||||
if attr in attr_list then check its type and | |||||
update its value. | |||||
else add a new attr in sections[s] | |||||
""" | |||||
setattr(sections[s], attr, val) | |||||
except Exception as e: | |||||
print("cannot load attribute %s in section %s" | |||||
% (attr, s)) | |||||
pass | |||||
class ConfigSection(object): | |||||
def __init__(self): | |||||
pass | |||||
def __getitem__(self, key): | |||||
""" | |||||
:param key: str, the name of the attribute | |||||
:return attr: the value of this attribute | |||||
if key not in self.__dict__.keys(): | |||||
return self[key] | |||||
else: | |||||
raise AttributeError | |||||
""" | |||||
if key in self.__dict__.keys(): | |||||
return getattr(self, key) | |||||
raise AttributeError("do NOT have attribute %s" % key) | |||||
def __setitem__(self, key, value): | |||||
""" | |||||
:param key: str, the name of the attribute | |||||
:param value: the value of this attribute | |||||
if key not in self.__dict__.keys(): | |||||
self[key] will be added | |||||
else: | |||||
self[key] will be updated | |||||
""" | |||||
if key in self.__dict__.keys(): | |||||
if not isinstance(value, type(getattr(self, key))): | |||||
raise AttributeError("attr %s except %s but got %s" % | |||||
(key, str(type(getattr(self, key))), str(type(value)))) | |||||
setattr(self, key, value) | |||||
def __contains__(self, item): | |||||
""" | |||||
:param item: The key of item. | |||||
:return: True if the key in self.__dict__.keys() else False. | |||||
""" | |||||
return item in self.__dict__.keys() | |||||
def __eq__(self, other): | |||||
"""Overwrite the == operator | |||||
:param other: Another ConfigSection() object which to be compared. | |||||
:return: True if value of each key in each ConfigSection() object are equal to the other, else False. | |||||
""" | |||||
for k in self.__dict__.keys(): | |||||
if k not in other.__dict__.keys(): | |||||
return False | |||||
if getattr(self, k) != getattr(self, k): | |||||
return False | |||||
for k in other.__dict__.keys(): | |||||
if k not in self.__dict__.keys(): | |||||
return False | |||||
if getattr(self, k) != getattr(self, k): | |||||
return False | |||||
return True | |||||
def __ne__(self, other): | |||||
"""Overwrite the != operator | |||||
:param other: | |||||
:return: | |||||
""" | |||||
return not self.__eq__(other) | |||||
@property | |||||
def data(self): | |||||
return self.__dict__ | |||||
if __name__ == "__main__": | |||||
config = ConfigLoader('there is no data') | |||||
section = {'General': ConfigSection(), 'My': ConfigSection(), 'A': ConfigSection()} | |||||
""" | |||||
General and My can be found in config file, so the attr and | |||||
value will be updated | |||||
A cannot be found in config file, so nothing will be done | |||||
""" | |||||
config.load_config("../../test/data_for_tests/config", section) | |||||
for s in section: | |||||
print(s) | |||||
for attr in section[s].__dict__.keys(): | |||||
print(s, attr, getattr(section[s], attr), type(getattr(section[s], attr))) |
@@ -1,9 +1,8 @@ | |||||
#TODO: need fix for current DataSet | |||||
import os | import os | ||||
from fastNLP.core.dataset import DataSet | from fastNLP.core.dataset import DataSet | ||||
from fastNLP.core.instance import Instance | from fastNLP.core.instance import Instance | ||||
from fastNLP.io.base_loader import BaseLoader | |||||
from fastNLP.io.base_loader import DataLoaderRegister | |||||
def convert_seq_dataset(data): | def convert_seq_dataset(data): | ||||
@@ -20,8 +19,7 @@ def convert_seq_dataset(data): | |||||
""" | """ | ||||
dataset = DataSet() | dataset = DataSet() | ||||
for word_seq in data: | for word_seq in data: | ||||
x = TextField(word_seq, is_target=False) | |||||
dataset.append(Instance(word_seq=x)) | |||||
dataset.append(Instance(word_seq=word_seq)) | |||||
return dataset | return dataset | ||||
@@ -40,11 +38,7 @@ def convert_seq2tag_dataset(data): | |||||
""" | """ | ||||
dataset = DataSet() | dataset = DataSet() | ||||
for sample in data: | for sample in data: | ||||
word_seq, label = sample[0], sample[1] | |||||
ins = Instance() | |||||
ins.add_field("word_seq", TextField(word_seq, is_target=False)) \ | |||||
.add_field("label", LabelField(label, is_target=True)) | |||||
dataset.append(ins) | |||||
dataset.append(Instance(word_seq=sample[0], label=sample[1])) | |||||
return dataset | return dataset | ||||
@@ -63,20 +57,13 @@ def convert_seq2seq_dataset(data): | |||||
""" | """ | ||||
dataset = DataSet() | dataset = DataSet() | ||||
for sample in data: | for sample in data: | ||||
word_seq, label_seq = sample[0], sample[1] | |||||
ins = Instance() | |||||
ins.add_field("word_seq", TextField(word_seq, is_target=False)) \ | |||||
.add_field("label_seq", TextField(label_seq, is_target=True)) | |||||
dataset.append(ins) | |||||
dataset.append(Instance(word_seq=sample[0], label_seq=sample[1])) | |||||
return dataset | return dataset | ||||
class DataSetLoader(BaseLoader): | |||||
class DataSetLoader: | |||||
""""loader for data sets""" | """"loader for data sets""" | ||||
def __init__(self): | |||||
super(DataSetLoader, self).__init__() | |||||
def load(self, path): | def load(self, path): | ||||
""" load data in `path` into a dataset | """ load data in `path` into a dataset | ||||
""" | """ | ||||
@@ -88,7 +75,20 @@ class DataSetLoader(BaseLoader): | |||||
raise NotImplementedError | raise NotImplementedError | ||||
@DataSet.set_reader('read_raw') | |||||
class NativeDataSetLoader(DataSetLoader): | |||||
def __init__(self): | |||||
super(NativeDataSetLoader, self).__init__() | |||||
def load(self, path): | |||||
ds = DataSet.read_csv(path, headers=("raw_sentence", "label"), sep="\t") | |||||
ds.set_input("raw_sentence") | |||||
ds.set_target("label") | |||||
return ds | |||||
DataLoaderRegister.set_reader(NativeDataSetLoader, 'read_naive') | |||||
class RawDataSetLoader(DataSetLoader): | class RawDataSetLoader(DataSetLoader): | ||||
def __init__(self): | def __init__(self): | ||||
super(RawDataSetLoader, self).__init__() | super(RawDataSetLoader, self).__init__() | ||||
@@ -104,7 +104,9 @@ class RawDataSetLoader(DataSetLoader): | |||||
return convert_seq_dataset(data) | return convert_seq_dataset(data) | ||||
@DataSet.set_reader('read_pos') | |||||
DataLoaderRegister.set_reader(RawDataSetLoader, 'read_rawdata') | |||||
class POSDataSetLoader(DataSetLoader): | class POSDataSetLoader(DataSetLoader): | ||||
"""Dataset Loader for POS Tag datasets. | """Dataset Loader for POS Tag datasets. | ||||
@@ -174,7 +176,9 @@ class POSDataSetLoader(DataSetLoader): | |||||
return convert_seq2seq_dataset(data) | return convert_seq2seq_dataset(data) | ||||
@DataSet.set_reader('read_tokenize') | |||||
DataLoaderRegister.set_reader(POSDataSetLoader, 'read_pos') | |||||
class TokenizeDataSetLoader(DataSetLoader): | class TokenizeDataSetLoader(DataSetLoader): | ||||
""" | """ | ||||
Data set loader for tokenization data sets | Data set loader for tokenization data sets | ||||
@@ -234,7 +238,6 @@ class TokenizeDataSetLoader(DataSetLoader): | |||||
return convert_seq2seq_dataset(data) | return convert_seq2seq_dataset(data) | ||||
@DataSet.set_reader('read_class') | |||||
class ClassDataSetLoader(DataSetLoader): | class ClassDataSetLoader(DataSetLoader): | ||||
"""Loader for classification data sets""" | """Loader for classification data sets""" | ||||
@@ -273,7 +276,6 @@ class ClassDataSetLoader(DataSetLoader): | |||||
return convert_seq2tag_dataset(data) | return convert_seq2tag_dataset(data) | ||||
@DataSet.set_reader('read_conll') | |||||
class ConllLoader(DataSetLoader): | class ConllLoader(DataSetLoader): | ||||
"""loader for conll format files""" | """loader for conll format files""" | ||||
@@ -315,7 +317,6 @@ class ConllLoader(DataSetLoader): | |||||
pass | pass | ||||
@DataSet.set_reader('read_lm') | |||||
class LMDataSetLoader(DataSetLoader): | class LMDataSetLoader(DataSetLoader): | ||||
"""Language Model Dataset Loader | """Language Model Dataset Loader | ||||
@@ -352,7 +353,6 @@ class LMDataSetLoader(DataSetLoader): | |||||
pass | pass | ||||
@DataSet.set_reader('read_people_daily') | |||||
class PeopleDailyCorpusLoader(DataSetLoader): | class PeopleDailyCorpusLoader(DataSetLoader): | ||||
""" | """ | ||||
People Daily Corpus: Chinese word segmentation, POS tag, NER | People Daily Corpus: Chinese word segmentation, POS tag, NER | ||||
@@ -403,10 +403,19 @@ class PeopleDailyCorpusLoader(DataSetLoader): | |||||
pos_tag_examples.append([sent_words, sent_pos_tag]) | pos_tag_examples.append([sent_words, sent_pos_tag]) | ||||
ner_examples.append([sent_words, sent_ner]) | ner_examples.append([sent_words, sent_ner]) | ||||
# List[List[List[str], List[str]]] | # List[List[List[str], List[str]]] | ||||
return pos_tag_examples, ner_examples | |||||
# ner_examples not used | |||||
return self.convert(pos_tag_examples) | |||||
def convert(self, data): | def convert(self, data): | ||||
pass | |||||
data_set = DataSet() | |||||
for item in data: | |||||
sent_words, sent_pos_tag = item[0], item[1] | |||||
data_set.append(Instance(words=sent_words, tags=sent_pos_tag)) | |||||
data_set.apply(lambda ins: len(ins), new_field_name="seq_len") | |||||
data_set.set_target("tags") | |||||
data_set.set_input("sent_words") | |||||
data_set.set_input("seq_len") | |||||
return data_set | |||||
class SNLIDataSetLoader(DataSetLoader): | class SNLIDataSetLoader(DataSetLoader): | ||||
@@ -462,17 +471,13 @@ class SNLIDataSetLoader(DataSetLoader): | |||||
for example in data: | for example in data: | ||||
p, h, l = example | p, h, l = example | ||||
# list, list, str | # list, list, str | ||||
x1 = TextField(p, is_target=False) | |||||
x2 = TextField(h, is_target=False) | |||||
x1_len = TextField([1] * len(p), is_target=False) | |||||
x2_len = TextField([1] * len(h), is_target=False) | |||||
y = LabelField(l, is_target=True) | |||||
instance = Instance() | instance = Instance() | ||||
instance.add_field("premise", x1) | |||||
instance.add_field("hypothesis", x2) | |||||
instance.add_field("premise_len", x1_len) | |||||
instance.add_field("hypothesis_len", x2_len) | |||||
instance.add_field("truth", y) | |||||
instance.add_field("premise", p) | |||||
instance.add_field("hypothesis", h) | |||||
instance.add_field("truth", l) | |||||
data_set.append(instance) | data_set.append(instance) | ||||
data_set.apply(lambda ins: len(ins["premise"]), new_field_name="premise_len") | |||||
data_set.apply(lambda ins: len(ins["hypothesis"]), new_field_name="hypothesis_len") | |||||
data_set.set_input("premise", "hypothesis", "premise_len", "hypothesis_len") | |||||
data_set.set_target("truth") | |||||
return data_set | return data_set |
@@ -1,3 +1,4 @@ | |||||
import numpy as np | |||||
import torch | import torch | ||||
from fastNLP.core.vocabulary import Vocabulary | from fastNLP.core.vocabulary import Vocabulary | ||||
@@ -26,7 +27,7 @@ class EmbedLoader(BaseLoader): | |||||
emb = {} | emb = {} | ||||
with open(emb_file, 'r', encoding='utf-8') as f: | with open(emb_file, 'r', encoding='utf-8') as f: | ||||
for line in f: | for line in f: | ||||
line = list(filter(lambda w: len(w)>0, line.strip().split(' '))) | |||||
line = list(filter(lambda w: len(w) > 0, line.strip().split(' '))) | |||||
if len(line) > 2: | if len(line) > 2: | ||||
emb[line[0]] = torch.Tensor(list(map(float, line[1:]))) | emb[line[0]] = torch.Tensor(list(map(float, line[1:]))) | ||||
return emb | return emb | ||||
@@ -35,9 +36,9 @@ class EmbedLoader(BaseLoader): | |||||
def _load_pretrain(emb_file, emb_type): | def _load_pretrain(emb_file, emb_type): | ||||
"""Read txt data from embedding file and convert to np.array as pre-trained embedding | """Read txt data from embedding file and convert to np.array as pre-trained embedding | ||||
:param emb_file: str, the pre-trained embedding file path | |||||
:param emb_type: str, the pre-trained embedding data format | |||||
:return dict: {str: np.array} | |||||
:param str emb_file: the pre-trained embedding file path | |||||
:param str emb_type: the pre-trained embedding data format | |||||
:return dict embedding: `{str: np.array}` | |||||
""" | """ | ||||
if emb_type == 'glove': | if emb_type == 'glove': | ||||
return EmbedLoader._load_glove(emb_file) | return EmbedLoader._load_glove(emb_file) | ||||
@@ -45,38 +46,68 @@ class EmbedLoader(BaseLoader): | |||||
raise Exception("embedding type {} not support yet".format(emb_type)) | raise Exception("embedding type {} not support yet".format(emb_type)) | ||||
@staticmethod | @staticmethod | ||||
def load_embedding(emb_dim, emb_file, emb_type, vocab, emb_pkl): | |||||
def load_embedding(emb_dim, emb_file, emb_type, vocab): | |||||
"""Load the pre-trained embedding and combine with the given dictionary. | """Load the pre-trained embedding and combine with the given dictionary. | ||||
:param emb_dim: int, the dimension of the embedding. Should be the same as pre-trained embedding. | |||||
:param emb_file: str, the pre-trained embedding file path. | |||||
:param emb_type: str, the pre-trained embedding format, support glove now | |||||
:param vocab: Vocabulary, a mapping from word to index, can be provided by user or built from pre-trained embedding | |||||
:param emb_pkl: str, the embedding pickle file. | |||||
:param int emb_dim: the dimension of the embedding. Should be the same as pre-trained embedding. | |||||
:param str emb_file: the pre-trained embedding file path. | |||||
:param str emb_type: the pre-trained embedding format, support glove now | |||||
:param Vocabulary vocab: a mapping from word to index, can be provided by user or built from pre-trained embedding | |||||
:return embedding_tensor: Tensor of shape (len(word_dict), emb_dim) | :return embedding_tensor: Tensor of shape (len(word_dict), emb_dim) | ||||
vocab: input vocab or vocab built by pre-train | vocab: input vocab or vocab built by pre-train | ||||
TODO: fragile code | |||||
""" | """ | ||||
# If the embedding pickle exists, load it and return. | |||||
# if os.path.exists(emb_pkl): | |||||
# with open(emb_pkl, "rb") as f: | |||||
# embedding_tensor, vocab = _pickle.load(f) | |||||
# return embedding_tensor, vocab | |||||
# Otherwise, load the pre-trained embedding. | |||||
pretrain = EmbedLoader._load_pretrain(emb_file, emb_type) | pretrain = EmbedLoader._load_pretrain(emb_file, emb_type) | ||||
if vocab is None: | if vocab is None: | ||||
# build vocabulary from pre-trained embedding | # build vocabulary from pre-trained embedding | ||||
vocab = Vocabulary() | vocab = Vocabulary() | ||||
for w in pretrain.keys(): | for w in pretrain.keys(): | ||||
vocab.update(w) | |||||
vocab.add(w) | |||||
embedding_tensor = torch.randn(len(vocab), emb_dim) | embedding_tensor = torch.randn(len(vocab), emb_dim) | ||||
for w, v in pretrain.items(): | for w, v in pretrain.items(): | ||||
if len(v.shape) > 1 or emb_dim != v.shape[0]: | if len(v.shape) > 1 or emb_dim != v.shape[0]: | ||||
raise ValueError('pretrian embedding dim is {}, dismatching required {}'.format(v.shape, (emb_dim,))) | |||||
raise ValueError( | |||||
"Pretrained embedding dim is {}. Dimension dismatched. Required {}".format(v.shape, (emb_dim,))) | |||||
if vocab.has_word(w): | if vocab.has_word(w): | ||||
embedding_tensor[vocab[w]] = v | embedding_tensor[vocab[w]] = v | ||||
# save and return the result | |||||
# with open(emb_pkl, "wb") as f: | |||||
# _pickle.dump((embedding_tensor, vocab), f) | |||||
return embedding_tensor, vocab | return embedding_tensor, vocab | ||||
@staticmethod | |||||
def parse_glove_line(line): | |||||
line = list(filter(lambda w: len(w) > 0, line.strip().split(" "))) | |||||
if len(line) <= 2: | |||||
raise RuntimeError("something goes wrong in parsing glove embedding") | |||||
return line[0], torch.Tensor(list(map(float, line[1:]))) | |||||
@staticmethod | |||||
def fast_load_embedding(emb_dim, emb_file, vocab): | |||||
"""Fast load the pre-trained embedding and combine with the given dictionary. | |||||
This loading method uses line-by-line operation. | |||||
:param int emb_dim: the dimension of the embedding. Should be the same as pre-trained embedding. | |||||
:param str emb_file: the pre-trained embedding file path. | |||||
:param Vocabulary vocab: a mapping from word to index, can be provided by user or built from pre-trained embedding | |||||
:return numpy.ndarray embedding_matrix: | |||||
""" | |||||
if vocab is None: | |||||
raise RuntimeError("You must provide a vocabulary.") | |||||
embedding_matrix = np.zeros(shape=(len(vocab), emb_dim)) | |||||
hit_flags = np.zeros(shape=(len(vocab),), dtype=int) | |||||
with open(emb_file, "r", encoding="utf-8") as f: | |||||
for line in f: | |||||
word, vector = EmbedLoader.parse_glove_line(line) | |||||
if word in vocab: | |||||
if len(vector.shape) > 1 or emb_dim != vector.shape[0]: | |||||
raise ValueError("Pre-trained embedding dim is {}. Expect {}.".format(vector.shape, (emb_dim,))) | |||||
embedding_matrix[vocab[word]] = vector | |||||
hit_flags[vocab[word]] = 1 | |||||
if np.sum(hit_flags) < len(vocab): | |||||
# some words from vocab are missing in pre-trained embedding | |||||
# we normally sample each dimension | |||||
vocab_embed = embedding_matrix[np.where(hit_flags)] | |||||
sampled_vectors = np.random.normal(vocab_embed.mean(axis=0), vocab_embed.std(axis=0), | |||||
size=(len(vocab) - np.sum(hit_flags), emb_dim)) | |||||
embedding_matrix[np.where(1 - hit_flags)] = sampled_vectors | |||||
return embedding_matrix |
@@ -1,5 +1,32 @@ | |||||
import torch | import torch | ||||
from fastNLP.io.base_loader import BaseLoader | |||||
class ModelLoader(BaseLoader): | |||||
""" | |||||
Loader for models. | |||||
""" | |||||
def __init__(self): | |||||
super(ModelLoader, self).__init__() | |||||
@staticmethod | |||||
def load_pytorch(empty_model, model_path): | |||||
""" | |||||
Load model parameters from .pkl files into the empty PyTorch model. | |||||
:param empty_model: a PyTorch model with initialized parameters. | |||||
:param model_path: str, the path to the saved model. | |||||
""" | |||||
empty_model.load_state_dict(torch.load(model_path)) | |||||
@staticmethod | |||||
def load_pytorch_model(model_path): | |||||
"""Load the entire model. | |||||
""" | |||||
return torch.load(model_path) | |||||
class ModelSaver(object): | class ModelSaver(object): | ||||
"""Save a model | """Save a model | ||||
@@ -8,6 +35,7 @@ class ModelSaver(object): | |||||
saver.save_pytorch(model) | saver.save_pytorch(model) | ||||
""" | """ | ||||
def __init__(self, save_path): | def __init__(self, save_path): | ||||
""" | """ | ||||
@@ -1,28 +0,0 @@ | |||||
import torch | |||||
from fastNLP.io.base_loader import BaseLoader | |||||
class ModelLoader(BaseLoader): | |||||
""" | |||||
Loader for models. | |||||
""" | |||||
def __init__(self): | |||||
super(ModelLoader, self).__init__() | |||||
@staticmethod | |||||
def load_pytorch(empty_model, model_path): | |||||
""" | |||||
Load model parameters from .pkl files into the empty PyTorch model. | |||||
:param empty_model: a PyTorch model with initialized parameters. | |||||
:param model_path: str, the path to the saved model. | |||||
""" | |||||
empty_model.load_state_dict(torch.load(model_path)) | |||||
@staticmethod | |||||
def load_pytorch_model(model_path): | |||||
"""Load the entire model. | |||||
""" | |||||
return torch.load(model_path) |
@@ -1,6 +1,6 @@ | |||||
import torch | import torch | ||||
from fastNLP.core.trainer import Trainer | |||||
from fastNLP.modules.decoder.MLP import MLP | |||||
class BaseModel(torch.nn.Module): | class BaseModel(torch.nn.Module): | ||||
@@ -11,8 +11,19 @@ class BaseModel(torch.nn.Module): | |||||
super(BaseModel, self).__init__() | super(BaseModel, self).__init__() | ||||
def fit(self, train_data, dev_data=None, **train_args): | def fit(self, train_data, dev_data=None, **train_args): | ||||
trainer = Trainer(**train_args) | |||||
trainer.train(self, train_data, dev_data) | |||||
pass | |||||
def predict(self, *args, **kwargs): | def predict(self, *args, **kwargs): | ||||
raise NotImplementedError | raise NotImplementedError | ||||
class NaiveClassifier(BaseModel): | |||||
def __init__(self, in_feature_dim, out_feature_dim): | |||||
super(NaiveClassifier, self).__init__() | |||||
self.mlp = MLP([in_feature_dim, in_feature_dim, out_feature_dim]) | |||||
def forward(self, x): | |||||
return {"predict": torch.sigmoid(self.mlp(x))} | |||||
def predict(self, x): | |||||
return {"predict": torch.sigmoid(self.mlp(x)) > 0.5} |
@@ -18,8 +18,8 @@ class CNNText(torch.nn.Module): | |||||
def __init__(self, embed_num, | def __init__(self, embed_num, | ||||
embed_dim, | embed_dim, | ||||
num_classes, | num_classes, | ||||
kernel_nums=(3,4,5), | |||||
kernel_sizes=(3,4,5), | |||||
kernel_nums=(3, 4, 5), | |||||
kernel_sizes=(3, 4, 5), | |||||
padding=0, | padding=0, | ||||
dropout=0.5): | dropout=0.5): | ||||
super(CNNText, self).__init__() | super(CNNText, self).__init__() | ||||
@@ -33,7 +33,6 @@ class CNNText(torch.nn.Module): | |||||
padding=padding) | padding=padding) | ||||
self.dropout = nn.Dropout(dropout) | self.dropout = nn.Dropout(dropout) | ||||
self.fc = encoder.Linear(sum(kernel_nums), num_classes) | self.fc = encoder.Linear(sum(kernel_nums), num_classes) | ||||
self._loss = nn.CrossEntropyLoss() | |||||
def forward(self, word_seq): | def forward(self, word_seq): | ||||
""" | """ | ||||
@@ -45,7 +44,7 @@ class CNNText(torch.nn.Module): | |||||
x = self.conv_pool(x) # [N,L,C] -> [N,C] | x = self.conv_pool(x) # [N,L,C] -> [N,C] | ||||
x = self.dropout(x) | x = self.dropout(x) | ||||
x = self.fc(x) # [N,C] -> [N, N_class] | x = self.fc(x) # [N,C] -> [N, N_class] | ||||
return {'output':x} | |||||
return {'pred': x} | |||||
def predict(self, word_seq): | def predict(self, word_seq): | ||||
""" | """ | ||||
@@ -54,28 +53,5 @@ class CNNText(torch.nn.Module): | |||||
:return predict: dict of torch.LongTensor, [batch_size, seq_len] | :return predict: dict of torch.LongTensor, [batch_size, seq_len] | ||||
""" | """ | ||||
output = self(word_seq) | output = self(word_seq) | ||||
_, predict = output['output'].max(dim=1) | |||||
return {'predict': predict} | |||||
def get_loss(self, output, label_seq): | |||||
""" | |||||
:param output: output of forward(), [batch_size, seq_len] | |||||
:param label_seq: true label in DataSet, [batch_size, seq_len] | |||||
:return loss: torch.Tensor | |||||
""" | |||||
return self._loss(output, label_seq) | |||||
def evaluate(self, predict, label_seq): | |||||
""" | |||||
:param predict: iterable predict tensors | |||||
:param label_seq: iterable true label tensors | |||||
:return accuracy: dict of float | |||||
""" | |||||
predict, label_seq = torch.stack(tuple(predict), dim=0), torch.stack(tuple(label_seq), dim=0) | |||||
predict, label_seq = predict.squeeze(), label_seq.squeeze() | |||||
correct = (predict == label_seq).long().sum().item() | |||||
total = label_seq.size(0) | |||||
return {'acc': 1.0 * correct / total} | |||||
_, predict = output['pred'].max(dim=1) | |||||
return {'pred': predict} |
@@ -43,7 +43,7 @@ class ConvCharEmbedding(nn.Module): | |||||
# [batch_size*sent_length, feature_maps[i], 1, width - kernels[i] + 1] | # [batch_size*sent_length, feature_maps[i], 1, width - kernels[i] + 1] | ||||
y = torch.squeeze(y, 2) | y = torch.squeeze(y, 2) | ||||
# [batch_size*sent_length, feature_maps[i], width - kernels[i] + 1] | # [batch_size*sent_length, feature_maps[i], width - kernels[i] + 1] | ||||
y = F.tanh(y) | |||||
y = torch.tanh(y) | |||||
y, __ = torch.max(y, 2) | y, __ = torch.max(y, 2) | ||||
# [batch_size*sent_length, feature_maps[i]] | # [batch_size*sent_length, feature_maps[i]] | ||||
feats.append(y) | feats.append(y) | ||||
@@ -5,7 +5,7 @@ sys.path.extend(['/home/yfshao/workdir/dev_fastnlp']) | |||||
from fastNLP.api.processor import * | from fastNLP.api.processor import * | ||||
from fastNLP.models.biaffine_parser import BiaffineParser | from fastNLP.models.biaffine_parser import BiaffineParser | ||||
from fastNLP.io.config_loader import ConfigSection, ConfigLoader | |||||
from fastNLP.io.config_io import ConfigSection, ConfigLoader | |||||
import _pickle as pickle | import _pickle as pickle | ||||
import torch | import torch | ||||
@@ -13,11 +13,10 @@ from fastNLP.core.vocabulary import Vocabulary | |||||
from fastNLP.core.dataset import DataSet | from fastNLP.core.dataset import DataSet | ||||
from fastNLP.core.field import TextField, SeqLabelField | from fastNLP.core.field import TextField, SeqLabelField | ||||
from fastNLP.core.tester import Tester | from fastNLP.core.tester import Tester | ||||
from fastNLP.io.config_loader import ConfigLoader, ConfigSection | |||||
from fastNLP.io.model_loader import ModelLoader | |||||
from fastNLP.io.config_io import ConfigLoader, ConfigSection | |||||
from fastNLP.io.model_io import ModelLoader, ModelSaver | |||||
from fastNLP.io.embed_loader import EmbedLoader | from fastNLP.io.embed_loader import EmbedLoader | ||||
from fastNLP.models.biaffine_parser import BiaffineParser | from fastNLP.models.biaffine_parser import BiaffineParser | ||||
from fastNLP.io.model_saver import ModelSaver | |||||
BOS = '<BOS>' | BOS = '<BOS>' | ||||
EOS = '<EOS>' | EOS = '<EOS>' | ||||
@@ -2,8 +2,8 @@ import torch.nn.functional as F | |||||
from fastNLP.core.trainer import ClassificationTrainer | from fastNLP.core.trainer import ClassificationTrainer | ||||
from fastNLP.core.utils import ClassPreprocess as Preprocess | from fastNLP.core.utils import ClassPreprocess as Preprocess | ||||
from fastNLP.io.config_loader import ConfigLoader | |||||
from fastNLP.io.config_loader import ConfigSection | |||||
from fastNLP.io.config_io import ConfigLoader | |||||
from fastNLP.io.config_io import ConfigSection | |||||
from fastNLP.io.dataset_loader import ClassDataSetLoader as Dataset_loader | from fastNLP.io.dataset_loader import ClassDataSetLoader as Dataset_loader | ||||
from fastNLP.models.base_model import BaseModel | from fastNLP.models.base_model import BaseModel | ||||
from fastNLP.modules.aggregator.self_attention import SelfAttention | from fastNLP.modules.aggregator.self_attention import SelfAttention | ||||
@@ -3,12 +3,11 @@ import sys | |||||
sys.path.append(os.path.join(os.path.dirname(__file__), '../..')) | sys.path.append(os.path.join(os.path.dirname(__file__), '../..')) | ||||
from fastNLP.io.config_loader import ConfigLoader, ConfigSection | |||||
from fastNLP.io.config_io import ConfigLoader, ConfigSection | |||||
from fastNLP.core.trainer import SeqLabelTrainer | from fastNLP.core.trainer import SeqLabelTrainer | ||||
from fastNLP.io.dataset_loader import BaseLoader, TokenizeDataSetLoader | from fastNLP.io.dataset_loader import BaseLoader, TokenizeDataSetLoader | ||||
from fastNLP.core.utils import load_pickle | from fastNLP.core.utils import load_pickle | ||||
from fastNLP.io.model_saver import ModelSaver | |||||
from fastNLP.io.model_loader import ModelLoader | |||||
from fastNLP.io.model_io import ModelLoader, ModelSaver | |||||
from fastNLP.core.tester import SeqLabelTester | from fastNLP.core.tester import SeqLabelTester | ||||
from fastNLP.models.sequence_modeling import AdvSeqLabel | from fastNLP.models.sequence_modeling import AdvSeqLabel | ||||
from fastNLP.core.predictor import SeqLabelInfer | from fastNLP.core.predictor import SeqLabelInfer | ||||
@@ -1,4 +1,4 @@ | |||||
numpy>=1.14.2 | numpy>=1.14.2 | ||||
torch>=0.4.0 | torch>=0.4.0 | ||||
torchvision>=0.1.8 | |||||
tensorboardX | tensorboardX | ||||
tqdm>=4.28.1 |
@@ -12,12 +12,12 @@ with open('requirements.txt', encoding='utf-8') as f: | |||||
reqs = f.read() | reqs = f.read() | ||||
setup( | setup( | ||||
name='fastNLP', | |||||
name='FastNLP', | |||||
version='0.1.1', | version='0.1.1', | ||||
description='fastNLP: Deep Learning Toolkit for NLP, developed by Fudan FastNLP Team', | description='fastNLP: Deep Learning Toolkit for NLP, developed by Fudan FastNLP Team', | ||||
long_description=readme, | long_description=readme, | ||||
license=license, | license=license, | ||||
author='fudanNLP', | |||||
author='FudanNLP', | |||||
python_requires='>=3.5', | python_requires='>=3.5', | ||||
packages=find_packages(), | packages=find_packages(), | ||||
install_requires=reqs.strip().split('\n'), | install_requires=reqs.strip().split('\n'), | ||||
@@ -0,0 +1,12 @@ | |||||
import unittest | |||||
from fastNLP.api.processor import FullSpaceToHalfSpaceProcessor | |||||
from fastNLP.core.dataset import DataSet | |||||
class TestProcessor(unittest.TestCase): | |||||
def test_FullSpaceToHalfSpaceProcessor(self): | |||||
ds = DataSet({"word": ["00, u1, u), (u2, u2"]}) | |||||
proc = FullSpaceToHalfSpaceProcessor("word") | |||||
ds = proc(ds) | |||||
self.assertTrue(ds.field_arrays["word"].content, ["00, u1, u), (u2, u2"]) |
@@ -22,8 +22,8 @@ class TestCase1(unittest.TestCase): | |||||
def test_dataset_batching(self): | def test_dataset_batching(self): | ||||
ds = DataSet({"x": [[1, 2, 3, 4]] * 40, "y": [[5, 6]] * 40}) | ds = DataSet({"x": [[1, 2, 3, 4]] * 40, "y": [[5, 6]] * 40}) | ||||
ds.set_input(x=True) | |||||
ds.set_target(y=True) | |||||
ds.set_input("x") | |||||
ds.set_target("y") | |||||
iter = Batch(ds, batch_size=4, sampler=SequentialSampler(), as_numpy=True) | iter = Batch(ds, batch_size=4, sampler=SequentialSampler(), as_numpy=True) | ||||
for x, y in iter: | for x, y in iter: | ||||
self.assertTrue(isinstance(x["x"], np.ndarray) and isinstance(y["y"], np.ndarray)) | self.assertTrue(isinstance(x["x"], np.ndarray) and isinstance(y["y"], np.ndarray)) | ||||
@@ -1,6 +1,8 @@ | |||||
import os | |||||
import unittest | import unittest | ||||
from fastNLP.core.dataset import DataSet | from fastNLP.core.dataset import DataSet | ||||
from fastNLP.core.fieldarray import FieldArray | |||||
from fastNLP.core.instance import Instance | from fastNLP.core.instance import Instance | ||||
@@ -44,6 +46,9 @@ class TestDataSet(unittest.TestCase): | |||||
self.assertEqual(dd.field_arrays["y"].content, [[1, 2, 3, 4]] * 10) | self.assertEqual(dd.field_arrays["y"].content, [[1, 2, 3, 4]] * 10) | ||||
self.assertEqual(dd.field_arrays["z"].content, [[5, 6]] * 10) | self.assertEqual(dd.field_arrays["z"].content, [[5, 6]] * 10) | ||||
with self.assertRaises(RuntimeError): | |||||
dd.add_field("??", [[1, 2]] * 40) | |||||
def test_delete_field(self): | def test_delete_field(self): | ||||
dd = DataSet() | dd = DataSet() | ||||
dd.add_field("x", [[1, 2, 3]] * 10) | dd.add_field("x", [[1, 2, 3]] * 10) | ||||
@@ -55,7 +60,7 @@ class TestDataSet(unittest.TestCase): | |||||
def test_getitem(self): | def test_getitem(self): | ||||
ds = DataSet({"x": [[1, 2, 3, 4]] * 40, "y": [[5, 6]] * 40}) | ds = DataSet({"x": [[1, 2, 3, 4]] * 40, "y": [[5, 6]] * 40}) | ||||
ins_1, ins_0 = ds[0], ds[1] | ins_1, ins_0 = ds[0], ds[1] | ||||
self.assertTrue(isinstance(ins_1, DataSet.Instance) and isinstance(ins_0, DataSet.Instance)) | |||||
self.assertTrue(isinstance(ins_1, Instance) and isinstance(ins_0, Instance)) | |||||
self.assertEqual(ins_1["x"], [1, 2, 3, 4]) | self.assertEqual(ins_1["x"], [1, 2, 3, 4]) | ||||
self.assertEqual(ins_1["y"], [5, 6]) | self.assertEqual(ins_1["y"], [5, 6]) | ||||
self.assertEqual(ins_0["x"], [1, 2, 3, 4]) | self.assertEqual(ins_0["x"], [1, 2, 3, 4]) | ||||
@@ -65,11 +70,131 @@ class TestDataSet(unittest.TestCase): | |||||
self.assertTrue(isinstance(sub_ds, DataSet)) | self.assertTrue(isinstance(sub_ds, DataSet)) | ||||
self.assertEqual(len(sub_ds), 10) | self.assertEqual(len(sub_ds), 10) | ||||
field = ds["x"] | |||||
self.assertEqual(field, ds.field_arrays["x"]) | |||||
def test_get_item_error(self): | |||||
with self.assertRaises(RuntimeError): | |||||
ds = DataSet({"x": [[1, 2, 3, 4]] * 10, "y": [[5, 6]] * 10}) | |||||
_ = ds[40:] | |||||
with self.assertRaises(KeyError): | |||||
ds = DataSet({"x": [[1, 2, 3, 4]] * 10, "y": [[5, 6]] * 10}) | |||||
_ = ds["kom"] | |||||
def test_len_(self): | |||||
ds = DataSet({"x": [[1, 2, 3, 4]] * 40, "y": [[5, 6]] * 40}) | |||||
self.assertEqual(len(ds), 40) | |||||
ds = DataSet() | |||||
self.assertEqual(len(ds), 0) | |||||
def test_apply(self): | def test_apply(self): | ||||
ds = DataSet({"x": [[1, 2, 3, 4]] * 40, "y": [[5, 6]] * 40}) | ds = DataSet({"x": [[1, 2, 3, 4]] * 40, "y": [[5, 6]] * 40}) | ||||
ds.apply(lambda ins: ins["x"][::-1], new_field_name="rx") | ds.apply(lambda ins: ins["x"][::-1], new_field_name="rx") | ||||
self.assertTrue("rx" in ds.field_arrays) | self.assertTrue("rx" in ds.field_arrays) | ||||
self.assertEqual(ds.field_arrays["rx"].content[0], [4, 3, 2, 1]) | self.assertEqual(ds.field_arrays["rx"].content[0], [4, 3, 2, 1]) | ||||
ds.apply(lambda ins: len(ins["y"]), new_field_name="y") | |||||
self.assertEqual(ds.field_arrays["y"].content[0], 2) | |||||
res = ds.apply(lambda ins: len(ins["x"])) | |||||
self.assertTrue(isinstance(res, list) and len(res) > 0) | |||||
self.assertTrue(res[0], 4) | |||||
def test_drop(self): | |||||
ds = DataSet({"x": [[1, 2, 3, 4]] * 40, "y": [[5, 6], [7, 8, 9, 0]] * 20}) | |||||
ds.drop(lambda ins: len(ins["y"]) < 3) | |||||
self.assertEqual(len(ds), 20) | |||||
def test_contains(self): | |||||
ds = DataSet({"x": [[1, 2, 3, 4]] * 40, "y": [[5, 6]] * 40}) | |||||
self.assertTrue("x" in ds) | |||||
self.assertTrue("y" in ds) | |||||
self.assertFalse("z" in ds) | |||||
def test_rename_field(self): | |||||
ds = DataSet({"x": [[1, 2, 3, 4]] * 10, "y": [[5, 6]] * 10}) | |||||
ds.rename_field("x", "xx") | |||||
self.assertTrue("xx" in ds) | |||||
self.assertFalse("x" in ds) | |||||
with self.assertRaises(KeyError): | |||||
ds.rename_field("yyy", "oo") | |||||
def test_input_target(self): | |||||
ds = DataSet({"x": [[1, 2, 3, 4]] * 10, "y": [[5, 6]] * 10}) | |||||
ds.set_input("x") | |||||
ds.set_target("y") | |||||
self.assertTrue(ds.field_arrays["x"].is_input) | |||||
self.assertTrue(ds.field_arrays["y"].is_target) | |||||
with self.assertRaises(KeyError): | |||||
ds.set_input("xxx") | |||||
with self.assertRaises(KeyError): | |||||
ds.set_input("yyy") | |||||
def test_get_input_name(self): | |||||
ds = DataSet({"x": [[1, 2, 3, 4]] * 10, "y": [[5, 6]] * 10}) | |||||
self.assertEqual(ds.get_input_name(), [_ for _ in ds.field_arrays if ds.field_arrays[_].is_input]) | |||||
def test_get_target_name(self): | |||||
ds = DataSet({"x": [[1, 2, 3, 4]] * 10, "y": [[5, 6]] * 10}) | |||||
self.assertEqual(ds.get_target_name(), [_ for _ in ds.field_arrays if ds.field_arrays[_].is_target]) | |||||
def test_apply2(self): | |||||
def split_sent(ins): | |||||
return ins['raw_sentence'].split() | |||||
dataset = DataSet.read_csv('test/data_for_tests/tutorial_sample_dataset.csv', headers=('raw_sentence', 'label'), | |||||
sep='\t') | |||||
dataset.drop(lambda x: len(x['raw_sentence'].split()) == 0) | |||||
dataset.apply(split_sent, new_field_name='words', is_input=True) | |||||
# print(dataset) | |||||
def test_add_field(self): | |||||
ds = DataSet({"x": [3, 4]}) | |||||
ds.add_field('y', [['hello', 'world'], ['this', 'is', 'a', 'test']], is_input=True, is_target=True) | |||||
# ds.apply(lambda x:[x['x']]*3, is_input=True, is_target=True, new_field_name='y') | |||||
print(ds) | |||||
def test_save_load(self): | |||||
ds = DataSet({"x": [[1, 2, 3, 4]] * 10, "y": [[5, 6]] * 10}) | |||||
ds.save("./my_ds.pkl") | |||||
self.assertTrue(os.path.exists("./my_ds.pkl")) | |||||
ds_1 = DataSet.load("./my_ds.pkl") | |||||
os.remove("my_ds.pkl") | |||||
def test_get_all_fields(self): | |||||
ds = DataSet({"x": [[1, 2, 3, 4]] * 10, "y": [[5, 6]] * 10}) | |||||
ans = ds.get_all_fields() | |||||
self.assertEqual(ans["x"].content, [[1, 2, 3, 4]] * 10) | |||||
self.assertEqual(ans["y"].content, [[5, 6]] * 10) | |||||
def test_get_field(self): | |||||
ds = DataSet({"x": [[1, 2, 3, 4]] * 10, "y": [[5, 6]] * 10}) | |||||
ans = ds.get_field("x") | |||||
self.assertTrue(isinstance(ans, FieldArray)) | |||||
self.assertEqual(ans.content, [[1, 2, 3, 4]] * 10) | |||||
ans = ds.get_field("y") | |||||
self.assertTrue(isinstance(ans, FieldArray)) | |||||
self.assertEqual(ans.content, [[5, 6]] * 10) | |||||
def test_reader(self): | |||||
# 跑通即可 | |||||
ds = DataSet().read_naive("test/data_for_tests/tutorial_sample_dataset.csv") | |||||
self.assertTrue(isinstance(ds, DataSet)) | |||||
self.assertTrue(len(ds) > 0) | |||||
ds = DataSet().read_rawdata("test/data_for_tests/people_daily_raw.txt") | |||||
self.assertTrue(isinstance(ds, DataSet)) | |||||
self.assertTrue(len(ds) > 0) | |||||
ds = DataSet().read_pos("test/data_for_tests/people.txt") | |||||
self.assertTrue(isinstance(ds, DataSet)) | |||||
self.assertTrue(len(ds) > 0) | |||||
class TestDataSetIter(unittest.TestCase): | |||||
def test__repr__(self): | |||||
ds = DataSet({"x": [[1, 2, 3, 4]] * 10, "y": [[5, 6]] * 10}) | |||||
for iter in ds: | |||||
self.assertEqual(iter.__repr__(), "{'x': [1, 2, 3, 4],\n'y': [5, 6]}") |
@@ -20,3 +20,80 @@ class TestFieldArray(unittest.TestCase): | |||||
self.assertEqual(fa.get(0), 1) | self.assertEqual(fa.get(0), 1) | ||||
self.assertTrue(isinstance(fa.get([0, 1, 2]), np.ndarray)) | self.assertTrue(isinstance(fa.get([0, 1, 2]), np.ndarray)) | ||||
self.assertListEqual(list(fa.get([0, 1, 2])), [1, 2, 3]) | self.assertListEqual(list(fa.get([0, 1, 2])), [1, 2, 3]) | ||||
def test_type_conversion(self): | |||||
fa = FieldArray("x", [1.2, 2.2, 3, 4, 5], is_input=True) | |||||
self.assertEqual(fa.pytype, float) | |||||
self.assertEqual(fa.dtype, np.float64) | |||||
fa = FieldArray("x", [1, 2, 3, 4, 5], is_input=True) | |||||
fa.append(1.3333) | |||||
self.assertEqual(fa.pytype, float) | |||||
self.assertEqual(fa.dtype, np.float64) | |||||
fa = FieldArray("y", [1.1, 2.2, 3.3, 4.4, 5.5], is_input=True) | |||||
fa.append(10) | |||||
self.assertEqual(fa.pytype, float) | |||||
self.assertEqual(fa.dtype, np.float64) | |||||
fa = FieldArray("y", ["a", "b", "c", "d"], is_input=True) | |||||
fa.append("e") | |||||
self.assertEqual(fa.dtype, np.str) | |||||
self.assertEqual(fa.pytype, str) | |||||
def test_support_np_array(self): | |||||
fa = FieldArray("y", [np.array([1.1, 2.2, 3.3, 4.4, 5.5])], is_input=True) | |||||
self.assertEqual(fa.dtype, np.ndarray) | |||||
self.assertEqual(fa.pytype, np.ndarray) | |||||
fa.append(np.array([1.1, 2.2, 3.3, 4.4, 5.5])) | |||||
self.assertEqual(fa.dtype, np.ndarray) | |||||
self.assertEqual(fa.pytype, np.ndarray) | |||||
fa = FieldArray("my_field", np.random.rand(3, 5), is_input=True) | |||||
# in this case, pytype is actually a float. We do not care about it. | |||||
self.assertEqual(fa.dtype, np.float64) | |||||
def test_nested_list(self): | |||||
fa = FieldArray("y", [[1.1, 2.2, 3.3, 4.4, 5.5], [1.1, 2.2, 3.3, 4.4, 5.5]], is_input=True) | |||||
self.assertEqual(fa.pytype, float) | |||||
self.assertEqual(fa.dtype, np.float64) | |||||
def test_getitem_v1(self): | |||||
fa = FieldArray("y", [[1.1, 2.2, 3.3, 4.4, 5.5], [1, 2, 3, 4, 5]], is_input=True) | |||||
self.assertEqual(fa[0], [1.1, 2.2, 3.3, 4.4, 5.5]) | |||||
ans = fa[[0, 1]] | |||||
self.assertTrue(isinstance(ans, np.ndarray)) | |||||
self.assertTrue(isinstance(ans[0], np.ndarray)) | |||||
self.assertEqual(ans[0].tolist(), [1.1, 2.2, 3.3, 4.4, 5.5]) | |||||
self.assertEqual(ans[1].tolist(), [1, 2, 3, 4, 5]) | |||||
self.assertEqual(ans.dtype, np.float64) | |||||
def test_getitem_v2(self): | |||||
x = np.random.rand(10, 5) | |||||
fa = FieldArray("my_field", x, is_input=True) | |||||
indices = [0, 1, 3, 4, 6] | |||||
for a, b in zip(fa[indices], x[indices]): | |||||
self.assertListEqual(a.tolist(), b.tolist()) | |||||
def test_append(self): | |||||
with self.assertRaises(Exception): | |||||
fa = FieldArray("y", [[1.1, 2.2, 3.3, 4.4, 5.5], [1, 2, 3, 4, 5]], is_input=True) | |||||
fa.append(0) | |||||
with self.assertRaises(Exception): | |||||
fa = FieldArray("y", [1.1, 2.2, 3.3, 4.4, 5.5], is_input=True) | |||||
fa.append([1, 2, 3, 4, 5]) | |||||
with self.assertRaises(Exception): | |||||
fa = FieldArray("y", [[1.1, 2.2, 3.3, 4.4, 5.5], [1, 2, 3, 4, 5]], is_input=True) | |||||
fa.append([]) | |||||
with self.assertRaises(Exception): | |||||
fa = FieldArray("y", [[1.1, 2.2, 3.3, 4.4, 5.5], [1, 2, 3, 4, 5]], is_input=True) | |||||
fa.append(["str", 0, 0, 0, 1.89]) | |||||
fa = FieldArray("y", [[1.1, 2.2, 3.3, 4.4, 5.5], [1, 2, 3, 4, 5]], is_input=True) | |||||
fa.append([1.2, 2.3, 3.4, 4.5, 5.6]) | |||||
self.assertEqual(len(fa), 3) | |||||
self.assertEqual(fa[2], [1.2, 2.3, 3.4, 4.5, 5.6]) |
@@ -27,3 +27,9 @@ class TestCase(unittest.TestCase): | |||||
self.assertEqual(ins["x"], [1, 2, 3]) | self.assertEqual(ins["x"], [1, 2, 3]) | ||||
self.assertEqual(ins["y"], [4, 5, 6]) | self.assertEqual(ins["y"], [4, 5, 6]) | ||||
self.assertEqual(ins["z"], [1, 1, 1]) | self.assertEqual(ins["z"], [1, 1, 1]) | ||||
def test_repr(self): | |||||
fields = {"x": [1, 2, 3], "y": [4, 5, 6], "z": [1, 1, 1]} | |||||
ins = Instance(**fields) | |||||
# simple print, that is enough. | |||||
print(ins) |
@@ -1,306 +1,87 @@ | |||||
import unittest | import unittest | ||||
import fastNLP.core.loss as loss | |||||
import math | |||||
import torch as tc | |||||
import pdb | |||||
import torch | |||||
import torch.nn.functional as F | |||||
class TestLoss(unittest.TestCase): | |||||
def test_case_1(self): | |||||
#验证nllloss的原理 | |||||
print (".----------------------------------") | |||||
loss_func = loss.Loss("nll") | |||||
#pdb.set_trace() | |||||
y = tc.Tensor( | |||||
[ | |||||
[.3,.4,.3], | |||||
[.5,.3,.2], | |||||
[.3,.6,.1], | |||||
] | |||||
) | |||||
gy = tc.LongTensor( | |||||
[ | |||||
0, | |||||
1, | |||||
2, | |||||
] | |||||
) | |||||
y = tc.log(y) | |||||
los = loss_func(y , gy) | |||||
r = -math.log(.3) - math.log(.3) - math.log(.1) | |||||
r /= 3 | |||||
print ("loss = %f" % (los)) | |||||
print ("r = %f" % (r)) | |||||
self.assertEqual(int(los * 1000), int(r * 1000)) | |||||
def test_case_2(self): | |||||
#验证squash()的正确性 | |||||
print ("----------------------------------") | |||||
log = math.log | |||||
loss_func = loss.Loss("nll") | |||||
#pdb.set_trace() | |||||
y = tc.Tensor( | |||||
[ | |||||
[[.3,.4,.3],[.3,.4,.3],], | |||||
[[.5,.3,.2],[.1,.2,.7],], | |||||
[[.3,.6,.1],[.2,.1,.7],], | |||||
] | |||||
) | |||||
gy = tc.LongTensor( | |||||
[ | |||||
[0,2], | |||||
[1,2], | |||||
[2,1], | |||||
] | |||||
) | |||||
#pdb.set_trace() | |||||
y = tc.log(y) | |||||
los = loss_func(y , gy) | |||||
print ("loss = %f" % (los)) | |||||
r = -log(.3) - log(.3) - log(.1) - log(.3) - log(.7) - log(.1) | |||||
r /= 6 | |||||
print ("r = %f" % (r)) | |||||
self.assertEqual(int(los * 1000), int(r * 1000)) | |||||
def test_case_3(self): | |||||
#验证pack_padded_sequence()的正确性 | |||||
print ("----------------------------------") | |||||
log = math.log | |||||
loss_func = loss.Loss("nll") | |||||
#pdb.set_trace() | |||||
y = tc.Tensor( | |||||
[ | |||||
[[.3,.4,.3],[.3,.2,.5],[.4,.5,.1,],], | |||||
[[.5,.3,.2],[.1,.2,.7],[.0,.0,.0,],], | |||||
[[.3,.6,.1],[.0,.0,.0],[.0,.0,.0,],], | |||||
] | |||||
) | |||||
gy = tc.LongTensor( | |||||
[ | |||||
[0,2,1,], | |||||
[1,2,0,], | |||||
[2,0,0,], | |||||
] | |||||
) | |||||
lens = [3,2,1] | |||||
#pdb.set_trace() | |||||
y = tc.log(y) | |||||
yy = tc.nn.utils.rnn.pack_padded_sequence(y , lens , batch_first = True).data | |||||
gyy = tc.nn.utils.rnn.pack_padded_sequence(gy , lens , batch_first = True).data | |||||
los = loss_func(yy , gyy) | |||||
print ("loss = %f" % (los)) | |||||
r = -log(.3) - log(.5) - log(.5) - log(.3) - log(.7) - log(.1) | |||||
r /= 6 | |||||
print ("r = %f" % (r)) | |||||
self.assertEqual(int(los * 1000), int(r * 1000)) | |||||
def test_case_4(self): | |||||
#验证unpad()的正确性 | |||||
print ("----------------------------------") | |||||
log = math.log | |||||
#pdb.set_trace() | |||||
y = tc.Tensor( | |||||
[ | |||||
[[.3,.4,.3],[.3,.2,.5],[.4,.5,.1,],[.6,.3,.1,],], | |||||
[[.5,.3,.2],[.1,.2,.7],[.0,.0,.0,],[.0,.0,.0,],], | |||||
[[.3,.6,.1],[.0,.0,.0],[.0,.0,.0,],[.0,.0,.0,],], | |||||
] | |||||
) | |||||
gy = tc.LongTensor( | |||||
[ | |||||
[0,2,1,2,], | |||||
[1,2,0,0,], | |||||
[2,0,0,0,], | |||||
] | |||||
) | |||||
lens = [4,2,1] | |||||
#pdb.set_trace() | |||||
y = tc.log(y) | |||||
import fastNLP.core.losses as loss | |||||
from fastNLP.core.losses import squash, unpad | |||||
loss_func = loss.Loss("nll" , pre_pro = ["unpad"]) | |||||
los = loss_func(y , gy , lens = lens) | |||||
print ("loss = %f" % (los)) | |||||
r = -log(.1) -log(.3) - log(.5) - log(.5) - log(.3) - log(.7) - log(.1) | |||||
r /= 7 | |||||
print ("r = %f" % (r)) | |||||
self.assertEqual(int(los * 1000), int(r * 1000)) | |||||
def test_case_5(self): | |||||
#验证mask()和make_mask()的正确性 | |||||
print ("----------------------------------") | |||||
log = math.log | |||||
#pdb.set_trace() | |||||
y = tc.Tensor( | |||||
[ | |||||
[[.5,.3,.2],[.1,.2,.7],[.0,.0,.0,],[.0,.0,.0,],], | |||||
[[.5,.4,.1],[.3,.2,.5],[.4,.5,.1,],[.6,.1,.3,],], | |||||
[[.3,.6,.1],[.3,.2,.5],[.0,.0,.0,],[.0,.0,.0,],], | |||||
] | |||||
) | |||||
gy = tc.LongTensor( | |||||
[ | |||||
[1,2,0,0,], | |||||
[0,2,1,2,], | |||||
[2,1,0,0,], | |||||
] | |||||
) | |||||
mask = tc.ByteTensor( | |||||
[ | |||||
[1,1,0,0,], | |||||
[1,1,1,1,], | |||||
[1,1,0,0,], | |||||
] | |||||
) | |||||
y = tc.log(y) | |||||
lens = [2,4,2] | |||||
loss_func = loss.Loss("nll" , pre_pro = ["mask"]) | |||||
los = loss_func(y , gy , mask = mask) | |||||
print ("loss = %f" % (los)) | |||||
los2 = loss_func(y , gy , mask = loss.make_mask(lens,gy.size()[-1])) | |||||
print ("loss2 = %f" % (los2)) | |||||
r = -log(.3) -log(.7) - log(.5) - log(.5) - log(.5) - log(.3) - log(.1) - log(.2) | |||||
r /= 8 | |||||
print ("r = %f" % (r)) | |||||
self.assertEqual(int(los * 1000), int(r * 1000)) | |||||
self.assertEqual(int(los2 * 1000), int(r * 1000)) | |||||
def test_case_6(self): | |||||
#验证unpad_mask()的正确性 | |||||
print ("----------------------------------") | |||||
log = math.log | |||||
#pdb.set_trace() | |||||
y = tc.Tensor( | |||||
[ | |||||
[[.3,.4,.3],[.3,.2,.5],[.4,.5,.1,],[.6,.3,.1,],], | |||||
[[.5,.3,.2],[.1,.2,.7],[.0,.0,.0,],[.0,.0,.0,],], | |||||
[[.3,.6,.1],[.0,.0,.0],[.0,.0,.0,],[.0,.0,.0,],], | |||||
] | |||||
) | |||||
gy = tc.LongTensor( | |||||
[ | |||||
[0,2,1,2,], | |||||
[1,2,0,0,], | |||||
[2,0,0,0,], | |||||
] | |||||
) | |||||
lens = [4,2,1] | |||||
#pdb.set_trace() | |||||
y = tc.log(y) | |||||
loss_func = loss.Loss("nll" , pre_pro = ["unpad_mask"]) | |||||
los = loss_func(y , gy , lens = lens) | |||||
print ("loss = %f" % (los)) | |||||
r = -log(.1) -log(.3) - log(.5) - log(.5) - log(.3) - log(.7) - log(.1) | |||||
r /= 7 | |||||
print ("r = %f" % (r)) | |||||
self.assertEqual(int(los * 1000), int(r * 1000)) | |||||
def test_case_7(self): | |||||
#验证一些其他东西 | |||||
print ("----------------------------------") | |||||
log = math.log | |||||
#pdb.set_trace() | |||||
y = tc.Tensor( | |||||
[ | |||||
[[.3,.4,.3],[.3,.2,.5],[.4,.5,.1,],[.6,.3,.1,],], | |||||
[[.5,.3,.2],[.1,.2,.7],[.0,.0,.0,],[.0,.0,.0,],], | |||||
[[.3,.6,.1],[.0,.0,.0],[.0,.0,.0,],[.0,.0,.0,],], | |||||
] | |||||
) | |||||
gy = tc.LongTensor( | |||||
[ | |||||
[0,2,1,2,], | |||||
[1,2,0,0,], | |||||
[2,0,0,0,], | |||||
] | |||||
) | |||||
lens = [4,2,1] | |||||
#pdb.set_trace() | |||||
y = tc.log(y) | |||||
loss_func = loss.Loss("nll" , pre_pro = [] , weight = tc.Tensor([1,1,0])) | |||||
loss_func.add_pre_pro("unpad_mask") | |||||
los = loss_func(y , gy , lens = lens) | |||||
print ("loss = %f" % (los)) | |||||
r = - log(.3) - log(.5) - log(.3) | |||||
r /= 3 | |||||
print ("r = %f" % (r)) | |||||
self.assertEqual(int(los * 1000), int(r * 1000)) | |||||
if __name__ == "__main__": | |||||
unittest.main() | |||||
class TestLoss(unittest.TestCase): | |||||
def test_CrossEntropyLoss(self): | |||||
ce = loss.CrossEntropyLoss(pred="my_predict", target="my_truth") | |||||
a = torch.randn(3, 5, requires_grad=False) | |||||
b = torch.empty(3, dtype=torch.long).random_(5) | |||||
ans = ce({"my_predict": a}, {"my_truth": b}) | |||||
self.assertEqual(ans, torch.nn.functional.cross_entropy(a, b)) | |||||
def test_BCELoss(self): | |||||
bce = loss.BCELoss(pred="my_predict", target="my_truth") | |||||
a = torch.sigmoid(torch.randn((3, 5), requires_grad=False)) | |||||
b = torch.randn((3, 5), requires_grad=False) | |||||
ans = bce({"my_predict": a}, {"my_truth": b}) | |||||
self.assertEqual(ans, torch.nn.functional.binary_cross_entropy(a, b)) | |||||
def test_L1Loss(self): | |||||
l1 = loss.L1Loss(pred="my_predict", target="my_truth") | |||||
a = torch.randn(3, 5, requires_grad=False) | |||||
b = torch.randn(3, 5) | |||||
ans = l1({"my_predict": a}, {"my_truth": b}) | |||||
self.assertEqual(ans, torch.nn.functional.l1_loss(a, b)) | |||||
def test_NLLLoss(self): | |||||
l1 = loss.NLLLoss(pred="my_predict", target="my_truth") | |||||
a = F.log_softmax(torch.randn(3, 5, requires_grad=False), dim=0) | |||||
b = torch.tensor([1, 0, 4]) | |||||
ans = l1({"my_predict": a}, {"my_truth": b}) | |||||
self.assertEqual(ans, torch.nn.functional.nll_loss(a, b)) | |||||
class TestLosserError(unittest.TestCase): | |||||
def test_losser1(self): | |||||
# (1) only input, targets passed | |||||
pred_dict = {"pred": torch.zeros(4, 3)} | |||||
target_dict = {'target': torch.zeros(4).long()} | |||||
los = loss.CrossEntropyLoss() | |||||
print(los(pred_dict=pred_dict, target_dict=target_dict)) | |||||
# | |||||
def test_losser2(self): | |||||
# (2) with corrupted size | |||||
pred_dict = {"pred": torch.zeros(16, 3)} | |||||
target_dict = {'target': torch.zeros(16, 3).long()} | |||||
los = loss.CrossEntropyLoss() | |||||
with self.assertRaises(RuntimeError): | |||||
print(los(pred_dict=pred_dict, target_dict=target_dict)) | |||||
def test_losser3(self): | |||||
# (2) with corrupted size | |||||
pred_dict = {"pred": torch.zeros(16, 3), 'stop_fast_param': 0} | |||||
target_dict = {'target': torch.zeros(16).long()} | |||||
los = loss.CrossEntropyLoss() | |||||
print(los(pred_dict=pred_dict, target_dict=target_dict)) | |||||
def test_check_error(self): | |||||
l1 = loss.NLLLoss(pred="my_predict", target="my_truth") | |||||
a = F.log_softmax(torch.randn(3, 5, requires_grad=False), dim=0) | |||||
b = torch.tensor([1, 0, 4]) | |||||
with self.assertRaises(Exception): | |||||
ans = l1({"wrong_predict": a, "my": b}, {"my_truth": b}) | |||||
with self.assertRaises(Exception): | |||||
ans = l1({"my_predict": a}, {"truth": b, "my": a}) | |||||
class TestLossUtils(unittest.TestCase): | |||||
def test_squash(self): | |||||
a, b = squash(torch.randn(3, 5), torch.randn(3, 5)) | |||||
self.assertEqual(tuple(a.size()), (3, 5)) | |||||
self.assertEqual(tuple(b.size()), (15,)) | |||||
def test_unpad(self): | |||||
a, b = unpad(torch.randn(5, 8, 3), torch.randn(5, 8)) | |||||
self.assertEqual(tuple(a.size()), (5, 8, 3)) | |||||
self.assertEqual(tuple(b.size()), (5, 8)) |
@@ -0,0 +1,145 @@ | |||||
import unittest | |||||
import numpy as np | |||||
import torch | |||||
from fastNLP.core.metrics import AccuracyMetric | |||||
from fastNLP.core.metrics import pred_topk, accuracy_topk | |||||
class TestAccuracyMetric(unittest.TestCase): | |||||
def test_AccuracyMetric1(self): | |||||
# (1) only input, targets passed | |||||
pred_dict = {"pred": torch.zeros(4, 3)} | |||||
target_dict = {'target': torch.zeros(4)} | |||||
metric = AccuracyMetric() | |||||
metric(pred_dict=pred_dict, target_dict=target_dict, ) | |||||
print(metric.get_metric()) | |||||
def test_AccuracyMetric2(self): | |||||
# (2) with corrupted size | |||||
try: | |||||
pred_dict = {"pred": torch.zeros(4, 3, 2)} | |||||
target_dict = {'target': torch.zeros(4)} | |||||
metric = AccuracyMetric() | |||||
metric(pred_dict=pred_dict, target_dict=target_dict, ) | |||||
print(metric.get_metric()) | |||||
except Exception as e: | |||||
print(e) | |||||
return | |||||
self.assertTrue(True, False), "No exception catches." | |||||
def test_AccuracyMetric3(self): | |||||
# (3) the second batch is corrupted size | |||||
try: | |||||
metric = AccuracyMetric() | |||||
pred_dict = {"pred": torch.zeros(4, 3, 2)} | |||||
target_dict = {'target': torch.zeros(4, 3)} | |||||
metric(pred_dict=pred_dict, target_dict=target_dict) | |||||
pred_dict = {"pred": torch.zeros(4, 3, 2)} | |||||
target_dict = {'target': torch.zeros(4)} | |||||
metric(pred_dict=pred_dict, target_dict=target_dict) | |||||
print(metric.get_metric()) | |||||
except Exception as e: | |||||
print(e) | |||||
return | |||||
self.assertTrue(True, False), "No exception catches." | |||||
def test_AccuaryMetric4(self): | |||||
# (5) check reset | |||||
metric = AccuracyMetric() | |||||
pred_dict = {"pred": torch.zeros(4, 3, 2)} | |||||
target_dict = {'target': torch.zeros(4, 3)} | |||||
metric(pred_dict=pred_dict, target_dict=target_dict) | |||||
self.assertDictEqual(metric.get_metric(), {'acc': 1}) | |||||
pred_dict = {"pred": torch.zeros(4, 3, 2)} | |||||
target_dict = {'target': torch.zeros(4, 3) + 1} | |||||
metric(pred_dict=pred_dict, target_dict=target_dict) | |||||
self.assertDictEqual(metric.get_metric(), {'acc': 0}) | |||||
def test_AccuaryMetric5(self): | |||||
# (5) check reset | |||||
metric = AccuracyMetric() | |||||
pred_dict = {"pred": torch.zeros(4, 3, 2)} | |||||
target_dict = {'target': torch.zeros(4, 3)} | |||||
metric(pred_dict=pred_dict, target_dict=target_dict) | |||||
self.assertDictEqual(metric.get_metric(reset=False), {'acc': 1}) | |||||
pred_dict = {"pred": torch.zeros(4, 3, 2)} | |||||
target_dict = {'target': torch.zeros(4, 3) + 1} | |||||
metric(pred_dict=pred_dict, target_dict=target_dict) | |||||
self.assertDictEqual(metric.get_metric(), {'acc': 0.5}) | |||||
def test_AccuaryMetric6(self): | |||||
# (6) check numpy array is not acceptable | |||||
try: | |||||
metric = AccuracyMetric() | |||||
pred_dict = {"pred": np.zeros((4, 3, 2))} | |||||
target_dict = {'target': np.zeros((4, 3))} | |||||
metric(pred_dict=pred_dict, target_dict=target_dict) | |||||
except Exception as e: | |||||
print(e) | |||||
return | |||||
self.assertTrue(True, False), "No exception catches." | |||||
def test_AccuaryMetric7(self): | |||||
# (7) check map, match | |||||
metric = AccuracyMetric(pred='predictions', target='targets') | |||||
pred_dict = {"predictions": torch.zeros(4, 3, 2)} | |||||
target_dict = {'targets': torch.zeros(4, 3)} | |||||
metric(pred_dict=pred_dict, target_dict=target_dict) | |||||
self.assertDictEqual(metric.get_metric(), {'acc': 1}) | |||||
def test_AccuaryMetric8(self): | |||||
# (8) check map, does not match. use stop_fast_param to stop fast param map | |||||
try: | |||||
metric = AccuracyMetric(pred='predictions', target='targets') | |||||
pred_dict = {"prediction": torch.zeros(4, 3, 2), "stop_fast_param": 1} | |||||
target_dict = {'targets': torch.zeros(4, 3)} | |||||
metric(pred_dict=pred_dict, target_dict=target_dict, ) | |||||
self.assertDictEqual(metric.get_metric(), {'acc': 1}) | |||||
except Exception as e: | |||||
print(e) | |||||
return | |||||
self.assertTrue(True, False), "No exception catches." | |||||
def test_AccuaryMetric9(self): | |||||
# (9) check map, include unused | |||||
try: | |||||
metric = AccuracyMetric(pred='prediction', target='targets') | |||||
pred_dict = {"prediction": torch.zeros(4, 3, 2), 'unused': 1} | |||||
target_dict = {'targets': torch.zeros(4, 3)} | |||||
metric(pred_dict=pred_dict, target_dict=target_dict) | |||||
self.assertDictEqual(metric.get_metric(), {'acc': 1}) | |||||
except Exception as e: | |||||
print(e) | |||||
return | |||||
self.assertTrue(True, False), "No exception catches." | |||||
def test_AccuaryMetric10(self): | |||||
# (10) check _fast_metric | |||||
try: | |||||
metric = AccuracyMetric() | |||||
pred_dict = {"predictions": torch.zeros(4, 3, 2), "masks": torch.zeros(4, 3)} | |||||
target_dict = {'targets': torch.zeros(4, 3)} | |||||
metric(pred_dict=pred_dict, target_dict=target_dict) | |||||
self.assertDictEqual(metric.get_metric(), {'acc': 1}) | |||||
except Exception as e: | |||||
print(e) | |||||
return | |||||
self.assertTrue(True, False), "No exception catches." | |||||
class TestUsefulFunctions(unittest.TestCase): | |||||
# 测试metrics.py中一些看上去挺有用的函数 | |||||
def test_case_1(self): | |||||
# multi-class | |||||
_ = accuracy_topk(np.random.randint(0, 3, size=(10, 1)), np.random.randint(0, 3, size=(10, 1)), k=3) | |||||
_ = pred_topk(np.random.randint(0, 3, size=(10, 1))) | |||||
# 跑通即可 |
@@ -0,0 +1,54 @@ | |||||
import unittest | |||||
import torch | |||||
from fastNLP.core.optimizer import SGD, Adam | |||||
class TestOptim(unittest.TestCase): | |||||
def test_SGD(self): | |||||
optim = SGD(model_params=torch.nn.Linear(10, 3).parameters()) | |||||
self.assertTrue("lr" in optim.__dict__["settings"]) | |||||
self.assertTrue("momentum" in optim.__dict__["settings"]) | |||||
res = optim.construct_from_pytorch(torch.nn.Linear(10, 3).parameters()) | |||||
self.assertTrue(isinstance(res, torch.optim.SGD)) | |||||
optim = SGD(lr=0.001) | |||||
self.assertEqual(optim.__dict__["settings"]["lr"], 0.001) | |||||
res = optim.construct_from_pytorch(torch.nn.Linear(10, 3).parameters()) | |||||
self.assertTrue(isinstance(res, torch.optim.SGD)) | |||||
optim = SGD(lr=0.002, momentum=0.989) | |||||
self.assertEqual(optim.__dict__["settings"]["lr"], 0.002) | |||||
self.assertEqual(optim.__dict__["settings"]["momentum"], 0.989) | |||||
optim = SGD(0.001) | |||||
self.assertEqual(optim.__dict__["settings"]["lr"], 0.001) | |||||
res = optim.construct_from_pytorch(torch.nn.Linear(10, 3).parameters()) | |||||
self.assertTrue(isinstance(res, torch.optim.SGD)) | |||||
with self.assertRaises(TypeError): | |||||
_ = SGD("???") | |||||
with self.assertRaises(TypeError): | |||||
_ = SGD(0.001, lr=0.002) | |||||
def test_Adam(self): | |||||
optim = Adam(model_params=torch.nn.Linear(10, 3).parameters()) | |||||
self.assertTrue("lr" in optim.__dict__["settings"]) | |||||
self.assertTrue("weight_decay" in optim.__dict__["settings"]) | |||||
res = optim.construct_from_pytorch(torch.nn.Linear(10, 3).parameters()) | |||||
self.assertTrue(isinstance(res, torch.optim.Adam)) | |||||
optim = Adam(lr=0.001) | |||||
self.assertEqual(optim.__dict__["settings"]["lr"], 0.001) | |||||
res = optim.construct_from_pytorch(torch.nn.Linear(10, 3).parameters()) | |||||
self.assertTrue(isinstance(res, torch.optim.Adam)) | |||||
optim = Adam(lr=0.002, weight_decay=0.989) | |||||
self.assertEqual(optim.__dict__["settings"]["lr"], 0.002) | |||||
self.assertEqual(optim.__dict__["settings"]["weight_decay"], 0.989) | |||||
optim = Adam(0.001) | |||||
self.assertEqual(optim.__dict__["settings"]["lr"], 0.001) | |||||
res = optim.construct_from_pytorch(torch.nn.Linear(10, 3).parameters()) | |||||
self.assertTrue(isinstance(res, torch.optim.Adam)) |
@@ -1,6 +1,34 @@ | |||||
import unittest | import unittest | ||||
import numpy as np | |||||
import torch | |||||
from fastNLP.core.dataset import DataSet | |||||
from fastNLP.core.instance import Instance | |||||
from fastNLP.core.predictor import Predictor | |||||
from fastNLP.modules.encoder.linear import Linear | |||||
def prepare_fake_dataset(): | |||||
mean = np.array([-3, -3]) | |||||
cov = np.array([[1, 0], [0, 1]]) | |||||
class_A = np.random.multivariate_normal(mean, cov, size=(1000,)) | |||||
mean = np.array([3, 3]) | |||||
cov = np.array([[1, 0], [0, 1]]) | |||||
class_B = np.random.multivariate_normal(mean, cov, size=(1000,)) | |||||
data_set = DataSet([Instance(x=[float(item[0]), float(item[1])], y=[0.0]) for item in class_A] + | |||||
[Instance(x=[float(item[0]), float(item[1])], y=[1.0]) for item in class_B]) | |||||
return data_set | |||||
class TestPredictor(unittest.TestCase): | class TestPredictor(unittest.TestCase): | ||||
def test(self): | def test(self): | ||||
pass | |||||
predictor = Predictor() | |||||
model = Linear(2, 1) | |||||
data = prepare_fake_dataset() | |||||
data.set_input("x") | |||||
ans = predictor.predict(model, data) | |||||
self.assertEqual(len(ans), 2000) | |||||
self.assertTrue(isinstance(ans[0], torch.Tensor)) |
@@ -1,9 +1,11 @@ | |||||
import random | |||||
import unittest | import unittest | ||||
import torch | import torch | ||||
from fastNLP.core.dataset import DataSet | |||||
from fastNLP.core.sampler import convert_to_torch_tensor, SequentialSampler, RandomSampler, \ | from fastNLP.core.sampler import convert_to_torch_tensor, SequentialSampler, RandomSampler, \ | ||||
k_means_1d, k_means_bucketing, simple_sort_bucketing | |||||
k_means_1d, k_means_bucketing, simple_sort_bucketing, BucketSampler | |||||
class TestSampler(unittest.TestCase): | class TestSampler(unittest.TestCase): | ||||
@@ -40,3 +42,11 @@ class TestSampler(unittest.TestCase): | |||||
def test_simple_sort_bucketing(self): | def test_simple_sort_bucketing(self): | ||||
_ = simple_sort_bucketing([21, 3, 25, 7, 9, 22, 4, 6, 28, 10]) | _ = simple_sort_bucketing([21, 3, 25, 7, 9, 22, 4, 6, 28, 10]) | ||||
assert len(_) == 10 | assert len(_) == 10 | ||||
def test_BucketSampler(self): | |||||
sampler = BucketSampler(num_buckets=3, batch_size=16, seq_lens_field_name="seq_len") | |||||
data_set = DataSet({"x": [[0] * random.randint(1, 10)] * 10, "y": [[5, 6]] * 10}) | |||||
data_set.apply(lambda ins: len(ins["x"]), new_field_name="seq_len") | |||||
indices = sampler(data_set) | |||||
self.assertEqual(len(indices), 10) | |||||
# 跑通即可,不验证效果 |
@@ -4,6 +4,64 @@ data_name = "pku_training.utf8" | |||||
pickle_path = "data_for_tests" | pickle_path = "data_for_tests" | ||||
import numpy as np | |||||
import torch.nn.functional as F | |||||
from torch import nn | |||||
import time | |||||
from fastNLP.core.utils import CheckError | |||||
from fastNLP.core.dataset import DataSet | |||||
from fastNLP.core.instance import Instance | |||||
from fastNLP.core.losses import BCELoss | |||||
from fastNLP.core.losses import CrossEntropyLoss | |||||
from fastNLP.core.metrics import AccuracyMetric | |||||
from fastNLP.core.optimizer import SGD | |||||
from fastNLP.core.tester import Tester | |||||
from fastNLP.models.base_model import NaiveClassifier | |||||
def prepare_fake_dataset(): | |||||
mean = np.array([-3, -3]) | |||||
cov = np.array([[1, 0], [0, 1]]) | |||||
class_A = np.random.multivariate_normal(mean, cov, size=(1000,)) | |||||
mean = np.array([3, 3]) | |||||
cov = np.array([[1, 0], [0, 1]]) | |||||
class_B = np.random.multivariate_normal(mean, cov, size=(1000,)) | |||||
data_set = DataSet([Instance(x=[float(item[0]), float(item[1])], y=[0.0]) for item in class_A] + | |||||
[Instance(x=[float(item[0]), float(item[1])], y=[1.0]) for item in class_B]) | |||||
return data_set | |||||
def prepare_fake_dataset2(*args, size=100): | |||||
ys = np.random.randint(4, size=100, dtype=np.int64) | |||||
data = {'y': ys} | |||||
for arg in args: | |||||
data[arg] = np.random.randn(size, 5) | |||||
return DataSet(data=data) | |||||
class TestTester(unittest.TestCase): | class TestTester(unittest.TestCase): | ||||
def test_case_1(self): | def test_case_1(self): | ||||
pass | |||||
# 检查报错提示能否正确提醒用户 | |||||
dataset = prepare_fake_dataset2('x1', 'x_unused') | |||||
dataset.rename_field('x_unused', 'x2') | |||||
dataset.set_input('x1', 'x2') | |||||
dataset.set_target('y', 'x1') | |||||
class Model(nn.Module): | |||||
def __init__(self): | |||||
super().__init__() | |||||
self.fc = nn.Linear(5, 4) | |||||
def forward(self, x1, x2): | |||||
x1 = self.fc(x1) | |||||
x2 = self.fc(x2) | |||||
x = x1 + x2 | |||||
time.sleep(0.1) | |||||
# loss = F.cross_entropy(x, y) | |||||
return {'preds': x} | |||||
model = Model() | |||||
with self.assertRaises(NameError): | |||||
tester = Tester( | |||||
data=dataset, | |||||
model=model, | |||||
metrics=AccuracyMetric()) | |||||
tester.test() |
@@ -1,6 +1,242 @@ | |||||
import unittest | import unittest | ||||
import numpy as np | |||||
import torch.nn.functional as F | |||||
from torch import nn | |||||
import time | |||||
from fastNLP.core.utils import CheckError | |||||
from fastNLP.core.dataset import DataSet | |||||
from fastNLP.core.instance import Instance | |||||
from fastNLP.core.losses import BCELoss | |||||
from fastNLP.core.losses import CrossEntropyLoss | |||||
from fastNLP.core.metrics import AccuracyMetric | |||||
from fastNLP.core.optimizer import SGD | |||||
from fastNLP.core.trainer import Trainer | |||||
from fastNLP.models.base_model import NaiveClassifier | |||||
class TestTrainer(unittest.TestCase): | |||||
def test_case_1(self): | |||||
pass | |||||
def prepare_fake_dataset(): | |||||
mean = np.array([-3, -3]) | |||||
cov = np.array([[1, 0], [0, 1]]) | |||||
class_A = np.random.multivariate_normal(mean, cov, size=(1000,)) | |||||
mean = np.array([3, 3]) | |||||
cov = np.array([[1, 0], [0, 1]]) | |||||
class_B = np.random.multivariate_normal(mean, cov, size=(1000,)) | |||||
data_set = DataSet([Instance(x=[float(item[0]), float(item[1])], y=[0.0]) for item in class_A] + | |||||
[Instance(x=[float(item[0]), float(item[1])], y=[1.0]) for item in class_B]) | |||||
return data_set | |||||
def prepare_fake_dataset2(*args, size=100): | |||||
ys = np.random.randint(4, size=100, dtype=np.int64) | |||||
data = {'y': ys} | |||||
for arg in args: | |||||
data[arg] = np.random.randn(size, 5) | |||||
return DataSet(data=data) | |||||
class TrainerTestGround(unittest.TestCase): | |||||
def test_case(self): | |||||
data_set = prepare_fake_dataset() | |||||
data_set.set_input("x", flag=True) | |||||
data_set.set_target("y", flag=True) | |||||
train_set, dev_set = data_set.split(0.3) | |||||
model = NaiveClassifier(2, 1) | |||||
trainer = Trainer(train_set, model, | |||||
loss=BCELoss(pred="predict", target="y"), | |||||
metrics=AccuracyMetric(pred="predict", target="y"), | |||||
n_epochs=10, | |||||
batch_size=32, | |||||
print_every=50, | |||||
validate_every=-1, | |||||
dev_data=dev_set, | |||||
optimizer=SGD(lr=0.1), | |||||
check_code_level=2, | |||||
use_tqdm=True, | |||||
save_path=None) | |||||
trainer.train() | |||||
""" | |||||
# 应该正确运行 | |||||
""" | |||||
def test_trainer_suggestion1(self): | |||||
# 检查报错提示能否正确提醒用户。 | |||||
# 这里没有传入forward需要的数据。需要trainer提醒用户如何设置。 | |||||
dataset = prepare_fake_dataset2('x') | |||||
class Model(nn.Module): | |||||
def __init__(self): | |||||
super().__init__() | |||||
self.fc = nn.Linear(5, 4) | |||||
def forward(self, x1, x2, y): | |||||
x1 = self.fc(x1) | |||||
x2 = self.fc(x2) | |||||
x = x1 + x2 | |||||
loss = F.cross_entropy(x, y) | |||||
return {'loss': loss} | |||||
model = Model() | |||||
with self.assertRaises(NameError): | |||||
trainer = Trainer( | |||||
train_data=dataset, | |||||
model=model | |||||
) | |||||
""" | |||||
# 应该获取到的报错提示 | |||||
NameError: | |||||
The following problems occurred when calling Model.forward(self, x1, x2, y) | |||||
missing param: ['y', 'x1', 'x2'] | |||||
Suggestion: (1). You might need to set ['y'] as input. | |||||
(2). You need to provide ['x1', 'x2'] in DataSet and set it as input. | |||||
""" | |||||
def test_trainer_suggestion2(self): | |||||
# 检查报错提示能否正确提醒用户 | |||||
# 这里传入forward需要的数据,看是否可以运行 | |||||
dataset = prepare_fake_dataset2('x1', 'x2') | |||||
dataset.set_input('x1', 'x2', 'y', flag=True) | |||||
class Model(nn.Module): | |||||
def __init__(self): | |||||
super().__init__() | |||||
self.fc = nn.Linear(5, 4) | |||||
def forward(self, x1, x2, y): | |||||
x1 = self.fc(x1) | |||||
x2 = self.fc(x2) | |||||
x = x1 + x2 | |||||
loss = F.cross_entropy(x, y) | |||||
return {'loss': loss} | |||||
model = Model() | |||||
trainer = Trainer( | |||||
train_data=dataset, | |||||
model=model, | |||||
use_tqdm=False, | |||||
print_every=2 | |||||
) | |||||
trainer.train() | |||||
""" | |||||
# 应该正确运行 | |||||
""" | |||||
def test_trainer_suggestion3(self): | |||||
# 检查报错提示能否正确提醒用户 | |||||
# 这里传入forward需要的数据,但是forward没有返回loss这个key | |||||
dataset = prepare_fake_dataset2('x1', 'x2') | |||||
dataset.set_input('x1', 'x2', 'y', flag=True) | |||||
class Model(nn.Module): | |||||
def __init__(self): | |||||
super().__init__() | |||||
self.fc = nn.Linear(5, 4) | |||||
def forward(self, x1, x2, y): | |||||
x1 = self.fc(x1) | |||||
x2 = self.fc(x2) | |||||
x = x1 + x2 | |||||
loss = F.cross_entropy(x, y) | |||||
return {'wrong_loss_key': loss} | |||||
model = Model() | |||||
with self.assertRaises(NameError): | |||||
trainer = Trainer( | |||||
train_data=dataset, | |||||
model=model, | |||||
use_tqdm=False, | |||||
print_every=2 | |||||
) | |||||
trainer.train() | |||||
def test_trainer_suggestion4(self): | |||||
# 检查报错提示能否正确提醒用户 | |||||
# 这里传入forward需要的数据,是否可以正确提示unused | |||||
dataset = prepare_fake_dataset2('x1', 'x2') | |||||
dataset.set_input('x1', 'x2', 'y', flag=True) | |||||
class Model(nn.Module): | |||||
def __init__(self): | |||||
super().__init__() | |||||
self.fc = nn.Linear(5, 4) | |||||
def forward(self, x1, x2, y): | |||||
x1 = self.fc(x1) | |||||
x2 = self.fc(x2) | |||||
x = x1 + x2 | |||||
loss = F.cross_entropy(x, y) | |||||
return {'losses': loss} | |||||
model = Model() | |||||
with self.assertRaises(NameError): | |||||
trainer = Trainer( | |||||
train_data=dataset, | |||||
model=model, | |||||
use_tqdm=False, | |||||
print_every=2 | |||||
) | |||||
def test_trainer_suggestion5(self): | |||||
# 检查报错提示能否正确提醒用户 | |||||
# 这里传入多余参数,让其duplicate, 但这里因为y不会被调用,所以其实不会报错 | |||||
dataset = prepare_fake_dataset2('x1', 'x_unused') | |||||
dataset.rename_field('x_unused', 'x2') | |||||
dataset.set_input('x1', 'x2', 'y') | |||||
dataset.set_target('y') | |||||
class Model(nn.Module): | |||||
def __init__(self): | |||||
super().__init__() | |||||
self.fc = nn.Linear(5, 4) | |||||
def forward(self, x1, x2, y): | |||||
x1 = self.fc(x1) | |||||
x2 = self.fc(x2) | |||||
x = x1 + x2 | |||||
loss = F.cross_entropy(x, y) | |||||
return {'loss': loss} | |||||
model = Model() | |||||
trainer = Trainer( | |||||
train_data=dataset, | |||||
model=model, | |||||
use_tqdm=False, | |||||
print_every=2 | |||||
) | |||||
def test_trainer_suggestion6(self): | |||||
# 检查报错提示能否正确提醒用户 | |||||
# 这里传入多余参数,让其duplicate | |||||
dataset = prepare_fake_dataset2('x1', 'x_unused') | |||||
dataset.rename_field('x_unused', 'x2') | |||||
dataset.set_input('x1', 'x2') | |||||
dataset.set_target('y', 'x1') | |||||
class Model(nn.Module): | |||||
def __init__(self): | |||||
super().__init__() | |||||
self.fc = nn.Linear(5, 4) | |||||
def forward(self, x1, x2): | |||||
x1 = self.fc(x1) | |||||
x2 = self.fc(x2) | |||||
x = x1 + x2 | |||||
time.sleep(0.1) | |||||
# loss = F.cross_entropy(x, y) | |||||
return {'preds': x} | |||||
model = Model() | |||||
with self.assertRaises(NameError): | |||||
trainer = Trainer( | |||||
train_data=dataset, | |||||
model=model, | |||||
dev_data=dataset, | |||||
loss=CrossEntropyLoss(), | |||||
metrics=AccuracyMetric(), | |||||
use_tqdm=False, | |||||
print_every=2) | |||||
def test_case2(self): | |||||
# check metrics Wrong | |||||
data_set = prepare_fake_dataset2('x1', 'x2') |
@@ -10,36 +10,36 @@ counter = Counter(text) | |||||
class TestAdd(unittest.TestCase): | class TestAdd(unittest.TestCase): | ||||
def test_add(self): | def test_add(self): | ||||
vocab = Vocabulary(need_default=True, max_size=None, min_freq=None) | |||||
vocab = Vocabulary(max_size=None, min_freq=None) | |||||
for word in text: | for word in text: | ||||
vocab.add(word) | vocab.add(word) | ||||
self.assertEqual(vocab.word_count, counter) | self.assertEqual(vocab.word_count, counter) | ||||
def test_add_word(self): | def test_add_word(self): | ||||
vocab = Vocabulary(need_default=True, max_size=None, min_freq=None) | |||||
vocab = Vocabulary(max_size=None, min_freq=None) | |||||
for word in text: | for word in text: | ||||
vocab.add_word(word) | vocab.add_word(word) | ||||
self.assertEqual(vocab.word_count, counter) | self.assertEqual(vocab.word_count, counter) | ||||
def test_add_word_lst(self): | def test_add_word_lst(self): | ||||
vocab = Vocabulary(need_default=True, max_size=None, min_freq=None) | |||||
vocab = Vocabulary(max_size=None, min_freq=None) | |||||
vocab.add_word_lst(text) | vocab.add_word_lst(text) | ||||
self.assertEqual(vocab.word_count, counter) | self.assertEqual(vocab.word_count, counter) | ||||
def test_update(self): | def test_update(self): | ||||
vocab = Vocabulary(need_default=True, max_size=None, min_freq=None) | |||||
vocab = Vocabulary(max_size=None, min_freq=None) | |||||
vocab.update(text) | vocab.update(text) | ||||
self.assertEqual(vocab.word_count, counter) | self.assertEqual(vocab.word_count, counter) | ||||
class TestIndexing(unittest.TestCase): | class TestIndexing(unittest.TestCase): | ||||
def test_len(self): | def test_len(self): | ||||
vocab = Vocabulary(need_default=False, max_size=None, min_freq=None) | |||||
vocab = Vocabulary(max_size=None, min_freq=None, unknown=None, padding=None) | |||||
vocab.update(text) | vocab.update(text) | ||||
self.assertEqual(len(vocab), len(counter)) | self.assertEqual(len(vocab), len(counter)) | ||||
def test_contains(self): | def test_contains(self): | ||||
vocab = Vocabulary(need_default=True, max_size=None, min_freq=None) | |||||
vocab = Vocabulary(max_size=None, min_freq=None, unknown=None, padding=None) | |||||
vocab.update(text) | vocab.update(text) | ||||
self.assertTrue(text[-1] in vocab) | self.assertTrue(text[-1] in vocab) | ||||
self.assertFalse("~!@#" in vocab) | self.assertFalse("~!@#" in vocab) | ||||
@@ -47,7 +47,7 @@ class TestIndexing(unittest.TestCase): | |||||
self.assertEqual("~!@#" in vocab, vocab.has_word("~!@#")) | self.assertEqual("~!@#" in vocab, vocab.has_word("~!@#")) | ||||
def test_index(self): | def test_index(self): | ||||
vocab = Vocabulary(need_default=True, max_size=None, min_freq=None) | |||||
vocab = Vocabulary(max_size=None, min_freq=None) | |||||
vocab.update(text) | vocab.update(text) | ||||
res = [vocab[w] for w in set(text)] | res = [vocab[w] for w in set(text)] | ||||
self.assertEqual(len(res), len(set(res))) | self.assertEqual(len(res), len(set(res))) | ||||
@@ -56,6 +56,33 @@ class TestIndexing(unittest.TestCase): | |||||
self.assertEqual(len(res), len(set(res))) | self.assertEqual(len(res), len(set(res))) | ||||
def test_to_word(self): | def test_to_word(self): | ||||
vocab = Vocabulary(need_default=True, max_size=None, min_freq=None) | |||||
vocab = Vocabulary(max_size=None, min_freq=None) | |||||
vocab.update(text) | vocab.update(text) | ||||
self.assertEqual(text, [vocab.to_word(idx) for idx in [vocab[w] for w in text]]) | self.assertEqual(text, [vocab.to_word(idx) for idx in [vocab[w] for w in text]]) | ||||
class TestOther(unittest.TestCase): | |||||
def test_additional_update(self): | |||||
vocab = Vocabulary(max_size=None, min_freq=None) | |||||
vocab.update(text) | |||||
_ = vocab["well"] | |||||
self.assertEqual(vocab.rebuild, False) | |||||
vocab.add("hahaha") | |||||
self.assertEqual(vocab.rebuild, True) | |||||
_ = vocab["hahaha"] | |||||
self.assertEqual(vocab.rebuild, False) | |||||
self.assertTrue("hahaha" in vocab) | |||||
def test_warning(self): | |||||
vocab = Vocabulary(max_size=len(set(text)), min_freq=None) | |||||
vocab.update(text) | |||||
self.assertEqual(vocab.rebuild, True) | |||||
print(len(vocab)) | |||||
self.assertEqual(vocab.rebuild, False) | |||||
vocab.update(["hahahha", "hhh", "vvvv", "ass", "asss", "jfweiong", "eqgfeg", "feqfw"]) | |||||
# this will print a warning | |||||
self.assertEqual(vocab.rebuild, True) |
@@ -1,12 +1,6 @@ | |||||
the 0.418 0.24968 -0.41242 0.1217 0.34527 -0.044457 -0.49688 -0.17862 -0.00066023 -0.6566 0.27843 -0.14767 -0.55677 0.14658 -0.0095095 0.011658 0.10204 -0.12792 -0.8443 -0.12181 -0.016801 -0.33279 -0.1552 -0.23131 -0.19181 -1.8823 -0.76746 0.099051 -0.42125 -0.19526 4.0071 -0.18594 -0.52287 -0.31681 0.00059213 0.0074449 0.17778 -0.15897 0.012041 -0.054223 -0.29871 -0.15749 -0.34758 -0.045637 -0.44251 0.18785 0.0027849 -0.18411 -0.11514 -0.78581 | the 0.418 0.24968 -0.41242 0.1217 0.34527 -0.044457 -0.49688 -0.17862 -0.00066023 -0.6566 0.27843 -0.14767 -0.55677 0.14658 -0.0095095 0.011658 0.10204 -0.12792 -0.8443 -0.12181 -0.016801 -0.33279 -0.1552 -0.23131 -0.19181 -1.8823 -0.76746 0.099051 -0.42125 -0.19526 4.0071 -0.18594 -0.52287 -0.31681 0.00059213 0.0074449 0.17778 -0.15897 0.012041 -0.054223 -0.29871 -0.15749 -0.34758 -0.045637 -0.44251 0.18785 0.0027849 -0.18411 -0.11514 -0.78581 | ||||
, 0.013441 0.23682 -0.16899 0.40951 0.63812 0.47709 -0.42852 -0.55641 -0.364 -0.23938 0.13001 -0.063734 -0.39575 -0.48162 0.23291 0.090201 -0.13324 0.078639 -0.41634 -0.15428 0.10068 0.48891 0.31226 -0.1252 -0.037512 -1.5179 0.12612 -0.02442 -0.042961 -0.28351 3.5416 -0.11956 -0.014533 -0.1499 0.21864 -0.33412 -0.13872 0.31806 0.70358 0.44858 -0.080262 0.63003 0.32111 -0.46765 0.22786 0.36034 -0.37818 -0.56657 0.044691 0.30392 | |||||
. 0.15164 0.30177 -0.16763 0.17684 0.31719 0.33973 -0.43478 -0.31086 -0.44999 -0.29486 0.16608 0.11963 -0.41328 -0.42353 0.59868 0.28825 -0.11547 -0.041848 -0.67989 -0.25063 0.18472 0.086876 0.46582 0.015035 0.043474 -1.4671 -0.30384 -0.023441 0.30589 -0.21785 3.746 0.0042284 -0.18436 -0.46209 0.098329 -0.11907 0.23919 0.1161 0.41705 0.056763 -6.3681e-05 0.068987 0.087939 -0.10285 -0.13931 0.22314 -0.080803 -0.35652 0.016413 0.10216 | |||||
of 0.70853 0.57088 -0.4716 0.18048 0.54449 0.72603 0.18157 -0.52393 0.10381 -0.17566 0.078852 -0.36216 -0.11829 -0.83336 0.11917 -0.16605 0.061555 -0.012719 -0.56623 0.013616 0.22851 -0.14396 -0.067549 -0.38157 -0.23698 -1.7037 -0.86692 -0.26704 -0.2589 0.1767 3.8676 -0.1613 -0.13273 -0.68881 0.18444 0.0052464 -0.33874 -0.078956 0.24185 0.36576 -0.34727 0.28483 0.075693 -0.062178 -0.38988 0.22902 -0.21617 -0.22562 -0.093918 -0.80375 | of 0.70853 0.57088 -0.4716 0.18048 0.54449 0.72603 0.18157 -0.52393 0.10381 -0.17566 0.078852 -0.36216 -0.11829 -0.83336 0.11917 -0.16605 0.061555 -0.012719 -0.56623 0.013616 0.22851 -0.14396 -0.067549 -0.38157 -0.23698 -1.7037 -0.86692 -0.26704 -0.2589 0.1767 3.8676 -0.1613 -0.13273 -0.68881 0.18444 0.0052464 -0.33874 -0.078956 0.24185 0.36576 -0.34727 0.28483 0.075693 -0.062178 -0.38988 0.22902 -0.21617 -0.22562 -0.093918 -0.80375 | ||||
to 0.68047 -0.039263 0.30186 -0.17792 0.42962 0.032246 -0.41376 0.13228 -0.29847 -0.085253 0.17118 0.22419 -0.10046 -0.43653 0.33418 0.67846 0.057204 -0.34448 -0.42785 -0.43275 0.55963 0.10032 0.18677 -0.26854 0.037334 -2.0932 0.22171 -0.39868 0.20912 -0.55725 3.8826 0.47466 -0.95658 -0.37788 0.20869 -0.32752 0.12751 0.088359 0.16351 -0.21634 -0.094375 0.018324 0.21048 -0.03088 -0.19722 0.082279 -0.09434 -0.073297 -0.064699 -0.26044 | to 0.68047 -0.039263 0.30186 -0.17792 0.42962 0.032246 -0.41376 0.13228 -0.29847 -0.085253 0.17118 0.22419 -0.10046 -0.43653 0.33418 0.67846 0.057204 -0.34448 -0.42785 -0.43275 0.55963 0.10032 0.18677 -0.26854 0.037334 -2.0932 0.22171 -0.39868 0.20912 -0.55725 3.8826 0.47466 -0.95658 -0.37788 0.20869 -0.32752 0.12751 0.088359 0.16351 -0.21634 -0.094375 0.018324 0.21048 -0.03088 -0.19722 0.082279 -0.09434 -0.073297 -0.064699 -0.26044 | ||||
and 0.26818 0.14346 -0.27877 0.016257 0.11384 0.69923 -0.51332 -0.47368 -0.33075 -0.13834 0.2702 0.30938 -0.45012 -0.4127 -0.09932 0.038085 0.029749 0.10076 -0.25058 -0.51818 0.34558 0.44922 0.48791 -0.080866 -0.10121 -1.3777 -0.10866 -0.23201 0.012839 -0.46508 3.8463 0.31362 0.13643 -0.52244 0.3302 0.33707 -0.35601 0.32431 0.12041 0.3512 -0.069043 0.36885 0.25168 -0.24517 0.25381 0.1367 -0.31178 -0.6321 -0.25028 -0.38097 | and 0.26818 0.14346 -0.27877 0.016257 0.11384 0.69923 -0.51332 -0.47368 -0.33075 -0.13834 0.2702 0.30938 -0.45012 -0.4127 -0.09932 0.038085 0.029749 0.10076 -0.25058 -0.51818 0.34558 0.44922 0.48791 -0.080866 -0.10121 -1.3777 -0.10866 -0.23201 0.012839 -0.46508 3.8463 0.31362 0.13643 -0.52244 0.3302 0.33707 -0.35601 0.32431 0.12041 0.3512 -0.069043 0.36885 0.25168 -0.24517 0.25381 0.1367 -0.31178 -0.6321 -0.25028 -0.38097 | ||||
in 0.33042 0.24995 -0.60874 0.10923 0.036372 0.151 -0.55083 -0.074239 -0.092307 -0.32821 0.09598 -0.82269 -0.36717 -0.67009 0.42909 0.016496 -0.23573 0.12864 -1.0953 0.43334 0.57067 -0.1036 0.20422 0.078308 -0.42795 -1.7984 -0.27865 0.11954 -0.12689 0.031744 3.8631 -0.17786 -0.082434 -0.62698 0.26497 -0.057185 -0.073521 0.46103 0.30862 0.12498 -0.48609 -0.0080272 0.031184 -0.36576 -0.42699 0.42164 -0.11666 -0.50703 -0.027273 -0.53285 | in 0.33042 0.24995 -0.60874 0.10923 0.036372 0.151 -0.55083 -0.074239 -0.092307 -0.32821 0.09598 -0.82269 -0.36717 -0.67009 0.42909 0.016496 -0.23573 0.12864 -1.0953 0.43334 0.57067 -0.1036 0.20422 0.078308 -0.42795 -1.7984 -0.27865 0.11954 -0.12689 0.031744 3.8631 -0.17786 -0.082434 -0.62698 0.26497 -0.057185 -0.073521 0.46103 0.30862 0.12498 -0.48609 -0.0080272 0.031184 -0.36576 -0.42699 0.42164 -0.11666 -0.50703 -0.027273 -0.53285 | ||||
a 0.21705 0.46515 -0.46757 0.10082 1.0135 0.74845 -0.53104 -0.26256 0.16812 0.13182 -0.24909 -0.44185 -0.21739 0.51004 0.13448 -0.43141 -0.03123 0.20674 -0.78138 -0.20148 -0.097401 0.16088 -0.61836 -0.18504 -0.12461 -2.2526 -0.22321 0.5043 0.32257 0.15313 3.9636 -0.71365 -0.67012 0.28388 0.21738 0.14433 0.25926 0.23434 0.4274 -0.44451 0.13813 0.36973 -0.64289 0.024142 -0.039315 -0.26037 0.12017 -0.043782 0.41013 0.1796 | |||||
" 0.25769 0.45629 -0.76974 -0.37679 0.59272 -0.063527 0.20545 -0.57385 -0.29009 -0.13662 0.32728 1.4719 -0.73681 -0.12036 0.71354 -0.46098 0.65248 0.48887 -0.51558 0.039951 -0.34307 -0.014087 0.86488 0.3546 0.7999 -1.4995 -1.8153 0.41128 0.23921 -0.43139 3.6623 -0.79834 -0.54538 0.16943 -0.82017 -0.3461 0.69495 -1.2256 -0.17992 -0.057474 0.030498 -0.39543 -0.38515 -1.0002 0.087599 -0.31009 -0.34677 -0.31438 0.75004 0.97065 | |||||
's 0.23727 0.40478 -0.20547 0.58805 0.65533 0.32867 -0.81964 -0.23236 0.27428 0.24265 0.054992 0.16296 -1.2555 -0.086437 0.44536 0.096561 -0.16519 0.058378 -0.38598 0.086977 0.0033869 0.55095 -0.77697 -0.62096 0.092948 -2.5685 -0.67739 0.10151 -0.48643 -0.057805 3.1859 -0.017554 -0.16138 0.055486 -0.25885 -0.33938 -0.19928 0.26049 0.10478 -0.55934 -0.12342 0.65961 -0.51802 -0.82995 -0.082739 0.28155 -0.423 -0.27378 -0.007901 -0.030231 | |||||
a 0.21705 0.46515 -0.46757 0.10082 1.0135 0.74845 -0.53104 -0.26256 0.16812 0.13182 -0.24909 -0.44185 -0.21739 0.51004 0.13448 -0.43141 -0.03123 0.20674 -0.78138 -0.20148 -0.097401 0.16088 -0.61836 -0.18504 -0.12461 -2.2526 -0.22321 0.5043 0.32257 0.15313 3.9636 -0.71365 -0.67012 0.28388 0.21738 0.14433 0.25926 0.23434 0.4274 -0.44451 0.13813 0.36973 -0.64289 0.024142 -0.039315 -0.26037 0.12017 -0.043782 0.41013 0.1796 |
@@ -0,0 +1,77 @@ | |||||
A series of escapades demonstrating the adage that what is good for the goose is also good for the gander , some of which occasionally amuses but none of which amounts to much of a story . 1 | |||||
This quiet , introspective and entertaining independent is worth seeking . 4 | |||||
Even fans of Ismail Merchant 's work , I suspect , would have a hard time sitting through this one . 1 | |||||
A positively thrilling combination of ethnography and all the intrigue , betrayal , deceit and murder of a Shakespearean tragedy or a juicy soap opera . 3 | |||||
Aggressive self-glorification and a manipulative whitewash . 1 | |||||
A comedy-drama of nearly epic proportions rooted in a sincere performance by the title character undergoing midlife crisis . 4 | |||||
Narratively , Trouble Every Day is a plodding mess . 1 | |||||
The Importance of Being Earnest , so thick with wit it plays like a reading from Bartlett 's Familiar Quotations 3 | |||||
But it does n't leave you with much . 1 | |||||
You could hate it for the same reason . 1 | |||||
There 's little to recommend Snow Dogs , unless one considers cliched dialogue and perverse escapism a source of high hilarity . 1 | |||||
Kung Pow is Oedekerk 's realization of his childhood dream to be in a martial-arts flick , and proves that sometimes the dreams of youth should remain just that . 1 | |||||
The performances are an absolute joy . 4 | |||||
Fresnadillo has something serious to say about the ways in which extravagant chance can distort our perspective and throw us off the path of good sense . 3 | |||||
I still like Moonlight Mile , better judgment be damned . 3 | |||||
A welcome relief from baseball movies that try too hard to be mythic , this one is a sweet and modest and ultimately winning story . 3 | |||||
a bilingual charmer , just like the woman who inspired it 3 | |||||
Like a less dizzily gorgeous companion to Mr. Wong 's In the Mood for Love -- very much a Hong Kong movie despite its mainland setting . 2 | |||||
As inept as big-screen remakes of The Avengers and The Wild Wild West . 1 | |||||
It 's everything you 'd expect -- but nothing more . 2 | |||||
Best indie of the year , so far . 4 | |||||
Hatfield and Hicks make the oddest of couples , and in this sense the movie becomes a study of the gambles of the publishing world , offering a case study that exists apart from all the movie 's political ramifications . 3 | |||||
It 's like going to a house party and watching the host defend himself against a frothing ex-girlfriend . 1 | |||||
That the Chuck Norris `` grenade gag '' occurs about 7 times during Windtalkers is a good indication of how serious-minded the film is . 2 | |||||
The plot is romantic comedy boilerplate from start to finish . 2 | |||||
It arrives with an impeccable pedigree , mongrel pep , and almost indecipherable plot complications . 2 | |||||
A film that clearly means to preach exclusively to the converted . 2 | |||||
While The Importance of Being Earnest offers opportunities for occasional smiles and chuckles , it does n't give us a reason to be in the theater beyond Wilde 's wit and the actors ' performances . 1 | |||||
The latest vapid actor 's exercise to appropriate the structure of Arthur Schnitzler 's Reigen . 1 | |||||
More vaudeville show than well-constructed narrative , but on those terms it 's inoffensive and actually rather sweet . 2 | |||||
Nothing more than a run-of-the-mill action flick . 2 | |||||
Hampered -- no , paralyzed -- by a self-indulgent script ... that aims for poetry and ends up sounding like satire . 0 | |||||
Ice Age is the first computer-generated feature cartoon to feel like other movies , and that makes for some glacial pacing early on . 2 | |||||
There 's very little sense to what 's going on here , but the makers serve up the cliches with considerable dash . 2 | |||||
Cattaneo should have followed the runaway success of his first film , The Full Monty , with something different . 2 | |||||
They 're the unnamed , easily substitutable forces that serve as whatever terror the heroes of horror movies try to avoid . 1 | |||||
It almost feels as if the movie is more interested in entertaining itself than in amusing us . 1 | |||||
The movie 's progression into rambling incoherence gives new meaning to the phrase ` fatal script error . ' 0 | |||||
I still like Moonlight Mile , better judgment be damned . 3 | |||||
A welcome relief from baseball movies that try too hard to be mythic , this one is a sweet and modest and ultimately winning story . 3 | |||||
a bilingual charmer , just like the woman who inspired it 3 | |||||
Like a less dizzily gorgeous companion to Mr. Wong 's In the Mood for Love -- very much a Hong Kong movie despite its mainland setting . 2 | |||||
As inept as big-screen remakes of The Avengers and The Wild Wild West . 1 | |||||
It 's everything you 'd expect -- but nothing more . 2 | |||||
Best indie of the year , so far . 4 | |||||
Hatfield and Hicks make the oddest of couples , and in this sense the movie becomes a study of the gambles of the publishing world , offering a case study that exists apart from all the movie 's political ramifications . 3 | |||||
It 's like going to a house party and watching the host defend himself against a frothing ex-girlfriend . 1 | |||||
That the Chuck Norris `` grenade gag '' occurs about 7 times during Windtalkers is a good indication of how serious-minded the film is . 2 | |||||
The plot is romantic comedy boilerplate from start to finish . 2 | |||||
It arrives with an impeccable pedigree , mongrel pep , and almost indecipherable plot complications . 2 | |||||
A film that clearly means to preach exclusively to the converted . 2 | |||||
I still like Moonlight Mile , better judgment be damned . 3 | |||||
A welcome relief from baseball movies that try too hard to be mythic , this one is a sweet and modest and ultimately winning story . 3 | |||||
a bilingual charmer , just like the woman who inspired it 3 | |||||
Like a less dizzily gorgeous companion to Mr. Wong 's In the Mood for Love -- very much a Hong Kong movie despite its mainland setting . 2 | |||||
As inept as big-screen remakes of The Avengers and The Wild Wild West . 1 | |||||
It 's everything you 'd expect -- but nothing more . 2 | |||||
Best indie of the year , so far . 4 | |||||
Hatfield and Hicks make the oddest of couples , and in this sense the movie becomes a study of the gambles of the publishing world , offering a case study that exists apart from all the movie 's political ramifications . 3 | |||||
It 's like going to a house party and watching the host defend himself against a frothing ex-girlfriend . 1 | |||||
That the Chuck Norris `` grenade gag '' occurs about 7 times during Windtalkers is a good indication of how serious-minded the film is . 2 | |||||
The plot is romantic comedy boilerplate from start to finish . 2 | |||||
It arrives with an impeccable pedigree , mongrel pep , and almost indecipherable plot complications . 2 | |||||
A film that clearly means to preach exclusively to the converted . 2 | |||||
I still like Moonlight Mile , better judgment be damned . 3 | |||||
A welcome relief from baseball movies that try too hard to be mythic , this one is a sweet and modest and ultimately winning story . 3 | |||||
a bilingual charmer , just like the woman who inspired it 3 | |||||
Like a less dizzily gorgeous companion to Mr. Wong 's In the Mood for Love -- very much a Hong Kong movie despite its mainland setting . 2 | |||||
As inept as big-screen remakes of The Avengers and The Wild Wild West . 1 | |||||
It 's everything you 'd expect -- but nothing more . 2 | |||||
Best indie of the year , so far . 4 | |||||
Hatfield and Hicks make the oddest of couples , and in this sense the movie becomes a study of the gambles of the publishing world , offering a case study that exists apart from all the movie 's political ramifications . 3 | |||||
It 's like going to a house party and watching the host defend himself against a frothing ex-girlfriend . 1 | |||||
That the Chuck Norris `` grenade gag '' occurs about 7 times during Windtalkers is a good indication of how serious-minded the film is . 2 | |||||
The plot is romantic comedy boilerplate from start to finish . 2 | |||||
It arrives with an impeccable pedigree , mongrel pep , and almost indecipherable plot complications . 2 | |||||
A film that clearly means to preach exclusively to the converted . 2 |
@@ -1,8 +1,7 @@ | |||||
import os | import os | ||||
import unittest | import unittest | ||||
from fastNLP.io.config_loader import ConfigSection, ConfigLoader | |||||
from fastNLP.io.config_saver import ConfigSaver | |||||
from fastNLP.io.config_io import ConfigSection, ConfigLoader, ConfigSaver | |||||
class TestConfigSaver(unittest.TestCase): | class TestConfigSaver(unittest.TestCase): | ||||
@@ -0,0 +1,12 @@ | |||||
import unittest | |||||
from fastNLP.core.vocabulary import Vocabulary | |||||
from fastNLP.io.embed_loader import EmbedLoader | |||||
class TestEmbedLoader(unittest.TestCase): | |||||
def test_case(self): | |||||
vocab = Vocabulary() | |||||
vocab.update(["the", "in", "I", "to", "of", "hahaha"]) | |||||
embedding = EmbedLoader().fast_load_embedding(50, "test/data_for_tests/glove.6B.50d_test.txt", vocab) | |||||
self.assertEqual(tuple(embedding.shape), (len(vocab), 50)) |
@@ -0,0 +1,91 @@ | |||||
import unittest | |||||
from fastNLP import DataSet | |||||
from fastNLP import Instance | |||||
from fastNLP import Tester | |||||
from fastNLP import Vocabulary | |||||
from fastNLP.core.losses import CrossEntropyLoss | |||||
from fastNLP.core.metrics import AccuracyMetric | |||||
from fastNLP.models import CNNText | |||||
class TestTutorial(unittest.TestCase): | |||||
def test_tutorial(self): | |||||
# 从csv读取数据到DataSet | |||||
sample_path = "test/data_for_tests/tutorial_sample_dataset.csv" | |||||
dataset = DataSet.read_csv(sample_path, headers=('raw_sentence', 'label'), | |||||
sep='\t') | |||||
print(len(dataset)) | |||||
print(dataset[0]) | |||||
dataset.append(Instance(raw_sentence='fake data', label='0')) | |||||
dataset.apply(lambda x: x['raw_sentence'].lower(), new_field_name='raw_sentence') | |||||
# label转int | |||||
dataset.apply(lambda x: int(x['label']), new_field_name='label') | |||||
# 使用空格分割句子 | |||||
def split_sent(ins): | |||||
return ins['raw_sentence'].split() | |||||
dataset.apply(split_sent, new_field_name='words') | |||||
# 增加长度信息 | |||||
dataset.apply(lambda x: len(x['words']), new_field_name='seq_len') | |||||
print(len(dataset)) | |||||
print(dataset[0]) | |||||
# DataSet.drop(func)筛除数据 | |||||
dataset.drop(lambda x: x['seq_len'] <= 3) | |||||
print(len(dataset)) | |||||
# 设置DataSet中,哪些field要转为tensor | |||||
# set target,loss或evaluate中的golden,计算loss,模型评估时使用 | |||||
dataset.set_target("label") | |||||
# set input,模型forward时使用 | |||||
dataset.set_input("words") | |||||
# 分出测试集、训练集 | |||||
test_data, train_data = dataset.split(0.5) | |||||
print(len(test_data)) | |||||
print(len(train_data)) | |||||
# 构建词表, Vocabulary.add(word) | |||||
vocab = Vocabulary(min_freq=2) | |||||
train_data.apply(lambda x: [vocab.add(word) for word in x['words']]) | |||||
vocab.build_vocab() | |||||
# index句子, Vocabulary.to_index(word) | |||||
train_data.apply(lambda x: [vocab.to_index(word) for word in x['words']], new_field_name='words') | |||||
test_data.apply(lambda x: [vocab.to_index(word) for word in x['words']], new_field_name='words') | |||||
print(test_data[0]) | |||||
model = CNNText(embed_num=len(vocab), embed_dim=50, num_classes=5, padding=2, dropout=0.1) | |||||
from fastNLP import Trainer | |||||
from copy import deepcopy | |||||
# 更改DataSet中对应field的名称,要以模型的forward等参数名一致 | |||||
train_data.rename_field('words', 'word_seq') # input field 与 forward 参数一致 | |||||
train_data.rename_field('label', 'label_seq') | |||||
test_data.rename_field('words', 'word_seq') | |||||
test_data.rename_field('label', 'label_seq') | |||||
# 实例化Trainer,传入模型和数据,进行训练 | |||||
copy_model = deepcopy(model) | |||||
overfit_trainer = Trainer(train_data=test_data, model=copy_model, | |||||
loss=CrossEntropyLoss(pred="output", target="label_seq"), | |||||
metrics=AccuracyMetric(pred="predict", target="label_seq"), n_epochs=10, batch_size=4, | |||||
dev_data=test_data, save_path="./save") | |||||
overfit_trainer.train() | |||||
trainer = Trainer(train_data=train_data, model=model, | |||||
loss=CrossEntropyLoss(pred="output", target="label_seq"), | |||||
metrics=AccuracyMetric(pred="predict", target="label_seq"), n_epochs=10, batch_size=4, | |||||
dev_data=test_data, save_path="./save") | |||||
trainer.train() | |||||
print('Train finished!') | |||||
# 使用fastNLP的Tester测试脚本 | |||||
tester = Tester(data=test_data, model=model, metrics=AccuracyMetric(pred="predict", target="label_seq"), | |||||
batch_size=4) | |||||
acc = tester.test() | |||||
print(acc) |
@@ -0,0 +1,911 @@ | |||||
{ | |||||
"cells": [ | |||||
{ | |||||
"cell_type": "markdown", | |||||
"metadata": {}, | |||||
"source": [ | |||||
"fastNLP上手教程\n", | |||||
"-------\n", | |||||
"\n", | |||||
"fastNLP提供方便的数据预处理,训练和测试模型的功能" | |||||
] | |||||
}, | |||||
{ | |||||
"cell_type": "markdown", | |||||
"metadata": {}, | |||||
"source": [ | |||||
"DataSet & Instance\n", | |||||
"------\n", | |||||
"\n", | |||||
"fastNLP用DataSet和Instance保存和处理数据。每个DataSet表示一个数据集,每个Instance表示一个数据样本。一个DataSet存有多个Instance,每个Instance可以自定义存哪些内容。\n", | |||||
"\n", | |||||
"有一些read_*方法,可以轻松从文件读取数据,存成DataSet。" | |||||
] | |||||
}, | |||||
{ | |||||
"cell_type": "code", | |||||
"execution_count": 9, | |||||
"metadata": {}, | |||||
"outputs": [ | |||||
{ | |||||
"name": "stdout", | |||||
"output_type": "stream", | |||||
"text": [ | |||||
"8529" | |||||
] | |||||
}, | |||||
{ | |||||
"name": "stdout", | |||||
"output_type": "stream", | |||||
"text": [ | |||||
"\n" | |||||
] | |||||
} | |||||
], | |||||
"source": [ | |||||
"from fastNLP import DataSet\n", | |||||
"from fastNLP import Instance\n", | |||||
"\n", | |||||
"# 从csv读取数据到DataSet\n", | |||||
"dataset = DataSet.read_csv('../sentence.csv', headers=('raw_sentence', 'label'), sep='\\t')\n", | |||||
"print(len(dataset))" | |||||
] | |||||
}, | |||||
{ | |||||
"cell_type": "code", | |||||
"execution_count": 10, | |||||
"metadata": {}, | |||||
"outputs": [ | |||||
{ | |||||
"name": "stdout", | |||||
"output_type": "stream", | |||||
"text": [ | |||||
"{'raw_sentence': A series of escapades demonstrating the adage that what is good for the goose is also good for the gander , some of which occasionally amuses but none of which amounts to much of a story .,\n'label': 1}" | |||||
] | |||||
}, | |||||
{ | |||||
"name": "stdout", | |||||
"output_type": "stream", | |||||
"text": [ | |||||
"\n" | |||||
] | |||||
} | |||||
], | |||||
"source": [ | |||||
"# 使用数字索引[k],获取第k个样本\n", | |||||
"print(dataset[0])\n", | |||||
"\n", | |||||
"# 索引也可以是负数\n", | |||||
"print(dataset[-3])" | |||||
] | |||||
}, | |||||
{ | |||||
"cell_type": "markdown", | |||||
"metadata": {}, | |||||
"source": [ | |||||
"## Instance\n", | |||||
"Instance表示一个样本,由一个或多个field(域,属性,特征)组成,每个field有名字和值。\n", | |||||
"\n", | |||||
"在初始化Instance时即可定义它包含的域,使用 \"field_name=field_value\"的写法。" | |||||
] | |||||
}, | |||||
{ | |||||
"cell_type": "code", | |||||
"execution_count": 11, | |||||
"metadata": {}, | |||||
"outputs": [ | |||||
{ | |||||
"data": { | |||||
"text/plain": [ | |||||
"{'raw_sentence': fake data,\n'label': 0}" | |||||
] | |||||
}, | |||||
"execution_count": 11, | |||||
"metadata": {}, | |||||
"output_type": "execute_result" | |||||
} | |||||
], | |||||
"source": [ | |||||
"# DataSet.append(Instance)加入新数据\n", | |||||
"dataset.append(Instance(raw_sentence='fake data', label='0'))\n", | |||||
"dataset[-1]" | |||||
] | |||||
}, | |||||
{ | |||||
"cell_type": "markdown", | |||||
"metadata": {}, | |||||
"source": [ | |||||
"## DataSet.apply方法\n", | |||||
"数据预处理利器" | |||||
] | |||||
}, | |||||
{ | |||||
"cell_type": "code", | |||||
"execution_count": 12, | |||||
"metadata": {}, | |||||
"outputs": [ | |||||
{ | |||||
"name": "stdout", | |||||
"output_type": "stream", | |||||
"text": [ | |||||
"{'raw_sentence': a series of escapades demonstrating the adage that what is good for the goose is also good for the gander , some of which occasionally amuses but none of which amounts to much of a story .,\n'label': 1}" | |||||
] | |||||
}, | |||||
{ | |||||
"name": "stdout", | |||||
"output_type": "stream", | |||||
"text": [ | |||||
"\n" | |||||
] | |||||
} | |||||
], | |||||
"source": [ | |||||
"# 将所有数字转为小写\n", | |||||
"dataset.apply(lambda x: x['raw_sentence'].lower(), new_field_name='raw_sentence')\n", | |||||
"print(dataset[0])" | |||||
] | |||||
}, | |||||
{ | |||||
"cell_type": "code", | |||||
"execution_count": 13, | |||||
"metadata": {}, | |||||
"outputs": [ | |||||
{ | |||||
"name": "stdout", | |||||
"output_type": "stream", | |||||
"text": [ | |||||
"{'raw_sentence': a series of escapades demonstrating the adage that what is good for the goose is also good for the gander , some of which occasionally amuses but none of which amounts to much of a story .,\n'label': 1}" | |||||
] | |||||
}, | |||||
{ | |||||
"name": "stdout", | |||||
"output_type": "stream", | |||||
"text": [ | |||||
"\n" | |||||
] | |||||
} | |||||
], | |||||
"source": [ | |||||
"# label转int\n", | |||||
"dataset.apply(lambda x: int(x['label']), new_field_name='label')\n", | |||||
"print(dataset[0])" | |||||
] | |||||
}, | |||||
{ | |||||
"cell_type": "code", | |||||
"execution_count": 14, | |||||
"metadata": {}, | |||||
"outputs": [ | |||||
{ | |||||
"name": "stdout", | |||||
"output_type": "stream", | |||||
"text": [ | |||||
"{'raw_sentence': a series of escapades demonstrating the adage that what is good for the goose is also good for the gander , some of which occasionally amuses but none of which amounts to much of a story .,\n'label': 1,\n'words': ['a', 'series', 'of', 'escapades', 'demonstrating', 'the', 'adage', 'that', 'what', 'is', 'good', 'for', 'the', 'goose', 'is', 'also', 'good', 'for', 'the', 'gander', ',', 'some', 'of', 'which', 'occasionally', 'amuses', 'but', 'none', 'of', 'which', 'amounts', 'to', 'much', 'of', 'a', 'story', '.']}" | |||||
] | |||||
}, | |||||
{ | |||||
"name": "stdout", | |||||
"output_type": "stream", | |||||
"text": [ | |||||
"\n" | |||||
] | |||||
} | |||||
], | |||||
"source": [ | |||||
"# 使用空格分割句子\n", | |||||
"def split_sent(ins):\n", | |||||
" return ins['raw_sentence'].split()\n", | |||||
"dataset.apply(split_sent, new_field_name='words')\n", | |||||
"print(dataset[0])" | |||||
] | |||||
}, | |||||
{ | |||||
"cell_type": "code", | |||||
"execution_count": 15, | |||||
"metadata": {}, | |||||
"outputs": [ | |||||
{ | |||||
"name": "stdout", | |||||
"output_type": "stream", | |||||
"text": [ | |||||
"{'raw_sentence': a series of escapades demonstrating the adage that what is good for the goose is also good for the gander , some of which occasionally amuses but none of which amounts to much of a story .,\n'label': 1,\n'words': ['a', 'series', 'of', 'escapades', 'demonstrating', 'the', 'adage', 'that', 'what', 'is', 'good', 'for', 'the', 'goose', 'is', 'also', 'good', 'for', 'the', 'gander', ',', 'some', 'of', 'which', 'occasionally', 'amuses', 'but', 'none', 'of', 'which', 'amounts', 'to', 'much', 'of', 'a', 'story', '.'],\n'seq_len': 37}" | |||||
] | |||||
}, | |||||
{ | |||||
"name": "stdout", | |||||
"output_type": "stream", | |||||
"text": [ | |||||
"\n" | |||||
] | |||||
} | |||||
], | |||||
"source": [ | |||||
"# 增加长度信息\n", | |||||
"dataset.apply(lambda x: len(x['words']), new_field_name='seq_len')\n", | |||||
"print(dataset[0])" | |||||
] | |||||
}, | |||||
{ | |||||
"cell_type": "markdown", | |||||
"metadata": {}, | |||||
"source": [ | |||||
"## DataSet.drop\n", | |||||
"筛选数据" | |||||
] | |||||
}, | |||||
{ | |||||
"cell_type": "code", | |||||
"execution_count": 16, | |||||
"metadata": {}, | |||||
"outputs": [ | |||||
{ | |||||
"name": "stdout", | |||||
"output_type": "stream", | |||||
"text": [ | |||||
"8358" | |||||
] | |||||
}, | |||||
{ | |||||
"name": "stdout", | |||||
"output_type": "stream", | |||||
"text": [ | |||||
"\n" | |||||
] | |||||
} | |||||
], | |||||
"source": [ | |||||
"dataset.drop(lambda x: x['seq_len'] <= 3)\n", | |||||
"print(len(dataset))" | |||||
] | |||||
}, | |||||
{ | |||||
"cell_type": "markdown", | |||||
"metadata": {}, | |||||
"source": [ | |||||
"## 配置DataSet\n", | |||||
"1. 哪些域是特征,哪些域是标签\n", | |||||
"2. 切分训练集/验证集" | |||||
] | |||||
}, | |||||
{ | |||||
"cell_type": "code", | |||||
"execution_count": 17, | |||||
"metadata": {}, | |||||
"outputs": [], | |||||
"source": [ | |||||
"# 设置DataSet中,哪些field要转为tensor\n", | |||||
"\n", | |||||
"# set target,loss或evaluate中的golden,计算loss,模型评估时使用\n", | |||||
"dataset.set_target(\"label\")\n", | |||||
"# set input,模型forward时使用\n", | |||||
"dataset.set_input(\"words\")" | |||||
] | |||||
}, | |||||
{ | |||||
"cell_type": "code", | |||||
"execution_count": 18, | |||||
"metadata": {}, | |||||
"outputs": [ | |||||
{ | |||||
"name": "stdout", | |||||
"output_type": "stream", | |||||
"text": [ | |||||
"5851" | |||||
] | |||||
}, | |||||
{ | |||||
"name": "stdout", | |||||
"output_type": "stream", | |||||
"text": [ | |||||
"\n" | |||||
] | |||||
}, | |||||
{ | |||||
"name": "stdout", | |||||
"output_type": "stream", | |||||
"text": [ | |||||
"2507" | |||||
] | |||||
}, | |||||
{ | |||||
"name": "stdout", | |||||
"output_type": "stream", | |||||
"text": [ | |||||
"\n" | |||||
] | |||||
} | |||||
], | |||||
"source": [ | |||||
"# 分出测试集、训练集\n", | |||||
"\n", | |||||
"test_data, train_data = dataset.split(0.3)\n", | |||||
"print(len(test_data))\n", | |||||
"print(len(train_data))" | |||||
] | |||||
}, | |||||
{ | |||||
"cell_type": "markdown", | |||||
"metadata": {}, | |||||
"source": [ | |||||
"Vocabulary\n", | |||||
"------\n", | |||||
"\n", | |||||
"fastNLP中的Vocabulary轻松构建词表,将词转成数字" | |||||
] | |||||
}, | |||||
{ | |||||
"cell_type": "code", | |||||
"execution_count": 19, | |||||
"metadata": {}, | |||||
"outputs": [ | |||||
{ | |||||
"name": "stdout", | |||||
"output_type": "stream", | |||||
"text": [ | |||||
"{'raw_sentence': the project 's filmmakers forgot to include anything even halfway scary as they poorly rejigger fatal attraction into a high school setting .,\n'label': 0,\n'words': [4, 423, 9, 316, 1, 8, 1, 312, 72, 1478, 885, 14, 86, 725, 1, 1913, 1431, 53, 5, 455, 736, 1, 2],\n'seq_len': 23}" | |||||
] | |||||
}, | |||||
{ | |||||
"name": "stdout", | |||||
"output_type": "stream", | |||||
"text": [ | |||||
"\n" | |||||
] | |||||
} | |||||
], | |||||
"source": [ | |||||
"from fastNLP import Vocabulary\n", | |||||
"\n", | |||||
"# 构建词表, Vocabulary.add(word)\n", | |||||
"vocab = Vocabulary(min_freq=2)\n", | |||||
"train_data.apply(lambda x: [vocab.add(word) for word in x['words']])\n", | |||||
"vocab.build_vocab()\n", | |||||
"\n", | |||||
"# index句子, Vocabulary.to_index(word)\n", | |||||
"train_data.apply(lambda x: [vocab.to_index(word) for word in x['words']], new_field_name='words')\n", | |||||
"test_data.apply(lambda x: [vocab.to_index(word) for word in x['words']], new_field_name='words')\n", | |||||
"\n", | |||||
"\n", | |||||
"print(test_data[0])" | |||||
] | |||||
}, | |||||
{ | |||||
"cell_type": "markdown", | |||||
"metadata": {}, | |||||
"source": [ | |||||
"# Model\n", | |||||
"定义一个PyTorch模型" | |||||
] | |||||
}, | |||||
{ | |||||
"cell_type": "code", | |||||
"execution_count": 20, | |||||
"metadata": {}, | |||||
"outputs": [ | |||||
{ | |||||
"data": { | |||||
"text/plain": [ | |||||
"CNNText(\n (embed): Embedding(\n (embed): Embedding(3459, 50, padding_idx=0)\n (dropout): Dropout(p=0.0)\n )\n (conv_pool): ConvMaxpool(\n (convs): ModuleList(\n (0): Conv1d(50, 3, kernel_size=(3,), stride=(1,), padding=(2,))\n (1): Conv1d(50, 4, kernel_size=(4,), stride=(1,), padding=(2,))\n (2): Conv1d(50, 5, kernel_size=(5,), stride=(1,), padding=(2,))\n )\n )\n (dropout): Dropout(p=0.1)\n (fc): Linear(\n (linear): Linear(in_features=12, out_features=5, bias=True)\n )\n)" | |||||
] | |||||
}, | |||||
"execution_count": 20, | |||||
"metadata": {}, | |||||
"output_type": "execute_result" | |||||
} | |||||
], | |||||
"source": [ | |||||
"from fastNLP.models import CNNText\n", | |||||
"model = CNNText(embed_num=len(vocab), embed_dim=50, num_classes=5, padding=2, dropout=0.1)\n", | |||||
"model" | |||||
] | |||||
}, | |||||
{ | |||||
"cell_type": "markdown", | |||||
"metadata": {}, | |||||
"source": [ | |||||
"这是上述模型的forward方法。如果你不知道什么是forward方法,请参考我们的PyTorch教程。\n", | |||||
"\n", | |||||
"注意两点:\n", | |||||
"1. forward参数名字叫**word_seq**,请记住。\n", | |||||
"2. forward的返回值是一个**dict**,其中有个key的名字叫**output**。\n", | |||||
"\n", | |||||
"```Python\n", | |||||
" def forward(self, word_seq):\n", | |||||
" \"\"\"\n", | |||||
"\n", | |||||
" :param word_seq: torch.LongTensor, [batch_size, seq_len]\n", | |||||
" :return output: dict of torch.LongTensor, [batch_size, num_classes]\n", | |||||
" \"\"\"\n", | |||||
" x = self.embed(word_seq) # [N,L] -> [N,L,C]\n", | |||||
" x = self.conv_pool(x) # [N,L,C] -> [N,C]\n", | |||||
" x = self.dropout(x)\n", | |||||
" x = self.fc(x) # [N,C] -> [N, N_class]\n", | |||||
" return {'output': x}\n", | |||||
"```" | |||||
] | |||||
}, | |||||
{ | |||||
"cell_type": "markdown", | |||||
"metadata": {}, | |||||
"source": [ | |||||
"这是上述模型的predict方法,是用来直接输出该任务的预测结果,与forward目的不同。\n", | |||||
"\n", | |||||
"注意两点:\n", | |||||
"1. predict参数名也叫**word_seq**。\n", | |||||
"2. predict的返回值是也一个**dict**,其中有个key的名字叫**predict**。\n", | |||||
"\n", | |||||
"```\n", | |||||
" def predict(self, word_seq):\n", | |||||
" \"\"\"\n", | |||||
"\n", | |||||
" :param word_seq: torch.LongTensor, [batch_size, seq_len]\n", | |||||
" :return predict: dict of torch.LongTensor, [batch_size, seq_len]\n", | |||||
" \"\"\"\n", | |||||
" output = self(word_seq)\n", | |||||
" _, predict = output['output'].max(dim=1)\n", | |||||
" return {'predict': predict}\n", | |||||
"```" | |||||
] | |||||
}, | |||||
{ | |||||
"cell_type": "markdown", | |||||
"metadata": {}, | |||||
"source": [ | |||||
"Trainer & Tester\n", | |||||
"------\n", | |||||
"\n", | |||||
"使用fastNLP的Trainer训练模型" | |||||
] | |||||
}, | |||||
{ | |||||
"cell_type": "code", | |||||
"execution_count": 21, | |||||
"metadata": {}, | |||||
"outputs": [], | |||||
"source": [ | |||||
"from fastNLP import Trainer\n", | |||||
"from copy import deepcopy\n", | |||||
"from fastNLP.core.losses import CrossEntropyLoss\n", | |||||
"from fastNLP.core.metrics import AccuracyMetric\n", | |||||
"\n", | |||||
"\n", | |||||
"# 更改DataSet中对应field的名称,与模型的forward的参数名一致\n", | |||||
"# 因为forward的参数叫word_seq, 所以要把原本叫words的field改名为word_seq\n", | |||||
"# 这里的演示是让你了解这种**命名规则**\n", | |||||
"train_data.rename_field('words', 'word_seq')\n", | |||||
"test_data.rename_field('words', 'word_seq')\n", | |||||
"\n", | |||||
"# 顺便把label换名为label_seq\n", | |||||
"train_data.rename_field('label', 'label_seq')\n", | |||||
"test_data.rename_field('label', 'label_seq')" | |||||
] | |||||
}, | |||||
{ | |||||
"cell_type": "markdown", | |||||
"metadata": {}, | |||||
"source": [ | |||||
"### loss\n", | |||||
"训练模型需要提供一个损失函数\n", | |||||
"\n", | |||||
"下面提供了一个在分类问题中常用的交叉熵损失。注意它的**初始化参数**。\n", | |||||
"\n", | |||||
"pred参数对应的是模型的forward返回的dict的一个key的名字,这里是\"output\"。\n", | |||||
"\n", | |||||
"target参数对应的是dataset作为标签的field的名字,这里是\"label_seq\"。" | |||||
] | |||||
}, | |||||
{ | |||||
"cell_type": "code", | |||||
"execution_count": 22, | |||||
"metadata": {}, | |||||
"outputs": [], | |||||
"source": [ | |||||
"loss = CrossEntropyLoss(pred=\"output\", target=\"label_seq\")" | |||||
] | |||||
}, | |||||
{ | |||||
"cell_type": "markdown", | |||||
"metadata": {}, | |||||
"source": [ | |||||
"### Metric\n", | |||||
"定义评价指标\n", | |||||
"\n", | |||||
"这里使用准确率。参数的“命名规则”跟上面类似。\n", | |||||
"\n", | |||||
"pred参数对应的是模型的predict方法返回的dict的一个key的名字,这里是\"predict\"。\n", | |||||
"\n", | |||||
"target参数对应的是dataset作为标签的field的名字,这里是\"label_seq\"。" | |||||
] | |||||
}, | |||||
{ | |||||
"cell_type": "code", | |||||
"execution_count": 23, | |||||
"metadata": {}, | |||||
"outputs": [], | |||||
"source": [ | |||||
"metric = AccuracyMetric(pred=\"predict\", target=\"label_seq\")" | |||||
] | |||||
}, | |||||
{ | |||||
"cell_type": "code", | |||||
"execution_count": 24, | |||||
"metadata": {}, | |||||
"outputs": [ | |||||
{ | |||||
"name": "stdout", | |||||
"output_type": "stream", | |||||
"text": [ | |||||
"training epochs started 2018-12-07 14:11:31" | |||||
] | |||||
}, | |||||
{ | |||||
"name": "stdout", | |||||
"output_type": "stream", | |||||
"text": [ | |||||
"\n" | |||||
] | |||||
}, | |||||
{ | |||||
"data": { | |||||
"text/plain": [ | |||||
"HBox(children=(IntProgress(value=0, layout=Layout(flex='2'), max=915), HTML(value='')), layout=Layout(display=…" | |||||
] | |||||
}, | |||||
"execution_count": 0, | |||||
"metadata": {}, | |||||
"output_type": "execute_result" | |||||
}, | |||||
{ | |||||
"name": "stdout", | |||||
"output_type": "stream", | |||||
"text": [ | |||||
"\r" | |||||
] | |||||
}, | |||||
{ | |||||
"name": "stdout", | |||||
"output_type": "stream", | |||||
"text": [ | |||||
"Epoch 1/5. Step:183/915. AccuracyMetric: acc=0.350367" | |||||
] | |||||
}, | |||||
{ | |||||
"name": "stdout", | |||||
"output_type": "stream", | |||||
"text": [ | |||||
"\n" | |||||
] | |||||
}, | |||||
{ | |||||
"name": "stdout", | |||||
"output_type": "stream", | |||||
"text": [ | |||||
"\r" | |||||
] | |||||
}, | |||||
{ | |||||
"name": "stdout", | |||||
"output_type": "stream", | |||||
"text": [ | |||||
"Epoch 2/5. Step:366/915. AccuracyMetric: acc=0.409332" | |||||
] | |||||
}, | |||||
{ | |||||
"name": "stdout", | |||||
"output_type": "stream", | |||||
"text": [ | |||||
"\n" | |||||
] | |||||
}, | |||||
{ | |||||
"name": "stdout", | |||||
"output_type": "stream", | |||||
"text": [ | |||||
"\r" | |||||
] | |||||
}, | |||||
{ | |||||
"name": "stdout", | |||||
"output_type": "stream", | |||||
"text": [ | |||||
"Epoch 3/5. Step:549/915. AccuracyMetric: acc=0.572552" | |||||
] | |||||
}, | |||||
{ | |||||
"name": "stdout", | |||||
"output_type": "stream", | |||||
"text": [ | |||||
"\n" | |||||
] | |||||
}, | |||||
{ | |||||
"name": "stdout", | |||||
"output_type": "stream", | |||||
"text": [ | |||||
"\r" | |||||
] | |||||
}, | |||||
{ | |||||
"name": "stdout", | |||||
"output_type": "stream", | |||||
"text": [ | |||||
"Epoch 4/5. Step:732/915. AccuracyMetric: acc=0.711331" | |||||
] | |||||
}, | |||||
{ | |||||
"name": "stdout", | |||||
"output_type": "stream", | |||||
"text": [ | |||||
"\n" | |||||
] | |||||
}, | |||||
{ | |||||
"name": "stdout", | |||||
"output_type": "stream", | |||||
"text": [ | |||||
"\r" | |||||
] | |||||
}, | |||||
{ | |||||
"name": "stdout", | |||||
"output_type": "stream", | |||||
"text": [ | |||||
"Epoch 5/5. Step:915/915. AccuracyMetric: acc=0.801572" | |||||
] | |||||
}, | |||||
{ | |||||
"name": "stdout", | |||||
"output_type": "stream", | |||||
"text": [ | |||||
"\n" | |||||
] | |||||
}, | |||||
{ | |||||
"name": "stdout", | |||||
"output_type": "stream", | |||||
"text": [ | |||||
"\r" | |||||
] | |||||
} | |||||
], | |||||
"source": [ | |||||
"# 实例化Trainer,传入模型和数据,进行训练\n", | |||||
"# 先在test_data拟合\n", | |||||
"copy_model = deepcopy(model)\n", | |||||
"overfit_trainer = Trainer(model=copy_model, train_data=test_data, dev_data=test_data,\n", | |||||
" loss=loss,\n", | |||||
" metrics=metric,\n", | |||||
" save_path=None,\n", | |||||
" batch_size=32,\n", | |||||
" n_epochs=5)\n", | |||||
"overfit_trainer.train()" | |||||
] | |||||
}, | |||||
{ | |||||
"cell_type": "code", | |||||
"execution_count": 25, | |||||
"metadata": {}, | |||||
"outputs": [ | |||||
{ | |||||
"name": "stdout", | |||||
"output_type": "stream", | |||||
"text": [ | |||||
"training epochs started 2018-12-07 14:12:21" | |||||
] | |||||
}, | |||||
{ | |||||
"name": "stdout", | |||||
"output_type": "stream", | |||||
"text": [ | |||||
"\n" | |||||
] | |||||
}, | |||||
{ | |||||
"data": { | |||||
"text/plain": [ | |||||
"HBox(children=(IntProgress(value=0, layout=Layout(flex='2'), max=395), HTML(value='')), layout=Layout(display=…" | |||||
] | |||||
}, | |||||
"execution_count": 0, | |||||
"metadata": {}, | |||||
"output_type": "execute_result" | |||||
}, | |||||
{ | |||||
"name": "stdout", | |||||
"output_type": "stream", | |||||
"text": [ | |||||
"\r" | |||||
] | |||||
}, | |||||
{ | |||||
"name": "stdout", | |||||
"output_type": "stream", | |||||
"text": [ | |||||
"Epoch 1/5. Step:79/395. AccuracyMetric: acc=0.250043" | |||||
] | |||||
}, | |||||
{ | |||||
"name": "stdout", | |||||
"output_type": "stream", | |||||
"text": [ | |||||
"\n" | |||||
] | |||||
}, | |||||
{ | |||||
"name": "stdout", | |||||
"output_type": "stream", | |||||
"text": [ | |||||
"\r" | |||||
] | |||||
}, | |||||
{ | |||||
"name": "stdout", | |||||
"output_type": "stream", | |||||
"text": [ | |||||
"Epoch 2/5. Step:158/395. AccuracyMetric: acc=0.280807" | |||||
] | |||||
}, | |||||
{ | |||||
"name": "stdout", | |||||
"output_type": "stream", | |||||
"text": [ | |||||
"\n" | |||||
] | |||||
}, | |||||
{ | |||||
"name": "stdout", | |||||
"output_type": "stream", | |||||
"text": [ | |||||
"\r" | |||||
] | |||||
}, | |||||
{ | |||||
"name": "stdout", | |||||
"output_type": "stream", | |||||
"text": [ | |||||
"Epoch 3/5. Step:237/395. AccuracyMetric: acc=0.280978" | |||||
] | |||||
}, | |||||
{ | |||||
"name": "stdout", | |||||
"output_type": "stream", | |||||
"text": [ | |||||
"\n" | |||||
] | |||||
}, | |||||
{ | |||||
"name": "stdout", | |||||
"output_type": "stream", | |||||
"text": [ | |||||
"\r" | |||||
] | |||||
}, | |||||
{ | |||||
"name": "stdout", | |||||
"output_type": "stream", | |||||
"text": [ | |||||
"Epoch 4/5. Step:316/395. AccuracyMetric: acc=0.285592" | |||||
] | |||||
}, | |||||
{ | |||||
"name": "stdout", | |||||
"output_type": "stream", | |||||
"text": [ | |||||
"\n" | |||||
] | |||||
}, | |||||
{ | |||||
"name": "stdout", | |||||
"output_type": "stream", | |||||
"text": [ | |||||
"\r" | |||||
] | |||||
}, | |||||
{ | |||||
"name": "stdout", | |||||
"output_type": "stream", | |||||
"text": [ | |||||
"Epoch 5/5. Step:395/395. AccuracyMetric: acc=0.278927" | |||||
] | |||||
}, | |||||
{ | |||||
"name": "stdout", | |||||
"output_type": "stream", | |||||
"text": [ | |||||
"\n" | |||||
] | |||||
}, | |||||
{ | |||||
"name": "stdout", | |||||
"output_type": "stream", | |||||
"text": [ | |||||
"\r" | |||||
] | |||||
} | |||||
], | |||||
"source": [ | |||||
"# 用train_data训练,在test_data验证\n", | |||||
"trainer = Trainer(model=model, train_data=train_data, dev_data=test_data,\n", | |||||
" loss=CrossEntropyLoss(pred=\"output\", target=\"label_seq\"),\n", | |||||
" metrics=AccuracyMetric(pred=\"predict\", target=\"label_seq\"),\n", | |||||
" save_path=None,\n", | |||||
" batch_size=32,\n", | |||||
" n_epochs=5)\n", | |||||
"trainer.train()\n", | |||||
"print('Train finished!')" | |||||
] | |||||
}, | |||||
{ | |||||
"cell_type": "code", | |||||
"execution_count": 26, | |||||
"metadata": {}, | |||||
"outputs": [ | |||||
{ | |||||
"name": "stdout", | |||||
"output_type": "stream", | |||||
"text": [ | |||||
"[tester] \nAccuracyMetric: acc=0.280636" | |||||
] | |||||
}, | |||||
{ | |||||
"name": "stdout", | |||||
"output_type": "stream", | |||||
"text": [ | |||||
"\n" | |||||
] | |||||
}, | |||||
{ | |||||
"name": "stdout", | |||||
"output_type": "stream", | |||||
"text": [ | |||||
"{'AccuracyMetric': {'acc': 0.280636}}" | |||||
] | |||||
}, | |||||
{ | |||||
"name": "stdout", | |||||
"output_type": "stream", | |||||
"text": [ | |||||
"\n" | |||||
] | |||||
} | |||||
], | |||||
"source": [ | |||||
"# 调用Tester在test_data上评价效果\n", | |||||
"from fastNLP import Tester\n", | |||||
"\n", | |||||
"tester = Tester(data=test_data, model=model, metrics=AccuracyMetric(pred=\"predict\", target=\"label_seq\"),\n", | |||||
" batch_size=4)\n", | |||||
"acc = tester.test()\n", | |||||
"print(acc)" | |||||
] | |||||
}, | |||||
{ | |||||
"cell_type": "code", | |||||
"execution_count": null, | |||||
"metadata": {}, | |||||
"outputs": [], | |||||
"source": [] | |||||
} | |||||
], | |||||
"metadata": { | |||||
"kernelspec": { | |||||
"display_name": "Python 3", | |||||
"language": "python", | |||||
"name": "python3" | |||||
}, | |||||
"language_info": { | |||||
"codemirror_mode": { | |||||
"name": "ipython", | |||||
"version": 3 | |||||
}, | |||||
"file_extension": ".py", | |||||
"mimetype": "text/x-python", | |||||
"name": "python", | |||||
"nbconvert_exporter": "python", | |||||
"pygments_lexer": "ipython3", | |||||
"version": "3.6.7" | |||||
} | |||||
}, | |||||
"nbformat": 4, | |||||
"nbformat_minor": 2 | |||||
} |
@@ -0,0 +1,860 @@ | |||||
{ | |||||
"cells": [ | |||||
{ | |||||
"cell_type": "markdown", | |||||
"metadata": {}, | |||||
"source": [ | |||||
"fastNLP上手教程\n", | |||||
"-------\n", | |||||
"\n", | |||||
"fastNLP提供方便的数据预处理,训练和测试模型的功能" | |||||
] | |||||
}, | |||||
{ | |||||
"cell_type": "markdown", | |||||
"metadata": {}, | |||||
"source": [ | |||||
"DataSet & Instance\n", | |||||
"------\n", | |||||
"\n", | |||||
"fastNLP用DataSet和Instance保存和处理数据。每个DataSet表示一个数据集,每个Instance表示一个数据样本。一个DataSet存有多个Instance,每个Instance可以自定义存哪些内容。\n", | |||||
"\n", | |||||
"有一些read_*方法,可以轻松从文件读取数据,存成DataSet。" | |||||
] | |||||
}, | |||||
{ | |||||
"cell_type": "code", | |||||
"execution_count": null, | |||||
"metadata": {}, | |||||
"outputs": [ | |||||
{ | |||||
"name": "stdout", | |||||
"output_type": "stream", | |||||
"text": [ | |||||
"{'raw_sentence': A series of escapades demonstrating the adage that what is good for the goose is also good for the gander , some of which occasionally amuses but none of which amounts to much of a story .,\n'label': 1}" | |||||
] | |||||
}, | |||||
{ | |||||
"name": "stdout", | |||||
"output_type": "stream", | |||||
"text": [ | |||||
"\n" | |||||
] | |||||
} | |||||
], | |||||
"source": [ | |||||
"from fastNLP import DataSet\n", | |||||
"from fastNLP import Instance\n", | |||||
"\n", | |||||
"# 从csv读取数据到DataSet\n", | |||||
"win_path = \"C:\\\\Users\\zyfeng\\Desktop\\FudanNLP\\\\fastNLP\\\\test\\\\data_for_tests\\\\tutorial_sample_dataset.csv\"\n", | |||||
"dataset = DataSet.read_csv(win_path, headers=('raw_sentence', 'label'), sep='\\t')\n", | |||||
"print(dataset[0])" | |||||
] | |||||
}, | |||||
{ | |||||
"cell_type": "code", | |||||
"execution_count": 2, | |||||
"metadata": {}, | |||||
"outputs": [ | |||||
{ | |||||
"data": { | |||||
"text/plain": [ | |||||
"{'raw_sentence': fake data,\n'label': 0}" | |||||
] | |||||
}, | |||||
"execution_count": 2, | |||||
"metadata": {}, | |||||
"output_type": "execute_result" | |||||
} | |||||
], | |||||
"source": [ | |||||
"# DataSet.append(Instance)加入新数据\n", | |||||
"\n", | |||||
"dataset.append(Instance(raw_sentence='fake data', label='0'))\n", | |||||
"dataset[-1]" | |||||
] | |||||
}, | |||||
{ | |||||
"cell_type": "code", | |||||
"execution_count": 3, | |||||
"metadata": {}, | |||||
"outputs": [], | |||||
"source": [ | |||||
"# DataSet.apply(func, new_field_name)对数据预处理\n", | |||||
"\n", | |||||
"# 将所有数字转为小写\n", | |||||
"dataset.apply(lambda x: x['raw_sentence'].lower(), new_field_name='raw_sentence')\n", | |||||
"# label转int\n", | |||||
"dataset.apply(lambda x: int(x['label']), new_field_name='label_seq', is_target=True)\n", | |||||
"# 使用空格分割句子\n", | |||||
"dataset.drop(lambda x: len(x['raw_sentence'].split()) == 0)\n", | |||||
"def split_sent(ins):\n", | |||||
" return ins['raw_sentence'].split()\n", | |||||
"dataset.apply(split_sent, new_field_name='words', is_input=True)" | |||||
] | |||||
}, | |||||
{ | |||||
"cell_type": "code", | |||||
"execution_count": 4, | |||||
"metadata": {}, | |||||
"outputs": [], | |||||
"source": [ | |||||
"# DataSet.drop(func)筛除数据\n", | |||||
"# 删除低于某个长度的词语\n", | |||||
"dataset.drop(lambda x: len(x['words']) <= 3)" | |||||
] | |||||
}, | |||||
{ | |||||
"cell_type": "code", | |||||
"execution_count": 7, | |||||
"metadata": {}, | |||||
"outputs": [ | |||||
{ | |||||
"name": "stdout", | |||||
"output_type": "stream", | |||||
"text": [ | |||||
"Train size: " | |||||
] | |||||
}, | |||||
{ | |||||
"name": "stdout", | |||||
"output_type": "stream", | |||||
"text": [ | |||||
" " | |||||
] | |||||
}, | |||||
{ | |||||
"name": "stdout", | |||||
"output_type": "stream", | |||||
"text": [ | |||||
"54" | |||||
] | |||||
}, | |||||
{ | |||||
"name": "stdout", | |||||
"output_type": "stream", | |||||
"text": [ | |||||
"\n" | |||||
] | |||||
}, | |||||
{ | |||||
"name": "stdout", | |||||
"output_type": "stream", | |||||
"text": [ | |||||
"Test size: " | |||||
] | |||||
} | |||||
], | |||||
"source": [ | |||||
"# 分出测试集、训练集\n", | |||||
"\n", | |||||
"test_data, train_data = dataset.split(0.3)\n", | |||||
"print(\"Train size: \", len(test_data))\n", | |||||
"print(\"Test size: \", len(train_data))" | |||||
] | |||||
}, | |||||
{ | |||||
"cell_type": "markdown", | |||||
"metadata": {}, | |||||
"source": [ | |||||
"Vocabulary\n", | |||||
"------\n", | |||||
"\n", | |||||
"fastNLP中的Vocabulary轻松构建词表,将词转成数字" | |||||
] | |||||
}, | |||||
{ | |||||
"cell_type": "code", | |||||
"execution_count": 8, | |||||
"metadata": {}, | |||||
"outputs": [ | |||||
{ | |||||
"name": "stdout", | |||||
"output_type": "stream", | |||||
"text": [ | |||||
"{'raw_sentence': the plot is romantic comedy boilerplate from start to finish .,\n'label': 2,\n'label_seq': 2,\n'words': ['the', 'plot', 'is', 'romantic', 'comedy', 'boilerplate', 'from', 'start', 'to', 'finish', '.'],\n'word_seq': [2, 13, 9, 24, 25, 26, 15, 27, 11, 28, 3]}" | |||||
] | |||||
}, | |||||
{ | |||||
"name": "stdout", | |||||
"output_type": "stream", | |||||
"text": [ | |||||
"\n" | |||||
] | |||||
} | |||||
], | |||||
"source": [ | |||||
"from fastNLP import Vocabulary\n", | |||||
"\n", | |||||
"# 构建词表, Vocabulary.add(word)\n", | |||||
"vocab = Vocabulary(min_freq=2)\n", | |||||
"train_data.apply(lambda x: [vocab.add(word) for word in x['words']])\n", | |||||
"vocab.build_vocab()\n", | |||||
"\n", | |||||
"# index句子, Vocabulary.to_index(word)\n", | |||||
"train_data.apply(lambda x: [vocab.to_index(word) for word in x['words']], new_field_name='word_seq', is_input=True)\n", | |||||
"test_data.apply(lambda x: [vocab.to_index(word) for word in x['words']], new_field_name='word_seq', is_input=True)\n", | |||||
"\n", | |||||
"\n", | |||||
"print(test_data[0])" | |||||
] | |||||
}, | |||||
{ | |||||
"cell_type": "code", | |||||
"execution_count": 8, | |||||
"metadata": { | |||||
"scrolled": true | |||||
}, | |||||
"outputs": [ | |||||
{ | |||||
"name": "stdout", | |||||
"output_type": "stream", | |||||
"text": [ | |||||
"batch_x has: {'words': array([list(['this', 'kind', 'of', 'hands-on', 'storytelling', 'is', 'ultimately', 'what', 'makes', 'shanghai', 'ghetto', 'move', 'beyond', 'a', 'good', ',', 'dry', ',', 'reliable', 'textbook', 'and', 'what', 'allows', 'it', 'to', 'rank', 'with', 'its', 'worthy', 'predecessors', '.']),\n", | |||||
" list(['the', 'entire', 'movie', 'is', 'filled', 'with', 'deja', 'vu', 'moments', '.'])],\n", | |||||
" dtype=object), 'word_seq': tensor([[ 19, 184, 6, 1, 481, 9, 206, 50, 91, 1210, 1609, 1330,\n", | |||||
" 495, 5, 63, 4, 1269, 4, 1, 1184, 7, 50, 1050, 10,\n", | |||||
" 8, 1611, 16, 21, 1039, 1, 2],\n", | |||||
" [ 3, 711, 22, 9, 1282, 16, 2482, 2483, 200, 2, 0, 0,\n", | |||||
" 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n", | |||||
" 0, 0, 0, 0, 0, 0, 0]])}\n", | |||||
"batch_y has: {'label_seq': tensor([3, 2])}\n" | |||||
] | |||||
} | |||||
], | |||||
"source": [ | |||||
"# 假设你们需要做强化学习或者gan之类的项目,也许你们可以使用这里的dataset\n", | |||||
"from fastNLP.core.batch import Batch\n", | |||||
"from fastNLP.core.sampler import RandomSampler\n", | |||||
"\n", | |||||
"batch_iterator = Batch(dataset=train_data, batch_size=2, sampler=RandomSampler())\n", | |||||
"for batch_x, batch_y in batch_iterator:\n", | |||||
" print(\"batch_x has: \", batch_x)\n", | |||||
" print(\"batch_y has: \", batch_y)\n", | |||||
" break" | |||||
] | |||||
}, | |||||
{ | |||||
"cell_type": "markdown", | |||||
"metadata": {}, | |||||
"source": [ | |||||
"# Model\n" | |||||
] | |||||
}, | |||||
{ | |||||
"cell_type": "code", | |||||
"execution_count": 9, | |||||
"metadata": { | |||||
"collapsed": false | |||||
}, | |||||
"outputs": [ | |||||
{ | |||||
"data": { | |||||
"text/plain": [ | |||||
"CNNText(\n (embed): Embedding(\n (embed): Embedding(77, 50, padding_idx=0)\n (dropout): Dropout(p=0.0)\n )\n (conv_pool): ConvMaxpool(\n (convs): ModuleList(\n (0): Conv1d(50, 3, kernel_size=(3,), stride=(1,), padding=(2,))\n (1): Conv1d(50, 4, kernel_size=(4,), stride=(1,), padding=(2,))\n (2): Conv1d(50, 5, kernel_size=(5,), stride=(1,), padding=(2,))\n )\n )\n (dropout): Dropout(p=0.1)\n (fc): Linear(\n (linear): Linear(in_features=12, out_features=5, bias=True)\n )\n)" | |||||
] | |||||
}, | |||||
"execution_count": 9, | |||||
"metadata": {}, | |||||
"output_type": "execute_result" | |||||
} | |||||
], | |||||
"source": [ | |||||
"# 定义一个简单的Pytorch模型\n", | |||||
"\n", | |||||
"from fastNLP.models import CNNText\n", | |||||
"model = CNNText(embed_num=len(vocab), embed_dim=50, num_classes=5, padding=2, dropout=0.1)\n", | |||||
"model" | |||||
] | |||||
}, | |||||
{ | |||||
"cell_type": "markdown", | |||||
"metadata": {}, | |||||
"source": [ | |||||
"Trainer & Tester\n", | |||||
"------\n", | |||||
"\n", | |||||
"使用fastNLP的Trainer训练模型" | |||||
] | |||||
}, | |||||
{ | |||||
"cell_type": "code", | |||||
"execution_count": 11, | |||||
"metadata": {}, | |||||
"outputs": [], | |||||
"source": [ | |||||
"from fastNLP import Trainer\n", | |||||
"from copy import deepcopy\n", | |||||
"from fastNLP import CrossEntropyLoss\n", | |||||
"from fastNLP import AccuracyMetric" | |||||
] | |||||
}, | |||||
{ | |||||
"cell_type": "code", | |||||
"execution_count": 12, | |||||
"metadata": {}, | |||||
"outputs": [ | |||||
{ | |||||
"name": "stdout", | |||||
"output_type": "stream", | |||||
"text": [ | |||||
"training epochs started 2018-12-07 14:07:20" | |||||
] | |||||
}, | |||||
{ | |||||
"name": "stdout", | |||||
"output_type": "stream", | |||||
"text": [ | |||||
"\n" | |||||
] | |||||
}, | |||||
{ | |||||
"data": { | |||||
"text/plain": [ | |||||
"HBox(children=(IntProgress(value=0, layout=Layout(flex='2'), max=20), HTML(value='')), layout=Layout(display='…" | |||||
] | |||||
}, | |||||
"execution_count": 0, | |||||
"metadata": {}, | |||||
"output_type": "execute_result" | |||||
}, | |||||
{ | |||||
"name": "stdout", | |||||
"output_type": "stream", | |||||
"text": [ | |||||
"\r" | |||||
] | |||||
}, | |||||
{ | |||||
"name": "stdout", | |||||
"output_type": "stream", | |||||
"text": [ | |||||
"Epoch 1/10. Step:2/20. AccuracyMetric: acc=0.037037" | |||||
] | |||||
}, | |||||
{ | |||||
"name": "stdout", | |||||
"output_type": "stream", | |||||
"text": [ | |||||
"\n" | |||||
] | |||||
}, | |||||
{ | |||||
"name": "stdout", | |||||
"output_type": "stream", | |||||
"text": [ | |||||
"\r" | |||||
] | |||||
}, | |||||
{ | |||||
"name": "stdout", | |||||
"output_type": "stream", | |||||
"text": [ | |||||
"Epoch 2/10. Step:4/20. AccuracyMetric: acc=0.296296" | |||||
] | |||||
}, | |||||
{ | |||||
"name": "stdout", | |||||
"output_type": "stream", | |||||
"text": [ | |||||
"\n" | |||||
] | |||||
}, | |||||
{ | |||||
"name": "stdout", | |||||
"output_type": "stream", | |||||
"text": [ | |||||
"\r" | |||||
] | |||||
}, | |||||
{ | |||||
"name": "stdout", | |||||
"output_type": "stream", | |||||
"text": [ | |||||
"Epoch 3/10. Step:6/20. AccuracyMetric: acc=0.333333" | |||||
] | |||||
}, | |||||
{ | |||||
"name": "stdout", | |||||
"output_type": "stream", | |||||
"text": [ | |||||
"\n" | |||||
] | |||||
}, | |||||
{ | |||||
"name": "stdout", | |||||
"output_type": "stream", | |||||
"text": [ | |||||
"\r" | |||||
] | |||||
}, | |||||
{ | |||||
"name": "stdout", | |||||
"output_type": "stream", | |||||
"text": [ | |||||
"Epoch 4/10. Step:8/20. AccuracyMetric: acc=0.555556" | |||||
] | |||||
}, | |||||
{ | |||||
"name": "stdout", | |||||
"output_type": "stream", | |||||
"text": [ | |||||
"\n" | |||||
] | |||||
}, | |||||
{ | |||||
"name": "stdout", | |||||
"output_type": "stream", | |||||
"text": [ | |||||
"\r" | |||||
] | |||||
}, | |||||
{ | |||||
"name": "stdout", | |||||
"output_type": "stream", | |||||
"text": [ | |||||
"Epoch 5/10. Step:10/20. AccuracyMetric: acc=0.611111" | |||||
] | |||||
}, | |||||
{ | |||||
"name": "stdout", | |||||
"output_type": "stream", | |||||
"text": [ | |||||
"\n" | |||||
] | |||||
}, | |||||
{ | |||||
"name": "stdout", | |||||
"output_type": "stream", | |||||
"text": [ | |||||
"\r" | |||||
] | |||||
}, | |||||
{ | |||||
"name": "stdout", | |||||
"output_type": "stream", | |||||
"text": [ | |||||
"Epoch 6/10. Step:12/20. AccuracyMetric: acc=0.481481" | |||||
] | |||||
}, | |||||
{ | |||||
"name": "stdout", | |||||
"output_type": "stream", | |||||
"text": [ | |||||
"\n" | |||||
] | |||||
}, | |||||
{ | |||||
"name": "stdout", | |||||
"output_type": "stream", | |||||
"text": [ | |||||
"\r" | |||||
] | |||||
}, | |||||
{ | |||||
"name": "stdout", | |||||
"output_type": "stream", | |||||
"text": [ | |||||
"Epoch 7/10. Step:14/20. AccuracyMetric: acc=0.62963" | |||||
] | |||||
}, | |||||
{ | |||||
"name": "stdout", | |||||
"output_type": "stream", | |||||
"text": [ | |||||
"\n" | |||||
] | |||||
}, | |||||
{ | |||||
"name": "stdout", | |||||
"output_type": "stream", | |||||
"text": [ | |||||
"\r" | |||||
] | |||||
}, | |||||
{ | |||||
"name": "stdout", | |||||
"output_type": "stream", | |||||
"text": [ | |||||
"Epoch 8/10. Step:16/20. AccuracyMetric: acc=0.685185" | |||||
] | |||||
}, | |||||
{ | |||||
"name": "stdout", | |||||
"output_type": "stream", | |||||
"text": [ | |||||
"\n" | |||||
] | |||||
}, | |||||
{ | |||||
"name": "stdout", | |||||
"output_type": "stream", | |||||
"text": [ | |||||
"\r" | |||||
] | |||||
}, | |||||
{ | |||||
"name": "stdout", | |||||
"output_type": "stream", | |||||
"text": [ | |||||
"Epoch 9/10. Step:18/20. AccuracyMetric: acc=0.722222" | |||||
] | |||||
}, | |||||
{ | |||||
"name": "stdout", | |||||
"output_type": "stream", | |||||
"text": [ | |||||
"\n" | |||||
] | |||||
}, | |||||
{ | |||||
"name": "stdout", | |||||
"output_type": "stream", | |||||
"text": [ | |||||
"\r" | |||||
] | |||||
}, | |||||
{ | |||||
"name": "stdout", | |||||
"output_type": "stream", | |||||
"text": [ | |||||
"Epoch 10/10. Step:20/20. AccuracyMetric: acc=0.777778" | |||||
] | |||||
}, | |||||
{ | |||||
"name": "stdout", | |||||
"output_type": "stream", | |||||
"text": [ | |||||
"\n" | |||||
] | |||||
}, | |||||
{ | |||||
"name": "stdout", | |||||
"output_type": "stream", | |||||
"text": [ | |||||
"\r" | |||||
] | |||||
} | |||||
], | |||||
"source": [ | |||||
"# 进行overfitting测试\n", | |||||
"copy_model = deepcopy(model)\n", | |||||
"overfit_trainer = Trainer(model=copy_model, \n", | |||||
" train_data=test_data, \n", | |||||
" dev_data=test_data,\n", | |||||
" loss=CrossEntropyLoss(pred=\"output\", target=\"label_seq\"),\n", | |||||
" metrics=AccuracyMetric(),\n", | |||||
" n_epochs=10,\n", | |||||
" save_path=None)\n", | |||||
"overfit_trainer.train()" | |||||
] | |||||
}, | |||||
{ | |||||
"cell_type": "code", | |||||
"execution_count": 14, | |||||
"metadata": {}, | |||||
"outputs": [ | |||||
{ | |||||
"name": "stdout", | |||||
"output_type": "stream", | |||||
"text": [ | |||||
"training epochs started 2018-12-07 14:08:10" | |||||
] | |||||
}, | |||||
{ | |||||
"name": "stdout", | |||||
"output_type": "stream", | |||||
"text": [ | |||||
"\n" | |||||
] | |||||
}, | |||||
{ | |||||
"data": { | |||||
"text/plain": [ | |||||
"HBox(children=(IntProgress(value=0, layout=Layout(flex='2'), max=5), HTML(value='')), layout=Layout(display='i…" | |||||
] | |||||
}, | |||||
"execution_count": 0, | |||||
"metadata": {}, | |||||
"output_type": "execute_result" | |||||
}, | |||||
{ | |||||
"name": "stdout", | |||||
"output_type": "stream", | |||||
"text": [ | |||||
"\r" | |||||
] | |||||
}, | |||||
{ | |||||
"name": "stdout", | |||||
"output_type": "stream", | |||||
"text": [ | |||||
"Epoch 1/5. Step:1/5. AccuracyMetric: acc=0.037037" | |||||
] | |||||
}, | |||||
{ | |||||
"name": "stdout", | |||||
"output_type": "stream", | |||||
"text": [ | |||||
"\n" | |||||
] | |||||
}, | |||||
{ | |||||
"name": "stdout", | |||||
"output_type": "stream", | |||||
"text": [ | |||||
"\r" | |||||
] | |||||
}, | |||||
{ | |||||
"name": "stdout", | |||||
"output_type": "stream", | |||||
"text": [ | |||||
"Epoch 2/5. Step:2/5. AccuracyMetric: acc=0.037037" | |||||
] | |||||
}, | |||||
{ | |||||
"name": "stdout", | |||||
"output_type": "stream", | |||||
"text": [ | |||||
"\n" | |||||
] | |||||
}, | |||||
{ | |||||
"name": "stdout", | |||||
"output_type": "stream", | |||||
"text": [ | |||||
"\r" | |||||
] | |||||
}, | |||||
{ | |||||
"name": "stdout", | |||||
"output_type": "stream", | |||||
"text": [ | |||||
"Epoch 3/5. Step:3/5. AccuracyMetric: acc=0.037037" | |||||
] | |||||
}, | |||||
{ | |||||
"name": "stdout", | |||||
"output_type": "stream", | |||||
"text": [ | |||||
"\n" | |||||
] | |||||
}, | |||||
{ | |||||
"name": "stdout", | |||||
"output_type": "stream", | |||||
"text": [ | |||||
"\r" | |||||
] | |||||
}, | |||||
{ | |||||
"name": "stdout", | |||||
"output_type": "stream", | |||||
"text": [ | |||||
"Epoch 4/5. Step:4/5. AccuracyMetric: acc=0.185185" | |||||
] | |||||
}, | |||||
{ | |||||
"name": "stdout", | |||||
"output_type": "stream", | |||||
"text": [ | |||||
"\n" | |||||
] | |||||
}, | |||||
{ | |||||
"name": "stdout", | |||||
"output_type": "stream", | |||||
"text": [ | |||||
"\r" | |||||
] | |||||
}, | |||||
{ | |||||
"name": "stdout", | |||||
"output_type": "stream", | |||||
"text": [ | |||||
"Epoch 5/5. Step:5/5. AccuracyMetric: acc=0.240741" | |||||
] | |||||
}, | |||||
{ | |||||
"name": "stdout", | |||||
"output_type": "stream", | |||||
"text": [ | |||||
"\n" | |||||
] | |||||
}, | |||||
{ | |||||
"name": "stdout", | |||||
"output_type": "stream", | |||||
"text": [ | |||||
"\r" | |||||
] | |||||
}, | |||||
{ | |||||
"name": "stdout", | |||||
"output_type": "stream", | |||||
"text": [ | |||||
"Train finished!" | |||||
] | |||||
}, | |||||
{ | |||||
"name": "stdout", | |||||
"output_type": "stream", | |||||
"text": [ | |||||
"\n" | |||||
] | |||||
} | |||||
], | |||||
"source": [ | |||||
"# 实例化Trainer,传入模型和数据,进行训练\n", | |||||
"trainer = Trainer(model=model, \n", | |||||
" train_data=train_data, \n", | |||||
" dev_data=test_data,\n", | |||||
" loss=CrossEntropyLoss(pred=\"output\", target=\"label_seq\"),\n", | |||||
" metrics=AccuracyMetric(),\n", | |||||
" n_epochs=5)\n", | |||||
"trainer.train()\n", | |||||
"print('Train finished!')" | |||||
] | |||||
}, | |||||
{ | |||||
"cell_type": "code", | |||||
"execution_count": 15, | |||||
"metadata": {}, | |||||
"outputs": [ | |||||
{ | |||||
"name": "stdout", | |||||
"output_type": "stream", | |||||
"text": [ | |||||
"[tester] \nAccuracyMetric: acc=0.240741" | |||||
] | |||||
}, | |||||
{ | |||||
"name": "stdout", | |||||
"output_type": "stream", | |||||
"text": [ | |||||
"\n" | |||||
] | |||||
} | |||||
], | |||||
"source": [ | |||||
"from fastNLP import Tester\n", | |||||
"\n", | |||||
"tester = Tester(data=test_data, model=model, metrics=AccuracyMetric())\n", | |||||
"acc = tester.test()" | |||||
] | |||||
}, | |||||
{ | |||||
"cell_type": "code", | |||||
"execution_count": null, | |||||
"metadata": {}, | |||||
"outputs": [], | |||||
"source": [] | |||||
}, | |||||
{ | |||||
"cell_type": "markdown", | |||||
"metadata": {}, | |||||
"source": [ | |||||
"# In summary\n", | |||||
"\n", | |||||
"## fastNLP Trainer的伪代码逻辑\n", | |||||
"### 1. 准备DataSet,假设DataSet中共有如下的fields\n", | |||||
" ['raw_sentence', 'word_seq1', 'word_seq2', 'raw_label','label']\n", | |||||
" 通过\n", | |||||
" DataSet.set_input('word_seq1', word_seq2', flag=True)将'word_seq1', 'word_seq2'设置为input\n", | |||||
" 通过\n", | |||||
" DataSet.set_target('label', flag=True)将'label'设置为target\n", | |||||
"### 2. 初始化模型\n", | |||||
" class Model(nn.Module):\n", | |||||
" def __init__(self):\n", | |||||
" xxx\n", | |||||
" def forward(self, word_seq1, word_seq2):\n", | |||||
" # (1) 这里使用的形参名必须和DataSet中的input field的名称对应。因为我们是通过形参名, 进行赋值的\n", | |||||
" # (2) input field的数量可以多于这里的形参数量。但是不能少于。\n", | |||||
" xxxx\n", | |||||
" # 输出必须是一个dict\n", | |||||
"### 3. Trainer的训练过程\n", | |||||
" (1) 从DataSet中按照batch_size取出一个batch,调用Model.forward\n", | |||||
" (2) 将 Model.forward的结果 与 标记为target的field 传入Losser当中。\n", | |||||
" 由于每个人写的Model.forward的output的dict可能key并不一样,比如有人是{'pred':xxx}, {'output': xxx}; \n", | |||||
" 另外每个人将target可能也会设置为不同的名称, 比如有人是label, 有人设置为target;\n", | |||||
" 为了解决以上的问题,我们的loss提供映射机制\n", | |||||
" 比如CrossEntropyLosser的需要的输入是(prediction, target)。但是forward的output是{'output': xxx}; 'label'是target\n", | |||||
" 那么初始化losser的时候写为CrossEntropyLosser(prediction='output', target='label')即可\n", | |||||
" (3) 对于Metric是同理的\n", | |||||
" Metric计算也是从 forward的结果中取值 与 设置target的field中取值。 也是可以通过映射找到对应的值 \n", | |||||
" \n", | |||||
" \n", | |||||
"\n", | |||||
"## 一些问题.\n", | |||||
"### 1. DataSet中为什么需要设置input和target\n", | |||||
" 只有被设置为input或者target的数据才会在train的过程中被取出来\n", | |||||
" (1.1) 我们只会在设置为input的field中寻找传递给Model.forward的参数。\n", | |||||
" (1.2) 我们在传递值给losser或者metric的时候会使用来自: \n", | |||||
" (a)Model.forward的output\n", | |||||
" (b)被设置为target的field\n", | |||||
" \n", | |||||
"\n", | |||||
"### 2. 我们是通过forwad中的形参名将DataSet中的field赋值给对应的参数\n", | |||||
" (1.1) 构建模型过程中,\n", | |||||
" 例如:\n", | |||||
" DataSet中x,seq_lens是input,那么forward就应该是\n", | |||||
" def forward(self, x, seq_lens):\n", | |||||
" pass\n", | |||||
" 我们是通过形参名称进行匹配的field的\n", | |||||
" \n", | |||||
"\n", | |||||
"\n", | |||||
"### 1. 加载数据到DataSet\n", | |||||
"### 2. 使用apply操作对DataSet进行预处理\n", | |||||
" (2.1) 处理过程中将某些field设置为input,某些field设置为target\n", | |||||
"### 3. 构建模型\n", | |||||
" (3.1) 构建模型过程中,需要注意forward函数的形参名需要和DataSet中设置为input的field名称是一致的。\n", | |||||
" 例如:\n", | |||||
" DataSet中x,seq_lens是input,那么forward就应该是\n", | |||||
" def forward(self, x, seq_lens):\n", | |||||
" pass\n", | |||||
" 我们是通过形参名称进行匹配的field的\n", | |||||
" (3.2) 模型的forward的output需要是dict类型的。\n", | |||||
" 建议将输出设置为{\"pred\": xx}.\n", | |||||
" \n" | |||||
] | |||||
}, | |||||
{ | |||||
"cell_type": "code", | |||||
"execution_count": null, | |||||
"metadata": {}, | |||||
"outputs": [], | |||||
"source": [] | |||||
}, | |||||
{ | |||||
"cell_type": "code", | |||||
"execution_count": null, | |||||
"metadata": {}, | |||||
"outputs": [], | |||||
"source": [] | |||||
} | |||||
], | |||||
"metadata": { | |||||
"kernelspec": { | |||||
"display_name": "Python 3", | |||||
"language": "python", | |||||
"name": "python3" | |||||
}, | |||||
"language_info": { | |||||
"codemirror_mode": { | |||||
"name": "ipython", | |||||
"version": 3 | |||||
}, | |||||
"file_extension": ".py", | |||||
"mimetype": "text/x-python", | |||||
"name": "python", | |||||
"nbconvert_exporter": "python", | |||||
"pygments_lexer": "ipython3", | |||||
"version": "3.6.7" | |||||
} | |||||
}, | |||||
"nbformat": 4, | |||||
"nbformat_minor": 2 | |||||
} |
@@ -0,0 +1,333 @@ | |||||
{ | |||||
"cells": [ | |||||
{ | |||||
"cell_type": "markdown", | |||||
"metadata": { | |||||
"collapsed": true | |||||
}, | |||||
"source": [ | |||||
"# FastNLP 1分钟上手教程" | |||||
] | |||||
}, | |||||
{ | |||||
"cell_type": "markdown", | |||||
"metadata": {}, | |||||
"source": [ | |||||
"## step 1\n", | |||||
"读取数据集" | |||||
] | |||||
}, | |||||
{ | |||||
"cell_type": "code", | |||||
"execution_count": 50, | |||||
"metadata": {}, | |||||
"outputs": [], | |||||
"source": [ | |||||
"from fastNLP import DataSet\n", | |||||
"# linux_path = \"../test/data_for_tests/tutorial_sample_dataset.csv\"\n", | |||||
"win_path = \"C:\\\\Users\\zyfeng\\Desktop\\FudanNLP\\\\fastNLP\\\\test\\\\data_for_tests\\\\tutorial_sample_dataset.csv\"\n", | |||||
"ds = DataSet.read_csv(win_path, headers=('raw_sentence', 'label'), sep='\\t')" | |||||
] | |||||
}, | |||||
{ | |||||
"cell_type": "markdown", | |||||
"metadata": {}, | |||||
"source": [ | |||||
"## step 2\n", | |||||
"数据预处理\n", | |||||
"1. 类型转换\n", | |||||
"2. 切分验证集\n", | |||||
"3. 构建词典" | |||||
] | |||||
}, | |||||
{ | |||||
"cell_type": "code", | |||||
"execution_count": 52, | |||||
"metadata": {}, | |||||
"outputs": [], | |||||
"source": [ | |||||
"# 将所有数字转为小写\n", | |||||
"ds.apply(lambda x: x['raw_sentence'].lower(), new_field_name='raw_sentence')\n", | |||||
"# label转int\n", | |||||
"ds.apply(lambda x: int(x['label']), new_field_name='label_seq', is_target=True)\n", | |||||
"\n", | |||||
"def split_sent(ins):\n", | |||||
" return ins['raw_sentence'].split()\n", | |||||
"ds.apply(split_sent, new_field_name='words', is_input=True)\n" | |||||
] | |||||
}, | |||||
{ | |||||
"cell_type": "code", | |||||
"execution_count": 60, | |||||
"metadata": { | |||||
"collapsed": false | |||||
}, | |||||
"outputs": [ | |||||
{ | |||||
"name": "stdout", | |||||
"output_type": "stream", | |||||
"text": [ | |||||
"Train size: " | |||||
] | |||||
}, | |||||
{ | |||||
"name": "stdout", | |||||
"output_type": "stream", | |||||
"text": [ | |||||
" " | |||||
] | |||||
}, | |||||
{ | |||||
"name": "stdout", | |||||
"output_type": "stream", | |||||
"text": [ | |||||
"54" | |||||
] | |||||
}, | |||||
{ | |||||
"name": "stdout", | |||||
"output_type": "stream", | |||||
"text": [ | |||||
"\n" | |||||
] | |||||
}, | |||||
{ | |||||
"name": "stdout", | |||||
"output_type": "stream", | |||||
"text": [ | |||||
"Test size: " | |||||
] | |||||
}, | |||||
{ | |||||
"name": "stdout", | |||||
"output_type": "stream", | |||||
"text": [ | |||||
" " | |||||
] | |||||
}, | |||||
{ | |||||
"name": "stdout", | |||||
"output_type": "stream", | |||||
"text": [ | |||||
"23" | |||||
] | |||||
}, | |||||
{ | |||||
"name": "stdout", | |||||
"output_type": "stream", | |||||
"text": [ | |||||
"\n" | |||||
] | |||||
} | |||||
], | |||||
"source": [ | |||||
"# 分割训练集/验证集\n", | |||||
"train_data, dev_data = ds.split(0.3)\n", | |||||
"print(\"Train size: \", len(train_data))\n", | |||||
"print(\"Test size: \", len(dev_data))" | |||||
] | |||||
}, | |||||
{ | |||||
"cell_type": "code", | |||||
"execution_count": 61, | |||||
"metadata": {}, | |||||
"outputs": [], | |||||
"source": [ | |||||
"from fastNLP import Vocabulary\n", | |||||
"vocab = Vocabulary(min_freq=2)\n", | |||||
"train_data.apply(lambda x: [vocab.add(word) for word in x['words']])\n", | |||||
"\n", | |||||
"# index句子, Vocabulary.to_index(word)\n", | |||||
"train_data.apply(lambda x: [vocab.to_index(word) for word in x['words']], new_field_name='word_seq', is_input=True)\n", | |||||
"dev_data.apply(lambda x: [vocab.to_index(word) for word in x['words']], new_field_name='word_seq', is_input=True)\n" | |||||
] | |||||
}, | |||||
{ | |||||
"cell_type": "markdown", | |||||
"metadata": {}, | |||||
"source": [ | |||||
"## step 3\n", | |||||
" 定义模型" | |||||
] | |||||
}, | |||||
{ | |||||
"cell_type": "code", | |||||
"execution_count": 62, | |||||
"metadata": {}, | |||||
"outputs": [], | |||||
"source": [ | |||||
"from fastNLP.models import CNNText\n", | |||||
"model = CNNText(embed_num=len(vocab), embed_dim=50, num_classes=5, padding=2, dropout=0.1)\n" | |||||
] | |||||
}, | |||||
{ | |||||
"cell_type": "markdown", | |||||
"metadata": {}, | |||||
"source": [ | |||||
"## step 4\n", | |||||
"开始训练" | |||||
] | |||||
}, | |||||
{ | |||||
"cell_type": "code", | |||||
"execution_count": 63, | |||||
"metadata": {}, | |||||
"outputs": [ | |||||
{ | |||||
"name": "stdout", | |||||
"output_type": "stream", | |||||
"text": [ | |||||
"training epochs started 2018-12-07 14:03:41" | |||||
] | |||||
}, | |||||
{ | |||||
"name": "stdout", | |||||
"output_type": "stream", | |||||
"text": [ | |||||
"\n" | |||||
] | |||||
}, | |||||
{ | |||||
"data": { | |||||
"text/plain": [ | |||||
"HBox(children=(IntProgress(value=0, layout=Layout(flex='2'), max=6), HTML(value='')), layout=Layout(display='i…" | |||||
] | |||||
}, | |||||
"execution_count": 0, | |||||
"metadata": {}, | |||||
"output_type": "execute_result" | |||||
}, | |||||
{ | |||||
"name": "stdout", | |||||
"output_type": "stream", | |||||
"text": [ | |||||
"\r" | |||||
] | |||||
}, | |||||
{ | |||||
"name": "stdout", | |||||
"output_type": "stream", | |||||
"text": [ | |||||
"Epoch 1/3. Step:2/6. AccuracyMetric: acc=0.26087" | |||||
] | |||||
}, | |||||
{ | |||||
"name": "stdout", | |||||
"output_type": "stream", | |||||
"text": [ | |||||
"\n" | |||||
] | |||||
}, | |||||
{ | |||||
"name": "stdout", | |||||
"output_type": "stream", | |||||
"text": [ | |||||
"\r" | |||||
] | |||||
}, | |||||
{ | |||||
"name": "stdout", | |||||
"output_type": "stream", | |||||
"text": [ | |||||
"Epoch 2/3. Step:4/6. AccuracyMetric: acc=0.347826" | |||||
] | |||||
}, | |||||
{ | |||||
"name": "stdout", | |||||
"output_type": "stream", | |||||
"text": [ | |||||
"\n" | |||||
] | |||||
}, | |||||
{ | |||||
"name": "stdout", | |||||
"output_type": "stream", | |||||
"text": [ | |||||
"\r" | |||||
] | |||||
}, | |||||
{ | |||||
"name": "stdout", | |||||
"output_type": "stream", | |||||
"text": [ | |||||
"Epoch 3/3. Step:6/6. AccuracyMetric: acc=0.608696" | |||||
] | |||||
}, | |||||
{ | |||||
"name": "stdout", | |||||
"output_type": "stream", | |||||
"text": [ | |||||
"\n" | |||||
] | |||||
}, | |||||
{ | |||||
"name": "stdout", | |||||
"output_type": "stream", | |||||
"text": [ | |||||
"\r" | |||||
] | |||||
}, | |||||
{ | |||||
"name": "stdout", | |||||
"output_type": "stream", | |||||
"text": [ | |||||
"Train finished!" | |||||
] | |||||
}, | |||||
{ | |||||
"name": "stdout", | |||||
"output_type": "stream", | |||||
"text": [ | |||||
"\n" | |||||
] | |||||
} | |||||
], | |||||
"source": [ | |||||
"from fastNLP import Trainer, CrossEntropyLoss, AccuracyMetric\n", | |||||
"trainer = Trainer(model=model, \n", | |||||
" train_data=train_data, \n", | |||||
" dev_data=dev_data,\n", | |||||
" loss=CrossEntropyLoss(),\n", | |||||
" metrics=AccuracyMetric()\n", | |||||
" )\n", | |||||
"trainer.train()\n", | |||||
"print('Train finished!')\n" | |||||
] | |||||
}, | |||||
{ | |||||
"cell_type": "markdown", | |||||
"metadata": {}, | |||||
"source": [ | |||||
"### 本教程结束。更多操作请参考进阶教程。" | |||||
] | |||||
}, | |||||
{ | |||||
"cell_type": "code", | |||||
"execution_count": null, | |||||
"metadata": {}, | |||||
"outputs": [], | |||||
"source": [] | |||||
} | |||||
], | |||||
"metadata": { | |||||
"kernelspec": { | |||||
"display_name": "Python 2", | |||||
"language": "python", | |||||
"name": "python2" | |||||
}, | |||||
"language_info": { | |||||
"codemirror_mode": { | |||||
"name": "ipython", | |||||
"version": 2 | |||||
}, | |||||
"file_extension": ".py", | |||||
"mimetype": "text/x-python", | |||||
"name": "python", | |||||
"nbconvert_exporter": "python", | |||||
"pygments_lexer": "ipython2", | |||||
"version": "2.7.6" | |||||
} | |||||
}, | |||||
"nbformat": 4, | |||||
"nbformat_minor": 0 | |||||
} |
@@ -0,0 +1,101 @@ | |||||
{ | |||||
"cells": [ | |||||
{ | |||||
"cell_type": "markdown", | |||||
"metadata": { | |||||
"collapsed": true | |||||
}, | |||||
"source": [ | |||||
"## FastNLP 进阶教程\n", | |||||
"本教程阅读时间平均30分钟" | |||||
] | |||||
}, | |||||
{ | |||||
"cell_type": "markdown", | |||||
"metadata": {}, | |||||
"source": [ | |||||
"## 数据部分\n", | |||||
"### DataSet\n" | |||||
] | |||||
}, | |||||
{ | |||||
"cell_type": "markdown", | |||||
"metadata": {}, | |||||
"source": [ | |||||
"### Instance" | |||||
] | |||||
}, | |||||
{ | |||||
"cell_type": "markdown", | |||||
"metadata": {}, | |||||
"source": [ | |||||
"### Vocabulary" | |||||
] | |||||
}, | |||||
{ | |||||
"cell_type": "markdown", | |||||
"metadata": {}, | |||||
"source": [ | |||||
"## 模型部分\n", | |||||
"### model" | |||||
] | |||||
}, | |||||
{ | |||||
"cell_type": "markdown", | |||||
"metadata": {}, | |||||
"source": [ | |||||
"## 训练测试部分\n", | |||||
"### Loss" | |||||
] | |||||
}, | |||||
{ | |||||
"cell_type": "markdown", | |||||
"metadata": {}, | |||||
"source": [ | |||||
"### Metric" | |||||
] | |||||
}, | |||||
{ | |||||
"cell_type": "markdown", | |||||
"metadata": {}, | |||||
"source": [ | |||||
"### Trainer" | |||||
] | |||||
}, | |||||
{ | |||||
"cell_type": "markdown", | |||||
"metadata": {}, | |||||
"source": [ | |||||
"### Tester" | |||||
] | |||||
}, | |||||
{ | |||||
"cell_type": "code", | |||||
"execution_count": null, | |||||
"metadata": {}, | |||||
"outputs": [], | |||||
"source": [] | |||||
} | |||||
], | |||||
"metadata": { | |||||
"kernelspec": { | |||||
"display_name": "Python 2", | |||||
"language": "python", | |||||
"name": "python2" | |||||
}, | |||||
"language_info": { | |||||
"codemirror_mode": { | |||||
"name": "ipython", | |||||
"version": 2 | |||||
}, | |||||
"file_extension": ".py", | |||||
"mimetype": "text/x-python", | |||||
"name": "python", | |||||
"nbconvert_exporter": "python", | |||||
"pygments_lexer": "ipython2", | |||||
"version": "2.7.6" | |||||
} | |||||
}, | |||||
"nbformat": 4, | |||||
"nbformat_minor": 0 | |||||
} |