| @@ -39,7 +39,8 @@ try: | |||
| from mindspore.train.callback import Callback, RunContext, ModelCheckpoint, SummaryStep | |||
| from mindspore.nn import Cell, Optimizer | |||
| from mindspore.nn.loss.loss import _Loss | |||
| from mindspore.dataset.engine import Dataset, MindDataset | |||
| from mindspore.dataset.engine import Dataset, ImageFolderDatasetV2, MnistDataset, Cifar10Dataset, Cifar100Dataset, \ | |||
| VOCDataset, CelebADataset, MindDataset, ManifestDataset, TFRecordDataset, TextFileDataset | |||
| import mindspore.dataset as ds | |||
| except (ImportError, ModuleNotFoundError): | |||
| log.warning('MindSpore Not Found!') | |||
| @@ -432,7 +433,7 @@ class AnalyzeObject: | |||
| if hasattr(network, backbone_key): | |||
| backbone = getattr(network, backbone_key) | |||
| backbone_name = type(backbone).__name__ | |||
| elif network is not None: | |||
| if backbone_name is None and network is not None: | |||
| backbone_name = type(network).__name__ | |||
| return backbone_name | |||
| @@ -498,8 +499,8 @@ class AnalyzeObject: | |||
| log.debug('dataset_batch_num: %d', batch_num) | |||
| log.debug('dataset_batch_size: %d', batch_size) | |||
| dataset_path = AnalyzeObject.get_dataset_path_wrapped(dataset) | |||
| if dataset_path: | |||
| dataset_path = '/'.join(dataset_path.split('/')[:-1]) | |||
| if dataset_path and os.path.isfile(dataset_path): | |||
| dataset_path, _ = os.path.split(dataset_path) | |||
| dataset_size = int(batch_num * batch_size) | |||
| if dataset_type == 'train': | |||
| @@ -516,14 +517,24 @@ class AnalyzeObject: | |||
| Get dataset path of MindDataset object. | |||
| Args: | |||
| output_dataset (Union[MindDataset, Dataset]): See | |||
| mindspore.dataengine.datasets.Dataset. | |||
| output_dataset (Union[Dataset, ImageFolderDatasetV2, MnistDataset, Cifar10Dataset, Cifar100Dataset, | |||
| VOCDataset, CelebADataset, MindDataset, ManifestDataset, TFRecordDataset, TextFileDataset]): | |||
| See mindspore.dataengine.datasets.Dataset. | |||
| Returns: | |||
| str, dataset path. | |||
| """ | |||
| if isinstance(output_dataset, MindDataset): | |||
| dataset_dir_set = (ImageFolderDatasetV2, MnistDataset, Cifar10Dataset, | |||
| Cifar100Dataset, VOCDataset, CelebADataset) | |||
| dataset_file_set = (MindDataset, ManifestDataset) | |||
| dataset_files_set = (TFRecordDataset, TextFileDataset) | |||
| if isinstance(output_dataset, dataset_file_set): | |||
| return output_dataset.dataset_file | |||
| if isinstance(output_dataset, dataset_dir_set): | |||
| return output_dataset.dataset_dir | |||
| if isinstance(output_dataset, dataset_files_set): | |||
| return output_dataset.dataset_files[0] | |||
| return self.get_dataset_path(output_dataset.input[0]) | |||
| @staticmethod | |||
| @@ -544,7 +555,7 @@ class AnalyzeObject: | |||
| dataset_path = AnalyzeObject().get_dataset_path(dataset) | |||
| except IndexError: | |||
| dataset_path = None | |||
| validate_file_path(dataset_path, allow_empty=True) | |||
| dataset_path = validate_file_path(dataset_path, allow_empty=True) | |||
| return dataset_path | |||
| @staticmethod | |||
| @@ -182,8 +182,8 @@ def validate_file_path(file_path, allow_empty=False): | |||
| """ | |||
| try: | |||
| if allow_empty and not file_path: | |||
| return | |||
| safe_normalize_path(file_path, raise_key='dataset_path', safe_prefixes=None) | |||
| return file_path | |||
| return safe_normalize_path(file_path, raise_key='dataset_path', safe_prefixes=None) | |||
| except ValidationError as error: | |||
| log.error(str(error)) | |||
| raise MindInsightException(error=LineageErrors.PARAM_FILE_PATH_ERROR, | |||
| @@ -323,11 +323,13 @@ class TestAnalyzer(TestCase): | |||
| res = self.analyzer.get_dataset_path_wrapped(dataset) | |||
| assert res == '/path/to/cifar10' | |||
| @mock.patch('os.path.isfile') | |||
| @mock.patch('mindinsight.lineagemgr.collection.model.model_lineage.' | |||
| 'AnalyzeObject.get_dataset_path_wrapped') | |||
| def test_analyze_dataset(self, mock_get_path): | |||
| def test_analyze_dataset(self, mock_get_path, mock_isfile): | |||
| """Test analyze_dataset method.""" | |||
| mock_get_path.return_value = '/path/to/mindinsightset' | |||
| mock_isfile.return_value = True | |||
| dataset = MindDataset( | |||
| dataset_size=10, | |||
| dataset_file='/path/to/mindinsightset' | |||
| @@ -13,5 +13,5 @@ | |||
| # limitations under the License. | |||
| # ============================================================================ | |||
| """Mock mindspore.dataset.engine.""" | |||
| from .datasets import Dataset, MindDataset | |||
| from .datasets import * | |||
| from .serializer_deserializer import serialize | |||
| @@ -38,3 +38,75 @@ class MindDataset(Dataset): | |||
| def __init__(self, dataset_size=None, dataset_file=None): | |||
| super(MindDataset, self).__init__(dataset_size) | |||
| self.dataset_file = dataset_file | |||
| class ImageFolderDatasetV2(Dataset): | |||
| """Mock the MindSpore ImageFolderDatasetV2 class.""" | |||
| def __init__(self, dataset_size=None, dataset_file=None): | |||
| super(ImageFolderDatasetV2, self).__init__(dataset_size) | |||
| self.dataset_file = dataset_file | |||
| class MnistDataset(Dataset): | |||
| """Mock the MindSpore MnistDataset class.""" | |||
| def __init__(self, dataset_size=None, dataset_file=None): | |||
| super(MnistDataset, self).__init__(dataset_size) | |||
| self.dataset_file = dataset_file | |||
| class Cifar10Dataset(Dataset): | |||
| """Mock the MindSpore Cifar10Dataset class.""" | |||
| def __init__(self, dataset_size=None, dataset_file=None): | |||
| super(Cifar10Dataset, self).__init__(dataset_size) | |||
| self.dataset_file = dataset_file | |||
| class Cifar100Dataset(Dataset): | |||
| """Mock the MindSpore Cifar100Dataset class.""" | |||
| def __init__(self, dataset_size=None, dataset_file=None): | |||
| super(Cifar100Dataset, self).__init__(dataset_size) | |||
| self.dataset_file = dataset_file | |||
| class VOCDataset(Dataset): | |||
| """Mock the MindSpore VOCDataset class.""" | |||
| def __init__(self, dataset_size=None, dataset_file=None): | |||
| super(VOCDataset, self).__init__(dataset_size) | |||
| self.dataset_file = dataset_file | |||
| class CelebADataset(Dataset): | |||
| """Mock the MindSpore CelebADataset class.""" | |||
| def __init__(self, dataset_size=None, dataset_file=None): | |||
| super(CelebADataset, self).__init__(dataset_size) | |||
| self.dataset_file = dataset_file | |||
| class ManifestDataset(Dataset): | |||
| """Mock the MindSpore ManifestDataset class.""" | |||
| def __init__(self, dataset_size=None, dataset_file=None): | |||
| super(ManifestDataset, self).__init__(dataset_size) | |||
| self.dataset_file = dataset_file | |||
| class TFRecordDataset(Dataset): | |||
| """Mock the MindSpore TFRecordDataset class.""" | |||
| def __init__(self, dataset_size=None, dataset_file=None): | |||
| super(TFRecordDataset, self).__init__(dataset_size) | |||
| self.dataset_file = dataset_file | |||
| class TextFileDataset(Dataset): | |||
| """Mock the MindSpore TextFileDataset class.""" | |||
| def __init__(self, dataset_size=None, dataset_file=None): | |||
| super(TextFileDataset, self).__init__(dataset_size) | |||
| self.dataset_file = dataset_file | |||