@@ -17,7 +17,7 @@ from numbers import Number | |||
from .sampler import SequentialSampler | |||
from .dataset import DataSet | |||
from ._logger import logger | |||
_python_is_exit = False | |||
@@ -75,7 +75,7 @@ class DataSetGetter: | |||
try: | |||
data, flag = _to_tensor(data, f.dtype) | |||
except TypeError as e: | |||
print(f"Field {n} cannot be converted to torch.tensor.") | |||
logger.error(f"Field {n} cannot be converted to torch.tensor.") | |||
raise e | |||
batch_dict[n] = data | |||
return batch_dict | |||
@@ -83,7 +83,6 @@ try: | |||
except: | |||
tensorboardX_flag = False | |||
from ..io.model_io import ModelSaver, ModelLoader | |||
from .dataset import DataSet | |||
from .tester import Tester | |||
from ._logger import logger | |||
@@ -505,7 +504,7 @@ class EarlyStopCallback(Callback): | |||
def on_exception(self, exception): | |||
if isinstance(exception, EarlyStopError): | |||
print("Early Stopping triggered in epoch {}!".format(self.epoch)) | |||
logger.info("Early Stopping triggered in epoch {}!".format(self.epoch)) | |||
else: | |||
raise exception # 抛出陌生Error | |||
@@ -752,8 +751,7 @@ class LRFinder(Callback): | |||
self.smooth_value = SmoothValue(0.8) | |||
self.opt = None | |||
self.find = None | |||
self.loader = ModelLoader() | |||
@property | |||
def lr_gen(self): | |||
scale = (self.end_lr - self.start_lr) / self.batch_per_epoch | |||
@@ -768,7 +766,7 @@ class LRFinder(Callback): | |||
self.opt = self.trainer.optimizer # pytorch optimizer | |||
self.opt.param_groups[0]["lr"] = self.start_lr | |||
# save model | |||
ModelSaver("tmp").save_pytorch(self.trainer.model, param_only=True) | |||
torch.save(self.model.state_dict(), 'tmp') | |||
self.find = True | |||
def on_backward_begin(self, loss): | |||
@@ -797,7 +795,9 @@ class LRFinder(Callback): | |||
self.opt.param_groups[0]["lr"] = self.best_lr | |||
self.find = False | |||
# reset model | |||
ModelLoader().load_pytorch(self.trainer.model, "tmp") | |||
states = torch.load('tmp') | |||
self.model.load_state_dict(states) | |||
os.remove('tmp') | |||
self.pbar.write("Model reset. \nFind best lr={}".format(self.best_lr)) | |||
@@ -988,14 +988,14 @@ class SaveModelCallback(Callback): | |||
try: | |||
_save_model(self.model, model_name=name, save_dir=self.save_dir, only_param=self.only_param) | |||
except Exception as e: | |||
print(f"The following exception:{e} happens when save model to {self.save_dir}.") | |||
logger.error(f"The following exception:{e} happens when save model to {self.save_dir}.") | |||
if delete_pair: | |||
try: | |||
delete_model_path = os.path.join(self.save_dir, delete_pair[1]) | |||
if os.path.exists(delete_model_path): | |||
os.remove(delete_model_path) | |||
except Exception as e: | |||
print(f"Fail to delete model {name} at {self.save_dir} caused by exception:{e}.") | |||
logger.error(f"Fail to delete model {name} at {self.save_dir} caused by exception:{e}.") | |||
def on_exception(self, exception): | |||
if self.save_on_exception: | |||
@@ -1032,7 +1032,7 @@ class EchoCallback(Callback): | |||
def __getattribute__(self, item): | |||
if item.startswith('on_'): | |||
print('{}.{} has been called at pid: {}'.format(self.name, item, os.getpid()), | |||
logger.info('{}.{} has been called at pid: {}'.format(self.name, item, os.getpid()), | |||
file=self.out) | |||
return super(EchoCallback, self).__getattribute__(item) | |||
@@ -300,6 +300,7 @@ from .utils import _get_func_signature | |||
from .field import AppendToTargetOrInputException | |||
from .field import SetInputOrTargetException | |||
from .const import Const | |||
from ._logger import logger | |||
class DataSet(object): | |||
""" | |||
@@ -452,7 +453,7 @@ class DataSet(object): | |||
try: | |||
self.field_arrays[name].append(field) | |||
except AppendToTargetOrInputException as e: | |||
print(f"Cannot append to field:{name}.") | |||
logger.error(f"Cannot append to field:{name}.") | |||
raise e | |||
def add_fieldarray(self, field_name, fieldarray): | |||
@@ -609,7 +610,7 @@ class DataSet(object): | |||
self.field_arrays[name]._use_1st_ins_infer_dim_type = bool(use_1st_ins_infer_dim_type) | |||
self.field_arrays[name].is_target = flag | |||
except SetInputOrTargetException as e: | |||
print(f"Cannot set field:{name} as target.") | |||
logger.error(f"Cannot set field:{name} as target.") | |||
raise e | |||
else: | |||
raise KeyError("{} is not a valid field name.".format(name)) | |||
@@ -633,7 +634,7 @@ class DataSet(object): | |||
self.field_arrays[name]._use_1st_ins_infer_dim_type = bool(use_1st_ins_infer_dim_type) | |||
self.field_arrays[name].is_input = flag | |||
except SetInputOrTargetException as e: | |||
print(f"Cannot set field:{name} as input, exception happens at the {e.index} value.") | |||
logger.error(f"Cannot set field:{name} as input, exception happens at the {e.index} value.") | |||
raise e | |||
else: | |||
raise KeyError("{} is not a valid field name.".format(name)) | |||
@@ -728,7 +729,7 @@ class DataSet(object): | |||
results.append(func(ins[field_name])) | |||
except Exception as e: | |||
if idx != -1: | |||
print("Exception happens at the `{}`th(from 1) instance.".format(idx+1)) | |||
logger.error("Exception happens at the `{}`th(from 1) instance.".format(idx+1)) | |||
raise e | |||
if not (new_field_name is None) and len(list(filter(lambda x: x is not None, results))) == 0: # all None | |||
raise ValueError("{} always return None.".format(_get_func_signature(func=func))) | |||
@@ -795,7 +796,7 @@ class DataSet(object): | |||
results.append(func(ins)) | |||
except BaseException as e: | |||
if idx != -1: | |||
print("Exception happens at the `{}`th instance.".format(idx)) | |||
logger.error("Exception happens at the `{}`th instance.".format(idx)) | |||
raise e | |||
# results = [func(ins) for ins in self._inner_iter()] | |||
@@ -54,7 +54,6 @@ class DistTrainer(): | |||
num_workers=1, drop_last=False, | |||
dev_data=None, metrics=None, metric_key=None, | |||
update_every=1, print_every=10, validate_every=-1, | |||
log_path=None, | |||
save_every=-1, save_path=None, device='auto', | |||
fp16='', backend=None, init_method=None): | |||
@@ -12,6 +12,7 @@ from abc import abstractmethod | |||
from copy import deepcopy | |||
from collections import Counter | |||
from .utils import _is_iterable | |||
from ._logger import logger | |||
class SetInputOrTargetException(Exception): | |||
@@ -39,7 +40,7 @@ class FieldArray: | |||
try: | |||
_content = list(_content) | |||
except BaseException as e: | |||
print(f"Cannot convert content(of type:{type(content)}) into list.") | |||
logger.error(f"Cannot convert content(of type:{type(content)}) into list.") | |||
raise e | |||
self.name = name | |||
self.content = _content | |||
@@ -263,7 +264,7 @@ class FieldArray: | |||
try: | |||
new_contents.append(cell.split(sep)) | |||
except Exception as e: | |||
print(f"Exception happens when process value in index {index}.") | |||
logger.error(f"Exception happens when process value in index {index}.") | |||
raise e | |||
return self._after_process(new_contents, inplace=inplace) | |||
@@ -283,8 +284,8 @@ class FieldArray: | |||
else: | |||
new_contents.append(int(cell)) | |||
except Exception as e: | |||
print(f"Exception happens when process value in index {index}.") | |||
print(e) | |||
logger.error(f"Exception happens when process value in index {index}.") | |||
raise e | |||
return self._after_process(new_contents, inplace=inplace) | |||
def float(self, inplace=True): | |||
@@ -303,7 +304,7 @@ class FieldArray: | |||
else: | |||
new_contents.append(float(cell)) | |||
except Exception as e: | |||
print(f"Exception happens when process value in index {index}.") | |||
logger.error(f"Exception happens when process value in index {index}.") | |||
raise e | |||
return self._after_process(new_contents, inplace=inplace) | |||
@@ -323,7 +324,7 @@ class FieldArray: | |||
else: | |||
new_contents.append(bool(cell)) | |||
except Exception as e: | |||
print(f"Exception happens when process value in index {index}.") | |||
logger.error(f"Exception happens when process value in index {index}.") | |||
raise e | |||
return self._after_process(new_contents, inplace=inplace) | |||
@@ -344,7 +345,7 @@ class FieldArray: | |||
else: | |||
new_contents.append(cell.lower()) | |||
except Exception as e: | |||
print(f"Exception happens when process value in index {index}.") | |||
logger.error(f"Exception happens when process value in index {index}.") | |||
raise e | |||
return self._after_process(new_contents, inplace=inplace) | |||
@@ -364,7 +365,7 @@ class FieldArray: | |||
else: | |||
new_contents.append(cell.upper()) | |||
except Exception as e: | |||
print(f"Exception happens when process value in index {index}.") | |||
logger.error(f"Exception happens when process value in index {index}.") | |||
raise e | |||
return self._after_process(new_contents, inplace=inplace) | |||
@@ -401,7 +402,7 @@ class FieldArray: | |||
self.is_input = self.is_input | |||
self.is_target = self.is_input | |||
except SetInputOrTargetException as e: | |||
print("The newly generated field cannot be set as input or target.") | |||
logger.error("The newly generated field cannot be set as input or target.") | |||
raise e | |||
return self | |||
else: | |||
@@ -192,7 +192,7 @@ class Tester(object): | |||
dataset=self.data, check_level=0) | |||
if self.verbose >= 1: | |||
print("[tester] \n{}".format(self._format_eval_results(eval_results))) | |||
logger.info("[tester] \n{}".format(self._format_eval_results(eval_results))) | |||
self._mode(network, is_test=False) | |||
return eval_results | |||
@@ -145,7 +145,7 @@ def cache_results(_cache_fp, _refresh=False, _verbose=1): | |||
with open(cache_filepath, 'rb') as f: | |||
results = _pickle.load(f) | |||
if verbose == 1: | |||
print("Read cache from {}.".format(cache_filepath)) | |||
logger.info("Read cache from {}.".format(cache_filepath)) | |||
refresh_flag = False | |||
if refresh_flag: | |||
@@ -156,7 +156,7 @@ def cache_results(_cache_fp, _refresh=False, _verbose=1): | |||
_prepare_cache_filepath(cache_filepath) | |||
with open(cache_filepath, 'wb') as f: | |||
_pickle.dump(results, f) | |||
print("Save cache to {}.".format(cache_filepath)) | |||
logger.info("Save cache to {}.".format(cache_filepath)) | |||
return results | |||
@@ -10,6 +10,7 @@ from .utils import Option | |||
from functools import partial | |||
import numpy as np | |||
from .utils import _is_iterable | |||
from ._logger import logger | |||
class VocabularyOption(Option): | |||
def __init__(self, | |||
@@ -49,7 +50,7 @@ def _check_build_status(func): | |||
if self.rebuild is False: | |||
self.rebuild = True | |||
if self.max_size is not None and len(self.word_count) >= self.max_size: | |||
print("[Warning] Vocabulary has reached the max size {} when calling {} method. " | |||
logger.info("[Warning] Vocabulary has reached the max size {} when calling {} method. " | |||
"Adding more words may cause unexpected behaviour of Vocabulary. ".format( | |||
self.max_size, func.__name__)) | |||
return func(self, *args, **kwargs) | |||
@@ -297,7 +298,7 @@ class Vocabulary(object): | |||
for f_n, n_f_n in zip(field_name, new_field_name): | |||
dataset.apply_field(index_instance, field_name=f_n, new_field_name=n_f_n) | |||
except Exception as e: | |||
print("When processing the `{}` dataset, the following error occurred.".format(idx)) | |||
logger.info("When processing the `{}` dataset, the following error occurred.".format(idx)) | |||
raise e | |||
else: | |||
raise RuntimeError("Only DataSet type is allowed.") | |||
@@ -353,7 +354,7 @@ class Vocabulary(object): | |||
try: | |||
dataset.apply(construct_vocab) | |||
except BaseException as e: | |||
print("When processing the `{}` dataset, the following error occurred:".format(idx)) | |||
log("When processing the `{}` dataset, the following error occurred:".format(idx)) | |||
raise e | |||
else: | |||
raise TypeError("Only DataSet type is allowed.") | |||
@@ -21,6 +21,7 @@ from ..io.file_utils import _get_embedding_url, cached_path, PRETRAINED_BERT_MOD | |||
from ..modules.encoder.bert import _WordPieceBertModel, BertModel, BertTokenizer | |||
from .contextual_embedding import ContextualEmbedding | |||
import warnings | |||
from ..core import logger | |||
class BertEmbedding(ContextualEmbedding): | |||
@@ -125,8 +126,10 @@ class BertEmbedding(ContextualEmbedding): | |||
with torch.no_grad(): | |||
if self._word_sep_index: # 不能drop sep | |||
sep_mask = words.eq(self._word_sep_index) | |||
mask = torch.ones_like(words).float() * self.word_dropout | |||
mask = torch.full_like(words, fill_value=self.word_dropout) | |||
mask = torch.bernoulli(mask).eq(1) # dropout_word越大,越多位置为1 | |||
pad_mask = words.ne(0) | |||
mask = pad_mask.__and__(mask) # pad的位置不为unk | |||
words = words.masked_fill(mask, self._word_unk_index) | |||
if self._word_sep_index: | |||
words.masked_fill_(sep_mask, self._word_sep_index) | |||
@@ -182,6 +185,7 @@ class BertWordPieceEncoder(nn.Module): | |||
self.model = _WordPieceBertModel(model_dir=model_dir, layers=layers, pooled_cls=pooled_cls) | |||
self._sep_index = self.model._sep_index | |||
self._wordpiece_pad_index = self.model._wordpiece_pad_index | |||
self._wordpiece_unk_index = self.model._wordpiece_unknown_index | |||
self._embed_size = len(self.model.layers) * self.model.encoder.hidden_size | |||
self.requires_grad = requires_grad | |||
@@ -263,8 +267,10 @@ class BertWordPieceEncoder(nn.Module): | |||
with torch.no_grad(): | |||
if self._word_sep_index: # 不能drop sep | |||
sep_mask = words.eq(self._wordpiece_unk_index) | |||
mask = torch.ones_like(words).float() * self.word_dropout | |||
mask = torch.full_like(words, fill_value=self.word_dropout) | |||
mask = torch.bernoulli(mask).eq(1) # dropout_word越大,越多位置为1 | |||
pad_mask = words.ne(self._wordpiece_pad_index) | |||
mask = pad_mask.__and__(mask) # pad的位置不为unk | |||
words = words.masked_fill(mask, self._word_unk_index) | |||
if self._word_sep_index: | |||
words.masked_fill_(sep_mask, self._wordpiece_unk_index) | |||
@@ -297,7 +303,7 @@ class _WordBertModel(nn.Module): | |||
self.auto_truncate = auto_truncate | |||
# 将所有vocab中word的wordpiece计算出来, 需要额外考虑[CLS]和[SEP] | |||
print("Start to generating word pieces for word.") | |||
logger.info("Start to generating word pieces for word.") | |||
# 第一步统计出需要的word_piece, 然后创建新的embed和word_piece_vocab, 然后填入值 | |||
word_piece_dict = {'[CLS]': 1, '[SEP]': 1} # 用到的word_piece以及新增的 | |||
found_count = 0 | |||
@@ -356,10 +362,10 @@ class _WordBertModel(nn.Module): | |||
self._sep_index = self.tokenzier.vocab['[SEP]'] | |||
self._word_pad_index = vocab.padding_idx | |||
self._wordpiece_pad_index = self.tokenzier.vocab['[PAD]'] # 需要用于生成word_piece | |||
print("Found(Or segment into word pieces) {} words out of {}.".format(found_count, len(vocab))) | |||
logger.info("Found(Or segment into word pieces) {} words out of {}.".format(found_count, len(vocab))) | |||
self.word_to_wordpieces = np.array(word_to_wordpieces) | |||
self.word_pieces_lengths = nn.Parameter(torch.LongTensor(word_pieces_lengths), requires_grad=False) | |||
print("Successfully generate word pieces.") | |||
logger.debug("Successfully generate word pieces.") | |||
def forward(self, words): | |||
""" | |||
@@ -19,6 +19,7 @@ from ..core.vocabulary import Vocabulary | |||
from .embedding import TokenEmbedding | |||
from .utils import _construct_char_vocab_from_vocab | |||
from .utils import get_embeddings | |||
from ..core import logger | |||
class CNNCharEmbedding(TokenEmbedding): | |||
@@ -81,11 +82,11 @@ class CNNCharEmbedding(TokenEmbedding): | |||
raise Exception( | |||
"Undefined activation function: choose from: [relu, tanh, sigmoid, or a callable function]") | |||
print("Start constructing character vocabulary.") | |||
logger.info("Start constructing character vocabulary.") | |||
# 建立char的词表 | |||
self.char_vocab = _construct_char_vocab_from_vocab(vocab, min_freq=min_char_freq) | |||
self.char_pad_index = self.char_vocab.padding_idx | |||
print(f"In total, there are {len(self.char_vocab)} distinct characters.") | |||
logger.info(f"In total, there are {len(self.char_vocab)} distinct characters.") | |||
# 对vocab进行index | |||
max_word_len = max(map(lambda x: len(x[0]), vocab)) | |||
self.words_to_chars_embedding = nn.Parameter(torch.full((len(vocab), max_word_len), | |||
@@ -236,11 +237,11 @@ class LSTMCharEmbedding(TokenEmbedding): | |||
raise Exception( | |||
"Undefined activation function: choose from: [relu, tanh, sigmoid, or a callable function]") | |||
print("Start constructing character vocabulary.") | |||
logger.info("Start constructing character vocabulary.") | |||
# 建立char的词表 | |||
self.char_vocab = _construct_char_vocab_from_vocab(vocab, min_freq=min_char_freq) | |||
self.char_pad_index = self.char_vocab.padding_idx | |||
print(f"In total, there are {len(self.char_vocab)} distinct characters.") | |||
logger.info(f"In total, there are {len(self.char_vocab)} distinct characters.") | |||
# 对vocab进行index | |||
self.max_word_len = max(map(lambda x: len(x[0]), vocab)) | |||
self.words_to_chars_embedding = nn.Parameter(torch.full((len(vocab), self.max_word_len), | |||
@@ -16,7 +16,7 @@ from ..core.batch import DataSetIter | |||
from ..core.sampler import SequentialSampler | |||
from ..core.utils import _move_model_to_device, _get_model_device | |||
from .embedding import TokenEmbedding | |||
from ..core import logger | |||
class ContextualEmbedding(TokenEmbedding): | |||
def __init__(self, vocab: Vocabulary, word_dropout: float = 0.0, dropout: float = 0.0): | |||
@@ -37,14 +37,14 @@ class ContextualEmbedding(TokenEmbedding): | |||
assert isinstance(dataset, DataSet), "Only fastNLP.DataSet object is allowed." | |||
assert 'words' in dataset.get_input_name(), "`words` field has to be set as input." | |||
except Exception as e: | |||
print(f"Exception happens at {index} dataset.") | |||
logger.error(f"Exception happens at {index} dataset.") | |||
raise e | |||
sent_embeds = {} | |||
_move_model_to_device(self, device=device) | |||
device = _get_model_device(self) | |||
pad_index = self._word_vocab.padding_idx | |||
print("Start to calculate sentence representations.") | |||
logger.info("Start to calculate sentence representations.") | |||
with torch.no_grad(): | |||
for index, dataset in enumerate(datasets): | |||
try: | |||
@@ -64,9 +64,9 @@ class ContextualEmbedding(TokenEmbedding): | |||
else: | |||
sent_embeds[tuple(words_list[b][:seq_len[b]])] = word_embeds[b, :-length] | |||
except Exception as e: | |||
print(f"Exception happens at {index} dataset.") | |||
logger.error(f"Exception happens at {index} dataset.") | |||
raise e | |||
print("Finish calculating sentence representations.") | |||
logger.info("Finish calculating sentence representations.") | |||
self.sent_embeds = sent_embeds | |||
if delete_weights: | |||
self._delete_model_weights() | |||
@@ -18,7 +18,7 @@ from ..core.vocabulary import Vocabulary | |||
from ..io.file_utils import cached_path, _get_embedding_url, PRETRAINED_ELMO_MODEL_DIR | |||
from ..modules.encoder._elmo import ElmobiLm, ConvTokenEmbedder | |||
from .contextual_embedding import ContextualEmbedding | |||
from ..core import logger | |||
class ElmoEmbedding(ContextualEmbedding): | |||
""" | |||
@@ -243,7 +243,7 @@ class _ElmoModel(nn.Module): | |||
index_in_pre = char_lexicon[OOV_TAG] | |||
char_emb_layer.weight.data[index] = char_embed_weights[index_in_pre] | |||
print(f"{found_char_count} out of {len(char_vocab)} characters were found in pretrained elmo embedding.") | |||
logger.info(f"{found_char_count} out of {len(char_vocab)} characters were found in pretrained elmo embedding.") | |||
# 生成words到chars的映射 | |||
max_chars = config['char_cnn']['max_characters_per_token'] | |||
@@ -281,7 +281,7 @@ class _ElmoModel(nn.Module): | |||
if cache_word_reprs: | |||
if config['char_cnn']['embedding']['dim'] > 0: # 只有在使用了chars的情况下有用 | |||
print("Start to generate cache word representations.") | |||
logger.info("Start to generate cache word representations.") | |||
batch_size = 320 | |||
# bos eos | |||
word_size = self.words_to_chars_embedding.size(0) | |||
@@ -299,10 +299,10 @@ class _ElmoModel(nn.Module): | |||
chars).detach() # batch_size x 1 x config['encoder']['projection_dim'] | |||
self.cached_word_embedding.weight.data[words] = word_reprs.squeeze(1) | |||
print("Finish generating cached word representations. Going to delete the character encoder.") | |||
logger.info("Finish generating cached word representations. Going to delete the character encoder.") | |||
del self.token_embedder, self.words_to_chars_embedding | |||
else: | |||
print("There is no need to cache word representations, since no character information is used.") | |||
logger.info("There is no need to cache word representations, since no character information is used.") | |||
def forward(self, words): | |||
""" | |||
@@ -138,8 +138,10 @@ class TokenEmbedding(nn.Module): | |||
:return: | |||
""" | |||
if self.word_dropout > 0 and self.training: | |||
mask = torch.ones_like(words).float() * self.word_dropout | |||
mask = torch.full_like(words, fill_value=self.word_dropout) | |||
mask = torch.bernoulli(mask).eq(1) # dropout_word越大,越多位置为1 | |||
pad_mask = words.ne(self._word_pad_index) | |||
mask = mask.__and__(pad_mask) | |||
words = words.masked_fill(mask, self._word_unk_index) | |||
return words | |||
@@ -19,6 +19,7 @@ from .embedding import TokenEmbedding | |||
from ..modules.utils import _get_file_name_base_on_postfix | |||
from copy import deepcopy | |||
from collections import defaultdict | |||
from ..core import logger | |||
class StaticEmbedding(TokenEmbedding): | |||
@@ -112,7 +113,7 @@ class StaticEmbedding(TokenEmbedding): | |||
truncated_words_to_words = torch.arange(len(vocab)).long() | |||
for word, index in vocab: | |||
truncated_words_to_words[index] = truncated_vocab.to_index(word) | |||
print(f"{len(vocab) - len(truncated_vocab)} out of {len(vocab)} words have frequency less than {min_freq}.") | |||
logger.info(f"{len(vocab) - len(truncated_vocab)} out of {len(vocab)} words have frequency less than {min_freq}.") | |||
vocab = truncated_vocab | |||
self.only_norm_found_vector = kwargs.get('only_norm_found_vector', False) | |||
@@ -124,7 +125,7 @@ class StaticEmbedding(TokenEmbedding): | |||
lowered_vocab.add_word(word.lower(), no_create_entry=True) | |||
else: | |||
lowered_vocab.add_word(word.lower()) # 先加入需要创建entry的 | |||
print(f"All word in the vocab have been lowered. There are {len(vocab)} words, {len(lowered_vocab)} " | |||
logger.info(f"All word in the vocab have been lowered. There are {len(vocab)} words, {len(lowered_vocab)} " | |||
f"unique lowered words.") | |||
if model_path: | |||
embedding = self._load_with_vocab(model_path, vocab=lowered_vocab, init_method=init_method) | |||
@@ -265,9 +266,9 @@ class StaticEmbedding(TokenEmbedding): | |||
if error == 'ignore': | |||
warnings.warn("Error occurred at the {} line.".format(idx)) | |||
else: | |||
print("Error occurred at the {} line.".format(idx)) | |||
logger.error("Error occurred at the {} line.".format(idx)) | |||
raise e | |||
print("Found {} out of {} words in the pre-training embedding.".format(found_count, len(vocab))) | |||
logger.info("Found {} out of {} words in the pre-training embedding.".format(found_count, len(vocab))) | |||
for word, index in vocab: | |||
if index not in matrix and not vocab._is_word_no_create_entry(word): | |||
if found_unknown: # 如果有unkonwn,用unknown初始化 | |||
@@ -11,7 +11,7 @@ import numpy as np | |||
from ..core.vocabulary import Vocabulary | |||
from .data_bundle import BaseLoader | |||
from ..core.utils import Option | |||
import logging | |||
class EmbeddingOption(Option): | |||
def __init__(self, | |||
@@ -91,10 +91,10 @@ class EmbedLoader(BaseLoader): | |||
if error == 'ignore': | |||
warnings.warn("Error occurred at the {} line.".format(idx)) | |||
else: | |||
print("Error occurred at the {} line.".format(idx)) | |||
logging.error("Error occurred at the {} line.".format(idx)) | |||
raise e | |||
total_hits = sum(hit_flags) | |||
print("Found {} out of {} words in the pre-training embedding.".format(total_hits, len(vocab))) | |||
logging.info("Found {} out of {} words in the pre-training embedding.".format(total_hits, len(vocab))) | |||
if init_method is None: | |||
found_vectors = matrix[hit_flags] | |||
if len(found_vectors) != 0: | |||
@@ -157,7 +157,7 @@ class EmbedLoader(BaseLoader): | |||
warnings.warn("Error occurred at the {} line.".format(idx)) | |||
pass | |||
else: | |||
print("Error occurred at the {} line.".format(idx)) | |||
logging.error("Error occurred at the {} line.".format(idx)) | |||
raise e | |||
if dim == -1: | |||
raise RuntimeError("{} is an empty file.".format(embed_filepath)) | |||
@@ -2,7 +2,8 @@ | |||
此模块用于给其它模块提供读取文件的函数,没有为用户提供 API | |||
""" | |||
import json | |||
import warnings | |||
from ..core import logger | |||
def _read_csv(path, encoding='utf-8', headers=None, sep=',', dropna=True): | |||
""" | |||
@@ -103,9 +104,9 @@ def _read_conll(path, encoding='utf-8', indexes=None, dropna=True): | |||
yield line_idx, res | |||
except Exception as e: | |||
if dropna: | |||
warnings.warn('Invalid instance ends at line: {} has been dropped.'.format(line_idx)) | |||
logger.warn('Invalid instance which ends at line: {} has been dropped.'.format(line_idx)) | |||
continue | |||
raise ValueError('Invalid instance ends at line: {}'.format(line_idx)) | |||
raise ValueError('Invalid instance which ends at line: {}'.format(line_idx)) | |||
elif line.startswith('#'): | |||
continue | |||
else: | |||
@@ -117,5 +118,5 @@ def _read_conll(path, encoding='utf-8', indexes=None, dropna=True): | |||
except Exception as e: | |||
if dropna: | |||
return | |||
print('invalid instance ends at line: {}'.format(line_idx)) | |||
logger.error('invalid instance ends at line: {}'.format(line_idx)) | |||
raise e |
@@ -7,6 +7,7 @@ import tempfile | |||
from tqdm import tqdm | |||
import shutil | |||
from requests import HTTPError | |||
from ..core import logger | |||
PRETRAINED_BERT_MODEL_DIR = { | |||
'en': 'bert-base-cased.zip', | |||
@@ -336,7 +337,7 @@ def get_from_cache(url: str, cache_dir: Path = None) -> Path: | |||
content_length = req.headers.get("Content-Length") | |||
total = int(content_length) if content_length is not None else None | |||
progress = tqdm(unit="B", total=total, unit_scale=1) | |||
print("%s not found in cache, downloading to %s" % (url, temp_filename)) | |||
logger.info("%s not found in cache, downloading to %s" % (url, temp_filename)) | |||
with open(temp_filename, "wb") as temp_file: | |||
for chunk in req.iter_content(chunk_size=1024 * 16): | |||
@@ -344,12 +345,12 @@ def get_from_cache(url: str, cache_dir: Path = None) -> Path: | |||
progress.update(len(chunk)) | |||
temp_file.write(chunk) | |||
progress.close() | |||
print(f"Finish download from {url}") | |||
logger.info(f"Finish download from {url}") | |||
# 开始解压 | |||
if suffix in ('.zip', '.tar.gz', '.gz'): | |||
uncompress_temp_dir = tempfile.mkdtemp() | |||
print(f"Start to uncompress file to {uncompress_temp_dir}") | |||
logger.debug(f"Start to uncompress file to {uncompress_temp_dir}") | |||
if suffix == '.zip': | |||
unzip_file(Path(temp_filename), Path(uncompress_temp_dir)) | |||
elif suffix == '.gz': | |||
@@ -362,13 +363,13 @@ def get_from_cache(url: str, cache_dir: Path = None) -> Path: | |||
uncompress_temp_dir = os.path.join(uncompress_temp_dir, filenames[0]) | |||
cache_path.mkdir(parents=True, exist_ok=True) | |||
print("Finish un-compressing file.") | |||
logger.debug("Finish un-compressing file.") | |||
else: | |||
uncompress_temp_dir = temp_filename | |||
cache_path = str(cache_path) + suffix | |||
# 复制到指定的位置 | |||
print(f"Copy file to {cache_path}") | |||
logger.info(f"Copy file to {cache_path}") | |||
if os.path.isdir(uncompress_temp_dir): | |||
for filename in os.listdir(uncompress_temp_dir): | |||
if os.path.isdir(os.path.join(uncompress_temp_dir, filename)): | |||
@@ -379,7 +380,7 @@ def get_from_cache(url: str, cache_dir: Path = None) -> Path: | |||
shutil.copyfile(uncompress_temp_dir, cache_path) | |||
success = True | |||
except Exception as e: | |||
print(e) | |||
logger.error(e) | |||
raise e | |||
finally: | |||
if not success: | |||
@@ -11,7 +11,7 @@ from .utils import get_tokenizer, _indexize, _add_words_field, _drop_empty_insta | |||
from .pipe import Pipe | |||
import re | |||
nonalpnum = re.compile('[^0-9a-zA-Z?!\']+') | |||
from ...core.utils import cache_results | |||
class _CLSPipe(Pipe): | |||
""" | |||
@@ -2,7 +2,7 @@ import os | |||
from typing import Union, Dict | |||
from pathlib import Path | |||
from ..core import logger | |||
def check_loader_paths(paths:Union[str, Dict[str, str]])->Dict[str, str]: | |||
""" | |||
@@ -70,8 +70,8 @@ def get_tokenizer(): | |||
import spacy | |||
spacy.prefer_gpu() | |||
en = spacy.load('en') | |||
print('use spacy tokenizer') | |||
logger.info('use spacy tokenizer') | |||
return lambda x: [w.text for w in en.tokenizer(x)] | |||
except Exception as e: | |||
print('use raw tokenizer') | |||
logger.error('use raw tokenizer') | |||
return lambda x: x.split() |
@@ -17,8 +17,7 @@ import os | |||
import torch | |||
from torch import nn | |||
import sys | |||
from ...core import logger | |||
from ..utils import _get_file_name_base_on_postfix | |||
CONFIG_FILE = 'bert_config.json' | |||
@@ -489,10 +488,10 @@ class BertModel(nn.Module): | |||
load(model, prefix='' if hasattr(model, 'bert') else 'bert.') | |||
if len(missing_keys) > 0: | |||
print("Weights of {} not initialized from pretrained model: {}".format( | |||
logger.warn("Weights of {} not initialized from pretrained model: {}".format( | |||
model.__class__.__name__, missing_keys)) | |||
if len(unexpected_keys) > 0: | |||
print("Weights from pretrained model not used in {}: {}".format( | |||
logger.warn("Weights from pretrained model not used in {}: {}".format( | |||
model.__class__.__name__, unexpected_keys)) | |||
return model | |||
@@ -799,7 +798,7 @@ class BertTokenizer(object): | |||
for token in tokens: | |||
ids.append(self.vocab[token]) | |||
if len(ids) > self.max_len: | |||
print( | |||
logger.warn( | |||
"Token indices sequence length is longer than the specified maximum " | |||
" sequence length for this BERT model ({} > {}). Running this" | |||
" sequence through BERT will result in indexing errors".format(len(ids), self.max_len) | |||
@@ -823,7 +822,7 @@ class BertTokenizer(object): | |||
with open(vocab_file, "w", encoding="utf-8") as writer: | |||
for token, token_index in sorted(self.vocab.items(), key=lambda kv: kv[1]): | |||
if index != token_index: | |||
print("Saving vocabulary to {}: vocabulary indices are not consecutive." | |||
logger.warn("Saving vocabulary to {}: vocabulary indices are not consecutive." | |||
" Please check that the vocabulary is not corrupted!".format(vocab_file)) | |||
index = token_index | |||
writer.write(token + u'\n') | |||
@@ -837,7 +836,7 @@ class BertTokenizer(object): | |||
""" | |||
pretrained_model_name_or_path = _get_file_name_base_on_postfix(model_dir, '.txt') | |||
print("loading vocabulary file {}".format(pretrained_model_name_or_path)) | |||
logger.info("loading vocabulary file {}".format(pretrained_model_name_or_path)) | |||
max_len = 512 | |||
kwargs['max_len'] = min(kwargs.get('max_position_embeddings', int(1e12)), max_len) | |||
# Instantiate tokenizer. | |||
@@ -901,7 +900,7 @@ class _WordPieceBertModel(nn.Module): | |||
is_input=True) | |||
dataset.set_pad_val('word_pieces', self._wordpiece_pad_index) | |||
except Exception as e: | |||
print(f"Exception happens when processing the {index} dataset.") | |||
logger.error(f"Exception happens when processing the {index} dataset.") | |||
raise e | |||
def forward(self, word_pieces, token_type_ids=None): | |||
@@ -1,93 +0,0 @@ | |||
from fastNLP.core.vocabulary import VocabularyOption | |||
from fastNLP.io.data_bundle import DataSetLoader, DataBundle | |||
from typing import Union, Dict | |||
from fastNLP import Vocabulary | |||
from fastNLP import Const | |||
from reproduction.utils import check_dataloader_paths | |||
from fastNLP.io import ConllLoader | |||
from reproduction.seqence_labelling.ner.data.utils import iob2bioes, iob2 | |||
class Conll2003DataLoader(DataSetLoader): | |||
def __init__(self, task:str='ner', encoding_type:str='bioes'): | |||
""" | |||
加载Conll2003格式的英语语料,该数据集的信息可以在https://www.clips.uantwerpen.be/conll2003/ner/找到。当task为pos | |||
时,返回的DataSet中target取值于第2列; 当task为chunk时,返回的DataSet中target取值于第3列;当task为ner时,返回 | |||
的DataSet中target取值于第4列。所有"-DOCSTART- -X- O O"将被忽略,这会导致数据的数量少于很多文献报道的值,但 | |||
鉴于"-DOCSTART- -X- O O"只是用于文档分割的符号,并不应该作为预测对象,所以我们忽略了数据中的-DOCTSTART-开头的行 | |||
ner与chunk任务读取后的数据的target将为encoding_type类型。pos任务读取后就是pos列的数据。 | |||
:param task: 指定需要标注任务。可选ner, pos, chunk | |||
""" | |||
assert task in ('ner', 'pos', 'chunk') | |||
index = {'ner':3, 'pos':1, 'chunk':2}[task] | |||
self._loader = ConllLoader(headers=['raw_words', 'target'], indexes=[0, index]) | |||
self._tag_converters = [] | |||
if task in ('ner', 'chunk'): | |||
self._tag_converters = [iob2] | |||
if encoding_type == 'bioes': | |||
self._tag_converters.append(iob2bioes) | |||
def load(self, path: str): | |||
dataset = self._loader.load(path) | |||
def convert_tag_schema(tags): | |||
for converter in self._tag_converters: | |||
tags = converter(tags) | |||
return tags | |||
if self._tag_converters: | |||
dataset.apply_field(convert_tag_schema, field_name=Const.TARGET, new_field_name=Const.TARGET) | |||
return dataset | |||
def process(self, paths: Union[str, Dict[str, str]], word_vocab_opt:VocabularyOption=None, lower:bool=False): | |||
""" | |||
读取并处理数据。数据中的'-DOCSTART-'开头的行会被忽略 | |||
:param paths: | |||
:param word_vocab_opt: vocabulary的初始化值 | |||
:param lower: 是否将所有字母转为小写。 | |||
:return: | |||
""" | |||
# 读取数据 | |||
paths = check_dataloader_paths(paths) | |||
data = DataBundle() | |||
input_fields = [Const.TARGET, Const.INPUT, Const.INPUT_LEN] | |||
target_fields = [Const.TARGET, Const.INPUT_LEN] | |||
for name, path in paths.items(): | |||
dataset = self.load(path) | |||
dataset.apply_field(lambda words: words, field_name='raw_words', new_field_name=Const.INPUT) | |||
if lower: | |||
dataset.words.lower() | |||
data.datasets[name] = dataset | |||
# 对construct vocab | |||
word_vocab = Vocabulary(min_freq=2) if word_vocab_opt is None else Vocabulary(**word_vocab_opt) | |||
word_vocab.from_dataset(data.datasets['train'], field_name=Const.INPUT, | |||
no_create_entry_dataset=[dataset for name, dataset in data.datasets.items() if name!='train']) | |||
word_vocab.index_dataset(*data.datasets.values(), field_name=Const.INPUT, new_field_name=Const.INPUT) | |||
data.vocabs[Const.INPUT] = word_vocab | |||
# cap words | |||
cap_word_vocab = Vocabulary() | |||
cap_word_vocab.from_dataset(data.datasets['train'], field_name='raw_words', | |||
no_create_entry_dataset=[dataset for name, dataset in data.datasets.items() if name!='train']) | |||
cap_word_vocab.index_dataset(*data.datasets.values(), field_name='raw_words', new_field_name='cap_words') | |||
input_fields.append('cap_words') | |||
data.vocabs['cap_words'] = cap_word_vocab | |||
# 对target建vocab | |||
target_vocab = Vocabulary(unknown=None, padding=None) | |||
target_vocab.from_dataset(*data.datasets.values(), field_name=Const.TARGET) | |||
target_vocab.index_dataset(*data.datasets.values(), field_name=Const.TARGET) | |||
data.vocabs[Const.TARGET] = target_vocab | |||
for name, dataset in data.datasets.items(): | |||
dataset.add_seq_len(Const.INPUT, new_field_name=Const.INPUT_LEN) | |||
dataset.set_input(*input_fields) | |||
dataset.set_target(*target_fields) | |||
return data | |||
if __name__ == '__main__': | |||
pass |
@@ -1,152 +0,0 @@ | |||
from fastNLP.core.vocabulary import VocabularyOption | |||
from fastNLP.io.data_bundle import DataSetLoader, DataBundle | |||
from typing import Union, Dict | |||
from fastNLP import DataSet | |||
from fastNLP import Vocabulary | |||
from fastNLP import Const | |||
from reproduction.utils import check_dataloader_paths | |||
from fastNLP.io import ConllLoader | |||
from reproduction.seqence_labelling.ner.data.utils import iob2bioes, iob2 | |||
class OntoNoteNERDataLoader(DataSetLoader): | |||
""" | |||
用于读取处理为Conll格式后的OntoNote数据。将OntoNote数据处理为conll格式的过程可以参考https://github.com/yhcc/OntoNotes-5.0-NER。 | |||
""" | |||
def __init__(self, encoding_type:str='bioes'): | |||
assert encoding_type in ('bioes', 'bio') | |||
self.encoding_type = encoding_type | |||
if encoding_type=='bioes': | |||
self.encoding_method = iob2bioes | |||
else: | |||
self.encoding_method = iob2 | |||
def load(self, path:str)->DataSet: | |||
""" | |||
给定一个文件路径,读取数据。返回的DataSet包含以下的field | |||
raw_words: List[str] | |||
target: List[str] | |||
:param path: | |||
:return: | |||
""" | |||
dataset = ConllLoader(headers=['raw_words', 'target'], indexes=[3, 10]).load(path) | |||
def convert_to_bio(tags): | |||
bio_tags = [] | |||
flag = None | |||
for tag in tags: | |||
label = tag.strip("()*") | |||
if '(' in tag: | |||
bio_label = 'B-' + label | |||
flag = label | |||
elif flag: | |||
bio_label = 'I-' + flag | |||
else: | |||
bio_label = 'O' | |||
if ')' in tag: | |||
flag = None | |||
bio_tags.append(bio_label) | |||
return self.encoding_method(bio_tags) | |||
def convert_word(words): | |||
converted_words = [] | |||
for word in words: | |||
word = word.replace('/.', '.') # 有些结尾的.是/.形式的 | |||
if not word.startswith('-'): | |||
converted_words.append(word) | |||
continue | |||
# 以下是由于这些符号被转义了,再转回来 | |||
tfrs = {'-LRB-':'(', | |||
'-RRB-': ')', | |||
'-LSB-': '[', | |||
'-RSB-': ']', | |||
'-LCB-': '{', | |||
'-RCB-': '}' | |||
} | |||
if word in tfrs: | |||
converted_words.append(tfrs[word]) | |||
else: | |||
converted_words.append(word) | |||
return converted_words | |||
dataset.apply_field(convert_word, field_name='raw_words', new_field_name='raw_words') | |||
dataset.apply_field(convert_to_bio, field_name='target', new_field_name='target') | |||
return dataset | |||
def process(self, paths: Union[str, Dict[str, str]], word_vocab_opt:VocabularyOption=None, | |||
lower:bool=True)->DataBundle: | |||
""" | |||
读取并处理数据。返回的DataInfo包含以下的内容 | |||
vocabs: | |||
word: Vocabulary | |||
target: Vocabulary | |||
datasets: | |||
train: DataSet | |||
words: List[int], 被设置为input | |||
target: int. label,被同时设置为input和target | |||
seq_len: int. 句子的长度,被同时设置为input和target | |||
raw_words: List[str] | |||
xxx(根据传入的paths可能有所变化) | |||
:param paths: | |||
:param word_vocab_opt: vocabulary的初始化值 | |||
:param lower: 是否使用小写 | |||
:return: | |||
""" | |||
paths = check_dataloader_paths(paths) | |||
data = DataBundle() | |||
input_fields = [Const.TARGET, Const.INPUT, Const.INPUT_LEN] | |||
target_fields = [Const.TARGET, Const.INPUT_LEN] | |||
for name, path in paths.items(): | |||
dataset = self.load(path) | |||
dataset.apply_field(lambda words: words, field_name='raw_words', new_field_name=Const.INPUT) | |||
if lower: | |||
dataset.words.lower() | |||
data.datasets[name] = dataset | |||
# 对construct vocab | |||
word_vocab = Vocabulary(min_freq=2) if word_vocab_opt is None else Vocabulary(**word_vocab_opt) | |||
word_vocab.from_dataset(data.datasets['train'], field_name=Const.INPUT, | |||
no_create_entry_dataset=[dataset for name, dataset in data.datasets.items() if name!='train']) | |||
word_vocab.index_dataset(*data.datasets.values(), field_name=Const.INPUT, new_field_name=Const.INPUT) | |||
data.vocabs[Const.INPUT] = word_vocab | |||
# cap words | |||
cap_word_vocab = Vocabulary() | |||
cap_word_vocab.from_dataset(*data.datasets.values(), field_name='raw_words') | |||
cap_word_vocab.index_dataset(*data.datasets.values(), field_name='raw_words', new_field_name='cap_words') | |||
input_fields.append('cap_words') | |||
data.vocabs['cap_words'] = cap_word_vocab | |||
# 对target建vocab | |||
target_vocab = Vocabulary(unknown=None, padding=None) | |||
target_vocab.from_dataset(*data.datasets.values(), field_name=Const.TARGET) | |||
target_vocab.index_dataset(*data.datasets.values(), field_name=Const.TARGET) | |||
data.vocabs[Const.TARGET] = target_vocab | |||
for name, dataset in data.datasets.items(): | |||
dataset.add_seq_len(Const.INPUT, new_field_name=Const.INPUT_LEN) | |||
dataset.set_input(*input_fields) | |||
dataset.set_target(*target_fields) | |||
return data | |||
if __name__ == '__main__': | |||
loader = OntoNoteNERDataLoader() | |||
dataset = loader.load('/hdd/fudanNLP/fastNLP/others/data/v4/english/test.txt') | |||
print(dataset.target.value_count()) | |||
print(dataset[:4]) | |||
""" | |||
train 115812 2200752 | |||
development 15680 304684 | |||
test 12217 230111 | |||
train 92403 1901772 | |||
valid 13606 279180 | |||
test 10258 204135 | |||
""" |
@@ -1,49 +0,0 @@ | |||
from typing import List | |||
def iob2(tags:List[str])->List[str]: | |||
""" | |||
检查数据是否是合法的IOB数据,如果是IOB1会被自动转换为IOB2。 | |||
:param tags: 需要转换的tags | |||
""" | |||
for i, tag in enumerate(tags): | |||
if tag == "O": | |||
continue | |||
split = tag.split("-") | |||
if len(split) != 2 or split[0] not in ["I", "B"]: | |||
raise TypeError("The encoding schema is not a valid IOB type.") | |||
if split[0] == "B": | |||
continue | |||
elif i == 0 or tags[i - 1] == "O": # conversion IOB1 to IOB2 | |||
tags[i] = "B" + tag[1:] | |||
elif tags[i - 1][1:] == tag[1:]: | |||
continue | |||
else: # conversion IOB1 to IOB2 | |||
tags[i] = "B" + tag[1:] | |||
return tags | |||
def iob2bioes(tags:List[str])->List[str]: | |||
""" | |||
将iob的tag转换为bmeso编码 | |||
:param tags: | |||
:return: | |||
""" | |||
new_tags = [] | |||
for i, tag in enumerate(tags): | |||
if tag == 'O': | |||
new_tags.append(tag) | |||
else: | |||
split = tag.split('-')[0] | |||
if split == 'B': | |||
if i+1!=len(tags) and tags[i+1].split('-')[0] == 'I': | |||
new_tags.append(tag) | |||
else: | |||
new_tags.append(tag.replace('B-', 'S-')) | |||
elif split == 'I': | |||
if i + 1<len(tags) and tags[i+1].split('-')[0] == 'I': | |||
new_tags.append(tag) | |||
else: | |||
new_tags.append(tag.replace('I-', 'E-')) | |||
else: | |||
raise TypeError("Invalid IOB format.") | |||
return new_tags |
@@ -0,0 +1,31 @@ | |||
from torch import nn | |||
from fastNLP.modules import ConditionalRandomField, allowed_transitions | |||
import torch.nn.functional as F | |||
class BertCRF(nn.Module): | |||
def __init__(self, embed, tag_vocab, encoding_type='bio'): | |||
super().__init__() | |||
self.embed = embed | |||
self.fc = nn.Linear(self.embed.embed_size, len(tag_vocab)) | |||
trans = allowed_transitions(tag_vocab, encoding_type=encoding_type, include_start_end=True) | |||
self.crf = ConditionalRandomField(len(tag_vocab), include_start_end_trans=True, allowed_transitions=trans) | |||
def _forward(self, words, target): | |||
mask = words.ne(0) | |||
words = self.embed(words) | |||
words = self.fc(words) | |||
logits = F.log_softmax(words, dim=-1) | |||
if target is not None: | |||
loss = self.crf(logits, target, mask) | |||
return {'loss': loss} | |||
else: | |||
paths, _ = self.crf.viterbi_decode(logits, mask) | |||
return {'pred': paths} | |||
def forward(self, words, target): | |||
return self._forward(words, target) | |||
def predict(self, words): | |||
return self._forward(words, None) |
@@ -1,33 +0,0 @@ | |||
from reproduction.seqence_labelling.ner.data.Conll2003Loader import Conll2003DataLoader | |||
from reproduction.seqence_labelling.ner.data.Conll2003Loader import iob2, iob2bioes | |||
import unittest | |||
class TestTagSchemaConverter(unittest.TestCase): | |||
def test_iob2(self): | |||
tags = ['B-ORG', 'O', 'B-MISC', 'O', 'O', 'O', 'B-MISC', 'O', 'O'] | |||
golden = ['B-ORG', 'O', 'B-MISC', 'O', 'O', 'O', 'B-MISC', 'O', 'O'] | |||
self.assertListEqual(golden, iob2(tags)) | |||
tags = ['I-ORG', 'O'] | |||
golden = ['B-ORG', 'O'] | |||
self.assertListEqual(golden, iob2(tags)) | |||
tags = ['I-MISC', 'I-MISC', 'O', 'I-PER', 'I-PER', 'O'] | |||
golden = ['B-MISC', 'I-MISC', 'O', 'B-PER', 'I-PER', 'O'] | |||
self.assertListEqual(golden, iob2(tags)) | |||
def test_iob2bemso(self): | |||
tags = ['B-MISC', 'I-MISC', 'O', 'B-PER', 'I-PER', 'O'] | |||
golden = ['B-MISC', 'E-MISC', 'O', 'B-PER', 'E-PER', 'O'] | |||
self.assertListEqual(golden, iob2bioes(tags)) | |||
def test_conll2003_loader(): | |||
path = '/hdd/fudanNLP/fastNLP/others/data/conll2003/train.txt' | |||
loader = Conll2003DataLoader().load(path) | |||
print(loader[:3]) | |||
if __name__ == '__main__': | |||
test_conll2003_loader() |
@@ -0,0 +1,52 @@ | |||
""" | |||
使用Bert进行英文命名实体识别 | |||
""" | |||
import sys | |||
sys.path.append('../../../') | |||
from reproduction.seqence_labelling.ner.model.bert_crf import BertCRF | |||
from fastNLP.embeddings import BertEmbedding | |||
from fastNLP import Trainer, Const | |||
from fastNLP import BucketSampler, SpanFPreRecMetric, GradientClipCallback | |||
from fastNLP.core.callback import WarmupCallback | |||
from fastNLP.core.optimizer import AdamW | |||
from fastNLP.io import Conll2003NERPipe | |||
from fastNLP import cache_results, EvaluateCallback | |||
encoding_type = 'bioes' | |||
@cache_results('caches/conll2003.pkl', _refresh=False) | |||
def load_data(): | |||
# 替换路径 | |||
paths = 'data/conll2003' | |||
data = Conll2003NERPipe(encoding_type=encoding_type).process_from_file(paths) | |||
return data | |||
data = load_data() | |||
print(data) | |||
embed = BertEmbedding(data.get_vocab(Const.INPUT), model_dir_or_name='en-base-cased', | |||
pool_method='max', requires_grad=True, layers='11', include_cls_sep=False, dropout=0.5, | |||
word_dropout=0.01) | |||
callbacks = [ | |||
GradientClipCallback(clip_type='norm', clip_value=1), | |||
WarmupCallback(warmup=0.1, schedule='linear'), | |||
EvaluateCallback(data.get_dataset('test')) | |||
] | |||
model = BertCRF(embed, tag_vocab=data.get_vocab('target'), encoding_type=encoding_type) | |||
optimizer = AdamW(model.parameters(), lr=2e-5) | |||
trainer = Trainer(train_data=data.datasets['train'], model=model, optimizer=optimizer, sampler=BucketSampler(), | |||
device=0, dev_data=data.datasets['dev'], batch_size=6, | |||
metrics=SpanFPreRecMetric(tag_vocab=data.vocabs[Const.TARGET], encoding_type=encoding_type), | |||
loss=None, callbacks=callbacks, num_workers=2, n_epochs=5, | |||
check_code_level=0, update_every=3, test_use_tqdm=False) | |||
trainer.train() | |||
@@ -1,4 +1,4 @@ | |||
from reproduction.seqence_labelling.ner.data.OntoNoteLoader import OntoNoteNERDataLoader | |||
from fastNLP.io import OntoNotesNERPipe | |||
from fastNLP.core.callback import LRScheduler | |||
from fastNLP import GradientClipCallback | |||
from torch.optim.lr_scheduler import LambdaLR | |||
@@ -10,14 +10,10 @@ from fastNLP import Trainer, Tester | |||
from fastNLP.core.metrics import MetricBase | |||
from reproduction.seqence_labelling.ner.model.dilated_cnn import IDCNN | |||
from fastNLP.core.utils import Option | |||
from fastNLP.embeddings.embedding import StaticEmbedding | |||
from fastNLP.embeddings import StaticEmbedding | |||
from fastNLP.core.utils import cache_results | |||
from fastNLP.core.vocabulary import VocabularyOption | |||
import torch.cuda | |||
import os | |||
os.environ['FASTNLP_BASE_URL'] = 'http://10.141.222.118:8888/file/download/' | |||
os.environ['FASTNLP_CACHE_DIR'] = '/remote-home/hyan01/fastnlp_caches' | |||
os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID" | |||
encoding_type = 'bioes' | |||
@@ -40,18 +36,8 @@ ops = Option( | |||
@cache_results('ontonotes-case-cache') | |||
def load_data(): | |||
print('loading data') | |||
data = OntoNoteNERDataLoader(encoding_type=encoding_type).process( | |||
paths = get_path('workdir/datasets/ontonotes-v4'), | |||
lower=False, | |||
word_vocab_opt=VocabularyOption(min_freq=0), | |||
) | |||
# data = Conll2003DataLoader(task='ner', encoding_type=encoding_type).process( | |||
# paths=get_path('workdir/datasets/conll03'), | |||
# lower=False, word_vocab_opt=VocabularyOption(min_freq=0) | |||
# ) | |||
# char_embed = CNNCharEmbedding(vocab=data.vocabs['cap_words'], embed_size=30, char_emb_size=30, filter_nums=[30], | |||
# kernel_sizes=[3]) | |||
data = OntoNotesNERPipe(encoding_type=encoding_type).process_from_file( | |||
paths = get_path('workdir/datasets/ontonotes-v4')) | |||
print('loading embedding') | |||
word_embed = StaticEmbedding(vocab=data.vocabs[Const.INPUT], | |||
model_dir_or_name='en-glove-840b-300', | |||