Browse Source

[to #44236829] record classname as default module name during ast-scanning

master
zhangzhicheng.zzc 3 years ago
parent
commit
67af7ee4fc
1 changed files with 7 additions and 3 deletions
  1. +7
    -3
      modelscope/utils/ast_utils.py

+ 7
- 3
modelscope/utils/ast_utils.py View File

@@ -43,6 +43,7 @@ MD5_KEY = 'md5'
INDEX_KEY = 'index'
REQUIREMENT_KEY = 'requirements'
MODULE_KEY = 'module'
CLASS_NAME = 'class_name'


class AstScaning(object):
@@ -237,6 +238,8 @@ class AstScaning(object):
['name']] = final_dict

if 'decorator_list' == field and attr != []:
for item in attr:
setattr(item, CLASS_NAME, node.name)
self.result_decorator.extend(attr)

out += f'{indentstr()}{field}={representation},\n'
@@ -294,7 +297,7 @@ class AstScaning(object):
else:
return getattr(eval(split_list[0]), split_list[1])

def _registry_indexer(self, parsed_input: tuple) -> tuple:
def _registry_indexer(self, parsed_input: tuple, class_name: str) -> tuple:
"""format registry information to a tuple indexer

Return:
@@ -310,7 +313,7 @@ class AstScaning(object):
if len(args_list) == 0 and len(keyword_list) == 0:
args_list.append(default_group)
if len(keyword_list) == 0 and len(args_list) == 1:
args_list.append(None)
args_list.append(class_name)
if len(keyword_list) == 1 and len(args_list) == 0:
args_list.append(default_group)

@@ -344,7 +347,8 @@ class AstScaning(object):
if type(node).__name__ != 'Call':
continue
parse_output = self._parse_decorator(node)
index = self._registry_indexer(parse_output)
index = self._registry_indexer(parse_output,
getattr(node, CLASS_NAME))
if None is not index:
results.append(index)
return results


Loading…
Cancel
Save