@@ -322,6 +322,8 @@ namespace Tensorflow | |||
attr_value.Shape = val1.as_proto(); | |||
else if(value is long[] val2) | |||
attr_value.Shape = tensor_util.as_shape(val2); | |||
else if (value is int[] val3) | |||
attr_value.Shape = tensor_util.as_shape(val3); | |||
break; | |||
default: | |||
@@ -34,10 +34,9 @@ namespace Tensorflow | |||
private static Tensor _constant_if_small<T>(T value, Shape shape, TF_DataType dtype, string name) | |||
{ | |||
Tensor tShape = null; | |||
var nd = np.zeros<T>(shape); | |||
if (shape.Size < 1000) | |||
{ | |||
return constant_op.constant(nd, name: name); | |||
return constant_op.constant(value, shape: shape, dtype: dtype, name: name); | |||
} | |||
else | |||
{ | |||
@@ -43,7 +43,7 @@ TensorFlow 1.13 RC.</PackageReleaseNotes> | |||
<ItemGroup> | |||
<PackageReference Include="Google.Protobuf" Version="3.6.1" /> | |||
<PackageReference Include="NumSharp" Version="0.7.0" /> | |||
<PackageReference Include="NumSharp" Version="0.7.1" /> | |||
</ItemGroup> | |||
<ItemGroup> | |||
@@ -56,9 +56,9 @@ namespace Tensorflow | |||
switch (values) | |||
{ | |||
/*case bool boolVal: | |||
case bool boolVal: | |||
nparray = boolVal; | |||
break;*/ | |||
break; | |||
case int intVal: | |||
nparray = intVal; | |||
break; | |||
@@ -74,6 +74,9 @@ namespace Tensorflow | |||
case string strVal: | |||
nparray = strVal; | |||
break; | |||
case string[] strVals: | |||
nparray = strVals; | |||
break; | |||
default: | |||
throw new Exception("make_tensor_proto Not Implemented"); | |||
} | |||
@@ -100,7 +103,8 @@ namespace Tensorflow | |||
} | |||
else | |||
{ | |||
throw new NotImplementedException("make_tensor_proto shape not implemented"); | |||
shape_size = new TensorShape(shape).Size; | |||
is_same_size = shape_size == nparray.size; | |||
} | |||
var tensor_proto = new tensor_pb2.TensorProto | |||
@@ -111,41 +115,17 @@ namespace Tensorflow | |||
if (is_same_size && _TENSOR_CONTENT_TYPES.Contains(numpy_dtype) && shape_size > 1) | |||
{ | |||
var bytes = new List<byte>(); | |||
var nd2 = nparray.ravel(); | |||
switch (nparray.dtype.Name) | |||
{ | |||
case "Int32": | |||
nd2.Data<int>().Select(x => | |||
{ | |||
bytes.AddRange(BitConverter.GetBytes(x)); | |||
return x; | |||
}).ToArray(); | |||
break; | |||
case "Single": | |||
nd2.Data<float>().Select(x => | |||
{ | |||
bytes.AddRange(BitConverter.GetBytes(x)); | |||
return x; | |||
}).ToArray(); | |||
break; | |||
case "Double": | |||
nd2.Data<double>().Select(x => | |||
{ | |||
bytes.AddRange(BitConverter.GetBytes(x)); | |||
return x; | |||
}).ToArray(); | |||
break; | |||
default: | |||
throw new Exception("make_tensor_proto Not Implemented"); | |||
} | |||
byte[] bytes = nparray.ToByteArray(); | |||
tensor_proto.TensorContent = Google.Protobuf.ByteString.CopyFrom(bytes.ToArray()); | |||
return tensor_proto; | |||
} | |||
if (numpy_dtype == TF_DataType.TF_STRING && !(values is NDArray) && values is string str) | |||
if (numpy_dtype == TF_DataType.TF_STRING && !(values is NDArray)) | |||
{ | |||
tensor_proto.StringVal.Add(Google.Protobuf.ByteString.CopyFromUtf8(str)); | |||
if (values is string str) | |||
tensor_proto.StringVal.Add(Google.Protobuf.ByteString.CopyFromUtf8(str)); | |||
else if (values is string[] str_values) | |||
tensor_proto.StringVal.AddRange(str_values.Select(x => Google.Protobuf.ByteString.CopyFromUtf8(x))); | |||
return tensor_proto; | |||
} | |||
@@ -123,8 +123,6 @@ namespace Tensorflow | |||
Version = _write_version | |||
}; | |||
}); | |||
} | |||
public Tensor _AddSaveOps(Tensor filename_tensor, SaveableObject[] saveables) | |||
@@ -6,7 +6,7 @@ | |||
</PropertyGroup> | |||
<ItemGroup> | |||
<PackageReference Include="NumSharp" Version="0.7.0" /> | |||
<PackageReference Include="NumSharp" Version="0.7.1" /> | |||
<PackageReference Include="TensorFlow.NET" Version="0.1.0" /> | |||
</ItemGroup> | |||
@@ -58,6 +58,7 @@ namespace TensorFlowNET.UnitTest | |||
var data = result.Data<int>(); | |||
Assert.AreEqual(0, data[0]); | |||
Assert.AreEqual(0, data[500]); | |||
Assert.AreEqual(0, data[result.size - 1]); | |||
}); | |||
} | |||
@@ -109,5 +110,15 @@ namespace TensorFlowNET.UnitTest | |||
//c_api.TF_StringDecode(str, (ulong)str.Length, IntPtr.Zero, ref dst_len, status); | |||
} | |||
/// <summary> | |||
/// tensorflow\c\c_api_test.cc | |||
/// TestEncodeDecode | |||
/// </summary> | |||
[TestMethod] | |||
public void EncodeDecode() | |||
{ | |||
} | |||
} | |||
} |
@@ -19,7 +19,7 @@ | |||
<PackageReference Include="Microsoft.NET.Test.Sdk" Version="15.9.0" /> | |||
<PackageReference Include="MSTest.TestAdapter" Version="1.4.0" /> | |||
<PackageReference Include="MSTest.TestFramework" Version="1.4.0" /> | |||
<PackageReference Include="NumSharp" Version="0.7.0" /> | |||
<PackageReference Include="NumSharp" Version="0.7.1" /> | |||
<PackageReference Include="TensorFlow.NET" Version="0.1.0" /> | |||
</ItemGroup> | |||
@@ -10,11 +10,11 @@ namespace TensorFlowNET.UnitTest | |||
public class TrainSaverTest : Python | |||
{ | |||
[TestMethod] | |||
public void WriteGraph() | |||
public void ExportGraph() | |||
{ | |||
var v = tf.Variable(0, name: "my_variable"); | |||
var sess = tf.Session(); | |||
tf.train.write_graph(sess.graph, "/tmp/my-model", "train.pbtxt"); | |||
tf.train.write_graph(sess.graph, "/tmp/my-model", "train1.pbtxt"); | |||
} | |||
[TestMethod] | |||
@@ -22,14 +22,13 @@ namespace TensorFlowNET.UnitTest | |||
{ | |||
var v = tf.Variable(0, name: "my_variable"); | |||
var sess = tf.Session(); | |||
tf.train.write_graph(sess.graph, "/tmp/my-model", "train.pbtxt"); | |||
tf.train.write_graph(sess.graph, "/tmp/my-model", "train2.pbtxt"); | |||
} | |||
[TestMethod] | |||
public void SaveSimple() | |||
public void Save1() | |||
{ | |||
var w1 = tf.Variable(tf.random_normal(new int[] { 2 }), name: "w1"); | |||
var w2 = tf.Variable(tf.random_normal(new int[] { 5 }), name: "w2"); | |||
var w1 = tf.Variable(0, name: "save1"); | |||
var init_op = tf.global_variables_initializer(); | |||
@@ -41,13 +40,13 @@ namespace TensorFlowNET.UnitTest | |||
sess.run(init_op); | |||
// Save the variables to disk. | |||
var save_path = saver.save(sess, "/tmp/model.ckpt"); | |||
var save_path = saver.save(sess, "/tmp/model1.ckpt"); | |||
Console.WriteLine($"Model saved in path: {save_path}"); | |||
}); | |||
} | |||
[TestMethod] | |||
public void Save() | |||
public void Save2() | |||
{ | |||
var v1 = tf.get_variable("v1", shape: new TensorShape(3), initializer: tf.zeros_initializer); | |||
var v2 = tf.get_variable("v2", shape: new TensorShape(5), initializer: tf.zeros_initializer); | |||
@@ -69,7 +68,7 @@ namespace TensorFlowNET.UnitTest | |||
dec_v2.op.run(); | |||
// Save the variables to disk. | |||
var save_path = saver.save(sess, "/tmp/model.ckpt"); | |||
var save_path = saver.save(sess, "/tmp/model2.ckpt"); | |||
Console.WriteLine($"Model saved in path: {save_path}"); | |||
}); | |||
} | |||