Browse Source

add the test for modules.utils.summary

tags/v0.4.10
ChenXin 5 years ago
parent
commit
b874fba8f2
1 changed files with 13 additions and 2 deletions
  1. +13
    -2
      test/modules/test_utils.py

+ 13
- 2
test/modules/test_utils.py View File

@@ -1,9 +1,20 @@
import unittest

import torch
from fastNLP.modules.utils import get_dropout_mask

from fastNLP.models import CNNText
from fastNLP.modules.utils import get_dropout_mask, summary


class TestUtil(unittest.TestCase):
def test_get_dropout_mask(self):
tensor = torch.randn(3, 4)
mask = get_dropout_mask(0.3, tensor)
self.assertSequenceEqual(mask.size(), torch.Size([3, 4]))
self.assertSequenceEqual(mask.size(), torch.Size([3, 4]))
def test_summary(self):
model = CNNText(embed=(4, 4), num_classes=2, kernel_nums=(9,5), kernel_sizes=(1,3))
# 4 * 4 + 4 * (9 * 1 + 5 * 3) + 2 * (9 + 5 + 1) = 142
self.assertSequenceEqual((142, 142, 0), summary(model))
model.embed.requires_grad = False
self.assertSequenceEqual((142, 126, 16), summary(model))

Loading…
Cancel
Save