You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

evaluate.py 1.5 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445
  1. from model import *
  2. from train import *
  3. def evaluate(net, dataset, bactch_size=64, use_cuda=False):
  4. dataloader = DataLoader(dataset, batch_size=bactch_size, collate_fn=collate, num_workers=0)
  5. count = 0
  6. if use_cuda:
  7. net.cuda()
  8. for i, batch_samples in enumerate(dataloader):
  9. x, y = batch_samples
  10. doc_list = []
  11. for sample in x:
  12. doc = []
  13. for sent_vec in sample:
  14. if use_cuda:
  15. sent_vec = sent_vec.cuda()
  16. doc.append(Variable(sent_vec, volatile=True))
  17. doc_list.append(pack_sequence(doc))
  18. if use_cuda:
  19. y = y.cuda()
  20. predicts = net(doc_list)
  21. p, idx = torch.max(predicts, dim=1)
  22. idx = idx.data
  23. count += torch.sum(torch.eq(idx, y))
  24. return count
  25. if __name__ == '__main__':
  26. '''
  27. Evaluate the performance of models
  28. '''
  29. from gensim.models import Word2Vec
  30. embed_model = Word2Vec.load('yelp.word2vec')
  31. embedding = Embedding_layer(embed_model.wv, embed_model.wv.vector_size)
  32. del embed_model
  33. net = HAN(input_size=200, output_size=5,
  34. word_hidden_size=50, word_num_layers=1, word_context_size=100,
  35. sent_hidden_size=50, sent_num_layers=1, sent_context_size=100)
  36. net.load_state_dict(torch.load('models.dict'))
  37. test_dataset = YelpDocSet('reviews', 199, 4, embedding)
  38. correct = evaluate(net, test_dataset, True)
  39. print('accuracy {}'.format(correct / len(test_dataset)))