diff --git a/fastNLP/api/api.py b/fastNLP/api/api.py index 8368dcc9..b9bc7b70 100644 --- a/fastNLP/api/api.py +++ b/fastNLP/api/api.py @@ -9,9 +9,7 @@ from fastNLP.core.dataset import DataSet from fastNLP.api.utils import load_url from fastNLP.api.processor import ModelProcessor -from reproduction.chinese_word_segment.cws_io.cws_reader import ConllCWSReader -from reproduction.pos_tag_model.pos_reader import ZhConllPOSReader -from reproduction.Biaffine_parser.util import ConllxDataLoader, add_seg_tag +from fastNLP.io.dataset_loader import ConllCWSReader, ZhConllPOSReader, ConllxDataLoader, add_seg_tag from fastNLP.core.instance import Instance from fastNLP.api.pipeline import Pipeline from fastNLP.core.metrics import SpanFPreRecMetric @@ -31,6 +29,16 @@ class API: self._dict = None def predict(self, *args, **kwargs): + """Do prediction for the given input. + """ + raise NotImplementedError + + def test(self, file_path): + """Test performance over the given data set. + + :param str file_path: + :return: a dictionary of metric values + """ raise NotImplementedError def load(self, path, device): diff --git a/fastNLP/api/processor.py b/fastNLP/api/processor.py index 7354fe0f..6867dae8 100644 --- a/fastNLP/api/processor.py +++ b/fastNLP/api/processor.py @@ -322,3 +322,103 @@ class SetInputProcessor(Processor): def process(self, dataset): dataset.set_input(*self.fields, flag=self.flag) return dataset + + +class VocabIndexerProcessor(Processor): + """ + 根据DataSet创建Vocabulary,并将其用数字index。新生成的index的field会被放在new_added_filed_name, 如果没有提供 + new_added_field_name, 则覆盖原有的field_name. + + """ + + def __init__(self, field_name, new_added_filed_name=None, min_freq=1, max_size=None, + verbose=0, is_input=True): + """ + + :param field_name: 从哪个field_name创建词表,以及对哪个field_name进行index操作 + :param new_added_filed_name: index时,生成的index field的名称,如果不传入,则覆盖field_name. + :param min_freq: 创建的Vocabulary允许的单词最少出现次数. + :param max_size: 创建的Vocabulary允许的最大的单词数量 + :param verbose: 0, 不输出任何信息;1,输出信息 + :param bool is_input: + """ + super(VocabIndexerProcessor, self).__init__(field_name, new_added_filed_name) + self.min_freq = min_freq + self.max_size = max_size + + self.verbose = verbose + self.is_input = is_input + + def construct_vocab(self, *datasets): + """ + 使用传入的DataSet创建vocabulary + + :param datasets: DataSet类型的数据,用于构建vocabulary + :return: + """ + self.vocab = Vocabulary(min_freq=self.min_freq, max_size=self.max_size) + for dataset in datasets: + assert isinstance(dataset, DataSet), "Only Dataset class is allowed, not {}.".format(type(dataset)) + dataset.apply(lambda ins: self.vocab.update(ins[self.field_name])) + self.vocab.build_vocab() + if self.verbose: + print("Vocabulary Constructed, has {} items.".format(len(self.vocab))) + + def process(self, *datasets, only_index_dataset=None): + """ + 若还未建立Vocabulary,则使用dataset中的DataSet建立vocabulary;若已经有了vocabulary则使用已有的vocabulary。得到vocabulary + 后,则会index datasets与only_index_dataset。 + + :param datasets: DataSet类型的数据 + :param only_index_dataset: DataSet, or list of DataSet. 该参数中的内容只会被用于index,不会被用于生成vocabulary。 + :return: + """ + if len(datasets) == 0 and not hasattr(self, 'vocab'): + raise RuntimeError("You have to construct vocabulary first. Or you have to pass datasets to construct it.") + if not hasattr(self, 'vocab'): + self.construct_vocab(*datasets) + else: + if self.verbose: + print("Using constructed vocabulary with {} items.".format(len(self.vocab))) + to_index_datasets = [] + if len(datasets) != 0: + for dataset in datasets: + assert isinstance(dataset, DataSet), "Only DataSet class is allowed, not {}.".format(type(dataset)) + to_index_datasets.append(dataset) + + if not (only_index_dataset is None): + if isinstance(only_index_dataset, list): + for dataset in only_index_dataset: + assert isinstance(dataset, DataSet), "Only DataSet class is allowed, not {}.".format(type(dataset)) + to_index_datasets.append(dataset) + elif isinstance(only_index_dataset, DataSet): + to_index_datasets.append(only_index_dataset) + else: + raise TypeError('Only DataSet or list of DataSet is allowed, not {}.'.format(type(only_index_dataset))) + + for dataset in to_index_datasets: + assert isinstance(dataset, DataSet), "Only DataSet class is allowed, not {}.".format(type(dataset)) + dataset.apply(lambda ins: [self.vocab.to_index(token) for token in ins[self.field_name]], + new_field_name=self.new_added_field_name, is_input=self.is_input) + # 只返回一个,infer时为了跟其他processor保持一致 + if len(to_index_datasets) == 1: + return to_index_datasets[0] + + def set_vocab(self, vocab): + assert isinstance(vocab, Vocabulary), "Only fastNLP.core.Vocabulary is allowed, not {}.".format(type(vocab)) + self.vocab = vocab + + def delete_vocab(self): + del self.vocab + + def get_vocab_size(self): + return len(self.vocab) + + def set_verbose(self, verbose): + """ + 设置processor verbose状态。 + + :param verbose: int, 0,不输出任何信息;1,输出vocab 信息。 + :return: + """ + self.verbose = verbose diff --git a/fastNLP/core/trainer.py b/fastNLP/core/trainer.py index add86156..ccb3d18e 100644 --- a/fastNLP/core/trainer.py +++ b/fastNLP/core/trainer.py @@ -283,7 +283,7 @@ class Trainer(object): self.callback_manager.after_batch() if ((self.validate_every > 0 and self.step % self.validate_every == 0) or - (self.validate_every < 0 and self.step % len(data_iterator)) == 0) \ + (self.validate_every < 0 and self.step % len(data_iterator) == 0)) \ and self.dev_data is not None: eval_res = self._do_validation(epoch=epoch, step=self.step) eval_str = "Evaluation at Epoch {}/{}. Step:{}/{}. ".format(epoch, self.n_epochs, self.step, @@ -367,12 +367,23 @@ class Trainer(object): return self.losser(predict, truth) def _save_model(self, model, model_name, only_param=False): + """ 存储不含有显卡信息的state_dict或model + :param model: + :param model_name: + :param only_param: + :return: + """ if self.save_path is not None: - model_name = os.path.join(self.save_path, model_name) + model_path = os.path.join(self.save_path, model_name) if only_param: - torch.save(model.state_dict(), model_name) + state_dict = model.state_dict() + for key in state_dict: + state_dict[key] = state_dict[key].cpu() + torch.save(state_dict, model_path) else: - torch.save(model, model_name) + model.cpu() + torch.save(model, model_path) + model.cuda() def _load_model(self, model, model_name, only_param=False): # 返回bool值指示是否成功reload模型 diff --git a/fastNLP/io/dataset_loader.py b/fastNLP/io/dataset_loader.py index 27d8a360..2d157da3 100644 --- a/fastNLP/io/dataset_loader.py +++ b/fastNLP/io/dataset_loader.py @@ -90,6 +90,7 @@ class NativeDataSetLoader(DataSetLoader): """A simple example of DataSetLoader """ + def __init__(self): super(NativeDataSetLoader, self).__init__() @@ -107,6 +108,7 @@ class RawDataSetLoader(DataSetLoader): """A simple example of raw data reader """ + def __init__(self): super(RawDataSetLoader, self).__init__() @@ -142,6 +144,7 @@ class POSDataSetLoader(DataSetLoader): In this example, there are two sentences "Tom and Jerry ." and "Hello world !". Each word has its own label. """ + def __init__(self): super(POSDataSetLoader, self).__init__() @@ -540,3 +543,373 @@ class SNLIDataSetLoader(DataSetLoader): data_set.set_input("premise", "hypothesis", "premise_len", "hypothesis_len") data_set.set_target("truth") return data_set + + +class ConllCWSReader(object): + def __init__(self): + pass + + def load(self, path, cut_long_sent=False): + """ + 返回的DataSet只包含raw_sentence这个field,内容为str。 + 假定了输入为conll的格式,以空行隔开两个句子,每行共7列,即 + 1 编者按 编者按 NN O 11 nmod:topic + 2 : : PU O 11 punct + 3 7月 7月 NT DATE 4 compound:nn + 4 12日 12日 NT DATE 11 nmod:tmod + 5 , , PU O 11 punct + + 1 这 这 DT O 3 det + 2 款 款 M O 1 mark:clf + 3 飞行 飞行 NN O 8 nsubj + 4 从 从 P O 5 case + 5 外型 外型 NN O 8 nmod:prep + """ + datalist = [] + with open(path, 'r', encoding='utf-8') as f: + sample = [] + for line in f: + if line.startswith('\n'): + datalist.append(sample) + sample = [] + elif line.startswith('#'): + continue + else: + sample.append(line.split('\t')) + if len(sample) > 0: + datalist.append(sample) + + ds = DataSet() + for sample in datalist: + # print(sample) + res = self.get_char_lst(sample) + if res is None: + continue + line = ' '.join(res) + if cut_long_sent: + sents = cut_long_sentence(line) + else: + sents = [line] + for raw_sentence in sents: + ds.append(Instance(raw_sentence=raw_sentence)) + + return ds + + def get_char_lst(self, sample): + if len(sample) == 0: + return None + text = [] + for w in sample: + t1, t2, t3, t4 = w[1], w[3], w[6], w[7] + if t3 == '_': + return None + text.append(t1) + return text + + +class POSCWSReader(DataSetLoader): + """ + 支持读取以下的情况, 即每一行是一个词, 用空行作为两句话的界限. + 迈 N + 向 N + 充 N + ... + 泽 I-PER + 民 I-PER + + ( N + 一 N + 九 N + ... + + + :param filepath: + :return: + """ + + def __init__(self, in_word_splitter=None): + super().__init__() + self.in_word_splitter = in_word_splitter + + def load(self, filepath, in_word_splitter=None, cut_long_sent=False): + if in_word_splitter is None: + in_word_splitter = self.in_word_splitter + dataset = DataSet() + with open(filepath, 'r') as f: + words = [] + for line in f: + line = line.strip() + if len(line) == 0: # new line + if len(words) == 0: # 不能接受空行 + continue + line = ' '.join(words) + if cut_long_sent: + sents = cut_long_sentence(line) + else: + sents = [line] + for sent in sents: + instance = Instance(raw_sentence=sent) + dataset.append(instance) + words = [] + else: + line = line.split()[0] + if in_word_splitter is None: + words.append(line) + else: + words.append(line.split(in_word_splitter)[0]) + return dataset + + +class NaiveCWSReader(DataSetLoader): + """ + 这个reader假设了分词数据集为以下形式, 即已经用空格分割好内容了 + 这是 fastNLP , 一个 非常 good 的 包 . + 或者,即每个part后面还有一个pos tag + 也/D 在/P 團員/Na 之中/Ng ,/COMMACATEGORY + """ + + def __init__(self, in_word_splitter=None): + super().__init__() + + self.in_word_splitter = in_word_splitter + + def load(self, filepath, in_word_splitter=None, cut_long_sent=False): + """ + 允许使用的情况有(默认以\t或空格作为seg) + 这是 fastNLP , 一个 非常 good 的 包 . + 和 + 也/D 在/P 團員/Na 之中/Ng ,/COMMACATEGORY + 如果splitter不为None则认为是第二种情况, 且我们会按splitter分割"也/D", 然后取第一部分. 例如"也/D".split('/')[0] + :param filepath: + :param in_word_splitter: + :return: + """ + if in_word_splitter == None: + in_word_splitter = self.in_word_splitter + dataset = DataSet() + with open(filepath, 'r') as f: + for line in f: + line = line.strip() + if len(line.replace(' ', '')) == 0: # 不能接受空行 + continue + + if not in_word_splitter is None: + words = [] + for part in line.split(): + word = part.split(in_word_splitter)[0] + words.append(word) + line = ' '.join(words) + if cut_long_sent: + sents = cut_long_sentence(line) + else: + sents = [line] + for sent in sents: + instance = Instance(raw_sentence=sent) + dataset.append(instance) + + return dataset + + +def cut_long_sentence(sent, max_sample_length=200): + """ + 将长于max_sample_length的sentence截成多段,只会在有空格的地方发生截断。所以截取的句子可能长于或者短于max_sample_length + + :param sent: str. + :param max_sample_length: int. + :return: list of str. + """ + sent_no_space = sent.replace(' ', '') + cutted_sentence = [] + if len(sent_no_space) > max_sample_length: + parts = sent.strip().split() + new_line = '' + length = 0 + for part in parts: + length += len(part) + new_line += part + ' ' + if length > max_sample_length: + new_line = new_line[:-1] + cutted_sentence.append(new_line) + length = 0 + new_line = '' + if new_line != '': + cutted_sentence.append(new_line[:-1]) + else: + cutted_sentence.append(sent) + return cutted_sentence + + +class ZhConllPOSReader(object): + # 中文colln格式reader + def __init__(self): + pass + + def load(self, path): + """ + 返回的DataSet, 包含以下的field + words:list of str, + tag: list of str, 被加入了BMES tag, 比如原来的序列为['VP', 'NN', 'NN', ..],会被认为是["S-VP", "B-NN", "M-NN",..] + 假定了输入为conll的格式,以空行隔开两个句子,每行共7列,即 + 1 编者按 编者按 NN O 11 nmod:topic + 2 : : PU O 11 punct + 3 7月 7月 NT DATE 4 compound:nn + 4 12日 12日 NT DATE 11 nmod:tmod + 5 , , PU O 11 punct + + 1 这 这 DT O 3 det + 2 款 款 M O 1 mark:clf + 3 飞行 飞行 NN O 8 nsubj + 4 从 从 P O 5 case + 5 外型 外型 NN O 8 nmod:prep + """ + datalist = [] + with open(path, 'r', encoding='utf-8') as f: + sample = [] + for line in f: + if line.startswith('\n'): + datalist.append(sample) + sample = [] + elif line.startswith('#'): + continue + else: + sample.append(line.split('\t')) + if len(sample) > 0: + datalist.append(sample) + + ds = DataSet() + for sample in datalist: + # print(sample) + res = self.get_one(sample) + if res is None: + continue + char_seq = [] + pos_seq = [] + for word, tag in zip(res[0], res[1]): + char_seq.extend(list(word)) + if len(word) == 1: + pos_seq.append('S-{}'.format(tag)) + elif len(word) > 1: + pos_seq.append('B-{}'.format(tag)) + for _ in range(len(word) - 2): + pos_seq.append('M-{}'.format(tag)) + pos_seq.append('E-{}'.format(tag)) + else: + raise ValueError("Zero length of word detected.") + + ds.append(Instance(words=char_seq, + tag=pos_seq)) + + return ds + + def get_one(self, sample): + if len(sample) == 0: + return None + text = [] + pos_tags = [] + for w in sample: + t1, t2, t3, t4 = w[1], w[3], w[6], w[7] + if t3 == '_': + return None + text.append(t1) + pos_tags.append(t2) + return text, pos_tags + + +class ConllPOSReader(object): + # 返回的Dataset包含words(list of list, 里层的list是character), tag两个field(list of str, str是标有BIO的tag)。 + def __init__(self): + pass + + def load(self, path): + datalist = [] + with open(path, 'r', encoding='utf-8') as f: + sample = [] + for line in f: + if line.startswith('\n'): + datalist.append(sample) + sample = [] + elif line.startswith('#'): + continue + else: + sample.append(line.split('\t')) + if len(sample) > 0: + datalist.append(sample) + + ds = DataSet() + for sample in datalist: + # print(sample) + res = self.get_one(sample) + if res is None: + continue + char_seq = [] + pos_seq = [] + for word, tag in zip(res[0], res[1]): + if len(word) == 1: + char_seq.append(word) + pos_seq.append('S-{}'.format(tag)) + elif len(word) > 1: + pos_seq.append('B-{}'.format(tag)) + for _ in range(len(word) - 2): + pos_seq.append('M-{}'.format(tag)) + pos_seq.append('E-{}'.format(tag)) + char_seq.extend(list(word)) + else: + raise ValueError("Zero length of word detected.") + + ds.append(Instance(words=char_seq, + tag=pos_seq)) + + return ds + + +class ConllxDataLoader(object): + def load(self, path): + datalist = [] + with open(path, 'r', encoding='utf-8') as f: + sample = [] + for line in f: + if line.startswith('\n'): + datalist.append(sample) + sample = [] + elif line.startswith('#'): + continue + else: + sample.append(line.split('\t')) + if len(sample) > 0: + datalist.append(sample) + + data = [self.get_one(sample) for sample in datalist] + return list(filter(lambda x: x is not None, data)) + + def get_one(self, sample): + sample = list(map(list, zip(*sample))) + if len(sample) == 0: + return None + for w in sample[7]: + if w == '_': + print('Error Sample {}'.format(sample)) + return None + # return word_seq, pos_seq, head_seq, head_tag_seq + return sample[1], sample[3], list(map(int, sample[6])), sample[7] + + +def add_seg_tag(data): + """ + + :param data: list of ([word], [pos], [heads], [head_tags]) + :return: list of ([word], [pos]) + """ + + _processed = [] + for word_list, pos_list, _, _ in data: + new_sample = [] + for word, pos in zip(word_list, pos_list): + if len(word) == 1: + new_sample.append((word, 'S-' + pos)) + else: + new_sample.append((word[0], 'B-' + pos)) + for c in word[1:-1]: + new_sample.append((c, 'M-' + pos)) + new_sample.append((word[-1], 'E-' + pos)) + _processed.append(list(map(list, zip(*new_sample)))) + return _processed diff --git a/fastNLP/models/biaffine_parser.py b/fastNLP/models/biaffine_parser.py index fb687301..b9b9dd56 100644 --- a/fastNLP/models/biaffine_parser.py +++ b/fastNLP/models/biaffine_parser.py @@ -6,6 +6,7 @@ from torch import nn from torch.nn import functional as F from fastNLP.modules.utils import initial_parameter from fastNLP.modules.encoder.variational_rnn import VarLSTM +from fastNLP.modules.encoder.transformer import TransformerEncoder from fastNLP.modules.dropout import TimestepDropout from fastNLP.models.base_model import BaseModel from fastNLP.modules.utils import seq_mask @@ -197,53 +198,49 @@ class BiaffineParser(GraphParser): pos_vocab_size, pos_emb_dim, num_label, - word_hid_dim=100, - pos_hid_dim=100, rnn_layers=1, rnn_hidden_size=200, arc_mlp_size=100, label_mlp_size=100, dropout=0.3, - use_var_lstm=False, + encoder='lstm', use_greedy_infer=False): super(BiaffineParser, self).__init__() rnn_out_size = 2 * rnn_hidden_size + word_hid_dim = pos_hid_dim = rnn_hidden_size self.word_embedding = nn.Embedding(num_embeddings=word_vocab_size, embedding_dim=word_emb_dim) self.pos_embedding = nn.Embedding(num_embeddings=pos_vocab_size, embedding_dim=pos_emb_dim) self.word_fc = nn.Linear(word_emb_dim, word_hid_dim) self.pos_fc = nn.Linear(pos_emb_dim, pos_hid_dim) self.word_norm = nn.LayerNorm(word_hid_dim) self.pos_norm = nn.LayerNorm(pos_hid_dim) - self.use_var_lstm = use_var_lstm - if use_var_lstm: - self.lstm = VarLSTM(input_size=word_hid_dim + pos_hid_dim, - hidden_size=rnn_hidden_size, - num_layers=rnn_layers, - bias=True, - batch_first=True, - input_dropout=dropout, - hidden_dropout=dropout, - bidirectional=True) + self.encoder_name = encoder + if encoder == 'var-lstm': + self.encoder = VarLSTM(input_size=word_hid_dim + pos_hid_dim, + hidden_size=rnn_hidden_size, + num_layers=rnn_layers, + bias=True, + batch_first=True, + input_dropout=dropout, + hidden_dropout=dropout, + bidirectional=True) + elif encoder == 'lstm': + self.encoder = nn.LSTM(input_size=word_hid_dim + pos_hid_dim, + hidden_size=rnn_hidden_size, + num_layers=rnn_layers, + bias=True, + batch_first=True, + dropout=dropout, + bidirectional=True) else: - self.lstm = nn.LSTM(input_size=word_hid_dim + pos_hid_dim, - hidden_size=rnn_hidden_size, - num_layers=rnn_layers, - bias=True, - batch_first=True, - dropout=dropout, - bidirectional=True) - - self.arc_head_mlp = nn.Sequential(nn.Linear(rnn_out_size, arc_mlp_size), - nn.LayerNorm(arc_mlp_size), + raise ValueError('unsupported encoder type: {}'.format(encoder)) + + self.mlp = nn.Sequential(nn.Linear(rnn_out_size, arc_mlp_size * 2 + label_mlp_size * 2), nn.ELU(), TimestepDropout(p=dropout),) - self.arc_dep_mlp = copy.deepcopy(self.arc_head_mlp) - self.label_head_mlp = nn.Sequential(nn.Linear(rnn_out_size, label_mlp_size), - nn.LayerNorm(label_mlp_size), - nn.ELU(), - TimestepDropout(p=dropout),) - self.label_dep_mlp = copy.deepcopy(self.label_head_mlp) + self.arc_mlp_size = arc_mlp_size + self.label_mlp_size = label_mlp_size self.arc_predictor = ArcBiaffine(arc_mlp_size, bias=True) self.label_predictor = LabelBilinear(label_mlp_size, label_mlp_size, num_label, bias=True) self.use_greedy_infer = use_greedy_infer @@ -286,24 +283,22 @@ class BiaffineParser(GraphParser): word, pos = self.word_fc(word), self.pos_fc(pos) word, pos = self.word_norm(word), self.pos_norm(pos) x = torch.cat([word, pos], dim=2) # -> [N,L,C] - del word, pos - # lstm, extract features + # encoder, extract features sort_lens, sort_idx = torch.sort(seq_lens, dim=0, descending=True) x = x[sort_idx] x = nn.utils.rnn.pack_padded_sequence(x, sort_lens, batch_first=True) - feat, _ = self.lstm(x) # -> [N,L,C] + feat, _ = self.encoder(x) # -> [N,L,C] feat, _ = nn.utils.rnn.pad_packed_sequence(feat, batch_first=True) _, unsort_idx = torch.sort(sort_idx, dim=0, descending=False) feat = feat[unsort_idx] # for arc biaffine # mlp, reduce dim - arc_dep = self.arc_dep_mlp(feat) - arc_head = self.arc_head_mlp(feat) - label_dep = self.label_dep_mlp(feat) - label_head = self.label_head_mlp(feat) - del feat + feat = self.mlp(feat) + arc_sz, label_sz = self.arc_mlp_size, self.label_mlp_size + arc_dep, arc_head = feat[:,:,:arc_sz], feat[:,:,arc_sz:2*arc_sz] + label_dep, label_head = feat[:,:,2*arc_sz:2*arc_sz+label_sz], feat[:,:,2*arc_sz+label_sz:] # biaffine arc classifier arc_pred = self.arc_predictor(arc_head, arc_dep) # [N, L, L] @@ -349,7 +344,7 @@ class BiaffineParser(GraphParser): batch_size, seq_len, _ = arc_pred.shape flip_mask = (mask == 0) _arc_pred = arc_pred.clone() - _arc_pred.masked_fill_(flip_mask.unsqueeze(1), -np.inf) + _arc_pred.masked_fill_(flip_mask.unsqueeze(1), -float('inf')) arc_logits = F.log_softmax(_arc_pred, dim=2) label_logits = F.log_softmax(label_pred, dim=2) batch_index = torch.arange(batch_size, device=arc_logits.device, dtype=torch.long).unsqueeze(1) @@ -357,12 +352,11 @@ class BiaffineParser(GraphParser): arc_loss = arc_logits[batch_index, child_index, arc_true] label_loss = label_logits[batch_index, child_index, label_true] - arc_loss = arc_loss[:, 1:] - label_loss = label_loss[:, 1:] - - float_mask = mask[:, 1:].float() - arc_nll = -(arc_loss*float_mask).mean() - label_nll = -(label_loss*float_mask).mean() + byte_mask = flip_mask.byte() + arc_loss.masked_fill_(byte_mask, 0) + label_loss.masked_fill_(byte_mask, 0) + arc_nll = -arc_loss.mean() + label_nll = -label_loss.mean() return arc_nll + label_nll def predict(self, word_seq, pos_seq, seq_lens): diff --git a/fastNLP/modules/aggregator/attention.py b/fastNLP/modules/aggregator/attention.py index 9f7d72dc..ef3f3fe5 100644 --- a/fastNLP/modules/aggregator/attention.py +++ b/fastNLP/modules/aggregator/attention.py @@ -5,6 +5,7 @@ import torch.nn.functional as F from torch import nn from fastNLP.modules.utils import mask_softmax +from fastNLP.modules.dropout import TimestepDropout class Attention(torch.nn.Module): @@ -23,62 +24,89 @@ class Attention(torch.nn.Module): class DotAtte(nn.Module): - def __init__(self, key_size, value_size): + def __init__(self, key_size, value_size, dropout=0.1): super(DotAtte, self).__init__() self.key_size = key_size self.value_size = value_size self.scale = math.sqrt(key_size) + self.drop = nn.Dropout(dropout) + self.softmax = nn.Softmax(dim=2) - def forward(self, Q, K, V, seq_mask=None): + def forward(self, Q, K, V, mask_out=None): """ :param Q: [batch, seq_len, key_size] :param K: [batch, seq_len, key_size] :param V: [batch, seq_len, value_size] - :param seq_mask: [batch, seq_len] + :param mask_out: [batch, seq_len] """ output = torch.matmul(Q, K.transpose(1, 2)) / self.scale - if seq_mask is not None: - output.masked_fill_(seq_mask.lt(1), -float('inf')) - output = nn.functional.softmax(output, dim=2) + if mask_out is not None: + output.masked_fill_(mask_out, -float('inf')) + output = self.softmax(output) + output = self.drop(output) return torch.matmul(output, V) class MultiHeadAtte(nn.Module): - def __init__(self, input_size, output_size, key_size, value_size, num_atte): + def __init__(self, model_size, key_size, value_size, num_head, dropout=0.1): """ - 实现的是以下内容 - QW1: (batch_size, seq_len, input_size) * (input_size, key_size) - KW2: (batch_size, seq_len, input_size) * (input_size, key_size) - VW3: (batch_size, seq_len, input_size) * (input_size, value_size) - - softmax(QK^T/sqrt(scale))*V: (batch_size, seq_len, value_size) 多个head(num_atten指定)的结果为 - (batch_size, seq_len, value_size*num_atte) - 最终结果将上式过一个(value_size*num_atte, output_size)的线性层,output为(batch_size, seq_len, output_size) - :param input_size: int, 输入的维度 - :param output_size: int, 输出特征的维度 - :param key_size: int, query和key映射到该维度 - :param value_size: int, value映射到该维度 - :param num_atte: + + :param model_size: int, 输入维度的大小。同时也是输出维度的大小。 + :param key_size: int, 每个head的维度大小。 + :param value_size: int,每个head中value的维度。 + :param num_head: int,head的数量。 + :param dropout: float。 """ super(MultiHeadAtte, self).__init__() - self.in_linear = nn.ModuleList() - for i in range(num_atte * 3): - out_feat = key_size if (i % 3) != 2 else value_size - self.in_linear.append(nn.Linear(input_size, out_feat)) - self.attes = nn.ModuleList([DotAtte(key_size, value_size) for _ in range(num_atte)]) - self.out_linear = nn.Linear(value_size * num_atte, output_size) - - def forward(self, Q, K, V, seq_mask=None): - heads = [] - for i in range(len(self.attes)): - j = i * 3 - qi, ki, vi = self.in_linear[j](Q), self.in_linear[j+1](K), self.in_linear[j+2](V) - headi = self.attes[i](qi, ki, vi, seq_mask) - heads.append(headi) - output = torch.cat(heads, dim=2) - return self.out_linear(output) + self.input_size = model_size + self.key_size = key_size + self.value_size = value_size + self.num_head = num_head + + in_size = key_size * num_head + self.q_in = nn.Linear(model_size, in_size) + self.k_in = nn.Linear(model_size, in_size) + self.v_in = nn.Linear(model_size, in_size) + self.attention = DotAtte(key_size=key_size, value_size=value_size) + self.out = nn.Linear(value_size * num_head, model_size) + self.drop = TimestepDropout(dropout) + self.reset_parameters() + + def reset_parameters(self): + sqrt = math.sqrt + nn.init.normal_(self.q_in.weight, mean=0, std=sqrt(2.0 / (self.input_size + self.key_size))) + nn.init.normal_(self.k_in.weight, mean=0, std=sqrt(2.0 / (self.input_size + self.key_size))) + nn.init.normal_(self.v_in.weight, mean=0, std=sqrt(2.0 / (self.input_size + self.value_size))) + nn.init.xavier_normal_(self.out.weight) + + def forward(self, Q, K, V, atte_mask_out=None): + """ + :param Q: [batch, seq_len, model_size] + :param K: [batch, seq_len, model_size] + :param V: [batch, seq_len, model_size] + :param seq_mask: [batch, seq_len] + """ + batch, seq_len, _ = Q.size() + d_k, d_v, n_head = self.key_size, self.value_size, self.num_head + # input linear + q = self.q_in(Q).view(batch, seq_len, n_head, d_k) + k = self.k_in(K).view(batch, seq_len, n_head, d_k) + v = self.v_in(V).view(batch, seq_len, n_head, d_k) + + # transpose q, k and v to do batch attention + q = q.permute(2, 0, 1, 3).contiguous().view(-1, seq_len, d_k) + k = k.permute(2, 0, 1, 3).contiguous().view(-1, seq_len, d_k) + v = v.permute(2, 0, 1, 3).contiguous().view(-1, seq_len, d_v) + if atte_mask_out is not None: + atte_mask_out = atte_mask_out.repeat(n_head, 1, 1) + atte = self.attention(q, k, v, atte_mask_out).view(n_head, batch, seq_len, d_v) + + # concat all heads, do output linear + atte = atte.permute(1, 2, 0, 3).contiguous().view(batch, seq_len, -1) + output = self.drop(self.out(atte)) + return output class Bi_Attention(nn.Module): def __init__(self): diff --git a/fastNLP/modules/encoder/transformer.py b/fastNLP/modules/encoder/transformer.py index 615a6f34..92ccc3fe 100644 --- a/fastNLP/modules/encoder/transformer.py +++ b/fastNLP/modules/encoder/transformer.py @@ -1,29 +1,57 @@ +import torch from torch import nn from ..aggregator.attention import MultiHeadAtte -from ..other_modules import LayerNormalization +from ..dropout import TimestepDropout class TransformerEncoder(nn.Module): class SubLayer(nn.Module): - def __init__(self, input_size, output_size, key_size, value_size, num_atte): + def __init__(self, model_size, inner_size, key_size, value_size, num_head, dropout=0.1): + """ + + :param model_size: int, 输入维度的大小。同时也是输出维度的大小。 + :param inner_size: int, FFN层的hidden大小 + :param key_size: int, 每个head的维度大小。 + :param value_size: int,每个head中value的维度。 + :param num_head: int,head的数量。 + :param dropout: float。 + """ super(TransformerEncoder.SubLayer, self).__init__() - self.atte = MultiHeadAtte(input_size, output_size, key_size, value_size, num_atte) - self.norm1 = LayerNormalization(output_size) - self.ffn = nn.Sequential(nn.Linear(output_size, output_size), + self.atte = MultiHeadAtte(model_size, key_size, value_size, num_head, dropout) + self.norm1 = nn.LayerNorm(model_size) + self.ffn = nn.Sequential(nn.Linear(model_size, inner_size), nn.ReLU(), - nn.Linear(output_size, output_size)) - self.norm2 = LayerNormalization(output_size) + nn.Linear(inner_size, model_size), + TimestepDropout(dropout),) + self.norm2 = nn.LayerNorm(model_size) + + def forward(self, input, seq_mask=None, atte_mask_out=None): + """ - def forward(self, input, seq_mask): - attention = self.atte(input) + :param input: [batch, seq_len, model_size] + :param seq_mask: [batch, seq_len] + :return: [batch, seq_len, model_size] + """ + attention = self.atte(input, input, input, atte_mask_out) norm_atte = self.norm1(attention + input) + attention *= seq_mask output = self.ffn(norm_atte) - return self.norm2(output + norm_atte) + output = self.norm2(output + norm_atte) + output *= seq_mask + return output def __init__(self, num_layers, **kargs): super(TransformerEncoder, self).__init__() - self.layers = nn.Sequential(*[self.SubLayer(**kargs) for _ in range(num_layers)]) + self.layers = nn.ModuleList([self.SubLayer(**kargs) for _ in range(num_layers)]) def forward(self, x, seq_mask=None): - return self.layers(x, seq_mask) + output = x + if seq_mask is None: + atte_mask_out = None + else: + atte_mask_out = (seq_mask < 1)[:,None,:] + seq_mask = seq_mask[:,:,None] + for layer in self.layers: + output = layer(output, seq_mask, atte_mask_out) + return output diff --git a/reproduction/Biaffine_parser/cfg.cfg b/reproduction/Biaffine_parser/cfg.cfg index 9b00c209..ad06598f 100644 --- a/reproduction/Biaffine_parser/cfg.cfg +++ b/reproduction/Biaffine_parser/cfg.cfg @@ -2,7 +2,8 @@ n_epochs = 40 batch_size = 32 use_cuda = true -validate_every = 500 +use_tqdm=true +validate_every = -1 use_golden_train=true [test] @@ -19,15 +20,13 @@ word_vocab_size = -1 word_emb_dim = 100 pos_vocab_size = -1 pos_emb_dim = 100 -word_hid_dim = 100 -pos_hid_dim = 100 rnn_layers = 3 -rnn_hidden_size = 400 +rnn_hidden_size = 256 arc_mlp_size = 500 label_mlp_size = 100 num_label = -1 -dropout = 0.33 -use_var_lstm=true +dropout = 0.3 +encoder="transformer" use_greedy_infer=false [optim] diff --git a/reproduction/Biaffine_parser/main.py b/reproduction/Biaffine_parser/main.py index 9028ff80..f4fd5836 100644 --- a/reproduction/Biaffine_parser/main.py +++ b/reproduction/Biaffine_parser/main.py @@ -5,7 +5,7 @@ sys.path.extend(['/home/yfshao/workdir/dev_fastnlp']) import torch import argparse -from reproduction.Biaffine_parser.util import ConllxDataLoader, add_seg_tag +from fastNLP.io.dataset_loader import ConllxDataLoader, add_seg_tag from fastNLP.core.dataset import DataSet from fastNLP.core.instance import Instance diff --git a/reproduction/Biaffine_parser/run.py b/reproduction/Biaffine_parser/run.py index 656da201..ded7487d 100644 --- a/reproduction/Biaffine_parser/run.py +++ b/reproduction/Biaffine_parser/run.py @@ -4,20 +4,15 @@ import sys sys.path.append(os.path.join(os.path.dirname(__file__), '../..')) import fastNLP -import torch from fastNLP.core.trainer import Trainer from fastNLP.core.instance import Instance from fastNLP.api.pipeline import Pipeline from fastNLP.models.biaffine_parser import BiaffineParser, ParserMetric, ParserLoss -from fastNLP.core.vocabulary import Vocabulary -from fastNLP.core.dataset import DataSet from fastNLP.core.tester import Tester from fastNLP.io.config_io import ConfigLoader, ConfigSection from fastNLP.io.model_io import ModelLoader -from fastNLP.io.embed_loader import EmbedLoader -from fastNLP.io.model_io import ModelSaver -from reproduction.Biaffine_parser.util import ConllxDataLoader, MyDataloader +from fastNLP.io.dataset_loader import ConllxDataLoader from fastNLP.api.processor import * BOS = '' @@ -141,7 +136,7 @@ model_args['pos_vocab_size'] = len(pos_v) model_args['num_label'] = len(tag_v) model = BiaffineParser(**model_args.data) -model.reset_parameters() +print(model) word_idxp = IndexerProcessor(word_v, 'words', 'word_seq') pos_idxp = IndexerProcessor(pos_v, 'pos', 'pos_seq') @@ -209,7 +204,8 @@ def save_pipe(path): pipe = Pipeline(processors=[num_p, word_idxp, pos_idxp, seq_p, set_input_p]) pipe.add_processor(ModelProcessor(model=model, batch_size=32)) pipe.add_processor(label_toword_p) - torch.save(pipe, os.path.join(path, 'pipe.pkl')) + os.makedirs(path, exist_ok=True) + torch.save({'pipeline': pipe}, os.path.join(path, 'pipe.pkl')) def test(path): diff --git a/reproduction/Biaffine_parser/util.py b/reproduction/Biaffine_parser/util.py index 793b1fb2..aa40e4e9 100644 --- a/reproduction/Biaffine_parser/util.py +++ b/reproduction/Biaffine_parser/util.py @@ -1,34 +1,3 @@ -class ConllxDataLoader(object): - def load(self, path): - datalist = [] - with open(path, 'r', encoding='utf-8') as f: - sample = [] - for line in f: - if line.startswith('\n'): - datalist.append(sample) - sample = [] - elif line.startswith('#'): - continue - else: - sample.append(line.split('\t')) - if len(sample) > 0: - datalist.append(sample) - - data = [self.get_one(sample) for sample in datalist] - return list(filter(lambda x: x is not None, data)) - - def get_one(self, sample): - sample = list(map(list, zip(*sample))) - if len(sample) == 0: - return None - for w in sample[7]: - if w == '_': - print('Error Sample {}'.format(sample)) - return None - # return word_seq, pos_seq, head_seq, head_tag_seq - return sample[1], sample[3], list(map(int, sample[6])), sample[7] - - class MyDataloader: def load(self, data_path): with open(data_path, "r", encoding="utf-8") as f: @@ -56,23 +25,3 @@ class MyDataloader: return data -def add_seg_tag(data): - """ - - :param data: list of ([word], [pos], [heads], [head_tags]) - :return: list of ([word], [pos]) - """ - - _processed = [] - for word_list, pos_list, _, _ in data: - new_sample = [] - for word, pos in zip(word_list, pos_list): - if len(word) == 1: - new_sample.append((word, 'S-' + pos)) - else: - new_sample.append((word[0], 'B-' + pos)) - for c in word[1:-1]: - new_sample.append((c, 'M-' + pos)) - new_sample.append((word[-1], 'E-' + pos)) - _processed.append(list(map(list, zip(*new_sample)))) - return _processed \ No newline at end of file diff --git a/reproduction/chinese_word_segment/cws.cfg b/reproduction/Chinese_word_segmentation/cws.cfg similarity index 100% rename from reproduction/chinese_word_segment/cws.cfg rename to reproduction/Chinese_word_segmentation/cws.cfg diff --git a/reproduction/chinese_word_segment/cws_io/__init__.py b/reproduction/Chinese_word_segmentation/cws_io/__init__.py similarity index 100% rename from reproduction/chinese_word_segment/cws_io/__init__.py rename to reproduction/Chinese_word_segmentation/cws_io/__init__.py diff --git a/reproduction/Chinese_word_segmentation/cws_io/cws_reader.py b/reproduction/Chinese_word_segmentation/cws_io/cws_reader.py new file mode 100644 index 00000000..b28b04f6 --- /dev/null +++ b/reproduction/Chinese_word_segmentation/cws_io/cws_reader.py @@ -0,0 +1,3 @@ + + + diff --git a/reproduction/chinese_word_segment/models/__init__.py b/reproduction/Chinese_word_segmentation/models/__init__.py similarity index 100% rename from reproduction/chinese_word_segment/models/__init__.py rename to reproduction/Chinese_word_segmentation/models/__init__.py diff --git a/reproduction/chinese_word_segment/models/cws_model.py b/reproduction/Chinese_word_segmentation/models/cws_model.py similarity index 98% rename from reproduction/chinese_word_segment/models/cws_model.py rename to reproduction/Chinese_word_segmentation/models/cws_model.py index c6cf6746..daefc380 100644 --- a/reproduction/chinese_word_segment/models/cws_model.py +++ b/reproduction/Chinese_word_segmentation/models/cws_model.py @@ -1,11 +1,11 @@ -from torch import nn import torch -import torch.nn.functional as F +from torch import nn -from fastNLP.modules.decoder.MLP import MLP from fastNLP.models.base_model import BaseModel -from reproduction.chinese_word_segment.utils import seq_lens_to_mask +from fastNLP.modules.decoder.MLP import MLP +from reproduction.Chinese_word_segmentation.utils import seq_lens_to_mask + class CWSBiLSTMEncoder(BaseModel): def __init__(self, vocab_num, embed_dim=100, bigram_vocab_num=None, bigram_embed_dim=100, num_bigram_per_char=None, diff --git a/reproduction/chinese_word_segment/process/__init__.py b/reproduction/Chinese_word_segmentation/process/__init__.py similarity index 100% rename from reproduction/chinese_word_segment/process/__init__.py rename to reproduction/Chinese_word_segmentation/process/__init__.py diff --git a/reproduction/chinese_word_segment/process/cws_processor.py b/reproduction/Chinese_word_segmentation/process/cws_processor.py similarity index 75% rename from reproduction/chinese_word_segment/process/cws_processor.py rename to reproduction/Chinese_word_segmentation/process/cws_processor.py index 9e57d35a..614d9ef5 100644 --- a/reproduction/chinese_word_segment/process/cws_processor.py +++ b/reproduction/Chinese_word_segmentation/process/cws_processor.py @@ -4,7 +4,7 @@ import re from fastNLP.api.processor import Processor from fastNLP.core.dataset import DataSet from fastNLP.core.vocabulary import Vocabulary -from reproduction.chinese_word_segment.process.span_converter import SpanConverter +from reproduction.Chinese_word_segmentation.process.span_converter import SpanConverter _SPECIAL_TAG_PATTERN = '<[a-zA-Z]+>' @@ -226,109 +226,6 @@ class Pre2Post2BigramProcessor(BigramProcessor): return bigrams -# 这里需要建立vocabulary了,但是遇到了以下的问题 -# (1) 如果使用Processor的方式的话,但是在这种情况返回的不是dataset。所以建立vocabulary的工作用另外的方式实现,不借用 -# Processor了 -# TODO 如何将建立vocab和index这两步统一了? - -class VocabIndexerProcessor(Processor): - """ - 根据DataSet创建Vocabulary,并将其用数字index。新生成的index的field会被放在new_added_filed_name, 如果没有提供 - new_added_field_name, 则覆盖原有的field_name. - - """ - def __init__(self, field_name, new_added_filed_name=None, min_freq=1, max_size=None, - verbose=0, is_input=True): - """ - - :param field_name: 从哪个field_name创建词表,以及对哪个field_name进行index操作 - :param new_added_filed_name: index时,生成的index field的名称,如果不传入,则覆盖field_name. - :param min_freq: 创建的Vocabulary允许的单词最少出现次数. - :param max_size: 创建的Vocabulary允许的最大的单词数量 - :param verbose: 0, 不输出任何信息;1,输出信息 - :param bool is_input: - """ - super(VocabIndexerProcessor, self).__init__(field_name, new_added_filed_name) - self.min_freq = min_freq - self.max_size = max_size - - self.verbose =verbose - self.is_input = is_input - - def construct_vocab(self, *datasets): - """ - 使用传入的DataSet创建vocabulary - - :param datasets: DataSet类型的数据,用于构建vocabulary - :return: - """ - self.vocab = Vocabulary(min_freq=self.min_freq, max_size=self.max_size) - for dataset in datasets: - assert isinstance(dataset, DataSet), "Only Dataset class is allowed, not {}.".format(type(dataset)) - dataset.apply(lambda ins: self.vocab.update(ins[self.field_name])) - self.vocab.build_vocab() - if self.verbose: - print("Vocabulary Constructed, has {} items.".format(len(self.vocab))) - - def process(self, *datasets, only_index_dataset=None): - """ - 若还未建立Vocabulary,则使用dataset中的DataSet建立vocabulary;若已经有了vocabulary则使用已有的vocabulary。得到vocabulary - 后,则会index datasets与only_index_dataset。 - - :param datasets: DataSet类型的数据 - :param only_index_dataset: DataSet, or list of DataSet. 该参数中的内容只会被用于index,不会被用于生成vocabulary。 - :return: - """ - if len(datasets)==0 and not hasattr(self,'vocab'): - raise RuntimeError("You have to construct vocabulary first. Or you have to pass datasets to construct it.") - if not hasattr(self, 'vocab'): - self.construct_vocab(*datasets) - else: - if self.verbose: - print("Using constructed vocabulary with {} items.".format(len(self.vocab))) - to_index_datasets = [] - if len(datasets)!=0: - for dataset in datasets: - assert isinstance(dataset, DataSet), "Only DataSet class is allowed, not {}.".format(type(dataset)) - to_index_datasets.append(dataset) - - if not (only_index_dataset is None): - if isinstance(only_index_dataset, list): - for dataset in only_index_dataset: - assert isinstance(dataset, DataSet), "Only DataSet class is allowed, not {}.".format(type(dataset)) - to_index_datasets.append(dataset) - elif isinstance(only_index_dataset, DataSet): - to_index_datasets.append(only_index_dataset) - else: - raise TypeError('Only DataSet or list of DataSet is allowed, not {}.'.format(type(only_index_dataset))) - - for dataset in to_index_datasets: - assert isinstance(dataset, DataSet), "Only DataSet class is allowed, not {}.".format(type(dataset)) - dataset.apply(lambda ins: [self.vocab.to_index(token) for token in ins[self.field_name]], - new_field_name=self.new_added_field_name, is_input=self.is_input) - # 只返回一个,infer时为了跟其他processor保持一致 - if len(to_index_datasets) == 1: - return to_index_datasets[0] - - def set_vocab(self, vocab): - assert isinstance(vocab, Vocabulary), "Only fastNLP.core.Vocabulary is allowed, not {}.".format(type(vocab)) - self.vocab = vocab - - def delete_vocab(self): - del self.vocab - - def get_vocab_size(self): - return len(self.vocab) - - def set_verbose(self, verbose): - """ - 设置processor verbose状态。 - - :param verbose: int, 0,不输出任何信息;1,输出vocab 信息。 - :return: - """ - self.verbose = verbose - class VocabProcessor(Processor): def __init__(self, field_name, min_freq=1, max_size=None): diff --git a/reproduction/chinese_word_segment/process/span_converter.py b/reproduction/Chinese_word_segmentation/process/span_converter.py similarity index 100% rename from reproduction/chinese_word_segment/process/span_converter.py rename to reproduction/Chinese_word_segmentation/process/span_converter.py diff --git a/reproduction/chinese_word_segment/utils.py b/reproduction/Chinese_word_segmentation/utils.py similarity index 100% rename from reproduction/chinese_word_segment/utils.py rename to reproduction/Chinese_word_segmentation/utils.py diff --git a/reproduction/pos_tag_model/pos_processor.py b/reproduction/POS_tagging/pos_processor.py similarity index 100% rename from reproduction/pos_tag_model/pos_processor.py rename to reproduction/POS_tagging/pos_processor.py diff --git a/reproduction/POS_tagging/pos_reader.py b/reproduction/POS_tagging/pos_reader.py new file mode 100644 index 00000000..4ff58f4b --- /dev/null +++ b/reproduction/POS_tagging/pos_reader.py @@ -0,0 +1,29 @@ +from fastNLP.io.dataset_loader import ZhConllPOSReader + + +def cut_long_sentence(sent, max_sample_length=200): + sent_no_space = sent.replace(' ', '') + cutted_sentence = [] + if len(sent_no_space) > max_sample_length: + parts = sent.strip().split() + new_line = '' + length = 0 + for part in parts: + length += len(part) + new_line += part + ' ' + if length > max_sample_length: + new_line = new_line[:-1] + cutted_sentence.append(new_line) + length = 0 + new_line = '' + if new_line != '': + cutted_sentence.append(new_line[:-1]) + else: + cutted_sentence.append(sent) + return cutted_sentence + + +if __name__ == '__main__': + reader = ZhConllPOSReader() + d = reader.load('/home/hyan/train.conllx') + print(d) \ No newline at end of file diff --git a/reproduction/pos_tag_model/pos_tag.cfg b/reproduction/POS_tagging/pos_tag.cfg similarity index 100% rename from reproduction/pos_tag_model/pos_tag.cfg rename to reproduction/POS_tagging/pos_tag.cfg diff --git a/reproduction/pos_tag_model/train_pos_tag.py b/reproduction/POS_tagging/train_pos_tag.py similarity index 95% rename from reproduction/pos_tag_model/train_pos_tag.py rename to reproduction/POS_tagging/train_pos_tag.py index adc9359c..09a9ba02 100644 --- a/reproduction/pos_tag_model/train_pos_tag.py +++ b/reproduction/POS_tagging/train_pos_tag.py @@ -10,13 +10,12 @@ sys.path.append(os.path.join(os.path.dirname(__file__), '../..')) from fastNLP.api.pipeline import Pipeline -from fastNLP.api.processor import SeqLenProcessor +from fastNLP.api.processor import SeqLenProcessor, VocabIndexerProcessor from fastNLP.core.metrics import SpanFPreRecMetric from fastNLP.core.trainer import Trainer from fastNLP.io.config_io import ConfigLoader, ConfigSection from fastNLP.models.sequence_modeling import AdvSeqLabel -from reproduction.chinese_word_segment.process.cws_processor import VocabIndexerProcessor -from reproduction.pos_tag_model.pos_reader import ZhConllPOSReader +from fastNLP.io.dataset_loader import ZhConllPOSReader from fastNLP.api.processor import ModelProcessor, Index2WordProcessor cfgfile = './pos_tag.cfg' diff --git a/reproduction/pos_tag_model/utils.py b/reproduction/POS_tagging/utils.py similarity index 100% rename from reproduction/pos_tag_model/utils.py rename to reproduction/POS_tagging/utils.py diff --git a/reproduction/chinese_word_segment/cws_io/cws_reader.py b/reproduction/chinese_word_segment/cws_io/cws_reader.py deleted file mode 100644 index 34bcf7dd..00000000 --- a/reproduction/chinese_word_segment/cws_io/cws_reader.py +++ /dev/null @@ -1,197 +0,0 @@ - - -from fastNLP.core.dataset import DataSet -from fastNLP.core.instance import Instance -from fastNLP.io.dataset_loader import DataSetLoader - - -def cut_long_sentence(sent, max_sample_length=200): - """ - 将长于max_sample_length的sentence截成多段,只会在有空格的地方发生截断。所以截取的句子可能长于或者短于max_sample_length - - :param sent: str. - :param max_sample_length: int. - :return: list of str. - """ - sent_no_space = sent.replace(' ', '') - cutted_sentence = [] - if len(sent_no_space) > max_sample_length: - parts = sent.strip().split() - new_line = '' - length = 0 - for part in parts: - length += len(part) - new_line += part + ' ' - if length > max_sample_length: - new_line = new_line[:-1] - cutted_sentence.append(new_line) - length = 0 - new_line = '' - if new_line != '': - cutted_sentence.append(new_line[:-1]) - else: - cutted_sentence.append(sent) - return cutted_sentence - -class NaiveCWSReader(DataSetLoader): - """ - 这个reader假设了分词数据集为以下形式, 即已经用空格分割好内容了 - 这是 fastNLP , 一个 非常 good 的 包 . - 或者,即每个part后面还有一个pos tag - 也/D 在/P 團員/Na 之中/Ng ,/COMMACATEGORY - """ - def __init__(self, in_word_splitter=None): - super().__init__() - - self.in_word_splitter = in_word_splitter - - def load(self, filepath, in_word_splitter=None, cut_long_sent=False): - """ - 允许使用的情况有(默认以\t或空格作为seg) - 这是 fastNLP , 一个 非常 good 的 包 . - 和 - 也/D 在/P 團員/Na 之中/Ng ,/COMMACATEGORY - 如果splitter不为None则认为是第二种情况, 且我们会按splitter分割"也/D", 然后取第一部分. 例如"也/D".split('/')[0] - :param filepath: - :param in_word_splitter: - :return: - """ - if in_word_splitter == None: - in_word_splitter = self.in_word_splitter - dataset = DataSet() - with open(filepath, 'r') as f: - for line in f: - line = line.strip() - if len(line.replace(' ', ''))==0: # 不能接受空行 - continue - - if not in_word_splitter is None: - words = [] - for part in line.split(): - word = part.split(in_word_splitter)[0] - words.append(word) - line = ' '.join(words) - if cut_long_sent: - sents = cut_long_sentence(line) - else: - sents = [line] - for sent in sents: - instance = Instance(raw_sentence=sent) - dataset.append(instance) - - return dataset - - -class POSCWSReader(DataSetLoader): - """ - 支持读取以下的情况, 即每一行是一个词, 用空行作为两句话的界限. - 迈 N - 向 N - 充 N - ... - 泽 I-PER - 民 I-PER - - ( N - 一 N - 九 N - ... - - - :param filepath: - :return: - """ - def __init__(self, in_word_splitter=None): - super().__init__() - self.in_word_splitter = in_word_splitter - - def load(self, filepath, in_word_splitter=None, cut_long_sent=False): - if in_word_splitter is None: - in_word_splitter = self.in_word_splitter - dataset = DataSet() - with open(filepath, 'r') as f: - words = [] - for line in f: - line = line.strip() - if len(line) == 0: # new line - if len(words)==0: # 不能接受空行 - continue - line = ' '.join(words) - if cut_long_sent: - sents = cut_long_sentence(line) - else: - sents = [line] - for sent in sents: - instance = Instance(raw_sentence=sent) - dataset.append(instance) - words = [] - else: - line = line.split()[0] - if in_word_splitter is None: - words.append(line) - else: - words.append(line.split(in_word_splitter)[0]) - return dataset - - -class ConllCWSReader(object): - def __init__(self): - pass - - def load(self, path, cut_long_sent=False): - """ - 返回的DataSet只包含raw_sentence这个field,内容为str。 - 假定了输入为conll的格式,以空行隔开两个句子,每行共7列,即 - 1 编者按 编者按 NN O 11 nmod:topic - 2 : : PU O 11 punct - 3 7月 7月 NT DATE 4 compound:nn - 4 12日 12日 NT DATE 11 nmod:tmod - 5 , , PU O 11 punct - - 1 这 这 DT O 3 det - 2 款 款 M O 1 mark:clf - 3 飞行 飞行 NN O 8 nsubj - 4 从 从 P O 5 case - 5 外型 外型 NN O 8 nmod:prep - """ - datalist = [] - with open(path, 'r', encoding='utf-8') as f: - sample = [] - for line in f: - if line.startswith('\n'): - datalist.append(sample) - sample = [] - elif line.startswith('#'): - continue - else: - sample.append(line.split('\t')) - if len(sample) > 0: - datalist.append(sample) - - ds = DataSet() - for sample in datalist: - # print(sample) - res = self.get_char_lst(sample) - if res is None: - continue - line = ' '.join(res) - if cut_long_sent: - sents = cut_long_sentence(line) - else: - sents = [line] - for raw_sentence in sents: - ds.append(Instance(raw_sentence=raw_sentence)) - - return ds - - def get_char_lst(self, sample): - if len(sample)==0: - return None - text = [] - for w in sample: - t1, t2, t3, t4 = w[1], w[3], w[6], w[7] - if t3 == '_': - return None - text.append(t1) - return text - diff --git a/reproduction/chinese_word_segment/models/cws_transformer.py b/reproduction/chinese_word_segment/models/cws_transformer.py index 3fcf91b5..64f9b09f 100644 --- a/reproduction/chinese_word_segment/models/cws_transformer.py +++ b/reproduction/chinese_word_segment/models/cws_transformer.py @@ -28,8 +28,9 @@ class TransformerCWS(nn.Module): self.fc1 = nn.Linear(input_size, hidden_size) value_size = hidden_size//num_heads - self.transformer = TransformerEncoder(num_layers, input_size=input_size, output_size=hidden_size, - key_size=value_size, value_size=value_size, num_atte=num_heads) + self.transformer = TransformerEncoder(num_layers, model_size=hidden_size, inner_size=hidden_size, + key_size=value_size, + value_size=value_size, num_head=num_heads) self.fc2 = nn.Linear(hidden_size, tag_size) @@ -39,7 +40,7 @@ class TransformerCWS(nn.Module): def forward(self, chars, target, seq_lens, bigrams=None): seq_lens = seq_lens - masks = seq_len_to_byte_mask(seq_lens) + masks = seq_len_to_byte_mask(seq_lens).float() x = self.embedding(chars) batch_size = x.size(0) length = x.size(1) diff --git a/reproduction/chinese_word_segment/run.py b/reproduction/chinese_word_segment/run.py deleted file mode 100644 index e7804bae..00000000 --- a/reproduction/chinese_word_segment/run.py +++ /dev/null @@ -1,151 +0,0 @@ -import os -import sys - -sys.path.append(os.path.join(os.path.dirname(__file__), '../..')) - -from fastNLP.io.config_io import ConfigLoader, ConfigSection -from fastNLP.core.trainer import SeqLabelTrainer -from fastNLP.io.dataset_loader import BaseLoader, TokenizeDataSetLoader -from fastNLP.core.utils import load_pickle -from fastNLP.io.model_io import ModelLoader, ModelSaver -from fastNLP.core.tester import SeqLabelTester -from fastNLP.models.sequence_modeling import AdvSeqLabel -from fastNLP.core.predictor import SeqLabelInfer -from fastNLP.core.utils import save_pickle -from fastNLP.core.metrics import SeqLabelEvaluator - -# not in the file's dir -if len(os.path.dirname(__file__)) != 0: - os.chdir(os.path.dirname(__file__)) -datadir = "/home/zyfeng/data/" -cfgfile = './cws.cfg' - -cws_data_path = os.path.join(datadir, "pku_training.utf8") -pickle_path = "save" -data_infer_path = os.path.join(datadir, "infer.utf8") - - -def infer(): - # Config Loader - test_args = ConfigSection() - ConfigLoader().load_config(cfgfile, {"POS_test": test_args}) - - # fetch dictionary size and number of labels from pickle files - word2index = load_pickle(pickle_path, "word2id.pkl") - test_args["vocab_size"] = len(word2index) - index2label = load_pickle(pickle_path, "label2id.pkl") - test_args["num_classes"] = len(index2label) - - # Define the same model - model = AdvSeqLabel(test_args) - - try: - ModelLoader.load_pytorch(model, "./save/trained_model.pkl") - print('model loaded!') - except Exception as e: - print('cannot load model!') - raise - - # Data Loader - infer_data = SeqLabelDataSet(load_func=BaseLoader.load_lines) - infer_data.load(data_infer_path, vocabs={"word_vocab": word2index}, infer=True) - print('data loaded') - - # Inference interface - infer = SeqLabelInfer(pickle_path) - results = infer.predict(model, infer_data) - - print(results) - print("Inference finished!") - - -def train(): - # Config Loader - train_args = ConfigSection() - test_args = ConfigSection() - ConfigLoader().load_config(cfgfile, {"train": train_args, "test": test_args}) - - print("loading data set...") - data = SeqLabelDataSet(load_func=TokenizeDataSetLoader.load) - data.load(cws_data_path) - data_train, data_dev = data.split(ratio=0.3) - train_args["vocab_size"] = len(data.word_vocab) - train_args["num_classes"] = len(data.label_vocab) - print("vocab size={}, num_classes={}".format(len(data.word_vocab), len(data.label_vocab))) - - change_field_is_target(data_dev, "truth", True) - save_pickle(data_dev, "./save/", "data_dev.pkl") - save_pickle(data.word_vocab, "./save/", "word2id.pkl") - save_pickle(data.label_vocab, "./save/", "label2id.pkl") - - # Trainer - trainer = SeqLabelTrainer(epochs=train_args["epochs"], batch_size=train_args["batch_size"], - validate=train_args["validate"], - use_cuda=train_args["use_cuda"], pickle_path=train_args["pickle_path"], - save_best_dev=True, print_every_step=10, model_name="trained_model.pkl", - evaluator=SeqLabelEvaluator()) - - # Model - model = AdvSeqLabel(train_args) - try: - ModelLoader.load_pytorch(model, "./save/saved_model.pkl") - print('model parameter loaded!') - except Exception as e: - print("No saved model. Continue.") - pass - - # Start training - trainer.train(model, data_train, data_dev) - print("Training finished!") - - # Saver - saver = ModelSaver("./save/trained_model.pkl") - saver.save_pytorch(model) - print("Model saved!") - - -def predict(): - # Config Loader - test_args = ConfigSection() - ConfigLoader().load_config(cfgfile, {"POS_test": test_args}) - - # fetch dictionary size and number of labels from pickle files - word2index = load_pickle(pickle_path, "word2id.pkl") - test_args["vocab_size"] = len(word2index) - index2label = load_pickle(pickle_path, "label2id.pkl") - test_args["num_classes"] = len(index2label) - - # load dev data - dev_data = load_pickle(pickle_path, "data_dev.pkl") - - # Define the same model - model = AdvSeqLabel(test_args) - - # Dump trained parameters into the model - ModelLoader.load_pytorch(model, "./save/trained_model.pkl") - print("model loaded!") - - # Tester - test_args["evaluator"] = SeqLabelEvaluator() - tester = SeqLabelTester(**test_args.data) - - # Start testing - tester.test(model, dev_data) - - -if __name__ == "__main__": - - import argparse - - parser = argparse.ArgumentParser(description='Run a chinese word segmentation model') - parser.add_argument('--mode', help='set the model\'s model', choices=['train', 'test', 'infer']) - args = parser.parse_args() - if args.mode == 'train': - train() - elif args.mode == 'test': - predict() - elif args.mode == 'infer': - infer() - else: - print('no mode specified for model!') - parser.print_help() diff --git a/reproduction/pos_tag_model/pos_reader.py b/reproduction/pos_tag_model/pos_reader.py deleted file mode 100644 index c0a8c4cd..00000000 --- a/reproduction/pos_tag_model/pos_reader.py +++ /dev/null @@ -1,153 +0,0 @@ - -from fastNLP.core.dataset import DataSet -from fastNLP.core.instance import Instance - -def cut_long_sentence(sent, max_sample_length=200): - sent_no_space = sent.replace(' ', '') - cutted_sentence = [] - if len(sent_no_space) > max_sample_length: - parts = sent.strip().split() - new_line = '' - length = 0 - for part in parts: - length += len(part) - new_line += part + ' ' - if length > max_sample_length: - new_line = new_line[:-1] - cutted_sentence.append(new_line) - length = 0 - new_line = '' - if new_line != '': - cutted_sentence.append(new_line[:-1]) - else: - cutted_sentence.append(sent) - return cutted_sentence - - -class ConllPOSReader(object): - # 返回的Dataset包含words(list of list, 里层的list是character), tag两个field(list of str, str是标有BIO的tag)。 - def __init__(self): - pass - - def load(self, path): - datalist = [] - with open(path, 'r', encoding='utf-8') as f: - sample = [] - for line in f: - if line.startswith('\n'): - datalist.append(sample) - sample = [] - elif line.startswith('#'): - continue - else: - sample.append(line.split('\t')) - if len(sample) > 0: - datalist.append(sample) - - ds = DataSet() - for sample in datalist: - # print(sample) - res = self.get_one(sample) - if res is None: - continue - char_seq = [] - pos_seq = [] - for word, tag in zip(res[0], res[1]): - if len(word)==1: - char_seq.append(word) - pos_seq.append('S-{}'.format(tag)) - elif len(word)>1: - pos_seq.append('B-{}'.format(tag)) - for _ in range(len(word)-2): - pos_seq.append('M-{}'.format(tag)) - pos_seq.append('E-{}'.format(tag)) - char_seq.extend(list(word)) - else: - raise ValueError("Zero length of word detected.") - - ds.append(Instance(words=char_seq, - tag=pos_seq)) - - return ds - - - -class ZhConllPOSReader(object): - # 中文colln格式reader - def __init__(self): - pass - - def load(self, path): - """ - 返回的DataSet, 包含以下的field - words:list of str, - tag: list of str, 被加入了BMES tag, 比如原来的序列为['VP', 'NN', 'NN', ..],会被认为是["S-VP", "B-NN", "M-NN",..] - 假定了输入为conll的格式,以空行隔开两个句子,每行共7列,即 - 1 编者按 编者按 NN O 11 nmod:topic - 2 : : PU O 11 punct - 3 7月 7月 NT DATE 4 compound:nn - 4 12日 12日 NT DATE 11 nmod:tmod - 5 , , PU O 11 punct - - 1 这 这 DT O 3 det - 2 款 款 M O 1 mark:clf - 3 飞行 飞行 NN O 8 nsubj - 4 从 从 P O 5 case - 5 外型 外型 NN O 8 nmod:prep - """ - datalist = [] - with open(path, 'r', encoding='utf-8') as f: - sample = [] - for line in f: - if line.startswith('\n'): - datalist.append(sample) - sample = [] - elif line.startswith('#'): - continue - else: - sample.append(line.split('\t')) - if len(sample) > 0: - datalist.append(sample) - - ds = DataSet() - for sample in datalist: - # print(sample) - res = self.get_one(sample) - if res is None: - continue - char_seq = [] - pos_seq = [] - for word, tag in zip(res[0], res[1]): - char_seq.extend(list(word)) - if len(word)==1: - pos_seq.append('S-{}'.format(tag)) - elif len(word)>1: - pos_seq.append('B-{}'.format(tag)) - for _ in range(len(word)-2): - pos_seq.append('M-{}'.format(tag)) - pos_seq.append('E-{}'.format(tag)) - else: - raise ValueError("Zero length of word detected.") - - ds.append(Instance(words=char_seq, - tag=pos_seq)) - - return ds - - def get_one(self, sample): - if len(sample)==0: - return None - text = [] - pos_tags = [] - for w in sample: - t1, t2, t3, t4 = w[1], w[3], w[6], w[7] - if t3 == '_': - return None - text.append(t1) - pos_tags.append(t2) - return text, pos_tags - -if __name__ == '__main__': - reader = ZhConllPOSReader() - d = reader.load('/home/hyan/train.conllx') - print(d) \ No newline at end of file diff --git a/test/core/test_batch.py b/test/core/test_batch.py index 08d803f1..1c4b22f8 100644 --- a/test/core/test_batch.py +++ b/test/core/test_batch.py @@ -1,6 +1,7 @@ import unittest import numpy as np +import torch from fastNLP.core.batch import Batch from fastNLP.core.dataset import DataSet @@ -31,3 +32,47 @@ class TestCase1(unittest.TestCase): self.assertEqual(len(y["y"]), 4) self.assertListEqual(list(x["x"][-1]), [1, 2, 3, 4]) self.assertListEqual(list(y["y"][-1]), [5, 6]) + + def test_list_padding(self): + ds = DataSet({"x": [[1], [1, 2], [1, 2, 3], [1, 2, 3, 4]] * 10, + "y": [[4, 3, 2, 1], [3, 2, 1], [2, 1], [1]] * 10}) + ds.set_input("x") + ds.set_target("y") + iter = Batch(ds, batch_size=4, sampler=SequentialSampler(), as_numpy=True) + for x, y in iter: + self.assertEqual(x["x"].shape, (4, 4)) + self.assertEqual(y["y"].shape, (4, 4)) + + def test_numpy_padding(self): + ds = DataSet({"x": np.array([[1], [1, 2], [1, 2, 3], [1, 2, 3, 4]] * 10), + "y": np.array([[4, 3, 2, 1], [3, 2, 1], [2, 1], [1]] * 10)}) + ds.set_input("x") + ds.set_target("y") + iter = Batch(ds, batch_size=4, sampler=SequentialSampler(), as_numpy=True) + for x, y in iter: + self.assertEqual(x["x"].shape, (4, 4)) + self.assertEqual(y["y"].shape, (4, 4)) + + def test_list_to_tensor(self): + ds = DataSet({"x": [[1], [1, 2], [1, 2, 3], [1, 2, 3, 4]] * 10, + "y": [[4, 3, 2, 1], [3, 2, 1], [2, 1], [1]] * 10}) + ds.set_input("x") + ds.set_target("y") + iter = Batch(ds, batch_size=4, sampler=SequentialSampler(), as_numpy=False) + for x, y in iter: + self.assertTrue(isinstance(x["x"], torch.Tensor)) + self.assertEqual(tuple(x["x"].shape), (4, 4)) + self.assertTrue(isinstance(y["y"], torch.Tensor)) + self.assertEqual(tuple(y["y"].shape), (4, 4)) + + def test_numpy_to_tensor(self): + ds = DataSet({"x": np.array([[1], [1, 2], [1, 2, 3], [1, 2, 3, 4]] * 10), + "y": np.array([[4, 3, 2, 1], [3, 2, 1], [2, 1], [1]] * 10)}) + ds.set_input("x") + ds.set_target("y") + iter = Batch(ds, batch_size=4, sampler=SequentialSampler(), as_numpy=False) + for x, y in iter: + self.assertTrue(isinstance(x["x"], torch.Tensor)) + self.assertEqual(tuple(x["x"].shape), (4, 4)) + self.assertTrue(isinstance(y["y"], torch.Tensor)) + self.assertEqual(tuple(y["y"].shape), (4, 4)) diff --git a/test/models/test_biaffine_parser.py b/test/models/test_biaffine_parser.py index 54935f76..d87000a0 100644 --- a/test/models/test_biaffine_parser.py +++ b/test/models/test_biaffine_parser.py @@ -77,9 +77,10 @@ class TestBiaffineParser(unittest.TestCase): ds, v1, v2, v3 = init_data() model = BiaffineParser(word_vocab_size=len(v1), word_emb_dim=30, pos_vocab_size=len(v2), pos_emb_dim=30, - num_label=len(v3), use_var_lstm=True) + num_label=len(v3), encoder='var-lstm') trainer = fastNLP.Trainer(model=model, train_data=ds, dev_data=ds, loss=ParserLoss(), metrics=ParserMetric(), metric_key='UAS', + batch_size=1, validate_every=10, n_epochs=10, use_cuda=False, use_tqdm=False) trainer.train(load_best_model=False)