|
|
|
@@ -83,7 +83,7 @@ def torch_installation_validation(func): |
|
|
|
type, inner function. |
|
|
|
""" |
|
|
|
|
|
|
|
def _f(graph_path: str, sample_shape: tuple, input_nodes: str, output_nodes: str, |
|
|
|
def _f(graph_path: str, input_nodes: dict, output_nodes: List[str], |
|
|
|
output_folder: str, report_folder: str = None): |
|
|
|
# Check whether pytorch is installed. |
|
|
|
error_info = None |
|
|
|
@@ -119,7 +119,7 @@ def torch_installation_validation(func): |
|
|
|
_print_error(error) |
|
|
|
sys.exit(0) |
|
|
|
|
|
|
|
func(graph_path=graph_path, sample_shape=sample_shape, |
|
|
|
func(graph_path=graph_path, |
|
|
|
input_nodes=input_nodes, output_nodes=output_nodes, |
|
|
|
output_folder=output_folder, report_folder=report_folder) |
|
|
|
|
|
|
|
@@ -265,11 +265,12 @@ def main_graph_base_converter(file_config): |
|
|
|
if not file_config.get("shape"): |
|
|
|
raise ParamMissingError("Param missing, `--shape` is required when using graph mode.") |
|
|
|
|
|
|
|
if graph_path.endswith("pth") and not file_config['input_nodes'] and \ |
|
|
|
file_config.get("shape") and len(file_config.get("shape")) == 1: |
|
|
|
if graph_path.endswith("pth") and not file_config.get("input_nodes", []) and \ |
|
|
|
file_config.get("shape") and len(file_config.get("shape", ())) == 1: |
|
|
|
file_config['input_nodes'] = ["input.1"] |
|
|
|
|
|
|
|
if len(file_config['shape']) != len(file_config['input_nodes']) != len(set(file_config['input_nodes'])): |
|
|
|
if len(file_config['shape']) != len(file_config.get("input_nodes", [])) != len( |
|
|
|
set(file_config.get("input_nodes", []))): |
|
|
|
raise BadParamError("`--shape` and `--input_nodes` must have the same length, " |
|
|
|
"and no redundant node in `--input_nodes`.") |
|
|
|
|
|
|
|
|