diff --git a/modelscope/utils/ast_utils.py b/modelscope/utils/ast_utils.py index 990a9571..263a81b3 100644 --- a/modelscope/utils/ast_utils.py +++ b/modelscope/utils/ast_utils.py @@ -36,6 +36,7 @@ SCAN_SUB_FOLDERS = [ ] INDEXER_FILE = 'ast_indexer' DECORATOR_KEY = 'decorators' +EXPRESS_KEY = 'express' FROM_IMPORT_KEY = 'from_imports' IMPORT_KEY = 'imports' FILE_NAME_KEY = 'filepath' @@ -45,6 +46,9 @@ INDEX_KEY = 'index' REQUIREMENT_KEY = 'requirements' MODULE_KEY = 'module' CLASS_NAME = 'class_name' +GROUP_KEY = 'group_key' +MODULE_NAME = 'module_name' +MODULE_CLS = 'module_cls' class AstScaning(object): @@ -53,6 +57,7 @@ class AstScaning(object): self.result_import = dict() self.result_from_import = dict() self.result_decorator = [] + self.express = [] def _is_sub_node(self, node: object) -> bool: return isinstance(node, @@ -108,6 +113,7 @@ class AstScaning(object): self.result_import = dict() self.result_from_import = dict() self.result_decorator = [] + self.result_express = [] def scan_ast(self, node: Union[ast.AST, None, str]): self._setup_global() @@ -243,13 +249,19 @@ class AstScaning(object): setattr(item, CLASS_NAME, node.name) self.result_decorator.extend(attr) + if attr != [] and type( + attr + ).__name__ == 'Call' and parent_node_name == 'Expr': + self.result_express.append(attr) + out += f'{indentstr()}{field}={representation},\n' out += indentstr() + ')' return { IMPORT_KEY: self.result_import, FROM_IMPORT_KEY: self.result_from_import, - DECORATOR_KEY: self.result_decorator + DECORATOR_KEY: self.result_decorator, + EXPRESS_KEY: self.result_express }, out def _parse_decorator(self, node: ast.AST) -> tuple: @@ -267,7 +279,10 @@ class AstScaning(object): def _get_args_name(nodes: list) -> list: result = [] for node in nodes: - result.append(_get_attribute_item(node)) + if type(node).__name__ == 'Str': + result.append((node.s, None)) + else: + result.append(_get_attribute_item(node)) return result def _get_keyword_name(nodes: ast.AST) -> list: @@ -276,9 +291,11 @@ class AstScaning(object): if type(node).__name__ == 'keyword': attribute_node = getattr(node, 'value') if type(attribute_node).__name__ == 'Str': - result.append((attribute_node.s, None)) + result.append((getattr(node, + 'arg'), attribute_node.s, None)) else: - result.append(_get_attribute_item(attribute_node)) + result.append((getattr(node, 'arg'), ) + + _get_attribute_item(attribute_node)) return result functions = _get_attribute_item(node.func) @@ -315,10 +332,26 @@ class AstScaning(object): args_list.append(default_group) if len(keyword_list) == 0 and len(args_list) == 1: args_list.append(class_name) - if len(keyword_list) == 1 and len(args_list) == 0: + + if len(keyword_list) > 0 and len(args_list) == 0: + remove_group_item = None + for item in keyword_list: + key, name, attr = item + if key == GROUP_KEY: + args_list.append((name, attr)) + remove_group_item = item + if remove_group_item is not None: + keyword_list.remove(remove_group_item) + + if len(args_list) == 0: args_list.append(default_group) - args_list.extend(keyword_list) + for item in keyword_list: + key, name, attr = item + if key == MODULE_CLS: + class_name = name + else: + args_list.append((name, attr)) for item in args_list: # the case empty input @@ -347,9 +380,14 @@ class AstScaning(object): for node in nodes: if type(node).__name__ != 'Call': continue + class_name = getattr(node, CLASS_NAME, None) + func = getattr(node, 'func') + + if getattr(func, 'attr', None) != REGISTER_MODULE: + continue + parse_output = self._parse_decorator(node) - index = self._registry_indexer(parse_output, - getattr(node, CLASS_NAME)) + index = self._registry_indexer(parse_output, class_name) if None is not index: results.append(index) return results @@ -363,6 +401,8 @@ class AstScaning(object): node = gast.parse(data) output, _ = self.scan_import(node, indent=' ', show_offsets=False) output[DECORATOR_KEY] = self.parse_decorators(output[DECORATOR_KEY]) + output[EXPRESS_KEY] = self.parse_decorators(output[EXPRESS_KEY]) + output[DECORATOR_KEY].extend(output[EXPRESS_KEY]) return output @@ -481,6 +521,13 @@ class FilesAstScaning(object): module_import[value_dict[MODULE_KEY]] = value_dict[IMPORT_KEY] return module_import + def _ignore_useless_keys(self, inverted_index): + if ('OPTIMIZERS', 'default', 'name') in inverted_index: + del inverted_index[('OPTIMIZERS', 'default', 'name')] + if ('LR_SCHEDULER', 'default', 'name') in inverted_index: + del inverted_index[('LR_SCHEDULER', 'default', 'name')] + return inverted_index + def get_files_scan_results(self, target_dir=MODELSCOPE_PATH, target_folders=SCAN_SUB_FOLDERS): @@ -514,6 +561,8 @@ class FilesAstScaning(object): MODULE_KEY: module_name } inverted_index_with_results = self._inverted_index(result) + inverted_index_with_results = self._ignore_useless_keys( + inverted_index_with_results) module_import = self._module_import(result) index = { INDEX_KEY: inverted_index_with_results,