Browse Source

修改DataSet split的一个注释错误

tags/v0.4.10
yh_cc 5 years ago
parent
commit
acf18e2e89
2 changed files with 11 additions and 5 deletions
  1. +1
    -1
      fastNLP/core/dataset.py
  2. +10
    -4
      fastNLP/modules/encoder/embedding.py

+ 1
- 1
fastNLP/core/dataset.py View File

@@ -805,7 +805,7 @@ class DataSet(object):
""" """
将DataSet按照ratio的比例拆分,返回两个DataSet 将DataSet按照ratio的比例拆分,返回两个DataSet


:param float ratio: 0<ratio<1, 返回的第一个DataSet拥有 `ratio` 这么多数据,第二个DataSet拥有 `(1-ratio)` 这么多数据
:param float ratio: 0<ratio<1, 返回的第一个DataSet拥有 `(1-ratio)` 这么多数据,第二个DataSet拥有`ratio`这么多数据
:return: [DataSet, DataSet] :return: [DataSet, DataSet]
""" """
assert isinstance(ratio, float) assert isinstance(ratio, float)


+ 10
- 4
fastNLP/modules/encoder/embedding.py View File

@@ -51,6 +51,13 @@ class Embedding(nn.Module):
x = self.embed(x) x = self.embed(x)
return self.dropout(x) return self.dropout(x)


@property
def num_embedding(self)->int:
return len(self)

def __len__(self):
return len(self.embed)

@property @property
def embed_size(self) -> int: def embed_size(self) -> int:
return self._embed_size return self._embed_size
@@ -109,9 +116,8 @@ class TokenEmbedding(nn.Module):
for param in self.parameters(): for param in self.parameters():
param.requires_grad = value param.requires_grad = value


@abstractmethod
def get_original_vocab(self):
pass
def __len__(self):
return len(self._word_vocab)


@property @property
def embed_size(self) -> int: def embed_size(self) -> int:
@@ -505,7 +511,7 @@ class CNNCharEmbedding(TokenEmbedding):
:param embed_size: 该word embedding的大小,默认值为50. :param embed_size: 该word embedding的大小,默认值为50.
:param char_emb_size: character的embed的大小。character是从vocab中生成的。默认值为50. :param char_emb_size: character的embed的大小。character是从vocab中生成的。默认值为50.
:param filter_nums: filter的数量. 长度需要和kernels一致。默认值为[40, 30, 20]. :param filter_nums: filter的数量. 长度需要和kernels一致。默认值为[40, 30, 20].
:param kernels: kernel的大小. 默认值为[5, 3, 1].
:param kernel_sizes: kernel的大小. 默认值为[5, 3, 1].
:param pool_method: character的表示在合成一个表示时所使用的pool方法,支持'avg', 'max'. :param pool_method: character的表示在合成一个表示时所使用的pool方法,支持'avg', 'max'.
:param activation: CNN之后使用的激活方法,支持'relu', 'sigmoid', 'tanh' 或者自定义函数. :param activation: CNN之后使用的激活方法,支持'relu', 'sigmoid', 'tanh' 或者自定义函数.
:param min_char_freq: character的最少出现次数。默认值为2. :param min_char_freq: character的最少出现次数。默认值为2.


Loading…
Cancel
Save