Browse Source

fix lineage collection for network, loss_fn and dataset path

tags/v0.3.0-alpha
luopengting 5 years ago
parent
commit
bd65d62e55
5 changed files with 97 additions and 12 deletions
  1. +19
    -8
      mindinsight/lineagemgr/collection/model/model_lineage.py
  2. +2
    -2
      mindinsight/lineagemgr/common/validator/validate.py
  3. +3
    -1
      tests/ut/lineagemgr/collection/model/test_model_lineage.py
  4. +1
    -1
      tests/utils/mindspore/dataset/engine/__init__.py
  5. +72
    -0
      tests/utils/mindspore/dataset/engine/datasets.py

+ 19
- 8
mindinsight/lineagemgr/collection/model/model_lineage.py View File

@@ -39,7 +39,8 @@ try:
from mindspore.train.callback import Callback, RunContext, ModelCheckpoint, SummaryStep
from mindspore.nn import Cell, Optimizer
from mindspore.nn.loss.loss import _Loss
from mindspore.dataset.engine import Dataset, MindDataset
from mindspore.dataset.engine import Dataset, ImageFolderDatasetV2, MnistDataset, Cifar10Dataset, Cifar100Dataset, \
VOCDataset, CelebADataset, MindDataset, ManifestDataset, TFRecordDataset, TextFileDataset
import mindspore.dataset as ds
except (ImportError, ModuleNotFoundError):
log.warning('MindSpore Not Found!')
@@ -432,7 +433,7 @@ class AnalyzeObject:
if hasattr(network, backbone_key):
backbone = getattr(network, backbone_key)
backbone_name = type(backbone).__name__
elif network is not None:
if backbone_name is None and network is not None:
backbone_name = type(network).__name__
return backbone_name

@@ -498,8 +499,8 @@ class AnalyzeObject:
log.debug('dataset_batch_num: %d', batch_num)
log.debug('dataset_batch_size: %d', batch_size)
dataset_path = AnalyzeObject.get_dataset_path_wrapped(dataset)
if dataset_path:
dataset_path = '/'.join(dataset_path.split('/')[:-1])
if dataset_path and os.path.isfile(dataset_path):
dataset_path, _ = os.path.split(dataset_path)

dataset_size = int(batch_num * batch_size)
if dataset_type == 'train':
@@ -516,14 +517,24 @@ class AnalyzeObject:
Get dataset path of MindDataset object.

Args:
output_dataset (Union[MindDataset, Dataset]): See
mindspore.dataengine.datasets.Dataset.
output_dataset (Union[Dataset, ImageFolderDatasetV2, MnistDataset, Cifar10Dataset, Cifar100Dataset,
VOCDataset, CelebADataset, MindDataset, ManifestDataset, TFRecordDataset, TextFileDataset]):
See mindspore.dataengine.datasets.Dataset.

Returns:
str, dataset path.
"""
if isinstance(output_dataset, MindDataset):
dataset_dir_set = (ImageFolderDatasetV2, MnistDataset, Cifar10Dataset,
Cifar100Dataset, VOCDataset, CelebADataset)
dataset_file_set = (MindDataset, ManifestDataset)
dataset_files_set = (TFRecordDataset, TextFileDataset)

if isinstance(output_dataset, dataset_file_set):
return output_dataset.dataset_file
if isinstance(output_dataset, dataset_dir_set):
return output_dataset.dataset_dir
if isinstance(output_dataset, dataset_files_set):
return output_dataset.dataset_files[0]
return self.get_dataset_path(output_dataset.input[0])

@staticmethod
@@ -544,7 +555,7 @@ class AnalyzeObject:
dataset_path = AnalyzeObject().get_dataset_path(dataset)
except IndexError:
dataset_path = None
validate_file_path(dataset_path, allow_empty=True)
dataset_path = validate_file_path(dataset_path, allow_empty=True)
return dataset_path

@staticmethod


+ 2
- 2
mindinsight/lineagemgr/common/validator/validate.py View File

@@ -182,8 +182,8 @@ def validate_file_path(file_path, allow_empty=False):
"""
try:
if allow_empty and not file_path:
return
safe_normalize_path(file_path, raise_key='dataset_path', safe_prefixes=None)
return file_path
return safe_normalize_path(file_path, raise_key='dataset_path', safe_prefixes=None)
except ValidationError as error:
log.error(str(error))
raise MindInsightException(error=LineageErrors.PARAM_FILE_PATH_ERROR,


+ 3
- 1
tests/ut/lineagemgr/collection/model/test_model_lineage.py View File

@@ -323,11 +323,13 @@ class TestAnalyzer(TestCase):
res = self.analyzer.get_dataset_path_wrapped(dataset)
assert res == '/path/to/cifar10'

@mock.patch('os.path.isfile')
@mock.patch('mindinsight.lineagemgr.collection.model.model_lineage.'
'AnalyzeObject.get_dataset_path_wrapped')
def test_analyze_dataset(self, mock_get_path):
def test_analyze_dataset(self, mock_get_path, mock_isfile):
"""Test analyze_dataset method."""
mock_get_path.return_value = '/path/to/mindinsightset'
mock_isfile.return_value = True
dataset = MindDataset(
dataset_size=10,
dataset_file='/path/to/mindinsightset'


+ 1
- 1
tests/utils/mindspore/dataset/engine/__init__.py View File

@@ -13,5 +13,5 @@
# limitations under the License.
# ============================================================================
"""Mock mindspore.dataset.engine."""
from .datasets import Dataset, MindDataset
from .datasets import *
from .serializer_deserializer import serialize

+ 72
- 0
tests/utils/mindspore/dataset/engine/datasets.py View File

@@ -38,3 +38,75 @@ class MindDataset(Dataset):
def __init__(self, dataset_size=None, dataset_file=None):
super(MindDataset, self).__init__(dataset_size)
self.dataset_file = dataset_file


class ImageFolderDatasetV2(Dataset):
"""Mock the MindSpore ImageFolderDatasetV2 class."""

def __init__(self, dataset_size=None, dataset_file=None):
super(ImageFolderDatasetV2, self).__init__(dataset_size)
self.dataset_file = dataset_file


class MnistDataset(Dataset):
"""Mock the MindSpore MnistDataset class."""

def __init__(self, dataset_size=None, dataset_file=None):
super(MnistDataset, self).__init__(dataset_size)
self.dataset_file = dataset_file


class Cifar10Dataset(Dataset):
"""Mock the MindSpore Cifar10Dataset class."""

def __init__(self, dataset_size=None, dataset_file=None):
super(Cifar10Dataset, self).__init__(dataset_size)
self.dataset_file = dataset_file


class Cifar100Dataset(Dataset):
"""Mock the MindSpore Cifar100Dataset class."""

def __init__(self, dataset_size=None, dataset_file=None):
super(Cifar100Dataset, self).__init__(dataset_size)
self.dataset_file = dataset_file


class VOCDataset(Dataset):
"""Mock the MindSpore VOCDataset class."""

def __init__(self, dataset_size=None, dataset_file=None):
super(VOCDataset, self).__init__(dataset_size)
self.dataset_file = dataset_file


class CelebADataset(Dataset):
"""Mock the MindSpore CelebADataset class."""

def __init__(self, dataset_size=None, dataset_file=None):
super(CelebADataset, self).__init__(dataset_size)
self.dataset_file = dataset_file


class ManifestDataset(Dataset):
"""Mock the MindSpore ManifestDataset class."""

def __init__(self, dataset_size=None, dataset_file=None):
super(ManifestDataset, self).__init__(dataset_size)
self.dataset_file = dataset_file


class TFRecordDataset(Dataset):
"""Mock the MindSpore TFRecordDataset class."""

def __init__(self, dataset_size=None, dataset_file=None):
super(TFRecordDataset, self).__init__(dataset_size)
self.dataset_file = dataset_file


class TextFileDataset(Dataset):
"""Mock the MindSpore TextFileDataset class."""

def __init__(self, dataset_size=None, dataset_file=None):
super(TextFileDataset, self).__init__(dataset_size)
self.dataset_file = dataset_file

Loading…
Cancel
Save