Browse Source

[to #42322933]allow none decorator registry in ast

master
zhangzhicheng.zzc 3 years ago
parent
commit
af4c6f70c2
1 changed files with 57 additions and 8 deletions
  1. +57
    -8
      modelscope/utils/ast_utils.py

+ 57
- 8
modelscope/utils/ast_utils.py View File

@@ -36,6 +36,7 @@ SCAN_SUB_FOLDERS = [
]
INDEXER_FILE = 'ast_indexer'
DECORATOR_KEY = 'decorators'
EXPRESS_KEY = 'express'
FROM_IMPORT_KEY = 'from_imports'
IMPORT_KEY = 'imports'
FILE_NAME_KEY = 'filepath'
@@ -45,6 +46,9 @@ INDEX_KEY = 'index'
REQUIREMENT_KEY = 'requirements'
MODULE_KEY = 'module'
CLASS_NAME = 'class_name'
GROUP_KEY = 'group_key'
MODULE_NAME = 'module_name'
MODULE_CLS = 'module_cls'


class AstScaning(object):
@@ -53,6 +57,7 @@ class AstScaning(object):
self.result_import = dict()
self.result_from_import = dict()
self.result_decorator = []
self.express = []

def _is_sub_node(self, node: object) -> bool:
return isinstance(node,
@@ -108,6 +113,7 @@ class AstScaning(object):
self.result_import = dict()
self.result_from_import = dict()
self.result_decorator = []
self.result_express = []

def scan_ast(self, node: Union[ast.AST, None, str]):
self._setup_global()
@@ -243,13 +249,19 @@ class AstScaning(object):
setattr(item, CLASS_NAME, node.name)
self.result_decorator.extend(attr)

if attr != [] and type(
attr
).__name__ == 'Call' and parent_node_name == 'Expr':
self.result_express.append(attr)

out += f'{indentstr()}{field}={representation},\n'

out += indentstr() + ')'
return {
IMPORT_KEY: self.result_import,
FROM_IMPORT_KEY: self.result_from_import,
DECORATOR_KEY: self.result_decorator
DECORATOR_KEY: self.result_decorator,
EXPRESS_KEY: self.result_express
}, out

def _parse_decorator(self, node: ast.AST) -> tuple:
@@ -267,7 +279,10 @@ class AstScaning(object):
def _get_args_name(nodes: list) -> list:
result = []
for node in nodes:
result.append(_get_attribute_item(node))
if type(node).__name__ == 'Str':
result.append((node.s, None))
else:
result.append(_get_attribute_item(node))
return result

def _get_keyword_name(nodes: ast.AST) -> list:
@@ -276,9 +291,11 @@ class AstScaning(object):
if type(node).__name__ == 'keyword':
attribute_node = getattr(node, 'value')
if type(attribute_node).__name__ == 'Str':
result.append((attribute_node.s, None))
result.append((getattr(node,
'arg'), attribute_node.s, None))
else:
result.append(_get_attribute_item(attribute_node))
result.append((getattr(node, 'arg'), )
+ _get_attribute_item(attribute_node))
return result

functions = _get_attribute_item(node.func)
@@ -315,10 +332,26 @@ class AstScaning(object):
args_list.append(default_group)
if len(keyword_list) == 0 and len(args_list) == 1:
args_list.append(class_name)
if len(keyword_list) == 1 and len(args_list) == 0:

if len(keyword_list) > 0 and len(args_list) == 0:
remove_group_item = None
for item in keyword_list:
key, name, attr = item
if key == GROUP_KEY:
args_list.append((name, attr))
remove_group_item = item
if remove_group_item is not None:
keyword_list.remove(remove_group_item)

if len(args_list) == 0:
args_list.append(default_group)

args_list.extend(keyword_list)
for item in keyword_list:
key, name, attr = item
if key == MODULE_CLS:
class_name = name
else:
args_list.append((name, attr))

for item in args_list:
# the case empty input
@@ -347,9 +380,14 @@ class AstScaning(object):
for node in nodes:
if type(node).__name__ != 'Call':
continue
class_name = getattr(node, CLASS_NAME, None)
func = getattr(node, 'func')

if getattr(func, 'attr', None) != REGISTER_MODULE:
continue

parse_output = self._parse_decorator(node)
index = self._registry_indexer(parse_output,
getattr(node, CLASS_NAME))
index = self._registry_indexer(parse_output, class_name)
if None is not index:
results.append(index)
return results
@@ -363,6 +401,8 @@ class AstScaning(object):
node = gast.parse(data)
output, _ = self.scan_import(node, indent=' ', show_offsets=False)
output[DECORATOR_KEY] = self.parse_decorators(output[DECORATOR_KEY])
output[EXPRESS_KEY] = self.parse_decorators(output[EXPRESS_KEY])
output[DECORATOR_KEY].extend(output[EXPRESS_KEY])
return output


@@ -481,6 +521,13 @@ class FilesAstScaning(object):
module_import[value_dict[MODULE_KEY]] = value_dict[IMPORT_KEY]
return module_import

def _ignore_useless_keys(self, inverted_index):
if ('OPTIMIZERS', 'default', 'name') in inverted_index:
del inverted_index[('OPTIMIZERS', 'default', 'name')]
if ('LR_SCHEDULER', 'default', 'name') in inverted_index:
del inverted_index[('LR_SCHEDULER', 'default', 'name')]
return inverted_index

def get_files_scan_results(self,
target_dir=MODELSCOPE_PATH,
target_folders=SCAN_SUB_FOLDERS):
@@ -514,6 +561,8 @@ class FilesAstScaning(object):
MODULE_KEY: module_name
}
inverted_index_with_results = self._inverted_index(result)
inverted_index_with_results = self._ignore_useless_keys(
inverted_index_with_results)
module_import = self._module_import(result)
index = {
INDEX_KEY: inverted_index_with_results,


Loading…
Cancel
Save