Browse Source

VariableTest #888

tags/TimeSeries
Oceania2018 3 years ago
parent
commit
e40be93380
7 changed files with 48 additions and 12 deletions
  1. +9
    -0
      src/TensorFlowNET.Core/APIs/tf.compat.v1.cs
  2. +1
    -4
      src/TensorFlowNET.Core/APIs/tf.variable.cs
  3. +3
    -3
      src/TensorFlowNET.Core/Tensorflow.Binding.csproj
  4. +6
    -0
      src/TensorFlowNET.Core/Variables/ResourceVariable.cs
  5. +1
    -3
      src/TensorFlowNET.Core/tensorflow.cs
  6. +2
    -2
      src/TensorFlowNET.Keras/Tensorflow.Keras.csproj
  7. +26
    -0
      test/TensorFlowNET.Graph.UnitTest/Basics/VariableTest.cs

+ 9
- 0
src/TensorFlowNET.Core/APIs/tf.compat.v1.cs View File

@@ -47,5 +47,14 @@ namespace Tensorflow
trainable: trainable,
collections: collections);
}

public Operation global_variables_initializer()
{
var g = variables.global_variables();
return variables.variables_initializer(g.ToArray());
}

public Session Session()
=> new Session().as_default();
}
}

+ 1
- 4
src/TensorFlowNET.Core/APIs/tf.variable.cs View File

@@ -37,10 +37,7 @@ namespace Tensorflow
=> variables.variables_initializer(var_list, name: name);

public Operation global_variables_initializer()
{
var g = variables.global_variables();
return variables.variables_initializer(g.ToArray());
}
=> tf.compat.v1.global_variables_initializer();

/// <summary>
/// Returns all variables created with `trainable=True`.


+ 3
- 3
src/TensorFlowNET.Core/Tensorflow.Binding.csproj View File

@@ -5,7 +5,7 @@
<AssemblyName>TensorFlow.NET</AssemblyName>
<RootNamespace>Tensorflow</RootNamespace>
<TargetTensorFlow>2.2.0</TargetTensorFlow>
<Version>0.60.5</Version>
<Version>0.60.6</Version>
<LangVersion>9.0</LangVersion>
<Nullable>enable</Nullable>
<Authors>Haiping Chen, Meinrad Recheis, Eli Belash</Authors>
@@ -20,7 +20,7 @@
<Description>Google's TensorFlow full binding in .NET Standard.
Building, training and infering deep learning models.
https://tensorflownet.readthedocs.io</Description>
<AssemblyVersion>0.60.5.0</AssemblyVersion>
<AssemblyVersion>0.60.6.0</AssemblyVersion>
<PackageReleaseNotes>tf.net 0.60.x and above are based on tensorflow native 2.6.0

* Eager Mode is added finally.
@@ -35,7 +35,7 @@ Keras API is a separate package released as TensorFlow.Keras.
tf.net 0.4x.x aligns with TensorFlow v2.4.1 native library.
tf.net 0.5x.x aligns with TensorFlow v2.5.x native library.
tf.net 0.6x.x aligns with TensorFlow v2.6.x native library.</PackageReleaseNotes>
<FileVersion>0.60.5.0</FileVersion>
<FileVersion>0.60.6.0</FileVersion>
<PackageLicenseFile>LICENSE</PackageLicenseFile>
<PackageRequireLicenseAcceptance>true</PackageRequireLicenseAcceptance>
<SignAssembly>true</SignAssembly>


+ 6
- 0
src/TensorFlowNET.Core/Variables/ResourceVariable.cs View File

@@ -17,6 +17,7 @@
using Google.Protobuf;
using System;
using System.Collections.Generic;
using Tensorflow.NumPy;
using static Tensorflow.Binding;

namespace Tensorflow
@@ -229,5 +230,10 @@ namespace Tensorflow

throw new NotImplementedException("to_proto RefVariable");
}

public NDArray eval(Session session = null)
{
return _graph_element.eval(session);
}
}
}

+ 1
- 3
src/TensorFlowNET.Core/tensorflow.cs View File

@@ -93,9 +93,7 @@ namespace Tensorflow
=> ops.get_default_session();

public Session Session()
{
return new Session().as_default();
}
=> compat.v1.Session();

public Session Session(Graph graph, ConfigProto config = null)
{


+ 2
- 2
src/TensorFlowNET.Keras/Tensorflow.Keras.csproj View File

@@ -37,8 +37,8 @@ Keras is an API designed for human beings, not machines. Keras follows best prac
<RepositoryType>Git</RepositoryType>
<SignAssembly>true</SignAssembly>
<AssemblyOriginatorKeyFile>Open.snk</AssemblyOriginatorKeyFile>
<AssemblyVersion>0.6.5.0</AssemblyVersion>
<FileVersion>0.6.5.0</FileVersion>
<AssemblyVersion>0.6.6.0</AssemblyVersion>
<FileVersion>0.6.6.0</FileVersion>
<PackageLicenseFile>LICENSE</PackageLicenseFile>
</PropertyGroup>



+ 26
- 0
test/TensorFlowNET.Graph.UnitTest/Basics/VariableTest.cs View File

@@ -0,0 +1,26 @@
using Microsoft.VisualStudio.TestTools.UnitTesting;
using System.Linq;
using Tensorflow;
using static Tensorflow.Binding;

namespace TensorFlowNET.UnitTest.Basics
{
[TestClass]
public class VariableTest : GraphModeTestBase
{
[TestMethod]
public void InitVariable()
{
var v = tf.Variable(new[] { 1, 2 });
var init = tf.compat.v1.global_variables_initializer();

using var sess = tf.compat.v1.Session();
sess.run(init);
// Usage passing the session explicitly.
print(v.eval(sess));
// Usage with the default session. The 'with' block
// above makes 'sess' the default session.
print(v.eval());
}
}
}

Loading…
Cancel
Save