|
- import ast
- import contextlib
- import hashlib
- import importlib
- import os
- import os.path as osp
- import time
- import traceback
- from functools import reduce
- from pathlib import Path
- from typing import Generator, Union
-
- import gast
- import json
-
- from modelscope import __version__
- from modelscope.fileio.file import LocalStorage
- from modelscope.metainfo import (Datasets, Heads, Hooks, LR_Schedulers,
- Metrics, Models, Optimizers, Pipelines,
- Preprocessors, TaskModels, Trainers)
- from modelscope.utils.constant import Fields, Tasks
- from modelscope.utils.file_utils import get_default_cache_dir
- from modelscope.utils.logger import get_logger
- from modelscope.utils.registry import default_group
-
- logger = get_logger()
- storage = LocalStorage()
- p = Path(__file__)
-
- # get the path of package 'modelscope'
- MODELSCOPE_PATH = p.resolve().parents[1]
- REGISTER_MODULE = 'register_module'
- IGNORED_PACKAGES = ['modelscope', '.']
- SCAN_SUB_FOLDERS = [
- 'models', 'metrics', 'pipelines', 'preprocessors', 'trainers', 'msdatasets'
- ]
- INDEXER_FILE = 'ast_indexer'
- DECORATOR_KEY = 'decorators'
- EXPRESS_KEY = 'express'
- FROM_IMPORT_KEY = 'from_imports'
- IMPORT_KEY = 'imports'
- FILE_NAME_KEY = 'filepath'
- VERSION_KEY = 'version'
- MD5_KEY = 'md5'
- INDEX_KEY = 'index'
- REQUIREMENT_KEY = 'requirements'
- MODULE_KEY = 'module'
- CLASS_NAME = 'class_name'
- GROUP_KEY = 'group_key'
- MODULE_NAME = 'module_name'
- MODULE_CLS = 'module_cls'
-
-
- class AstScaning(object):
-
- def __init__(self) -> None:
- self.result_import = dict()
- self.result_from_import = dict()
- self.result_decorator = []
- self.express = []
-
- def _is_sub_node(self, node: object) -> bool:
- return isinstance(node,
- ast.AST) and not isinstance(node, ast.expr_context)
-
- def _is_leaf(self, node: ast.AST) -> bool:
- for field in node._fields:
- attr = getattr(node, field)
- if self._is_sub_node(attr):
- return False
- elif isinstance(attr, (list, tuple)):
- for val in attr:
- if self._is_sub_node(val):
- return False
- else:
- return True
-
- def _fields(self, n: ast.AST, show_offsets: bool = True) -> tuple:
- if show_offsets:
- return n._attributes + n._fields
- else:
- return n._fields
-
- def _leaf(self, node: ast.AST, show_offsets: bool = True) -> str:
- output = dict()
- local_print = list()
- if isinstance(node, ast.AST):
- local_dict = dict()
- for field in self._fields(node, show_offsets=show_offsets):
- field_output, field_prints = self._leaf(
- getattr(node, field), show_offsets=show_offsets)
- local_dict[field] = field_output
- local_print.append('{}={}'.format(field, field_prints))
-
- prints = '{}({})'.format(
- type(node).__name__,
- ', '.join(local_print),
- )
- output[type(node).__name__] = local_dict
- return output, prints
- elif isinstance(node, list):
- if '_fields' not in node:
- return node, repr(node)
- for item in node:
- item_output, item_prints = self._leaf(
- getattr(node, item), show_offsets=show_offsets)
- local_print.append(item_prints)
- return node, '[{}]'.format(', '.join(local_print), )
- else:
- return node, repr(node)
-
- def _refresh(self):
- self.result_import = dict()
- self.result_from_import = dict()
- self.result_decorator = []
- self.result_express = []
-
- def scan_ast(self, node: Union[ast.AST, None, str]):
- self._setup_global()
- self.scan_import(node, indent=' ', show_offsets=False)
-
- def scan_import(
- self,
- node: Union[ast.AST, None, str],
- indent: Union[str, int] = ' ',
- show_offsets: bool = True,
- _indent: int = 0,
- parent_node_name: str = '',
- ) -> tuple:
- if node is None:
- return node, repr(node)
- elif self._is_leaf(node):
- return self._leaf(node, show_offsets=show_offsets)
- else:
- if isinstance(indent, int):
- indent_s = indent * ' '
- else:
- indent_s = indent
-
- class state:
- indent = _indent
-
- @contextlib.contextmanager
- def indented() -> Generator[None, None, None]:
- state.indent += 1
- yield
- state.indent -= 1
-
- def indentstr() -> str:
- return state.indent * indent_s
-
- def _scan_import(el: Union[ast.AST, None, str],
- _indent: int = 0,
- parent_node_name: str = '') -> str:
- return self.scan_import(
- el,
- indent=indent,
- show_offsets=show_offsets,
- _indent=_indent,
- parent_node_name=parent_node_name)
-
- out = type(node).__name__ + '(\n'
- outputs = dict()
- # add relative path expression
- if type(node).__name__ == 'ImportFrom':
- level = getattr(node, 'level')
- if level >= 1:
- path_level = ''.join(['.'] * level)
- setattr(node, 'level', 0)
- module_name = getattr(node, 'module')
- if module_name is None:
- setattr(node, 'module', path_level)
- else:
- setattr(node, 'module', path_level + module_name)
- with indented():
- for field in self._fields(node, show_offsets=show_offsets):
- attr = getattr(node, field)
- if attr == []:
- representation = '[]'
- outputs[field] = []
- elif (isinstance(attr, list) and len(attr) == 1
- and isinstance(attr[0], ast.AST)
- and self._is_leaf(attr[0])):
- local_out, local_print = _scan_import(attr[0])
- representation = f'[{local_print}]'
- outputs[field] = local_out
-
- elif isinstance(attr, list):
- representation = '[\n'
- el_dict = dict()
- with indented():
- for el in attr:
- local_out, local_print = _scan_import(
- el, state.indent,
- type(el).__name__)
- representation += '{}{},\n'.format(
- indentstr(),
- local_print,
- )
- name = type(el).__name__
- if (name == 'Import' or name == 'ImportFrom'
- or parent_node_name == 'ImportFrom'
- or parent_node_name == 'Import'):
- if name not in el_dict:
- el_dict[name] = []
- el_dict[name].append(local_out)
- representation += indentstr() + ']'
- outputs[field] = el_dict
- elif isinstance(attr, ast.AST):
- output, representation = _scan_import(
- attr, state.indent)
- outputs[field] = output
- else:
- representation = repr(attr)
- outputs[field] = attr
-
- if (type(node).__name__ == 'Import'
- or type(node).__name__ == 'ImportFrom'):
- if type(node).__name__ == 'ImportFrom':
- if field == 'module':
- self.result_from_import[
- outputs[field]] = dict()
- if field == 'names':
- if isinstance(outputs[field]['alias'], list):
- item_name = []
- for item in outputs[field]['alias']:
- local_name = item['alias']['name']
- item_name.append(local_name)
- self.result_from_import[
- outputs['module']] = item_name
- else:
- local_name = outputs[field]['alias'][
- 'name']
- self.result_from_import[
- outputs['module']] = [local_name]
-
- if type(node).__name__ == 'Import':
- final_dict = outputs[field]['alias']
- if isinstance(final_dict, list):
- for item in final_dict:
- self.result_import[
- item['alias']['name']] = item['alias']
- else:
- self.result_import[outputs[field]['alias']
- ['name']] = final_dict
-
- if 'decorator_list' == field and attr != []:
- for item in attr:
- setattr(item, CLASS_NAME, node.name)
- self.result_decorator.extend(attr)
-
- if attr != [] and type(
- attr
- ).__name__ == 'Call' and parent_node_name == 'Expr':
- self.result_express.append(attr)
-
- out += f'{indentstr()}{field}={representation},\n'
-
- out += indentstr() + ')'
- return {
- IMPORT_KEY: self.result_import,
- FROM_IMPORT_KEY: self.result_from_import,
- DECORATOR_KEY: self.result_decorator,
- EXPRESS_KEY: self.result_express
- }, out
-
- def _parse_decorator(self, node: ast.AST) -> tuple:
-
- def _get_attribute_item(node: ast.AST) -> tuple:
- value, id, attr = None, None, None
- if type(node).__name__ == 'Attribute':
- value = getattr(node, 'value')
- id = getattr(value, 'id')
- attr = getattr(node, 'attr')
- if type(node).__name__ == 'Name':
- id = getattr(node, 'id')
- return id, attr
-
- def _get_args_name(nodes: list) -> list:
- result = []
- for node in nodes:
- if type(node).__name__ == 'Str':
- result.append((node.s, None))
- else:
- result.append(_get_attribute_item(node))
- return result
-
- def _get_keyword_name(nodes: ast.AST) -> list:
- result = []
- for node in nodes:
- if type(node).__name__ == 'keyword':
- attribute_node = getattr(node, 'value')
- if type(attribute_node).__name__ == 'Str':
- result.append((getattr(node,
- 'arg'), attribute_node.s, None))
- else:
- result.append((getattr(node, 'arg'), )
- + _get_attribute_item(attribute_node))
- return result
-
- functions = _get_attribute_item(node.func)
- args_list = _get_args_name(node.args)
- keyword_list = _get_keyword_name(node.keywords)
- return functions, args_list, keyword_list
-
- def _get_registry_value(self, key_item):
- if key_item is None:
- return None
- if key_item == 'default_group':
- return default_group
- split_list = key_item.split('.')
- # in the case, the key_item is raw data, not registred
- if len(split_list) == 1:
- return key_item
- else:
- return getattr(eval(split_list[0]), split_list[1])
-
- def _registry_indexer(self, parsed_input: tuple, class_name: str) -> tuple:
- """format registry information to a tuple indexer
-
- Return:
- tuple: (MODELS, Tasks.text-classification, Models.structbert)
- """
- functions, args_list, keyword_list = parsed_input
-
- # ignore decocators other than register_module
- if REGISTER_MODULE != functions[1]:
- return None
- output = [functions[0]]
-
- if len(args_list) == 0 and len(keyword_list) == 0:
- args_list.append(default_group)
- if len(keyword_list) == 0 and len(args_list) == 1:
- args_list.append(class_name)
-
- if len(keyword_list) > 0 and len(args_list) == 0:
- remove_group_item = None
- for item in keyword_list:
- key, name, attr = item
- if key == GROUP_KEY:
- args_list.append((name, attr))
- remove_group_item = item
- if remove_group_item is not None:
- keyword_list.remove(remove_group_item)
-
- if len(args_list) == 0:
- args_list.append(default_group)
-
- for item in keyword_list:
- key, name, attr = item
- if key == MODULE_CLS:
- class_name = name
- else:
- args_list.append((name, attr))
-
- for item in args_list:
- # the case empty input
- if item is None:
- output.append(None)
- # the case (default_group)
- elif item[1] is None:
- output.append(item[0])
- elif isinstance(item, str):
- output.append(item)
- else:
- output.append('.'.join(item))
- return (output[0], self._get_registry_value(output[1]),
- self._get_registry_value(output[2]))
-
- def parse_decorators(self, nodes: list) -> list:
- """parse the AST nodes of decorators object to registry indexer
-
- Args:
- nodes (list): list of AST decorator nodes
-
- Returns:
- list: list of registry indexer
- """
- results = []
- for node in nodes:
- if type(node).__name__ != 'Call':
- continue
- class_name = getattr(node, CLASS_NAME, None)
- func = getattr(node, 'func')
-
- if getattr(func, 'attr', None) != REGISTER_MODULE:
- continue
-
- parse_output = self._parse_decorator(node)
- index = self._registry_indexer(parse_output, class_name)
- if None is not index:
- results.append(index)
- return results
-
- def generate_ast(self, file):
- self._refresh()
- with open(file, 'r') as code:
- data = code.readlines()
- data = ''.join(data)
-
- node = gast.parse(data)
- output, _ = self.scan_import(node, indent=' ', show_offsets=False)
- output[DECORATOR_KEY] = self.parse_decorators(output[DECORATOR_KEY])
- output[EXPRESS_KEY] = self.parse_decorators(output[EXPRESS_KEY])
- output[DECORATOR_KEY].extend(output[EXPRESS_KEY])
- return output
-
-
- class FilesAstScaning(object):
-
- def __init__(self) -> None:
- self.astScaner = AstScaning()
- self.file_dirs = []
-
- def _parse_import_path(self,
- import_package: str,
- current_path: str = None) -> str:
- """
- Args:
- import_package (str): relative import or abs import
- current_path (str): path/to/current/file
- """
- if import_package.startswith(IGNORED_PACKAGES[0]):
- return MODELSCOPE_PATH + '/' + '/'.join(
- import_package.split('.')[1:]) + '.py'
- elif import_package.startswith(IGNORED_PACKAGES[1]):
- current_path_list = current_path.split('/')
- import_package_list = import_package.split('.')
- level = 0
- for index, item in enumerate(import_package_list):
- if item != '':
- level = index
- break
-
- abs_path_list = current_path_list[0:-level]
- abs_path_list.extend(import_package_list[index:])
- return '/' + '/'.join(abs_path_list) + '.py'
- else:
- return current_path
-
- def _traversal_import(
- self,
- import_abs_path,
- ):
- pass
-
- def parse_import(self, scan_result: dict) -> list:
- """parse import and from import dicts to a third party package list
-
- Args:
- scan_result (dict): including the import and from import result
-
- Returns:
- list: a list of package ignored 'modelscope' and relative path import
- """
- output = []
- output.extend(list(scan_result[IMPORT_KEY].keys()))
- output.extend(list(scan_result[FROM_IMPORT_KEY].keys()))
-
- # get the package name
- for index, item in enumerate(output):
- if '' == item.split('.')[0]:
- output[index] = '.'
- else:
- output[index] = item.split('.')[0]
-
- ignored = set()
- for item in output:
- for ignored_package in IGNORED_PACKAGES:
- if item.startswith(ignored_package):
- ignored.add(item)
- return list(set(output) - set(ignored))
-
- def traversal_files(self, path, check_sub_dir):
- self.file_dirs = []
- if check_sub_dir is None or len(check_sub_dir) == 0:
- self._traversal_files(path)
-
- for item in check_sub_dir:
- sub_dir = os.path.join(path, item)
- if os.path.isdir(sub_dir):
- self._traversal_files(sub_dir)
-
- def _traversal_files(self, path):
- dir_list = os.scandir(path)
- for item in dir_list:
- if item.name.startswith('__'):
- continue
- if item.is_dir():
- self._traversal_files(item.path)
- elif item.is_file() and item.name.endswith('.py'):
- self.file_dirs.append(item.path)
-
- def _get_single_file_scan_result(self, file):
- try:
- output = self.astScaner.generate_ast(file)
- except Exception as e:
- detail = traceback.extract_tb(e.__traceback__)
- raise Exception(
- f'During ast indexing, error is in the file {detail[-1].filename}'
- f' line: {detail[-1].lineno}: "{detail[-1].line}" with error msg: '
- f'"{type(e).__name__}: {e}"')
-
- import_list = self.parse_import(output)
- return output[DECORATOR_KEY], import_list
-
- def _inverted_index(self, forward_index):
- inverted_index = dict()
- for index in forward_index:
- for item in forward_index[index][DECORATOR_KEY]:
- inverted_index[item] = {
- FILE_NAME_KEY: index,
- IMPORT_KEY: forward_index[index][IMPORT_KEY],
- MODULE_KEY: forward_index[index][MODULE_KEY],
- }
- return inverted_index
-
- def _module_import(self, forward_index):
- module_import = dict()
- for index, value_dict in forward_index.items():
- module_import[value_dict[MODULE_KEY]] = value_dict[IMPORT_KEY]
- return module_import
-
- def _ignore_useless_keys(self, inverted_index):
- if ('OPTIMIZERS', 'default', 'name') in inverted_index:
- del inverted_index[('OPTIMIZERS', 'default', 'name')]
- if ('LR_SCHEDULER', 'default', 'name') in inverted_index:
- del inverted_index[('LR_SCHEDULER', 'default', 'name')]
- return inverted_index
-
- def get_files_scan_results(self,
- target_dir=MODELSCOPE_PATH,
- target_folders=SCAN_SUB_FOLDERS):
- """the entry method of the ast scan method
-
- Args:
- target_dir (str, optional): the absolute path of the target directory to be scaned. Defaults to None.
- target_folder (list, optional): the list of
- sub-folders to be scaned in the target folder.
- Defaults to SCAN_SUB_FOLDERS.
-
- Returns:
- dict: indexer of registry
- """
-
- self.traversal_files(target_dir, target_folders)
- start = time.time()
- logger.info(
- f'AST-Scaning the path "{target_dir}" with the following sub folders {target_folders}'
- )
-
- result = dict()
- for file in self.file_dirs:
- filepath = file[file.rfind('modelscope'):]
- module_name = filepath.replace(osp.sep, '.').replace('.py', '')
- decorator_list, import_list = self._get_single_file_scan_result(
- file)
- result[file] = {
- DECORATOR_KEY: decorator_list,
- IMPORT_KEY: import_list,
- MODULE_KEY: module_name
- }
- inverted_index_with_results = self._inverted_index(result)
- inverted_index_with_results = self._ignore_useless_keys(
- inverted_index_with_results)
- module_import = self._module_import(result)
- index = {
- INDEX_KEY: inverted_index_with_results,
- REQUIREMENT_KEY: module_import
- }
- logger.info(
- f'Scaning done! A number of {len(inverted_index_with_results)}'
- f' files indexed! Time consumed {time.time()-start}s')
- return index
-
- def files_mtime_md5(self,
- target_path=MODELSCOPE_PATH,
- target_subfolder=SCAN_SUB_FOLDERS):
- self.file_dirs = []
- self.traversal_files(target_path, target_subfolder)
- files_mtime = []
- for item in self.file_dirs:
- files_mtime.append(os.path.getmtime(item))
- result_str = reduce(lambda x, y: str(x) + str(y), files_mtime, '')
- md5 = hashlib.md5(result_str.encode())
- return md5.hexdigest()
-
-
- file_scanner = FilesAstScaning()
-
-
- def _save_index(index, file_path):
- # convert tuple key to str key
- index[INDEX_KEY] = {str(k): v for k, v in index[INDEX_KEY].items()}
- index[VERSION_KEY] = __version__
- index[MD5_KEY] = file_scanner.files_mtime_md5()
- json_index = json.dumps(index)
- storage.write(json_index.encode(), file_path)
- index[INDEX_KEY] = {
- ast.literal_eval(k): v
- for k, v in index[INDEX_KEY].items()
- }
-
-
- def _load_index(file_path):
- bytes_index = storage.read(file_path)
- wrapped_index = json.loads(bytes_index)
- # convert str key to tuple key
- wrapped_index[INDEX_KEY] = {
- ast.literal_eval(k): v
- for k, v in wrapped_index[INDEX_KEY].items()
- }
- return wrapped_index
-
-
- def load_index(force_rebuild=False):
- """get the index from scan results or cache
-
- Args:
- force_rebuild: If set true, rebuild and load index
- Returns:
- dict: the index information for all registred modules, including key:
- index, requirments, version and md5, the detail is shown below example:
- {
- 'index': {
- ('MODELS', 'nlp', 'bert'):{
- 'filepath' : 'path/to/the/registered/model', 'imports':
- ['os', 'torch', 'typeing'] 'module':
- 'modelscope.models.nlp.bert'
- },
- ...
- }, 'requirments': {
- 'modelscope.models.nlp.bert': ['os', 'torch', 'typeing'],
- 'modelscope.models.nlp.structbert': ['os', 'torch', 'typeing'],
- ...
- }, 'version': '0.2.3', 'md5': '8616924970fe6bc119d1562832625612',
- }
- """
- cache_dir = os.getenv('MODELSCOPE_CACHE', get_default_cache_dir())
- file_path = os.path.join(cache_dir, INDEXER_FILE)
- logger.info(f'Loading ast index from {file_path}')
- index = None
- if not force_rebuild and os.path.exists(file_path):
- wrapped_index = _load_index(file_path)
- md5 = file_scanner.files_mtime_md5()
- if (wrapped_index[VERSION_KEY] == __version__
- and wrapped_index[MD5_KEY] == md5):
- index = wrapped_index
-
- if index is None:
- if force_rebuild:
- logger.info('Force rebuilding ast index')
- else:
- logger.info(
- f'No valid ast index found from {file_path}, rebuilding ast index!'
- )
- index = file_scanner.get_files_scan_results()
- _save_index(index, file_path)
- logger.info(
- f'Loading done! Current index file version is {index[VERSION_KEY]}, '
- f'with md5 {index[MD5_KEY]}')
- return index
-
-
- def check_import_module_avaliable(module_dicts: dict) -> list:
- missed_module = []
- for module in module_dicts.keys():
- loader = importlib.find_loader(module)
- if loader is None:
- missed_module.append(module)
- return missed_module
-
-
- if __name__ == '__main__':
- index = load_index()
- print(index)
|