|
|
|
@@ -710,12 +710,28 @@ class ReplacePath(BasePath): |
|
|
|
Args: |
|
|
|
increment_idx (int): To deduplicate module name. |
|
|
|
""" |
|
|
|
src = ",".join(self.topo_order_bef_repl) |
|
|
|
tgt = self.pattern.pattern |
|
|
|
md_name = f"Module{increment_idx}" |
|
|
|
src_aft_repl = src.replace(tgt, md_name) |
|
|
|
if src != src_aft_repl: |
|
|
|
new_order = [] |
|
|
|
idx = 0 |
|
|
|
merged = False |
|
|
|
while idx < len(self.topo_order_bef_repl): |
|
|
|
op_node = self.topo_order_bef_repl[idx] |
|
|
|
if op_node.op_type == self.pattern.head: |
|
|
|
windowed_nodes = self.topo_order_bef_repl[idx:idx + self.pattern.ptn_length] |
|
|
|
cur_window = ",".join([nd.op_type for nd in windowed_nodes]) |
|
|
|
if cur_window == self.pattern.pattern: |
|
|
|
merged_node = MergedONNXNode(name="", module_name=md_name, |
|
|
|
ori_nodes=windowed_nodes, inputs="", outputs="", |
|
|
|
known_module_name=self.pattern.known_module_name) |
|
|
|
new_order.append(merged_node) |
|
|
|
idx += self.pattern.ptn_length |
|
|
|
merged = True |
|
|
|
continue |
|
|
|
new_order.append(op_node) |
|
|
|
idx += 1 |
|
|
|
|
|
|
|
if merged: |
|
|
|
self.pattern.module_name = md_name |
|
|
|
self.topo_order_aft_repl = src_aft_repl.split(",") |
|
|
|
self.topo_order_aft_repl = new_order |
|
|
|
return md_name |
|
|
|
return None |