- Fixed all test cases to use using(Buffer) - Fixed all test cases to explicitly specify sessiontags/v0.12
@@ -15,58 +15,116 @@ | |||
******************************************************************************/ | |||
using System; | |||
using System.Runtime.CompilerServices; | |||
using System.Runtime.InteropServices; | |||
using NumSharp.Backends.Unmanaged; | |||
using static Tensorflow.c_api; | |||
namespace Tensorflow | |||
{ | |||
/// <summary> | |||
/// Represents a TF_Buffer that can be passed to Tensorflow. | |||
/// </summary> | |||
public class Buffer : DisposableObject | |||
{ | |||
private TF_Buffer buffer => Marshal.PtrToStructure<TF_Buffer>(_handle); | |||
private unsafe TF_Buffer buffer | |||
{ | |||
[MethodImpl(MethodImplOptions.AggressiveInlining)] | |||
get => *bufferptr; | |||
} | |||
private unsafe TF_Buffer* bufferptr | |||
{ | |||
[MethodImpl(MethodImplOptions.AggressiveInlining)] | |||
get => (TF_Buffer*) _handle; | |||
} | |||
public byte[] Data | |||
/// <summary> | |||
/// The memory block representing this buffer. | |||
/// </summary> | |||
/// <remarks>The deallocator is set to null.</remarks> | |||
public UnmanagedMemoryBlock<byte> MemoryBlock | |||
{ | |||
get | |||
get | |||
{ | |||
var data = new byte[buffer.length]; | |||
if (data.Length > 0) | |||
Marshal.Copy(buffer.data, data, 0, data.Length); | |||
return data; | |||
unsafe | |||
{ | |||
EnsureNotDisposed(); | |||
var buff = (TF_Buffer*) _handle; | |||
return new UnmanagedMemoryBlock<byte>((byte*) buff->data.ToPointer(), (long) buff->length); | |||
} | |||
} | |||
} | |||
public int Length => (int)buffer.length; | |||
public Buffer() | |||
/// <summary> | |||
/// The bytes length of this buffer. | |||
/// </summary> | |||
public ulong Length | |||
{ | |||
_handle = c_api.TF_NewBuffer(); | |||
get | |||
{ | |||
EnsureNotDisposed(); | |||
return buffer.length; | |||
} | |||
} | |||
public Buffer(IntPtr handle) | |||
public Buffer() => _handle = TF_NewBuffer(); | |||
internal Buffer(IntPtr handle) | |||
{ | |||
if (handle == IntPtr.Zero) | |||
throw new ArgumentException("Handle (IntPtr) can't be zero.", nameof(handle)); | |||
_handle = handle; | |||
} | |||
public Buffer(byte[] data) | |||
{ | |||
var dst = Marshal.AllocHGlobal(data.Length); | |||
Marshal.Copy(data, 0, dst, data.Length); | |||
public Buffer(byte[] data) : this(_toBuffer(data)) | |||
{ } | |||
_handle = c_api.TF_NewBufferFromString(dst, (ulong)data.Length); | |||
private static IntPtr _toBuffer(byte[] data) | |||
{ | |||
if (data == null) | |||
throw new ArgumentNullException(nameof(data)); | |||
Marshal.FreeHGlobal(dst); | |||
unsafe | |||
{ | |||
fixed (byte* src = data) | |||
return TF_NewBufferFromString(new IntPtr(src), (ulong) data.LongLength); | |||
} | |||
} | |||
public static implicit operator IntPtr(Buffer buffer) | |||
{ | |||
buffer.EnsureNotDisposed(); | |||
return buffer._handle; | |||
} | |||
public static implicit operator byte[](Buffer buffer) | |||
public static explicit operator byte[](Buffer buffer) => buffer.ToArray(); //has to be explicit, developer will assume it doesn't cost. | |||
/// <summary> | |||
/// Copies this buffer's contents onto a <see cref="byte"/> array. | |||
/// </summary> | |||
public byte[] ToArray() | |||
{ | |||
return buffer.Data; | |||
EnsureNotDisposed(); | |||
unsafe | |||
{ | |||
var len = buffer.length; | |||
if (len == 0) | |||
return Array.Empty<byte>(); | |||
byte[] data = new byte[len]; | |||
fixed (byte* dst = data) | |||
System.Buffer.MemoryCopy((void*) bufferptr->data, dst, len, len); | |||
return data; | |||
} | |||
} | |||
protected override void DisposeUnmanagedResources(IntPtr handle) | |||
=> c_api.TF_DeleteBuffer(handle); | |||
{ | |||
TF_DeleteBuffer(handle); | |||
} | |||
} | |||
} | |||
} |
@@ -15,6 +15,8 @@ | |||
******************************************************************************/ | |||
using System.Collections.Generic; | |||
using System.IO; | |||
using Tensorflow.Util; | |||
namespace Tensorflow | |||
{ | |||
@@ -27,12 +29,12 @@ namespace Tensorflow | |||
if(_registered_ops == null) | |||
{ | |||
_registered_ops = new Dictionary<string, OpDef>(); | |||
var handle = c_api.TF_GetAllOpList(); | |||
var buffer = new Buffer(handle); | |||
var op_list = OpList.Parser.ParseFrom(buffer); | |||
foreach (var op_def in op_list.Op) | |||
_registered_ops[op_def.Name] = op_def; | |||
using (var buffer = new Buffer(c_api.TF_GetAllOpList())) | |||
{ | |||
var op_list = OpList.Parser.ParseFrom(buffer.MemoryBlock.Stream()); | |||
foreach (var op_def in op_list.Op) | |||
_registered_ops[op_def.Name] = op_def; | |||
} | |||
} | |||
return _registered_ops; | |||
@@ -15,6 +15,7 @@ | |||
******************************************************************************/ | |||
using System.Collections.Generic; | |||
using System.Diagnostics.CodeAnalysis; | |||
using System.Linq; | |||
using Tensorflow.Operations; | |||
@@ -66,8 +67,9 @@ namespace Tensorflow | |||
/// within the context should have control dependencies on | |||
/// `control_inputs`. | |||
/// </summary> | |||
[SuppressMessage("ReSharper", "CoVariantArrayConversion")] | |||
public _ControlDependenciesController control_dependencies(ITensorOrOperation[] control_inputs) | |||
=> control_dependencies(control_inputs == null ? null : control_inputs.OfType<object>().ToArray()); | |||
=> control_dependencies((object[])control_inputs); | |||
/// <summary> | |||
/// Returns a context manager that specifies control dependencies. | |||
@@ -14,6 +14,9 @@ | |||
limitations under the License. | |||
******************************************************************************/ | |||
using System.IO; | |||
using Tensorflow.Util; | |||
namespace Tensorflow | |||
{ | |||
public partial class Graph | |||
@@ -23,21 +26,19 @@ namespace Tensorflow | |||
var buffer = new Buffer(); | |||
c_api.TF_GraphToGraphDef(_handle, buffer, s); | |||
s.Check(true); | |||
// var def = GraphDef.Parser.ParseFrom(buffer); | |||
// buffer.Dispose(); | |||
return buffer; | |||
} | |||
private GraphDef _as_graph_def(bool add_shapes = false) | |||
{ | |||
var status = new Status(); | |||
var buffer = ToGraphDef(status); | |||
status.Check(true); | |||
status.Dispose(); | |||
var def = GraphDef.Parser.ParseFrom(buffer); | |||
buffer.Dispose(); | |||
GraphDef def; | |||
using (var status = new Status()) | |||
using (var buffer = ToGraphDef(status)) | |||
{ | |||
status.Check(true); | |||
def = GraphDef.Parser.ParseFrom(buffer.MemoryBlock.Stream()); | |||
} | |||
// Strip the experimental library field iff it's empty. | |||
// if(def.Library.Function.Count == 0) | |||
@@ -45,7 +46,7 @@ namespace Tensorflow | |||
return def; | |||
} | |||
public GraphDef as_graph_def(bool add_shapes = false) | |||
public GraphDef as_graph_def(bool add_shapes = false) | |||
=> _as_graph_def(add_shapes); | |||
} | |||
} | |||
} |
@@ -18,6 +18,7 @@ using System; | |||
using System.Collections.Generic; | |||
using System.Linq; | |||
using System.Runtime.InteropServices; | |||
using Tensorflow.Util; | |||
using static Tensorflow.Binding; | |||
namespace Tensorflow | |||
@@ -30,7 +31,7 @@ namespace Tensorflow | |||
using (var status = new Status()) | |||
{ | |||
c_api.TF_GraphGetOpDef(_handle, type, buffer, status); | |||
return OpDef.Parser.ParseFrom(buffer.Data); | |||
return OpDef.Parser.ParseFrom(buffer.MemoryBlock.Stream()); | |||
} | |||
} | |||
@@ -71,7 +72,7 @@ namespace Tensorflow | |||
public ITensorOrOperation[] get_operations() | |||
{ | |||
return _nodes_by_name.Values.Select(x => x).ToArray(); | |||
return _nodes_by_name.Values.ToArray(); | |||
} | |||
/// <summary> | |||
@@ -85,7 +86,7 @@ namespace Tensorflow | |||
public ITensorOrOperation _get_operation_by_name_unsafe(string name) | |||
{ | |||
return _nodes_by_name.ContainsKey(name) ? _nodes_by_name[name] : null; | |||
return _nodes_by_name.TryGetValue(name, out var val) ? val : null; | |||
} | |||
public ITensorOrOperation _get_operation_by_tf_operation(IntPtr tf_oper) | |||
@@ -17,7 +17,9 @@ | |||
using Google.Protobuf.Collections; | |||
using System; | |||
using System.Collections.Generic; | |||
using System.IO; | |||
using System.Linq; | |||
using Tensorflow.Util; | |||
namespace Tensorflow | |||
{ | |||
@@ -226,9 +228,12 @@ namespace Tensorflow | |||
using (var status = new Status()) | |||
using (var buf = new Buffer()) | |||
{ | |||
c_api.TF_OperationGetAttrValueProto(_handle, name, buf, status); | |||
status.Check(true); | |||
x = AttrValue.Parser.ParseFrom(buf); | |||
unsafe | |||
{ | |||
c_api.TF_OperationGetAttrValueProto(_handle, name, buf, status); | |||
status.Check(true); | |||
x = AttrValue.Parser.ParseFrom(buf.MemoryBlock.Stream()); | |||
} | |||
} | |||
string oneof_value = x.ValueCase.ToString(); | |||
@@ -259,7 +264,7 @@ namespace Tensorflow | |||
{ | |||
c_api.TF_OperationToNodeDef(_handle, buffer, s); | |||
s.Check(); | |||
return NodeDef.Parser.ParseFrom(buffer); | |||
return NodeDef.Parser.ParseFrom(buffer.MemoryBlock.Stream()); | |||
} | |||
} | |||
@@ -299,8 +304,7 @@ namespace Tensorflow | |||
/// </summary> | |||
public TF_Output _tf_output(int output_idx) | |||
{ | |||
var tf_output = new TF_Output(op, output_idx); | |||
return tf_output; | |||
return new TF_Output(op, output_idx); | |||
} | |||
/// <summary> | |||
@@ -308,8 +312,7 @@ namespace Tensorflow | |||
/// </summary> | |||
public TF_Input _tf_input(int input_idx) | |||
{ | |||
var tf_input = new TF_Input(op, input_idx); | |||
return tf_input; | |||
return new TF_Input(op, input_idx); | |||
} | |||
} | |||
} |
@@ -37,8 +37,8 @@ namespace Tensorflow | |||
public void SetConfig(ConfigProto config) | |||
{ | |||
var bytes = config.ToByteArray(); | |||
var proto = Marshal.AllocHGlobal(bytes.Length); | |||
var bytes = config.ToByteArray(); //TODO! we can use WriteTo | |||
var proto = Marshal.AllocHGlobal(bytes.Length); //TODO! potential memory leak | |||
Marshal.Copy(bytes, 0, proto, bytes.Length); | |||
using (var status = new Status()) | |||
@@ -230,8 +230,8 @@ namespace Tensorflow | |||
// Add attrs | |||
foreach (var attr in node_def.Attr) | |||
{ | |||
var bytes = attr.Value.ToByteArray(); | |||
var proto = Marshal.AllocHGlobal(bytes.Length); | |||
var bytes = attr.Value.ToByteArray(); //TODO: we can use attr.Value.WriteTo with a memory stream. | |||
var proto = Marshal.AllocHGlobal(bytes.Length); //TODO: potential memory leak | |||
Marshal.Copy(bytes, 0, proto, bytes.Length); | |||
uint len = (uint)bytes.Length; | |||
c_api.TF_SetAttrValueProto(op_desc, attr.Key, proto, proto_len: len, status: status); | |||
@@ -64,8 +64,7 @@ namespace Tensorflow | |||
public Session Session() | |||
{ | |||
defaultSession = new Session(); | |||
return defaultSession; | |||
return new Session(); | |||
} | |||
public Session Session(Graph graph) | |||
@@ -102,7 +102,7 @@ namespace TensorFlowNET.Examples | |||
// Display logs per epoch step | |||
if ((epoch + 1) % display_step == 0) | |||
print($"Epoch: {(epoch + 1).ToString("D4")} Cost: {avg_cost.ToString("G9")} Elapse: {sw.ElapsedMilliseconds}ms"); | |||
print($"Epoch: {(epoch + 1):D4} Cost: {avg_cost:G9} Elapse: {sw.ElapsedMilliseconds}ms"); | |||
sw.Reset(); | |||
} | |||
@@ -114,8 +114,8 @@ namespace TensorFlowNET.Examples | |||
var correct_prediction = tf.equal(tf.argmax(pred, 1), tf.argmax(y, 1)); | |||
// Calculate accuracy | |||
var accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32)); | |||
float acc = accuracy.eval((x, mnist.Test.Data), (y, mnist.Test.Labels)); | |||
print($"Accuracy: {acc.ToString("F4")}"); | |||
float acc = accuracy.eval(sess, (x, mnist.Test.Data), (y, mnist.Test.Labels)); | |||
print($"Accuracy: {acc:F4}"); | |||
return acc > 0.9; | |||
} | |||
@@ -1,4 +1,5 @@ | |||
using Microsoft.VisualStudio.TestTools.UnitTesting; | |||
using Tensorflow; | |||
using static Tensorflow.Binding; | |||
namespace TensorFlowNET.UnitTest.Basics | |||
@@ -14,21 +15,22 @@ namespace TensorFlowNET.UnitTest.Basics | |||
var expected = new[] { false, true, false, false, true, false, true }; | |||
var spike = tf.Variable(false); | |||
spike.initializer.run(); | |||
foreach (var i in range(1, 2)) | |||
using (var sess = new Session()) | |||
{ | |||
if (raw_data[i] - raw_data[i - 1] > 5d) | |||
{ | |||
var updater = tf.assign(spike, tf.constant(true)); | |||
updater.eval(); | |||
} | |||
else | |||
spike.initializer.run(session: sess); | |||
foreach (var i in range(1, 2)) | |||
{ | |||
tf.assign(spike, tf.constant(true)).eval(); | |||
} | |||
if (raw_data[i] - raw_data[i - 1] > 5d) | |||
{ | |||
var updater = tf.assign(spike, tf.constant(true)); | |||
updater.eval(sess); | |||
} else | |||
{ | |||
tf.assign(spike, tf.constant(true)).eval(sess); | |||
} | |||
Assert.AreEqual((bool)spike.eval(), expected[i - 1]); | |||
Assert.AreEqual((bool) spike.eval(), expected[i - 1]); | |||
} | |||
} | |||
} | |||
} |
@@ -2,6 +2,7 @@ | |||
using NumSharp; | |||
using System; | |||
using Tensorflow; | |||
using Tensorflow.Util; | |||
using Buffer = Tensorflow.Buffer; | |||
namespace TensorFlowNET.UnitTest | |||
@@ -45,15 +46,18 @@ namespace TensorFlowNET.UnitTest | |||
private bool GetGraphDef(Graph graph, out GraphDef graph_def) | |||
{ | |||
graph_def = null; | |||
var s = new Status(); | |||
var buffer = new Buffer(); | |||
c_api.TF_GraphToGraphDef(graph, buffer, s); | |||
bool ret = TF_GetCode(s) == TF_OK; | |||
EXPECT_EQ(TF_OK, TF_GetCode(s)); | |||
if (ret) graph_def = GraphDef.Parser.ParseFrom(buffer.Data); | |||
buffer.Dispose(); | |||
s.Dispose(); | |||
return ret; | |||
using (var s = new Status()) | |||
{ | |||
using (var buffer = new Buffer()) | |||
{ | |||
c_api.TF_GraphToGraphDef(graph, buffer, s); | |||
bool ret = TF_GetCode(s) == TF_OK; | |||
EXPECT_EQ(TF_OK, TF_GetCode(s)); | |||
if (ret) | |||
graph_def = GraphDef.Parser.ParseFrom(buffer.MemoryBlock.Stream()); | |||
return ret; | |||
} | |||
} | |||
} | |||
private void RunGraphsAndCompareOutputs(TF_Output[] grad_outputs, TF_Output[] expected_grad_outputs) | |||
@@ -322,7 +322,6 @@ namespace TensorFlowNET.UnitTest | |||
EXPECT_EQ(feed2, control_inputs[1]); | |||
// Export to a graph def so we can import a graph with control dependencies | |||
graph_def.Dispose(); | |||
graph_def = new Buffer(); | |||
c_api.TF_GraphToGraphDef(graph, graph_def, s); | |||
EXPECT_EQ(TF_Code.TF_OK, s.Code); | |||
@@ -346,14 +345,10 @@ namespace TensorFlowNET.UnitTest | |||
EXPECT_EQ(feed4, control_inputs[1]); | |||
c_api.TF_DeleteImportGraphDefOptions(opts); | |||
c_api.TF_DeleteBuffer(graph_def); | |||
// Can add nodes to the imported graph without trouble. | |||
c_test_util.Add(feed, scalar, graph, s); | |||
ASSERT_EQ(TF_Code.TF_OK, s.Code); | |||
graph.Dispose(); | |||
s.Dispose(); | |||
} | |||
/// <summary> | |||
@@ -1,4 +1,5 @@ | |||
using Microsoft.VisualStudio.TestTools.UnitTesting; | |||
using System; | |||
using Microsoft.VisualStudio.TestTools.UnitTesting; | |||
using Tensorflow; | |||
using static Tensorflow.Binding; | |||
@@ -42,5 +43,39 @@ namespace TensorFlowNET.UnitTest | |||
Assert.AreEqual("", g._name_stack); | |||
} | |||
[TestMethod] | |||
public void NestedNameScope_Using() | |||
{ | |||
Graph g = tf.Graph().as_default(); | |||
using (var name = new ops.NameScope("scope1")) | |||
{ | |||
Assert.AreEqual("scope1", g._name_stack); | |||
Assert.AreEqual("scope1/", name); | |||
var const1 = tf.constant(1.0); | |||
Assert.AreEqual("scope1/Const:0", const1.name); | |||
using (var name2 = new ops.NameScope("scope2")) | |||
{ | |||
Assert.AreEqual("scope1/scope2", g._name_stack); | |||
Assert.AreEqual("scope1/scope2/", name); | |||
var const2 = tf.constant(2.0); | |||
Assert.AreEqual("scope1/scope2/Const:0", const2.name); | |||
} | |||
Assert.AreEqual("scope1", g._name_stack); | |||
var const3 = tf.constant(2.0); | |||
Assert.AreEqual("scope1/Const_1:0", const3.name); | |||
} | |||
; | |||
g.Dispose(); | |||
Assert.AreEqual("", g._name_stack); | |||
} | |||
} | |||
} |
@@ -4,6 +4,7 @@ using System.Collections.Generic; | |||
using System.Linq; | |||
using NumSharp; | |||
using Tensorflow; | |||
using Tensorflow.Util; | |||
using Buffer = Tensorflow.Buffer; | |||
using static Tensorflow.Binding; | |||
@@ -21,7 +22,7 @@ namespace TensorFlowNET.UnitTest | |||
{ | |||
var handle = c_api.TF_GetAllOpList(); | |||
var buffer = new Buffer(handle); | |||
var op_list = OpList.Parser.ParseFrom(buffer); | |||
var op_list = OpList.Parser.ParseFrom(buffer.MemoryBlock.Stream()); | |||
var _registered_ops = new Dictionary<string, OpDef>(); | |||
foreach (var op_def in op_list.Op) | |||
@@ -165,7 +165,7 @@ namespace TensorFlowNET.UnitTest | |||
{ | |||
using (var sess = tf.Session()) | |||
{ | |||
var ndarray=tensor.eval(); | |||
var ndarray=tensor.eval(sess); | |||
if (typeof(T) == typeof(double)) | |||
{ | |||
double x = ndarray; | |||
@@ -72,8 +72,6 @@ namespace TensorFlowNET.UnitTest | |||
// Clean up | |||
csession.CloseAndDelete(s); | |||
ASSERT_EQ(TF_Code.TF_OK, s.Code); | |||
graph.Dispose(); | |||
s.Dispose(); | |||
} | |||
[TestMethod] | |||
@@ -84,7 +82,7 @@ namespace TensorFlowNET.UnitTest | |||
var c = math_ops.matmul(a, b, name: "matmul"); | |||
using (var sess = tf.Session()) | |||
{ | |||
var result = c.eval(); | |||
var result = c.eval(sess); | |||
Assert.AreEqual(6, result.Data<double>()[0]); | |||
} | |||
} | |||
@@ -119,7 +119,7 @@ namespace TensorFlowNET.UnitTest | |||
{ | |||
sess.run(init_op); | |||
// o some work with the model. | |||
inc_v1.op.run(); | |||
inc_v1.op.run(session: sess); | |||
} | |||
} | |||
@@ -1,4 +1,6 @@ | |||
using Tensorflow; | |||
using System.Diagnostics.CodeAnalysis; | |||
using Tensorflow; | |||
using Tensorflow.Util; | |||
using Buffer = Tensorflow.Buffer; | |||
namespace TensorFlowNET.UnitTest | |||
@@ -26,12 +28,15 @@ namespace TensorFlowNET.UnitTest | |||
return op; | |||
} | |||
[SuppressMessage("ReSharper", "RedundantAssignment")] | |||
public static bool GetAttrValue(Operation oper, string attr_name, ref AttrValue attr_value, Status s) | |||
{ | |||
var buffer = new Buffer(); | |||
c_api.TF_OperationGetAttrValueProto(oper, attr_name, buffer, s); | |||
attr_value = AttrValue.Parser.ParseFrom(buffer); | |||
buffer.Dispose(); | |||
using (var buffer = new Buffer()) | |||
{ | |||
c_api.TF_OperationGetAttrValueProto(oper, attr_name, buffer, s); | |||
attr_value = AttrValue.Parser.ParseFrom(buffer.MemoryBlock.Stream()); | |||
} | |||
return s.Code == TF_Code.TF_OK; | |||
} | |||
@@ -42,7 +47,7 @@ namespace TensorFlowNET.UnitTest | |||
{ | |||
c_api.TF_GraphToGraphDef(graph, buffer, s); | |||
s.Check(); | |||
return GraphDef.Parser.ParseFrom(buffer); | |||
return GraphDef.Parser.ParseFrom(buffer.MemoryBlock.Stream()); | |||
} | |||
} | |||
@@ -24,16 +24,17 @@ namespace TensorFlowNET.UnitTest.ops_test | |||
[TestMethod] | |||
public void TestShape() | |||
{ | |||
var g = tf.Graph().as_default(); | |||
var x = constant_op.constant(new[,] { { 1, 2, 3 }, { 4, 5, 6 } }); | |||
var (c_op, op_desc) = ops._create_c_op(g, ops._NodeDef("Identity", "myop"), new[] { x }, new Operation[0]); | |||
var op = g._create_op_from_tf_operation(c_op); | |||
Assert.AreEqual("myop", op.name); | |||
Assert.AreEqual("Identity", op.type); | |||
Assert.AreEqual(1, len(op.outputs)); | |||
assertItemsEqual(new[] { 2, 3 }, op.outputs[0].shape); | |||
using (var g = tf.Graph().as_default()) | |||
{ | |||
var x = constant_op.constant(new[,] {{1, 2, 3}, {4, 5, 6}}); | |||
var (c_op, op_desc) = ops._create_c_op(g, ops._NodeDef("Identity", "myop"), new[] {x}, new Operation[0]); | |||
var op = g._create_op_from_tf_operation(c_op); | |||
Assert.AreEqual("myop", op.name); | |||
Assert.AreEqual("Identity", op.type); | |||
Assert.AreEqual(1, len(op.outputs)); | |||
assertItemsEqual(new[] {2, 3}, op.outputs[0].shape); | |||
} | |||
} | |||
[TestMethod] | |||