|
@@ -27,4 +27,4 @@ class TestBiLinear(unittest.TestCase): |
|
|
x_right = torch.randn((7, 10, 20, 5)) |
|
|
x_right = torch.randn((7, 10, 20, 5)) |
|
|
y = bl(x_left, x_right) |
|
|
y = bl(x_left, x_right) |
|
|
print(bl) |
|
|
print(bl) |
|
|
bl2 = BiLinear(n_left=15, n_right=15, n_out=10, bias=False) |
|
|
|
|
|
|
|
|
bl2 = BiLinear(n_left=15, n_right=15, n_out=10, bias=True) |