@@ -22,7 +22,10 @@ class CRLoader(JsonLoader): | |||||
""" | """ | ||||
def __init__(self, fields=None, dropna=False): | def __init__(self, fields=None, dropna=False): | ||||
super().__init__(fields, dropna) | super().__init__(fields, dropna) | ||||
self.fields = {"doc_key":Const.INPUTS(0),"speakers":Const.INPUTS(1),"clusters":Const.TARGET,"sentences":Const.INPUTS(2)} | |||||
# self.fields = {"doc_key":Const.INPUTS(0),"speakers":Const.INPUTS(1),"clusters":Const.TARGET,"sentences":Const.INPUTS(2)} | |||||
# TODO check 1 | |||||
self.fields = {"doc_key": "raw_key", "speakers": "raw_speakers", "clusters": "raw_clusters", | |||||
"sentences": "raw_words"} | |||||
def _load(self, path): | def _load(self, path): | ||||
""" | """ | ||||
@@ -22,21 +22,56 @@ class CoreferencePipe(Pipe): | |||||
self.config = config | self.config = config | ||||
def process(self, data_bundle: DataBundle): | def process(self, data_bundle: DataBundle): | ||||
""" | |||||
对load进来的数据进一步处理 | |||||
原始数据包含:raw_key,raw_speaker,raw_words,raw_clusters | |||||
.. csv-table:: | |||||
:header: "raw_key", "raw_speaker","raw_words","raw_clusters" | |||||
"bc/cctv/00/cctv_0000_0", "[["Speaker#1", "Speaker#1"],[]]","[["I","am"],[]]","[[[2,3],[6,7]],[[10,12],[20,22]]]" | |||||
"bc/cctv/00/cctv_0000_1"", "[["Speaker#1", "Speaker#1"],[]]","[["He","is"],[]]","[[[2,3],[6,7]],[[10,12],[20,22]]]" | |||||
"[...]", "[...]","[...]","[...]" | |||||
处理完成后数据包含文章类别、speaker信息、句子信息、句子对应的index、char、句子长度、target: | |||||
.. csv-table:: | |||||
:header: "words1", "words2","words3","words4","chars","seq_len","target" | |||||
"bc", "[[0,0],[1,1]]","[["I","am"],[]]",[[1,2],[]],[[[1],[2,3]],[]],[2,3],"[[[2,3],[6,7]],[[10,12],[20,22]]]" | |||||
"[...]", "[...]","[...]","[...]","[...]","[...]","[...]" | |||||
:param data_bundle: | |||||
:return: | |||||
""" | |||||
genres = {g: i for i, g in enumerate(["bc", "bn", "mz", "nw", "pt", "tc", "wb"])} | genres = {g: i for i, g in enumerate(["bc", "bn", "mz", "nw", "pt", "tc", "wb"])} | ||||
vocab = Vocabulary().from_dataset(*data_bundle.datasets.values(), field_name=Const.INPUTS(2)) | |||||
vocab = Vocabulary().from_dataset(*data_bundle.datasets.values(), field_name="raw_words") | |||||
vocab.build_vocab() | vocab.build_vocab() | ||||
word2id = vocab.word2idx | word2id = vocab.word2idx | ||||
data_bundle.vocabs = {"vocab":vocab} | |||||
char_dict = get_char_dict(self.config.char_path) | |||||
data_bundle.set_vocab(vocab,"vocab") | |||||
if self.config.char_path: | |||||
char_dict = get_char_dict(self.config.char_path) | |||||
else: | |||||
char_set = set() | |||||
for i,w in enumerate(word2id): | |||||
if i < 2: | |||||
continue | |||||
for c in w: | |||||
char_set.add(c) | |||||
char_dict = collections.defaultdict(int) | |||||
char_dict.update({c: i for i, c in enumerate(char_set)}) | |||||
for name, ds in data_bundle.datasets.items(): | for name, ds in data_bundle.datasets.items(): | ||||
# genre | # genre | ||||
ds.apply(lambda x: genres[x[Const.INPUTS(0)][:2]], new_field_name=Const.INPUTS(0)) | |||||
ds.apply(lambda x: genres[x["raw_key"][:2]], new_field_name=Const.INPUTS(0)) | |||||
# speaker_ids_np | # speaker_ids_np | ||||
ds.apply(lambda x: speaker2numpy(x[Const.INPUTS(1)], self.config.max_sentences, is_train=name == 'train'), | |||||
ds.apply(lambda x: speaker2numpy(x["raw_speakers"], self.config.max_sentences, is_train=name == 'train'), | |||||
new_field_name=Const.INPUTS(1)) | new_field_name=Const.INPUTS(1)) | ||||
# sentences | |||||
ds.rename_field("raw_words",Const.INPUTS(2)) | |||||
# doc_np | # doc_np | ||||
ds.apply(lambda x: doc2numpy(x[Const.INPUTS(2)], word2id, char_dict, max(self.config.filter), | ds.apply(lambda x: doc2numpy(x[Const.INPUTS(2)], word2id, char_dict, max(self.config.filter), | ||||
self.config.max_sentences, is_train=name == 'train')[0], | self.config.max_sentences, is_train=name == 'train')[0], | ||||
@@ -50,6 +85,9 @@ class CoreferencePipe(Pipe): | |||||
self.config.max_sentences, is_train=name == 'train')[2], | self.config.max_sentences, is_train=name == 'train')[2], | ||||
new_field_name=Const.INPUT_LEN) | new_field_name=Const.INPUT_LEN) | ||||
# clusters | |||||
ds.rename_field("raw_clusters", Const.TARGET) | |||||
ds.set_ignore_type(Const.TARGET) | ds.set_ignore_type(Const.TARGET) | ||||
ds.set_padder(Const.TARGET, None) | ds.set_padder(Const.TARGET, None) | ||||
@@ -37,15 +37,15 @@ if __name__ == "__main__": | |||||
print(config) | print(config) | ||||
@cache_results('cache.pkl') | |||||
# @cache_results('cache.pkl') | |||||
def cache(): | def cache(): | ||||
bundle = CoreferencePipe(Config()).process_from_file({'train': config.train_path, 'dev': config.dev_path,'test': config.test_path}) | |||||
bundle = CoreferencePipe(config).process_from_file({'train': config.train_path, 'dev': config.dev_path,'test': config.test_path}) | |||||
return bundle | return bundle | ||||
data_info = cache() | |||||
print("数据集划分:\ntrain:", str(len(data_info.datasets["train"])), | |||||
"\ndev:" + str(len(data_info.datasets["dev"])) + "\ntest:" + str(len(data_info.datasets["test"]))) | |||||
data_bundle = cache() | |||||
print("数据集划分:\ntrain:", str(len(data_bundle.get_dataset("train"))), | |||||
"\ndev:" + str(len(data_bundle.get_dataset("dev"))) + "\ntest:" + str(len(data_bundle.get_dataset('test')))) | |||||
# print(data_info) | # print(data_info) | ||||
model = Model(data_info.vocabs['vocab'], config) | |||||
model = Model(data_bundle.vocabs['vocab'], config) | |||||
print(model) | print(model) | ||||
loss = SoftmaxLoss() | loss = SoftmaxLoss() | ||||
@@ -56,8 +56,8 @@ if __name__ == "__main__": | |||||
lr_decay_callback = LRCallback(optim.param_groups, config.lr_decay) | lr_decay_callback = LRCallback(optim.param_groups, config.lr_decay) | ||||
trainer = Trainer(model=model, train_data=data_info.datasets["train"], dev_data=data_info.datasets["dev"], | |||||
loss=loss, metrics=metric, check_code_level=-1,sampler=None, | |||||
trainer = Trainer(model=model, train_data=data_bundle.datasets["train"], dev_data=data_bundle.datasets["dev"], | |||||
loss=loss, metrics=metric, check_code_level=-1, sampler=None, | |||||
batch_size=1, device=torch.device("cuda:" + config.cuda), metric_key='f', n_epochs=config.epoch, | batch_size=1, device=torch.device("cuda:" + config.cuda), metric_key='f', n_epochs=config.epoch, | ||||
optimizer=optim, | optimizer=optim, | ||||
save_path='/remote-home/xxliu/pycharm/fastNLP/fastNLP/reproduction/coreference_resolution/save', | save_path='/remote-home/xxliu/pycharm/fastNLP/fastNLP/reproduction/coreference_resolution/save', | ||||
@@ -0,0 +1,16 @@ | |||||
from fastNLP.io.loader.coreference import CRLoader | |||||
import unittest | |||||
class TestCR(unittest.TestCase): | |||||
def test_load(self): | |||||
test_root = "../../data_for_tests/coreference/" | |||||
train_path = test_root+"coreference_train.json" | |||||
dev_path = test_root+"coreference_dev.json" | |||||
test_path = test_root+"coreference_test.json" | |||||
paths = {"train": train_path,"dev":dev_path,"test":test_path} | |||||
bundle1 = CRLoader().load(paths) | |||||
bundle2 = CRLoader().load(test_root) | |||||
print(bundle1) | |||||
print(bundle2) |
@@ -0,0 +1,24 @@ | |||||
import unittest | |||||
from fastNLP.io.pipe.coreference import CoreferencePipe | |||||
class TestCR(unittest.TestCase): | |||||
def test_load(self): | |||||
class Config(): | |||||
max_sentences = 50 | |||||
filter = [3, 4, 5] | |||||
char_path = None | |||||
config = Config() | |||||
file_root_path = "../../data_for_tests/coreference/" | |||||
train_path = file_root_path + "coreference_train.json" | |||||
dev_path = file_root_path + "coreference_dev.json" | |||||
test_path = file_root_path + "coreference_test.json" | |||||
paths = {"train": train_path, "dev": dev_path, "test": test_path} | |||||
bundle1 = CoreferencePipe(config).process_from_file(paths) | |||||
bundle2 = CoreferencePipe(config).process_from_file(file_root_path) | |||||
print(bundle1) | |||||
print(bundle2) |