Browse Source

report do not contain if not convert

add for run test

add for test

use another script

line endswith \n

continue

sorted convert info

annotated import

startswith

add info in list

func ok

report more clear and pylint fix

delete devil figure

comment format more legal

use strip to define start info of line
tags/v0.5.0-beta
quyongxiu1 5 years ago
parent
commit
cc64d2eac8
1 changed files with 102 additions and 47 deletions
  1. +102
    -47
      mindinsight/mindconverter/converter.py

+ 102
- 47
mindinsight/mindconverter/converter.py View File

@@ -17,7 +17,6 @@ import copy
import importlib
import inspect
import os
import re
import stat

from mindinsight.mindconverter.config import ALL_MAPPING
@@ -29,6 +28,8 @@ from mindinsight.mindconverter.config import ALL_UNSUPPORTED
from mindinsight.mindconverter.common.log import logger
from mindinsight.mindconverter.forward_call import ForwardCall

LINE_NO_INDEX_DIFF = 1


class Converter:
"""Convert class"""
@@ -198,6 +199,7 @@ class Converter:
raise ValueError('"(" not found, {} should work with "("'.format(call_name))
right = self.find_right_parentheses(code, left)
end = right

expr = code[start:end + 1]
args_str = code[left:right + 1]

@@ -337,6 +339,96 @@ class Converter:
mapping.update(convert_fun(*args))
return mapping

@staticmethod
def get_code_start_line_num(source_lines):
"""
Get the start code line number exclude comments.

Args:
source_lines (list[str]): Split results of original code.

Returns:
int, the start line number.
"""
stack = []
index = 0
for i, line in enumerate(source_lines):
if line.strip().startswith('#'):
continue
if line.strip().startswith('"""'):
if not line.endswith('"""\n'):
stack.append('"""')
continue
if line.strip().startswith("'''"):
if not line.endswith("'''\n"):
stack.append("'''")
continue
if line.endswith('"""\n') or line.endswith("'''\n"):
stack.pop()
continue
if line.strip() != '' and not stack:
index = i
break
return index

def update_code_and_convert_info(self, code, mapping):
"""
Replace code according to mapping, and update convert info.

Args:
code (str): The code to replace.
mapping (dict): Mapping for original code and the replaced code.

Returns:
str, the replaced code.
"""

for key, value in mapping.items():
code = code.replace(key, value)

source_lines = code.splitlines(keepends=True)
start_line_number = self.get_code_start_line_num(source_lines)
add_import_infos = ['import mindspore\n',
'import mindspore.nn as nn\n',
'import mindspore.ops.operations as P\n']
for i, add_import_info in enumerate(add_import_infos):
source_lines.insert(start_line_number + i, add_import_info)
self.convert_info += '[Add Import] {}.\n'.format(add_import_info.strip())

insert_count = len(add_import_infos)
line_diff = insert_count - LINE_NO_INDEX_DIFF

for i in range(start_line_number + insert_count, len(source_lines)):
line = source_lines[i]

if (line.startswith('from torch') and 'import' in line) or line.startswith('import torch'):
new_line = '# ' + line
source_lines[i] = new_line
self.convert_info += '[Annotate][Line{:3d}] {} is annotated.\n'.format(i - line_diff, line.strip())
if line.strip().startswith('class') and '(nn.Module)' in line:
new_line = line.replace('nn.Module', 'nn.Cell')
source_lines[i] = new_line
self.convert_info += '[Convert][Line{:3d}] nn.Module is converted.\n'.format(i - line_diff)
if line.strip().startswith('def forward('):
new_line = line.replace('forward', 'construct')
source_lines[i] = new_line
self.convert_info += '[Convert][Line{:3d}] forward is converted.\n'.format(i - line_diff)
if 'nn.Linear' in line:
new_line = line.replace('nn.Linear', 'nn.Dense')
source_lines[i] = new_line
self.convert_info += '[Convert][Line{:3d}] nn.Linear is converted.\n'.format(i - line_diff)
if '(nn.Sequential)' in line:
new_line = line.replace('nn.Sequential', 'nn.SequentialCell')
source_lines[i] = new_line
self.convert_info += '[Convert][Line{:3d}] nn.Sequential is converted.\n'.format(i - line_diff)
if 'nn.init.' in line:
new_line = line.replace('nn.init', 'pass # nn.init')
source_lines[i] = new_line
self.convert_info += '[Annotate][Line{:3d}] {} is annotated.\n'.format(i - line_diff, 'nn.init')

code = ''.join(source_lines)
return code

def convert(self, import_name, output_dir, report_dir):
"""
Convert a module's code, code converted will be save in output_dir, and a report will be save in report_dir.
@@ -347,10 +439,10 @@ class Converter:
report_dir (str): The path to save report file.
"""
logger.info("Start converting %s", import_name)
self.convert_info += '[Start Convert]\nThe module is {}\n'.format(import_name)
start_info = '[Start Convert]\n'
module_info = 'The module is {}.\n'.format(import_name)

import_mod = importlib.import_module(import_name)

srcfile = inspect.getsourcefile(import_mod)
logger.info("Script file is %s", srcfile)

@@ -359,50 +451,14 @@ class Converter:

# replace python function under nn.Module
mapping = self.get_mapping(import_mod, forward_list)

code = inspect.getsource(import_mod)
for key, value in mapping.items():
code = code.replace(key, value)

source_lines = code.splitlines(keepends=True)
valid_line_num = 0

# find the first valid code line of the source
for num, source in enumerate(source_lines):
if re.search(r'^[a-z]\w+', source):
valid_line_num = num
break
source_lines.insert(valid_line_num, 'import mindspore.ops.operations as P\n')
source_lines.insert(valid_line_num, 'import mindspore.nn as nn\n')
source_lines.insert(valid_line_num, 'import mindspore\n')

code = ''.join(source_lines)

self.convert_info += '||[Import Add] Add follow import sentences:\n'
self.convert_info += 'import mindspore.ops.operations as P\n'
self.convert_info += 'import mindspore.nn as nn\n'
self.convert_info += 'import mindspore\n\n'

code = code.replace('import torch', '# import torch')
code = code.replace('from torch', '# from torch')
code = code.replace('(nn.Module):', '(nn.Cell):')
code = code.replace('forward(', 'construct(')
code = code.replace('nn.Linear', 'nn.Dense')
code = code.replace('(nn.Sequential)', '(nn.SequentialCell)')
code = code.replace('nn.init.', 'pass # nn.init.')

self.convert_info += '||[Import Annotated] Annotated follow import sentences:\n'
self.convert_info += 'import sentence on torch as follows are annotated:\n'
self.convert_info += 'import torch\n'
self.convert_info += 'from torch ...\n'

self.convert_info += '||[Explicit Convert] Module or function are explicitly converted as follows:\n'
self.convert_info += '[nn.Module] is converted to [nn.Cell]\n'
self.convert_info += '[forward] is converted to [construct]\n'
self.convert_info += '[nn.Linear] is converted to [nn.Dense]\n'
self.convert_info += '[nn.Sequential] is converted to [nn.SequentialCell]\n'
self.convert_info += '[nn.init] is not converted and annotated\n'
self.convert_info += '[Convert over]'
code = self.update_code_and_convert_info(code, mapping)
convert_info_split = self.convert_info.splitlines(keepends=True)
convert_info_split = sorted(convert_info_split)
convert_info_split.insert(0, start_info)
convert_info_split.insert(1, module_info)
convert_info_split.append('[Convert Over]')
self.convert_info = ''.join(convert_info_split)

dest_file = os.path.join(output_dir, os.path.basename(srcfile))
with os.fdopen(os.open(dest_file, self.flags, self.modes), 'w') as file:
@@ -439,7 +495,6 @@ def _path_split(file):

Returns:
list[str], list of file tail

"""
file_dir, name = os.path.split(file)
if file_dir:


Loading…
Cancel
Save