@@ -1,5 +1,4 @@
import numpy as np
from copy import copy
from fastNLP.core.fieldarray import FieldArray
from fastNLP.core.instance import Instance
@@ -28,38 +27,22 @@ class DataSet(object):
"""
class Ins tanc e(object):
def __init__(self, dataset, idx=-1, **fields):
self.dataset = dataset
class DataSet Iter (object):
def __init__(self, data_ set, idx=-1, **fields):
self.data_ set = data_ set
self.idx = idx
self.fields = fields
def __next__(self):
self.idx += 1
if self.idx >= len(self.dataset):
if self.idx >= len(self.data_ set):
raise StopIteration
return copy(self)
def add_field(self, field_name, field):
"""Add a new field to the instance.
:param field_name: str, the name of the field.
:param field:
"""
self.fields[field_name] = field
def __getitem__(self, name):
return self.dataset[name][self.idx]
def __setitem__(self, name, val):
if name not in self.dataset:
new_fields = [None] * len(self.dataset)
self.dataset.add_field(name, new_fields)
self.dataset[name][self.idx] = val
# this returns a copy
return self.data_set[self.idx]
def __repr__(self):
return "\n".join(['{}: {}'.format(name, repr(self.dataset[name][self.idx])) for name
in self.dataset.get_fields().keys()])
return "\n".join(['{}: {}'.format(name, repr(self.data_set[name][self.idx])) for name
in self.data_set.get_fields().keys()])
def __init__(self, data=None):
"""
@@ -89,14 +72,41 @@ class DataSet(object):
return item in self.field_arrays
def __iter__(self):
return self.Ins tanc e(self)
return self.DataSet Iter (self)
def _convert_ins(self, ins_list):
if isinstance(ins_list, list):
for ins in ins_list:
self.append(ins)
def __getitem__(self, idx):
"""Fetch Instance(s) at the `idx` position(s) in the dataset.
Notice: This method returns a copy of the actual instance(s). Any change to the returned value would not modify
the origin instance(s) of the DataSet.
If you want to make in-place changes to all Instances, use `apply` method.
:param idx: can be int or slice.
:return: If `idx` is int, return an Instance object.
If `idx` is slice, return a DataSet object.
"""
if isinstance(idx, int):
return Instance(**{name: self.field_arrays[name][idx] for name in self.field_arrays})
elif isinstance(idx, slice):
data_set = DataSet()
for field in self.field_arrays.values():
data_set.add_field(name=field.name,
fields=field.content[idx],
padding_val=field.padding_val,
is_input=field.is_input,
is_target=field.is_target)
return data_set
else:
self.append(ins_list)
raise KeyError("Unrecognized type {} for idx in __getitem__ method".format(type(idx)))
def __len__(self):
"""Fetch the length of the dataset.
:return int length:
"""
if len(self.field_arrays) == 0:
return 0
field = iter(self.field_arrays.values()).__next__()
return len(field)
def append(self, ins):
"""Add an instance to the DataSet.
@@ -143,72 +153,47 @@ class DataSet(object):
"""
return self.field_arrays
def __getitem__(self, idx):
"""
:param idx: can be int, slice, or str.
:return: If `idx` is int, return an Instance object.
If `idx` is slice, return a DataSet object.
If `idx` is str, it must be a field name, return the field.
"""
if isinstance(idx, int):
return self.Instance(self, idx, **{name: self.field_arrays[name][idx] for name in self.field_arrays})
elif isinstance(idx, slice):
data_set = DataSet()
for field in self.field_arrays.values():
data_set.add_field(name=field.name,
fields=field.content[idx],
padding_val=field.padding_val,
is_input=field.is_input,
is_target=field.is_target)
return data_set
elif isinstance(idx, str):
return self.field_arrays[idx]
else:
raise KeyError("Unrecognized type {} for idx in __getitem__ method".format(type(idx)))
def __len__(self):
if len(self.field_arrays) == 0:
return 0
field = iter(self.field_arrays.values()).__next__()
return len(field)
def get_length(self):
"""The same as __len__
"""Fetch the length of the dataset.
:return int length:
"""
return len(self)
def rename_field(self, old_name, new_name):
"""rename a field
"""Rename a field.
:param str old_name:
:param str new_name:
"""
if old_name in self.field_arrays:
self.field_arrays[new_name] = self.field_arrays.pop(old_name)
else:
raise KeyError("{} is not a valid name. ".format(old_name))
def set_target(self, ** fields):
"""Change the flag of `is_target` for all instance. For fields not set here, leave their `is_target` unchanged .
def set_target(self, *field_names, flag=True):
"""Change the target flag of these fields.
:param key-value pairs for field-name and `is_target` value(True, False).
:param field_names: a sequence of str, indicating field names
:param bool flag: Set these fields as target if True. Unset them if False.
"""
for name, val in fields.items() :
for name in field_names:
if name in self.field_arrays:
assert isinstance(val, bool)
self.field_arrays[name].is_target = val
self.field_arrays[name].is_target = flag
else:
raise KeyError("{} is not a valid field name.".format(name))
return self
def set_input(self, **fields):
for name, val in fields.items():
def set_input(self, *field_name, flag=True):
"""Set the input flag of these fields.
:param field_name: a sequence of str, indicating field names.
:param bool flag: Set these fields as input if True. Unset them if False.
"""
for name in field_name:
if name in self.field_arrays:
assert isinstance(val, bool)
self.field_arrays[name].is_input = val
self.field_arrays[name].is_input = flag
else:
raise KeyError("{} is not a valid field name.".format(name))
return self
def get_input_name(self):
return [name for name, field in self.field_arrays.items() if field.is_input]
@@ -216,27 +201,6 @@ class DataSet(object):
def get_target_name(self):
return [name for name, field in self.field_arrays.items() if field.is_target]
def __getattr__(self, item):
# block infinite recursion for copy, pickle
if item == '__setstate__':
raise AttributeError(item)
try:
return self.field_arrays.__getitem__(item)
except KeyError:
pass
try:
reader_cls = _READERS[item]
# add read_*data() support
def _read(*args, **kwargs):
data = reader_cls().load(*args, **kwargs)
self.extend(data)
return self
return _read
except KeyError:
raise AttributeError('{} does not exist.'.format(item))
@classmethod
def set_reader(cls, method_name):
"""decorator to add dataloader support
@@ -275,7 +239,6 @@ class DataSet(object):
results = [ins for ins in self if not func(ins)]
for name, old_field in self.field_arrays.items():
self.field_arrays[name].content = [ins[name] for ins in results]
# print(self.field_arrays[name])
def split(self, dev_ratio):
"""Split the dataset into training and development(validation) set.
@@ -300,27 +263,28 @@ class DataSet(object):
return train_set, dev_set
@classmethod
def read_csv(cls, csv_path, headers=None, sep='\t' , dropna=True):
with open(csv_path, 'r' ) as f:
def read_csv(cls, csv_path, headers=None, sep="," , dropna=True):
with open(csv_path, "r" ) as f:
start_idx = 0
if headers is None:
headers = f.readline().rstrip('\r\n')
headers = headers.split(sep)
start_idx += 1
else:
assert isinstance(headers, (list, tuple)), "headers should be list or tuple, not {}.".format(type(headers))
assert isinstance(headers, (list, tuple)), "headers should be list or tuple, not {}.".format(
type(headers))
_dict = {}
for col in headers:
_dict[col] = []
for line_idx, line in enumerate(f, start_idx):
contents = line.split(sep)
if len(contents)!=len(headers):
if len(contents) != len(headers):
if dropna:
continue
else:
#TODO change error type
raise ValueError("Line {} has {} parts, while header has {} parts."\
.format(line_idx, len(contents), len(headers)))
# TODO change error type
raise ValueError("Line {} has {} parts, while header has {} parts." \
.format(line_idx, len(contents), len(headers)))
for header, content in zip(headers, contents):
_dict[header].append(content)
return cls(_dict)