| @@ -39,7 +39,8 @@ try: | |||||
| from mindspore.train.callback import Callback, RunContext, ModelCheckpoint, SummaryStep | from mindspore.train.callback import Callback, RunContext, ModelCheckpoint, SummaryStep | ||||
| from mindspore.nn import Cell, Optimizer | from mindspore.nn import Cell, Optimizer | ||||
| from mindspore.nn.loss.loss import _Loss | from mindspore.nn.loss.loss import _Loss | ||||
| from mindspore.dataset.engine import Dataset, MindDataset | |||||
| from mindspore.dataset.engine import Dataset, ImageFolderDatasetV2, MnistDataset, Cifar10Dataset, Cifar100Dataset, \ | |||||
| VOCDataset, CelebADataset, MindDataset, ManifestDataset, TFRecordDataset, TextFileDataset | |||||
| import mindspore.dataset as ds | import mindspore.dataset as ds | ||||
| except (ImportError, ModuleNotFoundError): | except (ImportError, ModuleNotFoundError): | ||||
| log.warning('MindSpore Not Found!') | log.warning('MindSpore Not Found!') | ||||
| @@ -432,7 +433,7 @@ class AnalyzeObject: | |||||
| if hasattr(network, backbone_key): | if hasattr(network, backbone_key): | ||||
| backbone = getattr(network, backbone_key) | backbone = getattr(network, backbone_key) | ||||
| backbone_name = type(backbone).__name__ | backbone_name = type(backbone).__name__ | ||||
| elif network is not None: | |||||
| if backbone_name is None and network is not None: | |||||
| backbone_name = type(network).__name__ | backbone_name = type(network).__name__ | ||||
| return backbone_name | return backbone_name | ||||
| @@ -498,8 +499,8 @@ class AnalyzeObject: | |||||
| log.debug('dataset_batch_num: %d', batch_num) | log.debug('dataset_batch_num: %d', batch_num) | ||||
| log.debug('dataset_batch_size: %d', batch_size) | log.debug('dataset_batch_size: %d', batch_size) | ||||
| dataset_path = AnalyzeObject.get_dataset_path_wrapped(dataset) | dataset_path = AnalyzeObject.get_dataset_path_wrapped(dataset) | ||||
| if dataset_path: | |||||
| dataset_path = '/'.join(dataset_path.split('/')[:-1]) | |||||
| if dataset_path and os.path.isfile(dataset_path): | |||||
| dataset_path, _ = os.path.split(dataset_path) | |||||
| dataset_size = int(batch_num * batch_size) | dataset_size = int(batch_num * batch_size) | ||||
| if dataset_type == 'train': | if dataset_type == 'train': | ||||
| @@ -516,14 +517,24 @@ class AnalyzeObject: | |||||
| Get dataset path of MindDataset object. | Get dataset path of MindDataset object. | ||||
| Args: | Args: | ||||
| output_dataset (Union[MindDataset, Dataset]): See | |||||
| mindspore.dataengine.datasets.Dataset. | |||||
| output_dataset (Union[Dataset, ImageFolderDatasetV2, MnistDataset, Cifar10Dataset, Cifar100Dataset, | |||||
| VOCDataset, CelebADataset, MindDataset, ManifestDataset, TFRecordDataset, TextFileDataset]): | |||||
| See mindspore.dataengine.datasets.Dataset. | |||||
| Returns: | Returns: | ||||
| str, dataset path. | str, dataset path. | ||||
| """ | """ | ||||
| if isinstance(output_dataset, MindDataset): | |||||
| dataset_dir_set = (ImageFolderDatasetV2, MnistDataset, Cifar10Dataset, | |||||
| Cifar100Dataset, VOCDataset, CelebADataset) | |||||
| dataset_file_set = (MindDataset, ManifestDataset) | |||||
| dataset_files_set = (TFRecordDataset, TextFileDataset) | |||||
| if isinstance(output_dataset, dataset_file_set): | |||||
| return output_dataset.dataset_file | return output_dataset.dataset_file | ||||
| if isinstance(output_dataset, dataset_dir_set): | |||||
| return output_dataset.dataset_dir | |||||
| if isinstance(output_dataset, dataset_files_set): | |||||
| return output_dataset.dataset_files[0] | |||||
| return self.get_dataset_path(output_dataset.input[0]) | return self.get_dataset_path(output_dataset.input[0]) | ||||
| @staticmethod | @staticmethod | ||||
| @@ -544,7 +555,7 @@ class AnalyzeObject: | |||||
| dataset_path = AnalyzeObject().get_dataset_path(dataset) | dataset_path = AnalyzeObject().get_dataset_path(dataset) | ||||
| except IndexError: | except IndexError: | ||||
| dataset_path = None | dataset_path = None | ||||
| validate_file_path(dataset_path, allow_empty=True) | |||||
| dataset_path = validate_file_path(dataset_path, allow_empty=True) | |||||
| return dataset_path | return dataset_path | ||||
| @staticmethod | @staticmethod | ||||
| @@ -182,8 +182,8 @@ def validate_file_path(file_path, allow_empty=False): | |||||
| """ | """ | ||||
| try: | try: | ||||
| if allow_empty and not file_path: | if allow_empty and not file_path: | ||||
| return | |||||
| safe_normalize_path(file_path, raise_key='dataset_path', safe_prefixes=None) | |||||
| return file_path | |||||
| return safe_normalize_path(file_path, raise_key='dataset_path', safe_prefixes=None) | |||||
| except ValidationError as error: | except ValidationError as error: | ||||
| log.error(str(error)) | log.error(str(error)) | ||||
| raise MindInsightException(error=LineageErrors.PARAM_FILE_PATH_ERROR, | raise MindInsightException(error=LineageErrors.PARAM_FILE_PATH_ERROR, | ||||
| @@ -323,11 +323,13 @@ class TestAnalyzer(TestCase): | |||||
| res = self.analyzer.get_dataset_path_wrapped(dataset) | res = self.analyzer.get_dataset_path_wrapped(dataset) | ||||
| assert res == '/path/to/cifar10' | assert res == '/path/to/cifar10' | ||||
| @mock.patch('os.path.isfile') | |||||
| @mock.patch('mindinsight.lineagemgr.collection.model.model_lineage.' | @mock.patch('mindinsight.lineagemgr.collection.model.model_lineage.' | ||||
| 'AnalyzeObject.get_dataset_path_wrapped') | 'AnalyzeObject.get_dataset_path_wrapped') | ||||
| def test_analyze_dataset(self, mock_get_path): | |||||
| def test_analyze_dataset(self, mock_get_path, mock_isfile): | |||||
| """Test analyze_dataset method.""" | """Test analyze_dataset method.""" | ||||
| mock_get_path.return_value = '/path/to/mindinsightset' | mock_get_path.return_value = '/path/to/mindinsightset' | ||||
| mock_isfile.return_value = True | |||||
| dataset = MindDataset( | dataset = MindDataset( | ||||
| dataset_size=10, | dataset_size=10, | ||||
| dataset_file='/path/to/mindinsightset' | dataset_file='/path/to/mindinsightset' | ||||
| @@ -13,5 +13,5 @@ | |||||
| # limitations under the License. | # limitations under the License. | ||||
| # ============================================================================ | # ============================================================================ | ||||
| """Mock mindspore.dataset.engine.""" | """Mock mindspore.dataset.engine.""" | ||||
| from .datasets import Dataset, MindDataset | |||||
| from .datasets import * | |||||
| from .serializer_deserializer import serialize | from .serializer_deserializer import serialize | ||||
| @@ -38,3 +38,75 @@ class MindDataset(Dataset): | |||||
| def __init__(self, dataset_size=None, dataset_file=None): | def __init__(self, dataset_size=None, dataset_file=None): | ||||
| super(MindDataset, self).__init__(dataset_size) | super(MindDataset, self).__init__(dataset_size) | ||||
| self.dataset_file = dataset_file | self.dataset_file = dataset_file | ||||
| class ImageFolderDatasetV2(Dataset): | |||||
| """Mock the MindSpore ImageFolderDatasetV2 class.""" | |||||
| def __init__(self, dataset_size=None, dataset_file=None): | |||||
| super(ImageFolderDatasetV2, self).__init__(dataset_size) | |||||
| self.dataset_file = dataset_file | |||||
| class MnistDataset(Dataset): | |||||
| """Mock the MindSpore MnistDataset class.""" | |||||
| def __init__(self, dataset_size=None, dataset_file=None): | |||||
| super(MnistDataset, self).__init__(dataset_size) | |||||
| self.dataset_file = dataset_file | |||||
| class Cifar10Dataset(Dataset): | |||||
| """Mock the MindSpore Cifar10Dataset class.""" | |||||
| def __init__(self, dataset_size=None, dataset_file=None): | |||||
| super(Cifar10Dataset, self).__init__(dataset_size) | |||||
| self.dataset_file = dataset_file | |||||
| class Cifar100Dataset(Dataset): | |||||
| """Mock the MindSpore Cifar100Dataset class.""" | |||||
| def __init__(self, dataset_size=None, dataset_file=None): | |||||
| super(Cifar100Dataset, self).__init__(dataset_size) | |||||
| self.dataset_file = dataset_file | |||||
| class VOCDataset(Dataset): | |||||
| """Mock the MindSpore VOCDataset class.""" | |||||
| def __init__(self, dataset_size=None, dataset_file=None): | |||||
| super(VOCDataset, self).__init__(dataset_size) | |||||
| self.dataset_file = dataset_file | |||||
| class CelebADataset(Dataset): | |||||
| """Mock the MindSpore CelebADataset class.""" | |||||
| def __init__(self, dataset_size=None, dataset_file=None): | |||||
| super(CelebADataset, self).__init__(dataset_size) | |||||
| self.dataset_file = dataset_file | |||||
| class ManifestDataset(Dataset): | |||||
| """Mock the MindSpore ManifestDataset class.""" | |||||
| def __init__(self, dataset_size=None, dataset_file=None): | |||||
| super(ManifestDataset, self).__init__(dataset_size) | |||||
| self.dataset_file = dataset_file | |||||
| class TFRecordDataset(Dataset): | |||||
| """Mock the MindSpore TFRecordDataset class.""" | |||||
| def __init__(self, dataset_size=None, dataset_file=None): | |||||
| super(TFRecordDataset, self).__init__(dataset_size) | |||||
| self.dataset_file = dataset_file | |||||
| class TextFileDataset(Dataset): | |||||
| """Mock the MindSpore TextFileDataset class.""" | |||||
| def __init__(self, dataset_size=None, dataset_file=None): | |||||
| super(TextFileDataset, self).__init__(dataset_size) | |||||
| self.dataset_file = dataset_file | |||||