Browse Source

1.增加sequence labelling中bert ner; 2.将print替换为logger

tags/v0.4.10
yh_cc 5 years ago
parent
commit
74934271dc
28 changed files with 182 additions and 427 deletions
  1. +2
    -2
      fastNLP/core/batch.py
  2. +9
    -9
      fastNLP/core/callback.py
  3. +6
    -5
      fastNLP/core/dataset.py
  4. +0
    -1
      fastNLP/core/dist_trainer.py
  5. +10
    -9
      fastNLP/core/field.py
  6. +1
    -1
      fastNLP/core/tester.py
  7. +2
    -2
      fastNLP/core/utils.py
  8. +4
    -3
      fastNLP/core/vocabulary.py
  9. +11
    -5
      fastNLP/embeddings/bert_embedding.py
  10. +5
    -4
      fastNLP/embeddings/char_embedding.py
  11. +5
    -5
      fastNLP/embeddings/contextual_embedding.py
  12. +5
    -5
      fastNLP/embeddings/elmo_embedding.py
  13. +3
    -1
      fastNLP/embeddings/embedding.py
  14. +5
    -4
      fastNLP/embeddings/static_embedding.py
  15. +4
    -4
      fastNLP/io/embed_loader.py
  16. +5
    -4
      fastNLP/io/file_reader.py
  17. +7
    -6
      fastNLP/io/file_utils.py
  18. +1
    -1
      fastNLP/io/pipe/classification.py
  19. +3
    -3
      fastNLP/io/utils.py
  20. +7
    -8
      fastNLP/modules/encoder/bert.py
  21. +0
    -93
      reproduction/seqence_labelling/ner/data/Conll2003Loader.py
  22. +0
    -152
      reproduction/seqence_labelling/ner/data/OntoNoteLoader.py
  23. +0
    -49
      reproduction/seqence_labelling/ner/data/utils.py
  24. +31
    -0
      reproduction/seqence_labelling/ner/model/bert_crf.py
  25. +0
    -0
      reproduction/seqence_labelling/ner/test/__init__.py
  26. +0
    -33
      reproduction/seqence_labelling/ner/test/test.py
  27. +52
    -0
      reproduction/seqence_labelling/ner/train_bert.py
  28. +4
    -18
      reproduction/seqence_labelling/ner/train_idcnn.py

+ 2
- 2
fastNLP/core/batch.py View File

@@ -17,7 +17,7 @@ from numbers import Number

from .sampler import SequentialSampler
from .dataset import DataSet
from ._logger import logger
_python_is_exit = False


@@ -75,7 +75,7 @@ class DataSetGetter:
try:
data, flag = _to_tensor(data, f.dtype)
except TypeError as e:
print(f"Field {n} cannot be converted to torch.tensor.")
logger.error(f"Field {n} cannot be converted to torch.tensor.")
raise e
batch_dict[n] = data
return batch_dict


+ 9
- 9
fastNLP/core/callback.py View File

@@ -83,7 +83,6 @@ try:
except:
tensorboardX_flag = False

from ..io.model_io import ModelSaver, ModelLoader
from .dataset import DataSet
from .tester import Tester
from ._logger import logger
@@ -505,7 +504,7 @@ class EarlyStopCallback(Callback):
def on_exception(self, exception):
if isinstance(exception, EarlyStopError):
print("Early Stopping triggered in epoch {}!".format(self.epoch))
logger.info("Early Stopping triggered in epoch {}!".format(self.epoch))
else:
raise exception # 抛出陌生Error

@@ -752,8 +751,7 @@ class LRFinder(Callback):
self.smooth_value = SmoothValue(0.8)
self.opt = None
self.find = None
self.loader = ModelLoader()

@property
def lr_gen(self):
scale = (self.end_lr - self.start_lr) / self.batch_per_epoch
@@ -768,7 +766,7 @@ class LRFinder(Callback):
self.opt = self.trainer.optimizer # pytorch optimizer
self.opt.param_groups[0]["lr"] = self.start_lr
# save model
ModelSaver("tmp").save_pytorch(self.trainer.model, param_only=True)
torch.save(self.model.state_dict(), 'tmp')
self.find = True
def on_backward_begin(self, loss):
@@ -797,7 +795,9 @@ class LRFinder(Callback):
self.opt.param_groups[0]["lr"] = self.best_lr
self.find = False
# reset model
ModelLoader().load_pytorch(self.trainer.model, "tmp")
states = torch.load('tmp')
self.model.load_state_dict(states)
os.remove('tmp')
self.pbar.write("Model reset. \nFind best lr={}".format(self.best_lr))


@@ -988,14 +988,14 @@ class SaveModelCallback(Callback):
try:
_save_model(self.model, model_name=name, save_dir=self.save_dir, only_param=self.only_param)
except Exception as e:
print(f"The following exception:{e} happens when save model to {self.save_dir}.")
logger.error(f"The following exception:{e} happens when save model to {self.save_dir}.")
if delete_pair:
try:
delete_model_path = os.path.join(self.save_dir, delete_pair[1])
if os.path.exists(delete_model_path):
os.remove(delete_model_path)
except Exception as e:
print(f"Fail to delete model {name} at {self.save_dir} caused by exception:{e}.")
logger.error(f"Fail to delete model {name} at {self.save_dir} caused by exception:{e}.")

def on_exception(self, exception):
if self.save_on_exception:
@@ -1032,7 +1032,7 @@ class EchoCallback(Callback):

def __getattribute__(self, item):
if item.startswith('on_'):
print('{}.{} has been called at pid: {}'.format(self.name, item, os.getpid()),
logger.info('{}.{} has been called at pid: {}'.format(self.name, item, os.getpid()),
file=self.out)
return super(EchoCallback, self).__getattribute__(item)



+ 6
- 5
fastNLP/core/dataset.py View File

@@ -300,6 +300,7 @@ from .utils import _get_func_signature
from .field import AppendToTargetOrInputException
from .field import SetInputOrTargetException
from .const import Const
from ._logger import logger

class DataSet(object):
"""
@@ -452,7 +453,7 @@ class DataSet(object):
try:
self.field_arrays[name].append(field)
except AppendToTargetOrInputException as e:
print(f"Cannot append to field:{name}.")
logger.error(f"Cannot append to field:{name}.")
raise e
def add_fieldarray(self, field_name, fieldarray):
@@ -609,7 +610,7 @@ class DataSet(object):
self.field_arrays[name]._use_1st_ins_infer_dim_type = bool(use_1st_ins_infer_dim_type)
self.field_arrays[name].is_target = flag
except SetInputOrTargetException as e:
print(f"Cannot set field:{name} as target.")
logger.error(f"Cannot set field:{name} as target.")
raise e
else:
raise KeyError("{} is not a valid field name.".format(name))
@@ -633,7 +634,7 @@ class DataSet(object):
self.field_arrays[name]._use_1st_ins_infer_dim_type = bool(use_1st_ins_infer_dim_type)
self.field_arrays[name].is_input = flag
except SetInputOrTargetException as e:
print(f"Cannot set field:{name} as input, exception happens at the {e.index} value.")
logger.error(f"Cannot set field:{name} as input, exception happens at the {e.index} value.")
raise e
else:
raise KeyError("{} is not a valid field name.".format(name))
@@ -728,7 +729,7 @@ class DataSet(object):
results.append(func(ins[field_name]))
except Exception as e:
if idx != -1:
print("Exception happens at the `{}`th(from 1) instance.".format(idx+1))
logger.error("Exception happens at the `{}`th(from 1) instance.".format(idx+1))
raise e
if not (new_field_name is None) and len(list(filter(lambda x: x is not None, results))) == 0: # all None
raise ValueError("{} always return None.".format(_get_func_signature(func=func)))
@@ -795,7 +796,7 @@ class DataSet(object):
results.append(func(ins))
except BaseException as e:
if idx != -1:
print("Exception happens at the `{}`th instance.".format(idx))
logger.error("Exception happens at the `{}`th instance.".format(idx))
raise e

# results = [func(ins) for ins in self._inner_iter()]


+ 0
- 1
fastNLP/core/dist_trainer.py View File

@@ -54,7 +54,6 @@ class DistTrainer():
num_workers=1, drop_last=False,
dev_data=None, metrics=None, metric_key=None,
update_every=1, print_every=10, validate_every=-1,
log_path=None,
save_every=-1, save_path=None, device='auto',
fp16='', backend=None, init_method=None):



+ 10
- 9
fastNLP/core/field.py View File

@@ -12,6 +12,7 @@ from abc import abstractmethod
from copy import deepcopy
from collections import Counter
from .utils import _is_iterable
from ._logger import logger


class SetInputOrTargetException(Exception):
@@ -39,7 +40,7 @@ class FieldArray:
try:
_content = list(_content)
except BaseException as e:
print(f"Cannot convert content(of type:{type(content)}) into list.")
logger.error(f"Cannot convert content(of type:{type(content)}) into list.")
raise e
self.name = name
self.content = _content
@@ -263,7 +264,7 @@ class FieldArray:
try:
new_contents.append(cell.split(sep))
except Exception as e:
print(f"Exception happens when process value in index {index}.")
logger.error(f"Exception happens when process value in index {index}.")
raise e
return self._after_process(new_contents, inplace=inplace)
@@ -283,8 +284,8 @@ class FieldArray:
else:
new_contents.append(int(cell))
except Exception as e:
print(f"Exception happens when process value in index {index}.")
print(e)
logger.error(f"Exception happens when process value in index {index}.")
raise e
return self._after_process(new_contents, inplace=inplace)
def float(self, inplace=True):
@@ -303,7 +304,7 @@ class FieldArray:
else:
new_contents.append(float(cell))
except Exception as e:
print(f"Exception happens when process value in index {index}.")
logger.error(f"Exception happens when process value in index {index}.")
raise e
return self._after_process(new_contents, inplace=inplace)
@@ -323,7 +324,7 @@ class FieldArray:
else:
new_contents.append(bool(cell))
except Exception as e:
print(f"Exception happens when process value in index {index}.")
logger.error(f"Exception happens when process value in index {index}.")
raise e
return self._after_process(new_contents, inplace=inplace)
@@ -344,7 +345,7 @@ class FieldArray:
else:
new_contents.append(cell.lower())
except Exception as e:
print(f"Exception happens when process value in index {index}.")
logger.error(f"Exception happens when process value in index {index}.")
raise e
return self._after_process(new_contents, inplace=inplace)
@@ -364,7 +365,7 @@ class FieldArray:
else:
new_contents.append(cell.upper())
except Exception as e:
print(f"Exception happens when process value in index {index}.")
logger.error(f"Exception happens when process value in index {index}.")
raise e
return self._after_process(new_contents, inplace=inplace)
@@ -401,7 +402,7 @@ class FieldArray:
self.is_input = self.is_input
self.is_target = self.is_input
except SetInputOrTargetException as e:
print("The newly generated field cannot be set as input or target.")
logger.error("The newly generated field cannot be set as input or target.")
raise e
return self
else:


+ 1
- 1
fastNLP/core/tester.py View File

@@ -192,7 +192,7 @@ class Tester(object):
dataset=self.data, check_level=0)
if self.verbose >= 1:
print("[tester] \n{}".format(self._format_eval_results(eval_results)))
logger.info("[tester] \n{}".format(self._format_eval_results(eval_results)))
self._mode(network, is_test=False)
return eval_results


+ 2
- 2
fastNLP/core/utils.py View File

@@ -145,7 +145,7 @@ def cache_results(_cache_fp, _refresh=False, _verbose=1):
with open(cache_filepath, 'rb') as f:
results = _pickle.load(f)
if verbose == 1:
print("Read cache from {}.".format(cache_filepath))
logger.info("Read cache from {}.".format(cache_filepath))
refresh_flag = False
if refresh_flag:
@@ -156,7 +156,7 @@ def cache_results(_cache_fp, _refresh=False, _verbose=1):
_prepare_cache_filepath(cache_filepath)
with open(cache_filepath, 'wb') as f:
_pickle.dump(results, f)
print("Save cache to {}.".format(cache_filepath))
logger.info("Save cache to {}.".format(cache_filepath))
return results


+ 4
- 3
fastNLP/core/vocabulary.py View File

@@ -10,6 +10,7 @@ from .utils import Option
from functools import partial
import numpy as np
from .utils import _is_iterable
from ._logger import logger

class VocabularyOption(Option):
def __init__(self,
@@ -49,7 +50,7 @@ def _check_build_status(func):
if self.rebuild is False:
self.rebuild = True
if self.max_size is not None and len(self.word_count) >= self.max_size:
print("[Warning] Vocabulary has reached the max size {} when calling {} method. "
logger.info("[Warning] Vocabulary has reached the max size {} when calling {} method. "
"Adding more words may cause unexpected behaviour of Vocabulary. ".format(
self.max_size, func.__name__))
return func(self, *args, **kwargs)
@@ -297,7 +298,7 @@ class Vocabulary(object):
for f_n, n_f_n in zip(field_name, new_field_name):
dataset.apply_field(index_instance, field_name=f_n, new_field_name=n_f_n)
except Exception as e:
print("When processing the `{}` dataset, the following error occurred.".format(idx))
logger.info("When processing the `{}` dataset, the following error occurred.".format(idx))
raise e
else:
raise RuntimeError("Only DataSet type is allowed.")
@@ -353,7 +354,7 @@ class Vocabulary(object):
try:
dataset.apply(construct_vocab)
except BaseException as e:
print("When processing the `{}` dataset, the following error occurred:".format(idx))
log("When processing the `{}` dataset, the following error occurred:".format(idx))
raise e
else:
raise TypeError("Only DataSet type is allowed.")


+ 11
- 5
fastNLP/embeddings/bert_embedding.py View File

@@ -21,6 +21,7 @@ from ..io.file_utils import _get_embedding_url, cached_path, PRETRAINED_BERT_MOD
from ..modules.encoder.bert import _WordPieceBertModel, BertModel, BertTokenizer
from .contextual_embedding import ContextualEmbedding
import warnings
from ..core import logger


class BertEmbedding(ContextualEmbedding):
@@ -125,8 +126,10 @@ class BertEmbedding(ContextualEmbedding):
with torch.no_grad():
if self._word_sep_index: # 不能drop sep
sep_mask = words.eq(self._word_sep_index)
mask = torch.ones_like(words).float() * self.word_dropout
mask = torch.full_like(words, fill_value=self.word_dropout)
mask = torch.bernoulli(mask).eq(1) # dropout_word越大,越多位置为1
pad_mask = words.ne(0)
mask = pad_mask.__and__(mask) # pad的位置不为unk
words = words.masked_fill(mask, self._word_unk_index)
if self._word_sep_index:
words.masked_fill_(sep_mask, self._word_sep_index)
@@ -182,6 +185,7 @@ class BertWordPieceEncoder(nn.Module):
self.model = _WordPieceBertModel(model_dir=model_dir, layers=layers, pooled_cls=pooled_cls)
self._sep_index = self.model._sep_index
self._wordpiece_pad_index = self.model._wordpiece_pad_index
self._wordpiece_unk_index = self.model._wordpiece_unknown_index
self._embed_size = len(self.model.layers) * self.model.encoder.hidden_size
self.requires_grad = requires_grad
@@ -263,8 +267,10 @@ class BertWordPieceEncoder(nn.Module):
with torch.no_grad():
if self._word_sep_index: # 不能drop sep
sep_mask = words.eq(self._wordpiece_unk_index)
mask = torch.ones_like(words).float() * self.word_dropout
mask = torch.full_like(words, fill_value=self.word_dropout)
mask = torch.bernoulli(mask).eq(1) # dropout_word越大,越多位置为1
pad_mask = words.ne(self._wordpiece_pad_index)
mask = pad_mask.__and__(mask) # pad的位置不为unk
words = words.masked_fill(mask, self._word_unk_index)
if self._word_sep_index:
words.masked_fill_(sep_mask, self._wordpiece_unk_index)
@@ -297,7 +303,7 @@ class _WordBertModel(nn.Module):
self.auto_truncate = auto_truncate
# 将所有vocab中word的wordpiece计算出来, 需要额外考虑[CLS]和[SEP]
print("Start to generating word pieces for word.")
logger.info("Start to generating word pieces for word.")
# 第一步统计出需要的word_piece, 然后创建新的embed和word_piece_vocab, 然后填入值
word_piece_dict = {'[CLS]': 1, '[SEP]': 1} # 用到的word_piece以及新增的
found_count = 0
@@ -356,10 +362,10 @@ class _WordBertModel(nn.Module):
self._sep_index = self.tokenzier.vocab['[SEP]']
self._word_pad_index = vocab.padding_idx
self._wordpiece_pad_index = self.tokenzier.vocab['[PAD]'] # 需要用于生成word_piece
print("Found(Or segment into word pieces) {} words out of {}.".format(found_count, len(vocab)))
logger.info("Found(Or segment into word pieces) {} words out of {}.".format(found_count, len(vocab)))
self.word_to_wordpieces = np.array(word_to_wordpieces)
self.word_pieces_lengths = nn.Parameter(torch.LongTensor(word_pieces_lengths), requires_grad=False)
print("Successfully generate word pieces.")
logger.debug("Successfully generate word pieces.")
def forward(self, words):
"""


+ 5
- 4
fastNLP/embeddings/char_embedding.py View File

@@ -19,6 +19,7 @@ from ..core.vocabulary import Vocabulary
from .embedding import TokenEmbedding
from .utils import _construct_char_vocab_from_vocab
from .utils import get_embeddings
from ..core import logger


class CNNCharEmbedding(TokenEmbedding):
@@ -81,11 +82,11 @@ class CNNCharEmbedding(TokenEmbedding):
raise Exception(
"Undefined activation function: choose from: [relu, tanh, sigmoid, or a callable function]")
print("Start constructing character vocabulary.")
logger.info("Start constructing character vocabulary.")
# 建立char的词表
self.char_vocab = _construct_char_vocab_from_vocab(vocab, min_freq=min_char_freq)
self.char_pad_index = self.char_vocab.padding_idx
print(f"In total, there are {len(self.char_vocab)} distinct characters.")
logger.info(f"In total, there are {len(self.char_vocab)} distinct characters.")
# 对vocab进行index
max_word_len = max(map(lambda x: len(x[0]), vocab))
self.words_to_chars_embedding = nn.Parameter(torch.full((len(vocab), max_word_len),
@@ -236,11 +237,11 @@ class LSTMCharEmbedding(TokenEmbedding):
raise Exception(
"Undefined activation function: choose from: [relu, tanh, sigmoid, or a callable function]")
print("Start constructing character vocabulary.")
logger.info("Start constructing character vocabulary.")
# 建立char的词表
self.char_vocab = _construct_char_vocab_from_vocab(vocab, min_freq=min_char_freq)
self.char_pad_index = self.char_vocab.padding_idx
print(f"In total, there are {len(self.char_vocab)} distinct characters.")
logger.info(f"In total, there are {len(self.char_vocab)} distinct characters.")
# 对vocab进行index
self.max_word_len = max(map(lambda x: len(x[0]), vocab))
self.words_to_chars_embedding = nn.Parameter(torch.full((len(vocab), self.max_word_len),


+ 5
- 5
fastNLP/embeddings/contextual_embedding.py View File

@@ -16,7 +16,7 @@ from ..core.batch import DataSetIter
from ..core.sampler import SequentialSampler
from ..core.utils import _move_model_to_device, _get_model_device
from .embedding import TokenEmbedding
from ..core import logger

class ContextualEmbedding(TokenEmbedding):
def __init__(self, vocab: Vocabulary, word_dropout: float = 0.0, dropout: float = 0.0):
@@ -37,14 +37,14 @@ class ContextualEmbedding(TokenEmbedding):
assert isinstance(dataset, DataSet), "Only fastNLP.DataSet object is allowed."
assert 'words' in dataset.get_input_name(), "`words` field has to be set as input."
except Exception as e:
print(f"Exception happens at {index} dataset.")
logger.error(f"Exception happens at {index} dataset.")
raise e
sent_embeds = {}
_move_model_to_device(self, device=device)
device = _get_model_device(self)
pad_index = self._word_vocab.padding_idx
print("Start to calculate sentence representations.")
logger.info("Start to calculate sentence representations.")
with torch.no_grad():
for index, dataset in enumerate(datasets):
try:
@@ -64,9 +64,9 @@ class ContextualEmbedding(TokenEmbedding):
else:
sent_embeds[tuple(words_list[b][:seq_len[b]])] = word_embeds[b, :-length]
except Exception as e:
print(f"Exception happens at {index} dataset.")
logger.error(f"Exception happens at {index} dataset.")
raise e
print("Finish calculating sentence representations.")
logger.info("Finish calculating sentence representations.")
self.sent_embeds = sent_embeds
if delete_weights:
self._delete_model_weights()


+ 5
- 5
fastNLP/embeddings/elmo_embedding.py View File

@@ -18,7 +18,7 @@ from ..core.vocabulary import Vocabulary
from ..io.file_utils import cached_path, _get_embedding_url, PRETRAINED_ELMO_MODEL_DIR
from ..modules.encoder._elmo import ElmobiLm, ConvTokenEmbedder
from .contextual_embedding import ContextualEmbedding
from ..core import logger

class ElmoEmbedding(ContextualEmbedding):
"""
@@ -243,7 +243,7 @@ class _ElmoModel(nn.Module):
index_in_pre = char_lexicon[OOV_TAG]
char_emb_layer.weight.data[index] = char_embed_weights[index_in_pre]
print(f"{found_char_count} out of {len(char_vocab)} characters were found in pretrained elmo embedding.")
logger.info(f"{found_char_count} out of {len(char_vocab)} characters were found in pretrained elmo embedding.")
# 生成words到chars的映射
max_chars = config['char_cnn']['max_characters_per_token']
@@ -281,7 +281,7 @@ class _ElmoModel(nn.Module):
if cache_word_reprs:
if config['char_cnn']['embedding']['dim'] > 0: # 只有在使用了chars的情况下有用
print("Start to generate cache word representations.")
logger.info("Start to generate cache word representations.")
batch_size = 320
# bos eos
word_size = self.words_to_chars_embedding.size(0)
@@ -299,10 +299,10 @@ class _ElmoModel(nn.Module):
chars).detach() # batch_size x 1 x config['encoder']['projection_dim']
self.cached_word_embedding.weight.data[words] = word_reprs.squeeze(1)
print("Finish generating cached word representations. Going to delete the character encoder.")
logger.info("Finish generating cached word representations. Going to delete the character encoder.")
del self.token_embedder, self.words_to_chars_embedding
else:
print("There is no need to cache word representations, since no character information is used.")
logger.info("There is no need to cache word representations, since no character information is used.")
def forward(self, words):
"""


+ 3
- 1
fastNLP/embeddings/embedding.py View File

@@ -138,8 +138,10 @@ class TokenEmbedding(nn.Module):
:return:
"""
if self.word_dropout > 0 and self.training:
mask = torch.ones_like(words).float() * self.word_dropout
mask = torch.full_like(words, fill_value=self.word_dropout)
mask = torch.bernoulli(mask).eq(1) # dropout_word越大,越多位置为1
pad_mask = words.ne(self._word_pad_index)
mask = mask.__and__(pad_mask)
words = words.masked_fill(mask, self._word_unk_index)
return words


+ 5
- 4
fastNLP/embeddings/static_embedding.py View File

@@ -19,6 +19,7 @@ from .embedding import TokenEmbedding
from ..modules.utils import _get_file_name_base_on_postfix
from copy import deepcopy
from collections import defaultdict
from ..core import logger


class StaticEmbedding(TokenEmbedding):
@@ -112,7 +113,7 @@ class StaticEmbedding(TokenEmbedding):
truncated_words_to_words = torch.arange(len(vocab)).long()
for word, index in vocab:
truncated_words_to_words[index] = truncated_vocab.to_index(word)
print(f"{len(vocab) - len(truncated_vocab)} out of {len(vocab)} words have frequency less than {min_freq}.")
logger.info(f"{len(vocab) - len(truncated_vocab)} out of {len(vocab)} words have frequency less than {min_freq}.")
vocab = truncated_vocab
self.only_norm_found_vector = kwargs.get('only_norm_found_vector', False)
@@ -124,7 +125,7 @@ class StaticEmbedding(TokenEmbedding):
lowered_vocab.add_word(word.lower(), no_create_entry=True)
else:
lowered_vocab.add_word(word.lower()) # 先加入需要创建entry的
print(f"All word in the vocab have been lowered. There are {len(vocab)} words, {len(lowered_vocab)} "
logger.info(f"All word in the vocab have been lowered. There are {len(vocab)} words, {len(lowered_vocab)} "
f"unique lowered words.")
if model_path:
embedding = self._load_with_vocab(model_path, vocab=lowered_vocab, init_method=init_method)
@@ -265,9 +266,9 @@ class StaticEmbedding(TokenEmbedding):
if error == 'ignore':
warnings.warn("Error occurred at the {} line.".format(idx))
else:
print("Error occurred at the {} line.".format(idx))
logger.error("Error occurred at the {} line.".format(idx))
raise e
print("Found {} out of {} words in the pre-training embedding.".format(found_count, len(vocab)))
logger.info("Found {} out of {} words in the pre-training embedding.".format(found_count, len(vocab)))
for word, index in vocab:
if index not in matrix and not vocab._is_word_no_create_entry(word):
if found_unknown: # 如果有unkonwn,用unknown初始化


+ 4
- 4
fastNLP/io/embed_loader.py View File

@@ -11,7 +11,7 @@ import numpy as np
from ..core.vocabulary import Vocabulary
from .data_bundle import BaseLoader
from ..core.utils import Option
import logging

class EmbeddingOption(Option):
def __init__(self,
@@ -91,10 +91,10 @@ class EmbedLoader(BaseLoader):
if error == 'ignore':
warnings.warn("Error occurred at the {} line.".format(idx))
else:
print("Error occurred at the {} line.".format(idx))
logging.error("Error occurred at the {} line.".format(idx))
raise e
total_hits = sum(hit_flags)
print("Found {} out of {} words in the pre-training embedding.".format(total_hits, len(vocab)))
logging.info("Found {} out of {} words in the pre-training embedding.".format(total_hits, len(vocab)))
if init_method is None:
found_vectors = matrix[hit_flags]
if len(found_vectors) != 0:
@@ -157,7 +157,7 @@ class EmbedLoader(BaseLoader):
warnings.warn("Error occurred at the {} line.".format(idx))
pass
else:
print("Error occurred at the {} line.".format(idx))
logging.error("Error occurred at the {} line.".format(idx))
raise e
if dim == -1:
raise RuntimeError("{} is an empty file.".format(embed_filepath))


+ 5
- 4
fastNLP/io/file_reader.py View File

@@ -2,7 +2,8 @@
此模块用于给其它模块提供读取文件的函数,没有为用户提供 API
"""
import json
import warnings
from ..core import logger


def _read_csv(path, encoding='utf-8', headers=None, sep=',', dropna=True):
"""
@@ -103,9 +104,9 @@ def _read_conll(path, encoding='utf-8', indexes=None, dropna=True):
yield line_idx, res
except Exception as e:
if dropna:
warnings.warn('Invalid instance ends at line: {} has been dropped.'.format(line_idx))
logger.warn('Invalid instance which ends at line: {} has been dropped.'.format(line_idx))
continue
raise ValueError('Invalid instance ends at line: {}'.format(line_idx))
raise ValueError('Invalid instance which ends at line: {}'.format(line_idx))
elif line.startswith('#'):
continue
else:
@@ -117,5 +118,5 @@ def _read_conll(path, encoding='utf-8', indexes=None, dropna=True):
except Exception as e:
if dropna:
return
print('invalid instance ends at line: {}'.format(line_idx))
logger.error('invalid instance ends at line: {}'.format(line_idx))
raise e

+ 7
- 6
fastNLP/io/file_utils.py View File

@@ -7,6 +7,7 @@ import tempfile
from tqdm import tqdm
import shutil
from requests import HTTPError
from ..core import logger

PRETRAINED_BERT_MODEL_DIR = {
'en': 'bert-base-cased.zip',
@@ -336,7 +337,7 @@ def get_from_cache(url: str, cache_dir: Path = None) -> Path:
content_length = req.headers.get("Content-Length")
total = int(content_length) if content_length is not None else None
progress = tqdm(unit="B", total=total, unit_scale=1)
print("%s not found in cache, downloading to %s" % (url, temp_filename))
logger.info("%s not found in cache, downloading to %s" % (url, temp_filename))

with open(temp_filename, "wb") as temp_file:
for chunk in req.iter_content(chunk_size=1024 * 16):
@@ -344,12 +345,12 @@ def get_from_cache(url: str, cache_dir: Path = None) -> Path:
progress.update(len(chunk))
temp_file.write(chunk)
progress.close()
print(f"Finish download from {url}")
logger.info(f"Finish download from {url}")

# 开始解压
if suffix in ('.zip', '.tar.gz', '.gz'):
uncompress_temp_dir = tempfile.mkdtemp()
print(f"Start to uncompress file to {uncompress_temp_dir}")
logger.debug(f"Start to uncompress file to {uncompress_temp_dir}")
if suffix == '.zip':
unzip_file(Path(temp_filename), Path(uncompress_temp_dir))
elif suffix == '.gz':
@@ -362,13 +363,13 @@ def get_from_cache(url: str, cache_dir: Path = None) -> Path:
uncompress_temp_dir = os.path.join(uncompress_temp_dir, filenames[0])

cache_path.mkdir(parents=True, exist_ok=True)
print("Finish un-compressing file.")
logger.debug("Finish un-compressing file.")
else:
uncompress_temp_dir = temp_filename
cache_path = str(cache_path) + suffix

# 复制到指定的位置
print(f"Copy file to {cache_path}")
logger.info(f"Copy file to {cache_path}")
if os.path.isdir(uncompress_temp_dir):
for filename in os.listdir(uncompress_temp_dir):
if os.path.isdir(os.path.join(uncompress_temp_dir, filename)):
@@ -379,7 +380,7 @@ def get_from_cache(url: str, cache_dir: Path = None) -> Path:
shutil.copyfile(uncompress_temp_dir, cache_path)
success = True
except Exception as e:
print(e)
logger.error(e)
raise e
finally:
if not success:


+ 1
- 1
fastNLP/io/pipe/classification.py View File

@@ -11,7 +11,7 @@ from .utils import get_tokenizer, _indexize, _add_words_field, _drop_empty_insta
from .pipe import Pipe
import re
nonalpnum = re.compile('[^0-9a-zA-Z?!\']+')
from ...core.utils import cache_results

class _CLSPipe(Pipe):
"""


+ 3
- 3
fastNLP/io/utils.py View File

@@ -2,7 +2,7 @@ import os

from typing import Union, Dict
from pathlib import Path
from ..core import logger

def check_loader_paths(paths:Union[str, Dict[str, str]])->Dict[str, str]:
"""
@@ -70,8 +70,8 @@ def get_tokenizer():
import spacy
spacy.prefer_gpu()
en = spacy.load('en')
print('use spacy tokenizer')
logger.info('use spacy tokenizer')
return lambda x: [w.text for w in en.tokenizer(x)]
except Exception as e:
print('use raw tokenizer')
logger.error('use raw tokenizer')
return lambda x: x.split()

+ 7
- 8
fastNLP/modules/encoder/bert.py View File

@@ -17,8 +17,7 @@ import os

import torch
from torch import nn
import sys

from ...core import logger
from ..utils import _get_file_name_base_on_postfix

CONFIG_FILE = 'bert_config.json'
@@ -489,10 +488,10 @@ class BertModel(nn.Module):

load(model, prefix='' if hasattr(model, 'bert') else 'bert.')
if len(missing_keys) > 0:
print("Weights of {} not initialized from pretrained model: {}".format(
logger.warn("Weights of {} not initialized from pretrained model: {}".format(
model.__class__.__name__, missing_keys))
if len(unexpected_keys) > 0:
print("Weights from pretrained model not used in {}: {}".format(
logger.warn("Weights from pretrained model not used in {}: {}".format(
model.__class__.__name__, unexpected_keys))
return model

@@ -799,7 +798,7 @@ class BertTokenizer(object):
for token in tokens:
ids.append(self.vocab[token])
if len(ids) > self.max_len:
print(
logger.warn(
"Token indices sequence length is longer than the specified maximum "
" sequence length for this BERT model ({} > {}). Running this"
" sequence through BERT will result in indexing errors".format(len(ids), self.max_len)
@@ -823,7 +822,7 @@ class BertTokenizer(object):
with open(vocab_file, "w", encoding="utf-8") as writer:
for token, token_index in sorted(self.vocab.items(), key=lambda kv: kv[1]):
if index != token_index:
print("Saving vocabulary to {}: vocabulary indices are not consecutive."
logger.warn("Saving vocabulary to {}: vocabulary indices are not consecutive."
" Please check that the vocabulary is not corrupted!".format(vocab_file))
index = token_index
writer.write(token + u'\n')
@@ -837,7 +836,7 @@ class BertTokenizer(object):

"""
pretrained_model_name_or_path = _get_file_name_base_on_postfix(model_dir, '.txt')
print("loading vocabulary file {}".format(pretrained_model_name_or_path))
logger.info("loading vocabulary file {}".format(pretrained_model_name_or_path))
max_len = 512
kwargs['max_len'] = min(kwargs.get('max_position_embeddings', int(1e12)), max_len)
# Instantiate tokenizer.
@@ -901,7 +900,7 @@ class _WordPieceBertModel(nn.Module):
is_input=True)
dataset.set_pad_val('word_pieces', self._wordpiece_pad_index)
except Exception as e:
print(f"Exception happens when processing the {index} dataset.")
logger.error(f"Exception happens when processing the {index} dataset.")
raise e

def forward(self, word_pieces, token_type_ids=None):


+ 0
- 93
reproduction/seqence_labelling/ner/data/Conll2003Loader.py View File

@@ -1,93 +0,0 @@

from fastNLP.core.vocabulary import VocabularyOption
from fastNLP.io.data_bundle import DataSetLoader, DataBundle
from typing import Union, Dict
from fastNLP import Vocabulary
from fastNLP import Const
from reproduction.utils import check_dataloader_paths

from fastNLP.io import ConllLoader
from reproduction.seqence_labelling.ner.data.utils import iob2bioes, iob2


class Conll2003DataLoader(DataSetLoader):
def __init__(self, task:str='ner', encoding_type:str='bioes'):
"""
加载Conll2003格式的英语语料,该数据集的信息可以在https://www.clips.uantwerpen.be/conll2003/ner/找到。当task为pos
时,返回的DataSet中target取值于第2列; 当task为chunk时,返回的DataSet中target取值于第3列;当task为ner时,返回
的DataSet中target取值于第4列。所有"-DOCSTART- -X- O O"将被忽略,这会导致数据的数量少于很多文献报道的值,但
鉴于"-DOCSTART- -X- O O"只是用于文档分割的符号,并不应该作为预测对象,所以我们忽略了数据中的-DOCTSTART-开头的行
ner与chunk任务读取后的数据的target将为encoding_type类型。pos任务读取后就是pos列的数据。

:param task: 指定需要标注任务。可选ner, pos, chunk
"""
assert task in ('ner', 'pos', 'chunk')
index = {'ner':3, 'pos':1, 'chunk':2}[task]
self._loader = ConllLoader(headers=['raw_words', 'target'], indexes=[0, index])
self._tag_converters = []
if task in ('ner', 'chunk'):
self._tag_converters = [iob2]
if encoding_type == 'bioes':
self._tag_converters.append(iob2bioes)

def load(self, path: str):
dataset = self._loader.load(path)
def convert_tag_schema(tags):
for converter in self._tag_converters:
tags = converter(tags)
return tags
if self._tag_converters:
dataset.apply_field(convert_tag_schema, field_name=Const.TARGET, new_field_name=Const.TARGET)
return dataset

def process(self, paths: Union[str, Dict[str, str]], word_vocab_opt:VocabularyOption=None, lower:bool=False):
"""
读取并处理数据。数据中的'-DOCSTART-'开头的行会被忽略

:param paths:
:param word_vocab_opt: vocabulary的初始化值
:param lower: 是否将所有字母转为小写。
:return:
"""
# 读取数据
paths = check_dataloader_paths(paths)
data = DataBundle()
input_fields = [Const.TARGET, Const.INPUT, Const.INPUT_LEN]
target_fields = [Const.TARGET, Const.INPUT_LEN]
for name, path in paths.items():
dataset = self.load(path)
dataset.apply_field(lambda words: words, field_name='raw_words', new_field_name=Const.INPUT)
if lower:
dataset.words.lower()
data.datasets[name] = dataset

# 对construct vocab
word_vocab = Vocabulary(min_freq=2) if word_vocab_opt is None else Vocabulary(**word_vocab_opt)
word_vocab.from_dataset(data.datasets['train'], field_name=Const.INPUT,
no_create_entry_dataset=[dataset for name, dataset in data.datasets.items() if name!='train'])
word_vocab.index_dataset(*data.datasets.values(), field_name=Const.INPUT, new_field_name=Const.INPUT)
data.vocabs[Const.INPUT] = word_vocab

# cap words
cap_word_vocab = Vocabulary()
cap_word_vocab.from_dataset(data.datasets['train'], field_name='raw_words',
no_create_entry_dataset=[dataset for name, dataset in data.datasets.items() if name!='train'])
cap_word_vocab.index_dataset(*data.datasets.values(), field_name='raw_words', new_field_name='cap_words')
input_fields.append('cap_words')
data.vocabs['cap_words'] = cap_word_vocab

# 对target建vocab
target_vocab = Vocabulary(unknown=None, padding=None)
target_vocab.from_dataset(*data.datasets.values(), field_name=Const.TARGET)
target_vocab.index_dataset(*data.datasets.values(), field_name=Const.TARGET)
data.vocabs[Const.TARGET] = target_vocab

for name, dataset in data.datasets.items():
dataset.add_seq_len(Const.INPUT, new_field_name=Const.INPUT_LEN)
dataset.set_input(*input_fields)
dataset.set_target(*target_fields)

return data

if __name__ == '__main__':
pass

+ 0
- 152
reproduction/seqence_labelling/ner/data/OntoNoteLoader.py View File

@@ -1,152 +0,0 @@
from fastNLP.core.vocabulary import VocabularyOption
from fastNLP.io.data_bundle import DataSetLoader, DataBundle
from typing import Union, Dict
from fastNLP import DataSet
from fastNLP import Vocabulary
from fastNLP import Const
from reproduction.utils import check_dataloader_paths

from fastNLP.io import ConllLoader
from reproduction.seqence_labelling.ner.data.utils import iob2bioes, iob2

class OntoNoteNERDataLoader(DataSetLoader):
"""
用于读取处理为Conll格式后的OntoNote数据。将OntoNote数据处理为conll格式的过程可以参考https://github.com/yhcc/OntoNotes-5.0-NER。

"""
def __init__(self, encoding_type:str='bioes'):
assert encoding_type in ('bioes', 'bio')
self.encoding_type = encoding_type
if encoding_type=='bioes':
self.encoding_method = iob2bioes
else:
self.encoding_method = iob2

def load(self, path:str)->DataSet:
"""
给定一个文件路径,读取数据。返回的DataSet包含以下的field
raw_words: List[str]
target: List[str]

:param path:
:return:
"""
dataset = ConllLoader(headers=['raw_words', 'target'], indexes=[3, 10]).load(path)
def convert_to_bio(tags):
bio_tags = []
flag = None
for tag in tags:
label = tag.strip("()*")
if '(' in tag:
bio_label = 'B-' + label
flag = label
elif flag:
bio_label = 'I-' + flag
else:
bio_label = 'O'
if ')' in tag:
flag = None
bio_tags.append(bio_label)
return self.encoding_method(bio_tags)

def convert_word(words):
converted_words = []
for word in words:
word = word.replace('/.', '.') # 有些结尾的.是/.形式的
if not word.startswith('-'):
converted_words.append(word)
continue
# 以下是由于这些符号被转义了,再转回来
tfrs = {'-LRB-':'(',
'-RRB-': ')',
'-LSB-': '[',
'-RSB-': ']',
'-LCB-': '{',
'-RCB-': '}'
}
if word in tfrs:
converted_words.append(tfrs[word])
else:
converted_words.append(word)
return converted_words

dataset.apply_field(convert_word, field_name='raw_words', new_field_name='raw_words')
dataset.apply_field(convert_to_bio, field_name='target', new_field_name='target')

return dataset

def process(self, paths: Union[str, Dict[str, str]], word_vocab_opt:VocabularyOption=None,
lower:bool=True)->DataBundle:
"""
读取并处理数据。返回的DataInfo包含以下的内容
vocabs:
word: Vocabulary
target: Vocabulary
datasets:
train: DataSet
words: List[int], 被设置为input
target: int. label,被同时设置为input和target
seq_len: int. 句子的长度,被同时设置为input和target
raw_words: List[str]
xxx(根据传入的paths可能有所变化)

:param paths:
:param word_vocab_opt: vocabulary的初始化值
:param lower: 是否使用小写
:return:
"""
paths = check_dataloader_paths(paths)
data = DataBundle()
input_fields = [Const.TARGET, Const.INPUT, Const.INPUT_LEN]
target_fields = [Const.TARGET, Const.INPUT_LEN]
for name, path in paths.items():
dataset = self.load(path)
dataset.apply_field(lambda words: words, field_name='raw_words', new_field_name=Const.INPUT)
if lower:
dataset.words.lower()
data.datasets[name] = dataset

# 对construct vocab
word_vocab = Vocabulary(min_freq=2) if word_vocab_opt is None else Vocabulary(**word_vocab_opt)
word_vocab.from_dataset(data.datasets['train'], field_name=Const.INPUT,
no_create_entry_dataset=[dataset for name, dataset in data.datasets.items() if name!='train'])
word_vocab.index_dataset(*data.datasets.values(), field_name=Const.INPUT, new_field_name=Const.INPUT)
data.vocabs[Const.INPUT] = word_vocab

# cap words
cap_word_vocab = Vocabulary()
cap_word_vocab.from_dataset(*data.datasets.values(), field_name='raw_words')
cap_word_vocab.index_dataset(*data.datasets.values(), field_name='raw_words', new_field_name='cap_words')
input_fields.append('cap_words')
data.vocabs['cap_words'] = cap_word_vocab

# 对target建vocab
target_vocab = Vocabulary(unknown=None, padding=None)
target_vocab.from_dataset(*data.datasets.values(), field_name=Const.TARGET)
target_vocab.index_dataset(*data.datasets.values(), field_name=Const.TARGET)
data.vocabs[Const.TARGET] = target_vocab

for name, dataset in data.datasets.items():
dataset.add_seq_len(Const.INPUT, new_field_name=Const.INPUT_LEN)
dataset.set_input(*input_fields)
dataset.set_target(*target_fields)

return data


if __name__ == '__main__':
loader = OntoNoteNERDataLoader()
dataset = loader.load('/hdd/fudanNLP/fastNLP/others/data/v4/english/test.txt')
print(dataset.target.value_count())
print(dataset[:4])


"""
train 115812 2200752
development 15680 304684
test 12217 230111

train 92403 1901772
valid 13606 279180
test 10258 204135
"""

+ 0
- 49
reproduction/seqence_labelling/ner/data/utils.py View File

@@ -1,49 +0,0 @@
from typing import List

def iob2(tags:List[str])->List[str]:
"""
检查数据是否是合法的IOB数据,如果是IOB1会被自动转换为IOB2。

:param tags: 需要转换的tags
"""
for i, tag in enumerate(tags):
if tag == "O":
continue
split = tag.split("-")
if len(split) != 2 or split[0] not in ["I", "B"]:
raise TypeError("The encoding schema is not a valid IOB type.")
if split[0] == "B":
continue
elif i == 0 or tags[i - 1] == "O": # conversion IOB1 to IOB2
tags[i] = "B" + tag[1:]
elif tags[i - 1][1:] == tag[1:]:
continue
else: # conversion IOB1 to IOB2
tags[i] = "B" + tag[1:]
return tags

def iob2bioes(tags:List[str])->List[str]:
"""
将iob的tag转换为bmeso编码
:param tags:
:return:
"""
new_tags = []
for i, tag in enumerate(tags):
if tag == 'O':
new_tags.append(tag)
else:
split = tag.split('-')[0]
if split == 'B':
if i+1!=len(tags) and tags[i+1].split('-')[0] == 'I':
new_tags.append(tag)
else:
new_tags.append(tag.replace('B-', 'S-'))
elif split == 'I':
if i + 1<len(tags) and tags[i+1].split('-')[0] == 'I':
new_tags.append(tag)
else:
new_tags.append(tag.replace('I-', 'E-'))
else:
raise TypeError("Invalid IOB format.")
return new_tags

+ 31
- 0
reproduction/seqence_labelling/ner/model/bert_crf.py View File

@@ -0,0 +1,31 @@


from torch import nn
from fastNLP.modules import ConditionalRandomField, allowed_transitions
import torch.nn.functional as F

class BertCRF(nn.Module):
def __init__(self, embed, tag_vocab, encoding_type='bio'):
super().__init__()
self.embed = embed
self.fc = nn.Linear(self.embed.embed_size, len(tag_vocab))
trans = allowed_transitions(tag_vocab, encoding_type=encoding_type, include_start_end=True)
self.crf = ConditionalRandomField(len(tag_vocab), include_start_end_trans=True, allowed_transitions=trans)

def _forward(self, words, target):
mask = words.ne(0)
words = self.embed(words)
words = self.fc(words)
logits = F.log_softmax(words, dim=-1)
if target is not None:
loss = self.crf(logits, target, mask)
return {'loss': loss}
else:
paths, _ = self.crf.viterbi_decode(logits, mask)
return {'pred': paths}

def forward(self, words, target):
return self._forward(words, target)

def predict(self, words):
return self._forward(words, None)

+ 0
- 0
reproduction/seqence_labelling/ner/test/__init__.py View File


+ 0
- 33
reproduction/seqence_labelling/ner/test/test.py View File

@@ -1,33 +0,0 @@

from reproduction.seqence_labelling.ner.data.Conll2003Loader import Conll2003DataLoader
from reproduction.seqence_labelling.ner.data.Conll2003Loader import iob2, iob2bioes
import unittest

class TestTagSchemaConverter(unittest.TestCase):
def test_iob2(self):
tags = ['B-ORG', 'O', 'B-MISC', 'O', 'O', 'O', 'B-MISC', 'O', 'O']
golden = ['B-ORG', 'O', 'B-MISC', 'O', 'O', 'O', 'B-MISC', 'O', 'O']
self.assertListEqual(golden, iob2(tags))

tags = ['I-ORG', 'O']
golden = ['B-ORG', 'O']
self.assertListEqual(golden, iob2(tags))

tags = ['I-MISC', 'I-MISC', 'O', 'I-PER', 'I-PER', 'O']
golden = ['B-MISC', 'I-MISC', 'O', 'B-PER', 'I-PER', 'O']
self.assertListEqual(golden, iob2(tags))

def test_iob2bemso(self):
tags = ['B-MISC', 'I-MISC', 'O', 'B-PER', 'I-PER', 'O']
golden = ['B-MISC', 'E-MISC', 'O', 'B-PER', 'E-PER', 'O']
self.assertListEqual(golden, iob2bioes(tags))


def test_conll2003_loader():
path = '/hdd/fudanNLP/fastNLP/others/data/conll2003/train.txt'
loader = Conll2003DataLoader().load(path)
print(loader[:3])


if __name__ == '__main__':
test_conll2003_loader()

+ 52
- 0
reproduction/seqence_labelling/ner/train_bert.py View File

@@ -0,0 +1,52 @@


"""
使用Bert进行英文命名实体识别

"""

import sys

sys.path.append('../../../')

from reproduction.seqence_labelling.ner.model.bert_crf import BertCRF
from fastNLP.embeddings import BertEmbedding
from fastNLP import Trainer, Const
from fastNLP import BucketSampler, SpanFPreRecMetric, GradientClipCallback
from fastNLP.core.callback import WarmupCallback
from fastNLP.core.optimizer import AdamW
from fastNLP.io import Conll2003NERPipe

from fastNLP import cache_results, EvaluateCallback

encoding_type = 'bioes'

@cache_results('caches/conll2003.pkl', _refresh=False)
def load_data():
# 替换路径
paths = 'data/conll2003'
data = Conll2003NERPipe(encoding_type=encoding_type).process_from_file(paths)
return data
data = load_data()
print(data)

embed = BertEmbedding(data.get_vocab(Const.INPUT), model_dir_or_name='en-base-cased',
pool_method='max', requires_grad=True, layers='11', include_cls_sep=False, dropout=0.5,
word_dropout=0.01)

callbacks = [
GradientClipCallback(clip_type='norm', clip_value=1),
WarmupCallback(warmup=0.1, schedule='linear'),
EvaluateCallback(data.get_dataset('test'))
]

model = BertCRF(embed, tag_vocab=data.get_vocab('target'), encoding_type=encoding_type)
optimizer = AdamW(model.parameters(), lr=2e-5)

trainer = Trainer(train_data=data.datasets['train'], model=model, optimizer=optimizer, sampler=BucketSampler(),
device=0, dev_data=data.datasets['dev'], batch_size=6,
metrics=SpanFPreRecMetric(tag_vocab=data.vocabs[Const.TARGET], encoding_type=encoding_type),
loss=None, callbacks=callbacks, num_workers=2, n_epochs=5,
check_code_level=0, update_every=3, test_use_tqdm=False)
trainer.train()


+ 4
- 18
reproduction/seqence_labelling/ner/train_idcnn.py View File

@@ -1,4 +1,4 @@
from reproduction.seqence_labelling.ner.data.OntoNoteLoader import OntoNoteNERDataLoader
from fastNLP.io import OntoNotesNERPipe
from fastNLP.core.callback import LRScheduler
from fastNLP import GradientClipCallback
from torch.optim.lr_scheduler import LambdaLR
@@ -10,14 +10,10 @@ from fastNLP import Trainer, Tester
from fastNLP.core.metrics import MetricBase
from reproduction.seqence_labelling.ner.model.dilated_cnn import IDCNN
from fastNLP.core.utils import Option
from fastNLP.embeddings.embedding import StaticEmbedding
from fastNLP.embeddings import StaticEmbedding
from fastNLP.core.utils import cache_results
from fastNLP.core.vocabulary import VocabularyOption
import torch.cuda
import os
os.environ['FASTNLP_BASE_URL'] = 'http://10.141.222.118:8888/file/download/'
os.environ['FASTNLP_CACHE_DIR'] = '/remote-home/hyan01/fastnlp_caches'
os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"

encoding_type = 'bioes'

@@ -40,18 +36,8 @@ ops = Option(
@cache_results('ontonotes-case-cache')
def load_data():
print('loading data')
data = OntoNoteNERDataLoader(encoding_type=encoding_type).process(
paths = get_path('workdir/datasets/ontonotes-v4'),
lower=False,
word_vocab_opt=VocabularyOption(min_freq=0),
)
# data = Conll2003DataLoader(task='ner', encoding_type=encoding_type).process(
# paths=get_path('workdir/datasets/conll03'),
# lower=False, word_vocab_opt=VocabularyOption(min_freq=0)
# )

# char_embed = CNNCharEmbedding(vocab=data.vocabs['cap_words'], embed_size=30, char_emb_size=30, filter_nums=[30],
# kernel_sizes=[3])
data = OntoNotesNERPipe(encoding_type=encoding_type).process_from_file(
paths = get_path('workdir/datasets/ontonotes-v4'))
print('loading embedding')
word_embed = StaticEmbedding(vocab=data.vocabs[Const.INPUT],
model_dir_or_name='en-glove-840b-300',


Loading…
Cancel
Save