Browse Source

fix bugs

tags/v0.2.0^2
yunfan 6 years ago
parent
commit
5edd9de841
2 changed files with 1 additions and 24 deletions
  1. +1
    -1
      fastNLP/core/dataset.py
  2. +0
    -23
      fastNLP/models/cnn_text_classification.py

+ 1
- 1
fastNLP/core/dataset.py View File

@@ -67,8 +67,8 @@ class DataSet(object):
self.dataset = dataset
self.idx = idx
def __getitem__(self, item):
assert self.idx < len(self.dataset), "index:{} out of range".format(self.idx)
assert item in self.dataset.field_arrays, "no such field:{} in instance {}".format(item, self.dataset[self.idx])
assert self.idx < len(self.dataset.field_arrays[item]), "index:{} out of range".format(self.idx)
return self.dataset.field_arrays[item][self.idx]
def __repr__(self):
return self.dataset[self.idx].__repr__()


+ 0
- 23
fastNLP/models/cnn_text_classification.py View File

@@ -33,7 +33,6 @@ class CNNText(torch.nn.Module):
padding=padding)
self.dropout = nn.Dropout(dropout)
self.fc = encoder.Linear(sum(kernel_nums), num_classes)
self._loss = nn.CrossEntropyLoss()

def forward(self, word_seq):
"""
@@ -56,25 +55,3 @@ class CNNText(torch.nn.Module):
output = self(word_seq)
_, predict = output['output'].max(dim=1)
return {'predict': predict}

def get_loss(self, output, label_seq):
"""

:param output: output of forward(), [batch_size, seq_len]
:param label_seq: true label in DataSet, [batch_size, seq_len]
:return loss: torch.Tensor
"""
return self._loss(output, label_seq)

def evaluate(self, predict, label_seq):
"""

:param predict: iterable predict tensors
:param label_seq: iterable true label tensors
:return accuracy: dict of float
"""
predict, label_seq = torch.stack(tuple(predict), dim=0), torch.stack(tuple(label_seq), dim=0)
predict, label_seq = predict.squeeze(), label_seq.squeeze()
correct = (predict == label_seq).long().sum().item()
total = label_seq.size(0)
return {'acc': 1.0 * correct / total}

Loading…
Cancel
Save