Browse Source

support import CondContext, upgrade to tensorflow v1.14rc0

tags/v0.9
Oceania2018 6 years ago
parent
commit
2e45a0cf0b
18 changed files with 108 additions and 52 deletions
  1. +0
    -1
      .gitignore
  2. +6
    -0
      TensorFlow.NET.sln
  3. +0
    -0
      graph/README.md
  4. BIN
      graph/cond_test.meta
  5. BIN
      graph/kmeans.meta
  6. +0
    -23
      src/TensorFlowNET.Core/Framework/op_def_registry.py.cs
  7. +1
    -1
      src/TensorFlowNET.Core/Operations/ControlFlows/CondContext.cs
  8. +19
    -1
      src/TensorFlowNET.Core/Operations/ControlFlows/ControlFlowContext.cs
  9. +34
    -0
      src/TensorFlowNET.Core/Operations/Operation.Implicit.cs
  10. +0
    -21
      src/TensorFlowNET.Core/Operations/Operation.cs
  11. +5
    -0
      src/TensorFlowNET.Core/Python.cs
  12. +6
    -0
      src/TensorFlowNET.Core/Sessions/BaseSession.cs
  13. +4
    -0
      src/TensorFlowNET.Core/TensorFlowNET.Core.csproj
  14. +4
    -0
      src/TensorFlowNET.Core/Train/Saving/BaseSaverBuilder.cs
  15. BIN
      tensorflowlib/runtimes/win-x64/native/tensorflow.dll
  16. +27
    -5
      test/TensorFlowNET.Examples/KMeansClustering.cs
  17. +1
    -0
      test/TensorFlowNET.Examples/TensorFlowNET.Examples.csproj
  18. +1
    -0
      test/TensorFlowNET.UnitTest/TensorFlowNET.UnitTest.csproj

+ 0
- 1
.gitignore View File

@@ -62,7 +62,6 @@ StyleCopReport.xml
*_p.c
*_i.h
*.ilk
*.meta
*.obj
*.iobj
*.pch


+ 6
- 0
TensorFlow.NET.sln View File

@@ -9,6 +9,8 @@ Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "TensorFlowNET.Examples", "t
EndProject
Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "TensorFlowNET.Core", "src\TensorFlowNET.Core\TensorFlowNET.Core.csproj", "{FD682AC0-7B2D-45D3-8B0D-C6D678B04144}"
EndProject
Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "NumSharp.Core", "..\NumSharp\src\NumSharp.Core\NumSharp.Core.csproj", "{DE97EAD7-B92C-4112-9690-91C40A97179E}"
EndProject
Global
GlobalSection(SolutionConfigurationPlatforms) = preSolution
Debug|Any CPU = Debug|Any CPU
@@ -27,6 +29,10 @@ Global
{FD682AC0-7B2D-45D3-8B0D-C6D678B04144}.Debug|Any CPU.Build.0 = Debug|Any CPU
{FD682AC0-7B2D-45D3-8B0D-C6D678B04144}.Release|Any CPU.ActiveCfg = Release|Any CPU
{FD682AC0-7B2D-45D3-8B0D-C6D678B04144}.Release|Any CPU.Build.0 = Release|Any CPU
{DE97EAD7-B92C-4112-9690-91C40A97179E}.Debug|Any CPU.ActiveCfg = Debug|Any CPU
{DE97EAD7-B92C-4112-9690-91C40A97179E}.Debug|Any CPU.Build.0 = Debug|Any CPU
{DE97EAD7-B92C-4112-9690-91C40A97179E}.Release|Any CPU.ActiveCfg = Release|Any CPU
{DE97EAD7-B92C-4112-9690-91C40A97179E}.Release|Any CPU.Build.0 = Release|Any CPU
EndGlobalSection
GlobalSection(SolutionProperties) = preSolution
HideSolutionNode = FALSE


+ 0
- 0
graph/README.md View File


BIN
graph/cond_test.meta View File


BIN
graph/kmeans.meta View File


+ 0
- 23
src/TensorFlowNET.Core/Framework/op_def_registry.py.cs View File

@@ -20,32 +20,9 @@ namespace Tensorflow

foreach (var op_def in op_list.Op)
_registered_ops[op_def.Name] = op_def;

if (!_registered_ops.ContainsKey("NearestNeighbors"))
_registered_ops["NearestNeighbors"] = op_NearestNeighbors();
}

return _registered_ops;
}

/// <summary>
/// Doesn't work because the op can't be found on binary
/// </summary>
/// <returns></returns>
private static OpDef op_NearestNeighbors()
{
var def = new OpDef
{
Name = "NearestNeighbors"
};

def.InputArg.Add(new ArgDef { Name = "points", Type = DataType.DtFloat });
def.InputArg.Add(new ArgDef { Name = "centers", Type = DataType.DtFloat });
def.InputArg.Add(new ArgDef { Name = "k", Type = DataType.DtInt64 });
def.OutputArg.Add(new ArgDef { Name = "nearest_center_indices", Type = DataType.DtInt64 });
def.OutputArg.Add(new ArgDef { Name = "nearest_center_distances", Type = DataType.DtFloat });

return def;
}
}
}

+ 1
- 1
src/TensorFlowNET.Core/Operations/ControlFlows/CondContext.cs View File

@@ -335,7 +335,7 @@ namespace Tensorflow.Operations
ret.Enter();
foreach (var nested_def in proto.NestedContexts)
throw new NotImplementedException("");
from_control_flow_context_def(nested_def, import_scope: import_scope);
ret.Exit();
return ret;
}


+ 19
- 1
src/TensorFlowNET.Core/Operations/ControlFlows/ControlFlowContext.cs View File

@@ -3,7 +3,8 @@ using System.Collections.Generic;
using System.Linq;
using System.Text;
using Tensorflow.Operations.ControlFlows;

using static Tensorflow.ControlFlowContextDef;
namespace Tensorflow.Operations
{
/// <summary>
@@ -184,6 +185,23 @@ namespace Tensorflow.Operations
return null;
}

/// <summary>
/// Deserializes `context_def` into the appropriate ControlFlowContext.
/// </summary>
/// <param name="context_def">ControlFlowContextDef proto</param>
/// <param name="import_scope">Name scope to add</param>
/// <returns>A ControlFlowContext subclass</returns>
protected ControlFlowContext from_control_flow_context_def(ControlFlowContextDef context_def, string import_scope = "")
{
switch (context_def.CtxtCase)
{
case CtxtOneofCase.CondCtxt:
return new CondContext().from_proto(context_def.CondCtxt, import_scope: import_scope);
}
throw new NotImplementedException($"Unknown ControlFlowContextDef field: {context_def.CtxtCase}");
}

public object to_proto()
{
throw new NotImplementedException();


+ 34
- 0
src/TensorFlowNET.Core/Operations/Operation.Implicit.cs View File

@@ -0,0 +1,34 @@
using System;
using System.Collections.Generic;
using System.Text;

namespace Tensorflow
{
/// <summary>
/// Convert to other datatype implicitly
/// </summary>
public partial class Operation
{
public static implicit operator Operation(IntPtr handle) => new Operation(handle);
public static implicit operator IntPtr(Operation op) => op._handle;
public static implicit operator Tensor(Operation op) => op.output;

public override string ToString()
{
return _handle == IntPtr.Zero ? "tf.Operation Undefined" : $"tf.Operation '{name}' type={OpType}";
}

public override bool Equals(object obj)
{
switch (obj)
{
case IntPtr val:
return val == _handle;
case Operation val:
return val._handle == _handle;
}

return base.Equals(obj);
}
}
}

+ 0
- 21
src/TensorFlowNET.Core/Operations/Operation.cs View File

@@ -248,27 +248,6 @@ namespace Tensorflow
s.Check();
return NodeDef.Parser.ParseFrom(buffer);
}
}

public override string ToString()
{
return _handle == IntPtr.Zero ? "tf.Operation Undefined" : $"tf.Operation '{name}' type={OpType}";
}

public static implicit operator Operation(IntPtr handle) => new Operation(handle);
public static implicit operator IntPtr(Operation op) => op._handle;

public override bool Equals(object obj)
{
switch (obj)
{
case IntPtr val:
return val == _handle;
case Operation val:
return val._handle == _handle;
}

return base.Equals(obj);
}
/// <summary>


+ 5
- 0
src/TensorFlowNET.Core/Python.cs View File

@@ -27,6 +27,11 @@ namespace Tensorflow
return Enumerable.Range(0, end);
}

protected IEnumerable<int> range(int start, int end)
{
return Enumerable.Range(start, end);
}

public static T New<T>(object args) where T : IPyClass
{
var instance = Activator.CreateInstance<T>();


+ 6
- 0
src/TensorFlowNET.Core/Sessions/BaseSession.cs View File

@@ -204,6 +204,12 @@ namespace Tensorflow

switch (tensor.dtype)
{
case TF_DataType.TF_BOOL:
var bools = new bool[tensor.size];
for (ulong i = 0; i < tensor.size; i++)
bools[i] = *(bool*)(offset + (int)(tensor.itemsize * i));
nd = np.array(bools).reshape(ndims);
break;
case TF_DataType.TF_STRING:
var bytes = tensor.Data();
// wired, don't know why we have to start from offset 9.


+ 4
- 0
src/TensorFlowNET.Core/TensorFlowNET.Core.csproj View File

@@ -57,4 +57,8 @@ More math/ linalg APIs.</PackageReleaseNotes>
<Folder Include="Keras\Initializers\" />
</ItemGroup>

<ItemGroup>
<ProjectReference Include="..\..\..\NumSharp\src\NumSharp.Core\NumSharp.Core.csproj" />
</ItemGroup>

</Project>

+ 4
- 0
src/TensorFlowNET.Core/Train/Saving/BaseSaverBuilder.cs View File

@@ -2,6 +2,7 @@
using System.Collections.Generic;
using System.Linq;
using System.Text;
using Tensorflow.Operations;

namespace Tensorflow
{
@@ -115,6 +116,9 @@ namespace Tensorflow
case List<ITensorOrOperation> values:
foreach (var element in values) ;
break;
case List<CondContext> values:
foreach (var element in values) ;
break;
default:
throw new NotImplementedException("_build_internal.check_collection_list");
}


BIN
tensorflowlib/runtimes/win-x64/native/tensorflow.dll View File


+ 27
- 5
test/TensorFlowNET.Examples/KMeansClustering.cs View File

@@ -33,19 +33,41 @@ namespace TensorFlowNET.Examples

public bool Run()
{
PrepareData();

var graph = tf.Graph().as_default();

tf.train.import_meta_graph("kmeans.meta");

// Input images
var X = tf.placeholder(tf.float32, shape: new TensorShape(-1, num_features));
var X = graph.get_operation_by_name("Placeholder").output; // tf.placeholder(tf.float32, shape: new TensorShape(-1, num_features));
// Labels (for assigning a label to a centroid and testing)
var Y = tf.placeholder(tf.float32, shape: new TensorShape(-1, num_classes));
var Y = graph.get_operation_by_name("Placeholder_1").output; // tf.placeholder(tf.float32, shape: new TensorShape(-1, num_classes));

// K-Means Parameters
var kmeans = new KMeans(X, k, distance_metric: KMeans.COSINE_DISTANCE, use_mini_batch: true);
//var kmeans = new KMeans(X, k, distance_metric: KMeans.COSINE_DISTANCE, use_mini_batch: true);

// Build KMeans graph
var training_graph = kmeans.training_graph();
//var training_graph = kmeans.training_graph();

var init_vars = tf.global_variables_initializer();
Tensor init_op = graph.get_operation_by_name("cond/Merge");
var train_op = graph.get_operation_by_name("group_deps");
Tensor avg_distance = graph.get_operation_by_name("Mean");
Tensor cluster_idx = graph.get_operation_by_name("Squeeze_1");

with(tf.Session(graph), sess =>
{
sess.run(init_vars, new FeedItem(X, full_data_x));
sess.run(init_op, new FeedItem(X, full_data_x));

// Training
foreach(var i in range(1, num_steps + 1))
{
var result = sess.run(new Tensor[] { avg_distance, cluster_idx }, new FeedItem(X, full_data_x));
}
});

return false;
}



+ 1
- 0
test/TensorFlowNET.Examples/TensorFlowNET.Examples.csproj View File

@@ -13,6 +13,7 @@
</ItemGroup>

<ItemGroup>
<ProjectReference Include="..\..\..\NumSharp\src\NumSharp.Core\NumSharp.Core.csproj" />
<ProjectReference Include="..\..\src\TensorFlowNET.Core\TensorFlowNET.Core.csproj" />
</ItemGroup>



+ 1
- 0
test/TensorFlowNET.UnitTest/TensorFlowNET.UnitTest.csproj View File

@@ -22,6 +22,7 @@
</ItemGroup>

<ItemGroup>
<ProjectReference Include="..\..\..\NumSharp\src\NumSharp.Core\NumSharp.Core.csproj" />
<ProjectReference Include="..\..\src\TensorFlowNET.Core\TensorFlowNET.Core.csproj" />
<ProjectReference Include="..\TensorFlowNET.Examples\TensorFlowNET.Examples.csproj" />
</ItemGroup>


Loading…
Cancel
Save