|
|
|
@@ -15,17 +15,18 @@ |
|
|
|
"""API config""" |
|
|
|
import ast |
|
|
|
from collections import OrderedDict |
|
|
|
from functools import partial |
|
|
|
from importlib import import_module |
|
|
|
import json |
|
|
|
import os |
|
|
|
|
|
|
|
import pasta |
|
|
|
|
|
|
|
from mindinsight.mindconverter.enums import RequriedType |
|
|
|
from mindinsight.mindconverter.common.log import logger |
|
|
|
|
|
|
|
REQUIRED = RequriedType.REQUIRED.name |
|
|
|
UNREQUIRED = RequriedType.UNREQUIRED.name |
|
|
|
|
|
|
|
REQUIRED = 'REQUIRED' |
|
|
|
UNREQUIRED = 'UNREQUIRED' |
|
|
|
FUNC_MODULE = 'mindinsight.mindconverter.funcs' |
|
|
|
|
|
|
|
|
|
|
|
class APIPt: |
|
|
|
@@ -250,88 +251,65 @@ class MappingHelper: |
|
|
|
return expr_ms |
|
|
|
|
|
|
|
|
|
|
|
def gen_explicit_map_nn_sequential(_, args_pt): |
|
|
|
def get_ms_api(ms_api_info): |
|
|
|
""" |
|
|
|
Generate explicit_map for nn.Sequential. |
|
|
|
Get APIMs instance from ms_api_info. |
|
|
|
|
|
|
|
Args: |
|
|
|
args_pt (dict): Args for APIPt. |
|
|
|
ms_api_info (list): info for create an APIMs instance, the first value in list is name for APIMs, the second(if |
|
|
|
provided) is params for APIMs, the third(if provided) is p_attrs for APIMs. |
|
|
|
|
|
|
|
Returns: |
|
|
|
dict, map between frames. |
|
|
|
APIMs, instance of APIMs parsed from given info. |
|
|
|
""" |
|
|
|
args = args_pt['*args'] |
|
|
|
return {"*args": "[{}]".format(args)} |
|
|
|
ms_name = ms_api_info[0] |
|
|
|
ms_params = ms_api_info[1] if len(ms_api_info) >= 2 else None |
|
|
|
ms_p_attrs = set(ms_api_info[2]) if len(ms_api_info) >= 3 else None |
|
|
|
ms_api = APIMs(name=ms_name, params=ms_params, p_attrs=ms_p_attrs) |
|
|
|
return ms_api |
|
|
|
|
|
|
|
|
|
|
|
def gen_explicit_map_nn_maxpool2d(params_pt, args_pt): |
|
|
|
def get_pt_api(pt_api_info): |
|
|
|
""" |
|
|
|
Generate explicit_map for nn.MaxPool2d. |
|
|
|
Get APIPt instance from pt_api_info. |
|
|
|
|
|
|
|
Args: |
|
|
|
params_pt (dict): Params for APIPt. |
|
|
|
args_pt (dict): Args for APIPt. |
|
|
|
pt_api_info (list): info for create an APIMs instance, the first value in list is name for APIPt, the second(if |
|
|
|
provided) is params for APIPt. |
|
|
|
|
|
|
|
Returns: |
|
|
|
dict, map between frames. |
|
|
|
""" |
|
|
|
if 'padding' in args_pt: |
|
|
|
padding = args_pt['padding'] |
|
|
|
else: |
|
|
|
padding = params_pt['padding'] |
|
|
|
if padding.strip() in ("0", "(0,0)", "(0, 0)"): |
|
|
|
pad_mode = "'valid'" |
|
|
|
else: |
|
|
|
pad_mode = "'same'" |
|
|
|
return {"pad_mode": pad_mode} |
|
|
|
|
|
|
|
|
|
|
|
def gen_explicit_map_f_max_pool2d(params_pt, args_pt): |
|
|
|
APIMs, instance of APIMs parsed from given info. |
|
|
|
""" |
|
|
|
Generate explicit_map for F.MaxPool2d. |
|
|
|
pt_name = pt_api_info[0] |
|
|
|
pt_params = pt_api_info[1] if len(pt_api_info) >= 2 else None |
|
|
|
pt_api = APIPt(name=pt_name, params=pt_params) |
|
|
|
return pt_api |
|
|
|
|
|
|
|
Args: |
|
|
|
params_pt (dict): Params for APIPt. |
|
|
|
args_pt (dict): Args for APIPt. |
|
|
|
|
|
|
|
Returns: |
|
|
|
dict, map between frames. |
|
|
|
def get_mapping_from_file(path): |
|
|
|
""" |
|
|
|
if 'padding' in args_pt: |
|
|
|
padding = args_pt['padding'] |
|
|
|
else: |
|
|
|
padding = params_pt['padding'] |
|
|
|
if padding.strip() in ("0", "(0,0)", "(0, 0)"): |
|
|
|
padding = "'valid'" |
|
|
|
else: |
|
|
|
padding = "'same'" |
|
|
|
return {"padding": padding} |
|
|
|
|
|
|
|
|
|
|
|
def gen_explicit_map_one_delta(params_pt, args_pt, k_ms, k_pt): |
|
|
|
""" |
|
|
|
Generate explicit_map for which include mapping relationship is `1 - k_ms = k_pt`. |
|
|
|
Parse mapping info from given file. |
|
|
|
|
|
|
|
Args: |
|
|
|
params_pt (dict): Params for APIPt. |
|
|
|
args_pt (dict): Args for APIPt. |
|
|
|
path (str): The file path. |
|
|
|
|
|
|
|
Returns: |
|
|
|
dict, map between frames. |
|
|
|
dict, key is op name, value is a relevant instance of MappingHelper. |
|
|
|
""" |
|
|
|
value = args_pt[k_pt] if k_pt in args_pt else params_pt[k_pt] |
|
|
|
value = value.strip() |
|
|
|
|
|
|
|
def is_number(string): |
|
|
|
try: |
|
|
|
float(string) |
|
|
|
return True |
|
|
|
except ValueError: |
|
|
|
return False |
|
|
|
|
|
|
|
if is_number(value): |
|
|
|
return {k_ms: str(1 - float(value))} |
|
|
|
return {k_ms: "1.0 - " + value} |
|
|
|
mapping_info_d = load_json_file(path) |
|
|
|
parse_mapping_dict = {} |
|
|
|
for key, value in mapping_info_d.items(): |
|
|
|
ms_api_info = value.pop('ms_api') |
|
|
|
ms_api = get_ms_api(ms_api_info) |
|
|
|
pt_api_info = value.pop('pt_api') |
|
|
|
pt_api = get_pt_api(pt_api_info) |
|
|
|
gen_explicit_map = value.get('gen_explicit_map') |
|
|
|
if gen_explicit_map: |
|
|
|
module_name = import_module(FUNC_MODULE) |
|
|
|
value['gen_explicit_map'] = getattr(module_name, gen_explicit_map) |
|
|
|
|
|
|
|
parse_mapping_dict.update({key: MappingHelper(**dict(ms_api=ms_api, pt_api=pt_api), **value)}) |
|
|
|
return parse_mapping_dict |
|
|
|
|
|
|
|
|
|
|
|
def load_json_file(file_path): |
|
|
|
@@ -350,244 +328,38 @@ def load_json_file(file_path): |
|
|
|
|
|
|
|
|
|
|
|
# ---------------------------- mappings ---------------------------- |
|
|
|
NN_MAPPING = { |
|
|
|
'nn.Sequential': MappingHelper(**{"ms_api": APIMs('nn.SequentialCell', OrderedDict([('*args', REQUIRED)])), |
|
|
|
"pt_api": APIPt('nn.Sequential', OrderedDict([('*args', REQUIRED)])), |
|
|
|
"gen_explicit_map": gen_explicit_map_nn_sequential, |
|
|
|
"export_key": False |
|
|
|
}), |
|
|
|
'nn.Conv2d': MappingHelper(**{"ms_api": APIMs('nn.Conv2d', OrderedDict(in_channels=REQUIRED, |
|
|
|
out_channels=REQUIRED, |
|
|
|
kernel_size=REQUIRED, |
|
|
|
stride=1, |
|
|
|
pad_mode='same', |
|
|
|
padding=0, |
|
|
|
dilation=1, |
|
|
|
group=1, |
|
|
|
has_bias=False, |
|
|
|
weight_init='normal', |
|
|
|
bias_init='zeros')), |
|
|
|
"pt_api": APIPt('nn.Conv2d', OrderedDict(in_channels=REQUIRED, |
|
|
|
out_channels=REQUIRED, |
|
|
|
kernel_size=REQUIRED, |
|
|
|
stride=1, |
|
|
|
padding=0, |
|
|
|
dilation=1, |
|
|
|
groups=1, |
|
|
|
bias=True, |
|
|
|
padding_mode='zeros')), |
|
|
|
"ms2pt_mapping": {'in_channels': 'in_channels', |
|
|
|
'out_channels': 'out_channels', |
|
|
|
'kernel_size': 'kernel_size', |
|
|
|
'stride': 'stride', |
|
|
|
'padding': 'padding', |
|
|
|
'dilation': 'dilation', |
|
|
|
'group': 'groups', |
|
|
|
'has_bias': 'bias'}, |
|
|
|
"gen_explicit_map": (lambda params_pt, args_pt: {"pad_mode": "'pad'"}) |
|
|
|
}), |
|
|
|
'nn.BatchNorm2d': MappingHelper(**{"ms_api": APIMs('nn.BatchNorm2d', OrderedDict(num_features=REQUIRED, |
|
|
|
eps=1e-5, |
|
|
|
momentum=0.9, |
|
|
|
affine=True, |
|
|
|
gamma_init='ones', |
|
|
|
beta_init='zeros', |
|
|
|
moving_mean_init='zeros', |
|
|
|
moving_var_init='ones', |
|
|
|
use_batch_statistics=True)), |
|
|
|
"pt_api": APIPt('nn.BatchNorm2d', OrderedDict(num_features=REQUIRED, |
|
|
|
eps=1e-5, |
|
|
|
momentum=0.1, |
|
|
|
affine=True, |
|
|
|
track_running_stats=True)), |
|
|
|
"ms2pt_mapping": {"num_features": "num_features", |
|
|
|
"eps": "eps", |
|
|
|
"affine": "affine", |
|
|
|
"use_batch_statistics": "track_running_stats"}, |
|
|
|
"gen_explicit_map": partial(gen_explicit_map_one_delta, |
|
|
|
k_ms="momentum", k_pt="momentum") |
|
|
|
}), |
|
|
|
'nn.ReLU': MappingHelper(**{"ms_api": APIMs('nn.ReLU', OrderedDict()), |
|
|
|
"pt_api": APIPt('nn.ReLU', OrderedDict(inplace=False)), |
|
|
|
"ms2pt_mapping": {}}), |
|
|
|
'nn.ReLU6': MappingHelper(**{"ms_api": APIMs('nn.ReLU6', OrderedDict()), |
|
|
|
"pt_api": APIPt('nn.ReLU6', OrderedDict(inplace=False)), |
|
|
|
"ms2pt_mapping": {}}), |
|
|
|
'nn.Linear': MappingHelper(**{"ms_api": APIMs('nn.Dense', OrderedDict(in_channels=REQUIRED, |
|
|
|
out_channels=REQUIRED, |
|
|
|
weight_init='normal', |
|
|
|
bias_init='zeros', |
|
|
|
has_bias=True, |
|
|
|
activation=None)), |
|
|
|
"pt_api": APIPt('nn.Linear', OrderedDict(in_features=REQUIRED, |
|
|
|
out_features=REQUIRED, |
|
|
|
bias=True)), |
|
|
|
"ms2pt_mapping": {"in_channels": "in_features", |
|
|
|
"out_channels": "out_features", |
|
|
|
"has_bias": "bias"} |
|
|
|
}), |
|
|
|
'nn.MaxPool2d': MappingHelper(**{"ms_api": APIMs('nn.MaxPool2d', OrderedDict(kernel_size=1, |
|
|
|
stride=1, |
|
|
|
pad_mode="valid")), |
|
|
|
"pt_api": APIPt('nn.MaxPool2d', OrderedDict(kernel_size=REQUIRED, |
|
|
|
stride=None, |
|
|
|
padding=0, |
|
|
|
dilation=1, |
|
|
|
return_indices=False, |
|
|
|
ceil_mode="False")), |
|
|
|
"ms2pt_mapping": {"kernel_size": "kernel_size", |
|
|
|
"stride": "stride"}, |
|
|
|
"gen_explicit_map": gen_explicit_map_nn_maxpool2d |
|
|
|
}), |
|
|
|
'nn.AvgPool2d': MappingHelper(**{"ms_api": APIMs('nn.AvgPool2d', OrderedDict(kernel_size=1, |
|
|
|
stride=1, |
|
|
|
pad_mode="valid")), |
|
|
|
"pt_api": APIPt('nn.AvgPool2d', OrderedDict(kernel_size=REQUIRED, |
|
|
|
stride=None, |
|
|
|
padding=0, |
|
|
|
dilation=1, |
|
|
|
return_indices=False, |
|
|
|
ceil_mode="False")), |
|
|
|
"ms2pt_mapping": {"kernel_size": "kernel_size", |
|
|
|
"stride": "stride"}, |
|
|
|
"gen_explicit_map": gen_explicit_map_nn_maxpool2d |
|
|
|
}), |
|
|
|
'nn.Dropout': MappingHelper(**{"ms_api": APIMs('nn.Dropout', OrderedDict(keep_prob=0.5, |
|
|
|
seed0=0, |
|
|
|
seed1=0, |
|
|
|
dtype="mstype.float32")), |
|
|
|
"pt_api": APIPt('nn.Dropout', OrderedDict(p=0.5, |
|
|
|
inplace=False)), |
|
|
|
"ms2pt_mapping": {"keep_prob": "p"}, |
|
|
|
"gen_explicit_map": partial(gen_explicit_map_one_delta, |
|
|
|
k_ms="keep_prob", k_pt="p") |
|
|
|
}) |
|
|
|
} |
|
|
|
NN_MAPPING_PATH = os.path.realpath(os.path.join(os.path.dirname(__file__), 'mappings/nn_mappings.json')) |
|
|
|
NN_MAPPING = get_mapping_from_file(NN_MAPPING_PATH) |
|
|
|
# update to add key with full api_name, which starts with 'torch.nn.' |
|
|
|
NN_MAPPING.update({"torch." + k: v for k, v in NN_MAPPING.items()}) |
|
|
|
|
|
|
|
F_MAPPING_PATH = os.path.realpath(os.path.join(os.path.dirname(__file__), 'mappings/f_mappings.json')) |
|
|
|
F_MAPPING = get_mapping_from_file(F_MAPPING_PATH) |
|
|
|
# update to add key starts with 'nn.functional.' |
|
|
|
NN_FUNCTIONAL_D = {"nn.functional." + k[len('F.'):]: v for k, v in F_MAPPING.items()} |
|
|
|
# update to add key starts with 'torch.nn.functiona.l' |
|
|
|
TORCH_NN_FUNCTIONAL_D = {"torch.nn.functional." + k[len('F.'):]: v for k, v in F_MAPPING.items()} |
|
|
|
F_MAPPING.update(NN_FUNCTIONAL_D) |
|
|
|
F_MAPPING.update(TORCH_NN_FUNCTIONAL_D) |
|
|
|
|
|
|
|
F_MAPPING = { |
|
|
|
'F.relu': MappingHelper(**{"ms_api": APIMs('P.ReLU', OrderedDict(input=REQUIRED)), |
|
|
|
"pt_api": APIPt('F.relu', OrderedDict(input=REQUIRED, inplace=False)), |
|
|
|
"ms2pt_mapping": {"input": "input"}, |
|
|
|
}), |
|
|
|
'F.relu6': MappingHelper(**{"ms_api": APIMs('P.ReLU6', OrderedDict(input=REQUIRED)), |
|
|
|
"pt_api": APIPt('F.relu6', OrderedDict(input=REQUIRED, inplace=False)), |
|
|
|
"ms2pt_mapping": {"input": "input"}, |
|
|
|
}), |
|
|
|
'F.max_pool2d': MappingHelper(**{"ms_api": APIMs('P.MaxPool', OrderedDict(ksize=1, |
|
|
|
strides=1, |
|
|
|
padding="valid", |
|
|
|
input=REQUIRED), |
|
|
|
p_attrs={"ksize", "strides", "padding"}), |
|
|
|
"pt_api": APIPt('F.max_pool2d', OrderedDict(input=REQUIRED, |
|
|
|
kernel_size=REQUIRED, |
|
|
|
stride=None, |
|
|
|
padding=0, |
|
|
|
dilation=1, |
|
|
|
ceil_mode=False, |
|
|
|
return_indices=False)), |
|
|
|
"ms2pt_mapping": {"ksize": "kernel_size", |
|
|
|
"strides": "stride", |
|
|
|
"input": "input", |
|
|
|
}, |
|
|
|
"gen_explicit_map": gen_explicit_map_f_max_pool2d |
|
|
|
}), |
|
|
|
'F.avg_pool2d': MappingHelper(**{"ms_api": APIMs('P.AvgPool', OrderedDict(ksize=1, |
|
|
|
strides=1, |
|
|
|
padding="valid", |
|
|
|
input=REQUIRED), |
|
|
|
p_attrs={"ksize", "strides", "padding"}), |
|
|
|
"pt_api": APIPt('F.avg_pool2d', OrderedDict(input=REQUIRED, |
|
|
|
kernel_size=REQUIRED, |
|
|
|
stride=None, |
|
|
|
padding=0, |
|
|
|
dilation=1, |
|
|
|
ceil_mode=False, |
|
|
|
return_indices=False)), |
|
|
|
"ms2pt_mapping": {"ksize": "kernel_size", |
|
|
|
"strides": "stride", |
|
|
|
"input": "input", |
|
|
|
}, |
|
|
|
"gen_explicit_map": gen_explicit_map_f_max_pool2d |
|
|
|
}), |
|
|
|
} |
|
|
|
nn_functional_d = {"nn.functional." + k[2:]: v for k, v in F_MAPPING.items()} |
|
|
|
torch_nn_functional_d = {"torch.nn.functional." + k[2:]: v for k, v in F_MAPPING.items()} |
|
|
|
F_MAPPING.update(nn_functional_d) |
|
|
|
F_MAPPING.update(torch_nn_functional_d) |
|
|
|
|
|
|
|
|
|
|
|
TORCH_DOT_MAPPING = { |
|
|
|
'torch.flatten': MappingHelper(**{"ms_api": APIMs('P.Flatten', OrderedDict(input=REQUIRED)), |
|
|
|
"pt_api": APIPt('torch.flatten', OrderedDict(input=REQUIRED, |
|
|
|
start_dim=0, |
|
|
|
end_dim=-1)), |
|
|
|
"ms2pt_mapping": {"input": "input"} |
|
|
|
}), |
|
|
|
'torch.cat': MappingHelper(**{"ms_api": APIMs('P.Concat', |
|
|
|
OrderedDict(axis=0, input=REQUIRED), |
|
|
|
p_attrs={"axis"}), |
|
|
|
"pt_api": APIPt('torch.flatten', OrderedDict(tensors=REQUIRED, dim=0, out=None)), |
|
|
|
"ms2pt_mapping": {"input": "tensors", |
|
|
|
"axis": "dim"} |
|
|
|
}), |
|
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
TENSOR_DOT_MAPPING = { |
|
|
|
'.view': MappingHelper(**{"ms_api": APIMs('P.Reshape', OrderedDict(x=REQUIRED, shape=REQUIRED)), |
|
|
|
"pt_api": APIPt('.view', OrderedDict([('*shape', REQUIRED)])), |
|
|
|
"ms2pt_mapping": {"x": "call_name"}, |
|
|
|
"gen_explicit_map": (lambda params_pt, args_pt: {"shape": "(" + args_pt["*shape"] + ",)"}) |
|
|
|
}), |
|
|
|
'.size': MappingHelper(**{"ms_api": APIMs('P.Shape', OrderedDict(x=REQUIRED)), |
|
|
|
"pt_api": APIPt('.size', OrderedDict([('idx', REQUIRED)])), |
|
|
|
"ms2pt_mapping": {"x": "call_name"} |
|
|
|
}), |
|
|
|
'.flatten': MappingHelper(**{"ms_api": APIMs('P.Flatten', OrderedDict(input=REQUIRED)), |
|
|
|
"pt_api": APIPt('.flatten', OrderedDict(start_dim=0, |
|
|
|
end_dim=-1)), |
|
|
|
"ms2pt_mapping": {"input": "call_name"} |
|
|
|
}), |
|
|
|
'.reshape': MappingHelper(**{"ms_api": APIMs('P.Reshape', OrderedDict(x=REQUIRED, shape=REQUIRED)), |
|
|
|
"pt_api": APIPt('.reshape', OrderedDict([('*shape', REQUIRED)])), |
|
|
|
"ms2pt_mapping": {"x": "call_name"}, |
|
|
|
"gen_explicit_map": ( |
|
|
|
lambda params_pt, args_pt: {"shape": "(" + args_pt["*shape"] + ",)"}) |
|
|
|
}), |
|
|
|
'.mean': MappingHelper(**{"ms_api": APIMs('P.ReduceMean', OrderedDict(keep_dims=False, |
|
|
|
input=REQUIRED, |
|
|
|
axis=()), |
|
|
|
p_attrs={"keep_dims"}), |
|
|
|
"pt_api": APIPt('.mean', OrderedDict(dim=None, |
|
|
|
keepdim=False)), |
|
|
|
"ms2pt_mapping": {"keep_dims": "keepdim", |
|
|
|
"axis": "dim", |
|
|
|
"input": "call_name"}, |
|
|
|
}), |
|
|
|
'.squeeze': MappingHelper(**{"ms_api": APIMs('P.ReduceMean', OrderedDict(input=REQUIRED, |
|
|
|
axis=()), |
|
|
|
p_attrs={"axis"}), |
|
|
|
"pt_api": APIPt('.squeeze', OrderedDict(dim=None)), |
|
|
|
"ms2pt_mapping": {"axis": "dim", |
|
|
|
"input": "call_name"}, |
|
|
|
}), |
|
|
|
} |
|
|
|
TORCH_DOT_MAPPING_PATH = os.path.realpath(os.path.join(os.path.dirname(__file__), 'mappings/torch_dot_mappings.json')) |
|
|
|
TORCH_DOT_MAPPING = get_mapping_from_file(TORCH_DOT_MAPPING_PATH) |
|
|
|
|
|
|
|
TENSOR_DOT_MAPPING_PATH = os.path.realpath(os.path.join(os.path.dirname(__file__), 'mappings/tensor_dot_mappings.json')) |
|
|
|
TENSOR_DOT_MAPPING = get_mapping_from_file(TENSOR_DOT_MAPPING_PATH) |
|
|
|
|
|
|
|
ALL_MAPPING = {**NN_MAPPING, **F_MAPPING, **TORCH_DOT_MAPPING, **TENSOR_DOT_MAPPING} |
|
|
|
|
|
|
|
|
|
|
|
# ---------------------------- api list support or not support ---------------------------- |
|
|
|
NN_LIST_PATH = os.path.realpath(os.path.join(os.path.dirname(__file__), 'nn_list.json')) |
|
|
|
NN_LIST_PATH = os.path.realpath(os.path.join(os.path.dirname(__file__), 'ops', 'nn_list.json')) |
|
|
|
NN_LIST = load_json_file(NN_LIST_PATH) |
|
|
|
NN_LIST += ["torch." + name for name in NN_LIST] |
|
|
|
NN_SUPPORTED = [x for x in NN_LIST if x in ALL_MAPPING] |
|
|
|
NN_UNSUPPORTED = [x for x in NN_LIST if x not in ALL_MAPPING] |
|
|
|
|
|
|
|
|
|
|
|
F_LIST_PATH = os.path.realpath(os.path.join(os.path.dirname(__file__), 'f_list.json')) |
|
|
|
F_LIST_PATH = os.path.realpath(os.path.join(os.path.dirname(__file__), 'ops', 'f_list.json')) |
|
|
|
F_LIST = load_json_file(F_LIST_PATH) |
|
|
|
F_LIST += ["F." + name[len("torch.nn.functional."):] for name in F_LIST] + \ |
|
|
|
[name[len("torch."):] for name in F_LIST] |
|
|
|
@@ -595,7 +367,7 @@ F_SUPPORTED = [x for x in F_LIST if x in ALL_MAPPING] |
|
|
|
F_UNSUPPORTED = [x for x in F_LIST if x not in ALL_MAPPING] |
|
|
|
|
|
|
|
|
|
|
|
TORCH_DOT_LIST_PATH = os.path.realpath(os.path.join(os.path.dirname(__file__), 'torch_dot_list.json')) |
|
|
|
TORCH_DOT_LIST_PATH = os.path.realpath(os.path.join(os.path.dirname(__file__), 'ops', 'torch_dot_list.json')) |
|
|
|
TORCH_DOT_LIST = load_json_file(TORCH_DOT_LIST_PATH) |
|
|
|
|
|
|
|
|
|
|
|
@@ -603,7 +375,7 @@ TORCH_DOT_SUPPORTED = [x for x in TORCH_DOT_LIST if x in ALL_MAPPING] |
|
|
|
TORCH_DOT_UNSUPPORTED = [x for x in TORCH_DOT_LIST if x not in ALL_MAPPING] |
|
|
|
|
|
|
|
|
|
|
|
TENSOR_DOT_LIST_PATH = os.path.realpath(os.path.join(os.path.dirname(__file__), 'tensor_dot_list.json')) |
|
|
|
TENSOR_DOT_LIST_PATH = os.path.realpath(os.path.join(os.path.dirname(__file__), 'ops', 'tensor_dot_list.json')) |
|
|
|
TENSOR_DOT_LIST = load_json_file(TENSOR_DOT_LIST_PATH) |
|
|
|
|
|
|
|
|
|
|
|
@@ -620,5 +392,5 @@ ALL_UNSUPPORTED = NN_UNSUPPORTED + F_UNSUPPORTED + TORCH_DOT_UNSUPPORTED + TENSO |
|
|
|
UNSUPPORTED_WARN_INFOS = { |
|
|
|
"nn.AdaptiveAvgPool2d": "maybe could convert to P.ReduceMean", |
|
|
|
"F.adaptive_avg_pool2d": "maybe could convert to P.ReduceMean", |
|
|
|
"F.dropout": "please use nn.Dropout in __init__()", |
|
|
|
"F.dropout": "please use nn.Dropout in __init__()" |
|
|
|
} |