@@ -15,17 +15,18 @@
"""API config"""
import ast
from collections import OrderedDict
from functools import partial
from importlib import import_module
import json
import os
import pasta
from mindinsight.mindconverter.enums import RequriedType
from mindinsight.mindconverter.common.log import logger
REQUIRED = RequriedType.REQUIRED.name
UNREQUIRED = RequriedType.UNREQUIRED.name
REQUIRED = 'REQUIRED'
UNREQUIRED = 'UNREQUIRED'
FUNC_MODULE = 'mindinsight.mindconverter.funcs'
class APIPt:
@@ -250,88 +251,65 @@ class MappingHelper:
return expr_ms
def gen_explicit_map_nn_sequential(_, args_pt ):
def get_ms_api(ms_api_info ):
"""
Generate explicit_map for nn.Sequential .
Get APIMs instance from ms_api_info .
Args:
args_pt (dict): Args for APIPt.
ms_api_info (list): info for create an APIMs instance, the first value in list is name for APIMs, the second(if
provided) is params for APIMs, the third(if provided) is p_attrs for APIMs.
Returns:
dict, map between frames .
APIMs, instance of APIMs parsed from given info .
"""
args = args_pt['*args']
return {"*args": "[{}]".format(args)}
ms_name = ms_api_info[0]
ms_params = ms_api_info[1] if len(ms_api_info) >= 2 else None
ms_p_attrs = set(ms_api_info[2]) if len(ms_api_info) >= 3 else None
ms_api = APIMs(name=ms_name, params=ms_params, p_attrs=ms_p_attrs)
return ms_api
def gen_explicit_map_nn_maxpool2d(params_pt, args_pt):
def get_pt_api(pt_api_info ):
"""
Generate explicit_map for nn.MaxPool2d .
Get APIPt instance from pt_api_info .
Args:
params_pt (dict): Params for APIPt.
args_pt (dict): Arg s for APIPt.
pt_api_info (list): info for create an APIMs instance, the first value in list is name for APIPt, the second(if
provided) is param s for APIPt.
Returns:
dict, map between frames.
"""
if 'padding' in args_pt:
padding = args_pt['padding']
else:
padding = params_pt['padding']
if padding.strip() in ("0", "(0,0)", "(0, 0)"):
pad_mode = "'valid'"
else:
pad_mode = "'same'"
return {"pad_mode": pad_mode}
def gen_explicit_map_f_max_pool2d(params_pt, args_pt):
APIMs, instance of APIMs parsed from given info.
"""
Generate explicit_map for F.MaxPool2d.
pt_name = pt_api_info[0]
pt_params = pt_api_info[1] if len(pt_api_info) >= 2 else None
pt_api = APIPt(name=pt_name, params=pt_params)
return pt_api
Args:
params_pt (dict): Params for APIPt.
args_pt (dict): Args for APIPt.
Returns:
dict, map between frames.
def get_mapping_from_file(path):
"""
if 'padding' in args_pt:
padding = args_pt['padding']
else:
padding = params_pt['padding']
if padding.strip() in ("0", "(0,0)", "(0, 0)"):
padding = "'valid'"
else:
padding = "'same'"
return {"padding": padding}
def gen_explicit_map_one_delta(params_pt, args_pt, k_ms, k_pt):
"""
Generate explicit_map for which include mapping relationship is `1 - k_ms = k_pt`.
Parse mapping info from given file.
Args:
params_pt (dict): Params for APIPt.
args_pt (dict): Args for APIPt.
path (str): The file path.
Returns:
dict, map between frames .
dict, key is op name, value is a relevant instance of MappingHelper.
"""
value = args_pt[k_pt] if k_pt in args_pt else params_pt[k_pt]
value = value.strip()
def is_number(string):
try:
float(string)
return True
except ValueError:
return False
if is_number(value):
return {k_ms: str(1 - float(value))}
return {k_ms: "1.0 - " + value}
mapping_info_d = load_json_file(path)
parse_mapping_dict = {}
for key, value in mapping_info_d.items():
ms_api_info = value.pop('ms_api')
ms_api = get_ms_api(ms_api_info)
pt_api_info = value.pop('pt_api')
pt_api = get_pt_api(pt_api_info)
gen_explicit_map = value.get('gen_explicit_map')
if gen_explicit_map:
module_name = import_module(FUNC_MODULE)
value['gen_explicit_map'] = getattr(module_name, gen_explicit_map)
parse_mapping_dict.update({key: MappingHelper(**dict(ms_api=ms_api, pt_api=pt_api), **value)})
return parse_mapping_dict
def load_json_file(file_path):
@@ -350,244 +328,38 @@ def load_json_file(file_path):
# ---------------------------- mappings ----------------------------
NN_MAPPING = {
'nn.Sequential': MappingHelper(**{"ms_api": APIMs('nn.SequentialCell', OrderedDict([('*args', REQUIRED)])),
"pt_api": APIPt('nn.Sequential', OrderedDict([('*args', REQUIRED)])),
"gen_explicit_map": gen_explicit_map_nn_sequential,
"export_key": False
}),
'nn.Conv2d': MappingHelper(**{"ms_api": APIMs('nn.Conv2d', OrderedDict(in_channels=REQUIRED,
out_channels=REQUIRED,
kernel_size=REQUIRED,
stride=1,
pad_mode='same',
padding=0,
dilation=1,
group=1,
has_bias=False,
weight_init='normal',
bias_init='zeros')),
"pt_api": APIPt('nn.Conv2d', OrderedDict(in_channels=REQUIRED,
out_channels=REQUIRED,
kernel_size=REQUIRED,
stride=1,
padding=0,
dilation=1,
groups=1,
bias=True,
padding_mode='zeros')),
"ms2pt_mapping": {'in_channels': 'in_channels',
'out_channels': 'out_channels',
'kernel_size': 'kernel_size',
'stride': 'stride',
'padding': 'padding',
'dilation': 'dilation',
'group': 'groups',
'has_bias': 'bias'},
"gen_explicit_map": (lambda params_pt, args_pt: {"pad_mode": "'pad'"})
}),
'nn.BatchNorm2d': MappingHelper(**{"ms_api": APIMs('nn.BatchNorm2d', OrderedDict(num_features=REQUIRED,
eps=1e-5,
momentum=0.9,
affine=True,
gamma_init='ones',
beta_init='zeros',
moving_mean_init='zeros',
moving_var_init='ones',
use_batch_statistics=True)),
"pt_api": APIPt('nn.BatchNorm2d', OrderedDict(num_features=REQUIRED,
eps=1e-5,
momentum=0.1,
affine=True,
track_running_stats=True)),
"ms2pt_mapping": {"num_features": "num_features",
"eps": "eps",
"affine": "affine",
"use_batch_statistics": "track_running_stats"},
"gen_explicit_map": partial(gen_explicit_map_one_delta,
k_ms="momentum", k_pt="momentum")
}),
'nn.ReLU': MappingHelper(**{"ms_api": APIMs('nn.ReLU', OrderedDict()),
"pt_api": APIPt('nn.ReLU', OrderedDict(inplace=False)),
"ms2pt_mapping": {}}),
'nn.ReLU6': MappingHelper(**{"ms_api": APIMs('nn.ReLU6', OrderedDict()),
"pt_api": APIPt('nn.ReLU6', OrderedDict(inplace=False)),
"ms2pt_mapping": {}}),
'nn.Linear': MappingHelper(**{"ms_api": APIMs('nn.Dense', OrderedDict(in_channels=REQUIRED,
out_channels=REQUIRED,
weight_init='normal',
bias_init='zeros',
has_bias=True,
activation=None)),
"pt_api": APIPt('nn.Linear', OrderedDict(in_features=REQUIRED,
out_features=REQUIRED,
bias=True)),
"ms2pt_mapping": {"in_channels": "in_features",
"out_channels": "out_features",
"has_bias": "bias"}
}),
'nn.MaxPool2d': MappingHelper(**{"ms_api": APIMs('nn.MaxPool2d', OrderedDict(kernel_size=1,
stride=1,
pad_mode="valid")),
"pt_api": APIPt('nn.MaxPool2d', OrderedDict(kernel_size=REQUIRED,
stride=None,
padding=0,
dilation=1,
return_indices=False,
ceil_mode="False")),
"ms2pt_mapping": {"kernel_size": "kernel_size",
"stride": "stride"},
"gen_explicit_map": gen_explicit_map_nn_maxpool2d
}),
'nn.AvgPool2d': MappingHelper(**{"ms_api": APIMs('nn.AvgPool2d', OrderedDict(kernel_size=1,
stride=1,
pad_mode="valid")),
"pt_api": APIPt('nn.AvgPool2d', OrderedDict(kernel_size=REQUIRED,
stride=None,
padding=0,
dilation=1,
return_indices=False,
ceil_mode="False")),
"ms2pt_mapping": {"kernel_size": "kernel_size",
"stride": "stride"},
"gen_explicit_map": gen_explicit_map_nn_maxpool2d
}),
'nn.Dropout': MappingHelper(**{"ms_api": APIMs('nn.Dropout', OrderedDict(keep_prob=0.5,
seed0=0,
seed1=0,
dtype="mstype.float32")),
"pt_api": APIPt('nn.Dropout', OrderedDict(p=0.5,
inplace=False)),
"ms2pt_mapping": {"keep_prob": "p"},
"gen_explicit_map": partial(gen_explicit_map_one_delta,
k_ms="keep_prob", k_pt="p")
})
}
NN_MAPPING_PATH = os.path.realpath(os.path.join(os.path.dirname(__file__), 'mappings/nn_mappings.json'))
NN_MAPPING = get_mapping_from_file(NN_MAPPING_PATH)
# update to add key with full api_name, which starts with 'torch.nn.'
NN_MAPPING.update({"torch." + k: v for k, v in NN_MAPPING.items()})
F_MAPPING_PATH = os.path.realpath(os.path.join(os.path.dirname(__file__), 'mappings/f_mappings.json'))
F_MAPPING = get_mapping_from_file(F_MAPPING_PATH)
# update to add key starts with 'nn.functional.'
NN_FUNCTIONAL_D = {"nn.functional." + k[len('F.'):]: v for k, v in F_MAPPING.items()}
# update to add key starts with 'torch.nn.functiona.l'
TORCH_NN_FUNCTIONAL_D = {"torch.nn.functional." + k[len('F.'):]: v for k, v in F_MAPPING.items()}
F_MAPPING.update(NN_FUNCTIONAL_D)
F_MAPPING.update(TORCH_NN_FUNCTIONAL_D)
F_MAPPING = {
'F.relu': MappingHelper(**{"ms_api": APIMs('P.ReLU', OrderedDict(input=REQUIRED)),
"pt_api": APIPt('F.relu', OrderedDict(input=REQUIRED, inplace=False)),
"ms2pt_mapping": {"input": "input"},
}),
'F.relu6': MappingHelper(**{"ms_api": APIMs('P.ReLU6', OrderedDict(input=REQUIRED)),
"pt_api": APIPt('F.relu6', OrderedDict(input=REQUIRED, inplace=False)),
"ms2pt_mapping": {"input": "input"},
}),
'F.max_pool2d': MappingHelper(**{"ms_api": APIMs('P.MaxPool', OrderedDict(ksize=1,
strides=1,
padding="valid",
input=REQUIRED),
p_attrs={"ksize", "strides", "padding"}),
"pt_api": APIPt('F.max_pool2d', OrderedDict(input=REQUIRED,
kernel_size=REQUIRED,
stride=None,
padding=0,
dilation=1,
ceil_mode=False,
return_indices=False)),
"ms2pt_mapping": {"ksize": "kernel_size",
"strides": "stride",
"input": "input",
},
"gen_explicit_map": gen_explicit_map_f_max_pool2d
}),
'F.avg_pool2d': MappingHelper(**{"ms_api": APIMs('P.AvgPool', OrderedDict(ksize=1,
strides=1,
padding="valid",
input=REQUIRED),
p_attrs={"ksize", "strides", "padding"}),
"pt_api": APIPt('F.avg_pool2d', OrderedDict(input=REQUIRED,
kernel_size=REQUIRED,
stride=None,
padding=0,
dilation=1,
ceil_mode=False,
return_indices=False)),
"ms2pt_mapping": {"ksize": "kernel_size",
"strides": "stride",
"input": "input",
},
"gen_explicit_map": gen_explicit_map_f_max_pool2d
}),
}
nn_functional_d = {"nn.functional." + k[2:]: v for k, v in F_MAPPING.items()}
torch_nn_functional_d = {"torch.nn.functional." + k[2:]: v for k, v in F_MAPPING.items()}
F_MAPPING.update(nn_functional_d)
F_MAPPING.update(torch_nn_functional_d)
TORCH_DOT_MAPPING = {
'torch.flatten': MappingHelper(**{"ms_api": APIMs('P.Flatten', OrderedDict(input=REQUIRED)),
"pt_api": APIPt('torch.flatten', OrderedDict(input=REQUIRED,
start_dim=0,
end_dim=-1)),
"ms2pt_mapping": {"input": "input"}
}),
'torch.cat': MappingHelper(**{"ms_api": APIMs('P.Concat',
OrderedDict(axis=0, input=REQUIRED),
p_attrs={"axis"}),
"pt_api": APIPt('torch.flatten', OrderedDict(tensors=REQUIRED, dim=0, out=None)),
"ms2pt_mapping": {"input": "tensors",
"axis": "dim"}
}),
}
TENSOR_DOT_MAPPING = {
'.view': MappingHelper(**{"ms_api": APIMs('P.Reshape', OrderedDict(x=REQUIRED, shape=REQUIRED)),
"pt_api": APIPt('.view', OrderedDict([('*shape', REQUIRED)])),
"ms2pt_mapping": {"x": "call_name"},
"gen_explicit_map": (lambda params_pt, args_pt: {"shape": "(" + args_pt["*shape"] + ",)"})
}),
'.size': MappingHelper(**{"ms_api": APIMs('P.Shape', OrderedDict(x=REQUIRED)),
"pt_api": APIPt('.size', OrderedDict([('idx', REQUIRED)])),
"ms2pt_mapping": {"x": "call_name"}
}),
'.flatten': MappingHelper(**{"ms_api": APIMs('P.Flatten', OrderedDict(input=REQUIRED)),
"pt_api": APIPt('.flatten', OrderedDict(start_dim=0,
end_dim=-1)),
"ms2pt_mapping": {"input": "call_name"}
}),
'.reshape': MappingHelper(**{"ms_api": APIMs('P.Reshape', OrderedDict(x=REQUIRED, shape=REQUIRED)),
"pt_api": APIPt('.reshape', OrderedDict([('*shape', REQUIRED)])),
"ms2pt_mapping": {"x": "call_name"},
"gen_explicit_map": (
lambda params_pt, args_pt: {"shape": "(" + args_pt["*shape"] + ",)"})
}),
'.mean': MappingHelper(**{"ms_api": APIMs('P.ReduceMean', OrderedDict(keep_dims=False,
input=REQUIRED,
axis=()),
p_attrs={"keep_dims"}),
"pt_api": APIPt('.mean', OrderedDict(dim=None,
keepdim=False)),
"ms2pt_mapping": {"keep_dims": "keepdim",
"axis": "dim",
"input": "call_name"},
}),
'.squeeze': MappingHelper(**{"ms_api": APIMs('P.ReduceMean', OrderedDict(input=REQUIRED,
axis=()),
p_attrs={"axis"}),
"pt_api": APIPt('.squeeze', OrderedDict(dim=None)),
"ms2pt_mapping": {"axis": "dim",
"input": "call_name"},
}),
}
TORCH_DOT_MAPPING_PATH = os.path.realpath(os.path.join(os.path.dirname(__file__), 'mappings/torch_dot_mappings.json'))
TORCH_DOT_MAPPING = get_mapping_from_file(TORCH_DOT_MAPPING_PATH)
TENSOR_DOT_MAPPING_PATH = os.path.realpath(os.path.join(os.path.dirname(__file__), 'mappings/tensor_dot_mappings.json'))
TENSOR_DOT_MAPPING = get_mapping_from_file(TENSOR_DOT_MAPPING_PATH)
ALL_MAPPING = {**NN_MAPPING, **F_MAPPING, **TORCH_DOT_MAPPING, **TENSOR_DOT_MAPPING}
# ---------------------------- api list support or not support ----------------------------
NN_LIST_PATH = os.path.realpath(os.path.join(os.path.dirname(__file__), 'nn_list.json'))
NN_LIST_PATH = os.path.realpath(os.path.join(os.path.dirname(__file__), 'ops', 'nn_list.json'))
NN_LIST = load_json_file(NN_LIST_PATH)
NN_LIST += ["torch." + name for name in NN_LIST]
NN_SUPPORTED = [x for x in NN_LIST if x in ALL_MAPPING]
NN_UNSUPPORTED = [x for x in NN_LIST if x not in ALL_MAPPING]
F_LIST_PATH = os.path.realpath(os.path.join(os.path.dirname(__file__), 'f_list.json'))
F_LIST_PATH = os.path.realpath(os.path.join(os.path.dirname(__file__), 'ops', 'f_list.json'))
F_LIST = load_json_file(F_LIST_PATH)
F_LIST += ["F." + name[len("torch.nn.functional."):] for name in F_LIST] + \
[name[len("torch."):] for name in F_LIST]
@@ -595,7 +367,7 @@ F_SUPPORTED = [x for x in F_LIST if x in ALL_MAPPING]
F_UNSUPPORTED = [x for x in F_LIST if x not in ALL_MAPPING]
TORCH_DOT_LIST_PATH = os.path.realpath(os.path.join(os.path.dirname(__file__), 'torch_dot_list.json'))
TORCH_DOT_LIST_PATH = os.path.realpath(os.path.join(os.path.dirname(__file__), 'ops', ' torch_dot_list.json'))
TORCH_DOT_LIST = load_json_file(TORCH_DOT_LIST_PATH)
@@ -603,7 +375,7 @@ TORCH_DOT_SUPPORTED = [x for x in TORCH_DOT_LIST if x in ALL_MAPPING]
TORCH_DOT_UNSUPPORTED = [x for x in TORCH_DOT_LIST if x not in ALL_MAPPING]
TENSOR_DOT_LIST_PATH = os.path.realpath(os.path.join(os.path.dirname(__file__), 'tensor_dot_list.json'))
TENSOR_DOT_LIST_PATH = os.path.realpath(os.path.join(os.path.dirname(__file__), 'ops', ' tensor_dot_list.json'))
TENSOR_DOT_LIST = load_json_file(TENSOR_DOT_LIST_PATH)
@@ -620,5 +392,5 @@ ALL_UNSUPPORTED = NN_UNSUPPORTED + F_UNSUPPORTED + TORCH_DOT_UNSUPPORTED + TENSO
UNSUPPORTED_WARN_INFOS = {
"nn.AdaptiveAvgPool2d": "maybe could convert to P.ReduceMean",
"F.adaptive_avg_pool2d": "maybe could convert to P.ReduceMean",
"F.dropout": "please use nn.Dropout in __init__()",
"F.dropout": "please use nn.Dropout in __init__()"
}