|
|
|
@@ -13,6 +13,9 @@ |
|
|
|
# limitations under the License. |
|
|
|
# ============================================================================== |
|
|
|
"""Graph associated definition module.""" |
|
|
|
|
|
|
|
__all__ = ["GraphFactory", "PyTorchGraphNode"] |
|
|
|
|
|
|
|
from .base import Graph |
|
|
|
from .pytorch_graph import PyTorchGraph |
|
|
|
from .pytorch_graph_node import PyTorchGraphNode |
|
|
|
@@ -44,9 +47,3 @@ class GraphFactory: |
|
|
|
output_nodes=output_nodes, sample_shape=sample_shape) |
|
|
|
|
|
|
|
return PyTorchGraph.load(model_path=graph_path, sample_shape=sample_shape) |
|
|
|
|
|
|
|
|
|
|
|
__all__ = [ |
|
|
|
"GraphFactory", |
|
|
|
"PyTorchGraphNode", |
|
|
|
] |