|
|
@@ -126,8 +126,8 @@ class RawDataSetLoader(DataSetLoader): |
|
|
|
DataLoaderRegister.set_reader(RawDataSetLoader, 'read_rawdata') |
|
|
|
|
|
|
|
|
|
|
|
class POSDataSetLoader(DataSetLoader): |
|
|
|
"""Dataset Loader for a POS Tag dataset. |
|
|
|
class DummyPOSReader(DataSetLoader): |
|
|
|
"""A simple reader for a dummy POS tagging dataset. |
|
|
|
|
|
|
|
In these datasets, each line are divided by "\t". The first Col is the vocabulary and the second |
|
|
|
Col is the label. Different sentence are divided by an empty line. |
|
|
@@ -146,7 +146,7 @@ class POSDataSetLoader(DataSetLoader): |
|
|
|
""" |
|
|
|
|
|
|
|
def __init__(self): |
|
|
|
super(POSDataSetLoader, self).__init__() |
|
|
|
super(DummyPOSReader, self).__init__() |
|
|
|
|
|
|
|
def load(self, data_path): |
|
|
|
""" |
|
|
@@ -194,16 +194,14 @@ class POSDataSetLoader(DataSetLoader): |
|
|
|
return convert_seq2seq_dataset(data) |
|
|
|
|
|
|
|
|
|
|
|
DataLoaderRegister.set_reader(POSDataSetLoader, 'read_pos') |
|
|
|
DataLoaderRegister.set_reader(DummyPOSReader, 'read_pos') |
|
|
|
|
|
|
|
|
|
|
|
class TokenizeDataSetLoader(DataSetLoader): |
|
|
|
class DummyCWSReader(DataSetLoader): |
|
|
|
"""Load pku dataset for Chinese word segmentation. |
|
|
|
""" |
|
|
|
Data set loader for tokenization data sets |
|
|
|
""" |
|
|
|
|
|
|
|
def __init__(self): |
|
|
|
super(TokenizeDataSetLoader, self).__init__() |
|
|
|
super(DummyCWSReader, self).__init__() |
|
|
|
|
|
|
|
def load(self, data_path, max_seq_len=32): |
|
|
|
"""Load pku dataset for Chinese word segmentation. |
|
|
@@ -256,11 +254,11 @@ class TokenizeDataSetLoader(DataSetLoader): |
|
|
|
return convert_seq2seq_dataset(data) |
|
|
|
|
|
|
|
|
|
|
|
class ClassDataSetLoader(DataSetLoader): |
|
|
|
class DummyClassificationReader(DataSetLoader): |
|
|
|
"""Loader for a dummy classification data set""" |
|
|
|
|
|
|
|
def __init__(self): |
|
|
|
super(ClassDataSetLoader, self).__init__() |
|
|
|
super(DummyClassificationReader, self).__init__() |
|
|
|
|
|
|
|
def load(self, data_path): |
|
|
|
assert os.path.exists(data_path) |
|
|
@@ -271,7 +269,7 @@ class ClassDataSetLoader(DataSetLoader): |
|
|
|
|
|
|
|
@staticmethod |
|
|
|
def parse(lines): |
|
|
|
""" |
|
|
|
"""每行第一个token是标签,其余是字/词;由空格分隔。 |
|
|
|
|
|
|
|
:param lines: lines from dataset |
|
|
|
:return: list(list(list())): the three level of lists are words, sentence, and dataset |
|
|
@@ -327,16 +325,11 @@ class ConllLoader(DataSetLoader): |
|
|
|
pass |
|
|
|
|
|
|
|
|
|
|
|
class LMDataSetLoader(DataSetLoader): |
|
|
|
"""Language Model Dataset Loader |
|
|
|
|
|
|
|
This loader produces data for language model training in a supervised way. |
|
|
|
That means it has X and Y. |
|
|
|
|
|
|
|
class DummyLMReader(DataSetLoader): |
|
|
|
"""A Dummy Language Model Dataset Reader |
|
|
|
""" |
|
|
|
|
|
|
|
def __init__(self): |
|
|
|
super(LMDataSetLoader, self).__init__() |
|
|
|
super(DummyLMReader, self).__init__() |
|
|
|
|
|
|
|
def load(self, data_path): |
|
|
|
if not os.path.exists(data_path): |
|
|
@@ -364,19 +357,25 @@ class LMDataSetLoader(DataSetLoader): |
|
|
|
|
|
|
|
|
|
|
|
class PeopleDailyCorpusLoader(DataSetLoader): |
|
|
|
"""人民日报数据集 |
|
|
|
""" |
|
|
|
People Daily Corpus: Chinese word segmentation, POS tag, NER |
|
|
|
""" |
|
|
|
|
|
|
|
def __init__(self): |
|
|
|
super(PeopleDailyCorpusLoader, self).__init__() |
|
|
|
self.pos = True |
|
|
|
self.ner = True |
|
|
|
|
|
|
|
def load(self, data_path): |
|
|
|
def load(self, data_path, pos=True, ner=True): |
|
|
|
""" |
|
|
|
|
|
|
|
:param str data_path: 数据路径 |
|
|
|
:param bool pos: 是否使用词性标签 |
|
|
|
:param bool ner: 是否使用命名实体标签 |
|
|
|
:return: a DataSet object |
|
|
|
""" |
|
|
|
self.pos, self.ner = pos, ner |
|
|
|
with open(data_path, "r", encoding="utf-8") as f: |
|
|
|
sents = f.readlines() |
|
|
|
|
|
|
|
pos_tag_examples = [] |
|
|
|
ner_examples = [] |
|
|
|
examples = [] |
|
|
|
for sent in sents: |
|
|
|
if len(sent) <= 2: |
|
|
|
continue |
|
|
@@ -410,40 +409,44 @@ class PeopleDailyCorpusLoader(DataSetLoader): |
|
|
|
sent_ner.append(ner_tag) |
|
|
|
sent_pos_tag.append(pos) |
|
|
|
sent_words.append(token) |
|
|
|
pos_tag_examples.append([sent_words, sent_pos_tag]) |
|
|
|
ner_examples.append([sent_words, sent_ner]) |
|
|
|
# List[List[List[str], List[str]]] |
|
|
|
# ner_examples not used |
|
|
|
return self.convert(pos_tag_examples) |
|
|
|
example = [sent_words] |
|
|
|
if self.pos is True: |
|
|
|
example.append(sent_pos_tag) |
|
|
|
if self.ner is True: |
|
|
|
example.append(sent_ner) |
|
|
|
examples.append(example) |
|
|
|
return self.convert(examples) |
|
|
|
|
|
|
|
def convert(self, data): |
|
|
|
data_set = DataSet() |
|
|
|
for item in data: |
|
|
|
sent_words, sent_pos_tag = item[0], item[1] |
|
|
|
data_set.append(Instance(words=sent_words, tags=sent_pos_tag)) |
|
|
|
data_set.apply(lambda ins: len(ins), new_field_name="seq_len") |
|
|
|
data_set.set_target("tags") |
|
|
|
data_set.set_input("sent_words") |
|
|
|
data_set.set_input("seq_len") |
|
|
|
sent_words = item[0] |
|
|
|
if self.pos is True and self.ner is True: |
|
|
|
instance = Instance(words=sent_words, pos_tags=item[1], ner=item[2]) |
|
|
|
elif self.pos is True: |
|
|
|
instance = Instance(words=sent_words, pos_tags=item[1]) |
|
|
|
elif self.ner is True: |
|
|
|
instance = Instance(words=sent_words, ner=item[1]) |
|
|
|
else: |
|
|
|
instance = Instance(words=sent_words) |
|
|
|
data_set.append(instance) |
|
|
|
data_set.apply(lambda ins: len(ins["words"]), new_field_name="seq_len") |
|
|
|
return data_set |
|
|
|
|
|
|
|
|
|
|
|
class Conll2003Loader(DataSetLoader): |
|
|
|
"""Self-defined loader of conll2003 dataset |
|
|
|
"""Loader for conll2003 dataset |
|
|
|
|
|
|
|
More information about the given dataset cound be found on |
|
|
|
https://sites.google.com/site/ermasoftware/getting-started/ne-tagging-conll2003-data |
|
|
|
|
|
|
|
""" |
|
|
|
|
|
|
|
def __init__(self): |
|
|
|
super(Conll2003Loader, self).__init__() |
|
|
|
|
|
|
|
def load(self, dataset_path): |
|
|
|
with open(dataset_path, "r", encoding="utf-8") as f: |
|
|
|
lines = f.readlines() |
|
|
|
|
|
|
|
##Parse the dataset line by line |
|
|
|
parsed_data = [] |
|
|
|
sentence = [] |
|
|
|
tokens = [] |
|
|
@@ -470,21 +473,20 @@ class Conll2003Loader(DataSetLoader): |
|
|
|
lambda labels: labels[1], sample[1])) |
|
|
|
label2_list = list(map( |
|
|
|
lambda labels: labels[2], sample[1])) |
|
|
|
dataset.append(Instance(token_list=sample[0], |
|
|
|
label0_list=label0_list, |
|
|
|
label1_list=label1_list, |
|
|
|
label2_list=label2_list)) |
|
|
|
dataset.append(Instance(tokens=sample[0], |
|
|
|
pos=label0_list, |
|
|
|
chucks=label1_list, |
|
|
|
ner=label2_list)) |
|
|
|
|
|
|
|
return dataset |
|
|
|
|
|
|
|
|
|
|
|
class SNLIDataSetLoader(DataSetLoader): |
|
|
|
class SNLIDataSetReader(DataSetLoader): |
|
|
|
"""A data set loader for SNLI data set. |
|
|
|
|
|
|
|
""" |
|
|
|
|
|
|
|
def __init__(self): |
|
|
|
super(SNLIDataSetLoader, self).__init__() |
|
|
|
super(SNLIDataSetReader, self).__init__() |
|
|
|
|
|
|
|
def load(self, path_list): |
|
|
|
""" |
|
|
@@ -553,6 +555,8 @@ class ConllCWSReader(object): |
|
|
|
""" |
|
|
|
返回的DataSet只包含raw_sentence这个field,内容为str。 |
|
|
|
假定了输入为conll的格式,以空行隔开两个句子,每行共7列,即 |
|
|
|
:: |
|
|
|
|
|
|
|
1 编者按 编者按 NN O 11 nmod:topic |
|
|
|
2 : : PU O 11 punct |
|
|
|
3 7月 7月 NT DATE 4 compound:nn |
|
|
@@ -564,6 +568,7 @@ class ConllCWSReader(object): |
|
|
|
3 飞行 飞行 NN O 8 nsubj |
|
|
|
4 从 从 P O 5 case |
|
|
|
5 外型 外型 NN O 8 nmod:prep |
|
|
|
|
|
|
|
""" |
|
|
|
datalist = [] |
|
|
|
with open(path, 'r', encoding='utf-8') as f: |
|
|
@@ -575,7 +580,7 @@ class ConllCWSReader(object): |
|
|
|
elif line.startswith('#'): |
|
|
|
continue |
|
|
|
else: |
|
|
|
sample.append(line.split('\t')) |
|
|
|
sample.append(line.strip().split()) |
|
|
|
if len(sample) > 0: |
|
|
|
datalist.append(sample) |
|
|
|
|
|
|
@@ -592,7 +597,6 @@ class ConllCWSReader(object): |
|
|
|
sents = [line] |
|
|
|
for raw_sentence in sents: |
|
|
|
ds.append(Instance(raw_sentence=raw_sentence)) |
|
|
|
|
|
|
|
return ds |
|
|
|
|
|
|
|
def get_char_lst(self, sample): |
|
|
@@ -607,70 +611,22 @@ class ConllCWSReader(object): |
|
|
|
return text |
|
|
|
|
|
|
|
|
|
|
|
class POSCWSReader(DataSetLoader): |
|
|
|
""" |
|
|
|
支持读取以下的情况, 即每一行是一个词, 用空行作为两句话的界限. |
|
|
|
迈 N |
|
|
|
向 N |
|
|
|
充 N |
|
|
|
... |
|
|
|
泽 I-PER |
|
|
|
民 I-PER |
|
|
|
|
|
|
|
( N |
|
|
|
一 N |
|
|
|
九 N |
|
|
|
... |
|
|
|
|
|
|
|
|
|
|
|
:param filepath: |
|
|
|
:return: |
|
|
|
""" |
|
|
|
|
|
|
|
def __init__(self, in_word_splitter=None): |
|
|
|
super().__init__() |
|
|
|
self.in_word_splitter = in_word_splitter |
|
|
|
|
|
|
|
def load(self, filepath, in_word_splitter=None, cut_long_sent=False): |
|
|
|
if in_word_splitter is None: |
|
|
|
in_word_splitter = self.in_word_splitter |
|
|
|
dataset = DataSet() |
|
|
|
with open(filepath, 'r') as f: |
|
|
|
words = [] |
|
|
|
for line in f: |
|
|
|
line = line.strip() |
|
|
|
if len(line) == 0: # new line |
|
|
|
if len(words) == 0: # 不能接受空行 |
|
|
|
continue |
|
|
|
line = ' '.join(words) |
|
|
|
if cut_long_sent: |
|
|
|
sents = cut_long_sentence(line) |
|
|
|
else: |
|
|
|
sents = [line] |
|
|
|
for sent in sents: |
|
|
|
instance = Instance(raw_sentence=sent) |
|
|
|
dataset.append(instance) |
|
|
|
words = [] |
|
|
|
else: |
|
|
|
line = line.split()[0] |
|
|
|
if in_word_splitter is None: |
|
|
|
words.append(line) |
|
|
|
else: |
|
|
|
words.append(line.split(in_word_splitter)[0]) |
|
|
|
return dataset |
|
|
|
|
|
|
|
|
|
|
|
class NaiveCWSReader(DataSetLoader): |
|
|
|
""" |
|
|
|
这个reader假设了分词数据集为以下形式, 即已经用空格分割好内容了 |
|
|
|
例如:: |
|
|
|
|
|
|
|
这是 fastNLP , 一个 非常 good 的 包 . |
|
|
|
|
|
|
|
或者,即每个part后面还有一个pos tag |
|
|
|
例如:: |
|
|
|
|
|
|
|
也/D 在/P 團員/Na 之中/Ng ,/COMMACATEGORY |
|
|
|
|
|
|
|
""" |
|
|
|
|
|
|
|
def __init__(self, in_word_splitter=None): |
|
|
|
super().__init__() |
|
|
|
|
|
|
|
super(NaiveCWSReader, self).__init__() |
|
|
|
self.in_word_splitter = in_word_splitter |
|
|
|
|
|
|
|
def load(self, filepath, in_word_splitter=None, cut_long_sent=False): |
|
|
@@ -680,8 +636,10 @@ class NaiveCWSReader(DataSetLoader): |
|
|
|
和 |
|
|
|
也/D 在/P 團員/Na 之中/Ng ,/COMMACATEGORY |
|
|
|
如果splitter不为None则认为是第二种情况, 且我们会按splitter分割"也/D", 然后取第一部分. 例如"也/D".split('/')[0] |
|
|
|
|
|
|
|
:param filepath: |
|
|
|
:param in_word_splitter: |
|
|
|
:param cut_long_sent: |
|
|
|
:return: |
|
|
|
""" |
|
|
|
if in_word_splitter == None: |
|
|
@@ -740,7 +698,9 @@ def cut_long_sentence(sent, max_sample_length=200): |
|
|
|
|
|
|
|
|
|
|
|
class ZhConllPOSReader(object): |
|
|
|
# 中文colln格式reader |
|
|
|
"""读取中文Conll格式。返回“字级别”的标签,使用BMES记号扩展原来的词级别标签。 |
|
|
|
|
|
|
|
""" |
|
|
|
def __init__(self): |
|
|
|
pass |
|
|
|
|
|
|
@@ -750,6 +710,8 @@ class ZhConllPOSReader(object): |
|
|
|
words:list of str, |
|
|
|
tag: list of str, 被加入了BMES tag, 比如原来的序列为['VP', 'NN', 'NN', ..],会被认为是["S-VP", "B-NN", "M-NN",..] |
|
|
|
假定了输入为conll的格式,以空行隔开两个句子,每行共7列,即 |
|
|
|
:: |
|
|
|
|
|
|
|
1 编者按 编者按 NN O 11 nmod:topic |
|
|
|
2 : : PU O 11 punct |
|
|
|
3 7月 7月 NT DATE 4 compound:nn |
|
|
@@ -761,6 +723,7 @@ class ZhConllPOSReader(object): |
|
|
|
3 飞行 飞行 NN O 8 nsubj |
|
|
|
4 从 从 P O 5 case |
|
|
|
5 外型 外型 NN O 8 nmod:prep |
|
|
|
|
|
|
|
""" |
|
|
|
datalist = [] |
|
|
|
with open(path, 'r', encoding='utf-8') as f: |
|
|
@@ -815,67 +778,10 @@ class ZhConllPOSReader(object): |
|
|
|
return text, pos_tags |
|
|
|
|
|
|
|
|
|
|
|
class ConllPOSReader(object): |
|
|
|
# 返回的Dataset包含words(list of list, 里层的list是character), tag两个field(list of str, str是标有BIO的tag)。 |
|
|
|
def __init__(self): |
|
|
|
pass |
|
|
|
|
|
|
|
def load(self, path): |
|
|
|
datalist = [] |
|
|
|
with open(path, 'r', encoding='utf-8') as f: |
|
|
|
sample = [] |
|
|
|
for line in f: |
|
|
|
if line.startswith('\n'): |
|
|
|
datalist.append(sample) |
|
|
|
sample = [] |
|
|
|
elif line.startswith('#'): |
|
|
|
continue |
|
|
|
else: |
|
|
|
sample.append(line.split('\t')) |
|
|
|
if len(sample) > 0: |
|
|
|
datalist.append(sample) |
|
|
|
|
|
|
|
ds = DataSet() |
|
|
|
for sample in datalist: |
|
|
|
# print(sample) |
|
|
|
res = self.get_one(sample) |
|
|
|
if res is None: |
|
|
|
continue |
|
|
|
char_seq = [] |
|
|
|
pos_seq = [] |
|
|
|
for word, tag in zip(res[0], res[1]): |
|
|
|
if len(word) == 1: |
|
|
|
char_seq.append(word) |
|
|
|
pos_seq.append('S-{}'.format(tag)) |
|
|
|
elif len(word) > 1: |
|
|
|
pos_seq.append('B-{}'.format(tag)) |
|
|
|
for _ in range(len(word) - 2): |
|
|
|
pos_seq.append('M-{}'.format(tag)) |
|
|
|
pos_seq.append('E-{}'.format(tag)) |
|
|
|
char_seq.extend(list(word)) |
|
|
|
else: |
|
|
|
raise ValueError("Zero length of word detected.") |
|
|
|
|
|
|
|
ds.append(Instance(words=char_seq, |
|
|
|
tag=pos_seq)) |
|
|
|
return ds |
|
|
|
|
|
|
|
def get_one(self, sample): |
|
|
|
if len(sample) == 0: |
|
|
|
return None |
|
|
|
text = [] |
|
|
|
pos_tags = [] |
|
|
|
for w in sample: |
|
|
|
t1, t2, t3, t4 = w[1], w[3], w[6], w[7] |
|
|
|
if t3 == '_': |
|
|
|
return None |
|
|
|
text.append(t1) |
|
|
|
pos_tags.append(t2) |
|
|
|
return text, pos_tags |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class ConllxDataLoader(object): |
|
|
|
"""返回“词级别”的标签信息,包括词、词性、(句法)头依赖、(句法)边标签。跟``ZhConllPOSReader``完全不同。 |
|
|
|
|
|
|
|
""" |
|
|
|
def load(self, path): |
|
|
|
datalist = [] |
|
|
|
with open(path, 'r', encoding='utf-8') as f: |
|
|
|