@@ -14,6 +14,7 @@ core 模块里实现了 fastNLP 的核心框架,常用的功能都可以从 fa | |||||
""" | """ | ||||
from .batch import DataSetIter, BatchIter, TorchLoaderIter | from .batch import DataSetIter, BatchIter, TorchLoaderIter | ||||
from .callback import Callback, GradientClipCallback, EarlyStopCallback, TensorboardCallback, LRScheduler, ControlC | from .callback import Callback, GradientClipCallback, EarlyStopCallback, TensorboardCallback, LRScheduler, ControlC | ||||
from .callback import EvaluateCallback, FitlogCallback, SaveModelCallback | |||||
from .const import Const | from .const import Const | ||||
from .dataset import DataSet | from .dataset import DataSet | ||||
from .field import FieldArray, Padder, AutoPadder, EngChar2DPadder | from .field import FieldArray, Padder, AutoPadder, EngChar2DPadder | ||||
@@ -57,6 +57,7 @@ __all__ = [ | |||||
"FitlogCallback", | "FitlogCallback", | ||||
"LRScheduler", | "LRScheduler", | ||||
"ControlC", | "ControlC", | ||||
"EvaluateCallback", | |||||
"CallbackException", | "CallbackException", | ||||
"EarlyStopError" | "EarlyStopError" | ||||
@@ -504,10 +505,9 @@ class FitlogCallback(Callback): | |||||
并将验证结果写入到fitlog中。这些数据集的结果是根据dev上最好的结果报道的,即如果dev在第3个epoch取得了最佳,则 | 并将验证结果写入到fitlog中。这些数据集的结果是根据dev上最好的结果报道的,即如果dev在第3个epoch取得了最佳,则 | ||||
fitlog中记录的关于这些数据集的结果就是来自第三个epoch的结果。 | fitlog中记录的关于这些数据集的结果就是来自第三个epoch的结果。 | ||||
:param ~fastNLP.DataSet,Dict[~fastNLP.DataSet] data: 传入DataSet对象,会使用多个Trainer中的metric对数据进行验证。如果需要传入多个 | |||||
DataSet请通过dict的方式传入,dict的key将作为对应dataset的name传递给fitlog。若tester不为None时,data需要通过 | |||||
dict的方式传入。如果仅传入DataSet, 则被命名为test | |||||
:param ~fastNLP.Tester tester: Tester对象,将在on_valid_end时调用。tester中的DataSet会被称为为`test` | |||||
:param ~fastNLP.DataSet,Dict[~fastNLP.DataSet] data: 传入DataSet对象,会使用多个Trainer中的metric对数据进行验证。如果需要 | |||||
传入多个DataSet请通过dict的方式传入,dict的key将作为对应dataset的name传递给fitlog。data的结果的名称以'data'开头。 | |||||
:param ~fastNLP.Tester,Dict[~fastNLP.Tester] tester: Tester对象,将在on_valid_end时调用。tester的结果的名称以'tester'开头 | |||||
:param int log_loss_every: 多少个step记录一次loss(记录的是这几个batch的loss平均值),如果数据集较大建议将该值设置得 | :param int log_loss_every: 多少个step记录一次loss(记录的是这几个batch的loss平均值),如果数据集较大建议将该值设置得 | ||||
大一些,不然会导致log文件巨大。默认为0, 即不要记录loss。 | 大一些,不然会导致log文件巨大。默认为0, 即不要记录loss。 | ||||
:param int verbose: 是否在终端打印evaluation的结果,0不打印。 | :param int verbose: 是否在终端打印evaluation的结果,0不打印。 | ||||
@@ -521,20 +521,23 @@ class FitlogCallback(Callback): | |||||
self._log_exception = log_exception | self._log_exception = log_exception | ||||
assert isinstance(log_loss_every, int) and log_loss_every>=0 | assert isinstance(log_loss_every, int) and log_loss_every>=0 | ||||
if tester is not None: | if tester is not None: | ||||
assert isinstance(tester, Tester), "Only fastNLP.Tester allowed." | |||||
assert isinstance(data, dict) or data is None, "If tester is not None, only dict[DataSet] allowed for data." | |||||
if data is not None: | |||||
assert 'test' not in data, "Cannot use `test` as DataSet key, when tester is passed." | |||||
setattr(tester, 'verbose', 0) | |||||
self.testers['test'] = tester | |||||
if isinstance(tester, dict): | |||||
for name, test in tester.items(): | |||||
if not isinstance(test, Tester): | |||||
raise TypeError(f"{name} in tester is not a valid fastNLP.Tester.") | |||||
self.testers['tester-' + name] = test | |||||
if isinstance(tester, Tester): | |||||
self.testers['tester-test'] = tester | |||||
for tester in self.testers.values(): | |||||
setattr(tester, 'verbose', 0) | |||||
if isinstance(data, dict): | if isinstance(data, dict): | ||||
for key, value in data.items(): | for key, value in data.items(): | ||||
assert isinstance(value, DataSet), f"Only DataSet object is allowed, not {type(value)}." | assert isinstance(value, DataSet), f"Only DataSet object is allowed, not {type(value)}." | ||||
for key, value in data.items(): | for key, value in data.items(): | ||||
self.datasets[key] = value | |||||
self.datasets['data-' + key] = value | |||||
elif isinstance(data, DataSet): | elif isinstance(data, DataSet): | ||||
self.datasets['test'] = data | |||||
self.datasets['data-test'] = data | |||||
elif data is not None: | elif data is not None: | ||||
raise TypeError("data receives dict[DataSet] or DataSet object.") | raise TypeError("data receives dict[DataSet] or DataSet object.") | ||||
@@ -548,8 +551,11 @@ class FitlogCallback(Callback): | |||||
if len(self.datasets) > 0: | if len(self.datasets) > 0: | ||||
for key, data in self.datasets.items(): | for key, data in self.datasets.items(): | ||||
tester = Tester(data=data, model=self.model, batch_size=self.batch_size, metrics=self.trainer.metrics, | |||||
verbose=0) | |||||
tester = Tester(data=data, model=self.model, | |||||
batch_size=self.trainer.kwargs.get('dev_batch_size', self.batch_size), | |||||
metrics=self.trainer.metrics, | |||||
verbose=0, | |||||
use_tqdm=self.trainer.use_tqdm) | |||||
self.testers[key] = tester | self.testers[key] = tester | ||||
fitlog.add_progress(total_steps=self.n_steps) | fitlog.add_progress(total_steps=self.n_steps) | ||||
@@ -589,6 +595,65 @@ class FitlogCallback(Callback): | |||||
fitlog.add_other(repr(exception), name='except_info') | fitlog.add_other(repr(exception), name='except_info') | ||||
class EvaluateCallback(Callback): | |||||
""" | |||||
别名: :class:`fastNLP.EvaluateCallback` :class:`fastNLP.core.callback.EvaluateCallback` | |||||
该callback用于扩展Trainer训练过程中只能对dev数据进行验证的问题。 | |||||
:param ~fastNLP.DataSet,Dict[~fastNLP.DataSet] data: 传入DataSet对象,会使用多个Trainer中的metric对数据进行验证。如果需要传入多个 | |||||
DataSet请通过dict的方式传入。 | |||||
:param ~fastNLP.Tester,Dict[~fastNLP.DataSet] tester: Tester对象,将在on_valid_end时调用。 | |||||
""" | |||||
def __init__(self, data=None, tester=None): | |||||
super().__init__() | |||||
self.datasets = {} | |||||
self.testers = {} | |||||
if tester is not None: | |||||
if isinstance(tester, dict): | |||||
for name, test in tester.items(): | |||||
if not isinstance(test, Tester): | |||||
raise TypeError(f"{name} in tester is not a valid fastNLP.Tester.") | |||||
self.testers['tester-' + name] = test | |||||
if isinstance(tester, Tester): | |||||
self.testers['tester-test'] = tester | |||||
for tester in self.testers.values(): | |||||
setattr(tester, 'verbose', 0) | |||||
if isinstance(data, dict): | |||||
for key, value in data.items(): | |||||
assert isinstance(value, DataSet), f"Only DataSet object is allowed, not {type(value)}." | |||||
for key, value in data.items(): | |||||
self.datasets['data-' + key] = value | |||||
elif isinstance(data, DataSet): | |||||
self.datasets['data-test'] = data | |||||
elif data is not None: | |||||
raise TypeError("data receives dict[DataSet] or DataSet object.") | |||||
def on_train_begin(self): | |||||
if len(self.datasets) > 0and self.trainer.dev_data is None: | |||||
raise RuntimeError("Trainer has no dev data, you cannot pass extra DataSet to do evaluation.") | |||||
if len(self.datasets) > 0: | |||||
for key, data in self.datasets.items(): | |||||
tester = Tester(data=data, model=self.model, | |||||
batch_size=self.trainer.kwargs.get('dev_batch_size', self.batch_size), | |||||
metrics=self.trainer.metrics, verbose=0, | |||||
use_tqdm=self.trainer.use_tqdm) | |||||
self.testers[key] = tester | |||||
def on_valid_end(self, eval_result, metric_key, optimizer, better_result): | |||||
if len(self.testers) > 0: | |||||
for key, tester in self.testers.items(): | |||||
try: | |||||
eval_result = tester.test() | |||||
self.pbar.write("Evaluation on {}:".format(key)) | |||||
self.pbar.write(tester._format_eval_results(eval_result)) | |||||
except Exception: | |||||
self.pbar.write("Exception happens when evaluate on DataSet named `{}`.".format(key)) | |||||
class LRScheduler(Callback): | class LRScheduler(Callback): | ||||
""" | """ | ||||
别名::class:`fastNLP.LRScheduler` :class:`fastNLP.core.callback.LRScheduler` | 别名::class:`fastNLP.LRScheduler` :class:`fastNLP.core.callback.LRScheduler` | ||||
@@ -690,7 +690,7 @@ class Trainer(object): | |||||
(self.validate_every < 0 and self.step % len(data_iterator) == 0)) \ | (self.validate_every < 0 and self.step % len(data_iterator) == 0)) \ | ||||
and self.dev_data is not None: | and self.dev_data is not None: | ||||
eval_res = self._do_validation(epoch=epoch, step=self.step) | eval_res = self._do_validation(epoch=epoch, step=self.step) | ||||
eval_str = "Evaluation at Epoch {}/{}. Step:{}/{}. ".format(epoch, self.n_epochs, self.step, | |||||
eval_str = "Evaluation on dev at Epoch {}/{}. Step:{}/{}. ".format(epoch, self.n_epochs, self.step, | |||||
self.n_steps) + \ | self.n_steps) + \ | ||||
self.tester._format_eval_results(eval_res) | self.tester._format_eval_results(eval_res) | ||||
pbar.write(eval_str + '\n') | pbar.write(eval_str + '\n') | ||||
@@ -74,7 +74,7 @@ class BertEmbedding(ContextualEmbedding): | |||||
self.model = _WordBertModel(model_dir=model_dir, vocab=vocab, layers=layers, | self.model = _WordBertModel(model_dir=model_dir, vocab=vocab, layers=layers, | ||||
pool_method=pool_method, include_cls_sep=include_cls_sep, | pool_method=pool_method, include_cls_sep=include_cls_sep, | ||||
pooled_cls=pooled_cls, auto_truncate=auto_truncate) | |||||
pooled_cls=pooled_cls, auto_truncate=auto_truncate, min_freq=2) | |||||
self.requires_grad = requires_grad | self.requires_grad = requires_grad | ||||
self._embed_size = len(self.model.layers)*self.model.encoder.hidden_size | self._embed_size = len(self.model.layers)*self.model.encoder.hidden_size | ||||
@@ -209,7 +209,7 @@ class BertWordPieceEncoder(nn.Module): | |||||
class _WordBertModel(nn.Module): | class _WordBertModel(nn.Module): | ||||
def __init__(self, model_dir:str, vocab:Vocabulary, layers:str='-1', pool_method:str='first', | def __init__(self, model_dir:str, vocab:Vocabulary, layers:str='-1', pool_method:str='first', | ||||
include_cls_sep:bool=False, pooled_cls:bool=False, auto_truncate:bool=False): | |||||
include_cls_sep:bool=False, pooled_cls:bool=False, auto_truncate:bool=False, min_freq=2): | |||||
super().__init__() | super().__init__() | ||||
self.tokenzier = BertTokenizer.from_pretrained(model_dir) | self.tokenzier = BertTokenizer.from_pretrained(model_dir) | ||||
@@ -238,9 +238,12 @@ class _WordBertModel(nn.Module): | |||||
word_piece_dict = {'[CLS]':1, '[SEP]':1} # 用到的word_piece以及新增的 | word_piece_dict = {'[CLS]':1, '[SEP]':1} # 用到的word_piece以及新增的 | ||||
found_count = 0 | found_count = 0 | ||||
self._has_sep_in_vocab = '[SEP]' in vocab # 用来判断传入的数据是否需要生成token_ids | self._has_sep_in_vocab = '[SEP]' in vocab # 用来判断传入的数据是否需要生成token_ids | ||||
if '[sep]' in vocab: | |||||
warnings.warn("Lower cased [sep] detected, it cannot be correctly recognized as [SEP] by BertEmbedding.") | |||||
if "[CLS]" in vocab: | if "[CLS]" in vocab: | ||||
warnings.warn("[CLS] detected in your vocabulary. BertEmbedding will add [CSL] and [SEP] to the begin " | warnings.warn("[CLS] detected in your vocabulary. BertEmbedding will add [CSL] and [SEP] to the begin " | ||||
"and end of the sentence automatically.") | |||||
"and end of the input automatically, make sure you don't add [CLS] and [SEP] at the begin" | |||||
" and end.") | |||||
for word, index in vocab: | for word, index in vocab: | ||||
if index == vocab.padding_idx: # pad是个特殊的符号 | if index == vocab.padding_idx: # pad是个特殊的符号 | ||||
word = '[PAD]' | word = '[PAD]' | ||||
@@ -250,7 +253,8 @@ class _WordBertModel(nn.Module): | |||||
if len(word_pieces)==1: | if len(word_pieces)==1: | ||||
if not vocab._is_word_no_create_entry(word): # 如果是train中的值, 但是却没有找到 | if not vocab._is_word_no_create_entry(word): # 如果是train中的值, 但是却没有找到 | ||||
if index!=vocab.unknown_idx and word_pieces[0]=='[UNK]': # 说明这个词不在原始的word里面 | if index!=vocab.unknown_idx and word_pieces[0]=='[UNK]': # 说明这个词不在原始的word里面 | ||||
word_piece_dict[word] = 1 # 新增一个值 | |||||
if vocab.word_count[word]>=min_freq and not vocab._is_word_no_create_entry(word): #出现次数大于这个次数才新增 | |||||
word_piece_dict[word] = 1 # 新增一个值 | |||||
continue | continue | ||||
for word_piece in word_pieces: | for word_piece in word_pieces: | ||||
word_piece_dict[word_piece] = 1 | word_piece_dict[word_piece] = 1 | ||||
@@ -54,7 +54,7 @@ class StaticEmbedding(TokenEmbedding): | |||||
:param int min_freq: Vocabulary词频数小于这个数量的word将被指向unk。 | :param int min_freq: Vocabulary词频数小于这个数量的word将被指向unk。 | ||||
""" | """ | ||||
def __init__(self, vocab: Vocabulary, model_dir_or_name: str='en', embedding_dim=100, requires_grad: bool=True, | def __init__(self, vocab: Vocabulary, model_dir_or_name: str='en', embedding_dim=100, requires_grad: bool=True, | ||||
init_method=None, lower=False, dropout=0, word_dropout=0, normalize=False, min_freq=1): | |||||
init_method=None, lower=False, dropout=0, word_dropout=0, normalize=False, min_freq=1, **kwargs): | |||||
super(StaticEmbedding, self).__init__(vocab, word_dropout=word_dropout, dropout=dropout) | super(StaticEmbedding, self).__init__(vocab, word_dropout=word_dropout, dropout=dropout) | ||||
# 得到cache_path | # 得到cache_path | ||||
@@ -73,7 +73,7 @@ class StaticEmbedding(TokenEmbedding): | |||||
else: | else: | ||||
raise ValueError(f"Cannot recognize {model_dir_or_name}.") | raise ValueError(f"Cannot recognize {model_dir_or_name}.") | ||||
# 缩小vocab | |||||
# 根据min_freq缩小vocab | |||||
truncate_vocab = (vocab.min_freq is None and min_freq>1) or (vocab.min_freq and vocab.min_freq<min_freq) | truncate_vocab = (vocab.min_freq is None and min_freq>1) or (vocab.min_freq and vocab.min_freq<min_freq) | ||||
if truncate_vocab: | if truncate_vocab: | ||||
truncated_vocab = deepcopy(vocab) | truncated_vocab = deepcopy(vocab) | ||||
@@ -88,6 +88,13 @@ class StaticEmbedding(TokenEmbedding): | |||||
if lowered_word_count[word.lower()]>=min_freq and word_count<min_freq: | if lowered_word_count[word.lower()]>=min_freq and word_count<min_freq: | ||||
truncated_vocab.add_word_lst([word]*(min_freq-word_count), | truncated_vocab.add_word_lst([word]*(min_freq-word_count), | ||||
no_create_entry=truncated_vocab._is_word_no_create_entry(word)) | no_create_entry=truncated_vocab._is_word_no_create_entry(word)) | ||||
# 只限制在train里面的词语使用min_freq筛选 | |||||
if kwargs.get('only_train_min_freq', False): | |||||
for word in truncated_vocab.word_count.keys(): | |||||
if truncated_vocab._is_word_no_create_entry(word) and truncated_vocab.word_count[word]<min_freq: | |||||
truncated_vocab.add_word_lst([word] * (min_freq - truncated_vocab.word_count[word]), | |||||
no_create_entry=True) | |||||
truncated_vocab.build_vocab() | truncated_vocab.build_vocab() | ||||
truncated_words_to_words = torch.arange(len(vocab)).long() | truncated_words_to_words = torch.arange(len(vocab)).long() | ||||
for word, index in vocab: | for word, index in vocab: | ||||
@@ -307,15 +307,15 @@ def get_from_cache(url: str, cache_dir: Path = None) -> Path: | |||||
if not cache_path.exists(): | if not cache_path.exists(): | ||||
# Download to temporary file, then copy to cache dir once finished. | # Download to temporary file, then copy to cache dir once finished. | ||||
# Otherwise you get corrupt cache entries if the download gets interrupted. | # Otherwise you get corrupt cache entries if the download gets interrupted. | ||||
fd, temp_filename = tempfile.mkstemp() | |||||
print("%s not found in cache, downloading to %s" % (url, temp_filename)) | |||||
# GET file object | # GET file object | ||||
req = requests.get(url, stream=True, headers={"User-Agent": "fastNLP"}) | req = requests.get(url, stream=True, headers={"User-Agent": "fastNLP"}) | ||||
if req.status_code == 200: | if req.status_code == 200: | ||||
content_length = req.headers.get("Content-Length") | content_length = req.headers.get("Content-Length") | ||||
total = int(content_length) if content_length is not None else None | total = int(content_length) if content_length is not None else None | ||||
progress = tqdm(unit="B", total=total, unit_scale=1) | progress = tqdm(unit="B", total=total, unit_scale=1) | ||||
fd, temp_filename = tempfile.mkstemp() | |||||
print("%s not found in cache, downloading to %s" % (url, temp_filename)) | |||||
with open(temp_filename, "wb") as temp_file: | with open(temp_filename, "wb") as temp_file: | ||||
for chunk in req.iter_content(chunk_size=1024 * 16): | for chunk in req.iter_content(chunk_size=1024 * 16): | ||||
if chunk: # filter out keep-alive new chunks | if chunk: # filter out keep-alive new chunks | ||||
@@ -373,7 +373,7 @@ def get_from_cache(url: str, cache_dir: Path = None) -> Path: | |||||
os.remove(temp_filename) | os.remove(temp_filename) | ||||
return get_filepath(cache_path) | return get_filepath(cache_path) | ||||
else: | else: | ||||
raise HTTPError(f"Fail to download from {url}.") | |||||
raise HTTPError(f"Status code:{req.status_code}. Fail to download from {url}.") | |||||
def unzip_file(file: Path, to: Path): | def unzip_file(file: Path, to: Path): | ||||