Browse Source

fix _constant_if_small for zeros.

tags/v0.8.0
Oceania2018 6 years ago
parent
commit
5720dfd68e
9 changed files with 38 additions and 49 deletions
  1. +2
    -0
      src/TensorFlowNET.Core/Operations/OpDefLibrary.cs
  2. +1
    -2
      src/TensorFlowNET.Core/Operations/array_ops.py.cs
  3. +1
    -1
      src/TensorFlowNET.Core/TensorFlowNET.Core.csproj
  4. +13
    -33
      src/TensorFlowNET.Core/Tensors/tensor_util.cs
  5. +0
    -2
      src/TensorFlowNET.Core/Train/Saving/BaseSaverBuilder.cs
  6. +1
    -1
      test/TensorFlowNET.Examples/TensorFlowNET.Examples.csproj
  7. +11
    -0
      test/TensorFlowNET.UnitTest/ConstantTest.cs
  8. +1
    -1
      test/TensorFlowNET.UnitTest/TensorFlowNET.UnitTest.csproj
  9. +8
    -9
      test/TensorFlowNET.UnitTest/TrainSaverTest.cs

+ 2
- 0
src/TensorFlowNET.Core/Operations/OpDefLibrary.cs View File

@@ -322,6 +322,8 @@ namespace Tensorflow
attr_value.Shape = val1.as_proto(); attr_value.Shape = val1.as_proto();
else if(value is long[] val2) else if(value is long[] val2)
attr_value.Shape = tensor_util.as_shape(val2); attr_value.Shape = tensor_util.as_shape(val2);
else if (value is int[] val3)
attr_value.Shape = tensor_util.as_shape(val3);


break; break;
default: default:


+ 1
- 2
src/TensorFlowNET.Core/Operations/array_ops.py.cs View File

@@ -34,10 +34,9 @@ namespace Tensorflow
private static Tensor _constant_if_small<T>(T value, Shape shape, TF_DataType dtype, string name) private static Tensor _constant_if_small<T>(T value, Shape shape, TF_DataType dtype, string name)
{ {
Tensor tShape = null; Tensor tShape = null;
var nd = np.zeros<T>(shape);
if (shape.Size < 1000) if (shape.Size < 1000)
{ {
return constant_op.constant(nd, name: name);
return constant_op.constant(value, shape: shape, dtype: dtype, name: name);
} }
else else
{ {


+ 1
- 1
src/TensorFlowNET.Core/TensorFlowNET.Core.csproj View File

@@ -43,7 +43,7 @@ TensorFlow 1.13 RC.</PackageReleaseNotes>


<ItemGroup> <ItemGroup>
<PackageReference Include="Google.Protobuf" Version="3.6.1" /> <PackageReference Include="Google.Protobuf" Version="3.6.1" />
<PackageReference Include="NumSharp" Version="0.7.0" />
<PackageReference Include="NumSharp" Version="0.7.1" />
</ItemGroup> </ItemGroup>


<ItemGroup> <ItemGroup>


+ 13
- 33
src/TensorFlowNET.Core/Tensors/tensor_util.cs View File

@@ -56,9 +56,9 @@ namespace Tensorflow


switch (values) switch (values)
{ {
/*case bool boolVal:
case bool boolVal:
nparray = boolVal; nparray = boolVal;
break;*/
break;
case int intVal: case int intVal:
nparray = intVal; nparray = intVal;
break; break;
@@ -74,6 +74,9 @@ namespace Tensorflow
case string strVal: case string strVal:
nparray = strVal; nparray = strVal;
break; break;
case string[] strVals:
nparray = strVals;
break;
default: default:
throw new Exception("make_tensor_proto Not Implemented"); throw new Exception("make_tensor_proto Not Implemented");
} }
@@ -100,7 +103,8 @@ namespace Tensorflow
} }
else else
{ {
throw new NotImplementedException("make_tensor_proto shape not implemented");
shape_size = new TensorShape(shape).Size;
is_same_size = shape_size == nparray.size;
} }


var tensor_proto = new tensor_pb2.TensorProto var tensor_proto = new tensor_pb2.TensorProto
@@ -111,41 +115,17 @@ namespace Tensorflow


if (is_same_size && _TENSOR_CONTENT_TYPES.Contains(numpy_dtype) && shape_size > 1) if (is_same_size && _TENSOR_CONTENT_TYPES.Contains(numpy_dtype) && shape_size > 1)
{ {
var bytes = new List<byte>();
var nd2 = nparray.ravel();
switch (nparray.dtype.Name)
{
case "Int32":
nd2.Data<int>().Select(x =>
{
bytes.AddRange(BitConverter.GetBytes(x));
return x;
}).ToArray();
break;
case "Single":
nd2.Data<float>().Select(x =>
{
bytes.AddRange(BitConverter.GetBytes(x));
return x;
}).ToArray();
break;
case "Double":
nd2.Data<double>().Select(x =>
{
bytes.AddRange(BitConverter.GetBytes(x));
return x;
}).ToArray();
break;
default:
throw new Exception("make_tensor_proto Not Implemented");
}
byte[] bytes = nparray.ToByteArray();
tensor_proto.TensorContent = Google.Protobuf.ByteString.CopyFrom(bytes.ToArray()); tensor_proto.TensorContent = Google.Protobuf.ByteString.CopyFrom(bytes.ToArray());
return tensor_proto; return tensor_proto;
} }


if (numpy_dtype == TF_DataType.TF_STRING && !(values is NDArray) && values is string str)
if (numpy_dtype == TF_DataType.TF_STRING && !(values is NDArray))
{ {
tensor_proto.StringVal.Add(Google.Protobuf.ByteString.CopyFromUtf8(str));
if (values is string str)
tensor_proto.StringVal.Add(Google.Protobuf.ByteString.CopyFromUtf8(str));
else if (values is string[] str_values)
tensor_proto.StringVal.AddRange(str_values.Select(x => Google.Protobuf.ByteString.CopyFromUtf8(x)));
return tensor_proto; return tensor_proto;
} }




+ 0
- 2
src/TensorFlowNET.Core/Train/Saving/BaseSaverBuilder.cs View File

@@ -123,8 +123,6 @@ namespace Tensorflow
Version = _write_version Version = _write_version
}; };
}); });

} }


public Tensor _AddSaveOps(Tensor filename_tensor, SaveableObject[] saveables) public Tensor _AddSaveOps(Tensor filename_tensor, SaveableObject[] saveables)


+ 1
- 1
test/TensorFlowNET.Examples/TensorFlowNET.Examples.csproj View File

@@ -6,7 +6,7 @@
</PropertyGroup> </PropertyGroup>


<ItemGroup> <ItemGroup>
<PackageReference Include="NumSharp" Version="0.7.0" />
<PackageReference Include="NumSharp" Version="0.7.1" />
<PackageReference Include="TensorFlow.NET" Version="0.1.0" /> <PackageReference Include="TensorFlow.NET" Version="0.1.0" />
</ItemGroup> </ItemGroup>




+ 11
- 0
test/TensorFlowNET.UnitTest/ConstantTest.cs View File

@@ -58,6 +58,7 @@ namespace TensorFlowNET.UnitTest


var data = result.Data<int>(); var data = result.Data<int>();
Assert.AreEqual(0, data[0]); Assert.AreEqual(0, data[0]);
Assert.AreEqual(0, data[500]);
Assert.AreEqual(0, data[result.size - 1]); Assert.AreEqual(0, data[result.size - 1]);
}); });
} }
@@ -109,5 +110,15 @@ namespace TensorFlowNET.UnitTest


//c_api.TF_StringDecode(str, (ulong)str.Length, IntPtr.Zero, ref dst_len, status); //c_api.TF_StringDecode(str, (ulong)str.Length, IntPtr.Zero, ref dst_len, status);
} }

/// <summary>
/// tensorflow\c\c_api_test.cc
/// TestEncodeDecode
/// </summary>
[TestMethod]
public void EncodeDecode()
{

}
} }
} }

+ 1
- 1
test/TensorFlowNET.UnitTest/TensorFlowNET.UnitTest.csproj View File

@@ -19,7 +19,7 @@
<PackageReference Include="Microsoft.NET.Test.Sdk" Version="15.9.0" /> <PackageReference Include="Microsoft.NET.Test.Sdk" Version="15.9.0" />
<PackageReference Include="MSTest.TestAdapter" Version="1.4.0" /> <PackageReference Include="MSTest.TestAdapter" Version="1.4.0" />
<PackageReference Include="MSTest.TestFramework" Version="1.4.0" /> <PackageReference Include="MSTest.TestFramework" Version="1.4.0" />
<PackageReference Include="NumSharp" Version="0.7.0" />
<PackageReference Include="NumSharp" Version="0.7.1" />
<PackageReference Include="TensorFlow.NET" Version="0.1.0" /> <PackageReference Include="TensorFlow.NET" Version="0.1.0" />
</ItemGroup> </ItemGroup>




+ 8
- 9
test/TensorFlowNET.UnitTest/TrainSaverTest.cs View File

@@ -10,11 +10,11 @@ namespace TensorFlowNET.UnitTest
public class TrainSaverTest : Python public class TrainSaverTest : Python
{ {
[TestMethod] [TestMethod]
public void WriteGraph()
public void ExportGraph()
{ {
var v = tf.Variable(0, name: "my_variable"); var v = tf.Variable(0, name: "my_variable");
var sess = tf.Session(); var sess = tf.Session();
tf.train.write_graph(sess.graph, "/tmp/my-model", "train.pbtxt");
tf.train.write_graph(sess.graph, "/tmp/my-model", "train1.pbtxt");
} }


[TestMethod] [TestMethod]
@@ -22,14 +22,13 @@ namespace TensorFlowNET.UnitTest
{ {
var v = tf.Variable(0, name: "my_variable"); var v = tf.Variable(0, name: "my_variable");
var sess = tf.Session(); var sess = tf.Session();
tf.train.write_graph(sess.graph, "/tmp/my-model", "train.pbtxt");
tf.train.write_graph(sess.graph, "/tmp/my-model", "train2.pbtxt");
} }


[TestMethod] [TestMethod]
public void SaveSimple()
public void Save1()
{ {
var w1 = tf.Variable(tf.random_normal(new int[] { 2 }), name: "w1");
var w2 = tf.Variable(tf.random_normal(new int[] { 5 }), name: "w2");
var w1 = tf.Variable(0, name: "save1");


var init_op = tf.global_variables_initializer(); var init_op = tf.global_variables_initializer();


@@ -41,13 +40,13 @@ namespace TensorFlowNET.UnitTest
sess.run(init_op); sess.run(init_op);


// Save the variables to disk. // Save the variables to disk.
var save_path = saver.save(sess, "/tmp/model.ckpt");
var save_path = saver.save(sess, "/tmp/model1.ckpt");
Console.WriteLine($"Model saved in path: {save_path}"); Console.WriteLine($"Model saved in path: {save_path}");
}); });
} }


[TestMethod] [TestMethod]
public void Save()
public void Save2()
{ {
var v1 = tf.get_variable("v1", shape: new TensorShape(3), initializer: tf.zeros_initializer); var v1 = tf.get_variable("v1", shape: new TensorShape(3), initializer: tf.zeros_initializer);
var v2 = tf.get_variable("v2", shape: new TensorShape(5), initializer: tf.zeros_initializer); var v2 = tf.get_variable("v2", shape: new TensorShape(5), initializer: tf.zeros_initializer);
@@ -69,7 +68,7 @@ namespace TensorFlowNET.UnitTest
dec_v2.op.run(); dec_v2.op.run();


// Save the variables to disk. // Save the variables to disk.
var save_path = saver.save(sess, "/tmp/model.ckpt");
var save_path = saver.save(sess, "/tmp/model2.ckpt");
Console.WriteLine($"Model saved in path: {save_path}"); Console.WriteLine($"Model saved in path: {save_path}");
}); });
} }


Loading…
Cancel
Save