Browse Source

* DataSet __getitem__ returns copy of Instance

* refine interface of set_target & set_input
* rename DataSet.Instance into DataSet.DataSetIter
* remove unused methods in DataSet.DataSetIter
* remove __setattr__ in DataSet; It is dangerous.
* comment adjustment
tags/v0.2.0^2
FengZiYjun 5 years ago
parent
commit
da901ed5b0
2 changed files with 70 additions and 109 deletions
  1. +69
    -105
      fastNLP/core/dataset.py
  2. +1
    -4
      test/core/test_dataset.py

+ 69
- 105
fastNLP/core/dataset.py View File

@@ -1,5 +1,4 @@
import numpy as np
from copy import copy

from fastNLP.core.fieldarray import FieldArray
from fastNLP.core.instance import Instance
@@ -28,38 +27,22 @@ class DataSet(object):

"""

class Instance(object):
def __init__(self, dataset, idx=-1, **fields):
self.dataset = dataset
class DataSetIter(object):
def __init__(self, data_set, idx=-1, **fields):
self.data_set = data_set
self.idx = idx
self.fields = fields

def __next__(self):
self.idx += 1
if self.idx >= len(self.dataset):
if self.idx >= len(self.data_set):
raise StopIteration
return copy(self)

def add_field(self, field_name, field):
"""Add a new field to the instance.

:param field_name: str, the name of the field.
:param field:
"""
self.fields[field_name] = field

def __getitem__(self, name):
return self.dataset[name][self.idx]

def __setitem__(self, name, val):
if name not in self.dataset:
new_fields = [None] * len(self.dataset)
self.dataset.add_field(name, new_fields)
self.dataset[name][self.idx] = val
# this returns a copy
return self.data_set[self.idx]

def __repr__(self):
return "\n".join(['{}: {}'.format(name, repr(self.dataset[name][self.idx])) for name
in self.dataset.get_fields().keys()])
return "\n".join(['{}: {}'.format(name, repr(self.data_set[name][self.idx])) for name
in self.data_set.get_fields().keys()])

def __init__(self, data=None):
"""
@@ -89,14 +72,41 @@ class DataSet(object):
return item in self.field_arrays

def __iter__(self):
return self.Instance(self)
return self.DataSetIter(self)

def _convert_ins(self, ins_list):
if isinstance(ins_list, list):
for ins in ins_list:
self.append(ins)
def __getitem__(self, idx):
"""Fetch Instance(s) at the `idx` position(s) in the dataset.
Notice: This method returns a copy of the actual instance(s). Any change to the returned value would not modify
the origin instance(s) of the DataSet.
If you want to make in-place changes to all Instances, use `apply` method.

:param idx: can be int or slice.
:return: If `idx` is int, return an Instance object.
If `idx` is slice, return a DataSet object.
"""
if isinstance(idx, int):
return Instance(**{name: self.field_arrays[name][idx] for name in self.field_arrays})
elif isinstance(idx, slice):
data_set = DataSet()
for field in self.field_arrays.values():
data_set.add_field(name=field.name,
fields=field.content[idx],
padding_val=field.padding_val,
is_input=field.is_input,
is_target=field.is_target)
return data_set
else:
self.append(ins_list)
raise KeyError("Unrecognized type {} for idx in __getitem__ method".format(type(idx)))

def __len__(self):
"""Fetch the length of the dataset.

:return int length:
"""
if len(self.field_arrays) == 0:
return 0
field = iter(self.field_arrays.values()).__next__()
return len(field)

def append(self, ins):
"""Add an instance to the DataSet.
@@ -143,72 +153,47 @@ class DataSet(object):
"""
return self.field_arrays

def __getitem__(self, idx):
"""

:param idx: can be int, slice, or str.
:return: If `idx` is int, return an Instance object.
If `idx` is slice, return a DataSet object.
If `idx` is str, it must be a field name, return the field.

"""
if isinstance(idx, int):
return self.Instance(self, idx, **{name: self.field_arrays[name][idx] for name in self.field_arrays})
elif isinstance(idx, slice):
data_set = DataSet()
for field in self.field_arrays.values():
data_set.add_field(name=field.name,
fields=field.content[idx],
padding_val=field.padding_val,
is_input=field.is_input,
is_target=field.is_target)
return data_set
elif isinstance(idx, str):
return self.field_arrays[idx]
else:
raise KeyError("Unrecognized type {} for idx in __getitem__ method".format(type(idx)))

def __len__(self):
if len(self.field_arrays) == 0:
return 0
field = iter(self.field_arrays.values()).__next__()
return len(field)

def get_length(self):
"""The same as __len__
"""Fetch the length of the dataset.

:return int length:
"""
return len(self)

def rename_field(self, old_name, new_name):
"""rename a field
"""Rename a field.

:param str old_name:
:param str new_name:
"""
if old_name in self.field_arrays:
self.field_arrays[new_name] = self.field_arrays.pop(old_name)
else:
raise KeyError("{} is not a valid name. ".format(old_name))

def set_target(self, **fields):
"""Change the flag of `is_target` for all instance. For fields not set here, leave their `is_target` unchanged.
def set_target(self, *field_names, flag=True):
"""Change the target flag of these fields.

:param key-value pairs for field-name and `is_target` value(True, False).
:param field_names: a sequence of str, indicating field names
:param bool flag: Set these fields as target if True. Unset them if False.
"""
for name, val in fields.items():
for name in field_names:
if name in self.field_arrays:
assert isinstance(val, bool)
self.field_arrays[name].is_target = val
self.field_arrays[name].is_target = flag
else:
raise KeyError("{} is not a valid field name.".format(name))
return self

def set_input(self, **fields):
for name, val in fields.items():
def set_input(self, *field_name, flag=True):
"""Set the input flag of these fields.

:param field_name: a sequence of str, indicating field names.
:param bool flag: Set these fields as input if True. Unset them if False.
"""
for name in field_name:
if name in self.field_arrays:
assert isinstance(val, bool)
self.field_arrays[name].is_input = val
self.field_arrays[name].is_input = flag
else:
raise KeyError("{} is not a valid field name.".format(name))
return self

def get_input_name(self):
return [name for name, field in self.field_arrays.items() if field.is_input]
@@ -216,27 +201,6 @@ class DataSet(object):
def get_target_name(self):
return [name for name, field in self.field_arrays.items() if field.is_target]

def __getattr__(self, item):
# block infinite recursion for copy, pickle
if item == '__setstate__':
raise AttributeError(item)
try:
return self.field_arrays.__getitem__(item)
except KeyError:
pass
try:
reader_cls = _READERS[item]

# add read_*data() support
def _read(*args, **kwargs):
data = reader_cls().load(*args, **kwargs)
self.extend(data)
return self

return _read
except KeyError:
raise AttributeError('{} does not exist.'.format(item))

@classmethod
def set_reader(cls, method_name):
"""decorator to add dataloader support
@@ -275,7 +239,6 @@ class DataSet(object):
results = [ins for ins in self if not func(ins)]
for name, old_field in self.field_arrays.items():
self.field_arrays[name].content = [ins[name] for ins in results]
# print(self.field_arrays[name])

def split(self, dev_ratio):
"""Split the dataset into training and development(validation) set.
@@ -300,27 +263,28 @@ class DataSet(object):
return train_set, dev_set

@classmethod
def read_csv(cls, csv_path, headers=None, sep='\t', dropna=True):
with open(csv_path, 'r') as f:
def read_csv(cls, csv_path, headers=None, sep=",", dropna=True):
with open(csv_path, "r") as f:
start_idx = 0
if headers is None:
headers = f.readline().rstrip('\r\n')
headers = headers.split(sep)
start_idx += 1
else:
assert isinstance(headers, (list, tuple)), "headers should be list or tuple, not {}.".format(type(headers))
assert isinstance(headers, (list, tuple)), "headers should be list or tuple, not {}.".format(
type(headers))
_dict = {}
for col in headers:
_dict[col] = []
for line_idx, line in enumerate(f, start_idx):
contents = line.split(sep)
if len(contents)!=len(headers):
if len(contents) != len(headers):
if dropna:
continue
else:
#TODO change error type
raise ValueError("Line {} has {} parts, while header has {} parts."\
.format(line_idx, len(contents), len(headers)))
# TODO change error type
raise ValueError("Line {} has {} parts, while header has {} parts." \
.format(line_idx, len(contents), len(headers)))
for header, content in zip(headers, contents):
_dict[header].append(content)
return cls(_dict)

+ 1
- 4
test/core/test_dataset.py View File

@@ -55,7 +55,7 @@ class TestDataSet(unittest.TestCase):
def test_getitem(self):
ds = DataSet({"x": [[1, 2, 3, 4]] * 40, "y": [[5, 6]] * 40})
ins_1, ins_0 = ds[0], ds[1]
self.assertTrue(isinstance(ins_1, DataSet.Instance) and isinstance(ins_0, DataSet.Instance))
self.assertTrue(isinstance(ins_1, Instance) and isinstance(ins_0, Instance))
self.assertEqual(ins_1["x"], [1, 2, 3, 4])
self.assertEqual(ins_1["y"], [5, 6])
self.assertEqual(ins_0["x"], [1, 2, 3, 4])
@@ -65,9 +65,6 @@ class TestDataSet(unittest.TestCase):
self.assertTrue(isinstance(sub_ds, DataSet))
self.assertEqual(len(sub_ds), 10)

field = ds["x"]
self.assertEqual(field, ds.field_arrays["x"])

def test_apply(self):
ds = DataSet({"x": [[1, 2, 3, 4]] * 40, "y": [[5, 6]] * 40})
ds.apply(lambda ins: ins["x"][::-1], new_field_name="rx")


Loading…
Cancel
Save