| @@ -71,9 +71,9 @@ class APIPt: | |||||
| or the given args_str not valid. | or the given args_str not valid. | ||||
| """ | """ | ||||
| # expr is REQUIRED to meet (**) format | # expr is REQUIRED to meet (**) format | ||||
| if not (len(args_str) >= 2 and args_str[0] == "(" and args_str[-1] == ")"): | |||||
| raise ValueError('[{}] is think as args str, it should start with "(" and end with ")"'.format(args_str)) | |||||
| if not (len(args_str) >= 2 and args_str[0] == "(" and args_str.strip()[-1] == ")"): | |||||
| raise ValueError('"{}" is think as args string, it should start with "(" and end with ")" without ' | |||||
| 'considering spaces'.format(args_str)) | |||||
| try: | try: | ||||
| ast_node = ast.parse("whatever_call_name" + args_str) | ast_node = ast.parse("whatever_call_name" + args_str) | ||||
| call_node = ast_node.body[0].value | call_node = ast_node.body[0].value | ||||
| @@ -35,7 +35,14 @@ def gen_explicit_map_f_max_pool2d(params_pt, args_pt): | |||||
| padding = "'valid'" | padding = "'valid'" | ||||
| else: | else: | ||||
| padding = "'same'" | padding = "'same'" | ||||
| return {"padding": padding} | |||||
| if 'stride' in args_pt: | |||||
| strides = args_pt['stride'] | |||||
| else: | |||||
| strides = args_pt['kernel_size'] | |||||
| return {"padding": padding, | |||||
| "strides": strides} | |||||
| def gen_explicit_map_nn_sequential(_, args_pt): | def gen_explicit_map_nn_sequential(_, args_pt): | ||||
| @@ -97,7 +104,14 @@ def gen_explicit_map_nn_maxpool2d(params_pt, args_pt): | |||||
| pad_mode = "'valid'" | pad_mode = "'valid'" | ||||
| else: | else: | ||||
| pad_mode = "'same'" | pad_mode = "'same'" | ||||
| return {"pad_mode": pad_mode} | |||||
| if 'stride' in args_pt: | |||||
| stride = args_pt['stride'] | |||||
| else: | |||||
| stride = args_pt['kernel_size'] | |||||
| return {"pad_mode": pad_mode, | |||||
| "stride": stride} | |||||
| def torch_dot_eye_gen_explicit_map(_, args_pt): | def torch_dot_eye_gen_explicit_map(_, args_pt): | ||||
| @@ -21,14 +21,13 @@ | |||||
| "kernel_size": "REQUIRED", | "kernel_size": "REQUIRED", | ||||
| "stride": null, | "stride": null, | ||||
| "padding": 0, | "padding": 0, | ||||
| "dilation": 1, | |||||
| "ceil_mode": false, | "ceil_mode": false, | ||||
| "return_indices": false | |||||
| "count_include_pad": true, | |||||
| "divisor_override": null | |||||
| } | } | ||||
| ], | ], | ||||
| "ms2pt_mapping": { | "ms2pt_mapping": { | ||||
| "ksize": "kernel_size", | "ksize": "kernel_size", | ||||
| "strides": "stride", | |||||
| "input": "input" | "input": "input" | ||||
| }, | }, | ||||
| "gen_explicit_map": "gen_explicit_map_f_max_pool2d" | "gen_explicit_map": "gen_explicit_map_f_max_pool2d" | ||||
| @@ -62,7 +61,6 @@ | |||||
| ], | ], | ||||
| "ms2pt_mapping": { | "ms2pt_mapping": { | ||||
| "ksize": "kernel_size", | "ksize": "kernel_size", | ||||
| "strides": "stride", | |||||
| "input": "input" | "input": "input" | ||||
| }, | }, | ||||
| "gen_explicit_map": "gen_explicit_map_f_max_pool2d" | "gen_explicit_map": "gen_explicit_map_f_max_pool2d" | ||||
| @@ -16,9 +16,7 @@ | |||||
| "inplace": false | "inplace": false | ||||
| } | } | ||||
| ], | ], | ||||
| "ms2pt_mapping": { | |||||
| "keep_prob": "p" | |||||
| }, | |||||
| "ms2pt_mapping": {}, | |||||
| "gen_explicit_map": "nn_dropout_gen_explicit_map" | "gen_explicit_map": "nn_dropout_gen_explicit_map" | ||||
| }, | }, | ||||
| "nn.AvgPool2d": { | "nn.AvgPool2d": { | ||||
| @@ -36,14 +34,13 @@ | |||||
| "kernel_size": "REQUIRED", | "kernel_size": "REQUIRED", | ||||
| "stride": null, | "stride": null, | ||||
| "padding": 0, | "padding": 0, | ||||
| "dilation": 1, | |||||
| "return_indices": false, | |||||
| "ceil_mode": "False" | |||||
| "ceil_mode": false, | |||||
| "count_include_pad": true, | |||||
| "divisor_override": null | |||||
| } | } | ||||
| ], | ], | ||||
| "ms2pt_mapping": { | "ms2pt_mapping": { | ||||
| "kernel_size": "kernel_size", | |||||
| "stride": "stride" | |||||
| "kernel_size": "kernel_size" | |||||
| }, | }, | ||||
| "gen_explicit_map": "gen_explicit_map_nn_maxpool2d" | "gen_explicit_map": "gen_explicit_map_nn_maxpool2d" | ||||
| }, | }, | ||||
| @@ -68,8 +65,7 @@ | |||||
| } | } | ||||
| ], | ], | ||||
| "ms2pt_mapping": { | "ms2pt_mapping": { | ||||
| "kernel_size": "kernel_size", | |||||
| "stride": "stride" | |||||
| "kernel_size": "kernel_size" | |||||
| }, | }, | ||||
| "gen_explicit_map": "gen_explicit_map_nn_maxpool2d" | "gen_explicit_map": "gen_explicit_map_nn_maxpool2d" | ||||
| }, | }, | ||||
| @@ -64,6 +64,15 @@ class TestConverter: | |||||
| assert replaced_code == code.replace('nn.Softmax(dim=1)', | assert replaced_code == code.replace('nn.Softmax(dim=1)', | ||||
| '{}(axis=1)'.format(expected_ms_api_name)) | '{}(axis=1)'.format(expected_ms_api_name)) | ||||
| def test_convert_api_nn_dropout(self): | |||||
| """Test convert_api function work ok when convert api nn.Dropout""" | |||||
| code = """nn.Dropout(0.3)""" | |||||
| expected_ms_api_name = 'nn.Dropout' | |||||
| replaced_code = self.converter_ins.convert_api(code) | |||||
| assert replaced_code == code.replace('nn.Dropout(0.3)', | |||||
| "{}(keep_prob=0.7)".format(expected_ms_api_name)) | |||||
| # test convert_api with torch dot ops | # test convert_api with torch dot ops | ||||
| def test_convert_api_torch_dot_abs(self): | def test_convert_api_torch_dot_abs(self): | ||||
| """Test convert_api function work ok when convert api torch.abs""" | """Test convert_api function work ok when convert api torch.abs""" | ||||
| @@ -202,6 +211,33 @@ class TestConverter: | |||||
| assert replaced_code == code.replace('F.sigmoid(input)', | assert replaced_code == code.replace('F.sigmoid(input)', | ||||
| '{}()(input)'.format(expected_ms_api_name)) | '{}()(input)'.format(expected_ms_api_name)) | ||||
| def test_convert_api_f_max_pool2d(self): | |||||
| """Test convert_api function work ok when convert api F.max_pool2d""" | |||||
| code = """F.max_pool2d(out, 2)""" | |||||
| expected_ms_api_name = 'P.MaxPool' | |||||
| replaced_code = self.converter_ins.convert_api(code) | |||||
| assert replaced_code == code.replace('F.max_pool2d(out, 2)', | |||||
| "{}(2, 2, 'valid')(out)".format(expected_ms_api_name)) | |||||
| def test_convert_api_f_avg_pool2d_without_strides(self): | |||||
| """Test convert_api function work ok when convert api F.avg_pool2d""" | |||||
| code = """F.avg_pool2d(out, 2)""" | |||||
| expected_ms_api_name = 'P.AvgPool' | |||||
| replaced_code = self.converter_ins.convert_api(code) | |||||
| assert replaced_code == code.replace('F.avg_pool2d(out, 2)', | |||||
| "{}(2, 2, 'valid')(out)".format(expected_ms_api_name)) | |||||
| def test_convert_api_f_avg_pool2d_with_strides(self): | |||||
| """Test convert_api function work ok when convert api F.avg_pool2d""" | |||||
| code = """F.avg_pool2d(out, 2, 3)""" | |||||
| expected_ms_api_name = 'P.AvgPool' | |||||
| replaced_code = self.converter_ins.convert_api(code) | |||||
| assert replaced_code == code.replace('F.avg_pool2d(out, 2, 3)', | |||||
| "{}(2, 3, 'valid')(out)".format(expected_ms_api_name)) | |||||
| # test convert_api with tensor dot ops | # test convert_api with tensor dot ops | ||||
| def test_convert_api_tensor_dot_repeat(self): | def test_convert_api_tensor_dot_repeat(self): | ||||
| """Test convert_api function work ok when convert api .repeat""" | """Test convert_api function work ok when convert api .repeat""" | ||||
| @@ -216,7 +252,6 @@ class TestConverter: | |||||
| """Test convert_api function work ok when convert api .permute""" | """Test convert_api function work ok when convert api .permute""" | ||||
| code = "x.permute(2, 0, 1)" | code = "x.permute(2, 0, 1)" | ||||
| expected_ms_api_name = 'P.Transpose' | expected_ms_api_name = 'P.Transpose' | ||||
| replaced_code = self.converter_ins.convert_api(code) | replaced_code = self.converter_ins.convert_api(code) | ||||
| assert replaced_code == code.replace('x.permute(2, 0, 1)', | assert replaced_code == code.replace('x.permute(2, 0, 1)', | ||||
| '{}()(x, (2, 0, 1,))'.format(expected_ms_api_name)) | '{}()(x, (2, 0, 1,))'.format(expected_ms_api_name)) | ||||