Browse Source

tf.nn.ctc_greedy_decoder #473

tags/v0.20
Oceania2018 5 years ago
parent
commit
9d5bb8f1e2
8 changed files with 142 additions and 2 deletions
  1. +1
    -1
      src/TensorFlowNET.Core/APIs/tf.graph.cs
  2. +3
    -0
      src/TensorFlowNET.Core/APIs/tf.nn.cs
  3. +31
    -0
      src/TensorFlowNET.Core/GraphTransformation/GraphTransformer.cs
  4. +67
    -0
      src/TensorFlowNET.Core/Operations/ctc_ops.cs
  5. +38
    -0
      src/TensorFlowNET.Core/Operations/gen_ctc_ops.cs
  6. +0
    -0
      src/TensorFlowNET.Core/Operations/gen_image_ops.cs
  7. +0
    -0
      src/TensorFlowNET.Core/Operations/gen_io_ops.cs
  8. +2
    -1
      src/TensorFlowNet.Benchmarks/Benchmark.csproj

+ 1
- 1
src/TensorFlowNET.Core/APIs/tf.graph.cs View File

@@ -21,7 +21,7 @@ namespace Tensorflow
public partial class tensorflow
{
public graph_util_impl graph_util => new graph_util_impl();
public GraphTransformer graph_transforms => new GraphTransformer();
public GraphKeys GraphKeys { get; } = new GraphKeys();

public void reset_default_graph()


+ 3
- 0
src/TensorFlowNET.Core/APIs/tf.nn.cs View File

@@ -46,6 +46,9 @@ namespace Tensorflow
return gen_nn_ops.conv2d(parameters);
}

public Tensor[] ctc_greedy_decoder(Tensor inputs, Tensor sequence_length, bool merge_repeated = true, string name = null)
=> gen_ctc_ops.ctc_greedy_decoder(inputs, sequence_length, merge_repeated: merge_repeated, name: name);

/// <summary>
/// Computes dropout.
/// </summary>


+ 31
- 0
src/TensorFlowNET.Core/GraphTransformation/GraphTransformer.cs View File

@@ -0,0 +1,31 @@
using System;
using System.Collections.Generic;
using System.Text;

namespace Tensorflow
{
public class GraphTransformer
{
/// <summary>
/// Graph Transform Tool
/// https://github.com/tensorflow/tensorflow/blob/master/tensorflow/tools/graph_transforms/README.md
/// </summary>
/// <param name="input_graph_def">GraphDef object containing a model to be transformed</param>
/// <param name="inputs">the model inputs</param>
/// <param name="outputs">the model outputs</param>
/// <param name="transforms">transform names and parameters</param>
/// <returns></returns>
public GraphDef TransformGraph(GraphDef input_graph_def,
string[] inputs,
string[] outputs,
string[] transforms)
{
var input_graph_def_string = input_graph_def.ToString();
var inputs_string = string.Join(",", inputs);
var outputs_string = string.Join(",", outputs);
var transforms_string = string.Join(",", transforms);

throw new NotImplementedException("");
}
}
}

+ 67
- 0
src/TensorFlowNET.Core/Operations/ctc_ops.cs View File

@@ -0,0 +1,67 @@
/*****************************************************************************
Copyright 2018 The TensorFlow.NET Authors. All Rights Reserved.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
******************************************************************************/

using System;
using System.Linq;
using Tensorflow.Operations;
using static Tensorflow.Binding;

namespace Tensorflow
{
public class ctc_ops
{
/// <summary>
/// Performs greedy decoding on the logits given in inputs.
/// </summary>
/// <param name="inputs">
/// 3-D, shape: <c>(max_time x batch_size x num_classes)</c>, the logits.
/// </param>
/// <param name="sequence_length">
/// A vector containing sequence lengths, size <c>(batch_size)</c>.
/// </param>
/// <param name="name">
/// If specified, the created operation in the graph will be this one, otherwise it will be named 'CTCGreedyDecoder'.
/// </param>
/// <param name="merge_repeated">
/// If True, merge repeated classes in output.
/// </param>
/// <returns>
/// Returns a tuple with multiple values, as follows:
/// decoded_indices : Indices matrix, size <c>(total_decoded_outputs x 2)</c>,
/// of a <c>SparseTensor&amp;lt;int64, 2&amp;gt;</c>. The rows store: [batch, time].
/// decoded_values : Values vector, size: <c>(total_decoded_outputs)</c>,
/// of a <c>SparseTensor&amp;lt;int64, 2&amp;gt;</c>. The vector stores the decoded classes.
/// decoded_shape : Shape vector, size <c>(2)</c>, of the decoded SparseTensor.
/// Values are: <c>[batch_size, max_decoded_length]</c>.
/// log_probability : Matrix, size <c>(batch_size x 1)</c>, containing sequence
/// log-probabilities.
/// The Operation can be fetched from any of the Tensorreturned in the tuple values, by fetching the Operation property.
/// </returns>
/// <remarks>
/// A note about the attribute merge_repeated: if enabled, when
/// consecutive logits' maximum indices are the same, only the first of
/// these is emitted. Labeling the blank '*', the sequence "A B B * B B"
/// becomes "A B B" if merge_repeated = True and "A B B B B" if
/// merge_repeated = False.
///
/// Regardless of the value of merge_repeated, if the maximum index of a given
/// time and batch corresponds to the blank, index <c>(num_classes - 1)</c>, no new
/// element is emitted.
/// </remarks>
public Tensor[] ctc_greedy_decoder(Tensor inputs, Tensor sequence_length, bool merge_repeated = true, string name = null)
=> gen_ctc_ops.ctc_greedy_decoder(inputs, sequence_length, merge_repeated: merge_repeated, name: name);
}
}

+ 38
- 0
src/TensorFlowNET.Core/Operations/gen_ctc_ops.cs View File

@@ -0,0 +1,38 @@
/*****************************************************************************
Copyright 2018 The TensorFlow.NET Authors. All Rights Reserved.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
******************************************************************************/

namespace Tensorflow
{
public class gen_ctc_ops
{
public static OpDefLibrary _op_def_lib = new OpDefLibrary();

public static Tensor[] ctc_greedy_decoder(Tensor inputs, Tensor sequence_length, bool merge_repeated = true, string name = "CTCGreedyDecoder")
{
var op = _op_def_lib._apply_op_helper("CTCGreedyDecoder", name: name, args: new
{
inputs,
sequence_length,
merge_repeated
});
/*var decoded_indices = op.outputs[0];
var decoded_values = op.outputs[1];
var decoded_shape = op.outputs[2];
var log_probability = op.outputs[3];*/
return op.outputs;
}
}
}

src/TensorFlowNET.Core/Operations/gen_image_ops.py.cs → src/TensorFlowNET.Core/Operations/gen_image_ops.cs View File


src/TensorFlowNET.Core/Operations/gen_io_ops.py.cs → src/TensorFlowNET.Core/Operations/gen_io_ops.cs View File


+ 2
- 1
src/TensorFlowNet.Benchmarks/Benchmark.csproj View File

@@ -19,7 +19,8 @@

<ItemGroup>
<PackageReference Include="BenchmarkDotNet" Version="0.12.0" />
<PackageReference Include="TensorFlow.NET" Version="0.12.0" />
<PackageReference Include="SciSharp.TensorFlow.Redist" Version="1.14.1" />
<PackageReference Include="TensorFlow.NET" Version="0.13.0" />
</ItemGroup>

</Project>

Loading…
Cancel
Save