@@ -864,9 +864,10 @@ class DataSet(object): | |||||
def add_seq_len(self, field_name: str, new_field_name=Const.INPUT_LEN): | def add_seq_len(self, field_name: str, new_field_name=Const.INPUT_LEN): | ||||
""" | """ | ||||
将使用len()直接对field_name中每个元素作用,将其结果作为seqence length, 并放入seq_len这个field。 | |||||
将使用len()直接对field_name中每个元素作用,将其结果作为sequence length, 并放入seq_len这个field。 | |||||
:param field_name: str. | :param field_name: str. | ||||
:param new_field_name: str. 新的field_name | |||||
:return: | :return: | ||||
""" | """ | ||||
if self.has_field(field_name=field_name): | if self.has_field(field_name=field_name): | ||||
@@ -120,7 +120,7 @@ DATASET_DIR = { | |||||
# Summarization, English | # Summarization, English | ||||
"ext-cnndm": "ext-cnndm.zip", | "ext-cnndm": "ext-cnndm.zip", | ||||
# Question & answer | |||||
# Question & answer, Chinese | |||||
"cmrc2018": "cmrc2018.zip" | "cmrc2018": "cmrc2018.zip" | ||||
} | } | ||||
@@ -225,7 +225,7 @@ class IMDBLoader(Loader): | |||||
shutil.rmtree(data_dir) | shutil.rmtree(data_dir) | ||||
data_dir = self._get_dataset_path(dataset_name=dataset_name) | data_dir = self._get_dataset_path(dataset_name=dataset_name) | ||||
if not os.path.exists(os.path.join(data_dir, 'dev.csv')): | |||||
if not os.path.exists(os.path.join(data_dir, 'dev.txt')): | |||||
if dev_ratio > 0: | if dev_ratio > 0: | ||||
assert 0 < dev_ratio < 1, "dev_ratio should be in range (0,1)." | assert 0 < dev_ratio < 1, "dev_ratio should be in range (0,1)." | ||||
try: | try: | ||||
@@ -33,7 +33,6 @@ class CNNText(torch.nn.Module): | |||||
:param tuple(int,int),torch.FloatTensor,nn.Embedding,numpy.ndarray embed: Embedding的大小(传入tuple(int, int), | :param tuple(int,int),torch.FloatTensor,nn.Embedding,numpy.ndarray embed: Embedding的大小(传入tuple(int, int), | ||||
第一个int为vocab_zie, 第二个int为embed_dim); 如果为Tensor, Embedding, ndarray等则直接使用该值初始化Embedding | 第一个int为vocab_zie, 第二个int为embed_dim); 如果为Tensor, Embedding, ndarray等则直接使用该值初始化Embedding | ||||
:param int num_classes: 一共有多少类 | :param int num_classes: 一共有多少类 | ||||
:param int,tuple(int) out_channels: 输出channel的数量。如果为list,则需要与kernel_sizes的数量保持一致 | |||||
:param int,tuple(int) kernel_sizes: 输出channel的kernel大小。 | :param int,tuple(int) kernel_sizes: 输出channel的kernel大小。 | ||||
:param float dropout: Dropout的大小 | :param float dropout: Dropout的大小 | ||||
""" | """ | ||||
@@ -164,6 +164,7 @@ class BiRNN(nn.Module): | |||||
if self.dropout_rate > 0: | if self.dropout_rate > 0: | ||||
dropout_input = F.dropout(rnn_input.data, p=self.dropout_rate, training=self.training) | dropout_input = F.dropout(rnn_input.data, p=self.dropout_rate, training=self.training) | ||||
rnn_input = nn.utils.rnn.PackedSequence(dropout_input, rnn_input.batch_sizes) | rnn_input = nn.utils.rnn.PackedSequence(dropout_input, rnn_input.batch_sizes) | ||||
self.rnn.flatten_parameters() | |||||
output = self.rnn(rnn_input)[0] | output = self.rnn(rnn_input)[0] | ||||
# Unpack everything | # Unpack everything | ||||
output = nn.utils.rnn.pad_packed_sequence(output, batch_first=True)[0] | output = nn.utils.rnn.pad_packed_sequence(output, batch_first=True)[0] | ||||