Browse Source

fix _constant_if_small for zeros.

tags/v0.8.0
Oceania2018 6 years ago
parent
commit
5720dfd68e
9 changed files with 38 additions and 49 deletions
  1. +2
    -0
      src/TensorFlowNET.Core/Operations/OpDefLibrary.cs
  2. +1
    -2
      src/TensorFlowNET.Core/Operations/array_ops.py.cs
  3. +1
    -1
      src/TensorFlowNET.Core/TensorFlowNET.Core.csproj
  4. +13
    -33
      src/TensorFlowNET.Core/Tensors/tensor_util.cs
  5. +0
    -2
      src/TensorFlowNET.Core/Train/Saving/BaseSaverBuilder.cs
  6. +1
    -1
      test/TensorFlowNET.Examples/TensorFlowNET.Examples.csproj
  7. +11
    -0
      test/TensorFlowNET.UnitTest/ConstantTest.cs
  8. +1
    -1
      test/TensorFlowNET.UnitTest/TensorFlowNET.UnitTest.csproj
  9. +8
    -9
      test/TensorFlowNET.UnitTest/TrainSaverTest.cs

+ 2
- 0
src/TensorFlowNET.Core/Operations/OpDefLibrary.cs View File

@@ -322,6 +322,8 @@ namespace Tensorflow
attr_value.Shape = val1.as_proto();
else if(value is long[] val2)
attr_value.Shape = tensor_util.as_shape(val2);
else if (value is int[] val3)
attr_value.Shape = tensor_util.as_shape(val3);

break;
default:


+ 1
- 2
src/TensorFlowNET.Core/Operations/array_ops.py.cs View File

@@ -34,10 +34,9 @@ namespace Tensorflow
private static Tensor _constant_if_small<T>(T value, Shape shape, TF_DataType dtype, string name)
{
Tensor tShape = null;
var nd = np.zeros<T>(shape);
if (shape.Size < 1000)
{
return constant_op.constant(nd, name: name);
return constant_op.constant(value, shape: shape, dtype: dtype, name: name);
}
else
{


+ 1
- 1
src/TensorFlowNET.Core/TensorFlowNET.Core.csproj View File

@@ -43,7 +43,7 @@ TensorFlow 1.13 RC.</PackageReleaseNotes>

<ItemGroup>
<PackageReference Include="Google.Protobuf" Version="3.6.1" />
<PackageReference Include="NumSharp" Version="0.7.0" />
<PackageReference Include="NumSharp" Version="0.7.1" />
</ItemGroup>

<ItemGroup>


+ 13
- 33
src/TensorFlowNET.Core/Tensors/tensor_util.cs View File

@@ -56,9 +56,9 @@ namespace Tensorflow

switch (values)
{
/*case bool boolVal:
case bool boolVal:
nparray = boolVal;
break;*/
break;
case int intVal:
nparray = intVal;
break;
@@ -74,6 +74,9 @@ namespace Tensorflow
case string strVal:
nparray = strVal;
break;
case string[] strVals:
nparray = strVals;
break;
default:
throw new Exception("make_tensor_proto Not Implemented");
}
@@ -100,7 +103,8 @@ namespace Tensorflow
}
else
{
throw new NotImplementedException("make_tensor_proto shape not implemented");
shape_size = new TensorShape(shape).Size;
is_same_size = shape_size == nparray.size;
}

var tensor_proto = new tensor_pb2.TensorProto
@@ -111,41 +115,17 @@ namespace Tensorflow

if (is_same_size && _TENSOR_CONTENT_TYPES.Contains(numpy_dtype) && shape_size > 1)
{
var bytes = new List<byte>();
var nd2 = nparray.ravel();
switch (nparray.dtype.Name)
{
case "Int32":
nd2.Data<int>().Select(x =>
{
bytes.AddRange(BitConverter.GetBytes(x));
return x;
}).ToArray();
break;
case "Single":
nd2.Data<float>().Select(x =>
{
bytes.AddRange(BitConverter.GetBytes(x));
return x;
}).ToArray();
break;
case "Double":
nd2.Data<double>().Select(x =>
{
bytes.AddRange(BitConverter.GetBytes(x));
return x;
}).ToArray();
break;
default:
throw new Exception("make_tensor_proto Not Implemented");
}
byte[] bytes = nparray.ToByteArray();
tensor_proto.TensorContent = Google.Protobuf.ByteString.CopyFrom(bytes.ToArray());
return tensor_proto;
}

if (numpy_dtype == TF_DataType.TF_STRING && !(values is NDArray) && values is string str)
if (numpy_dtype == TF_DataType.TF_STRING && !(values is NDArray))
{
tensor_proto.StringVal.Add(Google.Protobuf.ByteString.CopyFromUtf8(str));
if (values is string str)
tensor_proto.StringVal.Add(Google.Protobuf.ByteString.CopyFromUtf8(str));
else if (values is string[] str_values)
tensor_proto.StringVal.AddRange(str_values.Select(x => Google.Protobuf.ByteString.CopyFromUtf8(x)));
return tensor_proto;
}



+ 0
- 2
src/TensorFlowNET.Core/Train/Saving/BaseSaverBuilder.cs View File

@@ -123,8 +123,6 @@ namespace Tensorflow
Version = _write_version
};
});

}

public Tensor _AddSaveOps(Tensor filename_tensor, SaveableObject[] saveables)


+ 1
- 1
test/TensorFlowNET.Examples/TensorFlowNET.Examples.csproj View File

@@ -6,7 +6,7 @@
</PropertyGroup>

<ItemGroup>
<PackageReference Include="NumSharp" Version="0.7.0" />
<PackageReference Include="NumSharp" Version="0.7.1" />
<PackageReference Include="TensorFlow.NET" Version="0.1.0" />
</ItemGroup>



+ 11
- 0
test/TensorFlowNET.UnitTest/ConstantTest.cs View File

@@ -58,6 +58,7 @@ namespace TensorFlowNET.UnitTest

var data = result.Data<int>();
Assert.AreEqual(0, data[0]);
Assert.AreEqual(0, data[500]);
Assert.AreEqual(0, data[result.size - 1]);
});
}
@@ -109,5 +110,15 @@ namespace TensorFlowNET.UnitTest

//c_api.TF_StringDecode(str, (ulong)str.Length, IntPtr.Zero, ref dst_len, status);
}

/// <summary>
/// tensorflow\c\c_api_test.cc
/// TestEncodeDecode
/// </summary>
[TestMethod]
public void EncodeDecode()
{

}
}
}

+ 1
- 1
test/TensorFlowNET.UnitTest/TensorFlowNET.UnitTest.csproj View File

@@ -19,7 +19,7 @@
<PackageReference Include="Microsoft.NET.Test.Sdk" Version="15.9.0" />
<PackageReference Include="MSTest.TestAdapter" Version="1.4.0" />
<PackageReference Include="MSTest.TestFramework" Version="1.4.0" />
<PackageReference Include="NumSharp" Version="0.7.0" />
<PackageReference Include="NumSharp" Version="0.7.1" />
<PackageReference Include="TensorFlow.NET" Version="0.1.0" />
</ItemGroup>



+ 8
- 9
test/TensorFlowNET.UnitTest/TrainSaverTest.cs View File

@@ -10,11 +10,11 @@ namespace TensorFlowNET.UnitTest
public class TrainSaverTest : Python
{
[TestMethod]
public void WriteGraph()
public void ExportGraph()
{
var v = tf.Variable(0, name: "my_variable");
var sess = tf.Session();
tf.train.write_graph(sess.graph, "/tmp/my-model", "train.pbtxt");
tf.train.write_graph(sess.graph, "/tmp/my-model", "train1.pbtxt");
}

[TestMethod]
@@ -22,14 +22,13 @@ namespace TensorFlowNET.UnitTest
{
var v = tf.Variable(0, name: "my_variable");
var sess = tf.Session();
tf.train.write_graph(sess.graph, "/tmp/my-model", "train.pbtxt");
tf.train.write_graph(sess.graph, "/tmp/my-model", "train2.pbtxt");
}

[TestMethod]
public void SaveSimple()
public void Save1()
{
var w1 = tf.Variable(tf.random_normal(new int[] { 2 }), name: "w1");
var w2 = tf.Variable(tf.random_normal(new int[] { 5 }), name: "w2");
var w1 = tf.Variable(0, name: "save1");

var init_op = tf.global_variables_initializer();

@@ -41,13 +40,13 @@ namespace TensorFlowNET.UnitTest
sess.run(init_op);

// Save the variables to disk.
var save_path = saver.save(sess, "/tmp/model.ckpt");
var save_path = saver.save(sess, "/tmp/model1.ckpt");
Console.WriteLine($"Model saved in path: {save_path}");
});
}

[TestMethod]
public void Save()
public void Save2()
{
var v1 = tf.get_variable("v1", shape: new TensorShape(3), initializer: tf.zeros_initializer);
var v2 = tf.get_variable("v2", shape: new TensorShape(5), initializer: tf.zeros_initializer);
@@ -69,7 +68,7 @@ namespace TensorFlowNET.UnitTest
dec_v2.op.run();

// Save the variables to disk.
var save_path = saver.save(sess, "/tmp/model.ckpt");
var save_path = saver.save(sess, "/tmp/model2.ckpt");
Console.WriteLine($"Model saved in path: {save_path}");
});
}


Loading…
Cancel
Save