@@ -15,11 +15,11 @@ namespace Tensorflow | |||||
bool preserve_cardinality = false, | bool preserve_cardinality = false, | ||||
bool use_legacy_function = false) : base(input_dataset) | bool use_legacy_function = false) : base(input_dataset) | ||||
{ | { | ||||
var func = new ConcreteFunction($"{map_func.Method.Name}_{Guid.NewGuid()}"); | |||||
var func = new ConcreteFunction($"{map_func.Method.Name}_{Tensorflow.ops.uid_function()}"); | |||||
func.Enter(); | func.Enter(); | ||||
var inputs = new Tensors(); | var inputs = new Tensors(); | ||||
foreach (var input in input_dataset.element_spec) | foreach (var input in input_dataset.element_spec) | ||||
inputs.Add(tf.placeholder(input.dtype, shape: input.shape)); | |||||
inputs.Add(tf.placeholder(input.dtype, shape: input.shape, name: "arg")); | |||||
var outputs = map_func(inputs); | var outputs = map_func(inputs); | ||||
func.ToGraph(inputs, outputs); | func.ToGraph(inputs, outputs); | ||||
func.Exit(); | func.Exit(); | ||||
@@ -36,7 +36,7 @@ namespace Tensorflow.Functions | |||||
public ConcreteFunction(Func<Tensor, Tensor> func, TF_DataType dtype) | public ConcreteFunction(Func<Tensor, Tensor> func, TF_DataType dtype) | ||||
{ | { | ||||
string func_name = $"{func.Method.Name}_{Guid.NewGuid()}"; | |||||
string func_name = $"{func.Method.Name}_{ops.uid_function()}"; | |||||
func_graph = new FuncGraph(func_name); | func_graph = new FuncGraph(func_name); | ||||
func_graph.as_default(); | func_graph.as_default(); | ||||
@@ -53,7 +53,7 @@ namespace Tensorflow.Functions | |||||
public ConcreteFunction(Func<Tensor, IDatasetV2> func, TF_DataType dtype) | public ConcreteFunction(Func<Tensor, IDatasetV2> func, TF_DataType dtype) | ||||
{ | { | ||||
string func_name = $"{func.Method.Name}_{Guid.NewGuid()}"; | |||||
string func_name = $"{func.Method.Name}_{ops.uid_function()}"; | |||||
func_graph = new FuncGraph(func_name); | func_graph = new FuncGraph(func_name); | ||||
func_graph.as_default(); | func_graph.as_default(); | ||||
@@ -74,7 +74,7 @@ namespace Tensorflow.Functions | |||||
public ConcreteFunction(Func<Tensors, Tensors> func, | public ConcreteFunction(Func<Tensors, Tensors> func, | ||||
TF_DataType[] dtypes, TensorShape[] shapes) | TF_DataType[] dtypes, TensorShape[] shapes) | ||||
{ | { | ||||
string func_name = $"{func.Method.Name}_{Guid.NewGuid()}"; | |||||
string func_name = $"{func.Method.Name}_{ops.uid_function()}"; | |||||
// IntPtr func_handle; | // IntPtr func_handle; | ||||
func_graph = new FuncGraph(func_name); | func_graph = new FuncGraph(func_name); | ||||
@@ -8,7 +8,7 @@ namespace Tensorflow.Graphs | |||||
{ | { | ||||
public Func<Tensor, Tensor> to_graph(Func<Tensor, Tensor> func) | public Func<Tensor, Tensor> to_graph(Func<Tensor, Tensor> func) | ||||
{ | { | ||||
string func_name = $"{func.Method.Name}_{Guid.NewGuid()}"; | |||||
string func_name = $"{func.Method.Name}_{ops.uid_function()}"; | |||||
var graph = new FuncGraph(func_name); | var graph = new FuncGraph(func_name); | ||||
graph.as_default(); | graph.as_default(); | ||||
@@ -38,7 +38,7 @@ namespace Tensorflow.Graphs | |||||
public Func<Tensor, Tensor, Tensor> to_graph(Func<Tensor, Tensor, Tensor> func) | public Func<Tensor, Tensor, Tensor> to_graph(Func<Tensor, Tensor, Tensor> func) | ||||
{ | { | ||||
string func_name = $"{func.Method.Name}_{Guid.NewGuid()}"; | |||||
string func_name = $"{func.Method.Name}_{ops.uid_function()}"; | |||||
var graph = new FuncGraph(func_name); | var graph = new FuncGraph(func_name); | ||||
graph.as_default(); | graph.as_default(); | ||||
@@ -22,7 +22,7 @@ namespace Tensorflow.Graphs | |||||
public override void OnEntry(MethodExecutionArgs args) | public override void OnEntry(MethodExecutionArgs args) | ||||
{ | { | ||||
// TODO: func_name can be cache in FullName + Args | // TODO: func_name can be cache in FullName + Args | ||||
func_name = $"{args.Method.DeclaringType.FullName}.{args.Method.Name}_{Guid.NewGuid()}"; | |||||
func_name = $"{args.Method.DeclaringType.FullName}.{args.Method.Name}_{ops.uid_function()}"; | |||||
if (functions.ContainsKey(func_name)) | if (functions.ContainsKey(func_name)) | ||||
{ | { | ||||
@@ -353,6 +353,10 @@ namespace Tensorflow | |||||
return Interlocked.Increment(ref uid_number); | return Interlocked.Increment(ref uid_number); | ||||
} | } | ||||
static int uid_number_for_function = 0; | |||||
public static int uid_function() | |||||
=> Interlocked.Increment(ref uid_number_for_function); | |||||
public static void reset_uid() | public static void reset_uid() | ||||
{ | { | ||||
uid_number = -1; | uid_number = -1; | ||||