@@ -0,0 +1,93 @@ | |||||
from fastNLP.core.vocabulary import VocabularyOption | |||||
from fastNLP.io.base_loader import DataSetLoader, DataInfo | |||||
from typing import Union, Dict | |||||
from fastNLP import Vocabulary | |||||
from fastNLP import Const | |||||
from reproduction.utils import check_dataloader_paths | |||||
from fastNLP.io import ConllLoader | |||||
from reproduction.seqence_labelling.ner.data.utils import iob2bioes, iob2 | |||||
class Conll2003DataLoader(DataSetLoader): | |||||
def __init__(self, task:str='ner', encoding_type:str='bioes'): | |||||
""" | |||||
加载Conll2003格式的英语语料,该数据集的信息可以在https://www.clips.uantwerpen.be/conll2003/ner/找到。当task为pos | |||||
时,返回的DataSet中target取值于第2列; 当task为chunk时,返回的DataSet中target取值于第3列;当task为ner时,返回 | |||||
的DataSet中target取值于第4列。所有"-DOCSTART- -X- O O"将被忽略,这会导致数据的数量少于很多文献报道的值,但 | |||||
鉴于"-DOCSTART- -X- O O"只是用于文档分割的符号,并不应该作为预测对象,所以我们忽略了数据中的-DOCTSTART-开头的行 | |||||
ner与chunk任务读取后的数据的target将为encoding_type类型。pos任务读取后就是pos列的数据。 | |||||
:param task: 指定需要标注任务。可选ner, pos, chunk | |||||
""" | |||||
assert task in ('ner', 'pos', 'chunk') | |||||
index = {'ner':3, 'pos':1, 'chunk':2}[task] | |||||
self._loader = ConllLoader(headers=['raw_words', 'target'], indexes=[0, index]) | |||||
self._tag_converters = [] | |||||
if task in ('ner', 'chunk'): | |||||
self._tag_converters = [iob2] | |||||
if encoding_type == 'bioes': | |||||
self._tag_converters.append(iob2bioes) | |||||
def load(self, path: str): | |||||
dataset = self._loader.load(path) | |||||
def convert_tag_schema(tags): | |||||
for converter in self._tag_converters: | |||||
tags = converter(tags) | |||||
return tags | |||||
if self._tag_converters: | |||||
dataset.apply_field(convert_tag_schema, field_name=Const.TARGET, new_field_name=Const.TARGET) | |||||
return dataset | |||||
def process(self, paths: Union[str, Dict[str, str]], word_vocab_opt:VocabularyOption=None, lower:bool=False): | |||||
""" | |||||
读取并处理数据。数据中的'-DOCSTART-'开头的行会被忽略 | |||||
:param paths: | |||||
:param word_vocab_opt: vocabulary的初始化值 | |||||
:param lower: 是否将所有字母转为小写。 | |||||
:return: | |||||
""" | |||||
# 读取数据 | |||||
paths = check_dataloader_paths(paths) | |||||
data = DataInfo() | |||||
input_fields = [Const.TARGET, Const.INPUT, Const.INPUT_LEN] | |||||
target_fields = [Const.TARGET, Const.INPUT_LEN] | |||||
for name, path in paths.items(): | |||||
dataset = self.load(path) | |||||
dataset.apply_field(lambda words: words, field_name='raw_words', new_field_name=Const.INPUT) | |||||
if lower: | |||||
dataset.words.lower() | |||||
data.datasets[name] = dataset | |||||
# 对construct vocab | |||||
word_vocab = Vocabulary(min_freq=2) if word_vocab_opt is None else Vocabulary(**word_vocab_opt) | |||||
word_vocab.from_dataset(data.datasets['train'], field_name=Const.INPUT, | |||||
no_create_entry_dataset=[dataset for name, dataset in data.datasets.items() if name!='train']) | |||||
word_vocab.index_dataset(*data.datasets.values(), field_name=Const.INPUT, new_field_name=Const.INPUT) | |||||
data.vocabs[Const.INPUT] = word_vocab | |||||
# cap words | |||||
cap_word_vocab = Vocabulary() | |||||
cap_word_vocab.from_dataset(data.datasets['train'], field_name='raw_words', | |||||
no_create_entry_dataset=[dataset for name, dataset in data.datasets.items() if name!='train']) | |||||
cap_word_vocab.index_dataset(*data.datasets.values(), field_name='raw_words', new_field_name='cap_words') | |||||
input_fields.append('cap_words') | |||||
data.vocabs['cap_words'] = cap_word_vocab | |||||
# 对target建vocab | |||||
target_vocab = Vocabulary(unknown=None, padding=None) | |||||
target_vocab.from_dataset(*data.datasets.values(), field_name=Const.TARGET) | |||||
target_vocab.index_dataset(*data.datasets.values(), field_name=Const.TARGET) | |||||
data.vocabs[Const.TARGET] = target_vocab | |||||
for name, dataset in data.datasets.items(): | |||||
dataset.add_seq_len(Const.INPUT, new_field_name=Const.INPUT_LEN) | |||||
dataset.set_input(*input_fields) | |||||
dataset.set_target(*target_fields) | |||||
return data | |||||
if __name__ == '__main__': | |||||
pass |
@@ -0,0 +1,152 @@ | |||||
from fastNLP.core.vocabulary import VocabularyOption | |||||
from fastNLP.io.base_loader import DataSetLoader, DataInfo | |||||
from typing import Union, Dict | |||||
from fastNLP import DataSet | |||||
from fastNLP import Vocabulary | |||||
from fastNLP import Const | |||||
from reproduction.utils import check_dataloader_paths | |||||
from fastNLP.io.dataset_loader import ConllLoader | |||||
from reproduction.seqence_labelling.ner.data.utils import iob2bioes, iob2 | |||||
class OntoNoteNERDataLoader(DataSetLoader): | |||||
""" | |||||
用于读取处理为Conll格式后的OntoNote数据。将OntoNote数据处理为conll格式的过程可以参考https://github.com/yhcc/OntoNotes-5.0-NER。 | |||||
""" | |||||
def __init__(self, encoding_type:str='bioes'): | |||||
assert encoding_type in ('bioes', 'bio') | |||||
self.encoding_type = encoding_type | |||||
if encoding_type=='bioes': | |||||
self.encoding_method = iob2bioes | |||||
else: | |||||
self.encoding_method = iob2 | |||||
def load(self, path:str)->DataSet: | |||||
""" | |||||
给定一个文件路径,读取数据。返回的DataSet包含以下的field | |||||
raw_words: List[str] | |||||
target: List[str] | |||||
:param path: | |||||
:return: | |||||
""" | |||||
dataset = ConllLoader(headers=['raw_words', 'target'], indexes=[3, 10]).load(path) | |||||
def convert_to_bio(tags): | |||||
bio_tags = [] | |||||
flag = None | |||||
for tag in tags: | |||||
label = tag.strip("()*") | |||||
if '(' in tag: | |||||
bio_label = 'B-' + label | |||||
flag = label | |||||
elif flag: | |||||
bio_label = 'I-' + flag | |||||
else: | |||||
bio_label = 'O' | |||||
if ')' in tag: | |||||
flag = None | |||||
bio_tags.append(bio_label) | |||||
return self.encoding_method(bio_tags) | |||||
def convert_word(words): | |||||
converted_words = [] | |||||
for word in words: | |||||
word = word.replace('/.', '.') # 有些结尾的.是/.形式的 | |||||
if not word.startswith('-'): | |||||
converted_words.append(word) | |||||
continue | |||||
# 以下是由于这些符号被转义了,再转回来 | |||||
tfrs = {'-LRB-':'(', | |||||
'-RRB-': ')', | |||||
'-LSB-': '[', | |||||
'-RSB-': ']', | |||||
'-LCB-': '{', | |||||
'-RCB-': '}' | |||||
} | |||||
if word in tfrs: | |||||
converted_words.append(tfrs[word]) | |||||
else: | |||||
converted_words.append(word) | |||||
return converted_words | |||||
dataset.apply_field(convert_word, field_name='raw_words', new_field_name='raw_words') | |||||
dataset.apply_field(convert_to_bio, field_name='target', new_field_name='target') | |||||
return dataset | |||||
def process(self, paths: Union[str, Dict[str, str]], word_vocab_opt:VocabularyOption=None, | |||||
lower:bool=True)->DataInfo: | |||||
""" | |||||
读取并处理数据。返回的DataInfo包含以下的内容 | |||||
vocabs: | |||||
word: Vocabulary | |||||
target: Vocabulary | |||||
datasets: | |||||
train: DataSet | |||||
words: List[int], 被设置为input | |||||
target: int. label,被同时设置为input和target | |||||
seq_len: int. 句子的长度,被同时设置为input和target | |||||
raw_words: List[str] | |||||
xxx(根据传入的paths可能有所变化) | |||||
:param paths: | |||||
:param word_vocab_opt: vocabulary的初始化值 | |||||
:param lower: 是否使用小写 | |||||
:return: | |||||
""" | |||||
paths = check_dataloader_paths(paths) | |||||
data = DataInfo() | |||||
input_fields = [Const.TARGET, Const.INPUT, Const.INPUT_LEN] | |||||
target_fields = [Const.TARGET, Const.INPUT_LEN] | |||||
for name, path in paths.items(): | |||||
dataset = self.load(path) | |||||
dataset.apply_field(lambda words: words, field_name='raw_words', new_field_name=Const.INPUT) | |||||
if lower: | |||||
dataset.words.lower() | |||||
data.datasets[name] = dataset | |||||
# 对construct vocab | |||||
word_vocab = Vocabulary(min_freq=2) if word_vocab_opt is None else Vocabulary(**word_vocab_opt) | |||||
word_vocab.from_dataset(data.datasets['train'], field_name=Const.INPUT, | |||||
no_create_entry_dataset=[dataset for name, dataset in data.datasets.items() if name!='train']) | |||||
word_vocab.index_dataset(*data.datasets.values(), field_name=Const.INPUT, new_field_name=Const.INPUT) | |||||
data.vocabs[Const.INPUT] = word_vocab | |||||
# cap words | |||||
cap_word_vocab = Vocabulary() | |||||
cap_word_vocab.from_dataset(*data.datasets.values(), field_name='raw_words') | |||||
cap_word_vocab.index_dataset(*data.datasets.values(), field_name='raw_words', new_field_name='cap_words') | |||||
input_fields.append('cap_words') | |||||
data.vocabs['cap_words'] = cap_word_vocab | |||||
# 对target建vocab | |||||
target_vocab = Vocabulary(unknown=None, padding=None) | |||||
target_vocab.from_dataset(*data.datasets.values(), field_name=Const.TARGET) | |||||
target_vocab.index_dataset(*data.datasets.values(), field_name=Const.TARGET) | |||||
data.vocabs[Const.TARGET] = target_vocab | |||||
for name, dataset in data.datasets.items(): | |||||
dataset.add_seq_len(Const.INPUT, new_field_name=Const.INPUT_LEN) | |||||
dataset.set_input(*input_fields) | |||||
dataset.set_target(*target_fields) | |||||
return data | |||||
if __name__ == '__main__': | |||||
loader = OntoNoteNERDataLoader() | |||||
dataset = loader.load('/hdd/fudanNLP/fastNLP/others/data/v4/english/test.txt') | |||||
print(dataset.target.value_count()) | |||||
print(dataset[:4]) | |||||
""" | |||||
train 115812 2200752 | |||||
development 15680 304684 | |||||
test 12217 230111 | |||||
train 92403 1901772 | |||||
valid 13606 279180 | |||||
test 10258 204135 | |||||
""" |
@@ -0,0 +1,49 @@ | |||||
from typing import List | |||||
def iob2(tags:List[str])->List[str]: | |||||
""" | |||||
检查数据是否是合法的IOB数据,如果是IOB1会被自动转换为IOB2。 | |||||
:param tags: 需要转换的tags | |||||
""" | |||||
for i, tag in enumerate(tags): | |||||
if tag == "O": | |||||
continue | |||||
split = tag.split("-") | |||||
if len(split) != 2 or split[0] not in ["I", "B"]: | |||||
raise TypeError("The encoding schema is not a valid IOB type.") | |||||
if split[0] == "B": | |||||
continue | |||||
elif i == 0 or tags[i - 1] == "O": # conversion IOB1 to IOB2 | |||||
tags[i] = "B" + tag[1:] | |||||
elif tags[i - 1][1:] == tag[1:]: | |||||
continue | |||||
else: # conversion IOB1 to IOB2 | |||||
tags[i] = "B" + tag[1:] | |||||
return tags | |||||
def iob2bioes(tags:List[str])->List[str]: | |||||
""" | |||||
将iob的tag转换为bmeso编码 | |||||
:param tags: | |||||
:return: | |||||
""" | |||||
new_tags = [] | |||||
for i, tag in enumerate(tags): | |||||
if tag == 'O': | |||||
new_tags.append(tag) | |||||
else: | |||||
split = tag.split('-')[0] | |||||
if split == 'B': | |||||
if i+1!=len(tags) and tags[i+1].split('-')[0] == 'I': | |||||
new_tags.append(tag) | |||||
else: | |||||
new_tags.append(tag.replace('B-', 'S-')) | |||||
elif split == 'I': | |||||
if i + 1<len(tags) and tags[i+1].split('-')[0] == 'I': | |||||
new_tags.append(tag) | |||||
else: | |||||
new_tags.append(tag.replace('I-', 'E-')) | |||||
else: | |||||
raise TypeError("Invalid IOB format.") | |||||
return new_tags |