From d643a7a894520d50b030bc026f9bc000c6516e5f Mon Sep 17 00:00:00 2001 From: yunfan Date: Fri, 23 Nov 2018 17:14:42 +0800 Subject: [PATCH] update set_target, batch's as_numpy --- fastNLP/api/api.py | 2 +- fastNLP/api/processor.py | 8 +++---- fastNLP/core/batch.py | 7 ++++-- fastNLP/core/dataset.py | 24 +++++++++++++++---- fastNLP/core/metrics.py | 5 ---- fastNLP/core/utils.py | 17 ++++++++++++- fastNLP/modules/__init__.py | 2 -- fastNLP/modules/interactor/__init__.py | 0 .../process/cws_processor.py | 6 ++--- 9 files changed, 48 insertions(+), 23 deletions(-) delete mode 100644 fastNLP/modules/interactor/__init__.py diff --git a/fastNLP/api/api.py b/fastNLP/api/api.py index 51559bfd..38658bcf 100644 --- a/fastNLP/api/api.py +++ b/fastNLP/api/api.py @@ -109,7 +109,7 @@ class POS(API): "use_cuda": True, "evaluator": evaluator} pp(te_dataset) - te_dataset.set_is_target(truth=True) + te_dataset.set_target(truth=True) tester = Tester(**default_valid_args) diff --git a/fastNLP/api/processor.py b/fastNLP/api/processor.py index 999cebac..711f2b67 100644 --- a/fastNLP/api/processor.py +++ b/fastNLP/api/processor.py @@ -152,7 +152,7 @@ class IndexerProcessor(Processor): index = [self.vocab.to_index(token) for token in tokens] ins[self.new_added_field_name] = index - dataset.set_need_tensor(**{self.new_added_field_name: True}) + dataset._set_need_tensor(**{self.new_added_field_name: True}) if self.delete_old_field: dataset.delete_field(self.field_name) @@ -186,7 +186,7 @@ class SeqLenProcessor(Processor): for ins in dataset: length = len(ins[self.field_name]) ins[self.new_added_field_name] = length - dataset.set_need_tensor(**{self.new_added_field_name: True}) + dataset._set_need_tensor(**{self.new_added_field_name: True}) return dataset class ModelProcessor(Processor): @@ -259,7 +259,7 @@ class SetTensorProcessor(Processor): def process(self, dataset): set_dict = {name: self.default for name in dataset.get_fields().keys()} set_dict.update(self.field_dict) - dataset.set_need_tensor(**set_dict) + dataset._set_need_tensor(**set_dict) return dataset @@ -272,5 +272,5 @@ class SetIsTargetProcessor(Processor): def process(self, dataset): set_dict = {name: self.default for name in dataset.get_fields().keys()} set_dict.update(self.field_dict) - dataset.set_is_target(**set_dict) + dataset.set_target(**set_dict) return dataset diff --git a/fastNLP/core/batch.py b/fastNLP/core/batch.py index b047081a..ce7e25c0 100644 --- a/fastNLP/core/batch.py +++ b/fastNLP/core/batch.py @@ -9,7 +9,7 @@ class Batch(object): """ - def __init__(self, dataset, batch_size, sampler, use_cuda=False): + def __init__(self, dataset, batch_size, sampler, as_numpy=False, use_cuda=False): """ :param dataset: a DataSet object @@ -21,6 +21,7 @@ class Batch(object): self.dataset = dataset self.batch_size = batch_size self.sampler = sampler + self.as_numpy = as_numpy self.use_cuda = use_cuda self.idx_list = None self.curidx = 0 @@ -53,7 +54,9 @@ class Batch(object): for field_name, field in self.dataset.get_fields().items(): if field.need_tensor: - batch = torch.from_numpy(field.get(indices)) + batch = field.get(indices) + if not self.as_numpy: + batch = torch.from_numpy(batch) if self.use_cuda: batch = batch.cuda() if field.is_target: diff --git a/fastNLP/core/dataset.py b/fastNLP/core/dataset.py index db0ebc53..702d37a1 100644 --- a/fastNLP/core/dataset.py +++ b/fastNLP/core/dataset.py @@ -30,21 +30,25 @@ class DataSet(object): def __init__(self, dataset, idx=-1): self.dataset = dataset self.idx = idx + self.fields = None def __next__(self): self.idx += 1 - if self.idx >= len(self.dataset): + try: + self.fields = {name: field[self.idx] for name, field in self.dataset.get_fields().items()} + except IndexError: raise StopIteration return self def __getitem__(self, name): - return self.dataset[name][self.idx] + return self.fields[name] def __setitem__(self, name, val): if name not in self.dataset: new_fields = [None] * len(self.dataset) self.dataset.add_field(name, new_fields) self.dataset[name][self.idx] = val + self.fields[name] = val def __repr__(self): return "\n".join(['{}: {}'.format(name, repr(self.dataset[name][self.idx])) for name @@ -163,9 +167,8 @@ class DataSet(object): self.field_arrays[new_name] = self.field_arrays.pop(old_name) else: raise KeyError("{} is not a valid name. ".format(old_name)) - return self - def set_is_target(self, **fields): + def set_target(self, **fields): """Change the flag of `is_target` for all instance. For fields not set here, leave their `is_target` unchanged. :param key-value pairs for field-name and `is_target` value(True, False). @@ -176,9 +179,20 @@ class DataSet(object): self.field_arrays[name].is_target = val else: raise KeyError("{} is not a valid field name.".format(name)) + self._set_need_tensor(**fields) + return self + + def set_input(self, **fields): + for name, val in fields.items(): + if name in self.field_arrays: + assert isinstance(val, bool) + self.field_arrays[name].is_target = not val + else: + raise KeyError("{} is not a valid field name.".format(name)) + self._set_need_tensor(**fields) return self - def set_need_tensor(self, **kwargs): + def _set_need_tensor(self, **kwargs): for name, val in kwargs.items(): if name in self.field_arrays: assert isinstance(val, bool) diff --git a/fastNLP/core/metrics.py b/fastNLP/core/metrics.py index 35c6b544..adc0326f 100644 --- a/fastNLP/core/metrics.py +++ b/fastNLP/core/metrics.py @@ -320,8 +320,3 @@ def pred_topk(y_prob, k=1): (1, k)) y_prob_topk = y_prob[x_axis_index, y_pred_topk] return y_pred_topk, y_prob_topk - - -if __name__ == '__main__': - y = np.array([1, 0, 1, 0, 1, 1]) - print(_label_types(y)) diff --git a/fastNLP/core/utils.py b/fastNLP/core/utils.py index 63c4be17..c773ae15 100644 --- a/fastNLP/core/utils.py +++ b/fastNLP/core/utils.py @@ -1,6 +1,6 @@ import _pickle import os - +import inspect def save_pickle(obj, pickle_path, file_name): """Save an object into a pickle file. @@ -44,3 +44,18 @@ def pickle_exist(pickle_path, pickle_name): return True else: return False + +def build_args(func, kwargs): + assert isinstance(func, function) and isinstance(kwargs, dict) + spect = inspect.getfullargspec(func) + assert spect.varargs is None, 'Positional Arguments({}) are not supported.'.format(spect.varargs) + needed_args = set(spect.args) + output = {name: default for name, default in zip(reversed(spect.args), reversed(spect.defaults))} + output.update({name: val for name, val in kwargs.items() if name in needed_args}) + if spect.varkw is not None: + output.update(kwargs) + + # check miss args + + + diff --git a/fastNLP/modules/__init__.py b/fastNLP/modules/__init__.py index 3af1ebad..f0f0404a 100644 --- a/fastNLP/modules/__init__.py +++ b/fastNLP/modules/__init__.py @@ -1,7 +1,6 @@ from . import aggregator from . import decoder from . import encoder -from . import interactor from .aggregator import * from .decoder import * from .encoder import * @@ -12,5 +11,4 @@ __version__ = '0.0.0' __all__ = ['encoder', 'decoder', 'aggregator', - 'interactor', 'TimestepDropout'] diff --git a/fastNLP/modules/interactor/__init__.py b/fastNLP/modules/interactor/__init__.py deleted file mode 100644 index e69de29b..00000000 diff --git a/reproduction/chinese_word_segment/process/cws_processor.py b/reproduction/chinese_word_segment/process/cws_processor.py index 03b6ea22..e7c069f1 100644 --- a/reproduction/chinese_word_segment/process/cws_processor.py +++ b/reproduction/chinese_word_segment/process/cws_processor.py @@ -111,8 +111,8 @@ class CWSTagProcessor(Processor): sentence = ins[self.field_name] tag_list = self._generate_tag(sentence) ins[self.new_added_field_name] = tag_list - dataset.set_is_target(**{self.new_added_field_name:True}) - dataset.set_need_tensor(**{self.new_added_field_name:True}) + dataset.set_target(**{self.new_added_field_name:True}) + dataset._set_need_tensor(**{self.new_added_field_name:True}) return dataset def _tags_from_word_len(self, word_len): @@ -230,7 +230,7 @@ class SeqLenProcessor(Processor): for ins in dataset: length = len(ins[self.field_name]) ins[self.new_added_field_name] = length - dataset.set_need_tensor(**{self.new_added_field_name:True}) + dataset._set_need_tensor(**{self.new_added_field_name:True}) return dataset class SegApp2OutputProcessor(Processor):