@@ -109,7 +109,7 @@ class POS(API): | |||||
"use_cuda": True, "evaluator": evaluator} | "use_cuda": True, "evaluator": evaluator} | ||||
pp(te_dataset) | pp(te_dataset) | ||||
te_dataset.set_is_target(truth=True) | |||||
te_dataset.set_target(truth=True) | |||||
tester = Tester(**default_valid_args) | tester = Tester(**default_valid_args) | ||||
@@ -152,7 +152,7 @@ class IndexerProcessor(Processor): | |||||
index = [self.vocab.to_index(token) for token in tokens] | index = [self.vocab.to_index(token) for token in tokens] | ||||
ins[self.new_added_field_name] = index | ins[self.new_added_field_name] = index | ||||
dataset.set_need_tensor(**{self.new_added_field_name: True}) | |||||
dataset._set_need_tensor(**{self.new_added_field_name: True}) | |||||
if self.delete_old_field: | if self.delete_old_field: | ||||
dataset.delete_field(self.field_name) | dataset.delete_field(self.field_name) | ||||
@@ -186,7 +186,7 @@ class SeqLenProcessor(Processor): | |||||
for ins in dataset: | for ins in dataset: | ||||
length = len(ins[self.field_name]) | length = len(ins[self.field_name]) | ||||
ins[self.new_added_field_name] = length | ins[self.new_added_field_name] = length | ||||
dataset.set_need_tensor(**{self.new_added_field_name: True}) | |||||
dataset._set_need_tensor(**{self.new_added_field_name: True}) | |||||
return dataset | return dataset | ||||
class ModelProcessor(Processor): | class ModelProcessor(Processor): | ||||
@@ -259,7 +259,7 @@ class SetTensorProcessor(Processor): | |||||
def process(self, dataset): | def process(self, dataset): | ||||
set_dict = {name: self.default for name in dataset.get_fields().keys()} | set_dict = {name: self.default for name in dataset.get_fields().keys()} | ||||
set_dict.update(self.field_dict) | set_dict.update(self.field_dict) | ||||
dataset.set_need_tensor(**set_dict) | |||||
dataset._set_need_tensor(**set_dict) | |||||
return dataset | return dataset | ||||
@@ -272,5 +272,5 @@ class SetIsTargetProcessor(Processor): | |||||
def process(self, dataset): | def process(self, dataset): | ||||
set_dict = {name: self.default for name in dataset.get_fields().keys()} | set_dict = {name: self.default for name in dataset.get_fields().keys()} | ||||
set_dict.update(self.field_dict) | set_dict.update(self.field_dict) | ||||
dataset.set_is_target(**set_dict) | |||||
dataset.set_target(**set_dict) | |||||
return dataset | return dataset |
@@ -9,7 +9,7 @@ class Batch(object): | |||||
""" | """ | ||||
def __init__(self, dataset, batch_size, sampler, use_cuda=False): | |||||
def __init__(self, dataset, batch_size, sampler, as_numpy=False, use_cuda=False): | |||||
""" | """ | ||||
:param dataset: a DataSet object | :param dataset: a DataSet object | ||||
@@ -21,6 +21,7 @@ class Batch(object): | |||||
self.dataset = dataset | self.dataset = dataset | ||||
self.batch_size = batch_size | self.batch_size = batch_size | ||||
self.sampler = sampler | self.sampler = sampler | ||||
self.as_numpy = as_numpy | |||||
self.use_cuda = use_cuda | self.use_cuda = use_cuda | ||||
self.idx_list = None | self.idx_list = None | ||||
self.curidx = 0 | self.curidx = 0 | ||||
@@ -53,7 +54,9 @@ class Batch(object): | |||||
for field_name, field in self.dataset.get_fields().items(): | for field_name, field in self.dataset.get_fields().items(): | ||||
if field.need_tensor: | if field.need_tensor: | ||||
batch = torch.from_numpy(field.get(indices)) | |||||
batch = field.get(indices) | |||||
if not self.as_numpy: | |||||
batch = torch.from_numpy(batch) | |||||
if self.use_cuda: | if self.use_cuda: | ||||
batch = batch.cuda() | batch = batch.cuda() | ||||
if field.is_target: | if field.is_target: | ||||
@@ -30,21 +30,25 @@ class DataSet(object): | |||||
def __init__(self, dataset, idx=-1): | def __init__(self, dataset, idx=-1): | ||||
self.dataset = dataset | self.dataset = dataset | ||||
self.idx = idx | self.idx = idx | ||||
self.fields = None | |||||
def __next__(self): | def __next__(self): | ||||
self.idx += 1 | self.idx += 1 | ||||
if self.idx >= len(self.dataset): | |||||
try: | |||||
self.fields = {name: field[self.idx] for name, field in self.dataset.get_fields().items()} | |||||
except IndexError: | |||||
raise StopIteration | raise StopIteration | ||||
return self | return self | ||||
def __getitem__(self, name): | def __getitem__(self, name): | ||||
return self.dataset[name][self.idx] | |||||
return self.fields[name] | |||||
def __setitem__(self, name, val): | def __setitem__(self, name, val): | ||||
if name not in self.dataset: | if name not in self.dataset: | ||||
new_fields = [None] * len(self.dataset) | new_fields = [None] * len(self.dataset) | ||||
self.dataset.add_field(name, new_fields) | self.dataset.add_field(name, new_fields) | ||||
self.dataset[name][self.idx] = val | self.dataset[name][self.idx] = val | ||||
self.fields[name] = val | |||||
def __repr__(self): | def __repr__(self): | ||||
return "\n".join(['{}: {}'.format(name, repr(self.dataset[name][self.idx])) for name | return "\n".join(['{}: {}'.format(name, repr(self.dataset[name][self.idx])) for name | ||||
@@ -163,9 +167,8 @@ class DataSet(object): | |||||
self.field_arrays[new_name] = self.field_arrays.pop(old_name) | self.field_arrays[new_name] = self.field_arrays.pop(old_name) | ||||
else: | else: | ||||
raise KeyError("{} is not a valid name. ".format(old_name)) | raise KeyError("{} is not a valid name. ".format(old_name)) | ||||
return self | |||||
def set_is_target(self, **fields): | |||||
def set_target(self, **fields): | |||||
"""Change the flag of `is_target` for all instance. For fields not set here, leave their `is_target` unchanged. | """Change the flag of `is_target` for all instance. For fields not set here, leave their `is_target` unchanged. | ||||
:param key-value pairs for field-name and `is_target` value(True, False). | :param key-value pairs for field-name and `is_target` value(True, False). | ||||
@@ -176,9 +179,20 @@ class DataSet(object): | |||||
self.field_arrays[name].is_target = val | self.field_arrays[name].is_target = val | ||||
else: | else: | ||||
raise KeyError("{} is not a valid field name.".format(name)) | raise KeyError("{} is not a valid field name.".format(name)) | ||||
self._set_need_tensor(**fields) | |||||
return self | |||||
def set_input(self, **fields): | |||||
for name, val in fields.items(): | |||||
if name in self.field_arrays: | |||||
assert isinstance(val, bool) | |||||
self.field_arrays[name].is_target = not val | |||||
else: | |||||
raise KeyError("{} is not a valid field name.".format(name)) | |||||
self._set_need_tensor(**fields) | |||||
return self | return self | ||||
def set_need_tensor(self, **kwargs): | |||||
def _set_need_tensor(self, **kwargs): | |||||
for name, val in kwargs.items(): | for name, val in kwargs.items(): | ||||
if name in self.field_arrays: | if name in self.field_arrays: | ||||
assert isinstance(val, bool) | assert isinstance(val, bool) | ||||
@@ -320,8 +320,3 @@ def pred_topk(y_prob, k=1): | |||||
(1, k)) | (1, k)) | ||||
y_prob_topk = y_prob[x_axis_index, y_pred_topk] | y_prob_topk = y_prob[x_axis_index, y_pred_topk] | ||||
return y_pred_topk, y_prob_topk | return y_pred_topk, y_prob_topk | ||||
if __name__ == '__main__': | |||||
y = np.array([1, 0, 1, 0, 1, 1]) | |||||
print(_label_types(y)) |
@@ -1,6 +1,6 @@ | |||||
import _pickle | import _pickle | ||||
import os | import os | ||||
import inspect | |||||
def save_pickle(obj, pickle_path, file_name): | def save_pickle(obj, pickle_path, file_name): | ||||
"""Save an object into a pickle file. | """Save an object into a pickle file. | ||||
@@ -44,3 +44,18 @@ def pickle_exist(pickle_path, pickle_name): | |||||
return True | return True | ||||
else: | else: | ||||
return False | return False | ||||
def build_args(func, kwargs): | |||||
assert isinstance(func, function) and isinstance(kwargs, dict) | |||||
spect = inspect.getfullargspec(func) | |||||
assert spect.varargs is None, 'Positional Arguments({}) are not supported.'.format(spect.varargs) | |||||
needed_args = set(spect.args) | |||||
output = {name: default for name, default in zip(reversed(spect.args), reversed(spect.defaults))} | |||||
output.update({name: val for name, val in kwargs.items() if name in needed_args}) | |||||
if spect.varkw is not None: | |||||
output.update(kwargs) | |||||
# check miss args | |||||
@@ -1,7 +1,6 @@ | |||||
from . import aggregator | from . import aggregator | ||||
from . import decoder | from . import decoder | ||||
from . import encoder | from . import encoder | ||||
from . import interactor | |||||
from .aggregator import * | from .aggregator import * | ||||
from .decoder import * | from .decoder import * | ||||
from .encoder import * | from .encoder import * | ||||
@@ -12,5 +11,4 @@ __version__ = '0.0.0' | |||||
__all__ = ['encoder', | __all__ = ['encoder', | ||||
'decoder', | 'decoder', | ||||
'aggregator', | 'aggregator', | ||||
'interactor', | |||||
'TimestepDropout'] | 'TimestepDropout'] |
@@ -111,8 +111,8 @@ class CWSTagProcessor(Processor): | |||||
sentence = ins[self.field_name] | sentence = ins[self.field_name] | ||||
tag_list = self._generate_tag(sentence) | tag_list = self._generate_tag(sentence) | ||||
ins[self.new_added_field_name] = tag_list | ins[self.new_added_field_name] = tag_list | ||||
dataset.set_is_target(**{self.new_added_field_name:True}) | |||||
dataset.set_need_tensor(**{self.new_added_field_name:True}) | |||||
dataset.set_target(**{self.new_added_field_name:True}) | |||||
dataset._set_need_tensor(**{self.new_added_field_name:True}) | |||||
return dataset | return dataset | ||||
def _tags_from_word_len(self, word_len): | def _tags_from_word_len(self, word_len): | ||||
@@ -230,7 +230,7 @@ class SeqLenProcessor(Processor): | |||||
for ins in dataset: | for ins in dataset: | ||||
length = len(ins[self.field_name]) | length = len(ins[self.field_name]) | ||||
ins[self.new_added_field_name] = length | ins[self.new_added_field_name] = length | ||||
dataset.set_need_tensor(**{self.new_added_field_name:True}) | |||||
dataset._set_need_tensor(**{self.new_added_field_name:True}) | |||||
return dataset | return dataset | ||||
class SegApp2OutputProcessor(Processor): | class SegApp2OutputProcessor(Processor): | ||||