|
- # Copyright 2020 Huawei Technologies Co., Ltd
- #
- # Licensed under the Apache License, Version 2.0 (the "License");
- # you may not use this file except in compliance with the License.
- # You may obtain a copy of the License at
- #
- # http://www.apache.org/licenses/LICENSE-2.0
- #
- # Unless required by applicable law or agreed to in writing, software
- # distributed under the License is distributed on an "AS IS" BASIS,
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- # See the License for the specific language governing permissions and
- # limitations under the License.
- # ============================================================================
- """ test container """
- from collections import OrderedDict
- import numpy as np
- import pytest
-
- import mindspore.nn as nn
- from mindspore import Tensor
-
- weight = Tensor(np.ones([2, 2]))
- conv2 = nn.Conv2d(3, 64, (3, 3), stride=2, padding=0)
-
- kernel_size = 3
- stride = 2
- padding = 1
- avg_pool = nn.AvgPool2d(kernel_size, stride)
-
-
- class TestSequentialCell():
- """ TestSequentialCell """
-
- def test_SequentialCell_init(self):
- m = nn.SequentialCell()
- assert type(m).__name__ == 'SequentialCell'
-
- def test_SequentialCell_init2(self):
- m = nn.SequentialCell([conv2])
- assert len(m) == 1
-
- def test_SequentialCell_init3(self):
- m = nn.SequentialCell([conv2, avg_pool])
- assert len(m) == 2
-
- def test_SequentialCell_init4(self):
- m = nn.SequentialCell(OrderedDict(
- [('cov2d', conv2), ('avg_pool', avg_pool)]))
- assert len(m) == 2
-
- def test_getitem1(self):
- m = nn.SequentialCell(OrderedDict(
- [('cov2d', conv2), ('avg_pool', avg_pool)]))
- assert m[0] == conv2
-
- def test_getitem2(self):
- m = nn.SequentialCell(OrderedDict(
- [('cov2d', conv2), ('avg_pool', avg_pool)]))
- assert len(m[0:2]) == 2
- assert m[:2][1] == avg_pool
-
- def test_setitem1(self):
- m = nn.SequentialCell(OrderedDict(
- [('cov2d', conv2), ('avg_pool', avg_pool)]))
- m[1] = conv2
- assert m[1] == m[0]
-
- def test_setitem2(self):
- m = nn.SequentialCell(OrderedDict(
- [('cov2d', conv2), ('avg_pool', avg_pool)]))
- with pytest.raises(TypeError):
- m[1.0] = conv2
-
- def test_delitem1(self):
- m = nn.SequentialCell(OrderedDict(
- [('cov2d', conv2), ('avg_pool', avg_pool)]))
- del m[0]
- assert len(m) == 1
-
- def test_delitem2(self):
- m = nn.SequentialCell(OrderedDict(
- [('cov2d', conv2), ('avg_pool', avg_pool)]))
- del m[:]
- assert type(m).__name__ == 'SequentialCell'
-
-
- class TestCellList():
- """ TestCellList """
-
- def test_init1(self):
- cell_list = nn.CellList([conv2, avg_pool])
- assert len(cell_list) == 2
-
- def test_init2(self):
- with pytest.raises(TypeError):
- nn.CellList(["test"])
-
- def test_getitem(self):
- cell_list = nn.CellList([conv2, avg_pool])
- assert cell_list[0] == conv2
- temp_cells = cell_list[:]
- assert temp_cells[1] == avg_pool
-
- def test_setitem(self):
- cell_list = nn.CellList([conv2, avg_pool])
- cell_list[0] = avg_pool
- assert cell_list[0] == cell_list[1]
-
- def test_delitem(self):
- cell_list = nn.CellList([conv2, avg_pool])
- del cell_list[0]
- assert len(cell_list) == 1
- del cell_list[:]
- assert type(cell_list).__name__ == 'CellList'
-
- def test_iter(self):
- cell_list = nn.CellList([conv2, avg_pool])
- for item in cell_list:
- cell = item
- assert type(cell).__name__ == 'AvgPool2d'
-
- def test_add(self):
- cell_list = nn.CellList([conv2, avg_pool])
- cell_list += [conv2]
- assert len(cell_list) == 3
- assert cell_list[0] == cell_list[2]
-
- def test_insert(self):
- cell_list = nn.CellList([conv2, avg_pool])
- cell_list.insert(0, avg_pool)
- assert len(cell_list) == 3
- assert cell_list[0] == cell_list[2]
-
- def test_append(self):
- cell_list = nn.CellList([conv2, avg_pool])
- cell_list.append(conv2)
- assert len(cell_list) == 3
- assert cell_list[0] == cell_list[2]
-
- def test_extend(self):
- cell_list = nn.CellList()
- cell_list.extend([conv2, avg_pool])
- assert len(cell_list) == 2
- assert cell_list[0] == conv2
|