@@ -322,6 +322,8 @@ namespace Tensorflow | |||||
attr_value.Shape = val1.as_proto(); | attr_value.Shape = val1.as_proto(); | ||||
else if(value is long[] val2) | else if(value is long[] val2) | ||||
attr_value.Shape = tensor_util.as_shape(val2); | attr_value.Shape = tensor_util.as_shape(val2); | ||||
else if (value is int[] val3) | |||||
attr_value.Shape = tensor_util.as_shape(val3); | |||||
break; | break; | ||||
default: | default: | ||||
@@ -34,10 +34,9 @@ namespace Tensorflow | |||||
private static Tensor _constant_if_small<T>(T value, Shape shape, TF_DataType dtype, string name) | private static Tensor _constant_if_small<T>(T value, Shape shape, TF_DataType dtype, string name) | ||||
{ | { | ||||
Tensor tShape = null; | Tensor tShape = null; | ||||
var nd = np.zeros<T>(shape); | |||||
if (shape.Size < 1000) | if (shape.Size < 1000) | ||||
{ | { | ||||
return constant_op.constant(nd, name: name); | |||||
return constant_op.constant(value, shape: shape, dtype: dtype, name: name); | |||||
} | } | ||||
else | else | ||||
{ | { | ||||
@@ -43,7 +43,7 @@ TensorFlow 1.13 RC.</PackageReleaseNotes> | |||||
<ItemGroup> | <ItemGroup> | ||||
<PackageReference Include="Google.Protobuf" Version="3.6.1" /> | <PackageReference Include="Google.Protobuf" Version="3.6.1" /> | ||||
<PackageReference Include="NumSharp" Version="0.7.0" /> | |||||
<PackageReference Include="NumSharp" Version="0.7.1" /> | |||||
</ItemGroup> | </ItemGroup> | ||||
<ItemGroup> | <ItemGroup> | ||||
@@ -56,9 +56,9 @@ namespace Tensorflow | |||||
switch (values) | switch (values) | ||||
{ | { | ||||
/*case bool boolVal: | |||||
case bool boolVal: | |||||
nparray = boolVal; | nparray = boolVal; | ||||
break;*/ | |||||
break; | |||||
case int intVal: | case int intVal: | ||||
nparray = intVal; | nparray = intVal; | ||||
break; | break; | ||||
@@ -74,6 +74,9 @@ namespace Tensorflow | |||||
case string strVal: | case string strVal: | ||||
nparray = strVal; | nparray = strVal; | ||||
break; | break; | ||||
case string[] strVals: | |||||
nparray = strVals; | |||||
break; | |||||
default: | default: | ||||
throw new Exception("make_tensor_proto Not Implemented"); | throw new Exception("make_tensor_proto Not Implemented"); | ||||
} | } | ||||
@@ -100,7 +103,8 @@ namespace Tensorflow | |||||
} | } | ||||
else | else | ||||
{ | { | ||||
throw new NotImplementedException("make_tensor_proto shape not implemented"); | |||||
shape_size = new TensorShape(shape).Size; | |||||
is_same_size = shape_size == nparray.size; | |||||
} | } | ||||
var tensor_proto = new tensor_pb2.TensorProto | var tensor_proto = new tensor_pb2.TensorProto | ||||
@@ -111,41 +115,17 @@ namespace Tensorflow | |||||
if (is_same_size && _TENSOR_CONTENT_TYPES.Contains(numpy_dtype) && shape_size > 1) | if (is_same_size && _TENSOR_CONTENT_TYPES.Contains(numpy_dtype) && shape_size > 1) | ||||
{ | { | ||||
var bytes = new List<byte>(); | |||||
var nd2 = nparray.ravel(); | |||||
switch (nparray.dtype.Name) | |||||
{ | |||||
case "Int32": | |||||
nd2.Data<int>().Select(x => | |||||
{ | |||||
bytes.AddRange(BitConverter.GetBytes(x)); | |||||
return x; | |||||
}).ToArray(); | |||||
break; | |||||
case "Single": | |||||
nd2.Data<float>().Select(x => | |||||
{ | |||||
bytes.AddRange(BitConverter.GetBytes(x)); | |||||
return x; | |||||
}).ToArray(); | |||||
break; | |||||
case "Double": | |||||
nd2.Data<double>().Select(x => | |||||
{ | |||||
bytes.AddRange(BitConverter.GetBytes(x)); | |||||
return x; | |||||
}).ToArray(); | |||||
break; | |||||
default: | |||||
throw new Exception("make_tensor_proto Not Implemented"); | |||||
} | |||||
byte[] bytes = nparray.ToByteArray(); | |||||
tensor_proto.TensorContent = Google.Protobuf.ByteString.CopyFrom(bytes.ToArray()); | tensor_proto.TensorContent = Google.Protobuf.ByteString.CopyFrom(bytes.ToArray()); | ||||
return tensor_proto; | return tensor_proto; | ||||
} | } | ||||
if (numpy_dtype == TF_DataType.TF_STRING && !(values is NDArray) && values is string str) | |||||
if (numpy_dtype == TF_DataType.TF_STRING && !(values is NDArray)) | |||||
{ | { | ||||
tensor_proto.StringVal.Add(Google.Protobuf.ByteString.CopyFromUtf8(str)); | |||||
if (values is string str) | |||||
tensor_proto.StringVal.Add(Google.Protobuf.ByteString.CopyFromUtf8(str)); | |||||
else if (values is string[] str_values) | |||||
tensor_proto.StringVal.AddRange(str_values.Select(x => Google.Protobuf.ByteString.CopyFromUtf8(x))); | |||||
return tensor_proto; | return tensor_proto; | ||||
} | } | ||||
@@ -123,8 +123,6 @@ namespace Tensorflow | |||||
Version = _write_version | Version = _write_version | ||||
}; | }; | ||||
}); | }); | ||||
} | } | ||||
public Tensor _AddSaveOps(Tensor filename_tensor, SaveableObject[] saveables) | public Tensor _AddSaveOps(Tensor filename_tensor, SaveableObject[] saveables) | ||||
@@ -6,7 +6,7 @@ | |||||
</PropertyGroup> | </PropertyGroup> | ||||
<ItemGroup> | <ItemGroup> | ||||
<PackageReference Include="NumSharp" Version="0.7.0" /> | |||||
<PackageReference Include="NumSharp" Version="0.7.1" /> | |||||
<PackageReference Include="TensorFlow.NET" Version="0.1.0" /> | <PackageReference Include="TensorFlow.NET" Version="0.1.0" /> | ||||
</ItemGroup> | </ItemGroup> | ||||
@@ -58,6 +58,7 @@ namespace TensorFlowNET.UnitTest | |||||
var data = result.Data<int>(); | var data = result.Data<int>(); | ||||
Assert.AreEqual(0, data[0]); | Assert.AreEqual(0, data[0]); | ||||
Assert.AreEqual(0, data[500]); | |||||
Assert.AreEqual(0, data[result.size - 1]); | Assert.AreEqual(0, data[result.size - 1]); | ||||
}); | }); | ||||
} | } | ||||
@@ -109,5 +110,15 @@ namespace TensorFlowNET.UnitTest | |||||
//c_api.TF_StringDecode(str, (ulong)str.Length, IntPtr.Zero, ref dst_len, status); | //c_api.TF_StringDecode(str, (ulong)str.Length, IntPtr.Zero, ref dst_len, status); | ||||
} | } | ||||
/// <summary> | |||||
/// tensorflow\c\c_api_test.cc | |||||
/// TestEncodeDecode | |||||
/// </summary> | |||||
[TestMethod] | |||||
public void EncodeDecode() | |||||
{ | |||||
} | |||||
} | } | ||||
} | } |
@@ -19,7 +19,7 @@ | |||||
<PackageReference Include="Microsoft.NET.Test.Sdk" Version="15.9.0" /> | <PackageReference Include="Microsoft.NET.Test.Sdk" Version="15.9.0" /> | ||||
<PackageReference Include="MSTest.TestAdapter" Version="1.4.0" /> | <PackageReference Include="MSTest.TestAdapter" Version="1.4.0" /> | ||||
<PackageReference Include="MSTest.TestFramework" Version="1.4.0" /> | <PackageReference Include="MSTest.TestFramework" Version="1.4.0" /> | ||||
<PackageReference Include="NumSharp" Version="0.7.0" /> | |||||
<PackageReference Include="NumSharp" Version="0.7.1" /> | |||||
<PackageReference Include="TensorFlow.NET" Version="0.1.0" /> | <PackageReference Include="TensorFlow.NET" Version="0.1.0" /> | ||||
</ItemGroup> | </ItemGroup> | ||||
@@ -10,11 +10,11 @@ namespace TensorFlowNET.UnitTest | |||||
public class TrainSaverTest : Python | public class TrainSaverTest : Python | ||||
{ | { | ||||
[TestMethod] | [TestMethod] | ||||
public void WriteGraph() | |||||
public void ExportGraph() | |||||
{ | { | ||||
var v = tf.Variable(0, name: "my_variable"); | var v = tf.Variable(0, name: "my_variable"); | ||||
var sess = tf.Session(); | var sess = tf.Session(); | ||||
tf.train.write_graph(sess.graph, "/tmp/my-model", "train.pbtxt"); | |||||
tf.train.write_graph(sess.graph, "/tmp/my-model", "train1.pbtxt"); | |||||
} | } | ||||
[TestMethod] | [TestMethod] | ||||
@@ -22,14 +22,13 @@ namespace TensorFlowNET.UnitTest | |||||
{ | { | ||||
var v = tf.Variable(0, name: "my_variable"); | var v = tf.Variable(0, name: "my_variable"); | ||||
var sess = tf.Session(); | var sess = tf.Session(); | ||||
tf.train.write_graph(sess.graph, "/tmp/my-model", "train.pbtxt"); | |||||
tf.train.write_graph(sess.graph, "/tmp/my-model", "train2.pbtxt"); | |||||
} | } | ||||
[TestMethod] | [TestMethod] | ||||
public void SaveSimple() | |||||
public void Save1() | |||||
{ | { | ||||
var w1 = tf.Variable(tf.random_normal(new int[] { 2 }), name: "w1"); | |||||
var w2 = tf.Variable(tf.random_normal(new int[] { 5 }), name: "w2"); | |||||
var w1 = tf.Variable(0, name: "save1"); | |||||
var init_op = tf.global_variables_initializer(); | var init_op = tf.global_variables_initializer(); | ||||
@@ -41,13 +40,13 @@ namespace TensorFlowNET.UnitTest | |||||
sess.run(init_op); | sess.run(init_op); | ||||
// Save the variables to disk. | // Save the variables to disk. | ||||
var save_path = saver.save(sess, "/tmp/model.ckpt"); | |||||
var save_path = saver.save(sess, "/tmp/model1.ckpt"); | |||||
Console.WriteLine($"Model saved in path: {save_path}"); | Console.WriteLine($"Model saved in path: {save_path}"); | ||||
}); | }); | ||||
} | } | ||||
[TestMethod] | [TestMethod] | ||||
public void Save() | |||||
public void Save2() | |||||
{ | { | ||||
var v1 = tf.get_variable("v1", shape: new TensorShape(3), initializer: tf.zeros_initializer); | var v1 = tf.get_variable("v1", shape: new TensorShape(3), initializer: tf.zeros_initializer); | ||||
var v2 = tf.get_variable("v2", shape: new TensorShape(5), initializer: tf.zeros_initializer); | var v2 = tf.get_variable("v2", shape: new TensorShape(5), initializer: tf.zeros_initializer); | ||||
@@ -69,7 +68,7 @@ namespace TensorFlowNET.UnitTest | |||||
dec_v2.op.run(); | dec_v2.op.run(); | ||||
// Save the variables to disk. | // Save the variables to disk. | ||||
var save_path = saver.save(sess, "/tmp/model.ckpt"); | |||||
var save_path = saver.save(sess, "/tmp/model2.ckpt"); | |||||
Console.WriteLine($"Model saved in path: {save_path}"); | Console.WriteLine($"Model saved in path: {save_path}"); | ||||
}); | }); | ||||
} | } | ||||