@@ -0,0 +1,14 @@ | |||||
using System; | |||||
using System.Collections.Generic; | |||||
using System.Text; | |||||
namespace Tensorflow | |||||
{ | |||||
public partial class Graph | |||||
{ | |||||
public void _colocate_with_for_gradient(Operation op, int? gradient_uid, bool ignore_existing = false) | |||||
{ | |||||
} | |||||
} | |||||
} |
@@ -72,7 +72,13 @@ namespace Tensorflow | |||||
} | } | ||||
else | else | ||||
{ | { | ||||
input_types.Add(value.dtype); | |||||
var base_type = value.dtype; | |||||
// base type | |||||
if ((int)value.dtype > 100) | |||||
{ | |||||
base_type = (TF_DataType)Enum.Parse(typeof(TF_DataType), ((int)value.dtype - 100).ToString()); | |||||
} | |||||
input_types.Add(base_type); | |||||
} | } | ||||
} | } | ||||
} | } | ||||
@@ -124,6 +124,8 @@ namespace Tensorflow | |||||
} | } | ||||
} | } | ||||
private TF_DataType[] _input_types => _inputs._inputs.Select(x => x.dtype).ToArray(); | |||||
private NodeDef _node_def; | private NodeDef _node_def; | ||||
public NodeDef node_def | public NodeDef node_def | ||||
{ | { | ||||
@@ -99,6 +99,8 @@ namespace Tensorflow | |||||
} | } | ||||
else | else | ||||
{ | { | ||||
ops.colocate_with(_initializer_op); | |||||
_snapshot = gen_array_ops.identity(_variable, name = "read"); | _snapshot = gen_array_ops.identity(_variable, name = "read"); | ||||
} | } | ||||
@@ -185,5 +185,16 @@ namespace Tensorflow | |||||
{ | { | ||||
return uid_number++; | return uid_number++; | ||||
} | } | ||||
public static void colocate_with(Operation op, bool ignore_existing = false) | |||||
{ | |||||
_colocate_with_for_gradient(op, null, ignore_existing); | |||||
} | |||||
private static void _colocate_with_for_gradient(Operation op, int? gradient_uid, bool ignore_existing = false) | |||||
{ | |||||
var default_graph = get_default_graph(); | |||||
default_graph._colocate_with_for_gradient(op, gradient_uid, ignore_existing); | |||||
} | |||||
} | } | ||||
} | } |