Browse Source

_registered_ops NearestNeighbors

tags/v0.9
Oceania2018 6 years ago
parent
commit
8cf46d2d23
2 changed files with 26 additions and 0 deletions
  1. +24
    -0
      src/TensorFlowNET.Core/Framework/op_def_registry.py.cs
  2. +2
    -0
      test/TensorFlowNET.Examples/KMeansClustering.cs

+ 24
- 0
src/TensorFlowNET.Core/Framework/op_def_registry.py.cs View File

@@ -1,6 +1,7 @@
using System;
using System.Collections.Generic;
using System.Text;
using static Tensorflow.OpDef.Types;

namespace Tensorflow
{
@@ -19,9 +20,32 @@ namespace Tensorflow

foreach (var op_def in op_list.Op)
_registered_ops[op_def.Name] = op_def;

if (!_registered_ops.ContainsKey("NearestNeighbors"))
_registered_ops["NearestNeighbors"] = op_NearestNeighbors();
}

return _registered_ops;
}

/// <summary>
/// Doesn't work because the op can't be found on binary
/// </summary>
/// <returns></returns>
private static OpDef op_NearestNeighbors()
{
var def = new OpDef
{
Name = "NearestNeighbors"
};

def.InputArg.Add(new ArgDef { Name = "points", Type = DataType.DtFloat });
def.InputArg.Add(new ArgDef { Name = "centers", Type = DataType.DtFloat });
def.InputArg.Add(new ArgDef { Name = "k", Type = DataType.DtInt64 });
def.OutputArg.Add(new ArgDef { Name = "nearest_center_indices", Type = DataType.DtInt64 });
def.OutputArg.Add(new ArgDef { Name = "nearest_center_distances", Type = DataType.DtFloat });

return def;
}
}
}

+ 2
- 0
test/TensorFlowNET.Examples/KMeansClustering.cs View File

@@ -33,6 +33,8 @@ namespace TensorFlowNET.Examples

public bool Run()
{
tf.train.import_meta_graph("kmeans.meta");

// Input images
var X = tf.placeholder(tf.float32, shape: new TensorShape(-1, num_features));
// Labels (for assigning a label to a centroid and testing)


Loading…
Cancel
Save