Browse Source

新增NER的数据加载与模型代码; 修改metric中的typo; 修改LSTM中的默认初始化将forget gate设置为1.

tags/v0.4.10
yh_cc 6 years ago
parent
commit
9a8fe42cd4
12 changed files with 469 additions and 17 deletions
  1. +5
    -5
      fastNLP/core/metrics.py
  2. +11
    -8
      fastNLP/modules/encoder/embedding.py
  3. +6
    -4
      fastNLP/modules/encoder/lstm.py
  4. +0
    -0
      reproduction/seqence_labelling/ner/__init__.py
  5. +92
    -0
      reproduction/seqence_labelling/ner/data/Conll2003Loader.py
  6. +130
    -0
      reproduction/seqence_labelling/ner/data/OntoNoteLoader.py
  7. +49
    -0
      reproduction/seqence_labelling/ner/data/utils.py
  8. +62
    -0
      reproduction/seqence_labelling/ner/model/lstm_cnn_crf.py
  9. +0
    -0
      reproduction/seqence_labelling/ner/test/__init__.py
  10. +33
    -0
      reproduction/seqence_labelling/ner/test/test.py
  11. +42
    -0
      reproduction/seqence_labelling/ner/train_cnn_lstm_crf_conll2003.py
  12. +39
    -0
      reproduction/seqence_labelling/ner/train_ontonote.py

+ 5
- 5
fastNLP/core/metrics.py View File

@@ -428,16 +428,16 @@ def _bioes_tag_to_spans(tags, ignore_labels=None):
prev_bioes_tag = None prev_bioes_tag = None
for idx, tag in enumerate(tags): for idx, tag in enumerate(tags):
tag = tag.lower() tag = tag.lower()
bieso_tag, label = tag[:1], tag[2:]
if bieso_tag in ('b', 's'):
bioes_tag, label = tag[:1], tag[2:]
if bioes_tag in ('b', 's'):
spans.append((label, [idx, idx])) spans.append((label, [idx, idx]))
elif bieso_tag in ('i', 'e') and prev_bioes_tag in ('b', 'i') and label == spans[-1][0]:
elif bioes_tag in ('i', 'e') and prev_bioes_tag in ('b', 'i') and label == spans[-1][0]:
spans[-1][1][1] = idx spans[-1][1][1] = idx
elif bieso_tag == 'o':
elif bioes_tag == 'o':
pass pass
else: else:
spans.append((label, [idx, idx])) spans.append((label, [idx, idx]))
prev_bioes_tag = bieso_tag
prev_bioes_tag = bioes_tag
return [(span[0], (span[1][0], span[1][1] + 1)) return [(span[0], (span[1][0], span[1][1] + 1))
for span in spans for span in spans
if span[0] not in ignore_labels if span[0] not in ignore_labels


+ 11
- 8
fastNLP/modules/encoder/embedding.py View File

@@ -500,8 +500,8 @@ class CNNCharEmbedding(TokenEmbedding):
""" """
别名::class:`fastNLP.modules.CNNCharEmbedding` :class:`fastNLP.modules.encoder.embedding.CNNCharEmbedding` 别名::class:`fastNLP.modules.CNNCharEmbedding` :class:`fastNLP.modules.encoder.embedding.CNNCharEmbedding`


使用CNN生成character embedding。CNN的结果为, CNN(x) -> activation(x) -> pool -> fc. 不同的kernel大小的fitler结果是
concat起来的。
使用CNN生成character embedding。CNN的结果为, embed(x) -> Dropout(x) -> CNN(x) -> activation(x) -> pool
-> fc. 不同的kernel大小的fitler结果是concat起来的。


Example:: Example::


@@ -511,13 +511,14 @@ class CNNCharEmbedding(TokenEmbedding):
:param vocab: 词表 :param vocab: 词表
:param embed_size: 该word embedding的大小,默认值为50. :param embed_size: 该word embedding的大小,默认值为50.
:param char_emb_size: character的embed的大小。character是从vocab中生成的。默认值为50. :param char_emb_size: character的embed的大小。character是从vocab中生成的。默认值为50.
:param dropout: 以多大的概率drop
:param filter_nums: filter的数量. 长度需要和kernels一致。默认值为[40, 30, 20]. :param filter_nums: filter的数量. 长度需要和kernels一致。默认值为[40, 30, 20].
:param kernel_sizes: kernel的大小. 默认值为[5, 3, 1]. :param kernel_sizes: kernel的大小. 默认值为[5, 3, 1].
:param pool_method: character的表示在合成一个表示时所使用的pool方法,支持'avg', 'max'. :param pool_method: character的表示在合成一个表示时所使用的pool方法,支持'avg', 'max'.
:param activation: CNN之后使用的激活方法,支持'relu', 'sigmoid', 'tanh' 或者自定义函数. :param activation: CNN之后使用的激活方法,支持'relu', 'sigmoid', 'tanh' 或者自定义函数.
:param min_char_freq: character的最少出现次数。默认值为2. :param min_char_freq: character的最少出现次数。默认值为2.
""" """
def __init__(self, vocab: Vocabulary, embed_size: int=50, char_emb_size: int=50,
def __init__(self, vocab: Vocabulary, embed_size: int=50, char_emb_size: int=50, dropout:float=0.5,
filter_nums: List[int]=(40, 30, 20), kernel_sizes: List[int]=(5, 3, 1), pool_method: str='max', filter_nums: List[int]=(40, 30, 20), kernel_sizes: List[int]=(5, 3, 1), pool_method: str='max',
activation='relu', min_char_freq: int=2): activation='relu', min_char_freq: int=2):
super(CNNCharEmbedding, self).__init__(vocab) super(CNNCharEmbedding, self).__init__(vocab)
@@ -526,6 +527,7 @@ class CNNCharEmbedding(TokenEmbedding):
assert kernel % 2 == 1, "Only odd kernel is allowed." assert kernel % 2 == 1, "Only odd kernel is allowed."


assert pool_method in ('max', 'avg') assert pool_method in ('max', 'avg')
self.dropout = nn.Dropout(dropout, inplace=True)
self.pool_method = pool_method self.pool_method = pool_method
# activation function # activation function
if isinstance(activation, str): if isinstance(activation, str):
@@ -583,7 +585,7 @@ class CNNCharEmbedding(TokenEmbedding):
# 为1的地方为mask # 为1的地方为mask
chars_masks = chars.eq(self.char_pad_index) # batch_size x max_len x max_word_len 如果为0, 说明是padding的位置了 chars_masks = chars.eq(self.char_pad_index) # batch_size x max_len x max_word_len 如果为0, 说明是padding的位置了
chars = self.char_embedding(chars) # batch_size x max_len x max_word_len x embed_size chars = self.char_embedding(chars) # batch_size x max_len x max_word_len x embed_size
chars = self.dropout(chars)
reshaped_chars = chars.reshape(batch_size*max_len, max_word_len, -1) reshaped_chars = chars.reshape(batch_size*max_len, max_word_len, -1)
reshaped_chars = reshaped_chars.transpose(1, 2) # B' x E x M reshaped_chars = reshaped_chars.transpose(1, 2) # B' x E x M
conv_chars = [conv(reshaped_chars).transpose(1, 2).reshape(batch_size, max_len, max_word_len, -1) conv_chars = [conv(reshaped_chars).transpose(1, 2).reshape(batch_size, max_len, max_word_len, -1)
@@ -635,7 +637,7 @@ class LSTMCharEmbedding(TokenEmbedding):
""" """
别名::class:`fastNLP.modules.LSTMCharEmbedding` :class:`fastNLP.modules.encoder.embedding.LSTMCharEmbedding` 别名::class:`fastNLP.modules.LSTMCharEmbedding` :class:`fastNLP.modules.encoder.embedding.LSTMCharEmbedding`


使用LSTM的方式对character进行encode.
使用LSTM的方式对character进行encode. embed(x) -> Dropout(x) -> LSTM(x) -> activation(x) -> pool


Example:: Example::


@@ -644,13 +646,14 @@ class LSTMCharEmbedding(TokenEmbedding):
:param vocab: 词表 :param vocab: 词表
:param embed_size: embedding的大小。默认值为50. :param embed_size: embedding的大小。默认值为50.
:param char_emb_size: character的embedding的大小。默认值为50. :param char_emb_size: character的embedding的大小。默认值为50.
:param dropout: 以多大概率drop
:param hidden_size: LSTM的中间hidden的大小,如果为bidirectional的,hidden会除二,默认为50. :param hidden_size: LSTM的中间hidden的大小,如果为bidirectional的,hidden会除二,默认为50.
:param pool_method: 支持'max', 'avg' :param pool_method: 支持'max', 'avg'
:param activation: 激活函数,支持'relu', 'sigmoid', 'tanh', 或者自定义函数. :param activation: 激活函数,支持'relu', 'sigmoid', 'tanh', 或者自定义函数.
:param min_char_freq: character的最小出现次数。默认值为2. :param min_char_freq: character的最小出现次数。默认值为2.
:param bidirectional: 是否使用双向的LSTM进行encode。默认值为True。 :param bidirectional: 是否使用双向的LSTM进行encode。默认值为True。
""" """
def __init__(self, vocab: Vocabulary, embed_size: int=50, char_emb_size: int=50, hidden_size=50,
def __init__(self, vocab: Vocabulary, embed_size: int=50, char_emb_size: int=50, dropout:float=0.5, hidden_size=50,
pool_method: str='max', activation='relu', min_char_freq: int=2, bidirectional=True): pool_method: str='max', activation='relu', min_char_freq: int=2, bidirectional=True):
super(LSTMCharEmbedding, self).__init__(vocab) super(LSTMCharEmbedding, self).__init__(vocab)


@@ -658,7 +661,7 @@ class LSTMCharEmbedding(TokenEmbedding):


assert pool_method in ('max', 'avg') assert pool_method in ('max', 'avg')
self.pool_method = pool_method self.pool_method = pool_method
self.dropout = nn.Dropout(dropout, inplace=True)
# activation function # activation function
if isinstance(activation, str): if isinstance(activation, str):
if activation.lower() == 'relu': if activation.lower() == 'relu':
@@ -715,7 +718,7 @@ class LSTMCharEmbedding(TokenEmbedding):
# 为mask的地方为1 # 为mask的地方为1
chars_masks = chars.eq(self.char_pad_index) # batch_size x max_len x max_word_len 如果为0, 说明是padding的位置了 chars_masks = chars.eq(self.char_pad_index) # batch_size x max_len x max_word_len 如果为0, 说明是padding的位置了
chars = self.char_embedding(chars) # batch_size x max_len x max_word_len x embed_size chars = self.char_embedding(chars) # batch_size x max_len x max_word_len x embed_size
chars = self.dropout(chars)
reshaped_chars = chars.reshape(batch_size * max_len, max_word_len, -1) reshaped_chars = chars.reshape(batch_size * max_len, max_word_len, -1)
char_seq_len = chars_masks.eq(0).sum(dim=-1).reshape(batch_size * max_len) char_seq_len = chars_masks.eq(0).sum(dim=-1).reshape(batch_size * max_len)
lstm_chars = self.lstm(reshaped_chars, char_seq_len)[0].reshape(batch_size, max_len, max_word_len, -1) lstm_chars = self.lstm(reshaped_chars, char_seq_len)[0].reshape(batch_size, max_len, max_word_len, -1)


+ 6
- 4
fastNLP/modules/encoder/lstm.py View File

@@ -40,12 +40,14 @@ class LSTM(nn.Module):


def init_param(self): def init_param(self):
for name, param in self.named_parameters(): for name, param in self.named_parameters():
if 'bias_i' in name:
param.data.fill_(1)
elif 'bias_h' in name:
if 'bias' in name:
# based on https://github.com/pytorch/pytorch/issues/750#issuecomment-280671871
param.data.fill_(0) param.data.fill_(0)
n = param.size(0)
start, end = n // 4, n // 2
param.data[start:end].fill_(1)
else: else:
nn.init.xavier_normal_(param)
nn.init.xavier_uniform_(param)


def forward(self, x, seq_len=None, h0=None, c0=None): def forward(self, x, seq_len=None, h0=None, c0=None):
""" """


+ 0
- 0
reproduction/seqence_labelling/ner/__init__.py View File


+ 92
- 0
reproduction/seqence_labelling/ner/data/Conll2003Loader.py View File

@@ -0,0 +1,92 @@

from fastNLP.core.vocabulary import VocabularyOption
from fastNLP.io.base_loader import DataSetLoader, DataInfo
from typing import Union, Dict
from fastNLP import Vocabulary
from fastNLP import Const
from reproduction.utils import check_dataloader_paths

from fastNLP.io.dataset_loader import ConllLoader
from reproduction.seqence_labelling.ner.data.utils import iob2bioes, iob2


class Conll2003DataLoader(DataSetLoader):
def __init__(self, task:str='ner', encoding_type:str='bioes'):
"""
加载Conll2003格式的英语语料,该数据集的信息可以在https://www.clips.uantwerpen.be/conll2003/ner/找到。当task为pos
时,返回的DataSet中target取值于第2列; 当task为chunk时,返回的DataSet中target取值于第3列;当task为ner时,返回
的DataSet中target取值于第4列。所有"-DOCSTART- -X- O O"将被忽略,这会导致数据的数量少于很多文献报道的值,但
鉴于"-DOCSTART- -X- O O"只是用于文档分割的符号,并不应该作为预测对象,所以我们忽略了数据中的中该值
ner与chunk任务读取后的数据的target将为encoding_type类型。pos任务读取后就是pos列的数据。

:param task: 指定需要标注任务。可选ner, pos, chunk
"""
assert task in ('ner', 'pos', 'chunk')
index = {'ner':3, 'pos':1, 'chunk':2}[task]
self._loader = ConllLoader(headers=['raw_words', 'target'], indexes=[0, index])
self._tag_converters = None
if task in ('ner', 'chunk'):
self._tag_converters = [iob2]
if encoding_type == 'bioes':
self._tag_converters.append(iob2bioes)

def load(self, path: str):
dataset = self._loader.load(path)
def convert_tag_schema(tags):
for converter in self._tag_converters:
tags = converter(tags)
return tags
if self._tag_converters:
dataset.apply_field(convert_tag_schema, field_name=Const.TARGET, new_field_name=Const.TARGET)
return dataset

def process(self, paths: Union[str, Dict[str, str]], word_vocab_opt:VocabularyOption=None, lower:bool=True):
"""
读取并处理数据。数据中的'-DOCSTART-'开头的行会被忽略

:param paths:
:param word_vocab_opt: vocabulary的初始化值
:param lower: 是否将所有字母转为小写
:return:
"""
# 读取数据
paths = check_dataloader_paths(paths)
data = DataInfo()
input_fields = [Const.TARGET, Const.INPUT, Const.INPUT_LEN]
target_fields = [Const.TARGET, Const.INPUT_LEN]
for name, path in paths.items():
dataset = self.load(path)
dataset.apply_field(lambda words: words, field_name='raw_words', new_field_name=Const.INPUT)
if lower:
dataset.apply_field(lambda words:[word.lower() for word in words], field_name=Const.INPUT,
new_field_name=Const.INPUT)
data.datasets[name] = dataset

# 对construct vocab
word_vocab = Vocabulary(min_freq=3) if word_vocab_opt is None else Vocabulary(**word_vocab_opt)
word_vocab.from_dataset(data.datasets['train'], field_name=Const.INPUT)
word_vocab.index_dataset(*data.datasets.values(), field_name=Const.INPUT, new_field_name=Const.INPUT)
data.vocabs[Const.INPUT] = word_vocab

# cap words
cap_word_vocab = Vocabulary()
cap_word_vocab.from_dataset(*data.datasets.values(), field_name='raw_words')
cap_word_vocab.index_dataset(*data.datasets.values(), field_name='raw_words', new_field_name='cap_words')
input_fields.append('cap_words')
data.vocabs['cap_words'] = cap_word_vocab

# 对target建vocab
target_vocab = Vocabulary(unknown=None, padding=None)
target_vocab.from_dataset(*data.datasets.values(), field_name=Const.TARGET)
target_vocab.index_dataset(*data.datasets.values(), field_name=Const.TARGET)
data.vocabs[Const.TARGET] = target_vocab

for name, dataset in data.datasets.items():
dataset.add_seq_len(Const.INPUT, new_field_name=Const.INPUT_LEN)
dataset.set_input(*input_fields)
dataset.set_target(*target_fields)

return data

if __name__ == '__main__':
pass

+ 130
- 0
reproduction/seqence_labelling/ner/data/OntoNoteLoader.py View File

@@ -0,0 +1,130 @@
from fastNLP.core.vocabulary import VocabularyOption
from fastNLP.io.base_loader import DataSetLoader, DataInfo
from typing import Union, Dict
from fastNLP import DataSet
from fastNLP import Vocabulary
from fastNLP import Const
from reproduction.utils import check_dataloader_paths

from fastNLP.io.dataset_loader import ConllLoader
from reproduction.seqence_labelling.ner.data.utils import iob2bioes, iob2

class OntoNoteNERDataLoader(DataSetLoader):
"""
用于读取处理为Conll格式后的OntoNote数据。将OntoNote数据处理为conll格式的过程可以参考https://github.com/yhcc/OntoNotes-5.0-NER。

"""
def __init__(self, encoding_type:str='bioes'):
assert encoding_type in ('bioes', 'bio')
self.encoding_type = encoding_type
if encoding_type=='bioes':
self.encoding_method = iob2bioes
else:
self.encoding_method = iob2

def load(self, path:str)->DataSet:
"""
给定一个文件路径,读取数据。返回的DataSet包含以下的field
raw_words: List[str]
target: List[str]

:param path:
:return:
"""
dataset = ConllLoader(headers=['raw_words', 'target'], indexes=[3, 10]).load(path)
def convert_to_bio(tags):
bio_tags = []
flag = None
for tag in tags:
label = tag.strip("()*")
if '(' in tag:
bio_label = 'B-' + label
flag = label
elif flag:
bio_label = 'I-' + flag
else:
bio_label = 'O'
if ')' in tag:
flag = None
bio_tags.append(bio_label)
return self.encoding_method(bio_tags)

dataset.apply_field(convert_to_bio, field_name='target', new_field_name='target')

return dataset

def process(self, paths: Union[str, Dict[str, str]], word_vocab_opt:VocabularyOption=None,
lower:bool=True)->DataInfo:
"""
读取并处理数据。返回的DataInfo包含以下的内容
vocabs:
word: Vocabulary
target: Vocabulary
datasets:
train: DataSet
words: List[int], 被设置为input
target: int. label,被同时设置为input和target
seq_len: int. 句子的长度,被同时设置为input和target
raw_words: List[str]
xxx(根据传入的paths可能有所变化)

:param paths:
:param word_vocab_opt: vocabulary的初始化值
:param lower: 是否使用小写
:return:
"""
paths = check_dataloader_paths(paths)
data = DataInfo()
input_fields = [Const.TARGET, Const.INPUT, Const.INPUT_LEN]
target_fields = [Const.TARGET, Const.INPUT_LEN]
for name, path in paths.items():
dataset = self.load(path)
dataset.apply_field(lambda words: words, field_name='raw_words', new_field_name=Const.INPUT)
if lower:
dataset.apply_field(lambda words:[word.lower() for word in words], field_name=Const.INPUT,
new_field_name=Const.INPUT)
data.datasets[name] = dataset

# 对construct vocab
word_vocab = Vocabulary(min_freq=2) if word_vocab_opt is None else Vocabulary(**word_vocab_opt)
word_vocab.from_dataset(data.datasets['train'], field_name='raw_words')
word_vocab.index_dataset(*data.datasets.values(), field_name='raw_words', new_field_name=Const.INPUT)
data.vocabs[Const.INPUT] = word_vocab

# cap words
cap_word_vocab = Vocabulary()
cap_word_vocab.from_dataset(data.datasets['train'], field_name='raw_words')
cap_word_vocab.index_dataset(*data.datasets.values(), field_name='raw_words', new_field_name='cap_words')
input_fields.append('cap_words')
data.vocabs['cap_words'] = cap_word_vocab

# 对target建vocab
target_vocab = Vocabulary(unknown=None, padding=None)
target_vocab.from_dataset(*data.datasets.values(), field_name=Const.TARGET)
target_vocab.index_dataset(*data.datasets.values(), field_name=Const.TARGET)
data.vocabs[Const.TARGET] = target_vocab

for name, dataset in data.datasets.items():
dataset.add_seq_len(Const.INPUT, new_field_name=Const.INPUT_LEN)
dataset.set_input(*input_fields)
dataset.set_target(*target_fields)

return data


if __name__ == '__main__':
loader = OntoNoteNERDataLoader()
dataset = loader.load('/hdd/fudanNLP/fastNLP/others/data/v4/english/test.txt')
print(dataset.target.value_count())
print(dataset[:4])


"""
train 115812 2200752
development 15680 304684
test 12217 230111

train 92403 1901772
valid 13606 279180
test 10258 204135
"""

+ 49
- 0
reproduction/seqence_labelling/ner/data/utils.py View File

@@ -0,0 +1,49 @@
from typing import List

def iob2(tags:List[str])->List[str]:
"""
检查数据是否是合法的IOB数据,如果是IOB1会被自动转换为IOB2。

:param tags: 需要转换的tags
"""
for i, tag in enumerate(tags):
if tag == "O":
continue
split = tag.split("-")
if len(split) != 2 or split[0] not in ["I", "B"]:
raise TypeError("The encoding schema is not a valid IOB type.")
if split[0] == "B":
continue
elif i == 0 or tags[i - 1] == "O": # conversion IOB1 to IOB2
tags[i] = "B" + tag[1:]
elif tags[i - 1][1:] == tag[1:]:
continue
else: # conversion IOB1 to IOB2
tags[i] = "B" + tag[1:]
return tags

def iob2bioes(tags:List[str])->List[str]:
"""
将iob的tag转换为bmeso编码
:param tags:
:return:
"""
new_tags = []
for i, tag in enumerate(tags):
if tag == 'O':
new_tags.append(tag)
else:
split = tag.split('-')[0]
if split == 'B':
if i+1!=len(tags) and tags[i+1].split('-')[0] == 'I':
new_tags.append(tag)
else:
new_tags.append(tag.replace('B-', 'S-'))
elif split == 'I':
if i + 1<len(tags) and tags[i+1].split('-')[0] == 'I':
new_tags.append(tag)
else:
new_tags.append(tag.replace('I-', 'E-'))
else:
raise TypeError("Invalid IOB format.")
return new_tags

+ 62
- 0
reproduction/seqence_labelling/ner/model/lstm_cnn_crf.py View File

@@ -0,0 +1,62 @@

import torch
from torch import nn
from fastNLP import seq_len_to_mask
from fastNLP.modules import Embedding
from fastNLP.modules import LSTM
from fastNLP.modules import ConditionalRandomField, allowed_transitions, TimestepDropout
import torch.nn.functional as F
from fastNLP import Const

class CNNBiLSTMCRF(nn.Module):
def __init__(self, embed, char_embed, hidden_size, num_layers, tag_vocab, dropout=0.5, encoding_type='bioes'):
super().__init__()

self.embedding = Embedding(embed, dropout=0.5)
self.char_embedding = Embedding(char_embed, dropout=0.5)
self.lstm = LSTM(input_size=self.embedding.embedding_dim+self.char_embedding.embedding_dim,
hidden_size=hidden_size//2, num_layers=num_layers,
bidirectional=True, batch_first=True, dropout=dropout)
self.forward_fc = nn.Linear(hidden_size//2, len(tag_vocab))
self.backward_fc = nn.Linear(hidden_size//2, len(tag_vocab))

transitions = allowed_transitions(tag_vocab.idx2word, encoding_type=encoding_type, include_start_end=False)
self.crf = ConditionalRandomField(len(tag_vocab), include_start_end_trans=False, allowed_transitions=transitions)

self.dropout = TimestepDropout(dropout, inplace=True)

for name, param in self.named_parameters():
if 'ward_fc' in name:
if param.data.dim()>1:
nn.init.xavier_normal_(param)
else:
nn.init.constant_(param, 0)
if 'crf' in name:
nn.init.zeros_(param)

def _forward(self, words, cap_words, seq_len, target=None):
words = self.embedding(words)
chars = self.char_embedding(cap_words)
words = torch.cat([words, chars], dim=-1)
outputs, _ = self.lstm(words, seq_len)
self.dropout(outputs)
forwards, backwards = outputs.chunk(2, dim=-1)

# forward_logits = F.log_softmax(self.forward_fc(forwards), dim=-1)
# backward_logits = F.log_softmax(self.backward_fc(backwards), dim=-1)

logits = self.forward_fc(forwards) + self.backward_fc(backwards)
self.dropout(logits)

if target is not None:
loss = self.crf(logits, target, seq_len_to_mask(seq_len))
return {Const.LOSS: loss}
else:
pred, _ = self.crf.viterbi_decode(logits, seq_len_to_mask(seq_len))
return {Const.OUTPUT: pred}

def forward(self, words, cap_words, seq_len, target):
return self._forward(words, cap_words, seq_len, target)

def predict(self, words, cap_words, seq_len):
return self._forward(words, cap_words, seq_len, None)

+ 0
- 0
reproduction/seqence_labelling/ner/test/__init__.py View File


+ 33
- 0
reproduction/seqence_labelling/ner/test/test.py View File

@@ -0,0 +1,33 @@

from reproduction.seqence_labelling.ner.data.Conll2003Loader import Conll2003DataLoader
from reproduction.seqence_labelling.ner.data.Conll2003Loader import iob2, iob2bioes
import unittest

class TestTagSchemaConverter(unittest.TestCase):
def test_iob2(self):
tags = ['B-ORG', 'O', 'B-MISC', 'O', 'O', 'O', 'B-MISC', 'O', 'O']
golden = ['B-ORG', 'O', 'B-MISC', 'O', 'O', 'O', 'B-MISC', 'O', 'O']
self.assertListEqual(golden, iob2(tags))

tags = ['I-ORG', 'O']
golden = ['B-ORG', 'O']
self.assertListEqual(golden, iob2(tags))

tags = ['I-MISC', 'I-MISC', 'O', 'I-PER', 'I-PER', 'O']
golden = ['B-MISC', 'I-MISC', 'O', 'B-PER', 'I-PER', 'O']
self.assertListEqual(golden, iob2(tags))

def test_iob2bemso(self):
tags = ['B-MISC', 'I-MISC', 'O', 'B-PER', 'I-PER', 'O']
golden = ['B-MISC', 'E-MISC', 'O', 'B-PER', 'E-PER', 'O']
self.assertListEqual(golden, iob2bioes(tags))


def test_conll2003_loader():
path = '/hdd/fudanNLP/fastNLP/others/data/conll2003/train.txt'
loader = Conll2003DataLoader().load(path)
print(loader[:3])


if __name__ == '__main__':
test_conll2003_loader()

+ 42
- 0
reproduction/seqence_labelling/ner/train_cnn_lstm_crf_conll2003.py View File

@@ -0,0 +1,42 @@


from fastNLP.modules.encoder.embedding import CNNCharEmbedding, StaticEmbedding, BertEmbedding
from fastNLP.core.vocabulary import VocabularyOption

from reproduction.seqence_labelling.ner.model.lstm_cnn_crf import CNNBiLSTMCRF
from fastNLP import Trainer
from fastNLP import SpanFPreRecMetric
from fastNLP import BucketSampler
from fastNLP import Const
from torch.optim import SGD, Adam
from fastNLP import GradientClipCallback
from fastNLP.core.callback import FitlogCallback
import fitlog
fitlog.debug()

from reproduction.seqence_labelling.ner.data.Conll2003Loader import Conll2003DataLoader

encoding_type = 'bioes'

data = Conll2003DataLoader(encoding_type=encoding_type).process('/hdd/fudanNLP/fastNLP/others/data/conll2003',
word_vocab_opt=VocabularyOption(min_freq=3))
print(data)
char_embed = CNNCharEmbedding(vocab=data.vocabs['cap_words'], embed_size=30, char_emb_size=30, filter_nums=[30],
kernel_sizes=[3])
word_embed = StaticEmbedding(vocab=data.vocabs[Const.INPUT],
model_dir_or_name='/hdd/fudanNLP/pretrain_vectors/glove.6B.100d.txt',
requires_grad=True)
word_embed.embedding.weight.data = word_embed.embedding.weight.data/word_embed.embedding.weight.data.std()

model = CNNBiLSTMCRF(word_embed, char_embed, hidden_size=400, num_layers=1, tag_vocab=data.vocabs[Const.TARGET],
encoding_type=encoding_type)

optimizer = Adam(model.parameters(), lr=0.001)

callbacks = [GradientClipCallback(clip_type='value'), FitlogCallback({'test':data.datasets['test']}, verbose=1)]

trainer = Trainer(train_data=data.datasets['train'], model=model, optimizer=optimizer, sampler=BucketSampler(),
device=0, dev_data=data.datasets['dev'], batch_size=32,
metrics=SpanFPreRecMetric(tag_vocab=data.vocabs[Const.TARGET], encoding_type=encoding_type),
callbacks=callbacks, num_workers=1, n_epochs=100)
trainer.train()

+ 39
- 0
reproduction/seqence_labelling/ner/train_ontonote.py View File

@@ -0,0 +1,39 @@


from fastNLP.modules.encoder.embedding import CNNCharEmbedding, StaticEmbedding

from reproduction.seqence_labelling.ner.model.lstm_cnn_crf import CNNBiLSTMCRF
from fastNLP import Trainer
from fastNLP import SpanFPreRecMetric
from fastNLP import BucketSampler
from fastNLP import Const
from torch.optim import SGD, Adam
from fastNLP import GradientClipCallback
from fastNLP.core.callback import FitlogCallback
import fitlog
fitlog.debug()

from reproduction.seqence_labelling.ner.data.OntoNoteLoader import OntoNoteNERDataLoader

encoding_type = 'bioes'

data = OntoNoteNERDataLoader(encoding_type=encoding_type).process('/hdd/fudanNLP/fastNLP/others/data/v4/english')
print(data)
char_embed = CNNCharEmbedding(vocab=data.vocabs['cap_words'], embed_size=30, char_emb_size=30, filter_nums=[30],
kernel_sizes=[3])
word_embed = StaticEmbedding(vocab=data.vocabs[Const.INPUT],
model_dir_or_name='/hdd/fudanNLP/pretrain_vectors/glove.6B.100d.txt',
requires_grad=True)

model = CNNBiLSTMCRF(word_embed, char_embed, hidden_size=200, num_layers=1, tag_vocab=data.vocabs[Const.TARGET],
encoding_type=encoding_type)

optimizer = Adam(model.parameters(), lr=0.001)

callbacks = [GradientClipCallback(), FitlogCallback(data.datasets['test'], verbose=1)]

trainer = Trainer(train_data=data.datasets['train'], model=model, optimizer=optimizer, sampler=BucketSampler(),
device=1, dev_data=data.datasets['dev'], batch_size=32,
metrics=SpanFPreRecMetric(tag_vocab=data.vocabs[Const.TARGET], encoding_type=encoding_type),
callbacks=callbacks, num_workers=1, n_epochs=100)
trainer.train()

Loading…
Cancel
Save