From da901ed5b092bda93c73fe3a85d753ba5da04b96 Mon Sep 17 00:00:00 2001 From: FengZiYjun Date: Fri, 30 Nov 2018 23:56:44 +0800 Subject: [PATCH] * DataSet __getitem__ returns copy of Instance * refine interface of set_target & set_input * rename DataSet.Instance into DataSet.DataSetIter * remove unused methods in DataSet.DataSetIter * remove __setattr__ in DataSet; It is dangerous. * comment adjustment --- fastNLP/core/dataset.py | 174 +++++++++++++++----------------------- test/core/test_dataset.py | 5 +- 2 files changed, 70 insertions(+), 109 deletions(-) diff --git a/fastNLP/core/dataset.py b/fastNLP/core/dataset.py index 8583b95b..920e9f11 100644 --- a/fastNLP/core/dataset.py +++ b/fastNLP/core/dataset.py @@ -1,5 +1,4 @@ import numpy as np -from copy import copy from fastNLP.core.fieldarray import FieldArray from fastNLP.core.instance import Instance @@ -28,38 +27,22 @@ class DataSet(object): """ - class Instance(object): - def __init__(self, dataset, idx=-1, **fields): - self.dataset = dataset + class DataSetIter(object): + def __init__(self, data_set, idx=-1, **fields): + self.data_set = data_set self.idx = idx self.fields = fields def __next__(self): self.idx += 1 - if self.idx >= len(self.dataset): + if self.idx >= len(self.data_set): raise StopIteration - return copy(self) - - def add_field(self, field_name, field): - """Add a new field to the instance. - - :param field_name: str, the name of the field. - :param field: - """ - self.fields[field_name] = field - - def __getitem__(self, name): - return self.dataset[name][self.idx] - - def __setitem__(self, name, val): - if name not in self.dataset: - new_fields = [None] * len(self.dataset) - self.dataset.add_field(name, new_fields) - self.dataset[name][self.idx] = val + # this returns a copy + return self.data_set[self.idx] def __repr__(self): - return "\n".join(['{}: {}'.format(name, repr(self.dataset[name][self.idx])) for name - in self.dataset.get_fields().keys()]) + return "\n".join(['{}: {}'.format(name, repr(self.data_set[name][self.idx])) for name + in self.data_set.get_fields().keys()]) def __init__(self, data=None): """ @@ -89,14 +72,41 @@ class DataSet(object): return item in self.field_arrays def __iter__(self): - return self.Instance(self) + return self.DataSetIter(self) - def _convert_ins(self, ins_list): - if isinstance(ins_list, list): - for ins in ins_list: - self.append(ins) + def __getitem__(self, idx): + """Fetch Instance(s) at the `idx` position(s) in the dataset. + Notice: This method returns a copy of the actual instance(s). Any change to the returned value would not modify + the origin instance(s) of the DataSet. + If you want to make in-place changes to all Instances, use `apply` method. + + :param idx: can be int or slice. + :return: If `idx` is int, return an Instance object. + If `idx` is slice, return a DataSet object. + """ + if isinstance(idx, int): + return Instance(**{name: self.field_arrays[name][idx] for name in self.field_arrays}) + elif isinstance(idx, slice): + data_set = DataSet() + for field in self.field_arrays.values(): + data_set.add_field(name=field.name, + fields=field.content[idx], + padding_val=field.padding_val, + is_input=field.is_input, + is_target=field.is_target) + return data_set else: - self.append(ins_list) + raise KeyError("Unrecognized type {} for idx in __getitem__ method".format(type(idx))) + + def __len__(self): + """Fetch the length of the dataset. + + :return int length: + """ + if len(self.field_arrays) == 0: + return 0 + field = iter(self.field_arrays.values()).__next__() + return len(field) def append(self, ins): """Add an instance to the DataSet. @@ -143,72 +153,47 @@ class DataSet(object): """ return self.field_arrays - def __getitem__(self, idx): - """ - - :param idx: can be int, slice, or str. - :return: If `idx` is int, return an Instance object. - If `idx` is slice, return a DataSet object. - If `idx` is str, it must be a field name, return the field. - - """ - if isinstance(idx, int): - return self.Instance(self, idx, **{name: self.field_arrays[name][idx] for name in self.field_arrays}) - elif isinstance(idx, slice): - data_set = DataSet() - for field in self.field_arrays.values(): - data_set.add_field(name=field.name, - fields=field.content[idx], - padding_val=field.padding_val, - is_input=field.is_input, - is_target=field.is_target) - return data_set - elif isinstance(idx, str): - return self.field_arrays[idx] - else: - raise KeyError("Unrecognized type {} for idx in __getitem__ method".format(type(idx))) - - def __len__(self): - if len(self.field_arrays) == 0: - return 0 - field = iter(self.field_arrays.values()).__next__() - return len(field) - def get_length(self): - """The same as __len__ + """Fetch the length of the dataset. + :return int length: """ return len(self) def rename_field(self, old_name, new_name): - """rename a field + """Rename a field. + + :param str old_name: + :param str new_name: """ if old_name in self.field_arrays: self.field_arrays[new_name] = self.field_arrays.pop(old_name) else: raise KeyError("{} is not a valid name. ".format(old_name)) - def set_target(self, **fields): - """Change the flag of `is_target` for all instance. For fields not set here, leave their `is_target` unchanged. + def set_target(self, *field_names, flag=True): + """Change the target flag of these fields. - :param key-value pairs for field-name and `is_target` value(True, False). + :param field_names: a sequence of str, indicating field names + :param bool flag: Set these fields as target if True. Unset them if False. """ - for name, val in fields.items(): + for name in field_names: if name in self.field_arrays: - assert isinstance(val, bool) - self.field_arrays[name].is_target = val + self.field_arrays[name].is_target = flag else: raise KeyError("{} is not a valid field name.".format(name)) - return self - def set_input(self, **fields): - for name, val in fields.items(): + def set_input(self, *field_name, flag=True): + """Set the input flag of these fields. + + :param field_name: a sequence of str, indicating field names. + :param bool flag: Set these fields as input if True. Unset them if False. + """ + for name in field_name: if name in self.field_arrays: - assert isinstance(val, bool) - self.field_arrays[name].is_input = val + self.field_arrays[name].is_input = flag else: raise KeyError("{} is not a valid field name.".format(name)) - return self def get_input_name(self): return [name for name, field in self.field_arrays.items() if field.is_input] @@ -216,27 +201,6 @@ class DataSet(object): def get_target_name(self): return [name for name, field in self.field_arrays.items() if field.is_target] - def __getattr__(self, item): - # block infinite recursion for copy, pickle - if item == '__setstate__': - raise AttributeError(item) - try: - return self.field_arrays.__getitem__(item) - except KeyError: - pass - try: - reader_cls = _READERS[item] - - # add read_*data() support - def _read(*args, **kwargs): - data = reader_cls().load(*args, **kwargs) - self.extend(data) - return self - - return _read - except KeyError: - raise AttributeError('{} does not exist.'.format(item)) - @classmethod def set_reader(cls, method_name): """decorator to add dataloader support @@ -275,7 +239,6 @@ class DataSet(object): results = [ins for ins in self if not func(ins)] for name, old_field in self.field_arrays.items(): self.field_arrays[name].content = [ins[name] for ins in results] - # print(self.field_arrays[name]) def split(self, dev_ratio): """Split the dataset into training and development(validation) set. @@ -300,27 +263,28 @@ class DataSet(object): return train_set, dev_set @classmethod - def read_csv(cls, csv_path, headers=None, sep='\t', dropna=True): - with open(csv_path, 'r') as f: + def read_csv(cls, csv_path, headers=None, sep=",", dropna=True): + with open(csv_path, "r") as f: start_idx = 0 if headers is None: headers = f.readline().rstrip('\r\n') headers = headers.split(sep) start_idx += 1 else: - assert isinstance(headers, (list, tuple)), "headers should be list or tuple, not {}.".format(type(headers)) + assert isinstance(headers, (list, tuple)), "headers should be list or tuple, not {}.".format( + type(headers)) _dict = {} for col in headers: _dict[col] = [] for line_idx, line in enumerate(f, start_idx): contents = line.split(sep) - if len(contents)!=len(headers): + if len(contents) != len(headers): if dropna: continue else: - #TODO change error type - raise ValueError("Line {} has {} parts, while header has {} parts."\ - .format(line_idx, len(contents), len(headers))) + # TODO change error type + raise ValueError("Line {} has {} parts, while header has {} parts." \ + .format(line_idx, len(contents), len(headers))) for header, content in zip(headers, contents): _dict[header].append(content) return cls(_dict) diff --git a/test/core/test_dataset.py b/test/core/test_dataset.py index b985b253..786e7248 100644 --- a/test/core/test_dataset.py +++ b/test/core/test_dataset.py @@ -55,7 +55,7 @@ class TestDataSet(unittest.TestCase): def test_getitem(self): ds = DataSet({"x": [[1, 2, 3, 4]] * 40, "y": [[5, 6]] * 40}) ins_1, ins_0 = ds[0], ds[1] - self.assertTrue(isinstance(ins_1, DataSet.Instance) and isinstance(ins_0, DataSet.Instance)) + self.assertTrue(isinstance(ins_1, Instance) and isinstance(ins_0, Instance)) self.assertEqual(ins_1["x"], [1, 2, 3, 4]) self.assertEqual(ins_1["y"], [5, 6]) self.assertEqual(ins_0["x"], [1, 2, 3, 4]) @@ -65,9 +65,6 @@ class TestDataSet(unittest.TestCase): self.assertTrue(isinstance(sub_ds, DataSet)) self.assertEqual(len(sub_ds), 10) - field = ds["x"] - self.assertEqual(field, ds.field_arrays["x"]) - def test_apply(self): ds = DataSet({"x": [[1, 2, 3, 4]] * 40, "y": [[5, 6]] * 40}) ds.apply(lambda ins: ins["x"][::-1], new_field_name="rx")