@@ -1,5 +1,4 @@ | |||||
//using Newtonsoft.Json; | |||||
using System; | |||||
using System; | |||||
using System.Collections.Generic; | using System.Collections.Generic; | ||||
using System.Linq; | using System.Linq; | ||||
using System.Runtime.InteropServices; | using System.Runtime.InteropServices; | ||||
@@ -15,10 +14,11 @@ namespace Tensorflow | |||||
private Tensor[] _outputs; | private Tensor[] _outputs; | ||||
public Tensor[] outputs => _outputs; | public Tensor[] outputs => _outputs; | ||||
//[JsonIgnore] | |||||
public Tensor output => _outputs.FirstOrDefault(); | public Tensor output => _outputs.FirstOrDefault(); | ||||
public int NumControlOutputs => c_api.TF_OperationNumControlOutputs(_handle); | public int NumControlOutputs => c_api.TF_OperationNumControlOutputs(_handle); | ||||
public int OutputNumConsumers(int index) => c_api.TF_OperationOutputNumConsumers(new TF_Output(_handle, index)); | public int OutputNumConsumers(int index) => c_api.TF_OperationOutputNumConsumers(new TF_Output(_handle, index)); | ||||
public unsafe TF_Input[] OutputConsumers(int index, int max_consumers) | public unsafe TF_Input[] OutputConsumers(int index, int max_consumers) | ||||
@@ -1,5 +1,6 @@ | |||||
using Google.Protobuf.Collections; | using Google.Protobuf.Collections; | ||||
//using Newtonsoft.Json; | |||||
using Newtonsoft.Json; | |||||
//using Newtonsoft.Json; | |||||
using System; | using System; | ||||
using System.Collections.Generic; | using System.Collections.Generic; | ||||
using System.Linq; | using System.Linq; | ||||
@@ -33,15 +34,15 @@ namespace Tensorflow | |||||
private readonly IntPtr _operDesc; | private readonly IntPtr _operDesc; | ||||
private Graph _graph; | private Graph _graph; | ||||
//[JsonIgnore] | |||||
[JsonIgnore] | |||||
public Graph graph => _graph; | public Graph graph => _graph; | ||||
//[JsonIgnore] | |||||
[JsonIgnore] | |||||
public int _id => _id_value; | public int _id => _id_value; | ||||
//[JsonIgnore] | |||||
[JsonIgnore] | |||||
public int _id_value; | public int _id_value; | ||||
public string type => OpType; | public string type => OpType; | ||||
//[JsonIgnore] | |||||
[JsonIgnore] | |||||
public Operation op => this; | public Operation op => this; | ||||
public TF_DataType dtype => TF_DataType.DtInvalid; | public TF_DataType dtype => TF_DataType.DtInvalid; | ||||
private Status status = new Status(); | private Status status = new Status(); | ||||
@@ -45,6 +45,7 @@ Bug memory leak issue when allocating Tensor.</PackageReleaseNotes> | |||||
<ItemGroup> | <ItemGroup> | ||||
<PackageReference Include="Google.Protobuf" Version="3.7.0" /> | <PackageReference Include="Google.Protobuf" Version="3.7.0" /> | ||||
<PackageReference Include="Newtonsoft.Json" Version="12.0.1" /> | |||||
</ItemGroup> | </ItemGroup> | ||||
<ItemGroup> | <ItemGroup> | ||||
@@ -1,4 +1,5 @@ | |||||
//using Newtonsoft.Json; | //using Newtonsoft.Json; | ||||
using Newtonsoft.Json; | |||||
using NumSharp; | using NumSharp; | ||||
using System; | using System; | ||||
using System.Collections.Generic; | using System.Collections.Generic; | ||||
@@ -18,13 +19,13 @@ namespace Tensorflow | |||||
private readonly IntPtr _handle; | private readonly IntPtr _handle; | ||||
private int _id; | private int _id; | ||||
//[JsonIgnore] | |||||
[JsonIgnore] | |||||
public int Id => _id; | public int Id => _id; | ||||
//[JsonIgnore] | |||||
[JsonIgnore] | |||||
public Graph graph => op?.graph; | public Graph graph => op?.graph; | ||||
//[JsonIgnore] | |||||
[JsonIgnore] | |||||
public Operation op { get; } | public Operation op { get; } | ||||
//[JsonIgnore] | |||||
[JsonIgnore] | |||||
public Tensor[] outputs => op.outputs; | public Tensor[] outputs => op.outputs; | ||||
/// <summary> | /// <summary> | ||||
@@ -112,9 +113,6 @@ namespace Tensorflow | |||||
public int NDims => rank; | public int NDims => rank; | ||||
//[JsonIgnore] | |||||
public Operation[] Consumers => consumers(); | |||||
public string Device => op.Device; | public string Device => op.Device; | ||||
public Operation[] consumers() | public Operation[] consumers() | ||||
@@ -1,4 +1,5 @@ | |||||
using Microsoft.VisualStudio.TestTools.UnitTesting; | using Microsoft.VisualStudio.TestTools.UnitTesting; | ||||
using Newtonsoft.Json; | |||||
using System; | using System; | ||||
using Tensorflow; | using Tensorflow; | ||||
@@ -14,26 +15,30 @@ namespace TensorFlowNET.UnitTest.control_flow_ops_test | |||||
public void testCondTrue() | public void testCondTrue() | ||||
{ | { | ||||
var graph = tf.Graph().as_default(); | var graph = tf.Graph().as_default(); | ||||
// tf.train.import_meta_graph("cond_test.meta"); | |||||
var json = JsonConvert.SerializeObject(graph._nodes_by_name, Formatting.Indented); | |||||
with(tf.Session(graph), sess => | with(tf.Session(graph), sess => | ||||
{ | { | ||||
var x = tf.constant(2); | |||||
var y = tf.constant(5); | |||||
var pred = tf.less(x, y); | |||||
var x = tf.constant(2, name: "x"); // graph.get_operation_by_name("Const").output; | |||||
var y = tf.constant(5, name: "y"); // graph.get_operation_by_name("Const_1").output; | |||||
var pred = tf.less(x, y); // graph.get_operation_by_name("Less").output; | |||||
Func<ITensorOrOperation> if_true = delegate | Func<ITensorOrOperation> if_true = delegate | ||||
{ | { | ||||
return tf.multiply(x, 17); | |||||
return tf.constant(2, name: "t2"); | |||||
}; | }; | ||||
Func<ITensorOrOperation> if_false = delegate | Func<ITensorOrOperation> if_false = delegate | ||||
{ | { | ||||
return tf.add(y, 23); | |||||
return tf.constant(5, name: "f5"); | |||||
}; | }; | ||||
var z = control_flow_ops.cond(pred, if_true, if_false); | |||||
var z = control_flow_ops.cond(pred, if_true, if_false); // graph.get_operation_by_name("cond/Merge").output | |||||
json = JsonConvert.SerializeObject(graph._nodes_by_name, Formatting.Indented); | |||||
int result = z.eval(sess); | int result = z.eval(sess); | ||||
assertEquals(result, 34); | |||||
assertEquals(result, 2); | |||||
}); | }); | ||||
} | } | ||||
@@ -58,25 +63,31 @@ namespace TensorFlowNET.UnitTest.control_flow_ops_test | |||||
print(result == 24) */ | print(result == 24) */ | ||||
var graph = tf.Graph().as_default(); | |||||
//tf.train.import_meta_graph("cond_test.meta"); | |||||
//var json = JsonConvert.SerializeObject(graph._nodes_by_name, Formatting.Indented); | |||||
with(tf.Session(), sess => | with(tf.Session(), sess => | ||||
{ | { | ||||
var x = tf.constant(2); | |||||
var y = tf.constant(1); | |||||
var x = tf.constant(2, name: "x"); | |||||
var y = tf.constant(1, name: "y"); | |||||
var pred = tf.less(x, y); | var pred = tf.less(x, y); | ||||
Func<ITensorOrOperation> if_true = delegate | Func<ITensorOrOperation> if_true = delegate | ||||
{ | { | ||||
return tf.multiply(x, 17); | |||||
return tf.constant(2, name: "t2"); | |||||
}; | }; | ||||
Func<ITensorOrOperation> if_false = delegate | Func<ITensorOrOperation> if_false = delegate | ||||
{ | { | ||||
return tf.add(y, 23); | |||||
return tf.constant(1, name: "f1"); | |||||
}; | }; | ||||
var z = control_flow_ops.cond(pred, if_true, if_false); | var z = control_flow_ops.cond(pred, if_true, if_false); | ||||
var json1 = JsonConvert.SerializeObject(graph._nodes_by_name, Formatting.Indented); | |||||
int result = z.eval(sess); | int result = z.eval(sess); | ||||
assertEquals(result, 24); | |||||
assertEquals(result, 1); | |||||
}); | }); | ||||
} | } | ||||