@@ -158,7 +158,7 @@ class Callback: | |||
def on_save_model(self, trainer): | |||
""" | |||
当将要保存模型时调用,此刻模型还未保存。 | |||
当调用 Trainer.save_model() 时调用,此刻模型还未保存。 | |||
:param trainer: | |||
:return: | |||
@@ -167,7 +167,7 @@ class Callback: | |||
def on_load_model(self, trainer): | |||
""" | |||
当将要加载模型时调用,此刻模型还未加载。 | |||
当调用 Trainer.load_model() 加载模型时调用,此刻模型还未加载。 | |||
:param trainer: | |||
:return: | |||
@@ -176,7 +176,7 @@ class Callback: | |||
def on_save_checkpoint(self, trainer) -> Dict: | |||
""" | |||
当 Trainer 将要保存 checkpoint 的时候触发,该函数用于保存当前 callback 在恢复需要的相关数据。 | |||
当 Trainer 将要保存 checkpoint 的时候触发 (即调用 Trainer.save_checkpoint() 函数时),该函数用于保存当前 callback 在恢复需要的相关数据。 | |||
:param trainer: | |||
:return: | |||
@@ -185,7 +185,8 @@ class Callback: | |||
def on_load_checkpoint(self, trainer, states: Optional[Dict]): | |||
r""" | |||
当 Trainer 要恢复 checkpoint 的时候触发( Trainer 与 Driver 已经加载好自身的状态),参数 states 为 on_save_checkpoint() | |||
当 Trainer 要恢复 checkpoint 的时候触发(即调用 Trainer.load_checkpoint() 函数时 Trainer 与 Driver 已经加载好自身的状态), | |||
参数 states 为 on_save_checkpoint() | |||
的返回值。 | |||
:param trainer: | |||
@@ -50,7 +50,7 @@ class CheckpointCallback(Callback): | |||
:param model_save_fn: 个性化的保存函数,当触发保存操作时,就调用这个函数,这个函数应当接受一个文件夹作为参数,不返回任何东西。 | |||
如果传入了 model_save_fn 函数,fastNLP 将不再进行模型相关的保存。在多卡场景下,我们只在 rank 0 上会运行该函数。 | |||
:param save_object: 可选 ['trainer', 'model'],表示在保存时的保存对象为 ``trainer+model`` 还是 只是 ``model`` 。如果 | |||
保存 ``trainer`` 对象的话,将会保存 :class:~fastNLP.Trainer 的相关状态,可以通过 :meth:`Trainer.load` 加载该断 | |||
保存 ``trainer`` 对象的话,将会保存 :class:~fastNLP.Trainer 的相关状态,可以通过 :meth:`Trainer.load_checkpoint` 加载该断 | |||
点继续训练。如果保存的是 ``Model`` 对象,则可以通过 :meth:`Trainer.load_model` 加载该模型权重。 | |||
:param save_evaluate_results: 是否保存 evaluate 的结果。如果为 True ,在保存 topk 模型的 folder 中还将额外保存一个 | |||
fastnlp_evaluate_results.json 文件,记录当前的 results。仅在设置了 topk 的场景下有用,默认为 True 。 | |||
@@ -1,9 +1,12 @@ | |||
__all__ = [ | |||
'FitlogCallback' | |||
] | |||
import os | |||
from .has_monitor_callback import HasMonitorCallback | |||
from ...envs import _module_available | |||
from ...envs import get_global_rank | |||
from ..log import logger | |||
if _module_available('fitlog'): | |||
import fitlog | |||
@@ -11,7 +14,9 @@ if _module_available('fitlog'): | |||
class FitlogCallback(HasMonitorCallback): | |||
""" | |||
自动记录 ``evaluation`` 结果到 ``fitlog`` 中。会自动记录每一次 ``evaluate`` 后的结果;同时会根据 | |||
``monitor`` 记录最好的结果。另外,会自动将非 ``rank 0`` 上的 ``fitlog`` 设置为 ``debug`` 状态。 | |||
``monitor`` 记录最好的结果。另外,会自动将非 ``rank 0`` 上的 ``fitlog`` 设置为 ``debug`` 状态。同时还会在 ``fitlog`` 的 | |||
``other`` 列中记录一个 ``launch_time`` ,可以通过这个数值找到当前这个脚本的在 save_folder (如果有使用其它需要保存模型的 | |||
``Callback`` ,例如 :class:`~fastNLP.CheckpointCallback` )下的文件夹名称。 | |||
:param monitor: 监控的 metric 值。 | |||
@@ -38,6 +43,14 @@ class FitlogCallback(HasMonitorCallback): | |||
def on_after_trainer_initialized(self, trainer, driver): | |||
if get_global_rank() != 0: # 如果不是 global rank 为 0 ,需要关闭 fitlog | |||
fitlog.debug() | |||
super().on_after_trainer_initialized(trainer, driver) | |||
fitlog.add_other('launch_time', os.environ['FASTNLP_LAUNCH_TIME']) | |||
def on_sanity_check_end(self, trainer, sanity_check_res): | |||
super(FitlogCallback, self).on_sanity_check_end(trainer, sanity_check_res) | |||
if self.monitor is None: | |||
logger.rank_zero_warning(f"No monitor set for {self.__class__.__name__}. Therefore, no best metric will " | |||
f"be logged.") | |||
def on_evaluate_end(self, trainer, results): | |||
results = self.itemize_results(results) | |||
@@ -63,7 +63,7 @@ class MoreEvaluateCallback(HasMonitorCallback): | |||
时间戳文件夹中。如果为 None ,默认使用当前文件夹。 | |||
:param only_state_dict: 保存模型时是否只保存 state_dict 。当 model_save_fn 不为 None 时,该参数无效。 | |||
:param save_object: 可选 ['trainer', 'model'],表示在保存时的保存对象为 ``trainer+model`` 还是 只是 ``model`` 。如果 | |||
保存 ``trainer`` 对象的话,将会保存 :class:~fastNLP.Trainer 的相关状态,可以通过 :meth:`Trainer.load` 加载该断 | |||
保存 ``trainer`` 对象的话,将会保存 :class:~fastNLP.Trainer 的相关状态,可以通过 :meth:`Trainer.load_checkpoint` 加载该断 | |||
点继续训练。如果保存的是 ``Model`` 对象,则可以通过 :meth:`Trainer.load_model` 加载该模型权重。 | |||
:param model_save_fn: 个性化的保存函数,当触发保存操作时,就调用这个函数,这个函数应当接受一个文件夹作为参数,不返回任何东西。 | |||
如果传入了 model_save_fn 函数,fastNLP 将不再进行模型相关的保存。在多卡场景下,我们只在 rank 0 上会运行该函数。 | |||
@@ -25,12 +25,12 @@ class Saver: | |||
:param folder: 保存在哪个文件夹下,默认为当前 folder 下。 | |||
:param save_object: 可选 ['trainer', 'model'],表示在保存时的保存对象为 ``trainer+model`` 还是 只是 ``model`` 。如果 | |||
保存 ``trainer`` 对象的话,将会保存 :class:~fastNLP.Trainer 的相关状态,可以通过 :meth:`Trainer.load` 加载该断 | |||
保存 ``trainer`` 对象的话,将会保存 :class:~fastNLP.Trainer 的相关状态,可以通过 :meth:`Trainer.load_checkpoint` 加载该断 | |||
点继续训练。如果保存的是 ``Model`` 对象,则可以通过 :meth:`Trainer.load_model` 加载该模型权重。 | |||
:param only_state_dict: 保存时是否仅保存权重,在 model_save_fn 不为 None 时无意义。 | |||
:param model_save_fn: 个性化的保存函数,当触发保存操作时,就调用这个函数,这个函数应当接受一个文件夹作为参数,不返回任何东西。 | |||
如果传入了 model_save_fn 函数,fastNLP 将不再进行模型相关的保存。在多卡场景下,我们只在 rank 0 上会运行该函数。 | |||
:param kwargs: 更多需要传递给 Trainer.save() 或者 Trainer.save_model() 接口的参数。 | |||
:param kwargs: 更多需要传递给 Trainer.save_checkpoint() 或者 Trainer.save_model() 接口的参数。 | |||
""" | |||
def __init__(self, folder:str=None, save_object:str='model', only_state_dict:bool=True, | |||
model_save_fn:Callable=None, **kwargs): | |||
@@ -48,7 +48,7 @@ class Saver: | |||
self.model_save_fn = model_save_fn | |||
self.kwargs = kwargs | |||
self.save_object = save_object | |||
self.save_fn_name = 'save' if save_object == 'trainer' else 'save_model' | |||
self.save_fn_name = 'save_checkpoint' if save_object == 'trainer' else 'save_model' | |||
self.timestamp_path = self.folder.joinpath(os.environ[FASTNLP_LAUNCH_TIME]) | |||
@@ -193,14 +193,14 @@ class TopkSaver(ResultsMonitor, Saver): | |||
:param larger_better: 该 monitor 是否越大越好。 | |||
:param folder: 保存在哪个文件夹下,默认为当前 folder 下。 | |||
:param save_object: 可选 ['trainer', 'model'],表示在保存时的保存对象为 ``trainer+model`` 还是 只是 ``model`` 。如果 | |||
保存 ``trainer`` 对象的话,将会保存 :class:~fastNLP.Trainer 的相关状态,可以通过 :meth:`Trainer.load` 加载该断 | |||
保存 ``trainer`` 对象的话,将会保存 :class:~fastNLP.Trainer 的相关状态,可以通过 :meth:`Trainer.load_checkpoint` 加载该断 | |||
点继续训练。如果保存的是 ``Model`` 对象,则可以通过 :meth:`Trainer.load_model` 加载该模型权重。 | |||
:param only_state_dict: 保存时是否仅保存权重,在 model_save_fn 不为 None 时无意义。 | |||
:param model_save_fn: 个性化的保存函数,当触发保存操作时,就调用这个函数,这个函数应当接受一个文件夹作为参数,不返回任何东西。 | |||
如果传入了 model_save_fn 函数,fastNLP 将不再进行模型相关的保存。在多卡场景下,我们只在 rank 0 上会运行该函数。 | |||
:param save_evaluate_results: 是否保存 evaluate 的结果。如果为 True ,在保存 topk 模型的 folder 中还将额外保存一个 | |||
``fastnlp_evaluate_results.json`` 文件,记录当前的 metric results 。仅在设置了 topk 的场景下有用,默认为 True 。 | |||
:param kwargs: 更多需要传递给 Trainer.save() 或者 Trainer.save_model() 接口的参数。 | |||
:param kwargs: 更多需要传递给 Trainer.save_checkpoint() 或者 Trainer.save_model() 接口的参数。 | |||
""" | |||
def __init__(self, topk:int=0, monitor:str=None, larger_better:bool=True, folder:str=None, save_object:str='model', | |||
only_state_dict:bool=True, model_save_fn:Callable=None, save_evaluate_results:bool=True, | |||
@@ -595,7 +595,7 @@ class Trainer(TrainerEventTrigger): | |||
self.driver.barrier() | |||
self.driver.zero_grad(self.set_grad_to_none) | |||
while self.cur_epoch_idx < self.n_epochs: | |||
# 这个是防止在 Trainer.load 之后还没结束当前 epoch 又继续 save | |||
# 这个是防止在 Trainer.load_checkpoint 之后还没结束当前 epoch 又继续 save | |||
self.start_batch_idx_in_epoch = self.trainer_state.batch_idx_in_epoch | |||
self.driver.set_model_mode("train") | |||
self.on_train_epoch_begin() | |||
@@ -1018,7 +1018,7 @@ class Trainer(TrainerEventTrigger): | |||
注意您需要在初始化 ``Trainer`` 后再通过 ``trainer`` 实例来调用该函数;这意味着您需要保证在保存和加载时使用的 ``driver`` 是属于同一个 | |||
训练框架的,例如都是 ``pytorch`` 或者 ``paddle``; | |||
注意在大多数情况下您不需要使用该函数,如果您需要断点重训功能,您可以直接使用 ``trainer.load`` 函数; | |||
注意在大多数情况下您不需要使用该函数,如果您需要断点重训功能,您可以直接使用 ``trainer.load_checkpoint`` 函数; | |||
该函数在通常情况下和 ``save_model`` 函数配套使用;其参数均与 ``save_model`` 函数成对应关系; | |||
""" | |||
@@ -1045,9 +1045,10 @@ class Trainer(TrainerEventTrigger): | |||
self.driver.load_model(folder, only_state_dict, **kwargs) | |||
self.driver.barrier() | |||
def save(self, folder: Union[str, Path], only_state_dict: bool = True, model_save_fn: Optional[Callable] = None, **kwargs): | |||
def save_checkpoint(self, folder: Union[str, Path], only_state_dict: bool = True, model_save_fn: Optional[Callable] = None, **kwargs): | |||
r""" | |||
用于帮助您实现断点重训功能的保存函数; | |||
用于帮助您实现断点重训功能的保存函数;保存内容包括:callback 状态、Trainer 的状态、Sampler 的状态【在恢复的时候才能恢复到特定 batch 】、 | |||
模型参数、optimizer的状态、fp16 Scaler的状态【如果有】。 | |||
:param folder: 保存在哪个文件夹下,会在该文件下声称两个文件:fastnlp_checkpoint.pkl.tar 与 fastnlp_model.pkl.tar 。 | |||
如果 model_save_fn 不为空,则没有 fastnlp_model.pkl.tar 文件; | |||
@@ -1061,12 +1062,12 @@ class Trainer(TrainerEventTrigger): | |||
注意如果您需要在训练的过程中使用断点重训功能,您可以直接使用 **``CheckpointCallback``**; | |||
``CheckpointCallback`` 的使用具体见 :class:`fastNLP.core.callbacks.checkpoint_callback.CheckpointCallback`; | |||
这意味着在大多数时刻您并不需要自己主动地调用该函数来保存 ``Trainer`` 的状态;当然您可以在自己定制的 callback 类中通过直接调用 ``trainer.save`` 来保存 ``Trainer`` 的状态; | |||
这意味着在大多数时刻您并不需要自己主动地调用该函数来保存 ``Trainer`` 的状态;当然您可以在自己定制的 callback 类中通过直接调用 ``trainer.save_checkpoint`` 来保存 ``Trainer`` 的状态; | |||
具体实际的保存状态的操作由具体的 driver 实现,这意味着对于不同的 ``Driver`` 来说,保存的操作可能是不尽相同的, | |||
您如果想要了解保存 ``Trainer`` 状态的更多细节,请直接查看各个 ``Driver`` 的 ``save`` 函数; | |||
``save`` 函数和 ``load`` 函数是配套使用的; | |||
``save_checkpoint`` 函数和 ``load_checkpoint`` 函数是配套使用的; | |||
.. note:: | |||
@@ -1111,14 +1112,14 @@ class Trainer(TrainerEventTrigger): | |||
if not callable(model_save_fn): | |||
raise ValueError("Parameter `model_save_fn` should be `Callable` type when it is not None.") | |||
rank_zero_call(model_save_fn)(folder) | |||
self.driver.save(folder=folder, dataloader=self.dataloader, states=states, should_save_model=False, **kwargs) | |||
self.driver.save_checkpoint(folder=folder, dataloader=self.dataloader, states=states, should_save_model=False, **kwargs) | |||
else: | |||
self.driver.save(folder=folder, dataloader=self.dataloader, states=states, | |||
self.driver.save_checkpoint(folder=folder, dataloader=self.dataloader, states=states, | |||
only_state_dict=only_state_dict, should_save_model=True, **kwargs) | |||
self.driver.barrier() | |||
def load(self, folder: str, resume_training: bool = True, only_state_dict: bool = True, | |||
def load_checkpoint(self, folder: str, resume_training: bool = True, only_state_dict: bool = True, | |||
model_load_fn: Optional[Callable] = None, **kwargs): | |||
r""" | |||
用于帮助您实现断点重训功能的加载函数; | |||
@@ -1128,22 +1129,22 @@ class Trainer(TrainerEventTrigger): | |||
只会加载 ``model`` 和 ``optimizers`` 的状态;而其余对象的值则根据用户的 ``Trainer`` 的初始化直接重置; | |||
:param only_state_dict: 保存的 ``model`` 是否只保存了权重; | |||
:param model_load_fn: 使用的模型加载函数,参数应为一个文件夹,注意该函数不需要返回任何内容;您可以传入该参数来定制自己的加载模型的操作, | |||
当该参数不为 None 时,我们默认加载模型由该函数完成,``trainer.load`` 函数则会把 ``trainer`` 的其余状态加载好; | |||
当该参数不为 None 时,我们默认加载模型由该函数完成,``trainer.load_checkpoint`` 函数则会把 ``trainer`` 的其余状态加载好; | |||
.. note:: | |||
在 fastNLP 中,断点重训的保存和加载的逻辑是完全分离的,这意味着您在第二次训练时可以将 ``CheckpointCallback`` 从 ``trainer`` 中 | |||
去除,而直接使用 ``trainer.load`` 函数加载 ``trainer`` 的状态来进行断点重训; | |||
去除,而直接使用 ``trainer.load_checkpoint`` 函数加载 ``trainer`` 的状态来进行断点重训; | |||
该函数在通常情况下和 ``save`` 函数配套使用;其参数与 ``save`` 函数成对应关系; | |||
该函数在通常情况下和 ``save_checkpoint`` 函数配套使用;其参数与 ``save_checkpoint`` 函数成对应关系; | |||
对于在前后两次训练 ``Driver`` 不同的情况时使用断点重训,请参考 :meth:`fastNLP.core.controllers.trainer.Trainer.load` 函数的 ``warning``; | |||
对于在前后两次训练 ``Driver`` 不同的情况时使用断点重训,请参考 :meth:`fastNLP.core.controllers.trainer.Trainer.load_checkpoint` 函数的 ``warning``; | |||
Example:: | |||
trainer = Trainer(...) | |||
trainer.load(folder='/path-to-your-saved_checkpoint_folder/', ...) | |||
trainer.load_checkpoint(folder='/path-to-your-saved_checkpoint_folder/', ...) | |||
trainer.run() | |||
@@ -1161,9 +1162,9 @@ class Trainer(TrainerEventTrigger): | |||
if not callable(model_load_fn): | |||
raise ValueError("Parameter `model_save_fn` should be `Callable`.") | |||
model_load_fn(folder) | |||
states = self.driver.load(folder=folder, dataloader=dataloader, should_load_model=False, **kwargs) | |||
states = self.driver.load_checkpoint(folder=folder, dataloader=dataloader, should_load_model=False, **kwargs) | |||
else: | |||
states = self.driver.load(folder=folder, dataloader=dataloader, only_state_dict=only_state_dict, should_load_model=True, **kwargs) | |||
states = self.driver.load_checkpoint(folder=folder, dataloader=dataloader, only_state_dict=only_state_dict, should_load_model=True, **kwargs) | |||
except FileNotFoundError as e: | |||
if FASTNLP_CHECKPOINT_FILENAME not in os.listdir(folder) and FASTNLP_MODEL_FILENAME in os.listdir(folder): | |||
logger.error("It seems that you are trying to load the trainer checkpoint from a model checkpoint folder.") | |||
@@ -1184,7 +1185,7 @@ class Trainer(TrainerEventTrigger): | |||
# 这里的原则就是应当使得 '还会产生的batch数量' + 'batch_idx_in_epoch' = '原来不断点训练的batch的总数'。其中由于 | |||
# '还会产生的batch数量' 是由还剩多少 sample 决定的,因此只能通过调整 'batch_idx_in_epoch' 使得等式成立 | |||
self.trainer_state.batch_idx_in_epoch = states.pop('batch_idx_in_epoch') | |||
# 这个是防止用户在 Trainer.load 之后还没结束当前 epoch 又继续 save | |||
# 这个是防止用户在 Trainer.load_checkpoint 之后还没结束当前 epoch 又继续 save_checkpoint | |||
self.start_batch_idx_in_epoch = self.trainer_state.batch_idx_in_epoch | |||
# 5. 恢复所有 callback 的状态; | |||
@@ -49,7 +49,7 @@ class Driver(ABC): | |||
不同 gpu 上出现重复;为 'unrepeatdist' 时,表示该 dataloader 应该保证所有 gpu 上迭代出来的数据合并起来应该刚好等于原始的 | |||
数据,允许不同 gpu 上 batch 的数量不一致。其中 trainer 中 kwargs 的参数 `use_dist_sampler` 为 True 时,该值为 "dist"; | |||
否则为 None ,evaluator 中的 kwargs 的参数 `use_dist_sampler` 为 True 时,该值为 "unrepeatdist",否则为 None; | |||
注意当 dist 为 ReproducibleSampler, ReproducibleBatchSampler 时,是断点重训加载时 driver.load 函数在调用; | |||
注意当 dist 为 ReproducibleSampler, ReproducibleBatchSampler 时,是断点重训加载时 driver.load_checkpoint 函数在调用; | |||
当 dist 为 str 或者 None 时,是 trainer 在初始化时调用该函数; | |||
:param reproducible: 如果为 False ,不要做任何考虑;如果为 True ,需要保证返回的 dataloader 可以保存当前的迭代状态,使得 | |||
@@ -254,39 +254,39 @@ class Driver(ABC): | |||
raise NotImplementedError("Each specific driver should implemented its own `load_model` function.") | |||
@abstractmethod | |||
def save(self, folder, states: Dict, dataloader, only_state_dict: bool = True, should_save_model: bool = True, **kwargs): | |||
def save_checkpoint(self, folder, states: Dict, dataloader, only_state_dict: bool = True, should_save_model: bool = True, **kwargs): | |||
r""" | |||
断点重训的保存函数,该函数会负责保存模型和 optimizers, fp16 的 state_dict;以及模型的保存(若 should_save_model 为 True) | |||
:param folder: 保存断点重训的状态的文件夹;save 函数应该在下面新增两(一)个文件 的 FASTNLP_CHECKPOINT_FILENAME 文件与 | |||
:param folder: 保存断点重训的状态的文件夹;save_checkpoint 函数应该在下面新增两(一)个文件 的 FASTNLP_CHECKPOINT_FILENAME 文件与 | |||
FASTNLP_MODEL_FILENAME (如果 should_save_model 为 True )。把 model 相关的内容放入到 FASTNLP_MODEL_FILENAME 文件 | |||
中,将传入的 states 以及自身产生其它状态一并保存在 FASTNLP_CHECKPOINT_FILENAME 里面。 | |||
:param states: 由 trainer 传入的一个字典,其中已经包含了为了实现断点重训所需要保存的其它对象的状态,Driver 应该只需要保存 | |||
该对象即可, Driver 应该不需要理解该对象,同时在 driver.load() 的时候,需要将 states 返回回去,load() 返回的值与这里的 | |||
该对象即可, Driver 应该不需要理解该对象,同时在 driver.load_checkpoint() 的时候,需要将 states 返回回去,load_checkpoint() 返回的值与这里的 | |||
传入的值保持一致。 | |||
:param dataloader: 正在使用的 dataloader,需要保存里面的状态使得之后可以从当前迭代的位置恢复。 | |||
:param only_state_dict: 是否只保存模型的参数,当 should_save_model 为 False ,该参数无效。 | |||
:param should_save_model: 是否应该保存模型,如果为False,Driver 将不负责 model 的保存。 | |||
""" | |||
raise NotImplementedError("Each specific driver should implemented its own `save` function.") | |||
raise NotImplementedError("Each specific driver should implemented its own `save_checkpoint` function.") | |||
@abstractmethod | |||
def load(self, folder: Union[str, Path], dataloader, only_state_dict: bool =True, should_load_model: bool = True, **kwargs) -> Dict: | |||
def load_checkpoint(self, folder: Union[str, Path], dataloader, only_state_dict: bool =True, should_load_model: bool = True, **kwargs) -> Dict: | |||
r""" | |||
断点重训的加载函数,注意该函数会负责读取数据,并且恢复 optimizers , fp16 的 state_dict 和 模型(根据 should_load_model )和; | |||
其它在 Driver.save() 函数中执行的保存操作,然后将一个 state 字典返回给 trainer ( 内容为Driver.save() 接受到的 states )。 | |||
其它在 Driver.save_checkpoint() 函数中执行的保存操作,然后将一个 state 字典返回给 trainer ( 内容为Driver.save_checkpoint() 接受到的 states )。 | |||
该函数应该在所有 rank 上执行。 | |||
:param folder: 读取该 folder 下的 FASTNLP_CHECKPOINT_FILENAME 文件与 FASTNLP_MODEL_FILENAME | |||
(如果 should_load_model 为True)。 | |||
:param dataloader: 当前给定 dataloader,需要根据 save 的 dataloader 状态合理设置。若该值为 None ,是不需要返回 'dataloader' | |||
:param dataloader: 当前给定 dataloader,需要根据保存的 dataloader 状态合理设置。若该值为 None ,是不需要返回 'dataloader' | |||
以及 'batch_idx_in_epoch' 这两个值。 | |||
:param only_state_dict: 读取的,当 should_save_model 为 False ,该参数无效。如果为 True ,说明保存的内容为权重;如果为 | |||
False 说明保存的是模型,但也是通过当前 Driver 的模型去加载保存的模型的权重,而不是使用保存的模型替换当前模型。 | |||
:param should_load_model: 是否应该加载模型,如果为False,Driver 将不负责加载模型。若该参数为 True ,但在保存的状态中没有 | |||
找到对应的模型状态,则报错。 | |||
:return: 需要返回 save 函数输入的 states 内容 | |||
:return: 需要返回 save_checkpoint 函数输入的 states 内容 | |||
'dataloader',返回的是根据传入的 dataloader 与 保存的状态一起设置为合理的状态,可以返回的对象与传入的dataloader是同一个。 | |||
在保存与当前传入 data sample 数目不一致时报错。 | |||
'batch_idx_in_epoch': int 类型的数据,表明当前 epoch 进行到了进行到了第几个 batch 了。 请注意,该值不能是只能通过保存的 | |||
@@ -297,7 +297,7 @@ class Driver(ABC): | |||
当drop_last为True,等同于 floor(sample_in_this_rank/batch_size) - floor(num_left_samples/batch_size); | |||
当drop_last为False,等同于 ceil(sample_in_this_rank/batch_size) - ceil(num_left_samples/batch_size)。 | |||
""" | |||
raise NotImplementedError("Each specific driver should implemented its own `load` function.") | |||
raise NotImplementedError("Each specific driver should implemented its own `load_checkpoint` function.") | |||
@staticmethod | |||
def tensor_to_numeric(tensor, reduce: Optional[str]=None): | |||
@@ -109,10 +109,10 @@ class JittorDriver(Driver): | |||
raise FileNotFoundError("Checkpoint at {} not found.".format(filepath)) | |||
return jt.load(filepath) | |||
def save(self): | |||
def save_checkpoint(self): | |||
... | |||
def load(self): | |||
def load_checkpoint(self): | |||
... | |||
def get_evaluate_context(self): | |||
@@ -409,7 +409,7 @@ class PaddleFleetDriver(PaddleDriver): | |||
# 暂时不支持iterableDataset | |||
assert dataloader.dataset_kind != _DatasetKind.ITER, \ | |||
"FastNLP does not support `IteratorDataset` now." | |||
# 如果 dist 为 ReproducibleBatchSampler, ReproducibleSampler 说明是在断点重训时 driver.load 函数调用; | |||
# 如果 dist 为 ReproducibleBatchSampler, ReproducibleSampler 说明是在断点重训时 driver.load_checkpoint 函数调用; | |||
if isinstance(dist, ReproducibleBatchSampler): | |||
dist.set_distributed( | |||
num_replicas=self.world_size, | |||
@@ -207,7 +207,7 @@ class PaddleDriver(Driver): | |||
model.load_dict(paddle.load(filepath)) | |||
@rank_zero_call | |||
def save(self, folder: Path, states: Dict, dataloader, only_state_dict: bool = True, should_save_model: bool = True, **kwargs): | |||
def save_checkpoint(self, folder: Path, states: Dict, dataloader, only_state_dict: bool = True, should_save_model: bool = True, **kwargs): | |||
r""" | |||
断点重训的保存函数,该函数会负责保存模型和 optimizers, fp16 的 state_dict;以及模型的保存(若 should_save_model 为 True) | |||
@@ -215,7 +215,7 @@ class PaddleDriver(Driver): | |||
FASTNLP_MODEL_FILENAME (如果 should_save_model 为 True )。把 model 相关的内容放入到 FASTNLP_MODEL_FILENAME 文件中, | |||
将传入的 states 以及自身产生其它状态一并保存在 FASTNLP_CHECKPOINT_FILENAME 里面。 | |||
:param states: 由 trainer 传入的一个字典,其中已经包含了为了实现断点重训所需要保存的其它对象的状态,Driver 应该只需要保存该对象即可, | |||
Driver 应该不需要理解该对象,同时在 driver.load() 的时候,需要将 states 返回回去,load() 返回的值与这里的传入的值保持一致。 | |||
Driver 应该不需要理解该对象,同时在 driver.load_checkpoint() 的时候,需要将 states 返回回去,load() 返回的值与这里的传入的值保持一致。 | |||
:param dataloader: 正在使用的 dataloader,需要保存里面的状态使得之后可以从当前迭代的位置恢复。 | |||
:param only_state_dict: 是否只保存模型的参数,当 should_save_model 为 False ,该参数无效。 | |||
:param should_save_model: 是否应该保存模型,如果为False,Driver 将不负责 model 的保存。 | |||
@@ -297,7 +297,7 @@ class PaddleDriver(Driver): | |||
paddle.save(states, str(folder.joinpath(FASTNLP_CHECKPOINT_FILENAME))) | |||
def load(self, folder: Path, dataloader, only_state_dict: bool = True, should_load_model: bool = True, **kwargs) -> Dict: | |||
def load_checkpoint(self, folder: Path, dataloader, only_state_dict: bool = True, should_load_model: bool = True, **kwargs) -> Dict: | |||
states = paddle.load(str(folder.joinpath(FASTNLP_CHECKPOINT_FILENAME))) | |||
@@ -106,7 +106,7 @@ class PaddleSingleDriver(PaddleDriver): | |||
# 暂时不支持iterableDataset | |||
assert dataloader.dataset_kind != _DatasetKind.ITER, \ | |||
"FastNLP does not support `IteratorDataset` now." | |||
# 如果 dist 为 ReproducibleBatchSampler, ReproducibleIterator 说明是在断点重训时 driver.load 函数调用; | |||
# 如果 dist 为 ReproducibleBatchSampler, ReproducibleIterator 说明是在断点重训时 driver.load_checkpoint 函数调用; | |||
if isinstance(dist, ReproducibleBatchSampler): | |||
return replace_batch_sampler(dataloader, dist) | |||
elif isinstance(dist, ReproducibleSampler): | |||
@@ -425,7 +425,7 @@ class TorchDDPDriver(TorchDriver): | |||
def set_dist_repro_dataloader(self, dataloader, dist: Optional[Union[str, ReproducibleSampler, ReproducibleBatchSampler]]=None, | |||
reproducible: bool = False): | |||
# 如果 dist 为 ReproducibleBatchSampler, ReproducibleSampler 说明是在断点重训时 driver.load 函数调用; | |||
# 如果 dist 为 ReproducibleBatchSampler, ReproducibleSampler 说明是在断点重训时 driver.load_checkpoint 函数调用; | |||
# 注意这里不需要调用 dist_sampler.set_distributed;因为如果用户使用的是 TorchDDPDriver,那么其在 Trainer 初始化的时候就已经调用了该函数; | |||
if isinstance(dist, ReproducibleBatchSampler): | |||
dist.set_distributed( | |||
@@ -96,7 +96,7 @@ class TorchSingleDriver(TorchDriver): | |||
dist: Union[str, ReproducibleBatchSampler, ReproducibleSampler] = None, | |||
reproducible: bool = False): | |||
# 如果 dist 为 ReproducibleBatchSampler, ReproducibleIterator 说明是在断点重训时 driver.load 函数调用; | |||
# 如果 dist 为 ReproducibleBatchSampler, ReproducibleIterator 说明是在断点重训时 driver.load_checkpoint 函数调用; | |||
if isinstance(dist, ReproducibleBatchSampler): | |||
return replace_batch_sampler(dataloader, dist) | |||
elif isinstance(dist, ReproducibleSampler): | |||
@@ -179,7 +179,7 @@ class TorchDriver(Driver): | |||
model.load_state_dict(res.state_dict()) | |||
@rank_zero_call | |||
def save(self, folder: Path, states: Dict, dataloader, only_state_dict: bool = True, should_save_model: bool = True, **kwargs): | |||
def save_checkpoint(self, folder: Path, states: Dict, dataloader, only_state_dict: bool = True, should_save_model: bool = True, **kwargs): | |||
# 传入的 dataloader 参数是 trainer 的 dataloader 属性,因为 driver 的所有 dataloader 我们是不会去改变它的,而是通过改变 | |||
# trainer.dataloader 来改变 dataloader 的状态,从而适配训练或者评测环境; | |||
@@ -253,7 +253,7 @@ class TorchDriver(Driver): | |||
states["optimizers_state_dict"] = optimizers_state_dict | |||
torch.save(states, Path(folder).joinpath(FASTNLP_CHECKPOINT_FILENAME)) | |||
def load(self, folder: Path, dataloader, only_state_dict: bool = True, should_load_model: bool = True, **kwargs) -> Dict: | |||
def load_checkpoint(self, folder: Path, dataloader, only_state_dict: bool = True, should_load_model: bool = True, **kwargs) -> Dict: | |||
states = torch.load(folder.joinpath(FASTNLP_CHECKPOINT_FILENAME)) | |||
# 1. 加载 optimizers 的状态; | |||
@@ -0,0 +1,15 @@ | |||
""" | |||
torch 可使用的几种 Embedding 。 | |||
""" | |||
__all__ = [ | |||
"CNNCharEmbedding", | |||
"LSTMCharEmbedding", | |||
"Embedding", | |||
"StackEmbedding", | |||
"StaticEmbedding" | |||
] | |||
from .char_embedding import * | |||
from .embedding import * | |||
from .stack_embedding import * | |||
from .static_embedding import StaticEmbedding |
@@ -0,0 +1,287 @@ | |||
r""" | |||
该文件中主要包含的是character的Embedding,包括基于CNN与LSTM的character Embedding。与其它Embedding一样,这里的Embedding输入也是 | |||
词的index而不需要使用词语中的char的index来获取表达。 | |||
""" | |||
__all__ = [ | |||
"CNNCharEmbedding", | |||
"LSTMCharEmbedding" | |||
] | |||
from typing import List | |||
from ...envs.imports import _NEED_IMPORT_TORCH | |||
if _NEED_IMPORT_TORCH: | |||
import torch | |||
import torch.nn as nn | |||
import torch.nn.functional as F | |||
from .embedding import TokenEmbedding | |||
from .static_embedding import StaticEmbedding | |||
from .utils import _construct_char_vocab_from_vocab | |||
from .utils import get_embeddings | |||
from ...core import logger | |||
from ...core.vocabulary import Vocabulary | |||
from ...modules.torch.encoder.lstm import LSTM | |||
class CNNCharEmbedding(TokenEmbedding): | |||
r""" | |||
使用 ``CNN`` 生成 ``character embedding``。``CNN`` 的结构为, char_embed(x) -> Dropout(x) -> CNN(x) -> activation(x) -> pool -> fc -> Dropout. | |||
不同的 ``kernel`` 大小的 ``fitler`` 结果是拼起来然后通过一层``fully connected layer,`` 然后输出``word``的表示。 | |||
Example:: | |||
>>> import torch | |||
>>> from fastNLP import Vocabulary | |||
>>> from fastNLP.embeddings.torch import CNNCharEmbedding | |||
>>> vocab = Vocabulary().add_word_lst("The whether is good .".split()) | |||
>>> embed = CNNCharEmbedding(vocab, embed_size=50) | |||
>>> words = torch.LongTensor([[vocab.to_index(word) for word in "The whether is good .".split()]]) | |||
>>> outputs = embed(words) | |||
>>> outputs.size() | |||
# torch.Size([1, 5,50]) | |||
""" | |||
def __init__(self, vocab: Vocabulary, embed_size: int = 50, char_emb_size: int = 50, word_dropout: float = 0, | |||
dropout: float = 0, filter_nums: List[int] = (40, 30, 20), kernel_sizes: List[int] = (5, 3, 1), | |||
pool_method: str = 'max', activation='relu', min_char_freq: int = 2, pre_train_char_embed: str = None, | |||
requires_grad:bool=True, include_word_start_end:bool=True): | |||
r""" | |||
:param vocab: 词表 | |||
:param embed_size: 该CNNCharEmbedding的输出维度大小,默认值为50. | |||
:param char_emb_size: character的embed的维度。character是从vocab中生成的。默认值为50. | |||
:param word_dropout: 以多大的概率将一个词替换为unk。这样既可以训练unk也是一定的regularize。 | |||
:param dropout: 以多大的概率drop分布式表示与char embedding的输出。 | |||
:param filter_nums: filter的数量. 长度需要和kernels一致。默认值为[40, 30, 20]. | |||
:param kernel_sizes: kernel的大小. 默认值为[5, 3, 1]. | |||
:param pool_method: character的表示在合成一个表示时所使用的pool方法,支持'avg', 'max'. | |||
:param activation: CNN之后使用的激活方法,支持'relu', 'sigmoid', 'tanh' 或者自定义函数. | |||
:param min_char_freq: character的最少出现次数。默认值为2. | |||
:param pre_train_char_embed: 可以有两种方式调用预训练好的character embedding:第一种是传入embedding文件夹 | |||
(文件夹下应该只有一个以.txt作为后缀的文件)或文件路径;第二种是传入embedding的名称,第二种情况将自动查看缓存中是否存在该模型, | |||
没有的话将自动下载。如果输入为None则使用embedding_dim的维度随机初始化一个embedding. | |||
:param requires_grad: 是否更新权重 | |||
:param include_word_start_end: 是否在每个word开始的character前和结束的character增加特殊标示符号; | |||
""" | |||
super(CNNCharEmbedding, self).__init__(vocab, word_dropout=word_dropout, dropout=dropout) | |||
for kernel in kernel_sizes: | |||
assert kernel % 2 == 1, "Only odd kernel is allowed." | |||
assert pool_method in ('max', 'avg') | |||
self.pool_method = pool_method | |||
# activation function | |||
if isinstance(activation, str): | |||
if activation.lower() == 'relu': | |||
self.activation = F.relu | |||
elif activation.lower() == 'sigmoid': | |||
self.activation = F.sigmoid | |||
elif activation.lower() == 'tanh': | |||
self.activation = F.tanh | |||
elif activation is None: | |||
self.activation = lambda x: x | |||
elif callable(activation): | |||
self.activation = activation | |||
else: | |||
raise Exception( | |||
"Undefined activation function: choose from: [relu, tanh, sigmoid, or a callable function]") | |||
logger.info("Start constructing character vocabulary.") | |||
# 建立char的词表 | |||
self.char_vocab = _construct_char_vocab_from_vocab(vocab, min_freq=min_char_freq, | |||
include_word_start_end=include_word_start_end) | |||
self.char_pad_index = self.char_vocab.padding_idx | |||
logger.info(f"In total, there are {len(self.char_vocab)} distinct characters.") | |||
# 对vocab进行index | |||
max_word_len = max(map(lambda x: len(x[0]), vocab)) | |||
if include_word_start_end: | |||
max_word_len += 2 | |||
self.register_buffer('words_to_chars_embedding', torch.full((len(vocab), max_word_len), | |||
fill_value=self.char_pad_index, dtype=torch.long)) | |||
self.register_buffer('word_lengths', torch.zeros(len(vocab)).long()) | |||
for word, index in vocab: | |||
# if index!=vocab.padding_idx: # 如果是pad的话,直接就为pad_value了。修改为不区分pad, 这样所有的<pad>也是同一个embed | |||
if include_word_start_end: | |||
word = ['<bow>'] + list(word) + ['<eow>'] | |||
self.words_to_chars_embedding[index, :len(word)] = \ | |||
torch.LongTensor([self.char_vocab.to_index(c) for c in word]) | |||
self.word_lengths[index] = len(word) | |||
# self.char_embedding = nn.Embedding(len(self.char_vocab), char_emb_size) | |||
if pre_train_char_embed: | |||
self.char_embedding = StaticEmbedding(self.char_vocab, model_dir_or_name=pre_train_char_embed) | |||
else: | |||
self.char_embedding = get_embeddings((len(self.char_vocab), char_emb_size)) | |||
self.convs = nn.ModuleList([nn.Conv1d( | |||
self.char_embedding.embedding_dim, filter_nums[i], kernel_size=kernel_sizes[i], bias=True, | |||
padding=kernel_sizes[i] // 2) | |||
for i in range(len(kernel_sizes))]) | |||
self._embed_size = embed_size | |||
self.fc = nn.Linear(sum(filter_nums), embed_size) | |||
self.requires_grad = requires_grad | |||
def forward(self, words): | |||
r""" | |||
输入words的index后,生成对应的words的表示。 | |||
:param words: [batch_size, max_len] | |||
:return: [batch_size, max_len, embed_size] | |||
""" | |||
words = self.drop_word(words) | |||
batch_size, max_len = words.size() | |||
chars = self.words_to_chars_embedding[words] # batch_size x max_len x max_word_len | |||
word_lengths = self.word_lengths[words] # batch_size x max_len | |||
max_word_len = word_lengths.max() | |||
chars = chars[:, :, :max_word_len] | |||
# 为1的地方为mask | |||
chars_masks = chars.eq(self.char_pad_index) # batch_size x max_len x max_word_len 如果为0, 说明是padding的位置了 | |||
chars = self.char_embedding(chars) # batch_size x max_len x max_word_len x embed_size | |||
chars = self.dropout(chars) | |||
reshaped_chars = chars.reshape(batch_size * max_len, max_word_len, -1) | |||
reshaped_chars = reshaped_chars.transpose(1, 2) # B' x E x M | |||
conv_chars = [conv(reshaped_chars).transpose(1, 2).reshape(batch_size, max_len, max_word_len, -1) | |||
for conv in self.convs] | |||
conv_chars = torch.cat(conv_chars, dim=-1).contiguous() # B x max_len x max_word_len x sum(filters) | |||
conv_chars = self.activation(conv_chars) | |||
if self.pool_method == 'max': | |||
conv_chars = conv_chars.masked_fill(chars_masks.unsqueeze(-1), float('-inf')) | |||
chars, _ = torch.max(conv_chars, dim=-2) # batch_size x max_len x sum(filters) | |||
else: | |||
conv_chars = conv_chars.masked_fill(chars_masks.unsqueeze(-1), 0) | |||
chars = torch.sum(conv_chars, dim=-2) / chars_masks.eq(False).sum(dim=-1, keepdim=True).float() | |||
chars = self.fc(chars) | |||
return self.dropout(chars) | |||
class LSTMCharEmbedding(TokenEmbedding): | |||
r""" | |||
使用 ``LSTM`` 的方式对 ``character`` 进行 ``encode``. embed(x) -> Dropout(x) -> LSTM(x) -> activation(x) -> pool -> Dropout | |||
Example:: | |||
>>> import torch | |||
>>> from fastNLP import Vocabulary | |||
>>> from fastNLP.embeddings.torch import LSTMCharEmbedding | |||
>>> vocab = Vocabulary().add_word_lst("The whether is good .".split()) | |||
>>> embed = LSTMCharEmbedding(vocab, embed_size=50) | |||
>>> words = torch.LongTensor([[vocab.to_index(word) for word in "The whether is good .".split()]]) | |||
>>> outputs = embed(words) | |||
>>> outputs.size() | |||
>>> # torch.Size([1, 5,50]) | |||
""" | |||
def __init__(self, vocab: Vocabulary, embed_size: int = 50, char_emb_size: int = 50, word_dropout: float = 0, | |||
dropout: float = 0, hidden_size=50, pool_method: str = 'max', activation='relu', | |||
min_char_freq: int = 2, bidirectional=True, pre_train_char_embed: str = None, | |||
requires_grad:bool=True, include_word_start_end:bool=True): | |||
r""" | |||
:param vocab: 词表 | |||
:param embed_size: LSTMCharEmbedding的输出维度。默认值为50. | |||
:param char_emb_size: character的embedding的维度。默认值为50. | |||
:param word_dropout: 以多大的概率将一个词替换为unk。这样既可以训练unk也是一定的regularize。 | |||
:param dropout: 以多大概率drop character embedding的输出以及最终的word的输出。 | |||
:param hidden_size: LSTM的中间hidden的大小,如果为bidirectional的,hidden会除二,默认为50. | |||
:param pool_method: 支持'max', 'avg'。 | |||
:param activation: 激活函数,支持'relu', 'sigmoid', 'tanh', 或者自定义函数. | |||
:param min_char_freq: character的最小出现次数。默认值为2. | |||
:param bidirectional: 是否使用双向的LSTM进行encode。默认值为True。 | |||
:param pre_train_char_embed: 可以有两种方式调用预训练好的character embedding:第一种是传入embedding文件夹 | |||
(文件夹下应该只有一个以.txt作为后缀的文件)或文件路径;第二种是传入embedding的名称,第二种情况将自动查看缓存中是否存在该模型, | |||
没有的话将自动下载。如果输入为None则使用embedding_dim的维度随机初始化一个embedding. | |||
:param requires_grad: 是否更新权重 | |||
:param include_word_start_end: 是否在每个word开始的character前和结束的character增加特殊标示符号; | |||
""" | |||
super(LSTMCharEmbedding, self).__init__(vocab, word_dropout=word_dropout, dropout=dropout) | |||
assert hidden_size % 2 == 0, "Only even kernel is allowed." | |||
assert pool_method in ('max', 'avg') | |||
self.pool_method = pool_method | |||
# activation function | |||
if isinstance(activation, str): | |||
if activation.lower() == 'relu': | |||
self.activation = F.relu | |||
elif activation.lower() == 'sigmoid': | |||
self.activation = F.sigmoid | |||
elif activation.lower() == 'tanh': | |||
self.activation = F.tanh | |||
elif activation is None: | |||
self.activation = lambda x: x | |||
elif callable(activation): | |||
self.activation = activation | |||
else: | |||
raise Exception( | |||
"Undefined activation function: choose from: [relu, tanh, sigmoid, or a callable function]") | |||
logger.info("Start constructing character vocabulary.") | |||
# 建立char的词表 | |||
self.char_vocab = _construct_char_vocab_from_vocab(vocab, min_freq=min_char_freq, | |||
include_word_start_end=include_word_start_end) | |||
self.char_pad_index = self.char_vocab.padding_idx | |||
logger.info(f"In total, there are {len(self.char_vocab)} distinct characters.") | |||
# 对vocab进行index | |||
max_word_len = max(map(lambda x: len(x[0]), vocab)) | |||
if include_word_start_end: | |||
max_word_len += 2 | |||
self.register_buffer('words_to_chars_embedding', torch.full((len(vocab), max_word_len), | |||
fill_value=self.char_pad_index, dtype=torch.long)) | |||
self.register_buffer('word_lengths', torch.zeros(len(vocab)).long()) | |||
for word, index in vocab: | |||
# if index!=vocab.padding_idx: # 如果是pad的话,直接就为pad_value了. 修改为不区分pad与否 | |||
if include_word_start_end: | |||
word = ['<bow>'] + list(word) + ['<eow>'] | |||
self.words_to_chars_embedding[index, :len(word)] = \ | |||
torch.LongTensor([self.char_vocab.to_index(c) for c in word]) | |||
self.word_lengths[index] = len(word) | |||
if pre_train_char_embed: | |||
self.char_embedding = StaticEmbedding(self.char_vocab, pre_train_char_embed) | |||
else: | |||
self.char_embedding = get_embeddings((len(self.char_vocab), char_emb_size)) | |||
self.fc = nn.Linear(hidden_size, embed_size) | |||
hidden_size = hidden_size // 2 if bidirectional else hidden_size | |||
self.lstm = LSTM(self.char_embedding.embedding_dim, hidden_size, bidirectional=bidirectional, batch_first=True) | |||
self._embed_size = embed_size | |||
self.bidirectional = bidirectional | |||
self.requires_grad = requires_grad | |||
def forward(self, words): | |||
r""" | |||
输入words的index后,生成对应的words的表示。 | |||
:param words: [batch_size, max_len] | |||
:return: [batch_size, max_len, embed_size] | |||
""" | |||
words = self.drop_word(words) | |||
batch_size, max_len = words.size() | |||
chars = self.words_to_chars_embedding[words] # batch_size x max_len x max_word_len | |||
word_lengths = self.word_lengths[words] # batch_size x max_len | |||
max_word_len = word_lengths.max() | |||
chars = chars[:, :, :max_word_len] | |||
# 为mask的地方为1 | |||
chars_masks = chars.eq(self.char_pad_index) # batch_size x max_len x max_word_len 如果为0, 说明是padding的位置了 | |||
chars = self.char_embedding(chars) # batch_size x max_len x max_word_len x embed_size | |||
chars = self.dropout(chars) | |||
reshaped_chars = chars.reshape(batch_size * max_len, max_word_len, -1) | |||
char_seq_len = chars_masks.eq(False).sum(dim=-1).reshape(batch_size * max_len) | |||
lstm_chars = self.lstm(reshaped_chars, char_seq_len)[0].reshape(batch_size, max_len, max_word_len, -1) | |||
# B x M x M x H | |||
lstm_chars = self.activation(lstm_chars) | |||
if self.pool_method == 'max': | |||
lstm_chars = lstm_chars.masked_fill(chars_masks.unsqueeze(-1), float('-inf')) | |||
chars, _ = torch.max(lstm_chars, dim=-2) # batch_size x max_len x H | |||
else: | |||
lstm_chars = lstm_chars.masked_fill(chars_masks.unsqueeze(-1), 0) | |||
chars = torch.sum(lstm_chars, dim=-2) / chars_masks.eq(False).sum(dim=-1, keepdim=True).float() | |||
chars = self.fc(chars) | |||
return self.dropout(chars) |
@@ -0,0 +1,220 @@ | |||
r""" | |||
该模块中的Embedding主要用于随机初始化的embedding(更推荐使用 :class:`fastNLP.embeddings.StaticEmbedding` ),或按照预训练权重初始化Embedding。 | |||
""" | |||
__all__ = [ | |||
"Embedding", | |||
] | |||
from abc import abstractmethod | |||
from typing import Union, Tuple | |||
from ...envs.imports import _NEED_IMPORT_TORCH | |||
if _NEED_IMPORT_TORCH: | |||
import torch | |||
from torch.nn import Module | |||
from torch import nn | |||
else: | |||
from ...core.utils.dummy_class import DummyClass as Module | |||
import numpy as np | |||
from .utils import get_embeddings | |||
class Embedding(Module): | |||
r""" | |||
词向量嵌入,支持输入多种方式初始化. 可以通过 ``self.num_embeddings`` 获取词表大小; ``self.embedding_dim`` 获取 ``embedding`` 的维度. | |||
Example:: | |||
>>> import numpy as np | |||
>>> from fastNLP.embeddings.torch import Embedding | |||
>>> init_embed = (2000, 100) | |||
>>> embed = Embedding(init_embed) # 随机初始化一个具有2000个词,每个词表示为100维的词向量 | |||
>>> init_embed = np.zeros((2000, 100)) | |||
>>> embed = Embedding(init_embed) # 使用numpy.ndarray的值作为初始化值初始化一个Embedding | |||
""" | |||
def __init__(self, init_embed:Union[Tuple[int,int],'torch.FloatTensor','nn.Embedding',np.ndarray], | |||
word_dropout:float=0, dropout:float=0.0, unk_index:int=None): | |||
r""" | |||
:param init_embed: 支持传入Embedding的大小(传入tuple(int, int), | |||
第一个int为vocab_zie, 第二个int为embed_dim); 或传入Tensor, Embedding, numpy.ndarray等则直接使用该值初始化Embedding; | |||
:param word_dropout: 按照一定概率随机将word设置为unk_index,这样可以使得unk这个token得到足够的训练, 且会对网络有 | |||
一定的regularize的作用。设置该值时,必须同时设置unk_index | |||
:param dropout: 对Embedding的输出的dropout。 | |||
:param unk_index: drop word时替换为的index。fastNLP的Vocabulary的unk_index默认为1。 | |||
""" | |||
super(Embedding, self).__init__() | |||
self.embed = get_embeddings(init_embed) | |||
self.dropout = nn.Dropout(dropout) | |||
if not isinstance(self.embed, TokenEmbedding): | |||
if hasattr(self.embed, 'embed_size'): | |||
self._embed_size = self.embed.embed_size | |||
elif hasattr(self.embed, 'embedding_dim'): | |||
self._embed_size = self.embed.embedding_dim | |||
else: | |||
self._embed_size = self.embed.weight.size(1) | |||
if word_dropout > 0 and not isinstance(unk_index, int): | |||
raise ValueError("When drop word is set, you need to pass in the unk_index.") | |||
else: | |||
self._embed_size = self.embed.embed_size | |||
unk_index = self.embed.get_word_vocab().unknown_idx | |||
self.unk_index = unk_index | |||
self.word_dropout = word_dropout | |||
def forward(self, words): | |||
r""" | |||
:param torch.LongTensor words: [batch, seq_len] | |||
:return: torch.Tensor : [batch, seq_len, embed_dim] | |||
""" | |||
if self.word_dropout > 0 and self.training: | |||
mask = torch.ones_like(words).float() * self.word_dropout | |||
mask = torch.bernoulli(mask).eq(1) # dropout_word越大,越多位置为1 | |||
words = words.masked_fill(mask, self.unk_index) | |||
words = self.embed(words) | |||
return self.dropout(words) | |||
@property | |||
def num_embedding(self) -> int: | |||
if isinstance(self.embed, nn.Embedding): | |||
return self.embed.weight.size(0) | |||
else: | |||
return self.embed.num_embeddings | |||
def __len__(self): | |||
return len(self.embed) | |||
@property | |||
def embed_size(self) -> int: | |||
return self._embed_size | |||
@property | |||
def embedding_dim(self) -> int: | |||
return self._embed_size | |||
@property | |||
def requires_grad(self): | |||
r""" | |||
Embedding的参数是否允许优化。True: 所有参数运行优化; False: 所有参数不允许优化; None: 部分允许优化、部分不允许 | |||
:return: | |||
""" | |||
if not isinstance(self.embed, TokenEmbedding): | |||
return self.embed.weight.requires_grad | |||
else: | |||
return self.embed.requires_grad | |||
@requires_grad.setter | |||
def requires_grad(self, value): | |||
if not isinstance(self.embed, TokenEmbedding): | |||
self.embed.weight.requires_grad = value | |||
else: | |||
self.embed.requires_grad = value | |||
@property | |||
def size(self): | |||
if isinstance(self.embed, TokenEmbedding): | |||
return self.embed.size | |||
else: | |||
return self.embed.weight.size() | |||
class TokenEmbedding(Module): | |||
r""" | |||
fastNLP中各种Embedding的基类 | |||
""" | |||
def __init__(self, vocab, word_dropout=0.0, dropout=0.0): | |||
super(TokenEmbedding, self).__init__() | |||
if vocab.rebuild: | |||
vocab.build_vocab() | |||
assert vocab.padding is not None, "Vocabulary must have a padding entry." | |||
self._word_vocab = vocab | |||
self._word_pad_index = vocab.padding_idx | |||
if word_dropout > 0: | |||
assert vocab.unknown is not None, "Vocabulary must have unknown entry when you want to drop a word." | |||
self.word_dropout = word_dropout | |||
self._word_unk_index = vocab.unknown_idx | |||
self.dropout_layer = nn.Dropout(dropout) | |||
def drop_word(self, words): | |||
r""" | |||
按照设定随机将words设置为unknown_index。 | |||
:param torch.LongTensor words: batch_size x max_len | |||
:return: | |||
""" | |||
if self.word_dropout > 0 and self.training: | |||
mask = torch.full_like(words, fill_value=self.word_dropout, dtype=torch.float, device=words.device) | |||
mask = torch.bernoulli(mask).eq(1) # dropout_word越大,越多位置为1 | |||
pad_mask = words.ne(self._word_pad_index) | |||
mask = mask.__and__(pad_mask) | |||
words = words.masked_fill(mask, self._word_unk_index) | |||
return words | |||
def dropout(self, words): | |||
r""" | |||
对embedding后的word表示进行drop。 | |||
:param torch.FloatTensor words: batch_size x max_len x embed_size | |||
:return: | |||
""" | |||
return self.dropout_layer(words) | |||
@property | |||
def requires_grad(self): | |||
r""" | |||
Embedding的参数是否允许优化。True: 所有参数运行优化; False: 所有参数不允许优化; None: 部分允许优化、部分不允许 | |||
:return: | |||
""" | |||
requires_grads = set([param.requires_grad for param in self.parameters()]) | |||
if len(requires_grads) == 1: | |||
return requires_grads.pop() | |||
else: | |||
return None | |||
@requires_grad.setter | |||
def requires_grad(self, value): | |||
for param in self.parameters(): | |||
param.requires_grad = value | |||
def __len__(self): | |||
return len(self._word_vocab) | |||
@property | |||
def embed_size(self) -> int: | |||
return self._embed_size | |||
@property | |||
def embedding_dim(self) -> int: | |||
return self._embed_size | |||
@property | |||
def num_embeddings(self) -> int: | |||
r""" | |||
这个值可能会大于实际的embedding矩阵的大小。 | |||
:return: | |||
""" | |||
return len(self._word_vocab) | |||
def get_word_vocab(self): | |||
r""" | |||
返回embedding的词典。 | |||
:return: Vocabulary | |||
""" | |||
return self._word_vocab | |||
@property | |||
def size(self): | |||
return torch.Size(self.num_embeddings, self._embed_size) | |||
@abstractmethod | |||
def forward(self, words): | |||
raise NotImplementedError |
@@ -0,0 +1,101 @@ | |||
r""" | |||
.. todo:: | |||
doc | |||
""" | |||
__all__ = [ | |||
"StackEmbedding", | |||
] | |||
from typing import List | |||
from ...envs.imports import _NEED_IMPORT_TORCH | |||
if _NEED_IMPORT_TORCH: | |||
import torch | |||
from torch import nn | |||
from .embedding import TokenEmbedding | |||
from .utils import _check_vocab_has_same_index | |||
class StackEmbedding(TokenEmbedding): | |||
r""" | |||
支持将多个embedding集合成一个embedding。 | |||
Example:: | |||
>>> from fastNLP import Vocabulary | |||
>>> from fastNLP.embeddings.torch import StaticEmbedding, StackEmbedding | |||
>>> vocab = Vocabulary().add_word_lst("The whether is good .".split()) | |||
>>> embed_1 = StaticEmbedding(vocab, model_dir_or_name='en-glove-6b-50d', requires_grad=True) | |||
>>> embed_2 = StaticEmbedding(vocab, model_dir_or_name='en-word2vec-300', requires_grad=True) | |||
>>> embed = StackEmbedding([embed_1, embed_2]) | |||
""" | |||
def __init__(self, embeds: List[TokenEmbedding], word_dropout=0, dropout=0): | |||
r""" | |||
:param embeds: 一个由若干个TokenEmbedding组成的list,要求每一个TokenEmbedding的词表都保持一致 | |||
:param word_dropout: 以多大的概率将一个词替换为unk。这样既可以训练unk也是一定的regularize。不同embedidng会在相同的位置 | |||
被设置为unknown。如果这里设置了dropout,则组成的embedding就不要再设置dropout了。 | |||
:param dropout: 以多大的概率对embedding的表示进行Dropout。0.1即随机将10%的值置为0。 | |||
""" | |||
vocabs = [] | |||
for embed in embeds: | |||
if hasattr(embed, 'get_word_vocab'): | |||
vocabs.append(embed.get_word_vocab()) | |||
_vocab = vocabs[0] | |||
for vocab in vocabs[1:]: | |||
if _vocab!=vocab: | |||
_check_vocab_has_same_index(_vocab, vocab) | |||
super(StackEmbedding, self).__init__(_vocab, word_dropout=word_dropout, dropout=dropout) | |||
assert isinstance(embeds, list) | |||
for embed in embeds: | |||
assert isinstance(embed, TokenEmbedding), "Only TokenEmbedding type is supported." | |||
self.embeds = nn.ModuleList(embeds) | |||
self._embed_size = sum([embed.embed_size for embed in self.embeds]) | |||
def append(self, embed: TokenEmbedding): | |||
r""" | |||
添加一个embedding到结尾。 | |||
:param embed: | |||
:return: | |||
""" | |||
assert isinstance(embed, TokenEmbedding) | |||
_check_vocab_has_same_index(self.get_word_vocab(), embed.get_word_vocab()) | |||
self._embed_size += embed.embed_size | |||
self.embeds.append(embed) | |||
return self | |||
def pop(self): | |||
r""" | |||
弹出最后一个embed | |||
:return: | |||
""" | |||
embed = self.embeds.pop() | |||
self._embed_size -= embed.embed_size | |||
return embed | |||
@property | |||
def embed_size(self): | |||
r""" | |||
该Embedding输出的vector的最后一维的维度。 | |||
:return: | |||
""" | |||
return self._embed_size | |||
def forward(self, words): | |||
r""" | |||
得到多个embedding的结果,并把结果按照顺序concat起来。 | |||
:param words: batch_size x max_len | |||
:return: 返回的shape和当前这个stack embedding中embedding的组成有关 | |||
""" | |||
outputs = [] | |||
words = self.drop_word(words) | |||
for embed in self.embeds: | |||
outputs.append(embed(words)) | |||
outputs = self.dropout(torch.cat(outputs, dim=-1)) | |||
return outputs |
@@ -0,0 +1,407 @@ | |||
r""" | |||
.. todo:: | |||
doc | |||
""" | |||
__all__ = [ | |||
"StaticEmbedding" | |||
] | |||
import os | |||
import warnings | |||
from collections import defaultdict | |||
from copy import deepcopy | |||
import json | |||
from typing import Union | |||
import numpy as np | |||
import torch | |||
import torch.nn as nn | |||
from .embedding import TokenEmbedding | |||
from ...core import logger | |||
from ...core.vocabulary import Vocabulary | |||
from ...io.file_utils import PRETRAIN_STATIC_FILES, _get_embedding_url, cached_path | |||
from ...io.file_utils import _get_file_name_base_on_postfix | |||
VOCAB_FILENAME = 'vocab.txt' | |||
STATIC_HYPER_FILENAME = 'static_hyper.json' | |||
STATIC_EMBED_FILENAME = 'static.txt' | |||
class StaticEmbedding(TokenEmbedding): | |||
r""" | |||
StaticEmbedding组件. 给定预训练embedding的名称或路径,根据vocab从embedding中抽取相应的数据(只会将出现在vocab中的词抽取出来, | |||
如果没有找到,则会随机初始化一个值(但如果该word是被标记为no_create_entry的话,则不会单独创建一个值,而是会被指向unk的index))。 | |||
当前支持自动下载的预训练vector有: | |||
.. code:: | |||
en: 实际为en-glove-840b-300d(常用) | |||
en-glove-6b-50d: glove官方的50d向量 | |||
en-glove-6b-100d: glove官方的100d向量 | |||
en-glove-6b-200d: glove官方的200d向量 | |||
en-glove-6b-300d: glove官方的300d向量 | |||
en-glove-42b-300d: glove官方使用42B数据训练版本 | |||
en-glove-840b-300d: | |||
en-glove-twitter-27b-25d: | |||
en-glove-twitter-27b-50d: | |||
en-glove-twitter-27b-100d: | |||
en-glove-twitter-27b-200d: | |||
en-word2vec-300d: word2vec官方发布的300d向量 | |||
en-fasttext-crawl: fasttext官方发布的300d英文预训练 | |||
cn-char-fastnlp-100d: fastNLP训练的100d的character embedding | |||
cn-bi-fastnlp-100d: fastNLP训练的100d的bigram embedding | |||
cn-tri-fastnlp-100d: fastNLP训练的100d的trigram embedding | |||
cn-fasttext: fasttext官方发布的300d中文预训练embedding | |||
Example:: | |||
>>> from fastNLP import Vocabulary | |||
>>> from fastNLP.embeddings.torch import StaticEmbedding | |||
>>> vocab = Vocabulary().add_word_lst("The whether is good .".split()) | |||
>>> embed = StaticEmbedding(vocab, model_dir_or_name='en-glove-50d') | |||
>>> vocab = Vocabulary().add_word_lst(["The", 'the', "THE"]) | |||
>>> embed = StaticEmbedding(vocab, model_dir_or_name="en-glove-50d", lower=True) | |||
>>> # "the", "The", "THE"它们共用一个vector,且将使用"the"在预训练词表中寻找它们的初始化表示。 | |||
>>> vocab = Vocabulary().add_word_lst(["The", "the", "THE"]) | |||
>>> embed = StaticEmbedding(vocab, model_dir_or_name=None, embedding_dim=5, lower=True) | |||
>>> words = torch.LongTensor([[vocab.to_index(word) for word in ["The", "the", "THE"]]]) | |||
>>> embed(words) | |||
>>> tensor([[[ 0.5773, 0.7251, -0.3104, 0.0777, 0.4849], | |||
[ 0.5773, 0.7251, -0.3104, 0.0777, 0.4849], | |||
[ 0.5773, 0.7251, -0.3104, 0.0777, 0.4849]]], | |||
grad_fn=<EmbeddingBackward>) # 每种word的输出是一致的。 | |||
""" | |||
def __init__(self, vocab: Vocabulary, model_dir_or_name: Union[str, None] = 'en', embedding_dim=-1, requires_grad: bool = True, | |||
init_method=None, lower=False, dropout=0, word_dropout=0, normalize=False, min_freq=1, **kwargs): | |||
r""" | |||
:param Vocabulary vocab: 词表. StaticEmbedding只会加载包含在词表中的词的词向量,在预训练向量中没找到的使用随机初始化 | |||
:param model_dir_or_name: 可以有两种方式调用预训练好的static embedding:第一种是传入embedding文件夹(文件夹下应该只有一个 | |||
以.txt作为后缀的文件)或文件路径;第二种是传入embedding的名称,第二种情况将自动查看缓存中是否存在该模型,没有的话将自动下载。 | |||
如果输入为None则使用embedding_dim的维度随机初始化一个embedding。 | |||
:param embedding_dim: 随机初始化的embedding的维度,当该值为大于0的值时,将忽略model_dir_or_name。 | |||
:param requires_grad: 是否需要gradient. 默认为True | |||
:param callable init_method: 如何初始化没有找到的值。可以使用torch.nn.init.*中各种方法, 传入的方法应该接受一个tensor,并 | |||
inplace地修改其值。 | |||
:param lower: 是否将vocab中的词语小写后再和预训练的词表进行匹配。如果你的词表中包含大写的词语,或者就是需要单独 | |||
为大写的词语开辟一个vector表示,则将lower设置为False。 | |||
:param dropout: 以多大的概率对embedding的表示进行Dropout。0.1即随机将10%的值置为0。 | |||
:param word_dropout: 以多大的概率将一个词替换为unk。这样既可以训练unk也是一定的regularize。 | |||
:param normalize: 是否对vector进行normalize,使得每个vector的norm为1。 | |||
:param min_freq: Vocabulary词频数小于这个数量的word将被指向unk。 | |||
:param kwargs: | |||
* only_train_min_freq * (*bool*) -- 仅对 train 中的词语使用 ``min_freq`` 筛选; | |||
* only_norm_found_vector * (*bool*) -- 默认为False, 是否仅对在预训练中找到的词语使用normalize; | |||
* only_use_pretrain_word * (*bool*) -- 默认为False, 仅使用出现在pretrain词表中的词,如果该词没有在预训练的词表中出现 | |||
则为unk。如果embedding不需要更新建议设置为True。 | |||
""" | |||
super(StaticEmbedding, self).__init__(vocab, word_dropout=word_dropout, dropout=dropout) | |||
if embedding_dim > 0: | |||
if model_dir_or_name: | |||
logger.info(f"StaticEmbedding will ignore `model_dir_or_name`, and randomly initialize embedding with" | |||
f" dimension {embedding_dim}. If you want to use pre-trained embedding, " | |||
f"set `embedding_dim` to 0.") | |||
model_dir_or_name = None | |||
# 得到cache_path | |||
if model_dir_or_name is None: | |||
assert embedding_dim >= 1, "The dimension of embedding should be larger than 1." | |||
embedding_dim = int(embedding_dim) | |||
model_path = None | |||
elif model_dir_or_name.lower() in PRETRAIN_STATIC_FILES: | |||
model_url = _get_embedding_url('static', model_dir_or_name.lower()) | |||
model_path = cached_path(model_url, name='embedding') | |||
# 检查是否存在 | |||
elif os.path.isfile(os.path.abspath(os.path.expanduser(model_dir_or_name))): | |||
model_path = os.path.abspath(os.path.expanduser(model_dir_or_name)) | |||
elif os.path.isdir(os.path.abspath(os.path.expanduser(model_dir_or_name))): | |||
model_path = _get_file_name_base_on_postfix(os.path.abspath(os.path.expanduser(model_dir_or_name)), '.txt') | |||
else: | |||
raise ValueError(f"Cannot recognize {model_dir_or_name}.") | |||
kwargs['min_freq'] = min_freq | |||
kwargs['lower'] = lower | |||
# 根据min_freq缩小vocab | |||
truncate_vocab = (vocab.min_freq is None and min_freq > 1) or (vocab.min_freq and vocab.min_freq < min_freq) | |||
if truncate_vocab: | |||
truncated_vocab = deepcopy(vocab) | |||
truncated_vocab.min_freq = min_freq | |||
truncated_vocab.word2idx = None | |||
if lower: # 如果有lower,将大小写的的freq需要同时考虑到 | |||
lowered_word_count = defaultdict(int) | |||
for word, count in truncated_vocab.word_count.items(): | |||
lowered_word_count[word.lower()] += count | |||
for word in truncated_vocab.word_count.keys(): | |||
word_count = truncated_vocab.word_count[word] | |||
if lowered_word_count[word.lower()] >= min_freq and word_count < min_freq: | |||
truncated_vocab.add_word_lst([word] * (min_freq - word_count), | |||
no_create_entry=truncated_vocab._is_word_no_create_entry(word)) | |||
# 只限制在train里面的词语使用min_freq筛选 | |||
if kwargs.get('only_train_min_freq', False) and model_dir_or_name is not None: | |||
for word in truncated_vocab.word_count.keys(): | |||
if truncated_vocab._is_word_no_create_entry(word) and truncated_vocab.word_count[word] < min_freq: | |||
truncated_vocab.add_word_lst([word] * (min_freq - truncated_vocab.word_count[word]), | |||
no_create_entry=True) | |||
truncated_vocab.build_vocab() | |||
truncated_words_to_words = torch.arange(len(vocab)).long() | |||
for word, index in vocab: | |||
truncated_words_to_words[index] = truncated_vocab.to_index(word) | |||
logger.info(f"{len(vocab) - len(truncated_vocab)} words have frequency less than {min_freq}.") | |||
vocab = truncated_vocab | |||
self.only_use_pretrain_word = kwargs.get('only_use_pretrain_word', False) | |||
self.only_norm_found_vector = kwargs.get('only_norm_found_vector', False) | |||
# 读取embedding | |||
if lower: | |||
lowered_vocab = Vocabulary(padding=vocab.padding, unknown=vocab.unknown) | |||
for word, index in vocab: | |||
if vocab._is_word_no_create_entry(word): | |||
lowered_vocab.add_word(word.lower(), no_create_entry=True) | |||
else: | |||
lowered_vocab.add_word(word.lower()) # 先加入需要创建entry的 | |||
logger.info(f"All word in the vocab have been lowered. There are {len(vocab)} words, {len(lowered_vocab)} " | |||
f"unique lowered words.") | |||
if model_path: | |||
embedding = self._load_with_vocab(model_path, vocab=lowered_vocab, init_method=init_method) | |||
else: | |||
embedding = self._randomly_init_embed(len(lowered_vocab), embedding_dim, init_method) | |||
self.register_buffer('words_to_words', torch.arange(len(vocab)).long()) | |||
if lowered_vocab.unknown: | |||
unknown_idx = lowered_vocab.unknown_idx | |||
else: | |||
unknown_idx = embedding.size(0) - 1 # 否则是最后一个为unknow | |||
self.register_buffer('words_to_words', torch.arange(len(vocab)).long()) | |||
words_to_words = torch.full((len(vocab),), fill_value=unknown_idx, dtype=torch.long).long() | |||
for word, index in vocab: | |||
if word not in lowered_vocab: | |||
word = word.lower() | |||
if word not in lowered_vocab and lowered_vocab._is_word_no_create_entry(word): | |||
continue # 如果不需要创建entry,已经默认unknown了 | |||
words_to_words[index] = self.words_to_words[lowered_vocab.to_index(word)] | |||
self.register_buffer('words_to_words', words_to_words) | |||
self._word_unk_index = lowered_vocab.unknown_idx # 替换一下unknown的index | |||
else: | |||
if model_path: | |||
embedding = self._load_with_vocab(model_path, vocab=vocab, init_method=init_method) | |||
else: | |||
embedding = self._randomly_init_embed(len(vocab), embedding_dim, init_method) | |||
self.register_buffer('words_to_words', torch.arange(len(vocab)).long()) | |||
if not self.only_norm_found_vector and normalize: | |||
embedding /= (torch.norm(embedding, dim=1, keepdim=True) + 1e-12) | |||
if truncate_vocab: | |||
for i in range(len(truncated_words_to_words)): | |||
index_in_truncated_vocab = truncated_words_to_words[i] | |||
truncated_words_to_words[i] = self.words_to_words[index_in_truncated_vocab] | |||
del self.words_to_words | |||
self.register_buffer('words_to_words', truncated_words_to_words) | |||
self.embedding = nn.Embedding(num_embeddings=embedding.shape[0], embedding_dim=embedding.shape[1], | |||
padding_idx=vocab.padding_idx, | |||
max_norm=None, norm_type=2, scale_grad_by_freq=False, | |||
sparse=False, _weight=embedding) | |||
self._embed_size = self.embedding.weight.size(1) | |||
self.requires_grad = requires_grad | |||
self.kwargs = kwargs | |||
@property | |||
def weight(self): | |||
return self.embedding.weight | |||
def _randomly_init_embed(self, num_embedding, embedding_dim, init_embed=None): | |||
r""" | |||
:param int num_embedding: embedding的entry的数量 | |||
:param int embedding_dim: embedding的维度大小 | |||
:param callable init_embed: 初始化方法 | |||
:return: torch.FloatTensor | |||
""" | |||
embed = torch.zeros(num_embedding, embedding_dim) | |||
if init_embed is None: | |||
nn.init.uniform_(embed, -np.sqrt(3 / embedding_dim), np.sqrt(3 / embedding_dim)) | |||
else: | |||
init_embed(embed) | |||
return embed | |||
def _load_with_vocab(self, embed_filepath, vocab, dtype=np.float32, padding='<pad>', unknown='<unk>', | |||
error='ignore', init_method=None): | |||
r""" | |||
从embed_filepath这个预训练的词向量中抽取出vocab这个词表的词的embedding。EmbedLoader将自动判断embed_filepath是 | |||
word2vec(第一行只有两个元素)还是glove格式的数据。 | |||
:param str embed_filepath: 预训练的embedding的路径。 | |||
:param vocab: 词表 :class:`~fastNLP.Vocabulary` 类型,读取出现在vocab中的词的embedding。 | |||
没有出现在vocab中的词的embedding将通过找到的词的embedding的正态分布采样出来,以使得整个Embedding是同分布的。 | |||
:param dtype: 读出的embedding的类型 | |||
:param str padding: 词表中padding的token | |||
:param str unknown: 词表中unknown的token | |||
:param str error: `ignore` , `strict` ; 如果 `ignore` ,错误将自动跳过; 如果 `strict` , 错误将抛出。 | |||
这里主要可能出错的地方在于词表有空行或者词表出现了维度不一致。 | |||
:param init_method: 如何初始化没有找到的值。可以使用torch.nn.init.*中各种方法。默认使用torch.nn.init.zeros_ | |||
:return torch.tensor: shape为 [len(vocab), dimension], dimension由pretrain的embedding决定。 | |||
""" | |||
assert isinstance(vocab, Vocabulary), "Only fastNLP.Vocabulary is supported." | |||
if not os.path.exists(embed_filepath): | |||
raise FileNotFoundError("`{}` does not exist.".format(embed_filepath)) | |||
with open(embed_filepath, 'r', encoding='utf-8') as f: | |||
line = f.readline().strip() | |||
parts = line.split() | |||
start_idx = 0 | |||
if len(parts) == 2: | |||
dim = int(parts[1]) | |||
start_idx += 1 | |||
else: | |||
dim = len(parts) - 1 | |||
f.seek(0) | |||
matrix = {} # index是word在vocab中的index,value是vector或None(如果在pretrain中没有找到该word) | |||
if vocab.padding: | |||
matrix[vocab.padding_idx] = torch.zeros(dim) | |||
if vocab.unknown: | |||
matrix[vocab.unknown_idx] = torch.zeros(dim) | |||
found_count = 0 | |||
found_unknown = False | |||
for idx, line in enumerate(f, start_idx): | |||
try: | |||
parts = line.strip().split() | |||
word = ''.join(parts[:-dim]) | |||
nums = parts[-dim:] | |||
# 对齐unk与pad | |||
if word == padding and vocab.padding is not None: | |||
word = vocab.padding | |||
elif word == unknown and vocab.unknown is not None: | |||
word = vocab.unknown | |||
found_unknown = True | |||
if word in vocab: | |||
index = vocab.to_index(word) | |||
if index in matrix: | |||
warnings.warn(f"Word has more than one vector in embedding file. Set logger level to " | |||
f"DEBUG for detail.") | |||
logger.debug(f"Word:{word} occurs again in line:{idx}(starts from 0)") | |||
matrix[index] = torch.from_numpy(np.fromstring(' '.join(nums), sep=' ', dtype=dtype, count=dim)) | |||
if self.only_norm_found_vector: | |||
matrix[index] = matrix[index] / np.linalg.norm(matrix[index]) | |||
found_count += 1 | |||
except Exception as e: | |||
if error == 'ignore': | |||
warnings.warn("Error occurred at the {} line.".format(idx)) | |||
else: | |||
logger.error("Error occurred at the {} line.".format(idx)) | |||
raise e | |||
logger.info("Found {} out of {} words in the pre-training embedding.".format(found_count, len(vocab))) | |||
if not self.only_use_pretrain_word: # 如果只用pretrain中的值就不要为未找到的词创建entry了 | |||
for word, index in vocab: | |||
if index not in matrix and not vocab._is_word_no_create_entry(word): | |||
if found_unknown: # 如果有unkonwn,用unknown初始化 | |||
matrix[index] = matrix[vocab.unknown_idx] | |||
else: | |||
matrix[index] = None | |||
# matrix中代表是需要建立entry的词 | |||
vectors = self._randomly_init_embed(len(matrix), dim, init_method) | |||
if vocab.unknown is None: # 创建一个专门的unknown | |||
unknown_idx = len(matrix) | |||
vectors = torch.cat((vectors, torch.zeros(1, dim)), dim=0).contiguous() | |||
else: | |||
unknown_idx = vocab.unknown_idx | |||
self.register_buffer('words_to_words', torch.full((len(vocab), ), fill_value=unknown_idx, dtype=torch.long).long()) | |||
index = 0 | |||
for word, index_in_vocab in vocab: | |||
if index_in_vocab in matrix: | |||
vec = matrix.get(index_in_vocab) | |||
if vec is not None: # 使用找到的vector, 如果为None说明需要训练 | |||
vectors[index] = vec | |||
self.words_to_words[index_in_vocab] = index | |||
index += 1 | |||
return vectors | |||
def forward(self, words): | |||
r""" | |||
传入words的index | |||
:param words: torch.LongTensor, [batch_size, max_len] | |||
:return: torch.FloatTensor, [batch_size, max_len, embed_size] | |||
""" | |||
if hasattr(self, 'words_to_words'): | |||
words = self.words_to_words[words] | |||
words = self.drop_word(words) | |||
words = self.embedding(words) | |||
words = self.dropout(words) | |||
return words | |||
def save(self, folder): | |||
""" | |||
将embedding存储到folder下,之后可以通过使用load方法读取 | |||
:param str folder: 会在该folder下生成三个文件, vocab.txt, static_embed_hyper.txt, static_embed_hyper.json. | |||
其中vocab.txt可以用Vocabulary通过load读取; embedding.txt按照word2vec的方式存储,以空格的方式隔开元素, | |||
第一行只有两个元素,剩下的行首先是word然后是各个维度的值; static_embed_hyper.json是StaticEmbedding的超参数 | |||
:return: | |||
""" | |||
os.makedirs(folder, exist_ok=True) | |||
vocab = self.get_word_vocab() | |||
vocab_fp = os.path.join(folder, VOCAB_FILENAME) | |||
vocab.save(vocab_fp) | |||
kwargs = self.kwargs.copy() | |||
kwargs['dropout'] = self.dropout_layer.p | |||
kwargs['word_dropout'] = self.word_dropout | |||
kwargs['requires_grad'] = self.requires_grad | |||
kwargs['only_norm_found_vector'] = False | |||
kwargs['only_use_pretrain_word'] = True | |||
with open(os.path.join(folder, STATIC_HYPER_FILENAME), 'w', encoding='utf-8') as f: | |||
json.dump(kwargs, f, indent=2) | |||
with open(os.path.join(folder, STATIC_EMBED_FILENAME), 'w', encoding='utf-8') as f: | |||
f.write('{}\n'.format(' '*30)) # 留白之后再来填写 | |||
word_count = 0 | |||
saved_word = {} | |||
valid_word_count = 0 | |||
for i in range(len(self.words_to_words)): | |||
word = vocab.to_word(i) | |||
if not vocab._is_word_no_create_entry(word): | |||
word_count += 1 | |||
if kwargs['lower']: | |||
word = word.lower() | |||
if word in saved_word: | |||
continue | |||
saved_word[word] = 1 | |||
vec_i = self.words_to_words[i] | |||
if vec_i==vocab.unknown_idx and i!=vocab.unknown_idx: | |||
continue | |||
vec = self.embedding.weight.data[vec_i].tolist() | |||
vec_str = ' '.join(map(str, vec)) | |||
f.write(f'{word} {vec_str}\n') | |||
valid_word_count += 1 | |||
f.seek(0) | |||
f.write('{} {}'.format(valid_word_count, self.embedding_dim)) | |||
logger.debug(f"StaticEmbedding has been saved to {folder}.") | |||
@classmethod | |||
def load(cls, folder): | |||
""" | |||
:param str folder: 该folder下应该有以下三个文件vocab.txt, static_embed.txt, static_hyper.json | |||
:return: | |||
""" | |||
for name in [VOCAB_FILENAME, STATIC_EMBED_FILENAME, STATIC_HYPER_FILENAME]: | |||
assert os.path.exists(os.path.join(folder, name)), f"{name} not found in {folder}." | |||
vocab = Vocabulary.load(os.path.join(folder, VOCAB_FILENAME)) | |||
with open(os.path.join(folder, STATIC_HYPER_FILENAME), 'r', encoding='utf-8') as f: | |||
hyper = json.load(f) | |||
logger.info(f"Load StaticEmbedding from {folder}.") | |||
embed = cls(vocab=vocab, model_dir_or_name=os.path.join(folder, STATIC_EMBED_FILENAME), **hyper) | |||
return embed | |||
@@ -0,0 +1,106 @@ | |||
r""" | |||
.. todo:: | |||
doc | |||
""" | |||
import numpy as np | |||
from ...envs.imports import _NEED_IMPORT_TORCH | |||
if _NEED_IMPORT_TORCH: | |||
import torch | |||
from torch import nn as nn | |||
from ...core.vocabulary import Vocabulary | |||
__all__ = [ | |||
'get_embeddings', | |||
'get_sinusoid_encoding_table' | |||
] | |||
def _construct_char_vocab_from_vocab(vocab: Vocabulary, min_freq: int = 1, include_word_start_end=True): | |||
r""" | |||
给定一个word的vocabulary生成character的vocabulary. | |||
:param vocab: 从vocab | |||
:param min_freq: | |||
:param include_word_start_end: 是否需要包含特殊的<bow>和<eos> | |||
:return: | |||
""" | |||
char_vocab = Vocabulary(min_freq=min_freq) | |||
for word, index in vocab: | |||
if not vocab._is_word_no_create_entry(word): | |||
char_vocab.add_word_lst(list(word)) | |||
if include_word_start_end: | |||
char_vocab.add_word_lst(['<bow>', '<eow>']) | |||
return char_vocab | |||
def get_embeddings(init_embed, padding_idx=None): | |||
r""" | |||
根据输入的init_embed返回Embedding对象。如果输入是tuple, 则随机初始化一个nn.Embedding; 如果输入是numpy.ndarray, 则按照ndarray | |||
的值将nn.Embedding初始化; 如果输入是torch.Tensor, 则按该值初始化nn.Embedding; 如果输入是fastNLP中的embedding将不做处理 | |||
返回原对象。 | |||
:param init_embed: 可以是 tuple:(num_embedings, embedding_dim), 即embedding的大小和每个词的维度;也可以传入 | |||
nn.Embedding 对象, 此时就以传入的对象作为embedding; 传入np.ndarray也行,将使用传入的ndarray作为作为Embedding初始化; | |||
传入torch.Tensor, 将使用传入的值作为Embedding初始化。 | |||
:param padding_idx: 当传入tuple时,padding_idx有效 | |||
:return nn.Embedding: embeddings | |||
""" | |||
if isinstance(init_embed, tuple): | |||
res = nn.Embedding( | |||
num_embeddings=init_embed[0], embedding_dim=init_embed[1], padding_idx=padding_idx) | |||
nn.init.uniform_(res.weight.data, a=-np.sqrt(3 / res.weight.data.size(1)), | |||
b=np.sqrt(3 / res.weight.data.size(1))) | |||
elif isinstance(init_embed, nn.Module): | |||
res = init_embed | |||
elif isinstance(init_embed, torch.Tensor): | |||
res = nn.Embedding.from_pretrained(init_embed, freeze=False) | |||
elif isinstance(init_embed, np.ndarray): | |||
init_embed = torch.tensor(init_embed, dtype=torch.float32) | |||
res = nn.Embedding.from_pretrained(init_embed, freeze=False) | |||
else: | |||
raise TypeError( | |||
'invalid init_embed type: {}'.format((type(init_embed)))) | |||
return res | |||
def get_sinusoid_encoding_table(n_position, d_hid, padding_idx=None): | |||
""" | |||
sinusoid的embedding,其中position的表示中,偶数维(0,2,4,...)是sin, 奇数(1,3,5...)是cos | |||
:param int n_position: 一共多少个position | |||
:param int d_hid: 多少维度,需要为偶数 | |||
:param padding_idx: | |||
:return: torch.FloatTensor, shape为n_position x d_hid | |||
""" | |||
def cal_angle(position, hid_idx): | |||
return position / np.power(10000, 2 * (hid_idx // 2) / d_hid) | |||
def get_posi_angle_vec(position): | |||
return [cal_angle(position, hid_j) for hid_j in range(d_hid)] | |||
sinusoid_table = np.array([get_posi_angle_vec(pos_i) for pos_i in range(n_position)]) | |||
sinusoid_table[:, 0::2] = np.sin(sinusoid_table[:, 0::2]) # dim 2i | |||
sinusoid_table[:, 1::2] = np.cos(sinusoid_table[:, 1::2]) # dim 2i+1 | |||
if padding_idx is not None: | |||
# zero vector for padding dimension | |||
sinusoid_table[padding_idx] = 0. | |||
return torch.FloatTensor(sinusoid_table) | |||
def _check_vocab_has_same_index(vocab, other_vocab): | |||
""" | |||
检查两个vocabulary是否含有相同的word idx | |||
:param Vocabulary vocab: | |||
:param Vocabulary other_vocab: | |||
:return: | |||
""" | |||
if other_vocab != vocab: | |||
for word, word_ix in vocab: | |||
other_word_idx = other_vocab.to_index(word) | |||
assert other_word_idx == word_ix, f"Word {word} has different index in vocabs, {word_ix} Vs. {other_word_idx}." |
@@ -109,13 +109,9 @@ __all__ = [ | |||
"CMRC2018BertPipe", | |||
'ModelLoader', | |||
'ModelSaver', | |||
] | |||
from .data_bundle import DataBundle | |||
from .embed_loader import EmbedLoader | |||
from .loader import * | |||
from .model_io import ModelLoader, ModelSaver | |||
from .pipe import * |
@@ -1,71 +0,0 @@ | |||
r""" | |||
用于载入和保存模型 | |||
""" | |||
__all__ = [ | |||
"ModelLoader", | |||
"ModelSaver" | |||
] | |||
from fastNLP.envs.imports import _NEED_IMPORT_TORCH | |||
if _NEED_IMPORT_TORCH: | |||
import torch | |||
class ModelLoader: | |||
r""" | |||
用于读取模型 | |||
""" | |||
def __init__(self): | |||
super(ModelLoader, self).__init__() | |||
@staticmethod | |||
def load_pytorch(empty_model, model_path): | |||
r""" | |||
从 ".pkl" 文件读取 PyTorch 模型 | |||
:param empty_model: 初始化参数的 PyTorch 模型 | |||
:param str model_path: 模型保存的路径 | |||
""" | |||
empty_model.load_state_dict(torch.load(model_path)) | |||
@staticmethod | |||
def load_pytorch_model(model_path): | |||
r""" | |||
读取整个模型 | |||
:param str model_path: 模型保存的路径 | |||
""" | |||
return torch.load(model_path) | |||
class ModelSaver(object): | |||
r""" | |||
用于保存模型 | |||
Example:: | |||
saver = ModelSaver("./save/model_ckpt_100.pkl") | |||
saver.save_pytorch(model) | |||
""" | |||
def __init__(self, save_path): | |||
r""" | |||
:param save_path: 模型保存的路径 | |||
""" | |||
self.save_path = save_path | |||
def save_pytorch(self, model, param_only=True): | |||
r""" | |||
把 PyTorch 模型存入 ".pkl" 文件 | |||
:param model: PyTorch 模型 | |||
:param bool param_only: 是否只保存模型的参数(否则保存整个模型) | |||
""" | |||
if param_only is True: | |||
torch.save(model.state_dict(), self.save_path) | |||
else: | |||
torch.save(model, self.save_path) |
@@ -160,7 +160,7 @@ class TestSetDistReproDataloader: | |||
""" | |||
传入的 `dist` 参数为具体的 ReproducibleSampler 或 ReproducibleBatchSampler 的情况 | |||
此时对应 driver.load 中的情况 | |||
此时对应 driver.load_checkpoint 中的情况 | |||
""" | |||
@magic_argv_env_context | |||
@@ -186,7 +186,7 @@ class TestSetDistReproDataloader: | |||
""" | |||
传入的 `dist` 参数为具体的 ReproducibleSampler 或 ReproducibleBatchSampler 的情况 | |||
此时对应 driver.load 中的情况 | |||
此时对应 driver.load_checkpoint 中的情况 | |||
""" | |||
@magic_argv_env_context | |||
@@ -0,0 +1,29 @@ | |||
import pytest | |||
from fastNLP.envs.imports import _NEED_IMPORT_TORCH | |||
if _NEED_IMPORT_TORCH: | |||
import torch | |||
from fastNLP import Vocabulary, DataSet, Instance | |||
from fastNLP.embeddings.torch.char_embedding import LSTMCharEmbedding, CNNCharEmbedding | |||
class TestCharEmbed: | |||
@pytest.mark.test | |||
def test_case_1(self): | |||
ds = DataSet([Instance(words=['hello', 'world']), Instance(words=['Jack'])]) | |||
vocab = Vocabulary().from_dataset(ds, field_name='words') | |||
assert len(vocab)==5 | |||
embed = LSTMCharEmbedding(vocab, embed_size=3) | |||
x = torch.LongTensor([[2, 1, 0], [4, 3, 4]]) | |||
y = embed(x) | |||
assert tuple(y.size()) == (2, 3, 3) | |||
@pytest.mark.test | |||
def test_case_2(self): | |||
ds = DataSet([Instance(words=['hello', 'world']), Instance(words=['Jack'])]) | |||
vocab = Vocabulary().from_dataset(ds, field_name='words') | |||
assert len(vocab)==5 | |||
embed = CNNCharEmbedding(vocab, embed_size=3) | |||
x = torch.LongTensor([[2, 1, 0], [4, 3, 4]]) | |||
y = embed(x) | |||
assert tuple(y.size()) == (2, 3, 3) |
@@ -0,0 +1,195 @@ | |||
import pytest | |||
import os | |||
from fastNLP.embeddings.torch import StaticEmbedding | |||
from fastNLP import Vocabulary | |||
from fastNLP.envs.imports import _NEED_IMPORT_TORCH | |||
if _NEED_IMPORT_TORCH: | |||
import torch | |||
import numpy as np | |||
tests_folder = os.path.abspath(os.path.join(os.path.dirname(__file__), '../..')) | |||
@pytest.mark.torch | |||
class TestLoad: | |||
def test_norm1(self): | |||
# 测试只对可以找到的norm | |||
vocab = Vocabulary().add_word_lst(['the', 'a', 'notinfile']) | |||
embed = StaticEmbedding(vocab, model_dir_or_name=tests_folder+'/helpers/data/embedding/small_static_embedding/' | |||
'glove.6B.50d_test.txt', | |||
only_norm_found_vector=True) | |||
assert round(torch.norm(embed(torch.LongTensor([[2]]))).item(), 4) == 1 | |||
assert torch.norm(embed(torch.LongTensor([[4]]))).item() != 1 | |||
def test_norm2(self): | |||
# 测试对所有都norm | |||
vocab = Vocabulary().add_word_lst(['the', 'a', 'notinfile']) | |||
embed = StaticEmbedding(vocab, model_dir_or_name=tests_folder+'/helpers/data/embedding/small_static_embedding/' | |||
'glove.6B.50d_test.txt', | |||
normalize=True) | |||
assert round(torch.norm(embed(torch.LongTensor([[2]]))).item(), 4) == 1 | |||
assert round(torch.norm(embed(torch.LongTensor([[4]]))).item(), 4) == 1 | |||
def test_dropword(self): | |||
# 测试是否可以通过drop word | |||
vocab = Vocabulary().add_word_lst([chr(i) for i in range(1, 200)]) | |||
embed = StaticEmbedding(vocab, model_dir_or_name=None, embedding_dim=10, dropout=0.1, word_dropout=0.4) | |||
for i in range(10): | |||
length = torch.randint(1, 50, (1,)).item() | |||
batch = torch.randint(1, 4, (1,)).item() | |||
words = torch.randint(1, 200, (batch, length)).long() | |||
embed(words) | |||
def test_only_use_pretrain_word(self): | |||
def check_word_unk(words, vocab, embed): | |||
for word in words: | |||
assert embed(torch.LongTensor([vocab.to_index(word)])).tolist()[0] == embed(torch.LongTensor([1])).tolist()[0] | |||
def check_vector_equal(words, vocab, embed, embed_dict, lower=False): | |||
for word in words: | |||
index = vocab.to_index(word) | |||
v1 = embed(torch.LongTensor([index])).tolist()[0] | |||
if lower: | |||
word = word.lower() | |||
v2 = embed_dict[word] | |||
for v1i, v2i in zip(v1, v2): | |||
assert np.allclose(v1i, v2i) | |||
embed_dict = read_static_embed(tests_folder+'/helpers/data/embedding/small_static_embedding/' | |||
'glove.6B.50d_test.txt') | |||
# 测试是否只使用pretrain的word | |||
vocab = Vocabulary().add_word_lst(['the', 'a', 'notinfile']) | |||
vocab.add_word('of', no_create_entry=True) | |||
embed = StaticEmbedding(vocab, model_dir_or_name=tests_folder+'/helpers/data/embedding/small_static_embedding/' | |||
'glove.6B.50d_test.txt', | |||
only_use_pretrain_word=True) | |||
# notinfile应该被置为unk | |||
check_vector_equal(['the', 'a', 'of'], vocab, embed, embed_dict) | |||
check_word_unk(['notinfile'], vocab, embed) | |||
# 测试在大小写情况下的使用 | |||
vocab = Vocabulary().add_word_lst(['The', 'a', 'notinfile']) | |||
vocab.add_word('Of', no_create_entry=True) | |||
embed = StaticEmbedding(vocab, model_dir_or_name=tests_folder+'/helpers/data/embedding/small_static_embedding/' | |||
'glove.6B.50d_test.txt', | |||
only_use_pretrain_word=True) | |||
check_word_unk(['The', 'Of', 'notinfile'], vocab, embed) # 这些词应该找不到 | |||
check_vector_equal(['a'], vocab, embed, embed_dict) | |||
embed = StaticEmbedding(vocab, model_dir_or_name=tests_folder+'/helpers/data/embedding/small_static_embedding/' | |||
'glove.6B.50d_test.txt', | |||
only_use_pretrain_word=True, lower=True) | |||
check_vector_equal(['The', 'Of', 'a'], vocab, embed, embed_dict, lower=True) | |||
check_word_unk(['notinfile'], vocab, embed) | |||
# 测试min_freq | |||
vocab = Vocabulary().add_word_lst(['The', 'a', 'notinfile1', 'A', 'notinfile2', 'notinfile2']) | |||
vocab.add_word('Of', no_create_entry=True) | |||
embed = StaticEmbedding(vocab, model_dir_or_name=tests_folder+'/helpers/data/embedding/small_static_embedding/' | |||
'glove.6B.50d_test.txt', | |||
only_use_pretrain_word=True, lower=True, min_freq=2, only_train_min_freq=True) | |||
check_vector_equal(['Of', 'a'], vocab, embed, embed_dict, lower=True) | |||
check_word_unk(['notinfile1', 'The', 'notinfile2'], vocab, embed) | |||
def test_sequential_index(self): | |||
# 当不存在no_create_entry时,words_to_words应该是顺序的 | |||
vocab = Vocabulary().add_word_lst(['The', 'a', 'notinfile1', 'A', 'notinfile2', 'notinfile2']) | |||
embed = StaticEmbedding(vocab, model_dir_or_name=tests_folder+'/helpers/data/embedding/small_static_embedding/' | |||
'glove.6B.50d_test.txt') | |||
for index,i in enumerate(embed.words_to_words): | |||
assert index==i | |||
embed_dict = read_static_embed(tests_folder+'/helpers/data/embedding/small_static_embedding/' | |||
'glove.6B.50d_test.txt') | |||
for word, index in vocab: | |||
if word in embed_dict: | |||
index = vocab.to_index(word) | |||
v1 = embed(torch.LongTensor([index])).tolist()[0] | |||
v2 = embed_dict[word] | |||
for v1i, v2i in zip(v1, v2): | |||
assert np.allclose(v1i, v2i) | |||
def test_save_load_static_embed(self): | |||
static_test_folder = 'static_save_test' | |||
try: | |||
# 测试包含no_create_entry | |||
os.makedirs(static_test_folder, exist_ok=True) | |||
vocab = Vocabulary().add_word_lst(['The', 'a', 'notinfile1', 'A']) | |||
vocab.add_word_lst(['notinfile2', 'notinfile2'], no_create_entry=True) | |||
embed = StaticEmbedding(vocab, model_dir_or_name=tests_folder+'/helpers/data/embedding/small_static_embedding/' | |||
'glove.6B.50d_test.txt') | |||
embed.save(static_test_folder) | |||
load_embed = StaticEmbedding.load(static_test_folder) | |||
words = torch.randint(len(vocab), size=(2, 20)) | |||
assert (embed(words) - load_embed(words)).sum() == 0 | |||
# 测试不包含no_create_entry | |||
vocab = Vocabulary().add_word_lst(['The', 'a', 'notinfile1', 'A']) | |||
embed = StaticEmbedding(vocab, model_dir_or_name=tests_folder+'/helpers/data/embedding/small_static_embedding/' | |||
'glove.6B.50d_test.txt') | |||
embed.save(static_test_folder) | |||
load_embed = StaticEmbedding.load(static_test_folder) | |||
words = torch.randint(len(vocab), size=(2, 20)) | |||
assert (embed(words) - load_embed(words)).sum() == 0 | |||
# 测试lower, min_freq | |||
vocab = Vocabulary().add_word_lst(['The', 'the', 'the', 'A', 'a', 'B']) | |||
embed = StaticEmbedding(vocab, model_dir_or_name=tests_folder+'/helpers/data/embedding/small_static_embedding/' | |||
'glove.6B.50d_test.txt', min_freq=2, lower=True) | |||
embed.save(static_test_folder) | |||
load_embed = StaticEmbedding.load(static_test_folder) | |||
words = torch.randint(len(vocab), size=(2, 20)) | |||
assert (embed(words) - load_embed(words)).sum() == 0 | |||
# 测试random的embedding | |||
vocab = Vocabulary().add_word_lst(['The', 'the', 'the', 'A', 'a', 'B']) | |||
vocab = vocab.add_word_lst(['b'], no_create_entry=True) | |||
embed = StaticEmbedding(vocab, model_dir_or_name=None, embedding_dim=4, min_freq=2, lower=True, | |||
normalize=True) | |||
embed.weight.data += 0.2 # 使得它不是normalize | |||
embed.save(static_test_folder) | |||
load_embed = StaticEmbedding.load(static_test_folder) | |||
words = torch.randint(len(vocab), size=(2, 20)) | |||
assert (embed(words) - load_embed(words)).sum()==0 | |||
finally: | |||
if os.path.isdir(static_test_folder): | |||
import shutil | |||
shutil.rmtree(static_test_folder) | |||
def read_static_embed(fp): | |||
""" | |||
:param str fp: embedding的路径 | |||
:return: {}, key是word, value是vector | |||
""" | |||
embed = {} | |||
with open(fp, 'r') as f: | |||
for line in f: | |||
line = line.strip() | |||
if line: | |||
parts = line.split() | |||
vector = list(map(float, parts[1:])) | |||
word = parts[0] | |||
embed[word] = vector | |||
return embed | |||
@pytest.mark.torch | |||
class TestRandomSameEntry: | |||
def test_same_vector(self): | |||
vocab = Vocabulary().add_word_lst(["The", "the", "THE", 'a', "A"]) | |||
embed = StaticEmbedding(vocab, model_dir_or_name=None, embedding_dim=5, lower=True) | |||
words = torch.LongTensor([[vocab.to_index(word) for word in ["The", "the", "THE", 'a', 'A']]]) | |||
words = embed(words) | |||
embed_0 = words[0, 0] | |||
for i in range(1, 3): | |||
assert torch.sum(embed_0==words[0, i]).eq(len(embed_0)) | |||
embed_0 = words[0, 3] | |||
for i in range(3, 5): | |||
assert torch.sum(embed_0 == words[0, i]).eq(len(embed_0)) |
@@ -0,0 +1,6 @@ | |||
the 0.418 0.24968 -0.41242 0.1217 0.34527 -0.044457 -0.49688 -0.17862 -0.00066023 -0.6566 0.27843 -0.14767 -0.55677 0.14658 -0.0095095 0.011658 0.10204 -0.12792 -0.8443 -0.12181 -0.016801 -0.33279 -0.1552 -0.23131 -0.19181 -1.8823 -0.76746 0.099051 -0.42125 -0.19526 4.0071 -0.18594 -0.52287 -0.31681 0.00059213 0.0074449 0.17778 -0.15897 0.012041 -0.054223 -0.29871 -0.15749 -0.34758 -0.045637 -0.44251 0.18785 0.0027849 -0.18411 -0.11514 -0.78581 | |||
of 0.70853 0.57088 -0.4716 0.18048 0.54449 0.72603 0.18157 -0.52393 0.10381 -0.17566 0.078852 -0.36216 -0.11829 -0.83336 0.11917 -0.16605 0.061555 -0.012719 -0.56623 0.013616 0.22851 -0.14396 -0.067549 -0.38157 -0.23698 -1.7037 -0.86692 -0.26704 -0.2589 0.1767 3.8676 -0.1613 -0.13273 -0.68881 0.18444 0.0052464 -0.33874 -0.078956 0.24185 0.36576 -0.34727 0.28483 0.075693 -0.062178 -0.38988 0.22902 -0.21617 -0.22562 -0.093918 -0.80375 | |||
to 0.68047 -0.039263 0.30186 -0.17792 0.42962 0.032246 -0.41376 0.13228 -0.29847 -0.085253 0.17118 0.22419 -0.10046 -0.43653 0.33418 0.67846 0.057204 -0.34448 -0.42785 -0.43275 0.55963 0.10032 0.18677 -0.26854 0.037334 -2.0932 0.22171 -0.39868 0.20912 -0.55725 3.8826 0.47466 -0.95658 -0.37788 0.20869 -0.32752 0.12751 0.088359 0.16351 -0.21634 -0.094375 0.018324 0.21048 -0.03088 -0.19722 0.082279 -0.09434 -0.073297 -0.064699 -0.26044 | |||
and 0.26818 0.14346 -0.27877 0.016257 0.11384 0.69923 -0.51332 -0.47368 -0.33075 -0.13834 0.2702 0.30938 -0.45012 -0.4127 -0.09932 0.038085 0.029749 0.10076 -0.25058 -0.51818 0.34558 0.44922 0.48791 -0.080866 -0.10121 -1.3777 -0.10866 -0.23201 0.012839 -0.46508 3.8463 0.31362 0.13643 -0.52244 0.3302 0.33707 -0.35601 0.32431 0.12041 0.3512 -0.069043 0.36885 0.25168 -0.24517 0.25381 0.1367 -0.31178 -0.6321 -0.25028 -0.38097 | |||
in 0.33042 0.24995 -0.60874 0.10923 0.036372 0.151 -0.55083 -0.074239 -0.092307 -0.32821 0.09598 -0.82269 -0.36717 -0.67009 0.42909 0.016496 -0.23573 0.12864 -1.0953 0.43334 0.57067 -0.1036 0.20422 0.078308 -0.42795 -1.7984 -0.27865 0.11954 -0.12689 0.031744 3.8631 -0.17786 -0.082434 -0.62698 0.26497 -0.057185 -0.073521 0.46103 0.30862 0.12498 -0.48609 -0.0080272 0.031184 -0.36576 -0.42699 0.42164 -0.11666 -0.50703 -0.027273 -0.53285 | |||
a 0.21705 0.46515 -0.46757 0.10082 1.0135 0.74845 -0.53104 -0.26256 0.16812 0.13182 -0.24909 -0.44185 -0.21739 0.51004 0.13448 -0.43141 -0.03123 0.20674 -0.78138 -0.20148 -0.097401 0.16088 -0.61836 -0.18504 -0.12461 -2.2526 -0.22321 0.5043 0.32257 0.15313 3.9636 -0.71365 -0.67012 0.28388 0.21738 0.14433 0.25926 0.23434 0.4274 -0.44451 0.13813 0.36973 -0.64289 0.024142 -0.039315 -0.26037 0.12017 -0.043782 0.41013 0.1796 |
@@ -0,0 +1,7 @@ | |||
5 50 | |||
the 0.418 0.24968 -0.41242 0.1217 0.34527 -0.044457 -0.49688 -0.17862 -0.00066023 -0.6566 0.27843 -0.14767 -0.55677 0.14658 -0.0095095 0.011658 0.10204 -0.12792 -0.8443 -0.12181 -0.016801 -0.33279 -0.1552 -0.23131 -0.19181 -1.8823 -0.76746 0.099051 -0.42125 -0.19526 4.0071 -0.18594 -0.52287 -0.31681 0.00059213 0.0074449 0.17778 -0.15897 0.012041 -0.054223 -0.29871 -0.15749 -0.34758 -0.045637 -0.44251 0.18785 0.0027849 -0.18411 -0.11514 -0.78581 | |||
of 0.70853 0.57088 -0.4716 0.18048 0.54449 0.72603 0.18157 -0.52393 0.10381 -0.17566 0.078852 -0.36216 -0.11829 -0.83336 0.11917 -0.16605 0.061555 -0.012719 -0.56623 0.013616 0.22851 -0.14396 -0.067549 -0.38157 -0.23698 -1.7037 -0.86692 -0.26704 -0.2589 0.1767 3.8676 -0.1613 -0.13273 -0.68881 0.18444 0.0052464 -0.33874 -0.078956 0.24185 0.36576 -0.34727 0.28483 0.075693 -0.062178 -0.38988 0.22902 -0.21617 -0.22562 -0.093918 -0.80375 | |||
to 0.68047 -0.039263 0.30186 -0.17792 0.42962 0.032246 -0.41376 0.13228 -0.29847 -0.085253 0.17118 0.22419 -0.10046 -0.43653 0.33418 0.67846 0.057204 -0.34448 -0.42785 -0.43275 0.55963 0.10032 0.18677 -0.26854 0.037334 -2.0932 0.22171 -0.39868 0.20912 -0.55725 3.8826 0.47466 -0.95658 -0.37788 0.20869 -0.32752 0.12751 0.088359 0.16351 -0.21634 -0.094375 0.018324 0.21048 -0.03088 -0.19722 0.082279 -0.09434 -0.073297 -0.064699 -0.26044 | |||
and 0.26818 0.14346 -0.27877 0.016257 0.11384 0.69923 -0.51332 -0.47368 -0.33075 -0.13834 0.2702 0.30938 -0.45012 -0.4127 -0.09932 0.038085 0.029749 0.10076 -0.25058 -0.51818 0.34558 0.44922 0.48791 -0.080866 -0.10121 -1.3777 -0.10866 -0.23201 0.012839 -0.46508 3.8463 0.31362 0.13643 -0.52244 0.3302 0.33707 -0.35601 0.32431 0.12041 0.3512 -0.069043 0.36885 0.25168 -0.24517 0.25381 0.1367 -0.31178 -0.6321 -0.25028 -0.38097 | |||
in 0.33042 0.24995 -0.60874 0.10923 0.036372 0.151 -0.55083 -0.074239 -0.092307 -0.32821 0.09598 -0.82269 -0.36717 -0.67009 0.42909 0.016496 -0.23573 0.12864 -1.0953 0.43334 0.57067 -0.1036 0.20422 0.078308 -0.42795 -1.7984 -0.27865 0.11954 -0.12689 0.031744 3.8631 -0.17786 -0.082434 -0.62698 0.26497 -0.057185 -0.073521 0.46103 0.30862 0.12498 -0.48609 -0.0080272 0.031184 -0.36576 -0.42699 0.42164 -0.11666 -0.50703 -0.027273 -0.53285 | |||
a 0.21705 0.46515 -0.46757 0.10082 1.0135 0.74845 -0.53104 -0.26256 0.16812 0.13182 -0.24909 -0.44185 -0.21739 0.51004 0.13448 -0.43141 -0.03123 0.20674 -0.78138 -0.20148 -0.097401 0.16088 -0.61836 -0.18504 -0.12461 -2.2526 -0.22321 0.5043 0.32257 0.15313 3.9636 -0.71365 -0.67012 0.28388 0.21738 0.14433 0.25926 0.23434 0.4274 -0.44451 0.13813 0.36973 -0.64289 0.024142 -0.039315 -0.26037 0.12017 -0.043782 0.41013 0.1796 |