@@ -9,9 +9,7 @@ from fastNLP.core.dataset import DataSet | |||||
from fastNLP.api.utils import load_url | from fastNLP.api.utils import load_url | ||||
from fastNLP.api.processor import ModelProcessor | from fastNLP.api.processor import ModelProcessor | ||||
from reproduction.chinese_word_segment.cws_io.cws_reader import ConllCWSReader | |||||
from reproduction.pos_tag_model.pos_reader import ZhConllPOSReader | |||||
from reproduction.Biaffine_parser.util import ConllxDataLoader, add_seg_tag | |||||
from fastNLP.io.dataset_loader import ConllCWSReader, ZhConllPOSReader, ConllxDataLoader, add_seg_tag | |||||
from fastNLP.core.instance import Instance | from fastNLP.core.instance import Instance | ||||
from fastNLP.api.pipeline import Pipeline | from fastNLP.api.pipeline import Pipeline | ||||
from fastNLP.core.metrics import SpanFPreRecMetric | from fastNLP.core.metrics import SpanFPreRecMetric | ||||
@@ -31,6 +29,16 @@ class API: | |||||
self._dict = None | self._dict = None | ||||
def predict(self, *args, **kwargs): | def predict(self, *args, **kwargs): | ||||
"""Do prediction for the given input. | |||||
""" | |||||
raise NotImplementedError | |||||
def test(self, file_path): | |||||
"""Test performance over the given data set. | |||||
:param str file_path: | |||||
:return: a dictionary of metric values | |||||
""" | |||||
raise NotImplementedError | raise NotImplementedError | ||||
def load(self, path, device): | def load(self, path, device): | ||||
@@ -322,3 +322,103 @@ class SetInputProcessor(Processor): | |||||
def process(self, dataset): | def process(self, dataset): | ||||
dataset.set_input(*self.fields, flag=self.flag) | dataset.set_input(*self.fields, flag=self.flag) | ||||
return dataset | return dataset | ||||
class VocabIndexerProcessor(Processor): | |||||
""" | |||||
根据DataSet创建Vocabulary,并将其用数字index。新生成的index的field会被放在new_added_filed_name, 如果没有提供 | |||||
new_added_field_name, 则覆盖原有的field_name. | |||||
""" | |||||
def __init__(self, field_name, new_added_filed_name=None, min_freq=1, max_size=None, | |||||
verbose=0, is_input=True): | |||||
""" | |||||
:param field_name: 从哪个field_name创建词表,以及对哪个field_name进行index操作 | |||||
:param new_added_filed_name: index时,生成的index field的名称,如果不传入,则覆盖field_name. | |||||
:param min_freq: 创建的Vocabulary允许的单词最少出现次数. | |||||
:param max_size: 创建的Vocabulary允许的最大的单词数量 | |||||
:param verbose: 0, 不输出任何信息;1,输出信息 | |||||
:param bool is_input: | |||||
""" | |||||
super(VocabIndexerProcessor, self).__init__(field_name, new_added_filed_name) | |||||
self.min_freq = min_freq | |||||
self.max_size = max_size | |||||
self.verbose = verbose | |||||
self.is_input = is_input | |||||
def construct_vocab(self, *datasets): | |||||
""" | |||||
使用传入的DataSet创建vocabulary | |||||
:param datasets: DataSet类型的数据,用于构建vocabulary | |||||
:return: | |||||
""" | |||||
self.vocab = Vocabulary(min_freq=self.min_freq, max_size=self.max_size) | |||||
for dataset in datasets: | |||||
assert isinstance(dataset, DataSet), "Only Dataset class is allowed, not {}.".format(type(dataset)) | |||||
dataset.apply(lambda ins: self.vocab.update(ins[self.field_name])) | |||||
self.vocab.build_vocab() | |||||
if self.verbose: | |||||
print("Vocabulary Constructed, has {} items.".format(len(self.vocab))) | |||||
def process(self, *datasets, only_index_dataset=None): | |||||
""" | |||||
若还未建立Vocabulary,则使用dataset中的DataSet建立vocabulary;若已经有了vocabulary则使用已有的vocabulary。得到vocabulary | |||||
后,则会index datasets与only_index_dataset。 | |||||
:param datasets: DataSet类型的数据 | |||||
:param only_index_dataset: DataSet, or list of DataSet. 该参数中的内容只会被用于index,不会被用于生成vocabulary。 | |||||
:return: | |||||
""" | |||||
if len(datasets) == 0 and not hasattr(self, 'vocab'): | |||||
raise RuntimeError("You have to construct vocabulary first. Or you have to pass datasets to construct it.") | |||||
if not hasattr(self, 'vocab'): | |||||
self.construct_vocab(*datasets) | |||||
else: | |||||
if self.verbose: | |||||
print("Using constructed vocabulary with {} items.".format(len(self.vocab))) | |||||
to_index_datasets = [] | |||||
if len(datasets) != 0: | |||||
for dataset in datasets: | |||||
assert isinstance(dataset, DataSet), "Only DataSet class is allowed, not {}.".format(type(dataset)) | |||||
to_index_datasets.append(dataset) | |||||
if not (only_index_dataset is None): | |||||
if isinstance(only_index_dataset, list): | |||||
for dataset in only_index_dataset: | |||||
assert isinstance(dataset, DataSet), "Only DataSet class is allowed, not {}.".format(type(dataset)) | |||||
to_index_datasets.append(dataset) | |||||
elif isinstance(only_index_dataset, DataSet): | |||||
to_index_datasets.append(only_index_dataset) | |||||
else: | |||||
raise TypeError('Only DataSet or list of DataSet is allowed, not {}.'.format(type(only_index_dataset))) | |||||
for dataset in to_index_datasets: | |||||
assert isinstance(dataset, DataSet), "Only DataSet class is allowed, not {}.".format(type(dataset)) | |||||
dataset.apply(lambda ins: [self.vocab.to_index(token) for token in ins[self.field_name]], | |||||
new_field_name=self.new_added_field_name, is_input=self.is_input) | |||||
# 只返回一个,infer时为了跟其他processor保持一致 | |||||
if len(to_index_datasets) == 1: | |||||
return to_index_datasets[0] | |||||
def set_vocab(self, vocab): | |||||
assert isinstance(vocab, Vocabulary), "Only fastNLP.core.Vocabulary is allowed, not {}.".format(type(vocab)) | |||||
self.vocab = vocab | |||||
def delete_vocab(self): | |||||
del self.vocab | |||||
def get_vocab_size(self): | |||||
return len(self.vocab) | |||||
def set_verbose(self, verbose): | |||||
""" | |||||
设置processor verbose状态。 | |||||
:param verbose: int, 0,不输出任何信息;1,输出vocab 信息。 | |||||
:return: | |||||
""" | |||||
self.verbose = verbose |
@@ -283,7 +283,7 @@ class Trainer(object): | |||||
self.callback_manager.after_batch() | self.callback_manager.after_batch() | ||||
if ((self.validate_every > 0 and self.step % self.validate_every == 0) or | if ((self.validate_every > 0 and self.step % self.validate_every == 0) or | ||||
(self.validate_every < 0 and self.step % len(data_iterator)) == 0) \ | |||||
(self.validate_every < 0 and self.step % len(data_iterator) == 0)) \ | |||||
and self.dev_data is not None: | and self.dev_data is not None: | ||||
eval_res = self._do_validation(epoch=epoch, step=self.step) | eval_res = self._do_validation(epoch=epoch, step=self.step) | ||||
eval_str = "Evaluation at Epoch {}/{}. Step:{}/{}. ".format(epoch, self.n_epochs, self.step, | eval_str = "Evaluation at Epoch {}/{}. Step:{}/{}. ".format(epoch, self.n_epochs, self.step, | ||||
@@ -367,12 +367,23 @@ class Trainer(object): | |||||
return self.losser(predict, truth) | return self.losser(predict, truth) | ||||
def _save_model(self, model, model_name, only_param=False): | def _save_model(self, model, model_name, only_param=False): | ||||
""" 存储不含有显卡信息的state_dict或model | |||||
:param model: | |||||
:param model_name: | |||||
:param only_param: | |||||
:return: | |||||
""" | |||||
if self.save_path is not None: | if self.save_path is not None: | ||||
model_name = os.path.join(self.save_path, model_name) | |||||
model_path = os.path.join(self.save_path, model_name) | |||||
if only_param: | if only_param: | ||||
torch.save(model.state_dict(), model_name) | |||||
state_dict = model.state_dict() | |||||
for key in state_dict: | |||||
state_dict[key] = state_dict[key].cpu() | |||||
torch.save(state_dict, model_path) | |||||
else: | else: | ||||
torch.save(model, model_name) | |||||
model.cpu() | |||||
torch.save(model, model_path) | |||||
model.cuda() | |||||
def _load_model(self, model, model_name, only_param=False): | def _load_model(self, model, model_name, only_param=False): | ||||
# 返回bool值指示是否成功reload模型 | # 返回bool值指示是否成功reload模型 | ||||
@@ -90,6 +90,7 @@ class NativeDataSetLoader(DataSetLoader): | |||||
"""A simple example of DataSetLoader | """A simple example of DataSetLoader | ||||
""" | """ | ||||
def __init__(self): | def __init__(self): | ||||
super(NativeDataSetLoader, self).__init__() | super(NativeDataSetLoader, self).__init__() | ||||
@@ -107,6 +108,7 @@ class RawDataSetLoader(DataSetLoader): | |||||
"""A simple example of raw data reader | """A simple example of raw data reader | ||||
""" | """ | ||||
def __init__(self): | def __init__(self): | ||||
super(RawDataSetLoader, self).__init__() | super(RawDataSetLoader, self).__init__() | ||||
@@ -142,6 +144,7 @@ class POSDataSetLoader(DataSetLoader): | |||||
In this example, there are two sentences "Tom and Jerry ." and "Hello world !". Each word has its own label. | In this example, there are two sentences "Tom and Jerry ." and "Hello world !". Each word has its own label. | ||||
""" | """ | ||||
def __init__(self): | def __init__(self): | ||||
super(POSDataSetLoader, self).__init__() | super(POSDataSetLoader, self).__init__() | ||||
@@ -540,3 +543,373 @@ class SNLIDataSetLoader(DataSetLoader): | |||||
data_set.set_input("premise", "hypothesis", "premise_len", "hypothesis_len") | data_set.set_input("premise", "hypothesis", "premise_len", "hypothesis_len") | ||||
data_set.set_target("truth") | data_set.set_target("truth") | ||||
return data_set | return data_set | ||||
class ConllCWSReader(object): | |||||
def __init__(self): | |||||
pass | |||||
def load(self, path, cut_long_sent=False): | |||||
""" | |||||
返回的DataSet只包含raw_sentence这个field,内容为str。 | |||||
假定了输入为conll的格式,以空行隔开两个句子,每行共7列,即 | |||||
1 编者按 编者按 NN O 11 nmod:topic | |||||
2 : : PU O 11 punct | |||||
3 7月 7月 NT DATE 4 compound:nn | |||||
4 12日 12日 NT DATE 11 nmod:tmod | |||||
5 , , PU O 11 punct | |||||
1 这 这 DT O 3 det | |||||
2 款 款 M O 1 mark:clf | |||||
3 飞行 飞行 NN O 8 nsubj | |||||
4 从 从 P O 5 case | |||||
5 外型 外型 NN O 8 nmod:prep | |||||
""" | |||||
datalist = [] | |||||
with open(path, 'r', encoding='utf-8') as f: | |||||
sample = [] | |||||
for line in f: | |||||
if line.startswith('\n'): | |||||
datalist.append(sample) | |||||
sample = [] | |||||
elif line.startswith('#'): | |||||
continue | |||||
else: | |||||
sample.append(line.split('\t')) | |||||
if len(sample) > 0: | |||||
datalist.append(sample) | |||||
ds = DataSet() | |||||
for sample in datalist: | |||||
# print(sample) | |||||
res = self.get_char_lst(sample) | |||||
if res is None: | |||||
continue | |||||
line = ' '.join(res) | |||||
if cut_long_sent: | |||||
sents = cut_long_sentence(line) | |||||
else: | |||||
sents = [line] | |||||
for raw_sentence in sents: | |||||
ds.append(Instance(raw_sentence=raw_sentence)) | |||||
return ds | |||||
def get_char_lst(self, sample): | |||||
if len(sample) == 0: | |||||
return None | |||||
text = [] | |||||
for w in sample: | |||||
t1, t2, t3, t4 = w[1], w[3], w[6], w[7] | |||||
if t3 == '_': | |||||
return None | |||||
text.append(t1) | |||||
return text | |||||
class POSCWSReader(DataSetLoader): | |||||
""" | |||||
支持读取以下的情况, 即每一行是一个词, 用空行作为两句话的界限. | |||||
迈 N | |||||
向 N | |||||
充 N | |||||
... | |||||
泽 I-PER | |||||
民 I-PER | |||||
( N | |||||
一 N | |||||
九 N | |||||
... | |||||
:param filepath: | |||||
:return: | |||||
""" | |||||
def __init__(self, in_word_splitter=None): | |||||
super().__init__() | |||||
self.in_word_splitter = in_word_splitter | |||||
def load(self, filepath, in_word_splitter=None, cut_long_sent=False): | |||||
if in_word_splitter is None: | |||||
in_word_splitter = self.in_word_splitter | |||||
dataset = DataSet() | |||||
with open(filepath, 'r') as f: | |||||
words = [] | |||||
for line in f: | |||||
line = line.strip() | |||||
if len(line) == 0: # new line | |||||
if len(words) == 0: # 不能接受空行 | |||||
continue | |||||
line = ' '.join(words) | |||||
if cut_long_sent: | |||||
sents = cut_long_sentence(line) | |||||
else: | |||||
sents = [line] | |||||
for sent in sents: | |||||
instance = Instance(raw_sentence=sent) | |||||
dataset.append(instance) | |||||
words = [] | |||||
else: | |||||
line = line.split()[0] | |||||
if in_word_splitter is None: | |||||
words.append(line) | |||||
else: | |||||
words.append(line.split(in_word_splitter)[0]) | |||||
return dataset | |||||
class NaiveCWSReader(DataSetLoader): | |||||
""" | |||||
这个reader假设了分词数据集为以下形式, 即已经用空格分割好内容了 | |||||
这是 fastNLP , 一个 非常 good 的 包 . | |||||
或者,即每个part后面还有一个pos tag | |||||
也/D 在/P 團員/Na 之中/Ng ,/COMMACATEGORY | |||||
""" | |||||
def __init__(self, in_word_splitter=None): | |||||
super().__init__() | |||||
self.in_word_splitter = in_word_splitter | |||||
def load(self, filepath, in_word_splitter=None, cut_long_sent=False): | |||||
""" | |||||
允许使用的情况有(默认以\t或空格作为seg) | |||||
这是 fastNLP , 一个 非常 good 的 包 . | |||||
和 | |||||
也/D 在/P 團員/Na 之中/Ng ,/COMMACATEGORY | |||||
如果splitter不为None则认为是第二种情况, 且我们会按splitter分割"也/D", 然后取第一部分. 例如"也/D".split('/')[0] | |||||
:param filepath: | |||||
:param in_word_splitter: | |||||
:return: | |||||
""" | |||||
if in_word_splitter == None: | |||||
in_word_splitter = self.in_word_splitter | |||||
dataset = DataSet() | |||||
with open(filepath, 'r') as f: | |||||
for line in f: | |||||
line = line.strip() | |||||
if len(line.replace(' ', '')) == 0: # 不能接受空行 | |||||
continue | |||||
if not in_word_splitter is None: | |||||
words = [] | |||||
for part in line.split(): | |||||
word = part.split(in_word_splitter)[0] | |||||
words.append(word) | |||||
line = ' '.join(words) | |||||
if cut_long_sent: | |||||
sents = cut_long_sentence(line) | |||||
else: | |||||
sents = [line] | |||||
for sent in sents: | |||||
instance = Instance(raw_sentence=sent) | |||||
dataset.append(instance) | |||||
return dataset | |||||
def cut_long_sentence(sent, max_sample_length=200): | |||||
""" | |||||
将长于max_sample_length的sentence截成多段,只会在有空格的地方发生截断。所以截取的句子可能长于或者短于max_sample_length | |||||
:param sent: str. | |||||
:param max_sample_length: int. | |||||
:return: list of str. | |||||
""" | |||||
sent_no_space = sent.replace(' ', '') | |||||
cutted_sentence = [] | |||||
if len(sent_no_space) > max_sample_length: | |||||
parts = sent.strip().split() | |||||
new_line = '' | |||||
length = 0 | |||||
for part in parts: | |||||
length += len(part) | |||||
new_line += part + ' ' | |||||
if length > max_sample_length: | |||||
new_line = new_line[:-1] | |||||
cutted_sentence.append(new_line) | |||||
length = 0 | |||||
new_line = '' | |||||
if new_line != '': | |||||
cutted_sentence.append(new_line[:-1]) | |||||
else: | |||||
cutted_sentence.append(sent) | |||||
return cutted_sentence | |||||
class ZhConllPOSReader(object): | |||||
# 中文colln格式reader | |||||
def __init__(self): | |||||
pass | |||||
def load(self, path): | |||||
""" | |||||
返回的DataSet, 包含以下的field | |||||
words:list of str, | |||||
tag: list of str, 被加入了BMES tag, 比如原来的序列为['VP', 'NN', 'NN', ..],会被认为是["S-VP", "B-NN", "M-NN",..] | |||||
假定了输入为conll的格式,以空行隔开两个句子,每行共7列,即 | |||||
1 编者按 编者按 NN O 11 nmod:topic | |||||
2 : : PU O 11 punct | |||||
3 7月 7月 NT DATE 4 compound:nn | |||||
4 12日 12日 NT DATE 11 nmod:tmod | |||||
5 , , PU O 11 punct | |||||
1 这 这 DT O 3 det | |||||
2 款 款 M O 1 mark:clf | |||||
3 飞行 飞行 NN O 8 nsubj | |||||
4 从 从 P O 5 case | |||||
5 外型 外型 NN O 8 nmod:prep | |||||
""" | |||||
datalist = [] | |||||
with open(path, 'r', encoding='utf-8') as f: | |||||
sample = [] | |||||
for line in f: | |||||
if line.startswith('\n'): | |||||
datalist.append(sample) | |||||
sample = [] | |||||
elif line.startswith('#'): | |||||
continue | |||||
else: | |||||
sample.append(line.split('\t')) | |||||
if len(sample) > 0: | |||||
datalist.append(sample) | |||||
ds = DataSet() | |||||
for sample in datalist: | |||||
# print(sample) | |||||
res = self.get_one(sample) | |||||
if res is None: | |||||
continue | |||||
char_seq = [] | |||||
pos_seq = [] | |||||
for word, tag in zip(res[0], res[1]): | |||||
char_seq.extend(list(word)) | |||||
if len(word) == 1: | |||||
pos_seq.append('S-{}'.format(tag)) | |||||
elif len(word) > 1: | |||||
pos_seq.append('B-{}'.format(tag)) | |||||
for _ in range(len(word) - 2): | |||||
pos_seq.append('M-{}'.format(tag)) | |||||
pos_seq.append('E-{}'.format(tag)) | |||||
else: | |||||
raise ValueError("Zero length of word detected.") | |||||
ds.append(Instance(words=char_seq, | |||||
tag=pos_seq)) | |||||
return ds | |||||
def get_one(self, sample): | |||||
if len(sample) == 0: | |||||
return None | |||||
text = [] | |||||
pos_tags = [] | |||||
for w in sample: | |||||
t1, t2, t3, t4 = w[1], w[3], w[6], w[7] | |||||
if t3 == '_': | |||||
return None | |||||
text.append(t1) | |||||
pos_tags.append(t2) | |||||
return text, pos_tags | |||||
class ConllPOSReader(object): | |||||
# 返回的Dataset包含words(list of list, 里层的list是character), tag两个field(list of str, str是标有BIO的tag)。 | |||||
def __init__(self): | |||||
pass | |||||
def load(self, path): | |||||
datalist = [] | |||||
with open(path, 'r', encoding='utf-8') as f: | |||||
sample = [] | |||||
for line in f: | |||||
if line.startswith('\n'): | |||||
datalist.append(sample) | |||||
sample = [] | |||||
elif line.startswith('#'): | |||||
continue | |||||
else: | |||||
sample.append(line.split('\t')) | |||||
if len(sample) > 0: | |||||
datalist.append(sample) | |||||
ds = DataSet() | |||||
for sample in datalist: | |||||
# print(sample) | |||||
res = self.get_one(sample) | |||||
if res is None: | |||||
continue | |||||
char_seq = [] | |||||
pos_seq = [] | |||||
for word, tag in zip(res[0], res[1]): | |||||
if len(word) == 1: | |||||
char_seq.append(word) | |||||
pos_seq.append('S-{}'.format(tag)) | |||||
elif len(word) > 1: | |||||
pos_seq.append('B-{}'.format(tag)) | |||||
for _ in range(len(word) - 2): | |||||
pos_seq.append('M-{}'.format(tag)) | |||||
pos_seq.append('E-{}'.format(tag)) | |||||
char_seq.extend(list(word)) | |||||
else: | |||||
raise ValueError("Zero length of word detected.") | |||||
ds.append(Instance(words=char_seq, | |||||
tag=pos_seq)) | |||||
return ds | |||||
class ConllxDataLoader(object): | |||||
def load(self, path): | |||||
datalist = [] | |||||
with open(path, 'r', encoding='utf-8') as f: | |||||
sample = [] | |||||
for line in f: | |||||
if line.startswith('\n'): | |||||
datalist.append(sample) | |||||
sample = [] | |||||
elif line.startswith('#'): | |||||
continue | |||||
else: | |||||
sample.append(line.split('\t')) | |||||
if len(sample) > 0: | |||||
datalist.append(sample) | |||||
data = [self.get_one(sample) for sample in datalist] | |||||
return list(filter(lambda x: x is not None, data)) | |||||
def get_one(self, sample): | |||||
sample = list(map(list, zip(*sample))) | |||||
if len(sample) == 0: | |||||
return None | |||||
for w in sample[7]: | |||||
if w == '_': | |||||
print('Error Sample {}'.format(sample)) | |||||
return None | |||||
# return word_seq, pos_seq, head_seq, head_tag_seq | |||||
return sample[1], sample[3], list(map(int, sample[6])), sample[7] | |||||
def add_seg_tag(data): | |||||
""" | |||||
:param data: list of ([word], [pos], [heads], [head_tags]) | |||||
:return: list of ([word], [pos]) | |||||
""" | |||||
_processed = [] | |||||
for word_list, pos_list, _, _ in data: | |||||
new_sample = [] | |||||
for word, pos in zip(word_list, pos_list): | |||||
if len(word) == 1: | |||||
new_sample.append((word, 'S-' + pos)) | |||||
else: | |||||
new_sample.append((word[0], 'B-' + pos)) | |||||
for c in word[1:-1]: | |||||
new_sample.append((c, 'M-' + pos)) | |||||
new_sample.append((word[-1], 'E-' + pos)) | |||||
_processed.append(list(map(list, zip(*new_sample)))) | |||||
return _processed |
@@ -6,6 +6,7 @@ from torch import nn | |||||
from torch.nn import functional as F | from torch.nn import functional as F | ||||
from fastNLP.modules.utils import initial_parameter | from fastNLP.modules.utils import initial_parameter | ||||
from fastNLP.modules.encoder.variational_rnn import VarLSTM | from fastNLP.modules.encoder.variational_rnn import VarLSTM | ||||
from fastNLP.modules.encoder.transformer import TransformerEncoder | |||||
from fastNLP.modules.dropout import TimestepDropout | from fastNLP.modules.dropout import TimestepDropout | ||||
from fastNLP.models.base_model import BaseModel | from fastNLP.models.base_model import BaseModel | ||||
from fastNLP.modules.utils import seq_mask | from fastNLP.modules.utils import seq_mask | ||||
@@ -197,53 +198,49 @@ class BiaffineParser(GraphParser): | |||||
pos_vocab_size, | pos_vocab_size, | ||||
pos_emb_dim, | pos_emb_dim, | ||||
num_label, | num_label, | ||||
word_hid_dim=100, | |||||
pos_hid_dim=100, | |||||
rnn_layers=1, | rnn_layers=1, | ||||
rnn_hidden_size=200, | rnn_hidden_size=200, | ||||
arc_mlp_size=100, | arc_mlp_size=100, | ||||
label_mlp_size=100, | label_mlp_size=100, | ||||
dropout=0.3, | dropout=0.3, | ||||
use_var_lstm=False, | |||||
encoder='lstm', | |||||
use_greedy_infer=False): | use_greedy_infer=False): | ||||
super(BiaffineParser, self).__init__() | super(BiaffineParser, self).__init__() | ||||
rnn_out_size = 2 * rnn_hidden_size | rnn_out_size = 2 * rnn_hidden_size | ||||
word_hid_dim = pos_hid_dim = rnn_hidden_size | |||||
self.word_embedding = nn.Embedding(num_embeddings=word_vocab_size, embedding_dim=word_emb_dim) | self.word_embedding = nn.Embedding(num_embeddings=word_vocab_size, embedding_dim=word_emb_dim) | ||||
self.pos_embedding = nn.Embedding(num_embeddings=pos_vocab_size, embedding_dim=pos_emb_dim) | self.pos_embedding = nn.Embedding(num_embeddings=pos_vocab_size, embedding_dim=pos_emb_dim) | ||||
self.word_fc = nn.Linear(word_emb_dim, word_hid_dim) | self.word_fc = nn.Linear(word_emb_dim, word_hid_dim) | ||||
self.pos_fc = nn.Linear(pos_emb_dim, pos_hid_dim) | self.pos_fc = nn.Linear(pos_emb_dim, pos_hid_dim) | ||||
self.word_norm = nn.LayerNorm(word_hid_dim) | self.word_norm = nn.LayerNorm(word_hid_dim) | ||||
self.pos_norm = nn.LayerNorm(pos_hid_dim) | self.pos_norm = nn.LayerNorm(pos_hid_dim) | ||||
self.use_var_lstm = use_var_lstm | |||||
if use_var_lstm: | |||||
self.lstm = VarLSTM(input_size=word_hid_dim + pos_hid_dim, | |||||
hidden_size=rnn_hidden_size, | |||||
num_layers=rnn_layers, | |||||
bias=True, | |||||
batch_first=True, | |||||
input_dropout=dropout, | |||||
hidden_dropout=dropout, | |||||
bidirectional=True) | |||||
self.encoder_name = encoder | |||||
if encoder == 'var-lstm': | |||||
self.encoder = VarLSTM(input_size=word_hid_dim + pos_hid_dim, | |||||
hidden_size=rnn_hidden_size, | |||||
num_layers=rnn_layers, | |||||
bias=True, | |||||
batch_first=True, | |||||
input_dropout=dropout, | |||||
hidden_dropout=dropout, | |||||
bidirectional=True) | |||||
elif encoder == 'lstm': | |||||
self.encoder = nn.LSTM(input_size=word_hid_dim + pos_hid_dim, | |||||
hidden_size=rnn_hidden_size, | |||||
num_layers=rnn_layers, | |||||
bias=True, | |||||
batch_first=True, | |||||
dropout=dropout, | |||||
bidirectional=True) | |||||
else: | else: | ||||
self.lstm = nn.LSTM(input_size=word_hid_dim + pos_hid_dim, | |||||
hidden_size=rnn_hidden_size, | |||||
num_layers=rnn_layers, | |||||
bias=True, | |||||
batch_first=True, | |||||
dropout=dropout, | |||||
bidirectional=True) | |||||
self.arc_head_mlp = nn.Sequential(nn.Linear(rnn_out_size, arc_mlp_size), | |||||
nn.LayerNorm(arc_mlp_size), | |||||
raise ValueError('unsupported encoder type: {}'.format(encoder)) | |||||
self.mlp = nn.Sequential(nn.Linear(rnn_out_size, arc_mlp_size * 2 + label_mlp_size * 2), | |||||
nn.ELU(), | nn.ELU(), | ||||
TimestepDropout(p=dropout),) | TimestepDropout(p=dropout),) | ||||
self.arc_dep_mlp = copy.deepcopy(self.arc_head_mlp) | |||||
self.label_head_mlp = nn.Sequential(nn.Linear(rnn_out_size, label_mlp_size), | |||||
nn.LayerNorm(label_mlp_size), | |||||
nn.ELU(), | |||||
TimestepDropout(p=dropout),) | |||||
self.label_dep_mlp = copy.deepcopy(self.label_head_mlp) | |||||
self.arc_mlp_size = arc_mlp_size | |||||
self.label_mlp_size = label_mlp_size | |||||
self.arc_predictor = ArcBiaffine(arc_mlp_size, bias=True) | self.arc_predictor = ArcBiaffine(arc_mlp_size, bias=True) | ||||
self.label_predictor = LabelBilinear(label_mlp_size, label_mlp_size, num_label, bias=True) | self.label_predictor = LabelBilinear(label_mlp_size, label_mlp_size, num_label, bias=True) | ||||
self.use_greedy_infer = use_greedy_infer | self.use_greedy_infer = use_greedy_infer | ||||
@@ -286,24 +283,22 @@ class BiaffineParser(GraphParser): | |||||
word, pos = self.word_fc(word), self.pos_fc(pos) | word, pos = self.word_fc(word), self.pos_fc(pos) | ||||
word, pos = self.word_norm(word), self.pos_norm(pos) | word, pos = self.word_norm(word), self.pos_norm(pos) | ||||
x = torch.cat([word, pos], dim=2) # -> [N,L,C] | x = torch.cat([word, pos], dim=2) # -> [N,L,C] | ||||
del word, pos | |||||
# lstm, extract features | |||||
# encoder, extract features | |||||
sort_lens, sort_idx = torch.sort(seq_lens, dim=0, descending=True) | sort_lens, sort_idx = torch.sort(seq_lens, dim=0, descending=True) | ||||
x = x[sort_idx] | x = x[sort_idx] | ||||
x = nn.utils.rnn.pack_padded_sequence(x, sort_lens, batch_first=True) | x = nn.utils.rnn.pack_padded_sequence(x, sort_lens, batch_first=True) | ||||
feat, _ = self.lstm(x) # -> [N,L,C] | |||||
feat, _ = self.encoder(x) # -> [N,L,C] | |||||
feat, _ = nn.utils.rnn.pad_packed_sequence(feat, batch_first=True) | feat, _ = nn.utils.rnn.pad_packed_sequence(feat, batch_first=True) | ||||
_, unsort_idx = torch.sort(sort_idx, dim=0, descending=False) | _, unsort_idx = torch.sort(sort_idx, dim=0, descending=False) | ||||
feat = feat[unsort_idx] | feat = feat[unsort_idx] | ||||
# for arc biaffine | # for arc biaffine | ||||
# mlp, reduce dim | # mlp, reduce dim | ||||
arc_dep = self.arc_dep_mlp(feat) | |||||
arc_head = self.arc_head_mlp(feat) | |||||
label_dep = self.label_dep_mlp(feat) | |||||
label_head = self.label_head_mlp(feat) | |||||
del feat | |||||
feat = self.mlp(feat) | |||||
arc_sz, label_sz = self.arc_mlp_size, self.label_mlp_size | |||||
arc_dep, arc_head = feat[:,:,:arc_sz], feat[:,:,arc_sz:2*arc_sz] | |||||
label_dep, label_head = feat[:,:,2*arc_sz:2*arc_sz+label_sz], feat[:,:,2*arc_sz+label_sz:] | |||||
# biaffine arc classifier | # biaffine arc classifier | ||||
arc_pred = self.arc_predictor(arc_head, arc_dep) # [N, L, L] | arc_pred = self.arc_predictor(arc_head, arc_dep) # [N, L, L] | ||||
@@ -349,7 +344,7 @@ class BiaffineParser(GraphParser): | |||||
batch_size, seq_len, _ = arc_pred.shape | batch_size, seq_len, _ = arc_pred.shape | ||||
flip_mask = (mask == 0) | flip_mask = (mask == 0) | ||||
_arc_pred = arc_pred.clone() | _arc_pred = arc_pred.clone() | ||||
_arc_pred.masked_fill_(flip_mask.unsqueeze(1), -np.inf) | |||||
_arc_pred.masked_fill_(flip_mask.unsqueeze(1), -float('inf')) | |||||
arc_logits = F.log_softmax(_arc_pred, dim=2) | arc_logits = F.log_softmax(_arc_pred, dim=2) | ||||
label_logits = F.log_softmax(label_pred, dim=2) | label_logits = F.log_softmax(label_pred, dim=2) | ||||
batch_index = torch.arange(batch_size, device=arc_logits.device, dtype=torch.long).unsqueeze(1) | batch_index = torch.arange(batch_size, device=arc_logits.device, dtype=torch.long).unsqueeze(1) | ||||
@@ -357,12 +352,11 @@ class BiaffineParser(GraphParser): | |||||
arc_loss = arc_logits[batch_index, child_index, arc_true] | arc_loss = arc_logits[batch_index, child_index, arc_true] | ||||
label_loss = label_logits[batch_index, child_index, label_true] | label_loss = label_logits[batch_index, child_index, label_true] | ||||
arc_loss = arc_loss[:, 1:] | |||||
label_loss = label_loss[:, 1:] | |||||
float_mask = mask[:, 1:].float() | |||||
arc_nll = -(arc_loss*float_mask).mean() | |||||
label_nll = -(label_loss*float_mask).mean() | |||||
byte_mask = flip_mask.byte() | |||||
arc_loss.masked_fill_(byte_mask, 0) | |||||
label_loss.masked_fill_(byte_mask, 0) | |||||
arc_nll = -arc_loss.mean() | |||||
label_nll = -label_loss.mean() | |||||
return arc_nll + label_nll | return arc_nll + label_nll | ||||
def predict(self, word_seq, pos_seq, seq_lens): | def predict(self, word_seq, pos_seq, seq_lens): | ||||
@@ -5,6 +5,7 @@ import torch.nn.functional as F | |||||
from torch import nn | from torch import nn | ||||
from fastNLP.modules.utils import mask_softmax | from fastNLP.modules.utils import mask_softmax | ||||
from fastNLP.modules.dropout import TimestepDropout | |||||
class Attention(torch.nn.Module): | class Attention(torch.nn.Module): | ||||
@@ -23,62 +24,89 @@ class Attention(torch.nn.Module): | |||||
class DotAtte(nn.Module): | class DotAtte(nn.Module): | ||||
def __init__(self, key_size, value_size): | |||||
def __init__(self, key_size, value_size, dropout=0.1): | |||||
super(DotAtte, self).__init__() | super(DotAtte, self).__init__() | ||||
self.key_size = key_size | self.key_size = key_size | ||||
self.value_size = value_size | self.value_size = value_size | ||||
self.scale = math.sqrt(key_size) | self.scale = math.sqrt(key_size) | ||||
self.drop = nn.Dropout(dropout) | |||||
self.softmax = nn.Softmax(dim=2) | |||||
def forward(self, Q, K, V, seq_mask=None): | |||||
def forward(self, Q, K, V, mask_out=None): | |||||
""" | """ | ||||
:param Q: [batch, seq_len, key_size] | :param Q: [batch, seq_len, key_size] | ||||
:param K: [batch, seq_len, key_size] | :param K: [batch, seq_len, key_size] | ||||
:param V: [batch, seq_len, value_size] | :param V: [batch, seq_len, value_size] | ||||
:param seq_mask: [batch, seq_len] | |||||
:param mask_out: [batch, seq_len] | |||||
""" | """ | ||||
output = torch.matmul(Q, K.transpose(1, 2)) / self.scale | output = torch.matmul(Q, K.transpose(1, 2)) / self.scale | ||||
if seq_mask is not None: | |||||
output.masked_fill_(seq_mask.lt(1), -float('inf')) | |||||
output = nn.functional.softmax(output, dim=2) | |||||
if mask_out is not None: | |||||
output.masked_fill_(mask_out, -float('inf')) | |||||
output = self.softmax(output) | |||||
output = self.drop(output) | |||||
return torch.matmul(output, V) | return torch.matmul(output, V) | ||||
class MultiHeadAtte(nn.Module): | class MultiHeadAtte(nn.Module): | ||||
def __init__(self, input_size, output_size, key_size, value_size, num_atte): | |||||
def __init__(self, model_size, key_size, value_size, num_head, dropout=0.1): | |||||
""" | """ | ||||
实现的是以下内容 | |||||
QW1: (batch_size, seq_len, input_size) * (input_size, key_size) | |||||
KW2: (batch_size, seq_len, input_size) * (input_size, key_size) | |||||
VW3: (batch_size, seq_len, input_size) * (input_size, value_size) | |||||
softmax(QK^T/sqrt(scale))*V: (batch_size, seq_len, value_size) 多个head(num_atten指定)的结果为 | |||||
(batch_size, seq_len, value_size*num_atte) | |||||
最终结果将上式过一个(value_size*num_atte, output_size)的线性层,output为(batch_size, seq_len, output_size) | |||||
:param input_size: int, 输入的维度 | |||||
:param output_size: int, 输出特征的维度 | |||||
:param key_size: int, query和key映射到该维度 | |||||
:param value_size: int, value映射到该维度 | |||||
:param num_atte: | |||||
:param model_size: int, 输入维度的大小。同时也是输出维度的大小。 | |||||
:param key_size: int, 每个head的维度大小。 | |||||
:param value_size: int,每个head中value的维度。 | |||||
:param num_head: int,head的数量。 | |||||
:param dropout: float。 | |||||
""" | """ | ||||
super(MultiHeadAtte, self).__init__() | super(MultiHeadAtte, self).__init__() | ||||
self.in_linear = nn.ModuleList() | |||||
for i in range(num_atte * 3): | |||||
out_feat = key_size if (i % 3) != 2 else value_size | |||||
self.in_linear.append(nn.Linear(input_size, out_feat)) | |||||
self.attes = nn.ModuleList([DotAtte(key_size, value_size) for _ in range(num_atte)]) | |||||
self.out_linear = nn.Linear(value_size * num_atte, output_size) | |||||
def forward(self, Q, K, V, seq_mask=None): | |||||
heads = [] | |||||
for i in range(len(self.attes)): | |||||
j = i * 3 | |||||
qi, ki, vi = self.in_linear[j](Q), self.in_linear[j+1](K), self.in_linear[j+2](V) | |||||
headi = self.attes[i](qi, ki, vi, seq_mask) | |||||
heads.append(headi) | |||||
output = torch.cat(heads, dim=2) | |||||
return self.out_linear(output) | |||||
self.input_size = model_size | |||||
self.key_size = key_size | |||||
self.value_size = value_size | |||||
self.num_head = num_head | |||||
in_size = key_size * num_head | |||||
self.q_in = nn.Linear(model_size, in_size) | |||||
self.k_in = nn.Linear(model_size, in_size) | |||||
self.v_in = nn.Linear(model_size, in_size) | |||||
self.attention = DotAtte(key_size=key_size, value_size=value_size) | |||||
self.out = nn.Linear(value_size * num_head, model_size) | |||||
self.drop = TimestepDropout(dropout) | |||||
self.reset_parameters() | |||||
def reset_parameters(self): | |||||
sqrt = math.sqrt | |||||
nn.init.normal_(self.q_in.weight, mean=0, std=sqrt(2.0 / (self.input_size + self.key_size))) | |||||
nn.init.normal_(self.k_in.weight, mean=0, std=sqrt(2.0 / (self.input_size + self.key_size))) | |||||
nn.init.normal_(self.v_in.weight, mean=0, std=sqrt(2.0 / (self.input_size + self.value_size))) | |||||
nn.init.xavier_normal_(self.out.weight) | |||||
def forward(self, Q, K, V, atte_mask_out=None): | |||||
""" | |||||
:param Q: [batch, seq_len, model_size] | |||||
:param K: [batch, seq_len, model_size] | |||||
:param V: [batch, seq_len, model_size] | |||||
:param seq_mask: [batch, seq_len] | |||||
""" | |||||
batch, seq_len, _ = Q.size() | |||||
d_k, d_v, n_head = self.key_size, self.value_size, self.num_head | |||||
# input linear | |||||
q = self.q_in(Q).view(batch, seq_len, n_head, d_k) | |||||
k = self.k_in(K).view(batch, seq_len, n_head, d_k) | |||||
v = self.v_in(V).view(batch, seq_len, n_head, d_k) | |||||
# transpose q, k and v to do batch attention | |||||
q = q.permute(2, 0, 1, 3).contiguous().view(-1, seq_len, d_k) | |||||
k = k.permute(2, 0, 1, 3).contiguous().view(-1, seq_len, d_k) | |||||
v = v.permute(2, 0, 1, 3).contiguous().view(-1, seq_len, d_v) | |||||
if atte_mask_out is not None: | |||||
atte_mask_out = atte_mask_out.repeat(n_head, 1, 1) | |||||
atte = self.attention(q, k, v, atte_mask_out).view(n_head, batch, seq_len, d_v) | |||||
# concat all heads, do output linear | |||||
atte = atte.permute(1, 2, 0, 3).contiguous().view(batch, seq_len, -1) | |||||
output = self.drop(self.out(atte)) | |||||
return output | |||||
class Bi_Attention(nn.Module): | class Bi_Attention(nn.Module): | ||||
def __init__(self): | def __init__(self): | ||||
@@ -1,29 +1,57 @@ | |||||
import torch | |||||
from torch import nn | from torch import nn | ||||
from ..aggregator.attention import MultiHeadAtte | from ..aggregator.attention import MultiHeadAtte | ||||
from ..other_modules import LayerNormalization | |||||
from ..dropout import TimestepDropout | |||||
class TransformerEncoder(nn.Module): | class TransformerEncoder(nn.Module): | ||||
class SubLayer(nn.Module): | class SubLayer(nn.Module): | ||||
def __init__(self, input_size, output_size, key_size, value_size, num_atte): | |||||
def __init__(self, model_size, inner_size, key_size, value_size, num_head, dropout=0.1): | |||||
""" | |||||
:param model_size: int, 输入维度的大小。同时也是输出维度的大小。 | |||||
:param inner_size: int, FFN层的hidden大小 | |||||
:param key_size: int, 每个head的维度大小。 | |||||
:param value_size: int,每个head中value的维度。 | |||||
:param num_head: int,head的数量。 | |||||
:param dropout: float。 | |||||
""" | |||||
super(TransformerEncoder.SubLayer, self).__init__() | super(TransformerEncoder.SubLayer, self).__init__() | ||||
self.atte = MultiHeadAtte(input_size, output_size, key_size, value_size, num_atte) | |||||
self.norm1 = LayerNormalization(output_size) | |||||
self.ffn = nn.Sequential(nn.Linear(output_size, output_size), | |||||
self.atte = MultiHeadAtte(model_size, key_size, value_size, num_head, dropout) | |||||
self.norm1 = nn.LayerNorm(model_size) | |||||
self.ffn = nn.Sequential(nn.Linear(model_size, inner_size), | |||||
nn.ReLU(), | nn.ReLU(), | ||||
nn.Linear(output_size, output_size)) | |||||
self.norm2 = LayerNormalization(output_size) | |||||
nn.Linear(inner_size, model_size), | |||||
TimestepDropout(dropout),) | |||||
self.norm2 = nn.LayerNorm(model_size) | |||||
def forward(self, input, seq_mask=None, atte_mask_out=None): | |||||
""" | |||||
def forward(self, input, seq_mask): | |||||
attention = self.atte(input) | |||||
:param input: [batch, seq_len, model_size] | |||||
:param seq_mask: [batch, seq_len] | |||||
:return: [batch, seq_len, model_size] | |||||
""" | |||||
attention = self.atte(input, input, input, atte_mask_out) | |||||
norm_atte = self.norm1(attention + input) | norm_atte = self.norm1(attention + input) | ||||
attention *= seq_mask | |||||
output = self.ffn(norm_atte) | output = self.ffn(norm_atte) | ||||
return self.norm2(output + norm_atte) | |||||
output = self.norm2(output + norm_atte) | |||||
output *= seq_mask | |||||
return output | |||||
def __init__(self, num_layers, **kargs): | def __init__(self, num_layers, **kargs): | ||||
super(TransformerEncoder, self).__init__() | super(TransformerEncoder, self).__init__() | ||||
self.layers = nn.Sequential(*[self.SubLayer(**kargs) for _ in range(num_layers)]) | |||||
self.layers = nn.ModuleList([self.SubLayer(**kargs) for _ in range(num_layers)]) | |||||
def forward(self, x, seq_mask=None): | def forward(self, x, seq_mask=None): | ||||
return self.layers(x, seq_mask) | |||||
output = x | |||||
if seq_mask is None: | |||||
atte_mask_out = None | |||||
else: | |||||
atte_mask_out = (seq_mask < 1)[:,None,:] | |||||
seq_mask = seq_mask[:,:,None] | |||||
for layer in self.layers: | |||||
output = layer(output, seq_mask, atte_mask_out) | |||||
return output |
@@ -2,7 +2,8 @@ | |||||
n_epochs = 40 | n_epochs = 40 | ||||
batch_size = 32 | batch_size = 32 | ||||
use_cuda = true | use_cuda = true | ||||
validate_every = 500 | |||||
use_tqdm=true | |||||
validate_every = -1 | |||||
use_golden_train=true | use_golden_train=true | ||||
[test] | [test] | ||||
@@ -19,15 +20,13 @@ word_vocab_size = -1 | |||||
word_emb_dim = 100 | word_emb_dim = 100 | ||||
pos_vocab_size = -1 | pos_vocab_size = -1 | ||||
pos_emb_dim = 100 | pos_emb_dim = 100 | ||||
word_hid_dim = 100 | |||||
pos_hid_dim = 100 | |||||
rnn_layers = 3 | rnn_layers = 3 | ||||
rnn_hidden_size = 400 | |||||
rnn_hidden_size = 256 | |||||
arc_mlp_size = 500 | arc_mlp_size = 500 | ||||
label_mlp_size = 100 | label_mlp_size = 100 | ||||
num_label = -1 | num_label = -1 | ||||
dropout = 0.33 | |||||
use_var_lstm=true | |||||
dropout = 0.3 | |||||
encoder="transformer" | |||||
use_greedy_infer=false | use_greedy_infer=false | ||||
[optim] | [optim] | ||||
@@ -5,7 +5,7 @@ sys.path.extend(['/home/yfshao/workdir/dev_fastnlp']) | |||||
import torch | import torch | ||||
import argparse | import argparse | ||||
from reproduction.Biaffine_parser.util import ConllxDataLoader, add_seg_tag | |||||
from fastNLP.io.dataset_loader import ConllxDataLoader, add_seg_tag | |||||
from fastNLP.core.dataset import DataSet | from fastNLP.core.dataset import DataSet | ||||
from fastNLP.core.instance import Instance | from fastNLP.core.instance import Instance | ||||
@@ -4,20 +4,15 @@ import sys | |||||
sys.path.append(os.path.join(os.path.dirname(__file__), '../..')) | sys.path.append(os.path.join(os.path.dirname(__file__), '../..')) | ||||
import fastNLP | import fastNLP | ||||
import torch | |||||
from fastNLP.core.trainer import Trainer | from fastNLP.core.trainer import Trainer | ||||
from fastNLP.core.instance import Instance | from fastNLP.core.instance import Instance | ||||
from fastNLP.api.pipeline import Pipeline | from fastNLP.api.pipeline import Pipeline | ||||
from fastNLP.models.biaffine_parser import BiaffineParser, ParserMetric, ParserLoss | from fastNLP.models.biaffine_parser import BiaffineParser, ParserMetric, ParserLoss | ||||
from fastNLP.core.vocabulary import Vocabulary | |||||
from fastNLP.core.dataset import DataSet | |||||
from fastNLP.core.tester import Tester | from fastNLP.core.tester import Tester | ||||
from fastNLP.io.config_io import ConfigLoader, ConfigSection | from fastNLP.io.config_io import ConfigLoader, ConfigSection | ||||
from fastNLP.io.model_io import ModelLoader | from fastNLP.io.model_io import ModelLoader | ||||
from fastNLP.io.embed_loader import EmbedLoader | |||||
from fastNLP.io.model_io import ModelSaver | |||||
from reproduction.Biaffine_parser.util import ConllxDataLoader, MyDataloader | |||||
from fastNLP.io.dataset_loader import ConllxDataLoader | |||||
from fastNLP.api.processor import * | from fastNLP.api.processor import * | ||||
BOS = '<BOS>' | BOS = '<BOS>' | ||||
@@ -141,7 +136,7 @@ model_args['pos_vocab_size'] = len(pos_v) | |||||
model_args['num_label'] = len(tag_v) | model_args['num_label'] = len(tag_v) | ||||
model = BiaffineParser(**model_args.data) | model = BiaffineParser(**model_args.data) | ||||
model.reset_parameters() | |||||
print(model) | |||||
word_idxp = IndexerProcessor(word_v, 'words', 'word_seq') | word_idxp = IndexerProcessor(word_v, 'words', 'word_seq') | ||||
pos_idxp = IndexerProcessor(pos_v, 'pos', 'pos_seq') | pos_idxp = IndexerProcessor(pos_v, 'pos', 'pos_seq') | ||||
@@ -209,7 +204,8 @@ def save_pipe(path): | |||||
pipe = Pipeline(processors=[num_p, word_idxp, pos_idxp, seq_p, set_input_p]) | pipe = Pipeline(processors=[num_p, word_idxp, pos_idxp, seq_p, set_input_p]) | ||||
pipe.add_processor(ModelProcessor(model=model, batch_size=32)) | pipe.add_processor(ModelProcessor(model=model, batch_size=32)) | ||||
pipe.add_processor(label_toword_p) | pipe.add_processor(label_toword_p) | ||||
torch.save(pipe, os.path.join(path, 'pipe.pkl')) | |||||
os.makedirs(path, exist_ok=True) | |||||
torch.save({'pipeline': pipe}, os.path.join(path, 'pipe.pkl')) | |||||
def test(path): | def test(path): | ||||
@@ -1,34 +1,3 @@ | |||||
class ConllxDataLoader(object): | |||||
def load(self, path): | |||||
datalist = [] | |||||
with open(path, 'r', encoding='utf-8') as f: | |||||
sample = [] | |||||
for line in f: | |||||
if line.startswith('\n'): | |||||
datalist.append(sample) | |||||
sample = [] | |||||
elif line.startswith('#'): | |||||
continue | |||||
else: | |||||
sample.append(line.split('\t')) | |||||
if len(sample) > 0: | |||||
datalist.append(sample) | |||||
data = [self.get_one(sample) for sample in datalist] | |||||
return list(filter(lambda x: x is not None, data)) | |||||
def get_one(self, sample): | |||||
sample = list(map(list, zip(*sample))) | |||||
if len(sample) == 0: | |||||
return None | |||||
for w in sample[7]: | |||||
if w == '_': | |||||
print('Error Sample {}'.format(sample)) | |||||
return None | |||||
# return word_seq, pos_seq, head_seq, head_tag_seq | |||||
return sample[1], sample[3], list(map(int, sample[6])), sample[7] | |||||
class MyDataloader: | class MyDataloader: | ||||
def load(self, data_path): | def load(self, data_path): | ||||
with open(data_path, "r", encoding="utf-8") as f: | with open(data_path, "r", encoding="utf-8") as f: | ||||
@@ -56,23 +25,3 @@ class MyDataloader: | |||||
return data | return data | ||||
def add_seg_tag(data): | |||||
""" | |||||
:param data: list of ([word], [pos], [heads], [head_tags]) | |||||
:return: list of ([word], [pos]) | |||||
""" | |||||
_processed = [] | |||||
for word_list, pos_list, _, _ in data: | |||||
new_sample = [] | |||||
for word, pos in zip(word_list, pos_list): | |||||
if len(word) == 1: | |||||
new_sample.append((word, 'S-' + pos)) | |||||
else: | |||||
new_sample.append((word[0], 'B-' + pos)) | |||||
for c in word[1:-1]: | |||||
new_sample.append((c, 'M-' + pos)) | |||||
new_sample.append((word[-1], 'E-' + pos)) | |||||
_processed.append(list(map(list, zip(*new_sample)))) | |||||
return _processed |
@@ -0,0 +1,3 @@ | |||||
@@ -1,11 +1,11 @@ | |||||
from torch import nn | |||||
import torch | import torch | ||||
import torch.nn.functional as F | |||||
from torch import nn | |||||
from fastNLP.modules.decoder.MLP import MLP | |||||
from fastNLP.models.base_model import BaseModel | from fastNLP.models.base_model import BaseModel | ||||
from reproduction.chinese_word_segment.utils import seq_lens_to_mask | |||||
from fastNLP.modules.decoder.MLP import MLP | |||||
from reproduction.Chinese_word_segmentation.utils import seq_lens_to_mask | |||||
class CWSBiLSTMEncoder(BaseModel): | class CWSBiLSTMEncoder(BaseModel): | ||||
def __init__(self, vocab_num, embed_dim=100, bigram_vocab_num=None, bigram_embed_dim=100, num_bigram_per_char=None, | def __init__(self, vocab_num, embed_dim=100, bigram_vocab_num=None, bigram_embed_dim=100, num_bigram_per_char=None, |
@@ -4,7 +4,7 @@ import re | |||||
from fastNLP.api.processor import Processor | from fastNLP.api.processor import Processor | ||||
from fastNLP.core.dataset import DataSet | from fastNLP.core.dataset import DataSet | ||||
from fastNLP.core.vocabulary import Vocabulary | from fastNLP.core.vocabulary import Vocabulary | ||||
from reproduction.chinese_word_segment.process.span_converter import SpanConverter | |||||
from reproduction.Chinese_word_segmentation.process.span_converter import SpanConverter | |||||
_SPECIAL_TAG_PATTERN = '<[a-zA-Z]+>' | _SPECIAL_TAG_PATTERN = '<[a-zA-Z]+>' | ||||
@@ -226,109 +226,6 @@ class Pre2Post2BigramProcessor(BigramProcessor): | |||||
return bigrams | return bigrams | ||||
# 这里需要建立vocabulary了,但是遇到了以下的问题 | |||||
# (1) 如果使用Processor的方式的话,但是在这种情况返回的不是dataset。所以建立vocabulary的工作用另外的方式实现,不借用 | |||||
# Processor了 | |||||
# TODO 如何将建立vocab和index这两步统一了? | |||||
class VocabIndexerProcessor(Processor): | |||||
""" | |||||
根据DataSet创建Vocabulary,并将其用数字index。新生成的index的field会被放在new_added_filed_name, 如果没有提供 | |||||
new_added_field_name, 则覆盖原有的field_name. | |||||
""" | |||||
def __init__(self, field_name, new_added_filed_name=None, min_freq=1, max_size=None, | |||||
verbose=0, is_input=True): | |||||
""" | |||||
:param field_name: 从哪个field_name创建词表,以及对哪个field_name进行index操作 | |||||
:param new_added_filed_name: index时,生成的index field的名称,如果不传入,则覆盖field_name. | |||||
:param min_freq: 创建的Vocabulary允许的单词最少出现次数. | |||||
:param max_size: 创建的Vocabulary允许的最大的单词数量 | |||||
:param verbose: 0, 不输出任何信息;1,输出信息 | |||||
:param bool is_input: | |||||
""" | |||||
super(VocabIndexerProcessor, self).__init__(field_name, new_added_filed_name) | |||||
self.min_freq = min_freq | |||||
self.max_size = max_size | |||||
self.verbose =verbose | |||||
self.is_input = is_input | |||||
def construct_vocab(self, *datasets): | |||||
""" | |||||
使用传入的DataSet创建vocabulary | |||||
:param datasets: DataSet类型的数据,用于构建vocabulary | |||||
:return: | |||||
""" | |||||
self.vocab = Vocabulary(min_freq=self.min_freq, max_size=self.max_size) | |||||
for dataset in datasets: | |||||
assert isinstance(dataset, DataSet), "Only Dataset class is allowed, not {}.".format(type(dataset)) | |||||
dataset.apply(lambda ins: self.vocab.update(ins[self.field_name])) | |||||
self.vocab.build_vocab() | |||||
if self.verbose: | |||||
print("Vocabulary Constructed, has {} items.".format(len(self.vocab))) | |||||
def process(self, *datasets, only_index_dataset=None): | |||||
""" | |||||
若还未建立Vocabulary,则使用dataset中的DataSet建立vocabulary;若已经有了vocabulary则使用已有的vocabulary。得到vocabulary | |||||
后,则会index datasets与only_index_dataset。 | |||||
:param datasets: DataSet类型的数据 | |||||
:param only_index_dataset: DataSet, or list of DataSet. 该参数中的内容只会被用于index,不会被用于生成vocabulary。 | |||||
:return: | |||||
""" | |||||
if len(datasets)==0 and not hasattr(self,'vocab'): | |||||
raise RuntimeError("You have to construct vocabulary first. Or you have to pass datasets to construct it.") | |||||
if not hasattr(self, 'vocab'): | |||||
self.construct_vocab(*datasets) | |||||
else: | |||||
if self.verbose: | |||||
print("Using constructed vocabulary with {} items.".format(len(self.vocab))) | |||||
to_index_datasets = [] | |||||
if len(datasets)!=0: | |||||
for dataset in datasets: | |||||
assert isinstance(dataset, DataSet), "Only DataSet class is allowed, not {}.".format(type(dataset)) | |||||
to_index_datasets.append(dataset) | |||||
if not (only_index_dataset is None): | |||||
if isinstance(only_index_dataset, list): | |||||
for dataset in only_index_dataset: | |||||
assert isinstance(dataset, DataSet), "Only DataSet class is allowed, not {}.".format(type(dataset)) | |||||
to_index_datasets.append(dataset) | |||||
elif isinstance(only_index_dataset, DataSet): | |||||
to_index_datasets.append(only_index_dataset) | |||||
else: | |||||
raise TypeError('Only DataSet or list of DataSet is allowed, not {}.'.format(type(only_index_dataset))) | |||||
for dataset in to_index_datasets: | |||||
assert isinstance(dataset, DataSet), "Only DataSet class is allowed, not {}.".format(type(dataset)) | |||||
dataset.apply(lambda ins: [self.vocab.to_index(token) for token in ins[self.field_name]], | |||||
new_field_name=self.new_added_field_name, is_input=self.is_input) | |||||
# 只返回一个,infer时为了跟其他processor保持一致 | |||||
if len(to_index_datasets) == 1: | |||||
return to_index_datasets[0] | |||||
def set_vocab(self, vocab): | |||||
assert isinstance(vocab, Vocabulary), "Only fastNLP.core.Vocabulary is allowed, not {}.".format(type(vocab)) | |||||
self.vocab = vocab | |||||
def delete_vocab(self): | |||||
del self.vocab | |||||
def get_vocab_size(self): | |||||
return len(self.vocab) | |||||
def set_verbose(self, verbose): | |||||
""" | |||||
设置processor verbose状态。 | |||||
:param verbose: int, 0,不输出任何信息;1,输出vocab 信息。 | |||||
:return: | |||||
""" | |||||
self.verbose = verbose | |||||
class VocabProcessor(Processor): | class VocabProcessor(Processor): | ||||
def __init__(self, field_name, min_freq=1, max_size=None): | def __init__(self, field_name, min_freq=1, max_size=None): | ||||
@@ -0,0 +1,29 @@ | |||||
from fastNLP.io.dataset_loader import ZhConllPOSReader | |||||
def cut_long_sentence(sent, max_sample_length=200): | |||||
sent_no_space = sent.replace(' ', '') | |||||
cutted_sentence = [] | |||||
if len(sent_no_space) > max_sample_length: | |||||
parts = sent.strip().split() | |||||
new_line = '' | |||||
length = 0 | |||||
for part in parts: | |||||
length += len(part) | |||||
new_line += part + ' ' | |||||
if length > max_sample_length: | |||||
new_line = new_line[:-1] | |||||
cutted_sentence.append(new_line) | |||||
length = 0 | |||||
new_line = '' | |||||
if new_line != '': | |||||
cutted_sentence.append(new_line[:-1]) | |||||
else: | |||||
cutted_sentence.append(sent) | |||||
return cutted_sentence | |||||
if __name__ == '__main__': | |||||
reader = ZhConllPOSReader() | |||||
d = reader.load('/home/hyan/train.conllx') | |||||
print(d) |
@@ -10,13 +10,12 @@ sys.path.append(os.path.join(os.path.dirname(__file__), '../..')) | |||||
from fastNLP.api.pipeline import Pipeline | from fastNLP.api.pipeline import Pipeline | ||||
from fastNLP.api.processor import SeqLenProcessor | |||||
from fastNLP.api.processor import SeqLenProcessor, VocabIndexerProcessor | |||||
from fastNLP.core.metrics import SpanFPreRecMetric | from fastNLP.core.metrics import SpanFPreRecMetric | ||||
from fastNLP.core.trainer import Trainer | from fastNLP.core.trainer import Trainer | ||||
from fastNLP.io.config_io import ConfigLoader, ConfigSection | from fastNLP.io.config_io import ConfigLoader, ConfigSection | ||||
from fastNLP.models.sequence_modeling import AdvSeqLabel | from fastNLP.models.sequence_modeling import AdvSeqLabel | ||||
from reproduction.chinese_word_segment.process.cws_processor import VocabIndexerProcessor | |||||
from reproduction.pos_tag_model.pos_reader import ZhConllPOSReader | |||||
from fastNLP.io.dataset_loader import ZhConllPOSReader | |||||
from fastNLP.api.processor import ModelProcessor, Index2WordProcessor | from fastNLP.api.processor import ModelProcessor, Index2WordProcessor | ||||
cfgfile = './pos_tag.cfg' | cfgfile = './pos_tag.cfg' |
@@ -1,197 +0,0 @@ | |||||
from fastNLP.core.dataset import DataSet | |||||
from fastNLP.core.instance import Instance | |||||
from fastNLP.io.dataset_loader import DataSetLoader | |||||
def cut_long_sentence(sent, max_sample_length=200): | |||||
""" | |||||
将长于max_sample_length的sentence截成多段,只会在有空格的地方发生截断。所以截取的句子可能长于或者短于max_sample_length | |||||
:param sent: str. | |||||
:param max_sample_length: int. | |||||
:return: list of str. | |||||
""" | |||||
sent_no_space = sent.replace(' ', '') | |||||
cutted_sentence = [] | |||||
if len(sent_no_space) > max_sample_length: | |||||
parts = sent.strip().split() | |||||
new_line = '' | |||||
length = 0 | |||||
for part in parts: | |||||
length += len(part) | |||||
new_line += part + ' ' | |||||
if length > max_sample_length: | |||||
new_line = new_line[:-1] | |||||
cutted_sentence.append(new_line) | |||||
length = 0 | |||||
new_line = '' | |||||
if new_line != '': | |||||
cutted_sentence.append(new_line[:-1]) | |||||
else: | |||||
cutted_sentence.append(sent) | |||||
return cutted_sentence | |||||
class NaiveCWSReader(DataSetLoader): | |||||
""" | |||||
这个reader假设了分词数据集为以下形式, 即已经用空格分割好内容了 | |||||
这是 fastNLP , 一个 非常 good 的 包 . | |||||
或者,即每个part后面还有一个pos tag | |||||
也/D 在/P 團員/Na 之中/Ng ,/COMMACATEGORY | |||||
""" | |||||
def __init__(self, in_word_splitter=None): | |||||
super().__init__() | |||||
self.in_word_splitter = in_word_splitter | |||||
def load(self, filepath, in_word_splitter=None, cut_long_sent=False): | |||||
""" | |||||
允许使用的情况有(默认以\t或空格作为seg) | |||||
这是 fastNLP , 一个 非常 good 的 包 . | |||||
和 | |||||
也/D 在/P 團員/Na 之中/Ng ,/COMMACATEGORY | |||||
如果splitter不为None则认为是第二种情况, 且我们会按splitter分割"也/D", 然后取第一部分. 例如"也/D".split('/')[0] | |||||
:param filepath: | |||||
:param in_word_splitter: | |||||
:return: | |||||
""" | |||||
if in_word_splitter == None: | |||||
in_word_splitter = self.in_word_splitter | |||||
dataset = DataSet() | |||||
with open(filepath, 'r') as f: | |||||
for line in f: | |||||
line = line.strip() | |||||
if len(line.replace(' ', ''))==0: # 不能接受空行 | |||||
continue | |||||
if not in_word_splitter is None: | |||||
words = [] | |||||
for part in line.split(): | |||||
word = part.split(in_word_splitter)[0] | |||||
words.append(word) | |||||
line = ' '.join(words) | |||||
if cut_long_sent: | |||||
sents = cut_long_sentence(line) | |||||
else: | |||||
sents = [line] | |||||
for sent in sents: | |||||
instance = Instance(raw_sentence=sent) | |||||
dataset.append(instance) | |||||
return dataset | |||||
class POSCWSReader(DataSetLoader): | |||||
""" | |||||
支持读取以下的情况, 即每一行是一个词, 用空行作为两句话的界限. | |||||
迈 N | |||||
向 N | |||||
充 N | |||||
... | |||||
泽 I-PER | |||||
民 I-PER | |||||
( N | |||||
一 N | |||||
九 N | |||||
... | |||||
:param filepath: | |||||
:return: | |||||
""" | |||||
def __init__(self, in_word_splitter=None): | |||||
super().__init__() | |||||
self.in_word_splitter = in_word_splitter | |||||
def load(self, filepath, in_word_splitter=None, cut_long_sent=False): | |||||
if in_word_splitter is None: | |||||
in_word_splitter = self.in_word_splitter | |||||
dataset = DataSet() | |||||
with open(filepath, 'r') as f: | |||||
words = [] | |||||
for line in f: | |||||
line = line.strip() | |||||
if len(line) == 0: # new line | |||||
if len(words)==0: # 不能接受空行 | |||||
continue | |||||
line = ' '.join(words) | |||||
if cut_long_sent: | |||||
sents = cut_long_sentence(line) | |||||
else: | |||||
sents = [line] | |||||
for sent in sents: | |||||
instance = Instance(raw_sentence=sent) | |||||
dataset.append(instance) | |||||
words = [] | |||||
else: | |||||
line = line.split()[0] | |||||
if in_word_splitter is None: | |||||
words.append(line) | |||||
else: | |||||
words.append(line.split(in_word_splitter)[0]) | |||||
return dataset | |||||
class ConllCWSReader(object): | |||||
def __init__(self): | |||||
pass | |||||
def load(self, path, cut_long_sent=False): | |||||
""" | |||||
返回的DataSet只包含raw_sentence这个field,内容为str。 | |||||
假定了输入为conll的格式,以空行隔开两个句子,每行共7列,即 | |||||
1 编者按 编者按 NN O 11 nmod:topic | |||||
2 : : PU O 11 punct | |||||
3 7月 7月 NT DATE 4 compound:nn | |||||
4 12日 12日 NT DATE 11 nmod:tmod | |||||
5 , , PU O 11 punct | |||||
1 这 这 DT O 3 det | |||||
2 款 款 M O 1 mark:clf | |||||
3 飞行 飞行 NN O 8 nsubj | |||||
4 从 从 P O 5 case | |||||
5 外型 外型 NN O 8 nmod:prep | |||||
""" | |||||
datalist = [] | |||||
with open(path, 'r', encoding='utf-8') as f: | |||||
sample = [] | |||||
for line in f: | |||||
if line.startswith('\n'): | |||||
datalist.append(sample) | |||||
sample = [] | |||||
elif line.startswith('#'): | |||||
continue | |||||
else: | |||||
sample.append(line.split('\t')) | |||||
if len(sample) > 0: | |||||
datalist.append(sample) | |||||
ds = DataSet() | |||||
for sample in datalist: | |||||
# print(sample) | |||||
res = self.get_char_lst(sample) | |||||
if res is None: | |||||
continue | |||||
line = ' '.join(res) | |||||
if cut_long_sent: | |||||
sents = cut_long_sentence(line) | |||||
else: | |||||
sents = [line] | |||||
for raw_sentence in sents: | |||||
ds.append(Instance(raw_sentence=raw_sentence)) | |||||
return ds | |||||
def get_char_lst(self, sample): | |||||
if len(sample)==0: | |||||
return None | |||||
text = [] | |||||
for w in sample: | |||||
t1, t2, t3, t4 = w[1], w[3], w[6], w[7] | |||||
if t3 == '_': | |||||
return None | |||||
text.append(t1) | |||||
return text | |||||
@@ -28,8 +28,9 @@ class TransformerCWS(nn.Module): | |||||
self.fc1 = nn.Linear(input_size, hidden_size) | self.fc1 = nn.Linear(input_size, hidden_size) | ||||
value_size = hidden_size//num_heads | value_size = hidden_size//num_heads | ||||
self.transformer = TransformerEncoder(num_layers, input_size=input_size, output_size=hidden_size, | |||||
key_size=value_size, value_size=value_size, num_atte=num_heads) | |||||
self.transformer = TransformerEncoder(num_layers, model_size=hidden_size, inner_size=hidden_size, | |||||
key_size=value_size, | |||||
value_size=value_size, num_head=num_heads) | |||||
self.fc2 = nn.Linear(hidden_size, tag_size) | self.fc2 = nn.Linear(hidden_size, tag_size) | ||||
@@ -39,7 +40,7 @@ class TransformerCWS(nn.Module): | |||||
def forward(self, chars, target, seq_lens, bigrams=None): | def forward(self, chars, target, seq_lens, bigrams=None): | ||||
seq_lens = seq_lens | seq_lens = seq_lens | ||||
masks = seq_len_to_byte_mask(seq_lens) | |||||
masks = seq_len_to_byte_mask(seq_lens).float() | |||||
x = self.embedding(chars) | x = self.embedding(chars) | ||||
batch_size = x.size(0) | batch_size = x.size(0) | ||||
length = x.size(1) | length = x.size(1) | ||||
@@ -1,151 +0,0 @@ | |||||
import os | |||||
import sys | |||||
sys.path.append(os.path.join(os.path.dirname(__file__), '../..')) | |||||
from fastNLP.io.config_io import ConfigLoader, ConfigSection | |||||
from fastNLP.core.trainer import SeqLabelTrainer | |||||
from fastNLP.io.dataset_loader import BaseLoader, TokenizeDataSetLoader | |||||
from fastNLP.core.utils import load_pickle | |||||
from fastNLP.io.model_io import ModelLoader, ModelSaver | |||||
from fastNLP.core.tester import SeqLabelTester | |||||
from fastNLP.models.sequence_modeling import AdvSeqLabel | |||||
from fastNLP.core.predictor import SeqLabelInfer | |||||
from fastNLP.core.utils import save_pickle | |||||
from fastNLP.core.metrics import SeqLabelEvaluator | |||||
# not in the file's dir | |||||
if len(os.path.dirname(__file__)) != 0: | |||||
os.chdir(os.path.dirname(__file__)) | |||||
datadir = "/home/zyfeng/data/" | |||||
cfgfile = './cws.cfg' | |||||
cws_data_path = os.path.join(datadir, "pku_training.utf8") | |||||
pickle_path = "save" | |||||
data_infer_path = os.path.join(datadir, "infer.utf8") | |||||
def infer(): | |||||
# Config Loader | |||||
test_args = ConfigSection() | |||||
ConfigLoader().load_config(cfgfile, {"POS_test": test_args}) | |||||
# fetch dictionary size and number of labels from pickle files | |||||
word2index = load_pickle(pickle_path, "word2id.pkl") | |||||
test_args["vocab_size"] = len(word2index) | |||||
index2label = load_pickle(pickle_path, "label2id.pkl") | |||||
test_args["num_classes"] = len(index2label) | |||||
# Define the same model | |||||
model = AdvSeqLabel(test_args) | |||||
try: | |||||
ModelLoader.load_pytorch(model, "./save/trained_model.pkl") | |||||
print('model loaded!') | |||||
except Exception as e: | |||||
print('cannot load model!') | |||||
raise | |||||
# Data Loader | |||||
infer_data = SeqLabelDataSet(load_func=BaseLoader.load_lines) | |||||
infer_data.load(data_infer_path, vocabs={"word_vocab": word2index}, infer=True) | |||||
print('data loaded') | |||||
# Inference interface | |||||
infer = SeqLabelInfer(pickle_path) | |||||
results = infer.predict(model, infer_data) | |||||
print(results) | |||||
print("Inference finished!") | |||||
def train(): | |||||
# Config Loader | |||||
train_args = ConfigSection() | |||||
test_args = ConfigSection() | |||||
ConfigLoader().load_config(cfgfile, {"train": train_args, "test": test_args}) | |||||
print("loading data set...") | |||||
data = SeqLabelDataSet(load_func=TokenizeDataSetLoader.load) | |||||
data.load(cws_data_path) | |||||
data_train, data_dev = data.split(ratio=0.3) | |||||
train_args["vocab_size"] = len(data.word_vocab) | |||||
train_args["num_classes"] = len(data.label_vocab) | |||||
print("vocab size={}, num_classes={}".format(len(data.word_vocab), len(data.label_vocab))) | |||||
change_field_is_target(data_dev, "truth", True) | |||||
save_pickle(data_dev, "./save/", "data_dev.pkl") | |||||
save_pickle(data.word_vocab, "./save/", "word2id.pkl") | |||||
save_pickle(data.label_vocab, "./save/", "label2id.pkl") | |||||
# Trainer | |||||
trainer = SeqLabelTrainer(epochs=train_args["epochs"], batch_size=train_args["batch_size"], | |||||
validate=train_args["validate"], | |||||
use_cuda=train_args["use_cuda"], pickle_path=train_args["pickle_path"], | |||||
save_best_dev=True, print_every_step=10, model_name="trained_model.pkl", | |||||
evaluator=SeqLabelEvaluator()) | |||||
# Model | |||||
model = AdvSeqLabel(train_args) | |||||
try: | |||||
ModelLoader.load_pytorch(model, "./save/saved_model.pkl") | |||||
print('model parameter loaded!') | |||||
except Exception as e: | |||||
print("No saved model. Continue.") | |||||
pass | |||||
# Start training | |||||
trainer.train(model, data_train, data_dev) | |||||
print("Training finished!") | |||||
# Saver | |||||
saver = ModelSaver("./save/trained_model.pkl") | |||||
saver.save_pytorch(model) | |||||
print("Model saved!") | |||||
def predict(): | |||||
# Config Loader | |||||
test_args = ConfigSection() | |||||
ConfigLoader().load_config(cfgfile, {"POS_test": test_args}) | |||||
# fetch dictionary size and number of labels from pickle files | |||||
word2index = load_pickle(pickle_path, "word2id.pkl") | |||||
test_args["vocab_size"] = len(word2index) | |||||
index2label = load_pickle(pickle_path, "label2id.pkl") | |||||
test_args["num_classes"] = len(index2label) | |||||
# load dev data | |||||
dev_data = load_pickle(pickle_path, "data_dev.pkl") | |||||
# Define the same model | |||||
model = AdvSeqLabel(test_args) | |||||
# Dump trained parameters into the model | |||||
ModelLoader.load_pytorch(model, "./save/trained_model.pkl") | |||||
print("model loaded!") | |||||
# Tester | |||||
test_args["evaluator"] = SeqLabelEvaluator() | |||||
tester = SeqLabelTester(**test_args.data) | |||||
# Start testing | |||||
tester.test(model, dev_data) | |||||
if __name__ == "__main__": | |||||
import argparse | |||||
parser = argparse.ArgumentParser(description='Run a chinese word segmentation model') | |||||
parser.add_argument('--mode', help='set the model\'s model', choices=['train', 'test', 'infer']) | |||||
args = parser.parse_args() | |||||
if args.mode == 'train': | |||||
train() | |||||
elif args.mode == 'test': | |||||
predict() | |||||
elif args.mode == 'infer': | |||||
infer() | |||||
else: | |||||
print('no mode specified for model!') | |||||
parser.print_help() |
@@ -1,153 +0,0 @@ | |||||
from fastNLP.core.dataset import DataSet | |||||
from fastNLP.core.instance import Instance | |||||
def cut_long_sentence(sent, max_sample_length=200): | |||||
sent_no_space = sent.replace(' ', '') | |||||
cutted_sentence = [] | |||||
if len(sent_no_space) > max_sample_length: | |||||
parts = sent.strip().split() | |||||
new_line = '' | |||||
length = 0 | |||||
for part in parts: | |||||
length += len(part) | |||||
new_line += part + ' ' | |||||
if length > max_sample_length: | |||||
new_line = new_line[:-1] | |||||
cutted_sentence.append(new_line) | |||||
length = 0 | |||||
new_line = '' | |||||
if new_line != '': | |||||
cutted_sentence.append(new_line[:-1]) | |||||
else: | |||||
cutted_sentence.append(sent) | |||||
return cutted_sentence | |||||
class ConllPOSReader(object): | |||||
# 返回的Dataset包含words(list of list, 里层的list是character), tag两个field(list of str, str是标有BIO的tag)。 | |||||
def __init__(self): | |||||
pass | |||||
def load(self, path): | |||||
datalist = [] | |||||
with open(path, 'r', encoding='utf-8') as f: | |||||
sample = [] | |||||
for line in f: | |||||
if line.startswith('\n'): | |||||
datalist.append(sample) | |||||
sample = [] | |||||
elif line.startswith('#'): | |||||
continue | |||||
else: | |||||
sample.append(line.split('\t')) | |||||
if len(sample) > 0: | |||||
datalist.append(sample) | |||||
ds = DataSet() | |||||
for sample in datalist: | |||||
# print(sample) | |||||
res = self.get_one(sample) | |||||
if res is None: | |||||
continue | |||||
char_seq = [] | |||||
pos_seq = [] | |||||
for word, tag in zip(res[0], res[1]): | |||||
if len(word)==1: | |||||
char_seq.append(word) | |||||
pos_seq.append('S-{}'.format(tag)) | |||||
elif len(word)>1: | |||||
pos_seq.append('B-{}'.format(tag)) | |||||
for _ in range(len(word)-2): | |||||
pos_seq.append('M-{}'.format(tag)) | |||||
pos_seq.append('E-{}'.format(tag)) | |||||
char_seq.extend(list(word)) | |||||
else: | |||||
raise ValueError("Zero length of word detected.") | |||||
ds.append(Instance(words=char_seq, | |||||
tag=pos_seq)) | |||||
return ds | |||||
class ZhConllPOSReader(object): | |||||
# 中文colln格式reader | |||||
def __init__(self): | |||||
pass | |||||
def load(self, path): | |||||
""" | |||||
返回的DataSet, 包含以下的field | |||||
words:list of str, | |||||
tag: list of str, 被加入了BMES tag, 比如原来的序列为['VP', 'NN', 'NN', ..],会被认为是["S-VP", "B-NN", "M-NN",..] | |||||
假定了输入为conll的格式,以空行隔开两个句子,每行共7列,即 | |||||
1 编者按 编者按 NN O 11 nmod:topic | |||||
2 : : PU O 11 punct | |||||
3 7月 7月 NT DATE 4 compound:nn | |||||
4 12日 12日 NT DATE 11 nmod:tmod | |||||
5 , , PU O 11 punct | |||||
1 这 这 DT O 3 det | |||||
2 款 款 M O 1 mark:clf | |||||
3 飞行 飞行 NN O 8 nsubj | |||||
4 从 从 P O 5 case | |||||
5 外型 外型 NN O 8 nmod:prep | |||||
""" | |||||
datalist = [] | |||||
with open(path, 'r', encoding='utf-8') as f: | |||||
sample = [] | |||||
for line in f: | |||||
if line.startswith('\n'): | |||||
datalist.append(sample) | |||||
sample = [] | |||||
elif line.startswith('#'): | |||||
continue | |||||
else: | |||||
sample.append(line.split('\t')) | |||||
if len(sample) > 0: | |||||
datalist.append(sample) | |||||
ds = DataSet() | |||||
for sample in datalist: | |||||
# print(sample) | |||||
res = self.get_one(sample) | |||||
if res is None: | |||||
continue | |||||
char_seq = [] | |||||
pos_seq = [] | |||||
for word, tag in zip(res[0], res[1]): | |||||
char_seq.extend(list(word)) | |||||
if len(word)==1: | |||||
pos_seq.append('S-{}'.format(tag)) | |||||
elif len(word)>1: | |||||
pos_seq.append('B-{}'.format(tag)) | |||||
for _ in range(len(word)-2): | |||||
pos_seq.append('M-{}'.format(tag)) | |||||
pos_seq.append('E-{}'.format(tag)) | |||||
else: | |||||
raise ValueError("Zero length of word detected.") | |||||
ds.append(Instance(words=char_seq, | |||||
tag=pos_seq)) | |||||
return ds | |||||
def get_one(self, sample): | |||||
if len(sample)==0: | |||||
return None | |||||
text = [] | |||||
pos_tags = [] | |||||
for w in sample: | |||||
t1, t2, t3, t4 = w[1], w[3], w[6], w[7] | |||||
if t3 == '_': | |||||
return None | |||||
text.append(t1) | |||||
pos_tags.append(t2) | |||||
return text, pos_tags | |||||
if __name__ == '__main__': | |||||
reader = ZhConllPOSReader() | |||||
d = reader.load('/home/hyan/train.conllx') | |||||
print(d) |
@@ -1,6 +1,7 @@ | |||||
import unittest | import unittest | ||||
import numpy as np | import numpy as np | ||||
import torch | |||||
from fastNLP.core.batch import Batch | from fastNLP.core.batch import Batch | ||||
from fastNLP.core.dataset import DataSet | from fastNLP.core.dataset import DataSet | ||||
@@ -31,3 +32,47 @@ class TestCase1(unittest.TestCase): | |||||
self.assertEqual(len(y["y"]), 4) | self.assertEqual(len(y["y"]), 4) | ||||
self.assertListEqual(list(x["x"][-1]), [1, 2, 3, 4]) | self.assertListEqual(list(x["x"][-1]), [1, 2, 3, 4]) | ||||
self.assertListEqual(list(y["y"][-1]), [5, 6]) | self.assertListEqual(list(y["y"][-1]), [5, 6]) | ||||
def test_list_padding(self): | |||||
ds = DataSet({"x": [[1], [1, 2], [1, 2, 3], [1, 2, 3, 4]] * 10, | |||||
"y": [[4, 3, 2, 1], [3, 2, 1], [2, 1], [1]] * 10}) | |||||
ds.set_input("x") | |||||
ds.set_target("y") | |||||
iter = Batch(ds, batch_size=4, sampler=SequentialSampler(), as_numpy=True) | |||||
for x, y in iter: | |||||
self.assertEqual(x["x"].shape, (4, 4)) | |||||
self.assertEqual(y["y"].shape, (4, 4)) | |||||
def test_numpy_padding(self): | |||||
ds = DataSet({"x": np.array([[1], [1, 2], [1, 2, 3], [1, 2, 3, 4]] * 10), | |||||
"y": np.array([[4, 3, 2, 1], [3, 2, 1], [2, 1], [1]] * 10)}) | |||||
ds.set_input("x") | |||||
ds.set_target("y") | |||||
iter = Batch(ds, batch_size=4, sampler=SequentialSampler(), as_numpy=True) | |||||
for x, y in iter: | |||||
self.assertEqual(x["x"].shape, (4, 4)) | |||||
self.assertEqual(y["y"].shape, (4, 4)) | |||||
def test_list_to_tensor(self): | |||||
ds = DataSet({"x": [[1], [1, 2], [1, 2, 3], [1, 2, 3, 4]] * 10, | |||||
"y": [[4, 3, 2, 1], [3, 2, 1], [2, 1], [1]] * 10}) | |||||
ds.set_input("x") | |||||
ds.set_target("y") | |||||
iter = Batch(ds, batch_size=4, sampler=SequentialSampler(), as_numpy=False) | |||||
for x, y in iter: | |||||
self.assertTrue(isinstance(x["x"], torch.Tensor)) | |||||
self.assertEqual(tuple(x["x"].shape), (4, 4)) | |||||
self.assertTrue(isinstance(y["y"], torch.Tensor)) | |||||
self.assertEqual(tuple(y["y"].shape), (4, 4)) | |||||
def test_numpy_to_tensor(self): | |||||
ds = DataSet({"x": np.array([[1], [1, 2], [1, 2, 3], [1, 2, 3, 4]] * 10), | |||||
"y": np.array([[4, 3, 2, 1], [3, 2, 1], [2, 1], [1]] * 10)}) | |||||
ds.set_input("x") | |||||
ds.set_target("y") | |||||
iter = Batch(ds, batch_size=4, sampler=SequentialSampler(), as_numpy=False) | |||||
for x, y in iter: | |||||
self.assertTrue(isinstance(x["x"], torch.Tensor)) | |||||
self.assertEqual(tuple(x["x"].shape), (4, 4)) | |||||
self.assertTrue(isinstance(y["y"], torch.Tensor)) | |||||
self.assertEqual(tuple(y["y"].shape), (4, 4)) |
@@ -77,9 +77,10 @@ class TestBiaffineParser(unittest.TestCase): | |||||
ds, v1, v2, v3 = init_data() | ds, v1, v2, v3 = init_data() | ||||
model = BiaffineParser(word_vocab_size=len(v1), word_emb_dim=30, | model = BiaffineParser(word_vocab_size=len(v1), word_emb_dim=30, | ||||
pos_vocab_size=len(v2), pos_emb_dim=30, | pos_vocab_size=len(v2), pos_emb_dim=30, | ||||
num_label=len(v3), use_var_lstm=True) | |||||
num_label=len(v3), encoder='var-lstm') | |||||
trainer = fastNLP.Trainer(model=model, train_data=ds, dev_data=ds, | trainer = fastNLP.Trainer(model=model, train_data=ds, dev_data=ds, | ||||
loss=ParserLoss(), metrics=ParserMetric(), metric_key='UAS', | loss=ParserLoss(), metrics=ParserMetric(), metric_key='UAS', | ||||
batch_size=1, validate_every=10, | |||||
n_epochs=10, use_cuda=False, use_tqdm=False) | n_epochs=10, use_cuda=False, use_tqdm=False) | ||||
trainer.train(load_best_model=False) | trainer.train(load_best_model=False) | ||||