|
|
@@ -74,9 +74,9 @@ def get_embeddings(init_embed): |
|
|
|
""" |
|
|
|
根据输入的init_embed生成nn.Embedding对象。 |
|
|
|
|
|
|
|
:param init_embed: 单词词典, 可以是 tuple, 包括(num_embedings, embedding_dim), 即 |
|
|
|
embedding的大小和每个词的维度. 也可以传入 nn.Embedding 对象, |
|
|
|
此时就以传入的对象作为embedding |
|
|
|
:param init_embed: 可以是 tuple:(num_embedings, embedding_dim), 即embedding的大小和每个词的维度;也可以传入 |
|
|
|
nn.Embedding 对象, 此时就以传入的对象作为embedding; 传入np.ndarray也行,将使用传入的ndarray作为作为Embedding初始 |
|
|
|
化; 传入orch.Tensor, 将使用传入的值作为Embedding初始化。 |
|
|
|
:return nn.Embedding embeddings: |
|
|
|
""" |
|
|
|
if isinstance(init_embed, tuple): |
|
|
|