Browse Source

update set_target, batch's as_numpy

tags/v0.2.0
yunfan 6 years ago
parent
commit
d643a7a894
9 changed files with 48 additions and 23 deletions
  1. +1
    -1
      fastNLP/api/api.py
  2. +4
    -4
      fastNLP/api/processor.py
  3. +5
    -2
      fastNLP/core/batch.py
  4. +19
    -5
      fastNLP/core/dataset.py
  5. +0
    -5
      fastNLP/core/metrics.py
  6. +16
    -1
      fastNLP/core/utils.py
  7. +0
    -2
      fastNLP/modules/__init__.py
  8. +0
    -0
      fastNLP/modules/interactor/__init__.py
  9. +3
    -3
      reproduction/chinese_word_segment/process/cws_processor.py

+ 1
- 1
fastNLP/api/api.py View File

@@ -109,7 +109,7 @@ class POS(API):
"use_cuda": True, "evaluator": evaluator} "use_cuda": True, "evaluator": evaluator}


pp(te_dataset) pp(te_dataset)
te_dataset.set_is_target(truth=True)
te_dataset.set_target(truth=True)


tester = Tester(**default_valid_args) tester = Tester(**default_valid_args)




+ 4
- 4
fastNLP/api/processor.py View File

@@ -152,7 +152,7 @@ class IndexerProcessor(Processor):
index = [self.vocab.to_index(token) for token in tokens] index = [self.vocab.to_index(token) for token in tokens]
ins[self.new_added_field_name] = index ins[self.new_added_field_name] = index


dataset.set_need_tensor(**{self.new_added_field_name: True})
dataset._set_need_tensor(**{self.new_added_field_name: True})


if self.delete_old_field: if self.delete_old_field:
dataset.delete_field(self.field_name) dataset.delete_field(self.field_name)
@@ -186,7 +186,7 @@ class SeqLenProcessor(Processor):
for ins in dataset: for ins in dataset:
length = len(ins[self.field_name]) length = len(ins[self.field_name])
ins[self.new_added_field_name] = length ins[self.new_added_field_name] = length
dataset.set_need_tensor(**{self.new_added_field_name: True})
dataset._set_need_tensor(**{self.new_added_field_name: True})
return dataset return dataset


class ModelProcessor(Processor): class ModelProcessor(Processor):
@@ -259,7 +259,7 @@ class SetTensorProcessor(Processor):
def process(self, dataset): def process(self, dataset):
set_dict = {name: self.default for name in dataset.get_fields().keys()} set_dict = {name: self.default for name in dataset.get_fields().keys()}
set_dict.update(self.field_dict) set_dict.update(self.field_dict)
dataset.set_need_tensor(**set_dict)
dataset._set_need_tensor(**set_dict)
return dataset return dataset




@@ -272,5 +272,5 @@ class SetIsTargetProcessor(Processor):
def process(self, dataset): def process(self, dataset):
set_dict = {name: self.default for name in dataset.get_fields().keys()} set_dict = {name: self.default for name in dataset.get_fields().keys()}
set_dict.update(self.field_dict) set_dict.update(self.field_dict)
dataset.set_is_target(**set_dict)
dataset.set_target(**set_dict)
return dataset return dataset

+ 5
- 2
fastNLP/core/batch.py View File

@@ -9,7 +9,7 @@ class Batch(object):


""" """


def __init__(self, dataset, batch_size, sampler, use_cuda=False):
def __init__(self, dataset, batch_size, sampler, as_numpy=False, use_cuda=False):
""" """


:param dataset: a DataSet object :param dataset: a DataSet object
@@ -21,6 +21,7 @@ class Batch(object):
self.dataset = dataset self.dataset = dataset
self.batch_size = batch_size self.batch_size = batch_size
self.sampler = sampler self.sampler = sampler
self.as_numpy = as_numpy
self.use_cuda = use_cuda self.use_cuda = use_cuda
self.idx_list = None self.idx_list = None
self.curidx = 0 self.curidx = 0
@@ -53,7 +54,9 @@ class Batch(object):


for field_name, field in self.dataset.get_fields().items(): for field_name, field in self.dataset.get_fields().items():
if field.need_tensor: if field.need_tensor:
batch = torch.from_numpy(field.get(indices))
batch = field.get(indices)
if not self.as_numpy:
batch = torch.from_numpy(batch)
if self.use_cuda: if self.use_cuda:
batch = batch.cuda() batch = batch.cuda()
if field.is_target: if field.is_target:


+ 19
- 5
fastNLP/core/dataset.py View File

@@ -30,21 +30,25 @@ class DataSet(object):
def __init__(self, dataset, idx=-1): def __init__(self, dataset, idx=-1):
self.dataset = dataset self.dataset = dataset
self.idx = idx self.idx = idx
self.fields = None


def __next__(self): def __next__(self):
self.idx += 1 self.idx += 1
if self.idx >= len(self.dataset):
try:
self.fields = {name: field[self.idx] for name, field in self.dataset.get_fields().items()}
except IndexError:
raise StopIteration raise StopIteration
return self return self


def __getitem__(self, name): def __getitem__(self, name):
return self.dataset[name][self.idx]
return self.fields[name]


def __setitem__(self, name, val): def __setitem__(self, name, val):
if name not in self.dataset: if name not in self.dataset:
new_fields = [None] * len(self.dataset) new_fields = [None] * len(self.dataset)
self.dataset.add_field(name, new_fields) self.dataset.add_field(name, new_fields)
self.dataset[name][self.idx] = val self.dataset[name][self.idx] = val
self.fields[name] = val


def __repr__(self): def __repr__(self):
return "\n".join(['{}: {}'.format(name, repr(self.dataset[name][self.idx])) for name return "\n".join(['{}: {}'.format(name, repr(self.dataset[name][self.idx])) for name
@@ -163,9 +167,8 @@ class DataSet(object):
self.field_arrays[new_name] = self.field_arrays.pop(old_name) self.field_arrays[new_name] = self.field_arrays.pop(old_name)
else: else:
raise KeyError("{} is not a valid name. ".format(old_name)) raise KeyError("{} is not a valid name. ".format(old_name))
return self


def set_is_target(self, **fields):
def set_target(self, **fields):
"""Change the flag of `is_target` for all instance. For fields not set here, leave their `is_target` unchanged. """Change the flag of `is_target` for all instance. For fields not set here, leave their `is_target` unchanged.


:param key-value pairs for field-name and `is_target` value(True, False). :param key-value pairs for field-name and `is_target` value(True, False).
@@ -176,9 +179,20 @@ class DataSet(object):
self.field_arrays[name].is_target = val self.field_arrays[name].is_target = val
else: else:
raise KeyError("{} is not a valid field name.".format(name)) raise KeyError("{} is not a valid field name.".format(name))
self._set_need_tensor(**fields)
return self

def set_input(self, **fields):
for name, val in fields.items():
if name in self.field_arrays:
assert isinstance(val, bool)
self.field_arrays[name].is_target = not val
else:
raise KeyError("{} is not a valid field name.".format(name))
self._set_need_tensor(**fields)
return self return self


def set_need_tensor(self, **kwargs):
def _set_need_tensor(self, **kwargs):
for name, val in kwargs.items(): for name, val in kwargs.items():
if name in self.field_arrays: if name in self.field_arrays:
assert isinstance(val, bool) assert isinstance(val, bool)


+ 0
- 5
fastNLP/core/metrics.py View File

@@ -320,8 +320,3 @@ def pred_topk(y_prob, k=1):
(1, k)) (1, k))
y_prob_topk = y_prob[x_axis_index, y_pred_topk] y_prob_topk = y_prob[x_axis_index, y_pred_topk]
return y_pred_topk, y_prob_topk return y_pred_topk, y_prob_topk


if __name__ == '__main__':
y = np.array([1, 0, 1, 0, 1, 1])
print(_label_types(y))

+ 16
- 1
fastNLP/core/utils.py View File

@@ -1,6 +1,6 @@
import _pickle import _pickle
import os import os
import inspect


def save_pickle(obj, pickle_path, file_name): def save_pickle(obj, pickle_path, file_name):
"""Save an object into a pickle file. """Save an object into a pickle file.
@@ -44,3 +44,18 @@ def pickle_exist(pickle_path, pickle_name):
return True return True
else: else:
return False return False

def build_args(func, kwargs):
assert isinstance(func, function) and isinstance(kwargs, dict)
spect = inspect.getfullargspec(func)
assert spect.varargs is None, 'Positional Arguments({}) are not supported.'.format(spect.varargs)
needed_args = set(spect.args)
output = {name: default for name, default in zip(reversed(spect.args), reversed(spect.defaults))}
output.update({name: val for name, val in kwargs.items() if name in needed_args})
if spect.varkw is not None:
output.update(kwargs)

# check miss args




+ 0
- 2
fastNLP/modules/__init__.py View File

@@ -1,7 +1,6 @@
from . import aggregator from . import aggregator
from . import decoder from . import decoder
from . import encoder from . import encoder
from . import interactor
from .aggregator import * from .aggregator import *
from .decoder import * from .decoder import *
from .encoder import * from .encoder import *
@@ -12,5 +11,4 @@ __version__ = '0.0.0'
__all__ = ['encoder', __all__ = ['encoder',
'decoder', 'decoder',
'aggregator', 'aggregator',
'interactor',
'TimestepDropout'] 'TimestepDropout']

+ 0
- 0
fastNLP/modules/interactor/__init__.py View File


+ 3
- 3
reproduction/chinese_word_segment/process/cws_processor.py View File

@@ -111,8 +111,8 @@ class CWSTagProcessor(Processor):
sentence = ins[self.field_name] sentence = ins[self.field_name]
tag_list = self._generate_tag(sentence) tag_list = self._generate_tag(sentence)
ins[self.new_added_field_name] = tag_list ins[self.new_added_field_name] = tag_list
dataset.set_is_target(**{self.new_added_field_name:True})
dataset.set_need_tensor(**{self.new_added_field_name:True})
dataset.set_target(**{self.new_added_field_name:True})
dataset._set_need_tensor(**{self.new_added_field_name:True})
return dataset return dataset


def _tags_from_word_len(self, word_len): def _tags_from_word_len(self, word_len):
@@ -230,7 +230,7 @@ class SeqLenProcessor(Processor):
for ins in dataset: for ins in dataset:
length = len(ins[self.field_name]) length = len(ins[self.field_name])
ins[self.new_added_field_name] = length ins[self.new_added_field_name] = length
dataset.set_need_tensor(**{self.new_added_field_name:True})
dataset._set_need_tensor(**{self.new_added_field_name:True})
return dataset return dataset


class SegApp2OutputProcessor(Processor): class SegApp2OutputProcessor(Processor):


Loading…
Cancel
Save