@@ -14,6 +14,7 @@ core 模块里实现了 fastNLP 的核心框架,常用的功能都可以从 fa | |||
""" | |||
from .batch import DataSetIter, BatchIter, TorchLoaderIter | |||
from .callback import Callback, GradientClipCallback, EarlyStopCallback, TensorboardCallback, LRScheduler, ControlC | |||
from .callback import EvaluateCallback, FitlogCallback, SaveModelCallback | |||
from .const import Const | |||
from .dataset import DataSet | |||
from .field import FieldArray, Padder, AutoPadder, EngChar2DPadder | |||
@@ -57,6 +57,7 @@ __all__ = [ | |||
"FitlogCallback", | |||
"LRScheduler", | |||
"ControlC", | |||
"EvaluateCallback", | |||
"CallbackException", | |||
"EarlyStopError" | |||
@@ -504,10 +505,9 @@ class FitlogCallback(Callback): | |||
并将验证结果写入到fitlog中。这些数据集的结果是根据dev上最好的结果报道的,即如果dev在第3个epoch取得了最佳,则 | |||
fitlog中记录的关于这些数据集的结果就是来自第三个epoch的结果。 | |||
:param ~fastNLP.DataSet,Dict[~fastNLP.DataSet] data: 传入DataSet对象,会使用多个Trainer中的metric对数据进行验证。如果需要传入多个 | |||
DataSet请通过dict的方式传入,dict的key将作为对应dataset的name传递给fitlog。若tester不为None时,data需要通过 | |||
dict的方式传入。如果仅传入DataSet, 则被命名为test | |||
:param ~fastNLP.Tester tester: Tester对象,将在on_valid_end时调用。tester中的DataSet会被称为为`test` | |||
:param ~fastNLP.DataSet,Dict[~fastNLP.DataSet] data: 传入DataSet对象,会使用多个Trainer中的metric对数据进行验证。如果需要 | |||
传入多个DataSet请通过dict的方式传入,dict的key将作为对应dataset的name传递给fitlog。data的结果的名称以'data'开头。 | |||
:param ~fastNLP.Tester,Dict[~fastNLP.Tester] tester: Tester对象,将在on_valid_end时调用。tester的结果的名称以'tester'开头 | |||
:param int log_loss_every: 多少个step记录一次loss(记录的是这几个batch的loss平均值),如果数据集较大建议将该值设置得 | |||
大一些,不然会导致log文件巨大。默认为0, 即不要记录loss。 | |||
:param int verbose: 是否在终端打印evaluation的结果,0不打印。 | |||
@@ -521,20 +521,23 @@ class FitlogCallback(Callback): | |||
self._log_exception = log_exception | |||
assert isinstance(log_loss_every, int) and log_loss_every>=0 | |||
if tester is not None: | |||
assert isinstance(tester, Tester), "Only fastNLP.Tester allowed." | |||
assert isinstance(data, dict) or data is None, "If tester is not None, only dict[DataSet] allowed for data." | |||
if data is not None: | |||
assert 'test' not in data, "Cannot use `test` as DataSet key, when tester is passed." | |||
setattr(tester, 'verbose', 0) | |||
self.testers['test'] = tester | |||
if isinstance(tester, dict): | |||
for name, test in tester.items(): | |||
if not isinstance(test, Tester): | |||
raise TypeError(f"{name} in tester is not a valid fastNLP.Tester.") | |||
self.testers['tester-' + name] = test | |||
if isinstance(tester, Tester): | |||
self.testers['tester-test'] = tester | |||
for tester in self.testers.values(): | |||
setattr(tester, 'verbose', 0) | |||
if isinstance(data, dict): | |||
for key, value in data.items(): | |||
assert isinstance(value, DataSet), f"Only DataSet object is allowed, not {type(value)}." | |||
for key, value in data.items(): | |||
self.datasets[key] = value | |||
self.datasets['data-' + key] = value | |||
elif isinstance(data, DataSet): | |||
self.datasets['test'] = data | |||
self.datasets['data-test'] = data | |||
elif data is not None: | |||
raise TypeError("data receives dict[DataSet] or DataSet object.") | |||
@@ -548,8 +551,11 @@ class FitlogCallback(Callback): | |||
if len(self.datasets) > 0: | |||
for key, data in self.datasets.items(): | |||
tester = Tester(data=data, model=self.model, batch_size=self.batch_size, metrics=self.trainer.metrics, | |||
verbose=0) | |||
tester = Tester(data=data, model=self.model, | |||
batch_size=self.trainer.kwargs.get('dev_batch_size', self.batch_size), | |||
metrics=self.trainer.metrics, | |||
verbose=0, | |||
use_tqdm=self.trainer.use_tqdm) | |||
self.testers[key] = tester | |||
fitlog.add_progress(total_steps=self.n_steps) | |||
@@ -589,6 +595,65 @@ class FitlogCallback(Callback): | |||
fitlog.add_other(repr(exception), name='except_info') | |||
class EvaluateCallback(Callback): | |||
""" | |||
别名: :class:`fastNLP.EvaluateCallback` :class:`fastNLP.core.callback.EvaluateCallback` | |||
该callback用于扩展Trainer训练过程中只能对dev数据进行验证的问题。 | |||
:param ~fastNLP.DataSet,Dict[~fastNLP.DataSet] data: 传入DataSet对象,会使用多个Trainer中的metric对数据进行验证。如果需要传入多个 | |||
DataSet请通过dict的方式传入。 | |||
:param ~fastNLP.Tester,Dict[~fastNLP.DataSet] tester: Tester对象,将在on_valid_end时调用。 | |||
""" | |||
def __init__(self, data=None, tester=None): | |||
super().__init__() | |||
self.datasets = {} | |||
self.testers = {} | |||
if tester is not None: | |||
if isinstance(tester, dict): | |||
for name, test in tester.items(): | |||
if not isinstance(test, Tester): | |||
raise TypeError(f"{name} in tester is not a valid fastNLP.Tester.") | |||
self.testers['tester-' + name] = test | |||
if isinstance(tester, Tester): | |||
self.testers['tester-test'] = tester | |||
for tester in self.testers.values(): | |||
setattr(tester, 'verbose', 0) | |||
if isinstance(data, dict): | |||
for key, value in data.items(): | |||
assert isinstance(value, DataSet), f"Only DataSet object is allowed, not {type(value)}." | |||
for key, value in data.items(): | |||
self.datasets['data-' + key] = value | |||
elif isinstance(data, DataSet): | |||
self.datasets['data-test'] = data | |||
elif data is not None: | |||
raise TypeError("data receives dict[DataSet] or DataSet object.") | |||
def on_train_begin(self): | |||
if len(self.datasets) > 0and self.trainer.dev_data is None: | |||
raise RuntimeError("Trainer has no dev data, you cannot pass extra DataSet to do evaluation.") | |||
if len(self.datasets) > 0: | |||
for key, data in self.datasets.items(): | |||
tester = Tester(data=data, model=self.model, | |||
batch_size=self.trainer.kwargs.get('dev_batch_size', self.batch_size), | |||
metrics=self.trainer.metrics, verbose=0, | |||
use_tqdm=self.trainer.use_tqdm) | |||
self.testers[key] = tester | |||
def on_valid_end(self, eval_result, metric_key, optimizer, better_result): | |||
if len(self.testers) > 0: | |||
for key, tester in self.testers.items(): | |||
try: | |||
eval_result = tester.test() | |||
self.pbar.write("Evaluation on {}:".format(key)) | |||
self.pbar.write(tester._format_eval_results(eval_result)) | |||
except Exception: | |||
self.pbar.write("Exception happens when evaluate on DataSet named `{}`.".format(key)) | |||
class LRScheduler(Callback): | |||
""" | |||
别名::class:`fastNLP.LRScheduler` :class:`fastNLP.core.callback.LRScheduler` | |||
@@ -690,7 +690,7 @@ class Trainer(object): | |||
(self.validate_every < 0 and self.step % len(data_iterator) == 0)) \ | |||
and self.dev_data is not None: | |||
eval_res = self._do_validation(epoch=epoch, step=self.step) | |||
eval_str = "Evaluation at Epoch {}/{}. Step:{}/{}. ".format(epoch, self.n_epochs, self.step, | |||
eval_str = "Evaluation on dev at Epoch {}/{}. Step:{}/{}. ".format(epoch, self.n_epochs, self.step, | |||
self.n_steps) + \ | |||
self.tester._format_eval_results(eval_res) | |||
pbar.write(eval_str + '\n') | |||
@@ -74,7 +74,7 @@ class BertEmbedding(ContextualEmbedding): | |||
self.model = _WordBertModel(model_dir=model_dir, vocab=vocab, layers=layers, | |||
pool_method=pool_method, include_cls_sep=include_cls_sep, | |||
pooled_cls=pooled_cls, auto_truncate=auto_truncate) | |||
pooled_cls=pooled_cls, auto_truncate=auto_truncate, min_freq=2) | |||
self.requires_grad = requires_grad | |||
self._embed_size = len(self.model.layers)*self.model.encoder.hidden_size | |||
@@ -209,7 +209,7 @@ class BertWordPieceEncoder(nn.Module): | |||
class _WordBertModel(nn.Module): | |||
def __init__(self, model_dir:str, vocab:Vocabulary, layers:str='-1', pool_method:str='first', | |||
include_cls_sep:bool=False, pooled_cls:bool=False, auto_truncate:bool=False): | |||
include_cls_sep:bool=False, pooled_cls:bool=False, auto_truncate:bool=False, min_freq=2): | |||
super().__init__() | |||
self.tokenzier = BertTokenizer.from_pretrained(model_dir) | |||
@@ -238,9 +238,12 @@ class _WordBertModel(nn.Module): | |||
word_piece_dict = {'[CLS]':1, '[SEP]':1} # 用到的word_piece以及新增的 | |||
found_count = 0 | |||
self._has_sep_in_vocab = '[SEP]' in vocab # 用来判断传入的数据是否需要生成token_ids | |||
if '[sep]' in vocab: | |||
warnings.warn("Lower cased [sep] detected, it cannot be correctly recognized as [SEP] by BertEmbedding.") | |||
if "[CLS]" in vocab: | |||
warnings.warn("[CLS] detected in your vocabulary. BertEmbedding will add [CSL] and [SEP] to the begin " | |||
"and end of the sentence automatically.") | |||
"and end of the input automatically, make sure you don't add [CLS] and [SEP] at the begin" | |||
" and end.") | |||
for word, index in vocab: | |||
if index == vocab.padding_idx: # pad是个特殊的符号 | |||
word = '[PAD]' | |||
@@ -250,7 +253,8 @@ class _WordBertModel(nn.Module): | |||
if len(word_pieces)==1: | |||
if not vocab._is_word_no_create_entry(word): # 如果是train中的值, 但是却没有找到 | |||
if index!=vocab.unknown_idx and word_pieces[0]=='[UNK]': # 说明这个词不在原始的word里面 | |||
word_piece_dict[word] = 1 # 新增一个值 | |||
if vocab.word_count[word]>=min_freq and not vocab._is_word_no_create_entry(word): #出现次数大于这个次数才新增 | |||
word_piece_dict[word] = 1 # 新增一个值 | |||
continue | |||
for word_piece in word_pieces: | |||
word_piece_dict[word_piece] = 1 | |||
@@ -54,7 +54,7 @@ class StaticEmbedding(TokenEmbedding): | |||
:param int min_freq: Vocabulary词频数小于这个数量的word将被指向unk。 | |||
""" | |||
def __init__(self, vocab: Vocabulary, model_dir_or_name: str='en', embedding_dim=100, requires_grad: bool=True, | |||
init_method=None, lower=False, dropout=0, word_dropout=0, normalize=False, min_freq=1): | |||
init_method=None, lower=False, dropout=0, word_dropout=0, normalize=False, min_freq=1, **kwargs): | |||
super(StaticEmbedding, self).__init__(vocab, word_dropout=word_dropout, dropout=dropout) | |||
# 得到cache_path | |||
@@ -73,7 +73,7 @@ class StaticEmbedding(TokenEmbedding): | |||
else: | |||
raise ValueError(f"Cannot recognize {model_dir_or_name}.") | |||
# 缩小vocab | |||
# 根据min_freq缩小vocab | |||
truncate_vocab = (vocab.min_freq is None and min_freq>1) or (vocab.min_freq and vocab.min_freq<min_freq) | |||
if truncate_vocab: | |||
truncated_vocab = deepcopy(vocab) | |||
@@ -88,6 +88,13 @@ class StaticEmbedding(TokenEmbedding): | |||
if lowered_word_count[word.lower()]>=min_freq and word_count<min_freq: | |||
truncated_vocab.add_word_lst([word]*(min_freq-word_count), | |||
no_create_entry=truncated_vocab._is_word_no_create_entry(word)) | |||
# 只限制在train里面的词语使用min_freq筛选 | |||
if kwargs.get('only_train_min_freq', False): | |||
for word in truncated_vocab.word_count.keys(): | |||
if truncated_vocab._is_word_no_create_entry(word) and truncated_vocab.word_count[word]<min_freq: | |||
truncated_vocab.add_word_lst([word] * (min_freq - truncated_vocab.word_count[word]), | |||
no_create_entry=True) | |||
truncated_vocab.build_vocab() | |||
truncated_words_to_words = torch.arange(len(vocab)).long() | |||
for word, index in vocab: | |||
@@ -307,15 +307,15 @@ def get_from_cache(url: str, cache_dir: Path = None) -> Path: | |||
if not cache_path.exists(): | |||
# Download to temporary file, then copy to cache dir once finished. | |||
# Otherwise you get corrupt cache entries if the download gets interrupted. | |||
fd, temp_filename = tempfile.mkstemp() | |||
print("%s not found in cache, downloading to %s" % (url, temp_filename)) | |||
# GET file object | |||
req = requests.get(url, stream=True, headers={"User-Agent": "fastNLP"}) | |||
if req.status_code == 200: | |||
content_length = req.headers.get("Content-Length") | |||
total = int(content_length) if content_length is not None else None | |||
progress = tqdm(unit="B", total=total, unit_scale=1) | |||
fd, temp_filename = tempfile.mkstemp() | |||
print("%s not found in cache, downloading to %s" % (url, temp_filename)) | |||
with open(temp_filename, "wb") as temp_file: | |||
for chunk in req.iter_content(chunk_size=1024 * 16): | |||
if chunk: # filter out keep-alive new chunks | |||
@@ -373,7 +373,7 @@ def get_from_cache(url: str, cache_dir: Path = None) -> Path: | |||
os.remove(temp_filename) | |||
return get_filepath(cache_path) | |||
else: | |||
raise HTTPError(f"Fail to download from {url}.") | |||
raise HTTPError(f"Status code:{req.status_code}. Fail to download from {url}.") | |||
def unzip_file(file: Path, to: Path): | |||