From 0511614861cb8284e8a2a82d25a4f0a35d827a8c Mon Sep 17 00:00:00 2001 From: haiping008 Date: Fri, 25 Jan 2019 17:22:30 -0600 Subject: [PATCH] ops.colocate_with --- src/TensorFlowNET.Core/Graphs/Graph.Gradient.cs.cs | 14 ++++++++++++++ src/TensorFlowNET.Core/Operations/OpDefLibrary.cs | 8 +++++++- src/TensorFlowNET.Core/Operations/Operation.cs | 2 ++ src/TensorFlowNET.Core/Variables/RefVariable.cs | 2 ++ src/TensorFlowNET.Core/ops.py.cs | 11 +++++++++++ 5 files changed, 36 insertions(+), 1 deletion(-) create mode 100644 src/TensorFlowNET.Core/Graphs/Graph.Gradient.cs.cs diff --git a/src/TensorFlowNET.Core/Graphs/Graph.Gradient.cs.cs b/src/TensorFlowNET.Core/Graphs/Graph.Gradient.cs.cs new file mode 100644 index 00000000..764d5bef --- /dev/null +++ b/src/TensorFlowNET.Core/Graphs/Graph.Gradient.cs.cs @@ -0,0 +1,14 @@ +using System; +using System.Collections.Generic; +using System.Text; + +namespace Tensorflow +{ + public partial class Graph + { + public void _colocate_with_for_gradient(Operation op, int? gradient_uid, bool ignore_existing = false) + { + + } + } +} diff --git a/src/TensorFlowNET.Core/Operations/OpDefLibrary.cs b/src/TensorFlowNET.Core/Operations/OpDefLibrary.cs index b31be3c0..e0e68929 100644 --- a/src/TensorFlowNET.Core/Operations/OpDefLibrary.cs +++ b/src/TensorFlowNET.Core/Operations/OpDefLibrary.cs @@ -72,7 +72,13 @@ namespace Tensorflow } else { - input_types.Add(value.dtype); + var base_type = value.dtype; + // base type + if ((int)value.dtype > 100) + { + base_type = (TF_DataType)Enum.Parse(typeof(TF_DataType), ((int)value.dtype - 100).ToString()); + } + input_types.Add(base_type); } } } diff --git a/src/TensorFlowNET.Core/Operations/Operation.cs b/src/TensorFlowNET.Core/Operations/Operation.cs index 7189747d..352baf93 100644 --- a/src/TensorFlowNET.Core/Operations/Operation.cs +++ b/src/TensorFlowNET.Core/Operations/Operation.cs @@ -124,6 +124,8 @@ namespace Tensorflow } } + private TF_DataType[] _input_types => _inputs._inputs.Select(x => x.dtype).ToArray(); + private NodeDef _node_def; public NodeDef node_def { diff --git a/src/TensorFlowNET.Core/Variables/RefVariable.cs b/src/TensorFlowNET.Core/Variables/RefVariable.cs index 30f49ebd..8a9a1377 100644 --- a/src/TensorFlowNET.Core/Variables/RefVariable.cs +++ b/src/TensorFlowNET.Core/Variables/RefVariable.cs @@ -99,6 +99,8 @@ namespace Tensorflow } else { + ops.colocate_with(_initializer_op); + _snapshot = gen_array_ops.identity(_variable, name = "read"); } diff --git a/src/TensorFlowNET.Core/ops.py.cs b/src/TensorFlowNET.Core/ops.py.cs index 9eb07e52..b8c8dc1c 100644 --- a/src/TensorFlowNET.Core/ops.py.cs +++ b/src/TensorFlowNET.Core/ops.py.cs @@ -185,5 +185,16 @@ namespace Tensorflow { return uid_number++; } + + public static void colocate_with(Operation op, bool ignore_existing = false) + { + _colocate_with_for_gradient(op, null, ignore_existing); + } + + private static void _colocate_with_for_gradient(Operation op, int? gradient_uid, bool ignore_existing = false) + { + var default_graph = get_default_graph(); + default_graph._colocate_with_for_gradient(op, gradient_uid, ignore_existing); + } } }