You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

stack_embedding.py 3.5 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102
  1. """
  2. .. todo::
  3. doc
  4. """
  5. __all__ = [
  6. "StackEmbedding",
  7. ]
  8. from typing import List
  9. import torch
  10. from torch import nn as nn
  11. from .embedding import TokenEmbedding
  12. class StackEmbedding(TokenEmbedding):
  13. """
  14. 支持将多个embedding集合成一个embedding。
  15. Example::
  16. >>> from fastNLP import Vocabulary
  17. >>> from fastNLP.embeddings import StaticEmbedding
  18. >>> vocab = Vocabulary().add_word_lst("The whether is good .".split())
  19. >>> embed_1 = StaticEmbedding(vocab, model_dir_or_name='en-glove-6b-50d', requires_grad=True)
  20. >>> embed_2 = StaticEmbedding(vocab, model_dir_or_name='en-word2vec-300', requires_grad=True)
  21. :param embeds: 一个由若干个TokenEmbedding组成的list,要求每一个TokenEmbedding的词表都保持一致
  22. :param float word_dropout: 以多大的概率将一个词替换为unk。这样既可以训练unk也是一定的regularize。不同embedidng会在相同的位置
  23. 被设置为unknown。如果这里设置了dropout,则组成的embedding就不要再设置dropout了。
  24. :param float dropout: 以多大的概率对embedding的表示进行Dropout。0.1即随机将10%的值置为0。
  25. """
  26. def __init__(self, embeds: List[TokenEmbedding], word_dropout=0, dropout=0):
  27. vocabs = []
  28. for embed in embeds:
  29. if hasattr(embed, 'get_word_vocab'):
  30. vocabs.append(embed.get_word_vocab())
  31. _vocab = vocabs[0]
  32. for vocab in vocabs[1:]:
  33. assert vocab == _vocab, "All embeddings in StackEmbedding should use the same word vocabulary."
  34. super(StackEmbedding, self).__init__(_vocab, word_dropout=word_dropout, dropout=dropout)
  35. assert isinstance(embeds, list)
  36. for embed in embeds:
  37. assert isinstance(embed, TokenEmbedding), "Only TokenEmbedding type is supported."
  38. self.embeds = nn.ModuleList(embeds)
  39. self._embed_size = sum([embed.embed_size for embed in self.embeds])
  40. def append(self, embed: TokenEmbedding):
  41. """
  42. 添加一个embedding到结尾。
  43. :param embed:
  44. :return:
  45. """
  46. assert isinstance(embed, TokenEmbedding)
  47. self.embeds.append(embed)
  48. def pop(self):
  49. """
  50. 弹出最后一个embed
  51. :return:
  52. """
  53. return self.embeds.pop()
  54. @property
  55. def embed_size(self):
  56. return self._embed_size
  57. @property
  58. def requires_grad(self):
  59. """
  60. Embedding的参数是否允许优化。True: 所有参数运行优化; False: 所有参数不允许优化; None: 部分允许优化、部分不允许
  61. :return:
  62. """
  63. requires_grads = set([embed.requires_grad for embed in self.embeds()])
  64. if len(requires_grads) == 1:
  65. return requires_grads.pop()
  66. else:
  67. return None
  68. @requires_grad.setter
  69. def requires_grad(self, value):
  70. for embed in self.embeds():
  71. embed.requires_grad = value
  72. def forward(self, words):
  73. """
  74. 得到多个embedding的结果,并把结果按照顺序concat起来。
  75. :param words: batch_size x max_len
  76. :return: 返回的shape和当前这个stack embedding中embedding的组成有关
  77. """
  78. outputs = []
  79. words = self.drop_word(words)
  80. for embed in self.embeds:
  81. outputs.append(embed(words))
  82. outputs = self.dropout(torch.cat(outputs, dim=-1))
  83. return outputs