|
- # Copyright 2020 Huawei Technologies Co., Ltd
- #
- # Licensed under the Apache License, Version 2.0 (the "License");
- # you may not use this file except in compliance with the License.
- # You may obtain a copy of the License at
- #
- # http://www.apache.org/licenses/LICENSE-2.0
- #
- # Unless required by applicable law or agreed to in writing, software
- # distributed under the License is distributed on an "AS IS" BASIS,
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- # See the License for the specific language governing permissions and
- # limitations under the License.
- # ============================================================================
- """ infer """
- from argparse import ArgumentParser
- import numpy as np
-
- from mindspore import Tensor
- from ....dataset_mock import MindData
-
- __factory = {
- "resnet50": resnet50(),
- }
-
-
- def parse_args():
- """ parse_args """
- parser = ArgumentParser(description="resnet50 example")
-
- parser.add_argument("--model", type=str, default="resnet50",
- help="the network architecture for training or testing")
- parser.add_argument("--phase", type=str, default="test",
- help="the phase of the model, default is test.")
- parser.add_argument("--file_path", type=str, default="/data/file/test1.txt",
- help="data directory of training or testing")
- parser.add_argument("--batch_size", type=int, default=1,
- help="batch size for training or testing ")
-
- return parser.parse_args()
-
-
- def get_model(name):
- """ get_model """
- if name not in __factory:
- raise KeyError("unknown model:", name)
- return __factory[name]
-
-
- def get_dataset(batch_size=32):
- """ get_dataset """
- dataset_types = np.float32
- dataset_shapes = (batch_size, 3, 224, 224)
-
- dataset = MindData(size=2, batch_size=batch_size,
- np_types=dataset_types,
- output_shapes=dataset_shapes,
- input_indexs=(0, 1))
- return dataset
-
-
- # pylint: disable=unused-argument
- def test(name, file_path, batch_size):
- """ test """
- network = get_model(name)
-
- batch = get_dataset(batch_size=batch_size)
-
- data_list = []
- for data in batch:
- data_list.append(data.asnumpy())
- batch_data = np.concatenate(data_list, axis=0).transpose((0, 3, 1, 2))
- input_tensor = Tensor(batch_data)
- print(input_tensor.shape)
- network(input_tensor)
-
-
- if __name__ == '__main__':
- args = parse_args()
- if args.phase == "train":
- raise NotImplementedError
- test(args.model, args.file_path, args.batch_size)
|