@@ -17,7 +17,7 @@ from numbers import Number | |||||
from .sampler import SequentialSampler | from .sampler import SequentialSampler | ||||
from .dataset import DataSet | from .dataset import DataSet | ||||
from ._logger import logger | |||||
_python_is_exit = False | _python_is_exit = False | ||||
@@ -75,7 +75,7 @@ class DataSetGetter: | |||||
try: | try: | ||||
data, flag = _to_tensor(data, f.dtype) | data, flag = _to_tensor(data, f.dtype) | ||||
except TypeError as e: | except TypeError as e: | ||||
print(f"Field {n} cannot be converted to torch.tensor.") | |||||
logger.error(f"Field {n} cannot be converted to torch.tensor.") | |||||
raise e | raise e | ||||
batch_dict[n] = data | batch_dict[n] = data | ||||
return batch_dict | return batch_dict | ||||
@@ -83,7 +83,6 @@ try: | |||||
except: | except: | ||||
tensorboardX_flag = False | tensorboardX_flag = False | ||||
from ..io.model_io import ModelSaver, ModelLoader | |||||
from .dataset import DataSet | from .dataset import DataSet | ||||
from .tester import Tester | from .tester import Tester | ||||
from ._logger import logger | from ._logger import logger | ||||
@@ -505,7 +504,7 @@ class EarlyStopCallback(Callback): | |||||
def on_exception(self, exception): | def on_exception(self, exception): | ||||
if isinstance(exception, EarlyStopError): | if isinstance(exception, EarlyStopError): | ||||
print("Early Stopping triggered in epoch {}!".format(self.epoch)) | |||||
logger.info("Early Stopping triggered in epoch {}!".format(self.epoch)) | |||||
else: | else: | ||||
raise exception # 抛出陌生Error | raise exception # 抛出陌生Error | ||||
@@ -752,8 +751,7 @@ class LRFinder(Callback): | |||||
self.smooth_value = SmoothValue(0.8) | self.smooth_value = SmoothValue(0.8) | ||||
self.opt = None | self.opt = None | ||||
self.find = None | self.find = None | ||||
self.loader = ModelLoader() | |||||
@property | @property | ||||
def lr_gen(self): | def lr_gen(self): | ||||
scale = (self.end_lr - self.start_lr) / self.batch_per_epoch | scale = (self.end_lr - self.start_lr) / self.batch_per_epoch | ||||
@@ -768,7 +766,7 @@ class LRFinder(Callback): | |||||
self.opt = self.trainer.optimizer # pytorch optimizer | self.opt = self.trainer.optimizer # pytorch optimizer | ||||
self.opt.param_groups[0]["lr"] = self.start_lr | self.opt.param_groups[0]["lr"] = self.start_lr | ||||
# save model | # save model | ||||
ModelSaver("tmp").save_pytorch(self.trainer.model, param_only=True) | |||||
torch.save(self.model.state_dict(), 'tmp') | |||||
self.find = True | self.find = True | ||||
def on_backward_begin(self, loss): | def on_backward_begin(self, loss): | ||||
@@ -797,7 +795,9 @@ class LRFinder(Callback): | |||||
self.opt.param_groups[0]["lr"] = self.best_lr | self.opt.param_groups[0]["lr"] = self.best_lr | ||||
self.find = False | self.find = False | ||||
# reset model | # reset model | ||||
ModelLoader().load_pytorch(self.trainer.model, "tmp") | |||||
states = torch.load('tmp') | |||||
self.model.load_state_dict(states) | |||||
os.remove('tmp') | |||||
self.pbar.write("Model reset. \nFind best lr={}".format(self.best_lr)) | self.pbar.write("Model reset. \nFind best lr={}".format(self.best_lr)) | ||||
@@ -988,14 +988,14 @@ class SaveModelCallback(Callback): | |||||
try: | try: | ||||
_save_model(self.model, model_name=name, save_dir=self.save_dir, only_param=self.only_param) | _save_model(self.model, model_name=name, save_dir=self.save_dir, only_param=self.only_param) | ||||
except Exception as e: | except Exception as e: | ||||
print(f"The following exception:{e} happens when save model to {self.save_dir}.") | |||||
logger.error(f"The following exception:{e} happens when save model to {self.save_dir}.") | |||||
if delete_pair: | if delete_pair: | ||||
try: | try: | ||||
delete_model_path = os.path.join(self.save_dir, delete_pair[1]) | delete_model_path = os.path.join(self.save_dir, delete_pair[1]) | ||||
if os.path.exists(delete_model_path): | if os.path.exists(delete_model_path): | ||||
os.remove(delete_model_path) | os.remove(delete_model_path) | ||||
except Exception as e: | except Exception as e: | ||||
print(f"Fail to delete model {name} at {self.save_dir} caused by exception:{e}.") | |||||
logger.error(f"Fail to delete model {name} at {self.save_dir} caused by exception:{e}.") | |||||
def on_exception(self, exception): | def on_exception(self, exception): | ||||
if self.save_on_exception: | if self.save_on_exception: | ||||
@@ -1032,7 +1032,7 @@ class EchoCallback(Callback): | |||||
def __getattribute__(self, item): | def __getattribute__(self, item): | ||||
if item.startswith('on_'): | if item.startswith('on_'): | ||||
print('{}.{} has been called at pid: {}'.format(self.name, item, os.getpid()), | |||||
logger.info('{}.{} has been called at pid: {}'.format(self.name, item, os.getpid()), | |||||
file=self.out) | file=self.out) | ||||
return super(EchoCallback, self).__getattribute__(item) | return super(EchoCallback, self).__getattribute__(item) | ||||
@@ -300,6 +300,7 @@ from .utils import _get_func_signature | |||||
from .field import AppendToTargetOrInputException | from .field import AppendToTargetOrInputException | ||||
from .field import SetInputOrTargetException | from .field import SetInputOrTargetException | ||||
from .const import Const | from .const import Const | ||||
from ._logger import logger | |||||
class DataSet(object): | class DataSet(object): | ||||
""" | """ | ||||
@@ -452,7 +453,7 @@ class DataSet(object): | |||||
try: | try: | ||||
self.field_arrays[name].append(field) | self.field_arrays[name].append(field) | ||||
except AppendToTargetOrInputException as e: | except AppendToTargetOrInputException as e: | ||||
print(f"Cannot append to field:{name}.") | |||||
logger.error(f"Cannot append to field:{name}.") | |||||
raise e | raise e | ||||
def add_fieldarray(self, field_name, fieldarray): | def add_fieldarray(self, field_name, fieldarray): | ||||
@@ -609,7 +610,7 @@ class DataSet(object): | |||||
self.field_arrays[name]._use_1st_ins_infer_dim_type = bool(use_1st_ins_infer_dim_type) | self.field_arrays[name]._use_1st_ins_infer_dim_type = bool(use_1st_ins_infer_dim_type) | ||||
self.field_arrays[name].is_target = flag | self.field_arrays[name].is_target = flag | ||||
except SetInputOrTargetException as e: | except SetInputOrTargetException as e: | ||||
print(f"Cannot set field:{name} as target.") | |||||
logger.error(f"Cannot set field:{name} as target.") | |||||
raise e | raise e | ||||
else: | else: | ||||
raise KeyError("{} is not a valid field name.".format(name)) | raise KeyError("{} is not a valid field name.".format(name)) | ||||
@@ -633,7 +634,7 @@ class DataSet(object): | |||||
self.field_arrays[name]._use_1st_ins_infer_dim_type = bool(use_1st_ins_infer_dim_type) | self.field_arrays[name]._use_1st_ins_infer_dim_type = bool(use_1st_ins_infer_dim_type) | ||||
self.field_arrays[name].is_input = flag | self.field_arrays[name].is_input = flag | ||||
except SetInputOrTargetException as e: | except SetInputOrTargetException as e: | ||||
print(f"Cannot set field:{name} as input, exception happens at the {e.index} value.") | |||||
logger.error(f"Cannot set field:{name} as input, exception happens at the {e.index} value.") | |||||
raise e | raise e | ||||
else: | else: | ||||
raise KeyError("{} is not a valid field name.".format(name)) | raise KeyError("{} is not a valid field name.".format(name)) | ||||
@@ -728,7 +729,7 @@ class DataSet(object): | |||||
results.append(func(ins[field_name])) | results.append(func(ins[field_name])) | ||||
except Exception as e: | except Exception as e: | ||||
if idx != -1: | if idx != -1: | ||||
print("Exception happens at the `{}`th(from 1) instance.".format(idx+1)) | |||||
logger.error("Exception happens at the `{}`th(from 1) instance.".format(idx+1)) | |||||
raise e | raise e | ||||
if not (new_field_name is None) and len(list(filter(lambda x: x is not None, results))) == 0: # all None | if not (new_field_name is None) and len(list(filter(lambda x: x is not None, results))) == 0: # all None | ||||
raise ValueError("{} always return None.".format(_get_func_signature(func=func))) | raise ValueError("{} always return None.".format(_get_func_signature(func=func))) | ||||
@@ -795,7 +796,7 @@ class DataSet(object): | |||||
results.append(func(ins)) | results.append(func(ins)) | ||||
except BaseException as e: | except BaseException as e: | ||||
if idx != -1: | if idx != -1: | ||||
print("Exception happens at the `{}`th instance.".format(idx)) | |||||
logger.error("Exception happens at the `{}`th instance.".format(idx)) | |||||
raise e | raise e | ||||
# results = [func(ins) for ins in self._inner_iter()] | # results = [func(ins) for ins in self._inner_iter()] | ||||
@@ -54,7 +54,6 @@ class DistTrainer(): | |||||
num_workers=1, drop_last=False, | num_workers=1, drop_last=False, | ||||
dev_data=None, metrics=None, metric_key=None, | dev_data=None, metrics=None, metric_key=None, | ||||
update_every=1, print_every=10, validate_every=-1, | update_every=1, print_every=10, validate_every=-1, | ||||
log_path=None, | |||||
save_every=-1, save_path=None, device='auto', | save_every=-1, save_path=None, device='auto', | ||||
fp16='', backend=None, init_method=None): | fp16='', backend=None, init_method=None): | ||||
@@ -12,6 +12,7 @@ from abc import abstractmethod | |||||
from copy import deepcopy | from copy import deepcopy | ||||
from collections import Counter | from collections import Counter | ||||
from .utils import _is_iterable | from .utils import _is_iterable | ||||
from ._logger import logger | |||||
class SetInputOrTargetException(Exception): | class SetInputOrTargetException(Exception): | ||||
@@ -39,7 +40,7 @@ class FieldArray: | |||||
try: | try: | ||||
_content = list(_content) | _content = list(_content) | ||||
except BaseException as e: | except BaseException as e: | ||||
print(f"Cannot convert content(of type:{type(content)}) into list.") | |||||
logger.error(f"Cannot convert content(of type:{type(content)}) into list.") | |||||
raise e | raise e | ||||
self.name = name | self.name = name | ||||
self.content = _content | self.content = _content | ||||
@@ -263,7 +264,7 @@ class FieldArray: | |||||
try: | try: | ||||
new_contents.append(cell.split(sep)) | new_contents.append(cell.split(sep)) | ||||
except Exception as e: | except Exception as e: | ||||
print(f"Exception happens when process value in index {index}.") | |||||
logger.error(f"Exception happens when process value in index {index}.") | |||||
raise e | raise e | ||||
return self._after_process(new_contents, inplace=inplace) | return self._after_process(new_contents, inplace=inplace) | ||||
@@ -283,8 +284,8 @@ class FieldArray: | |||||
else: | else: | ||||
new_contents.append(int(cell)) | new_contents.append(int(cell)) | ||||
except Exception as e: | except Exception as e: | ||||
print(f"Exception happens when process value in index {index}.") | |||||
print(e) | |||||
logger.error(f"Exception happens when process value in index {index}.") | |||||
raise e | |||||
return self._after_process(new_contents, inplace=inplace) | return self._after_process(new_contents, inplace=inplace) | ||||
def float(self, inplace=True): | def float(self, inplace=True): | ||||
@@ -303,7 +304,7 @@ class FieldArray: | |||||
else: | else: | ||||
new_contents.append(float(cell)) | new_contents.append(float(cell)) | ||||
except Exception as e: | except Exception as e: | ||||
print(f"Exception happens when process value in index {index}.") | |||||
logger.error(f"Exception happens when process value in index {index}.") | |||||
raise e | raise e | ||||
return self._after_process(new_contents, inplace=inplace) | return self._after_process(new_contents, inplace=inplace) | ||||
@@ -323,7 +324,7 @@ class FieldArray: | |||||
else: | else: | ||||
new_contents.append(bool(cell)) | new_contents.append(bool(cell)) | ||||
except Exception as e: | except Exception as e: | ||||
print(f"Exception happens when process value in index {index}.") | |||||
logger.error(f"Exception happens when process value in index {index}.") | |||||
raise e | raise e | ||||
return self._after_process(new_contents, inplace=inplace) | return self._after_process(new_contents, inplace=inplace) | ||||
@@ -344,7 +345,7 @@ class FieldArray: | |||||
else: | else: | ||||
new_contents.append(cell.lower()) | new_contents.append(cell.lower()) | ||||
except Exception as e: | except Exception as e: | ||||
print(f"Exception happens when process value in index {index}.") | |||||
logger.error(f"Exception happens when process value in index {index}.") | |||||
raise e | raise e | ||||
return self._after_process(new_contents, inplace=inplace) | return self._after_process(new_contents, inplace=inplace) | ||||
@@ -364,7 +365,7 @@ class FieldArray: | |||||
else: | else: | ||||
new_contents.append(cell.upper()) | new_contents.append(cell.upper()) | ||||
except Exception as e: | except Exception as e: | ||||
print(f"Exception happens when process value in index {index}.") | |||||
logger.error(f"Exception happens when process value in index {index}.") | |||||
raise e | raise e | ||||
return self._after_process(new_contents, inplace=inplace) | return self._after_process(new_contents, inplace=inplace) | ||||
@@ -401,7 +402,7 @@ class FieldArray: | |||||
self.is_input = self.is_input | self.is_input = self.is_input | ||||
self.is_target = self.is_input | self.is_target = self.is_input | ||||
except SetInputOrTargetException as e: | except SetInputOrTargetException as e: | ||||
print("The newly generated field cannot be set as input or target.") | |||||
logger.error("The newly generated field cannot be set as input or target.") | |||||
raise e | raise e | ||||
return self | return self | ||||
else: | else: | ||||
@@ -192,7 +192,7 @@ class Tester(object): | |||||
dataset=self.data, check_level=0) | dataset=self.data, check_level=0) | ||||
if self.verbose >= 1: | if self.verbose >= 1: | ||||
print("[tester] \n{}".format(self._format_eval_results(eval_results))) | |||||
logger.info("[tester] \n{}".format(self._format_eval_results(eval_results))) | |||||
self._mode(network, is_test=False) | self._mode(network, is_test=False) | ||||
return eval_results | return eval_results | ||||
@@ -145,7 +145,7 @@ def cache_results(_cache_fp, _refresh=False, _verbose=1): | |||||
with open(cache_filepath, 'rb') as f: | with open(cache_filepath, 'rb') as f: | ||||
results = _pickle.load(f) | results = _pickle.load(f) | ||||
if verbose == 1: | if verbose == 1: | ||||
print("Read cache from {}.".format(cache_filepath)) | |||||
logger.info("Read cache from {}.".format(cache_filepath)) | |||||
refresh_flag = False | refresh_flag = False | ||||
if refresh_flag: | if refresh_flag: | ||||
@@ -156,7 +156,7 @@ def cache_results(_cache_fp, _refresh=False, _verbose=1): | |||||
_prepare_cache_filepath(cache_filepath) | _prepare_cache_filepath(cache_filepath) | ||||
with open(cache_filepath, 'wb') as f: | with open(cache_filepath, 'wb') as f: | ||||
_pickle.dump(results, f) | _pickle.dump(results, f) | ||||
print("Save cache to {}.".format(cache_filepath)) | |||||
logger.info("Save cache to {}.".format(cache_filepath)) | |||||
return results | return results | ||||
@@ -10,6 +10,7 @@ from .utils import Option | |||||
from functools import partial | from functools import partial | ||||
import numpy as np | import numpy as np | ||||
from .utils import _is_iterable | from .utils import _is_iterable | ||||
from ._logger import logger | |||||
class VocabularyOption(Option): | class VocabularyOption(Option): | ||||
def __init__(self, | def __init__(self, | ||||
@@ -49,7 +50,7 @@ def _check_build_status(func): | |||||
if self.rebuild is False: | if self.rebuild is False: | ||||
self.rebuild = True | self.rebuild = True | ||||
if self.max_size is not None and len(self.word_count) >= self.max_size: | if self.max_size is not None and len(self.word_count) >= self.max_size: | ||||
print("[Warning] Vocabulary has reached the max size {} when calling {} method. " | |||||
logger.info("[Warning] Vocabulary has reached the max size {} when calling {} method. " | |||||
"Adding more words may cause unexpected behaviour of Vocabulary. ".format( | "Adding more words may cause unexpected behaviour of Vocabulary. ".format( | ||||
self.max_size, func.__name__)) | self.max_size, func.__name__)) | ||||
return func(self, *args, **kwargs) | return func(self, *args, **kwargs) | ||||
@@ -297,7 +298,7 @@ class Vocabulary(object): | |||||
for f_n, n_f_n in zip(field_name, new_field_name): | for f_n, n_f_n in zip(field_name, new_field_name): | ||||
dataset.apply_field(index_instance, field_name=f_n, new_field_name=n_f_n) | dataset.apply_field(index_instance, field_name=f_n, new_field_name=n_f_n) | ||||
except Exception as e: | except Exception as e: | ||||
print("When processing the `{}` dataset, the following error occurred.".format(idx)) | |||||
logger.info("When processing the `{}` dataset, the following error occurred.".format(idx)) | |||||
raise e | raise e | ||||
else: | else: | ||||
raise RuntimeError("Only DataSet type is allowed.") | raise RuntimeError("Only DataSet type is allowed.") | ||||
@@ -353,7 +354,7 @@ class Vocabulary(object): | |||||
try: | try: | ||||
dataset.apply(construct_vocab) | dataset.apply(construct_vocab) | ||||
except BaseException as e: | except BaseException as e: | ||||
print("When processing the `{}` dataset, the following error occurred:".format(idx)) | |||||
log("When processing the `{}` dataset, the following error occurred:".format(idx)) | |||||
raise e | raise e | ||||
else: | else: | ||||
raise TypeError("Only DataSet type is allowed.") | raise TypeError("Only DataSet type is allowed.") | ||||
@@ -21,6 +21,7 @@ from ..io.file_utils import _get_embedding_url, cached_path, PRETRAINED_BERT_MOD | |||||
from ..modules.encoder.bert import _WordPieceBertModel, BertModel, BertTokenizer | from ..modules.encoder.bert import _WordPieceBertModel, BertModel, BertTokenizer | ||||
from .contextual_embedding import ContextualEmbedding | from .contextual_embedding import ContextualEmbedding | ||||
import warnings | import warnings | ||||
from ..core import logger | |||||
class BertEmbedding(ContextualEmbedding): | class BertEmbedding(ContextualEmbedding): | ||||
@@ -125,8 +126,10 @@ class BertEmbedding(ContextualEmbedding): | |||||
with torch.no_grad(): | with torch.no_grad(): | ||||
if self._word_sep_index: # 不能drop sep | if self._word_sep_index: # 不能drop sep | ||||
sep_mask = words.eq(self._word_sep_index) | sep_mask = words.eq(self._word_sep_index) | ||||
mask = torch.ones_like(words).float() * self.word_dropout | |||||
mask = torch.full_like(words, fill_value=self.word_dropout) | |||||
mask = torch.bernoulli(mask).eq(1) # dropout_word越大,越多位置为1 | mask = torch.bernoulli(mask).eq(1) # dropout_word越大,越多位置为1 | ||||
pad_mask = words.ne(0) | |||||
mask = pad_mask.__and__(mask) # pad的位置不为unk | |||||
words = words.masked_fill(mask, self._word_unk_index) | words = words.masked_fill(mask, self._word_unk_index) | ||||
if self._word_sep_index: | if self._word_sep_index: | ||||
words.masked_fill_(sep_mask, self._word_sep_index) | words.masked_fill_(sep_mask, self._word_sep_index) | ||||
@@ -182,6 +185,7 @@ class BertWordPieceEncoder(nn.Module): | |||||
self.model = _WordPieceBertModel(model_dir=model_dir, layers=layers, pooled_cls=pooled_cls) | self.model = _WordPieceBertModel(model_dir=model_dir, layers=layers, pooled_cls=pooled_cls) | ||||
self._sep_index = self.model._sep_index | self._sep_index = self.model._sep_index | ||||
self._wordpiece_pad_index = self.model._wordpiece_pad_index | |||||
self._wordpiece_unk_index = self.model._wordpiece_unknown_index | self._wordpiece_unk_index = self.model._wordpiece_unknown_index | ||||
self._embed_size = len(self.model.layers) * self.model.encoder.hidden_size | self._embed_size = len(self.model.layers) * self.model.encoder.hidden_size | ||||
self.requires_grad = requires_grad | self.requires_grad = requires_grad | ||||
@@ -263,8 +267,10 @@ class BertWordPieceEncoder(nn.Module): | |||||
with torch.no_grad(): | with torch.no_grad(): | ||||
if self._word_sep_index: # 不能drop sep | if self._word_sep_index: # 不能drop sep | ||||
sep_mask = words.eq(self._wordpiece_unk_index) | sep_mask = words.eq(self._wordpiece_unk_index) | ||||
mask = torch.ones_like(words).float() * self.word_dropout | |||||
mask = torch.full_like(words, fill_value=self.word_dropout) | |||||
mask = torch.bernoulli(mask).eq(1) # dropout_word越大,越多位置为1 | mask = torch.bernoulli(mask).eq(1) # dropout_word越大,越多位置为1 | ||||
pad_mask = words.ne(self._wordpiece_pad_index) | |||||
mask = pad_mask.__and__(mask) # pad的位置不为unk | |||||
words = words.masked_fill(mask, self._word_unk_index) | words = words.masked_fill(mask, self._word_unk_index) | ||||
if self._word_sep_index: | if self._word_sep_index: | ||||
words.masked_fill_(sep_mask, self._wordpiece_unk_index) | words.masked_fill_(sep_mask, self._wordpiece_unk_index) | ||||
@@ -297,7 +303,7 @@ class _WordBertModel(nn.Module): | |||||
self.auto_truncate = auto_truncate | self.auto_truncate = auto_truncate | ||||
# 将所有vocab中word的wordpiece计算出来, 需要额外考虑[CLS]和[SEP] | # 将所有vocab中word的wordpiece计算出来, 需要额外考虑[CLS]和[SEP] | ||||
print("Start to generating word pieces for word.") | |||||
logger.info("Start to generating word pieces for word.") | |||||
# 第一步统计出需要的word_piece, 然后创建新的embed和word_piece_vocab, 然后填入值 | # 第一步统计出需要的word_piece, 然后创建新的embed和word_piece_vocab, 然后填入值 | ||||
word_piece_dict = {'[CLS]': 1, '[SEP]': 1} # 用到的word_piece以及新增的 | word_piece_dict = {'[CLS]': 1, '[SEP]': 1} # 用到的word_piece以及新增的 | ||||
found_count = 0 | found_count = 0 | ||||
@@ -356,10 +362,10 @@ class _WordBertModel(nn.Module): | |||||
self._sep_index = self.tokenzier.vocab['[SEP]'] | self._sep_index = self.tokenzier.vocab['[SEP]'] | ||||
self._word_pad_index = vocab.padding_idx | self._word_pad_index = vocab.padding_idx | ||||
self._wordpiece_pad_index = self.tokenzier.vocab['[PAD]'] # 需要用于生成word_piece | self._wordpiece_pad_index = self.tokenzier.vocab['[PAD]'] # 需要用于生成word_piece | ||||
print("Found(Or segment into word pieces) {} words out of {}.".format(found_count, len(vocab))) | |||||
logger.info("Found(Or segment into word pieces) {} words out of {}.".format(found_count, len(vocab))) | |||||
self.word_to_wordpieces = np.array(word_to_wordpieces) | self.word_to_wordpieces = np.array(word_to_wordpieces) | ||||
self.word_pieces_lengths = nn.Parameter(torch.LongTensor(word_pieces_lengths), requires_grad=False) | self.word_pieces_lengths = nn.Parameter(torch.LongTensor(word_pieces_lengths), requires_grad=False) | ||||
print("Successfully generate word pieces.") | |||||
logger.debug("Successfully generate word pieces.") | |||||
def forward(self, words): | def forward(self, words): | ||||
""" | """ | ||||
@@ -19,6 +19,7 @@ from ..core.vocabulary import Vocabulary | |||||
from .embedding import TokenEmbedding | from .embedding import TokenEmbedding | ||||
from .utils import _construct_char_vocab_from_vocab | from .utils import _construct_char_vocab_from_vocab | ||||
from .utils import get_embeddings | from .utils import get_embeddings | ||||
from ..core import logger | |||||
class CNNCharEmbedding(TokenEmbedding): | class CNNCharEmbedding(TokenEmbedding): | ||||
@@ -81,11 +82,11 @@ class CNNCharEmbedding(TokenEmbedding): | |||||
raise Exception( | raise Exception( | ||||
"Undefined activation function: choose from: [relu, tanh, sigmoid, or a callable function]") | "Undefined activation function: choose from: [relu, tanh, sigmoid, or a callable function]") | ||||
print("Start constructing character vocabulary.") | |||||
logger.info("Start constructing character vocabulary.") | |||||
# 建立char的词表 | # 建立char的词表 | ||||
self.char_vocab = _construct_char_vocab_from_vocab(vocab, min_freq=min_char_freq) | self.char_vocab = _construct_char_vocab_from_vocab(vocab, min_freq=min_char_freq) | ||||
self.char_pad_index = self.char_vocab.padding_idx | self.char_pad_index = self.char_vocab.padding_idx | ||||
print(f"In total, there are {len(self.char_vocab)} distinct characters.") | |||||
logger.info(f"In total, there are {len(self.char_vocab)} distinct characters.") | |||||
# 对vocab进行index | # 对vocab进行index | ||||
max_word_len = max(map(lambda x: len(x[0]), vocab)) | max_word_len = max(map(lambda x: len(x[0]), vocab)) | ||||
self.words_to_chars_embedding = nn.Parameter(torch.full((len(vocab), max_word_len), | self.words_to_chars_embedding = nn.Parameter(torch.full((len(vocab), max_word_len), | ||||
@@ -236,11 +237,11 @@ class LSTMCharEmbedding(TokenEmbedding): | |||||
raise Exception( | raise Exception( | ||||
"Undefined activation function: choose from: [relu, tanh, sigmoid, or a callable function]") | "Undefined activation function: choose from: [relu, tanh, sigmoid, or a callable function]") | ||||
print("Start constructing character vocabulary.") | |||||
logger.info("Start constructing character vocabulary.") | |||||
# 建立char的词表 | # 建立char的词表 | ||||
self.char_vocab = _construct_char_vocab_from_vocab(vocab, min_freq=min_char_freq) | self.char_vocab = _construct_char_vocab_from_vocab(vocab, min_freq=min_char_freq) | ||||
self.char_pad_index = self.char_vocab.padding_idx | self.char_pad_index = self.char_vocab.padding_idx | ||||
print(f"In total, there are {len(self.char_vocab)} distinct characters.") | |||||
logger.info(f"In total, there are {len(self.char_vocab)} distinct characters.") | |||||
# 对vocab进行index | # 对vocab进行index | ||||
self.max_word_len = max(map(lambda x: len(x[0]), vocab)) | self.max_word_len = max(map(lambda x: len(x[0]), vocab)) | ||||
self.words_to_chars_embedding = nn.Parameter(torch.full((len(vocab), self.max_word_len), | self.words_to_chars_embedding = nn.Parameter(torch.full((len(vocab), self.max_word_len), | ||||
@@ -16,7 +16,7 @@ from ..core.batch import DataSetIter | |||||
from ..core.sampler import SequentialSampler | from ..core.sampler import SequentialSampler | ||||
from ..core.utils import _move_model_to_device, _get_model_device | from ..core.utils import _move_model_to_device, _get_model_device | ||||
from .embedding import TokenEmbedding | from .embedding import TokenEmbedding | ||||
from ..core import logger | |||||
class ContextualEmbedding(TokenEmbedding): | class ContextualEmbedding(TokenEmbedding): | ||||
def __init__(self, vocab: Vocabulary, word_dropout: float = 0.0, dropout: float = 0.0): | def __init__(self, vocab: Vocabulary, word_dropout: float = 0.0, dropout: float = 0.0): | ||||
@@ -37,14 +37,14 @@ class ContextualEmbedding(TokenEmbedding): | |||||
assert isinstance(dataset, DataSet), "Only fastNLP.DataSet object is allowed." | assert isinstance(dataset, DataSet), "Only fastNLP.DataSet object is allowed." | ||||
assert 'words' in dataset.get_input_name(), "`words` field has to be set as input." | assert 'words' in dataset.get_input_name(), "`words` field has to be set as input." | ||||
except Exception as e: | except Exception as e: | ||||
print(f"Exception happens at {index} dataset.") | |||||
logger.error(f"Exception happens at {index} dataset.") | |||||
raise e | raise e | ||||
sent_embeds = {} | sent_embeds = {} | ||||
_move_model_to_device(self, device=device) | _move_model_to_device(self, device=device) | ||||
device = _get_model_device(self) | device = _get_model_device(self) | ||||
pad_index = self._word_vocab.padding_idx | pad_index = self._word_vocab.padding_idx | ||||
print("Start to calculate sentence representations.") | |||||
logger.info("Start to calculate sentence representations.") | |||||
with torch.no_grad(): | with torch.no_grad(): | ||||
for index, dataset in enumerate(datasets): | for index, dataset in enumerate(datasets): | ||||
try: | try: | ||||
@@ -64,9 +64,9 @@ class ContextualEmbedding(TokenEmbedding): | |||||
else: | else: | ||||
sent_embeds[tuple(words_list[b][:seq_len[b]])] = word_embeds[b, :-length] | sent_embeds[tuple(words_list[b][:seq_len[b]])] = word_embeds[b, :-length] | ||||
except Exception as e: | except Exception as e: | ||||
print(f"Exception happens at {index} dataset.") | |||||
logger.error(f"Exception happens at {index} dataset.") | |||||
raise e | raise e | ||||
print("Finish calculating sentence representations.") | |||||
logger.info("Finish calculating sentence representations.") | |||||
self.sent_embeds = sent_embeds | self.sent_embeds = sent_embeds | ||||
if delete_weights: | if delete_weights: | ||||
self._delete_model_weights() | self._delete_model_weights() | ||||
@@ -18,7 +18,7 @@ from ..core.vocabulary import Vocabulary | |||||
from ..io.file_utils import cached_path, _get_embedding_url, PRETRAINED_ELMO_MODEL_DIR | from ..io.file_utils import cached_path, _get_embedding_url, PRETRAINED_ELMO_MODEL_DIR | ||||
from ..modules.encoder._elmo import ElmobiLm, ConvTokenEmbedder | from ..modules.encoder._elmo import ElmobiLm, ConvTokenEmbedder | ||||
from .contextual_embedding import ContextualEmbedding | from .contextual_embedding import ContextualEmbedding | ||||
from ..core import logger | |||||
class ElmoEmbedding(ContextualEmbedding): | class ElmoEmbedding(ContextualEmbedding): | ||||
""" | """ | ||||
@@ -243,7 +243,7 @@ class _ElmoModel(nn.Module): | |||||
index_in_pre = char_lexicon[OOV_TAG] | index_in_pre = char_lexicon[OOV_TAG] | ||||
char_emb_layer.weight.data[index] = char_embed_weights[index_in_pre] | char_emb_layer.weight.data[index] = char_embed_weights[index_in_pre] | ||||
print(f"{found_char_count} out of {len(char_vocab)} characters were found in pretrained elmo embedding.") | |||||
logger.info(f"{found_char_count} out of {len(char_vocab)} characters were found in pretrained elmo embedding.") | |||||
# 生成words到chars的映射 | # 生成words到chars的映射 | ||||
max_chars = config['char_cnn']['max_characters_per_token'] | max_chars = config['char_cnn']['max_characters_per_token'] | ||||
@@ -281,7 +281,7 @@ class _ElmoModel(nn.Module): | |||||
if cache_word_reprs: | if cache_word_reprs: | ||||
if config['char_cnn']['embedding']['dim'] > 0: # 只有在使用了chars的情况下有用 | if config['char_cnn']['embedding']['dim'] > 0: # 只有在使用了chars的情况下有用 | ||||
print("Start to generate cache word representations.") | |||||
logger.info("Start to generate cache word representations.") | |||||
batch_size = 320 | batch_size = 320 | ||||
# bos eos | # bos eos | ||||
word_size = self.words_to_chars_embedding.size(0) | word_size = self.words_to_chars_embedding.size(0) | ||||
@@ -299,10 +299,10 @@ class _ElmoModel(nn.Module): | |||||
chars).detach() # batch_size x 1 x config['encoder']['projection_dim'] | chars).detach() # batch_size x 1 x config['encoder']['projection_dim'] | ||||
self.cached_word_embedding.weight.data[words] = word_reprs.squeeze(1) | self.cached_word_embedding.weight.data[words] = word_reprs.squeeze(1) | ||||
print("Finish generating cached word representations. Going to delete the character encoder.") | |||||
logger.info("Finish generating cached word representations. Going to delete the character encoder.") | |||||
del self.token_embedder, self.words_to_chars_embedding | del self.token_embedder, self.words_to_chars_embedding | ||||
else: | else: | ||||
print("There is no need to cache word representations, since no character information is used.") | |||||
logger.info("There is no need to cache word representations, since no character information is used.") | |||||
def forward(self, words): | def forward(self, words): | ||||
""" | """ | ||||
@@ -138,8 +138,10 @@ class TokenEmbedding(nn.Module): | |||||
:return: | :return: | ||||
""" | """ | ||||
if self.word_dropout > 0 and self.training: | if self.word_dropout > 0 and self.training: | ||||
mask = torch.ones_like(words).float() * self.word_dropout | |||||
mask = torch.full_like(words, fill_value=self.word_dropout) | |||||
mask = torch.bernoulli(mask).eq(1) # dropout_word越大,越多位置为1 | mask = torch.bernoulli(mask).eq(1) # dropout_word越大,越多位置为1 | ||||
pad_mask = words.ne(self._word_pad_index) | |||||
mask = mask.__and__(pad_mask) | |||||
words = words.masked_fill(mask, self._word_unk_index) | words = words.masked_fill(mask, self._word_unk_index) | ||||
return words | return words | ||||
@@ -19,6 +19,7 @@ from .embedding import TokenEmbedding | |||||
from ..modules.utils import _get_file_name_base_on_postfix | from ..modules.utils import _get_file_name_base_on_postfix | ||||
from copy import deepcopy | from copy import deepcopy | ||||
from collections import defaultdict | from collections import defaultdict | ||||
from ..core import logger | |||||
class StaticEmbedding(TokenEmbedding): | class StaticEmbedding(TokenEmbedding): | ||||
@@ -112,7 +113,7 @@ class StaticEmbedding(TokenEmbedding): | |||||
truncated_words_to_words = torch.arange(len(vocab)).long() | truncated_words_to_words = torch.arange(len(vocab)).long() | ||||
for word, index in vocab: | for word, index in vocab: | ||||
truncated_words_to_words[index] = truncated_vocab.to_index(word) | truncated_words_to_words[index] = truncated_vocab.to_index(word) | ||||
print(f"{len(vocab) - len(truncated_vocab)} out of {len(vocab)} words have frequency less than {min_freq}.") | |||||
logger.info(f"{len(vocab) - len(truncated_vocab)} out of {len(vocab)} words have frequency less than {min_freq}.") | |||||
vocab = truncated_vocab | vocab = truncated_vocab | ||||
self.only_norm_found_vector = kwargs.get('only_norm_found_vector', False) | self.only_norm_found_vector = kwargs.get('only_norm_found_vector', False) | ||||
@@ -124,7 +125,7 @@ class StaticEmbedding(TokenEmbedding): | |||||
lowered_vocab.add_word(word.lower(), no_create_entry=True) | lowered_vocab.add_word(word.lower(), no_create_entry=True) | ||||
else: | else: | ||||
lowered_vocab.add_word(word.lower()) # 先加入需要创建entry的 | lowered_vocab.add_word(word.lower()) # 先加入需要创建entry的 | ||||
print(f"All word in the vocab have been lowered. There are {len(vocab)} words, {len(lowered_vocab)} " | |||||
logger.info(f"All word in the vocab have been lowered. There are {len(vocab)} words, {len(lowered_vocab)} " | |||||
f"unique lowered words.") | f"unique lowered words.") | ||||
if model_path: | if model_path: | ||||
embedding = self._load_with_vocab(model_path, vocab=lowered_vocab, init_method=init_method) | embedding = self._load_with_vocab(model_path, vocab=lowered_vocab, init_method=init_method) | ||||
@@ -265,9 +266,9 @@ class StaticEmbedding(TokenEmbedding): | |||||
if error == 'ignore': | if error == 'ignore': | ||||
warnings.warn("Error occurred at the {} line.".format(idx)) | warnings.warn("Error occurred at the {} line.".format(idx)) | ||||
else: | else: | ||||
print("Error occurred at the {} line.".format(idx)) | |||||
logger.error("Error occurred at the {} line.".format(idx)) | |||||
raise e | raise e | ||||
print("Found {} out of {} words in the pre-training embedding.".format(found_count, len(vocab))) | |||||
logger.info("Found {} out of {} words in the pre-training embedding.".format(found_count, len(vocab))) | |||||
for word, index in vocab: | for word, index in vocab: | ||||
if index not in matrix and not vocab._is_word_no_create_entry(word): | if index not in matrix and not vocab._is_word_no_create_entry(word): | ||||
if found_unknown: # 如果有unkonwn,用unknown初始化 | if found_unknown: # 如果有unkonwn,用unknown初始化 | ||||
@@ -11,7 +11,7 @@ import numpy as np | |||||
from ..core.vocabulary import Vocabulary | from ..core.vocabulary import Vocabulary | ||||
from .data_bundle import BaseLoader | from .data_bundle import BaseLoader | ||||
from ..core.utils import Option | from ..core.utils import Option | ||||
import logging | |||||
class EmbeddingOption(Option): | class EmbeddingOption(Option): | ||||
def __init__(self, | def __init__(self, | ||||
@@ -91,10 +91,10 @@ class EmbedLoader(BaseLoader): | |||||
if error == 'ignore': | if error == 'ignore': | ||||
warnings.warn("Error occurred at the {} line.".format(idx)) | warnings.warn("Error occurred at the {} line.".format(idx)) | ||||
else: | else: | ||||
print("Error occurred at the {} line.".format(idx)) | |||||
logging.error("Error occurred at the {} line.".format(idx)) | |||||
raise e | raise e | ||||
total_hits = sum(hit_flags) | total_hits = sum(hit_flags) | ||||
print("Found {} out of {} words in the pre-training embedding.".format(total_hits, len(vocab))) | |||||
logging.info("Found {} out of {} words in the pre-training embedding.".format(total_hits, len(vocab))) | |||||
if init_method is None: | if init_method is None: | ||||
found_vectors = matrix[hit_flags] | found_vectors = matrix[hit_flags] | ||||
if len(found_vectors) != 0: | if len(found_vectors) != 0: | ||||
@@ -157,7 +157,7 @@ class EmbedLoader(BaseLoader): | |||||
warnings.warn("Error occurred at the {} line.".format(idx)) | warnings.warn("Error occurred at the {} line.".format(idx)) | ||||
pass | pass | ||||
else: | else: | ||||
print("Error occurred at the {} line.".format(idx)) | |||||
logging.error("Error occurred at the {} line.".format(idx)) | |||||
raise e | raise e | ||||
if dim == -1: | if dim == -1: | ||||
raise RuntimeError("{} is an empty file.".format(embed_filepath)) | raise RuntimeError("{} is an empty file.".format(embed_filepath)) | ||||
@@ -2,7 +2,8 @@ | |||||
此模块用于给其它模块提供读取文件的函数,没有为用户提供 API | 此模块用于给其它模块提供读取文件的函数,没有为用户提供 API | ||||
""" | """ | ||||
import json | import json | ||||
import warnings | |||||
from ..core import logger | |||||
def _read_csv(path, encoding='utf-8', headers=None, sep=',', dropna=True): | def _read_csv(path, encoding='utf-8', headers=None, sep=',', dropna=True): | ||||
""" | """ | ||||
@@ -103,9 +104,9 @@ def _read_conll(path, encoding='utf-8', indexes=None, dropna=True): | |||||
yield line_idx, res | yield line_idx, res | ||||
except Exception as e: | except Exception as e: | ||||
if dropna: | if dropna: | ||||
warnings.warn('Invalid instance ends at line: {} has been dropped.'.format(line_idx)) | |||||
logger.warn('Invalid instance which ends at line: {} has been dropped.'.format(line_idx)) | |||||
continue | continue | ||||
raise ValueError('Invalid instance ends at line: {}'.format(line_idx)) | |||||
raise ValueError('Invalid instance which ends at line: {}'.format(line_idx)) | |||||
elif line.startswith('#'): | elif line.startswith('#'): | ||||
continue | continue | ||||
else: | else: | ||||
@@ -117,5 +118,5 @@ def _read_conll(path, encoding='utf-8', indexes=None, dropna=True): | |||||
except Exception as e: | except Exception as e: | ||||
if dropna: | if dropna: | ||||
return | return | ||||
print('invalid instance ends at line: {}'.format(line_idx)) | |||||
logger.error('invalid instance ends at line: {}'.format(line_idx)) | |||||
raise e | raise e |
@@ -7,6 +7,7 @@ import tempfile | |||||
from tqdm import tqdm | from tqdm import tqdm | ||||
import shutil | import shutil | ||||
from requests import HTTPError | from requests import HTTPError | ||||
from ..core import logger | |||||
PRETRAINED_BERT_MODEL_DIR = { | PRETRAINED_BERT_MODEL_DIR = { | ||||
'en': 'bert-base-cased.zip', | 'en': 'bert-base-cased.zip', | ||||
@@ -336,7 +337,7 @@ def get_from_cache(url: str, cache_dir: Path = None) -> Path: | |||||
content_length = req.headers.get("Content-Length") | content_length = req.headers.get("Content-Length") | ||||
total = int(content_length) if content_length is not None else None | total = int(content_length) if content_length is not None else None | ||||
progress = tqdm(unit="B", total=total, unit_scale=1) | progress = tqdm(unit="B", total=total, unit_scale=1) | ||||
print("%s not found in cache, downloading to %s" % (url, temp_filename)) | |||||
logger.info("%s not found in cache, downloading to %s" % (url, temp_filename)) | |||||
with open(temp_filename, "wb") as temp_file: | with open(temp_filename, "wb") as temp_file: | ||||
for chunk in req.iter_content(chunk_size=1024 * 16): | for chunk in req.iter_content(chunk_size=1024 * 16): | ||||
@@ -344,12 +345,12 @@ def get_from_cache(url: str, cache_dir: Path = None) -> Path: | |||||
progress.update(len(chunk)) | progress.update(len(chunk)) | ||||
temp_file.write(chunk) | temp_file.write(chunk) | ||||
progress.close() | progress.close() | ||||
print(f"Finish download from {url}") | |||||
logger.info(f"Finish download from {url}") | |||||
# 开始解压 | # 开始解压 | ||||
if suffix in ('.zip', '.tar.gz', '.gz'): | if suffix in ('.zip', '.tar.gz', '.gz'): | ||||
uncompress_temp_dir = tempfile.mkdtemp() | uncompress_temp_dir = tempfile.mkdtemp() | ||||
print(f"Start to uncompress file to {uncompress_temp_dir}") | |||||
logger.debug(f"Start to uncompress file to {uncompress_temp_dir}") | |||||
if suffix == '.zip': | if suffix == '.zip': | ||||
unzip_file(Path(temp_filename), Path(uncompress_temp_dir)) | unzip_file(Path(temp_filename), Path(uncompress_temp_dir)) | ||||
elif suffix == '.gz': | elif suffix == '.gz': | ||||
@@ -362,13 +363,13 @@ def get_from_cache(url: str, cache_dir: Path = None) -> Path: | |||||
uncompress_temp_dir = os.path.join(uncompress_temp_dir, filenames[0]) | uncompress_temp_dir = os.path.join(uncompress_temp_dir, filenames[0]) | ||||
cache_path.mkdir(parents=True, exist_ok=True) | cache_path.mkdir(parents=True, exist_ok=True) | ||||
print("Finish un-compressing file.") | |||||
logger.debug("Finish un-compressing file.") | |||||
else: | else: | ||||
uncompress_temp_dir = temp_filename | uncompress_temp_dir = temp_filename | ||||
cache_path = str(cache_path) + suffix | cache_path = str(cache_path) + suffix | ||||
# 复制到指定的位置 | # 复制到指定的位置 | ||||
print(f"Copy file to {cache_path}") | |||||
logger.info(f"Copy file to {cache_path}") | |||||
if os.path.isdir(uncompress_temp_dir): | if os.path.isdir(uncompress_temp_dir): | ||||
for filename in os.listdir(uncompress_temp_dir): | for filename in os.listdir(uncompress_temp_dir): | ||||
if os.path.isdir(os.path.join(uncompress_temp_dir, filename)): | if os.path.isdir(os.path.join(uncompress_temp_dir, filename)): | ||||
@@ -379,7 +380,7 @@ def get_from_cache(url: str, cache_dir: Path = None) -> Path: | |||||
shutil.copyfile(uncompress_temp_dir, cache_path) | shutil.copyfile(uncompress_temp_dir, cache_path) | ||||
success = True | success = True | ||||
except Exception as e: | except Exception as e: | ||||
print(e) | |||||
logger.error(e) | |||||
raise e | raise e | ||||
finally: | finally: | ||||
if not success: | if not success: | ||||
@@ -11,7 +11,7 @@ from .utils import get_tokenizer, _indexize, _add_words_field, _drop_empty_insta | |||||
from .pipe import Pipe | from .pipe import Pipe | ||||
import re | import re | ||||
nonalpnum = re.compile('[^0-9a-zA-Z?!\']+') | nonalpnum = re.compile('[^0-9a-zA-Z?!\']+') | ||||
from ...core.utils import cache_results | |||||
class _CLSPipe(Pipe): | class _CLSPipe(Pipe): | ||||
""" | """ | ||||
@@ -2,7 +2,7 @@ import os | |||||
from typing import Union, Dict | from typing import Union, Dict | ||||
from pathlib import Path | from pathlib import Path | ||||
from ..core import logger | |||||
def check_loader_paths(paths:Union[str, Dict[str, str]])->Dict[str, str]: | def check_loader_paths(paths:Union[str, Dict[str, str]])->Dict[str, str]: | ||||
""" | """ | ||||
@@ -70,8 +70,8 @@ def get_tokenizer(): | |||||
import spacy | import spacy | ||||
spacy.prefer_gpu() | spacy.prefer_gpu() | ||||
en = spacy.load('en') | en = spacy.load('en') | ||||
print('use spacy tokenizer') | |||||
logger.info('use spacy tokenizer') | |||||
return lambda x: [w.text for w in en.tokenizer(x)] | return lambda x: [w.text for w in en.tokenizer(x)] | ||||
except Exception as e: | except Exception as e: | ||||
print('use raw tokenizer') | |||||
logger.error('use raw tokenizer') | |||||
return lambda x: x.split() | return lambda x: x.split() |
@@ -17,8 +17,7 @@ import os | |||||
import torch | import torch | ||||
from torch import nn | from torch import nn | ||||
import sys | |||||
from ...core import logger | |||||
from ..utils import _get_file_name_base_on_postfix | from ..utils import _get_file_name_base_on_postfix | ||||
CONFIG_FILE = 'bert_config.json' | CONFIG_FILE = 'bert_config.json' | ||||
@@ -489,10 +488,10 @@ class BertModel(nn.Module): | |||||
load(model, prefix='' if hasattr(model, 'bert') else 'bert.') | load(model, prefix='' if hasattr(model, 'bert') else 'bert.') | ||||
if len(missing_keys) > 0: | if len(missing_keys) > 0: | ||||
print("Weights of {} not initialized from pretrained model: {}".format( | |||||
logger.warn("Weights of {} not initialized from pretrained model: {}".format( | |||||
model.__class__.__name__, missing_keys)) | model.__class__.__name__, missing_keys)) | ||||
if len(unexpected_keys) > 0: | if len(unexpected_keys) > 0: | ||||
print("Weights from pretrained model not used in {}: {}".format( | |||||
logger.warn("Weights from pretrained model not used in {}: {}".format( | |||||
model.__class__.__name__, unexpected_keys)) | model.__class__.__name__, unexpected_keys)) | ||||
return model | return model | ||||
@@ -799,7 +798,7 @@ class BertTokenizer(object): | |||||
for token in tokens: | for token in tokens: | ||||
ids.append(self.vocab[token]) | ids.append(self.vocab[token]) | ||||
if len(ids) > self.max_len: | if len(ids) > self.max_len: | ||||
print( | |||||
logger.warn( | |||||
"Token indices sequence length is longer than the specified maximum " | "Token indices sequence length is longer than the specified maximum " | ||||
" sequence length for this BERT model ({} > {}). Running this" | " sequence length for this BERT model ({} > {}). Running this" | ||||
" sequence through BERT will result in indexing errors".format(len(ids), self.max_len) | " sequence through BERT will result in indexing errors".format(len(ids), self.max_len) | ||||
@@ -823,7 +822,7 @@ class BertTokenizer(object): | |||||
with open(vocab_file, "w", encoding="utf-8") as writer: | with open(vocab_file, "w", encoding="utf-8") as writer: | ||||
for token, token_index in sorted(self.vocab.items(), key=lambda kv: kv[1]): | for token, token_index in sorted(self.vocab.items(), key=lambda kv: kv[1]): | ||||
if index != token_index: | if index != token_index: | ||||
print("Saving vocabulary to {}: vocabulary indices are not consecutive." | |||||
logger.warn("Saving vocabulary to {}: vocabulary indices are not consecutive." | |||||
" Please check that the vocabulary is not corrupted!".format(vocab_file)) | " Please check that the vocabulary is not corrupted!".format(vocab_file)) | ||||
index = token_index | index = token_index | ||||
writer.write(token + u'\n') | writer.write(token + u'\n') | ||||
@@ -837,7 +836,7 @@ class BertTokenizer(object): | |||||
""" | """ | ||||
pretrained_model_name_or_path = _get_file_name_base_on_postfix(model_dir, '.txt') | pretrained_model_name_or_path = _get_file_name_base_on_postfix(model_dir, '.txt') | ||||
print("loading vocabulary file {}".format(pretrained_model_name_or_path)) | |||||
logger.info("loading vocabulary file {}".format(pretrained_model_name_or_path)) | |||||
max_len = 512 | max_len = 512 | ||||
kwargs['max_len'] = min(kwargs.get('max_position_embeddings', int(1e12)), max_len) | kwargs['max_len'] = min(kwargs.get('max_position_embeddings', int(1e12)), max_len) | ||||
# Instantiate tokenizer. | # Instantiate tokenizer. | ||||
@@ -901,7 +900,7 @@ class _WordPieceBertModel(nn.Module): | |||||
is_input=True) | is_input=True) | ||||
dataset.set_pad_val('word_pieces', self._wordpiece_pad_index) | dataset.set_pad_val('word_pieces', self._wordpiece_pad_index) | ||||
except Exception as e: | except Exception as e: | ||||
print(f"Exception happens when processing the {index} dataset.") | |||||
logger.error(f"Exception happens when processing the {index} dataset.") | |||||
raise e | raise e | ||||
def forward(self, word_pieces, token_type_ids=None): | def forward(self, word_pieces, token_type_ids=None): | ||||
@@ -1,93 +0,0 @@ | |||||
from fastNLP.core.vocabulary import VocabularyOption | |||||
from fastNLP.io.data_bundle import DataSetLoader, DataBundle | |||||
from typing import Union, Dict | |||||
from fastNLP import Vocabulary | |||||
from fastNLP import Const | |||||
from reproduction.utils import check_dataloader_paths | |||||
from fastNLP.io import ConllLoader | |||||
from reproduction.seqence_labelling.ner.data.utils import iob2bioes, iob2 | |||||
class Conll2003DataLoader(DataSetLoader): | |||||
def __init__(self, task:str='ner', encoding_type:str='bioes'): | |||||
""" | |||||
加载Conll2003格式的英语语料,该数据集的信息可以在https://www.clips.uantwerpen.be/conll2003/ner/找到。当task为pos | |||||
时,返回的DataSet中target取值于第2列; 当task为chunk时,返回的DataSet中target取值于第3列;当task为ner时,返回 | |||||
的DataSet中target取值于第4列。所有"-DOCSTART- -X- O O"将被忽略,这会导致数据的数量少于很多文献报道的值,但 | |||||
鉴于"-DOCSTART- -X- O O"只是用于文档分割的符号,并不应该作为预测对象,所以我们忽略了数据中的-DOCTSTART-开头的行 | |||||
ner与chunk任务读取后的数据的target将为encoding_type类型。pos任务读取后就是pos列的数据。 | |||||
:param task: 指定需要标注任务。可选ner, pos, chunk | |||||
""" | |||||
assert task in ('ner', 'pos', 'chunk') | |||||
index = {'ner':3, 'pos':1, 'chunk':2}[task] | |||||
self._loader = ConllLoader(headers=['raw_words', 'target'], indexes=[0, index]) | |||||
self._tag_converters = [] | |||||
if task in ('ner', 'chunk'): | |||||
self._tag_converters = [iob2] | |||||
if encoding_type == 'bioes': | |||||
self._tag_converters.append(iob2bioes) | |||||
def load(self, path: str): | |||||
dataset = self._loader.load(path) | |||||
def convert_tag_schema(tags): | |||||
for converter in self._tag_converters: | |||||
tags = converter(tags) | |||||
return tags | |||||
if self._tag_converters: | |||||
dataset.apply_field(convert_tag_schema, field_name=Const.TARGET, new_field_name=Const.TARGET) | |||||
return dataset | |||||
def process(self, paths: Union[str, Dict[str, str]], word_vocab_opt:VocabularyOption=None, lower:bool=False): | |||||
""" | |||||
读取并处理数据。数据中的'-DOCSTART-'开头的行会被忽略 | |||||
:param paths: | |||||
:param word_vocab_opt: vocabulary的初始化值 | |||||
:param lower: 是否将所有字母转为小写。 | |||||
:return: | |||||
""" | |||||
# 读取数据 | |||||
paths = check_dataloader_paths(paths) | |||||
data = DataBundle() | |||||
input_fields = [Const.TARGET, Const.INPUT, Const.INPUT_LEN] | |||||
target_fields = [Const.TARGET, Const.INPUT_LEN] | |||||
for name, path in paths.items(): | |||||
dataset = self.load(path) | |||||
dataset.apply_field(lambda words: words, field_name='raw_words', new_field_name=Const.INPUT) | |||||
if lower: | |||||
dataset.words.lower() | |||||
data.datasets[name] = dataset | |||||
# 对construct vocab | |||||
word_vocab = Vocabulary(min_freq=2) if word_vocab_opt is None else Vocabulary(**word_vocab_opt) | |||||
word_vocab.from_dataset(data.datasets['train'], field_name=Const.INPUT, | |||||
no_create_entry_dataset=[dataset for name, dataset in data.datasets.items() if name!='train']) | |||||
word_vocab.index_dataset(*data.datasets.values(), field_name=Const.INPUT, new_field_name=Const.INPUT) | |||||
data.vocabs[Const.INPUT] = word_vocab | |||||
# cap words | |||||
cap_word_vocab = Vocabulary() | |||||
cap_word_vocab.from_dataset(data.datasets['train'], field_name='raw_words', | |||||
no_create_entry_dataset=[dataset for name, dataset in data.datasets.items() if name!='train']) | |||||
cap_word_vocab.index_dataset(*data.datasets.values(), field_name='raw_words', new_field_name='cap_words') | |||||
input_fields.append('cap_words') | |||||
data.vocabs['cap_words'] = cap_word_vocab | |||||
# 对target建vocab | |||||
target_vocab = Vocabulary(unknown=None, padding=None) | |||||
target_vocab.from_dataset(*data.datasets.values(), field_name=Const.TARGET) | |||||
target_vocab.index_dataset(*data.datasets.values(), field_name=Const.TARGET) | |||||
data.vocabs[Const.TARGET] = target_vocab | |||||
for name, dataset in data.datasets.items(): | |||||
dataset.add_seq_len(Const.INPUT, new_field_name=Const.INPUT_LEN) | |||||
dataset.set_input(*input_fields) | |||||
dataset.set_target(*target_fields) | |||||
return data | |||||
if __name__ == '__main__': | |||||
pass |
@@ -1,152 +0,0 @@ | |||||
from fastNLP.core.vocabulary import VocabularyOption | |||||
from fastNLP.io.data_bundle import DataSetLoader, DataBundle | |||||
from typing import Union, Dict | |||||
from fastNLP import DataSet | |||||
from fastNLP import Vocabulary | |||||
from fastNLP import Const | |||||
from reproduction.utils import check_dataloader_paths | |||||
from fastNLP.io import ConllLoader | |||||
from reproduction.seqence_labelling.ner.data.utils import iob2bioes, iob2 | |||||
class OntoNoteNERDataLoader(DataSetLoader): | |||||
""" | |||||
用于读取处理为Conll格式后的OntoNote数据。将OntoNote数据处理为conll格式的过程可以参考https://github.com/yhcc/OntoNotes-5.0-NER。 | |||||
""" | |||||
def __init__(self, encoding_type:str='bioes'): | |||||
assert encoding_type in ('bioes', 'bio') | |||||
self.encoding_type = encoding_type | |||||
if encoding_type=='bioes': | |||||
self.encoding_method = iob2bioes | |||||
else: | |||||
self.encoding_method = iob2 | |||||
def load(self, path:str)->DataSet: | |||||
""" | |||||
给定一个文件路径,读取数据。返回的DataSet包含以下的field | |||||
raw_words: List[str] | |||||
target: List[str] | |||||
:param path: | |||||
:return: | |||||
""" | |||||
dataset = ConllLoader(headers=['raw_words', 'target'], indexes=[3, 10]).load(path) | |||||
def convert_to_bio(tags): | |||||
bio_tags = [] | |||||
flag = None | |||||
for tag in tags: | |||||
label = tag.strip("()*") | |||||
if '(' in tag: | |||||
bio_label = 'B-' + label | |||||
flag = label | |||||
elif flag: | |||||
bio_label = 'I-' + flag | |||||
else: | |||||
bio_label = 'O' | |||||
if ')' in tag: | |||||
flag = None | |||||
bio_tags.append(bio_label) | |||||
return self.encoding_method(bio_tags) | |||||
def convert_word(words): | |||||
converted_words = [] | |||||
for word in words: | |||||
word = word.replace('/.', '.') # 有些结尾的.是/.形式的 | |||||
if not word.startswith('-'): | |||||
converted_words.append(word) | |||||
continue | |||||
# 以下是由于这些符号被转义了,再转回来 | |||||
tfrs = {'-LRB-':'(', | |||||
'-RRB-': ')', | |||||
'-LSB-': '[', | |||||
'-RSB-': ']', | |||||
'-LCB-': '{', | |||||
'-RCB-': '}' | |||||
} | |||||
if word in tfrs: | |||||
converted_words.append(tfrs[word]) | |||||
else: | |||||
converted_words.append(word) | |||||
return converted_words | |||||
dataset.apply_field(convert_word, field_name='raw_words', new_field_name='raw_words') | |||||
dataset.apply_field(convert_to_bio, field_name='target', new_field_name='target') | |||||
return dataset | |||||
def process(self, paths: Union[str, Dict[str, str]], word_vocab_opt:VocabularyOption=None, | |||||
lower:bool=True)->DataBundle: | |||||
""" | |||||
读取并处理数据。返回的DataInfo包含以下的内容 | |||||
vocabs: | |||||
word: Vocabulary | |||||
target: Vocabulary | |||||
datasets: | |||||
train: DataSet | |||||
words: List[int], 被设置为input | |||||
target: int. label,被同时设置为input和target | |||||
seq_len: int. 句子的长度,被同时设置为input和target | |||||
raw_words: List[str] | |||||
xxx(根据传入的paths可能有所变化) | |||||
:param paths: | |||||
:param word_vocab_opt: vocabulary的初始化值 | |||||
:param lower: 是否使用小写 | |||||
:return: | |||||
""" | |||||
paths = check_dataloader_paths(paths) | |||||
data = DataBundle() | |||||
input_fields = [Const.TARGET, Const.INPUT, Const.INPUT_LEN] | |||||
target_fields = [Const.TARGET, Const.INPUT_LEN] | |||||
for name, path in paths.items(): | |||||
dataset = self.load(path) | |||||
dataset.apply_field(lambda words: words, field_name='raw_words', new_field_name=Const.INPUT) | |||||
if lower: | |||||
dataset.words.lower() | |||||
data.datasets[name] = dataset | |||||
# 对construct vocab | |||||
word_vocab = Vocabulary(min_freq=2) if word_vocab_opt is None else Vocabulary(**word_vocab_opt) | |||||
word_vocab.from_dataset(data.datasets['train'], field_name=Const.INPUT, | |||||
no_create_entry_dataset=[dataset for name, dataset in data.datasets.items() if name!='train']) | |||||
word_vocab.index_dataset(*data.datasets.values(), field_name=Const.INPUT, new_field_name=Const.INPUT) | |||||
data.vocabs[Const.INPUT] = word_vocab | |||||
# cap words | |||||
cap_word_vocab = Vocabulary() | |||||
cap_word_vocab.from_dataset(*data.datasets.values(), field_name='raw_words') | |||||
cap_word_vocab.index_dataset(*data.datasets.values(), field_name='raw_words', new_field_name='cap_words') | |||||
input_fields.append('cap_words') | |||||
data.vocabs['cap_words'] = cap_word_vocab | |||||
# 对target建vocab | |||||
target_vocab = Vocabulary(unknown=None, padding=None) | |||||
target_vocab.from_dataset(*data.datasets.values(), field_name=Const.TARGET) | |||||
target_vocab.index_dataset(*data.datasets.values(), field_name=Const.TARGET) | |||||
data.vocabs[Const.TARGET] = target_vocab | |||||
for name, dataset in data.datasets.items(): | |||||
dataset.add_seq_len(Const.INPUT, new_field_name=Const.INPUT_LEN) | |||||
dataset.set_input(*input_fields) | |||||
dataset.set_target(*target_fields) | |||||
return data | |||||
if __name__ == '__main__': | |||||
loader = OntoNoteNERDataLoader() | |||||
dataset = loader.load('/hdd/fudanNLP/fastNLP/others/data/v4/english/test.txt') | |||||
print(dataset.target.value_count()) | |||||
print(dataset[:4]) | |||||
""" | |||||
train 115812 2200752 | |||||
development 15680 304684 | |||||
test 12217 230111 | |||||
train 92403 1901772 | |||||
valid 13606 279180 | |||||
test 10258 204135 | |||||
""" |
@@ -1,49 +0,0 @@ | |||||
from typing import List | |||||
def iob2(tags:List[str])->List[str]: | |||||
""" | |||||
检查数据是否是合法的IOB数据,如果是IOB1会被自动转换为IOB2。 | |||||
:param tags: 需要转换的tags | |||||
""" | |||||
for i, tag in enumerate(tags): | |||||
if tag == "O": | |||||
continue | |||||
split = tag.split("-") | |||||
if len(split) != 2 or split[0] not in ["I", "B"]: | |||||
raise TypeError("The encoding schema is not a valid IOB type.") | |||||
if split[0] == "B": | |||||
continue | |||||
elif i == 0 or tags[i - 1] == "O": # conversion IOB1 to IOB2 | |||||
tags[i] = "B" + tag[1:] | |||||
elif tags[i - 1][1:] == tag[1:]: | |||||
continue | |||||
else: # conversion IOB1 to IOB2 | |||||
tags[i] = "B" + tag[1:] | |||||
return tags | |||||
def iob2bioes(tags:List[str])->List[str]: | |||||
""" | |||||
将iob的tag转换为bmeso编码 | |||||
:param tags: | |||||
:return: | |||||
""" | |||||
new_tags = [] | |||||
for i, tag in enumerate(tags): | |||||
if tag == 'O': | |||||
new_tags.append(tag) | |||||
else: | |||||
split = tag.split('-')[0] | |||||
if split == 'B': | |||||
if i+1!=len(tags) and tags[i+1].split('-')[0] == 'I': | |||||
new_tags.append(tag) | |||||
else: | |||||
new_tags.append(tag.replace('B-', 'S-')) | |||||
elif split == 'I': | |||||
if i + 1<len(tags) and tags[i+1].split('-')[0] == 'I': | |||||
new_tags.append(tag) | |||||
else: | |||||
new_tags.append(tag.replace('I-', 'E-')) | |||||
else: | |||||
raise TypeError("Invalid IOB format.") | |||||
return new_tags |
@@ -0,0 +1,31 @@ | |||||
from torch import nn | |||||
from fastNLP.modules import ConditionalRandomField, allowed_transitions | |||||
import torch.nn.functional as F | |||||
class BertCRF(nn.Module): | |||||
def __init__(self, embed, tag_vocab, encoding_type='bio'): | |||||
super().__init__() | |||||
self.embed = embed | |||||
self.fc = nn.Linear(self.embed.embed_size, len(tag_vocab)) | |||||
trans = allowed_transitions(tag_vocab, encoding_type=encoding_type, include_start_end=True) | |||||
self.crf = ConditionalRandomField(len(tag_vocab), include_start_end_trans=True, allowed_transitions=trans) | |||||
def _forward(self, words, target): | |||||
mask = words.ne(0) | |||||
words = self.embed(words) | |||||
words = self.fc(words) | |||||
logits = F.log_softmax(words, dim=-1) | |||||
if target is not None: | |||||
loss = self.crf(logits, target, mask) | |||||
return {'loss': loss} | |||||
else: | |||||
paths, _ = self.crf.viterbi_decode(logits, mask) | |||||
return {'pred': paths} | |||||
def forward(self, words, target): | |||||
return self._forward(words, target) | |||||
def predict(self, words): | |||||
return self._forward(words, None) |
@@ -1,33 +0,0 @@ | |||||
from reproduction.seqence_labelling.ner.data.Conll2003Loader import Conll2003DataLoader | |||||
from reproduction.seqence_labelling.ner.data.Conll2003Loader import iob2, iob2bioes | |||||
import unittest | |||||
class TestTagSchemaConverter(unittest.TestCase): | |||||
def test_iob2(self): | |||||
tags = ['B-ORG', 'O', 'B-MISC', 'O', 'O', 'O', 'B-MISC', 'O', 'O'] | |||||
golden = ['B-ORG', 'O', 'B-MISC', 'O', 'O', 'O', 'B-MISC', 'O', 'O'] | |||||
self.assertListEqual(golden, iob2(tags)) | |||||
tags = ['I-ORG', 'O'] | |||||
golden = ['B-ORG', 'O'] | |||||
self.assertListEqual(golden, iob2(tags)) | |||||
tags = ['I-MISC', 'I-MISC', 'O', 'I-PER', 'I-PER', 'O'] | |||||
golden = ['B-MISC', 'I-MISC', 'O', 'B-PER', 'I-PER', 'O'] | |||||
self.assertListEqual(golden, iob2(tags)) | |||||
def test_iob2bemso(self): | |||||
tags = ['B-MISC', 'I-MISC', 'O', 'B-PER', 'I-PER', 'O'] | |||||
golden = ['B-MISC', 'E-MISC', 'O', 'B-PER', 'E-PER', 'O'] | |||||
self.assertListEqual(golden, iob2bioes(tags)) | |||||
def test_conll2003_loader(): | |||||
path = '/hdd/fudanNLP/fastNLP/others/data/conll2003/train.txt' | |||||
loader = Conll2003DataLoader().load(path) | |||||
print(loader[:3]) | |||||
if __name__ == '__main__': | |||||
test_conll2003_loader() |
@@ -0,0 +1,52 @@ | |||||
""" | |||||
使用Bert进行英文命名实体识别 | |||||
""" | |||||
import sys | |||||
sys.path.append('../../../') | |||||
from reproduction.seqence_labelling.ner.model.bert_crf import BertCRF | |||||
from fastNLP.embeddings import BertEmbedding | |||||
from fastNLP import Trainer, Const | |||||
from fastNLP import BucketSampler, SpanFPreRecMetric, GradientClipCallback | |||||
from fastNLP.core.callback import WarmupCallback | |||||
from fastNLP.core.optimizer import AdamW | |||||
from fastNLP.io import Conll2003NERPipe | |||||
from fastNLP import cache_results, EvaluateCallback | |||||
encoding_type = 'bioes' | |||||
@cache_results('caches/conll2003.pkl', _refresh=False) | |||||
def load_data(): | |||||
# 替换路径 | |||||
paths = 'data/conll2003' | |||||
data = Conll2003NERPipe(encoding_type=encoding_type).process_from_file(paths) | |||||
return data | |||||
data = load_data() | |||||
print(data) | |||||
embed = BertEmbedding(data.get_vocab(Const.INPUT), model_dir_or_name='en-base-cased', | |||||
pool_method='max', requires_grad=True, layers='11', include_cls_sep=False, dropout=0.5, | |||||
word_dropout=0.01) | |||||
callbacks = [ | |||||
GradientClipCallback(clip_type='norm', clip_value=1), | |||||
WarmupCallback(warmup=0.1, schedule='linear'), | |||||
EvaluateCallback(data.get_dataset('test')) | |||||
] | |||||
model = BertCRF(embed, tag_vocab=data.get_vocab('target'), encoding_type=encoding_type) | |||||
optimizer = AdamW(model.parameters(), lr=2e-5) | |||||
trainer = Trainer(train_data=data.datasets['train'], model=model, optimizer=optimizer, sampler=BucketSampler(), | |||||
device=0, dev_data=data.datasets['dev'], batch_size=6, | |||||
metrics=SpanFPreRecMetric(tag_vocab=data.vocabs[Const.TARGET], encoding_type=encoding_type), | |||||
loss=None, callbacks=callbacks, num_workers=2, n_epochs=5, | |||||
check_code_level=0, update_every=3, test_use_tqdm=False) | |||||
trainer.train() | |||||
@@ -1,4 +1,4 @@ | |||||
from reproduction.seqence_labelling.ner.data.OntoNoteLoader import OntoNoteNERDataLoader | |||||
from fastNLP.io import OntoNotesNERPipe | |||||
from fastNLP.core.callback import LRScheduler | from fastNLP.core.callback import LRScheduler | ||||
from fastNLP import GradientClipCallback | from fastNLP import GradientClipCallback | ||||
from torch.optim.lr_scheduler import LambdaLR | from torch.optim.lr_scheduler import LambdaLR | ||||
@@ -10,14 +10,10 @@ from fastNLP import Trainer, Tester | |||||
from fastNLP.core.metrics import MetricBase | from fastNLP.core.metrics import MetricBase | ||||
from reproduction.seqence_labelling.ner.model.dilated_cnn import IDCNN | from reproduction.seqence_labelling.ner.model.dilated_cnn import IDCNN | ||||
from fastNLP.core.utils import Option | from fastNLP.core.utils import Option | ||||
from fastNLP.embeddings.embedding import StaticEmbedding | |||||
from fastNLP.embeddings import StaticEmbedding | |||||
from fastNLP.core.utils import cache_results | from fastNLP.core.utils import cache_results | ||||
from fastNLP.core.vocabulary import VocabularyOption | |||||
import torch.cuda | import torch.cuda | ||||
import os | import os | ||||
os.environ['FASTNLP_BASE_URL'] = 'http://10.141.222.118:8888/file/download/' | |||||
os.environ['FASTNLP_CACHE_DIR'] = '/remote-home/hyan01/fastnlp_caches' | |||||
os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID" | |||||
encoding_type = 'bioes' | encoding_type = 'bioes' | ||||
@@ -40,18 +36,8 @@ ops = Option( | |||||
@cache_results('ontonotes-case-cache') | @cache_results('ontonotes-case-cache') | ||||
def load_data(): | def load_data(): | ||||
print('loading data') | print('loading data') | ||||
data = OntoNoteNERDataLoader(encoding_type=encoding_type).process( | |||||
paths = get_path('workdir/datasets/ontonotes-v4'), | |||||
lower=False, | |||||
word_vocab_opt=VocabularyOption(min_freq=0), | |||||
) | |||||
# data = Conll2003DataLoader(task='ner', encoding_type=encoding_type).process( | |||||
# paths=get_path('workdir/datasets/conll03'), | |||||
# lower=False, word_vocab_opt=VocabularyOption(min_freq=0) | |||||
# ) | |||||
# char_embed = CNNCharEmbedding(vocab=data.vocabs['cap_words'], embed_size=30, char_emb_size=30, filter_nums=[30], | |||||
# kernel_sizes=[3]) | |||||
data = OntoNotesNERPipe(encoding_type=encoding_type).process_from_file( | |||||
paths = get_path('workdir/datasets/ontonotes-v4')) | |||||
print('loading embedding') | print('loading embedding') | ||||
word_embed = StaticEmbedding(vocab=data.vocabs[Const.INPUT], | word_embed = StaticEmbedding(vocab=data.vocabs[Const.INPUT], | ||||
model_dir_or_name='en-glove-840b-300', | model_dir_or_name='en-glove-840b-300', | ||||