Browse Source

Merge branch 'master' of https://github.com/SciSharp/TensorFlow.NET

tags/v0.40-tf2.4-tstring
Niklas Gustafsson 4 years ago
parent
commit
2b62e120cf
65 changed files with 1380 additions and 2670 deletions
  1. +5
    -4
      README.md
  2. +24
    -20
      src/TensorFlowNET.Console/MemoryBasicTest.cs
  3. +1
    -0
      src/TensorFlowNET.Console/Tensorflow.Console.csproj
  4. +30
    -4
      src/TensorFlowNET.Core/APIs/tf.linalg.cs
  5. +22
    -0
      src/TensorFlowNET.Core/APIs/tf.math.cs
  6. +6
    -1
      src/TensorFlowNET.Core/APIs/tf.random.cs
  7. +5
    -4
      src/TensorFlowNET.Core/APIs/tf.sparse.cs
  8. +23
    -0
      src/TensorFlowNET.Core/APIs/tf.strings.cs
  9. +0
    -13
      src/TensorFlowNET.Core/Contexts/AutoModeArgs.cs
  10. +17
    -49
      src/TensorFlowNET.Core/Contexts/Context.ExecuteOp.cs
  11. +26
    -2
      src/TensorFlowNET.Core/Contexts/Context.cs
  12. +25
    -0
      src/TensorFlowNET.Core/Contexts/ExecuteOpArgs.cs
  13. +12
    -12
      src/TensorFlowNET.Core/Data/DatasetV2.cs
  14. +2
    -0
      src/TensorFlowNET.Core/Data/IDatasetV2.cs
  15. +1
    -0
      src/TensorFlowNET.Core/Data/OwnedIterator.cs
  16. +23
    -53
      src/TensorFlowNET.Core/Eager/EagerRunner.TFE_FastPathExecute.cs
  17. +13
    -2
      src/TensorFlowNET.Core/Eager/FastPathOpExecInfo.cs
  18. +1
    -14
      src/TensorFlowNET.Core/Eager/IEagerRunner.cs
  19. +28
    -1
      src/TensorFlowNET.Core/Framework/random_seed.py.cs
  20. +0
    -63
      src/TensorFlowNET.Core/Framework/sparse_tensor.py.cs
  21. +2
    -2
      src/TensorFlowNET.Core/Framework/tensor_shape.cs
  22. +8
    -8
      src/TensorFlowNET.Core/Gradients/math_grad.cs
  23. +1
    -0
      src/TensorFlowNET.Core/Keras/ArgsDefinition/Preprocessing/TextVectorizationArgs.cs
  24. +79
    -400
      src/TensorFlowNET.Core/Operations/NnOps/gen_nn_ops.cs
  25. +6
    -2
      src/TensorFlowNET.Core/Operations/OpDefLibrary.cs
  26. +22
    -64
      src/TensorFlowNET.Core/Operations/array_ops.cs
  27. +2
    -29
      src/TensorFlowNET.Core/Operations/bitwise_ops.cs
  28. +85
    -362
      src/TensorFlowNET.Core/Operations/dataset_ops.cs
  29. +55
    -342
      src/TensorFlowNET.Core/Operations/gen_array_ops.cs
  30. +24
    -94
      src/TensorFlowNET.Core/Operations/gen_image_ops.cs
  31. +2
    -3
      src/TensorFlowNET.Core/Operations/gen_logging_ops.cs
  32. +72
    -698
      src/TensorFlowNET.Core/Operations/gen_math_ops.cs
  33. +1
    -8
      src/TensorFlowNET.Core/Operations/gen_math_ops.eager.cs
  34. +8
    -92
      src/TensorFlowNET.Core/Operations/gen_random_ops.cs
  35. +23
    -64
      src/TensorFlowNET.Core/Operations/gen_resource_variable_ops.cs
  36. +39
    -44
      src/TensorFlowNET.Core/Operations/math_ops.cs
  37. +87
    -60
      src/TensorFlowNET.Core/Operations/string_ops.cs
  38. +1
    -0
      src/TensorFlowNET.Core/Tensorflow.Binding.csproj
  39. +1
    -1
      src/TensorFlowNET.Core/Tensors/EagerTensorV2.cs
  40. +0
    -7
      src/TensorFlowNET.Core/Tensors/ITensor.cs
  41. +147
    -0
      src/TensorFlowNET.Core/Tensors/Ragged/RaggedTensor.cs
  42. +103
    -0
      src/TensorFlowNET.Core/Tensors/Ragged/RowPartition.cs
  43. +76
    -0
      src/TensorFlowNET.Core/Tensors/Ragged/SparseTensor.cs
  44. +1
    -2
      src/TensorFlowNET.Core/Tensors/Tensor.cs
  45. +10
    -56
      src/TensorFlowNET.Core/Training/gen_training_ops.cs
  46. +2
    -25
      src/TensorFlowNET.Core/Variables/gen_state_ops.py.cs
  47. +2
    -16
      src/TensorFlowNET.Keras/Activations/Activations.Relu.cs
  48. +2
    -16
      src/TensorFlowNET.Keras/Activations/Activations.Sigmoid.cs
  49. +2
    -16
      src/TensorFlowNET.Keras/Activations/Activations.Tanh.cs
  50. +13
    -1
      src/TensorFlowNET.Keras/Engine/CombinerPreprocessingLayer.cs
  51. +10
    -0
      src/TensorFlowNET.Keras/Engine/Interfaces/IAccumulator.cs
  52. +19
    -0
      src/TensorFlowNET.Keras/Engine/Interfaces/ICombiner.cs
  53. +30
    -0
      src/TensorFlowNET.Keras/Layers/Preprocessing/IndexLookup.cs
  54. +16
    -0
      src/TensorFlowNET.Keras/Layers/Preprocessing/IndexLookupAccumulator.cs
  55. +55
    -0
      src/TensorFlowNET.Keras/Layers/Preprocessing/IndexLookupCombiner.cs
  56. +23
    -0
      src/TensorFlowNET.Keras/Layers/Preprocessing/StringLookup.cs
  57. +15
    -4
      src/TensorFlowNET.Keras/Layers/Preprocessing/TextVectorization.cs
  58. +2
    -0
      src/TensorFlowNET.Keras/Preprocessings/DatasetUtils.index_directory.cs
  59. +9
    -8
      src/TensorFlowNET.Keras/Tensorflow.Keras.csproj
  60. +27
    -1
      src/TensorFlowNET.Text/Tokenizers/WhitespaceTokenizer.cs
  61. +1
    -0
      src/TensorFlowNet.Benchmarks/Tensorflow.Benchmark.csproj
  62. +2
    -2
      test/TensorFlowNET.UnitTest/Basics/RandomTest.cs
  63. +20
    -0
      test/TensorFlowNET.UnitTest/Dataset/DatasetTest.cs
  64. +8
    -0
      test/TensorFlowNET.UnitTest/ManagedAPI/StringsApiTest.cs
  65. +3
    -1
      test/TensorFlowNET.UnitTest/Text/TokenizerTest.cs

+ 5
- 4
README.md View File

@@ -9,7 +9,7 @@
[![Badge](https://img.shields.io/badge/link-996.icu-red.svg)](https://996.icu/#/en_US)
[![Binder](https://mybinder.org/badge_logo.svg)](https://mybinder.org/v2/gh/javiercp/BinderTF.NET/master?urlpath=lab)

*master branch is based on tensorflow 2.3 now, v0.15-tensorflow1.15 is from tensorflow1.15.*
*master branch is based on tensorflow v2.4, v0.3x branch is based on tensorflow v2.3, v0.15-tensorflow1.15 is from tensorflow1.15.*


![tensors_flowing](docs/assets/tensors_flowing.gif)
@@ -30,7 +30,8 @@ Go through the online docs [TensorFlow for .NET](https://scisharp.github.io/tens

| TensorFlow | tf native1.14, cuda 10.0 | tf native 1.15, cuda 10.0 | tf native 2.3, cuda 10.1 | tf native 2.4, cuda 11 |
| -------------------------- | ------------- | -------------- | ------------- | ------------- |
| tf.net 0.3x, tf.keras 0.2 | | | x | not compatible |
| tf.net 0.4x, tf.keras 0.5 | | | | x |
| tf.net 0.3x, tf.keras 0.4 | | | x | |
| tf.net 0.2x | | x | x | |
| tf.net 0.15 | x | x | | |
| tf.net 0.14 | x | | | |
@@ -50,10 +51,10 @@ PM> Install-Package TensorFlow.Keras

### Install tensorflow binary
### For CPU version
PM> Install-Package SciSharp.TensorFlow.Redist -Version 2.3.1
PM> Install-Package SciSharp.TensorFlow.Redist

### For GPU version (CUDA and cuDNN are required)
PM> Install-Package SciSharp.TensorFlow.Redist-Windows-GPU -Version 2.3.1
PM> Install-Package SciSharp.TensorFlow.Redist-Windows-GPU
```

Import TF.NET and Keras API in your project.


+ 24
- 20
src/TensorFlowNET.Console/MemoryBasicTest.cs View File

@@ -112,16 +112,18 @@ namespace Tensorflow
var strides = new[] { 1, 1, 1, 1 };
var dilations = new[] { 1, 1, 1, 1 };

var results = tf.Runner.TFE_FastPathExecute(tf.Context, tf.Context.DeviceName,
"Conv2D", null,
null,
input, filter,
"strides", strides,
"use_cudnn_on_gpu", true,
"padding", "VALID",
"explicit_paddings", new int[0],
"data_format", "NHWC",
"dilations", dilations);
var results = tf.Runner.TFE_FastPathExecute(new FastPathOpExecInfo("Conv2D", null, input, filter)
{
attrs = ConvertToDict(new
{
strides,
use_cudnn_on_gpu = true,
padding = "VALID",
explicit_paddings = new int[0],
data_format = "NHWC",
dilations
})
});
};

public Action<int, int> Conv2DWithVariable
@@ -132,16 +134,18 @@ namespace Tensorflow
var strides = new[] { 1, 1, 1, 1 };
var dilations = new[] { 1, 1, 1, 1 };

var results = tf.Runner.TFE_FastPathExecute(tf.Context, tf.Context.DeviceName,
"Conv2D", null,
null,
input, filter,
"strides", strides,
"use_cudnn_on_gpu", true,
"padding", "VALID",
"explicit_paddings", new int[0],
"data_format", "NHWC",
"dilations", dilations);
var results = tf.Runner.TFE_FastPathExecute(new FastPathOpExecInfo("Conv2D", null, input, filter)
{
attrs = ConvertToDict(new
{
strides,
use_cudnn_on_gpu = true,
padding = "VALID",
explicit_paddings = new int[0],
data_format = "NHWC",
dilations
})
});
};

public Action<int, int> Dataset


+ 1
- 0
src/TensorFlowNET.Console/Tensorflow.Console.csproj View File

@@ -11,6 +11,7 @@

<PropertyGroup Condition="'$(Configuration)|$(Platform)'=='Debug|AnyCPU'">
<DefineConstants>TRACE;DEBUG</DefineConstants>
<PlatformTarget>x64</PlatformTarget>
</PropertyGroup>

<ItemGroup>


+ 30
- 4
src/TensorFlowNET.Core/APIs/tf.linalg.cs View File

@@ -13,6 +13,7 @@
See the License for the specific language governing permissions and
limitations under the License.
******************************************************************************/
using static Tensorflow.Binding;

namespace Tensorflow
{
@@ -37,8 +38,8 @@ namespace Tensorflow
public Tensor matmul(Tensor a, Tensor b)
=> math_ops.matmul(a, b);

public Tensor batch_matmul(Tensor x, Tensor y)
=> gen_math_ops.batch_mat_mul(x, y);
public Tensor batch_matmul(Tensor x, Tensor y, bool adj_x = false, bool adj_y = false, string name = null)
=> math_ops.batch_matmul(x, y, adj_x: adj_x, adj_y: adj_y, name: name);
}

public Tensor diag(Tensor diagonal, string name = null)
@@ -47,7 +48,32 @@ namespace Tensorflow
public Tensor matmul(Tensor a, Tensor b)
=> math_ops.matmul(a, b);

public Tensor batch_matmul(Tensor x, Tensor y)
=> gen_math_ops.batch_mat_mul(x, y);
/// <summary>
/// Multiply slices of the two matrices "x" and "y".
/// </summary>
/// <remarks>
/// The `BatchMatMul` operation is embedded into the
/// `MatMul` operation on the DLL side. However the expected
/// attributes are not the same, hence we need to expose this
/// method to have the right args list on the `_apply_op_helper`
/// function.
///
/// For each rank > 2 the first rank - 2 dimensions are considered
/// as fixed, and have to be consistent across the two matrices. A
/// common matrix multiplication is then applied over the residual
/// 2 dimensions.
///
/// e.g.
/// x is (3, 6, 12); y is (3, 12, 6)
/// batch_matmul(x, y) ==> (3, 6, 6)
/// </remarks>
/// <param name="x"></param>
/// <param name="y"></param>
/// <param name="adj_x"></param>
/// <param name="adj_y"></param>
/// <param name="name"></param>
/// <returns></returns>
public Tensor batch_matmul(Tensor x, Tensor y, bool adj_x = false, bool adj_y = false, string name = null)
=> math_ops.batch_matmul(x, y, adj_x: adj_x, adj_y: adj_y, name: name);
}
}

+ 22
- 0
src/TensorFlowNET.Core/APIs/tf.math.cs View File

@@ -32,6 +32,28 @@ namespace Tensorflow
/// <returns></returns>
public Tensor erf(Tensor x, string name = null)
=> math_ops.erf(x, name);

/// <summary>
///
/// </summary>
/// <param name="arr"></param>
/// <param name="weights"></param>
/// <param name="minlength"></param>
/// <param name="maxlength"></param>
/// <param name="dtype"></param>
/// <param name="name"></param>
/// <param name="axis"></param>
/// <param name="binary_output"></param>
/// <returns></returns>
public Tensor bincount(Tensor arr, Tensor weights = null,
Tensor minlength = null,
Tensor maxlength = null,
TF_DataType dtype = TF_DataType.TF_INT32,
string name = null,
TensorShape axis = null,
bool binary_output = false)
=> math_ops.bincount(arr, weights: weights, minlength: minlength, maxlength: maxlength,
dtype: dtype, name: name, axis: axis, binary_output: binary_output);
}

public Tensor abs(Tensor x, string name = null)


+ 6
- 1
src/TensorFlowNET.Core/APIs/tf.random.cs View File

@@ -93,7 +93,12 @@ namespace Tensorflow
=> random_ops.random_shuffle(value, seed: seed, name: name);

public void set_random_seed(int seed)
=> ops.get_default_graph().seed = seed;
{
if (executing_eagerly())
Context.set_global_seed(seed);
else
ops.get_default_graph().seed = seed;
}

public Tensor multinomial(Tensor logits, int num_samples, int? seed = null,
string name = null, TF_DataType output_dtype = TF_DataType.DtInvalid)


+ 5
- 4
src/TensorFlowNET.Core/APIs/tf.sparse.cs View File

@@ -14,17 +14,18 @@
limitations under the License.
******************************************************************************/

using System;
using Tensorflow.Framework;

namespace Tensorflow
{
public partial class tensorflow
{
public SparseTensor<T> SparseTensor<T>(long[,] indices, T[] values, long[] dense_shape)
=> new SparseTensor<T>(indices, values, dense_shape);
public SparseTensor SparseTensor(long[,] indices, Array values, long[] dense_shape)
=> new SparseTensor(indices, values, dense_shape);

public Tensor sparse_tensor_to_dense<T>(SparseTensor<T> sp_input,
T default_value = default,
public Tensor sparse_tensor_to_dense(SparseTensor sp_input,
Array default_value = default,
bool validate_indices = true,
string name = null)
=> gen_sparse_ops.sparse_to_dense(sp_input.indices,


+ 23
- 0
src/TensorFlowNET.Core/APIs/tf.strings.cs View File

@@ -14,6 +14,8 @@
limitations under the License.
******************************************************************************/

using Tensorflow.Framework;

namespace Tensorflow
{
public partial class tensorflow
@@ -64,6 +66,27 @@ namespace Tensorflow
public Tensor substr(string input, int pos, int len,
string name = null, string @uint = "BYTE")
=> ops.substr(input, pos, len, @uint: @uint, name: name);

/// <summary>
/// String lengths of `input`.
/// </summary>
/// <param name="input"></param>
/// <param name="name"></param>
/// <param name="unit"></param>
/// <returns></returns>
public Tensor string_length(Tensor input, string name = null, string unit = "BYTE")
=> ops.string_length(input, name: name, unit: unit);

public RaggedTensor split(Tensor input, string sep = "", int maxsplit = -1, string name = null)
=> ops.string_split_v2(input, sep: sep, maxsplit : maxsplit, name : name);

public (RaggedTensor, RaggedTensor) unicode_decode_with_offsets(Tensor input, string input_encoding,
string errors = "replace", int replacement_char = 0xFFFD,
bool replace_control_characters = false, string name = null)
=> ops.unicode_decode_with_offsets(input, input_encoding, errors,
replacement_char: replacement_char,
replace_control_characters: replace_control_characters,
name: name);
}
}
}

+ 0
- 13
src/TensorFlowNET.Core/Contexts/AutoModeArgs.cs View File

@@ -1,13 +0,0 @@
using System;
using System.Collections.Generic;
using System.Text;

namespace Tensorflow
{
public class AutoModeArgs
{
public Func<Operation, object> GetGradientAttrs { get; set; }
public object OpInputArgs { get; set; }
public object OpAttrs { get; set; }
}
}

src/TensorFlowNET.Core/Contexts/Context.AutoMode.cs → src/TensorFlowNET.Core/Contexts/Context.ExecuteOp.cs View File

@@ -30,67 +30,35 @@ namespace Tensorflow.Contexts
public sealed partial class Context
{
// [DebuggerStepThrough]
public T RunInAutoMode<T>(Func<T> graphAction, Func<T> eagerAction, params object[] args)
public Tensors ExecuteOp(string OpType, string Name, ExecuteOpArgs args)
{
if (tf.Context.has_graph_arg(args))
Func<Tensors> graphAction = () =>
{
if (executing_eagerly())
{
graph_mode();
var result = graphAction();
restore_mode();
return result;
}
else
{
return graphAction();
}
}
else
{
if (tf.Context.executing_eagerly())
var keywords = new Dictionary<string, object>();
if(args.OpInputArgs != null)
{
return eagerAction();
foreach (var (i, input) in enumerate(args.OpInputArgs))
keywords[$"input_{i}"] = input;
}
else

if(args.OpAttrs != null)
{
return graphAction();
foreach (var attr in args.OpAttrs)
keywords[attr.Key] = attr.Value;
}
}
}
// [DebuggerStepThrough]
public Tensors RunInAutoMode2(string OpType, string Name, AutoModeArgs args)
{
var inputArgs = ConvertToDict(args.OpInputArgs);
var attrDict = ConvertToDict(args.OpAttrs);
Func<Tensor> graphAction = () =>
{
foreach (var attr in attrDict)
inputArgs[attr.Key] = attr.Value;
return tf.OpDefLib._apply_op_helper(OpType, Name, inputArgs).output;

return tf.OpDefLib._apply_op_helper(OpType, Name, keywords).outputs;
};

Func<Tensor> eagerAction = () =>
Func<Tensors> eagerAction = () =>
{
var attrs = new object[attrDict.Count() * 2];
int i = 0;
foreach(var arg in attrDict)
return tf.Runner.TFE_FastPathExecute(new FastPathOpExecInfo(OpType, Name, args.OpInputArgs)
{
attrs[i]= arg.Key;
attrs[i + 1] = arg.Value;
i += 2;
}

return tf.Runner.TFE_FastPathExecute2(tf.Context, tf.Context.DeviceName,
OpType, Name,
null,
inputArgs.Values.ToArray(),
attrs).FirstOrDefault();
attrs = args.OpAttrs
});
};

if (tf.Context.has_graph_arg(inputArgs.Values))
if (tf.Context.has_graph_arg(args.OpInputArgs))
{
if (executing_eagerly())
{

+ 26
- 2
src/TensorFlowNET.Core/Contexts/Context.cs View File

@@ -42,6 +42,9 @@ namespace Tensorflow.Contexts
SafeContextHandle _handle;
public SafeContextHandle Handle => _handle;

int? _seed;
Random _rng;

public Context()
{
_device_policy = ContextDevicePlacementPolicy.DEVICE_PLACEMENT_SILENT;
@@ -71,6 +74,24 @@ namespace Tensorflow.Contexts
initialized = true;
}

public void set_global_seed(int? seed)
{
_seed = seed;
if (seed.HasValue)
_rng = new Random(seed.Value);
else
_rng = null;
// Also clear the kernel cache, to reset any existing seeds
if (_handle != null)
c_api.TFE_ContextClearCaches(_handle);
}

public int? global_seed()
=> _seed;

public int? internal_operation_seed()
=> _rng?.Next(0, int.MaxValue);

public void start_step()
=> c_api.TFE_ContextStartStep(_handle);

@@ -86,7 +107,7 @@ namespace Tensorflow.Contexts
{
if(context_switches.Count() == 0)
tf.enable_eager_execution();
return context_switches.Current().EagerMode;
}

@@ -115,7 +136,10 @@ namespace Tensorflow.Contexts
public bool has_graph_arg(params object[] args)
{
var flatten_args = nest.flatten<object>(args);
bool has_graph_arg = false;
/*if (flatten_args.Count(x => x.GetType().IsValueType) == flatten_args.Count())
return tf.Context.executing_eagerly() == false*/

bool has_graph_arg = !tf.Context.executing_eagerly();
foreach (var el in flatten_args)
{
if (el is Tensor tensor && !tensor.IsEagerTensor)


+ 25
- 0
src/TensorFlowNET.Core/Contexts/ExecuteOpArgs.cs View File

@@ -0,0 +1,25 @@
using System;
using System.Collections.Generic;
using System.Text;
using static Tensorflow.Binding;

namespace Tensorflow
{
public class ExecuteOpArgs
{
public Func<Operation, object> GetGradientAttrs { get; set; }
public object[] OpInputArgs { get; set; }
public Dictionary<string, object> OpAttrs { get; set; }
public ExecuteOpArgs(params object[] inputArgs)
{
OpInputArgs = inputArgs;
}

public ExecuteOpArgs SetAttributes(object attrs)
{
OpAttrs = ConvertToDict(attrs);
return this;
}
}
}

+ 12
- 12
src/TensorFlowNET.Core/Data/DatasetV2.cs View File

@@ -68,6 +68,17 @@ namespace Tensorflow
public IDatasetV2 map(Func<Tensors, Tensors> map_func, int num_parallel_calls)
=> new ParallelMapDataset(this, map_func, num_parallel_calls: num_parallel_calls);

public OwnedIterator make_one_shot_iterator()
{
if (tf.Context.executing_eagerly())
{
// with ops.colocate_with(self._variant_tensor)
return new OwnedIterator(this);
}

throw new NotImplementedException("");
}

public IDatasetV2 flat_map(Func<Tensor, IDatasetV2> map_func)
=> new FlatMapDataset(this, map_func);

@@ -105,18 +116,7 @@ namespace Tensorflow
}

public Tensor dataset_cardinality(string name = null)
{
if (tf.Context.executing_eagerly())
{
var results = tf.Runner.TFE_FastPathExecute(tf.Context, tf.Context.DeviceName,
"DatasetCardinality", name,
null,
variant_tensor);
return results[0];
}

throw new NotImplementedException("");
}
=> tf.Context.ExecuteOp("DatasetCardinality", name, new ExecuteOpArgs(variant_tensor));

public override string ToString()
=> $"{GetType().Name} shapes: {string.Join(", ", structure.Select(x => x.shape))}, types: {string.Join(", ", structure.Select(x => "tf." + x.dtype.as_numpy_name()))}";


+ 2
- 0
src/TensorFlowNET.Core/Data/IDatasetV2.cs View File

@@ -72,6 +72,8 @@ namespace Tensorflow
IDatasetV2 map(Func<Tensors, Tensors> map_func,
int num_parallel_calls);

OwnedIterator make_one_shot_iterator();

IDatasetV2 flat_map(Func<Tensor, IDatasetV2> map_func);

IDatasetV2 model(AutotuneAlgorithm algorithm, long cpu_budget);


+ 1
- 0
src/TensorFlowNET.Core/Data/OwnedIterator.cs View File

@@ -26,6 +26,7 @@ namespace Tensorflow
dataset = dataset.apply_options();
_dataset = dataset;
_element_spec = dataset.element_spec;
// _flat_output_types =
(_iterator_resource, _deleter) = ops.anonymous_iterator_v2(_dataset.output_types, _dataset.output_shapes);
ops.make_iterator(dataset.variant_tensor, _iterator_resource);
}


+ 23
- 53
src/TensorFlowNET.Core/Eager/EagerRunner.TFE_FastPathExecute.cs View File

@@ -15,84 +15,54 @@ namespace Tensorflow.Eager
/// </summary>
public partial class EagerRunner
{
int kFastPathExecuteInputStartIndex = 0;
UnorderedMap<Context, SafeOpHandle> thread_local_eager_operation_map = new UnorderedMap<Context, SafeOpHandle>();

public Tensor[] TFE_FastPathExecute2(Context ctx,
string device_name,
string opName,
string name,
Action callbacks,
object[] inputArgs,
object[] attrs)
public Tensor[] TFE_FastPathExecute(FastPathOpExecInfo op_exec_info)
{
var args = new List<object>();
args.AddRange(inputArgs);
if (attrs != null)
args.AddRange(attrs);
return TFE_FastPathExecute(ctx, device_name, opName, name, callbacks, args.ToArray());
}

public Tensor[] TFE_FastPathExecute(Context ctx,
string device_name,
string opName,
string name,
Action callbacks,
params object[] args)
{
if (ctx == null)
throw new ValueError("This function does not handle the case of the path where " +
"all inputs are not already EagerTensors.");
if (op_exec_info.ctx == null)
op_exec_info.ctx = tf.Context;
if (string.IsNullOrEmpty(op_exec_info.device_name))
op_exec_info.device_name = tf.Context.DeviceName;

int args_size = args.Length;
var attr_list_sizes = new Dictionary<string, long>();

FastPathOpExecInfo op_exec_info = new FastPathOpExecInfo()
{
ctx = ctx,
args = args,
device_name = device_name,
op_name = opName,
name = name,
};

op_exec_info.run_gradient_callback = HasAccumulatorOrTape();
op_exec_info.run_post_exec_callbacks = callbacks != null;
op_exec_info.run_post_exec_callbacks = op_exec_info.callbacks != null;
op_exec_info.run_callbacks = op_exec_info.run_gradient_callback || op_exec_info.run_post_exec_callbacks;

var status = tf.Status;
using var op = GetOp(ctx, opName, status);
using var op = GetOp(op_exec_info.ctx, op_exec_info.op_name, status);

var op_def = tf.get_default_graph().GetOpDef(opName);
var op_def = tf.get_default_graph().GetOpDef(op_exec_info.op_name);

var flattened_attrs = new List<object>(op_def.Attr.Count * 2);
var flattened_inputs = new List<Tensor>(op_def.InputArg.Count);

// Set non-inferred attrs, including setting defaults if the attr is passed in
// as None.
for (int i = kFastPathExecuteInputStartIndex + op_def.InputArg.Count; i < args_size; i += 2)
if(op_exec_info.attrs != null)
{
var attr_name = args[i].ToString();
var attr_value = args[i + 1];

var attr = op_def.Attr.FirstOrDefault(x => x.Name == attr_name);
if (attr != null)
foreach (var attr1 in op_exec_info.attrs)
{
flattened_attrs.Add(attr_name);
flattened_attrs.Add(attr_value);
var attr = op_def.Attr.FirstOrDefault(x => x.Name == attr1.Key);
if (attr != null)
{
flattened_attrs.Add(attr.Name);
flattened_attrs.Add(attr1.Value);

SetOpAttrWithDefaults(ctx, op, attr, attr_name, attr_value, attr_list_sizes, status);
status.Check(true);
SetOpAttrWithDefaults(op_exec_info.ctx, op, attr, attr.Name, attr1.Value, attr_list_sizes, status);
status.Check(true);
}
}
}

c_api.TFE_OpSetDevice(op, device_name, status.Handle);
c_api.TFE_OpSetDevice(op, op_exec_info.device_name, status.Handle);
status.Check(true);

// Add inferred attrs and inputs.
for (int i = 0; i < op_def.InputArg.Count; i++)
{
var input = args[kFastPathExecuteInputStartIndex + i];
var input = op_exec_info.args[i];
var input_arg = op_def.InputArg[i];
if (!string.IsNullOrEmpty(input_arg.NumberAttr))
{
@@ -107,7 +77,7 @@ namespace Tensorflow.Eager

if (len > 0)
{
var fast_input_array = (object[])args[i];
var fast_input_array = (object[])op_exec_info.args[i];
// First item adds the type attr.
if (!AddInputToOp(fast_input_array[i], true, input_arg, flattened_attrs, flattened_inputs, op, status))
return null;
@@ -151,7 +121,7 @@ namespace Tensorflow.Eager
else
{
// The item is a single item.
AddInputToOp(args[i], true, input_arg, flattened_attrs, flattened_inputs, op, status);
AddInputToOp(op_exec_info.args[i], true, input_arg, flattened_attrs, flattened_inputs, op, status);
}
}

@@ -179,7 +149,7 @@ namespace Tensorflow.Eager
if (op_exec_info.run_callbacks)
{
RunCallbacks(op_exec_info,
kFastPathExecuteInputStartIndex + op_def.InputArg.Count(),
op_def.InputArg.Count(),
flattened_inputs.ToArray(), flattened_attrs.ToArray(), flat_result);
}



+ 13
- 2
src/TensorFlowNET.Core/Eager/FastPathOpExecInfo.cs View File

@@ -1,6 +1,8 @@
using Tensorflow.Contexts;
using System;
using System.Collections.Generic;
using Tensorflow.Contexts;

namespace Tensorflow.Eager
namespace Tensorflow
{
public class FastPathOpExecInfo
{
@@ -9,8 +11,17 @@ namespace Tensorflow.Eager
public string op_name { get; set; }
public string name { get; set; }
public object[] args { get; set; }
public Dictionary<string, object> attrs { get; set; }
public bool run_gradient_callback { get; set; }
public bool run_post_exec_callbacks { get; set; }
public bool run_callbacks { get; set; }
public Action callbacks { get; set; }

public FastPathOpExecInfo(string opName, string name, params object[] inputArgs)
{
this.op_name = opName;
this.name = name;
this.args = inputArgs;
}
}
}

+ 1
- 14
src/TensorFlowNET.Core/Eager/IEagerRunner.cs View File

@@ -16,20 +16,7 @@ namespace Tensorflow.Eager
TF_DataType default_dtype = TF_DataType.DtInvalid,
object[] args = null);

Tensor[] TFE_FastPathExecute2(Context ctx,
string device_name,
string opName,
string name,
Action callbacks,
object[] inputArgs,
object[] attrs);

Tensor[] TFE_FastPathExecute(Context ctx,
string device_name,
string opName,
string name,
Action callbacks,
params object[] args);
Tensor[] TFE_FastPathExecute(FastPathOpExecInfo op_exec_info);

Tensor[] TFE_Execute(Context ctx,
string device_name,


+ 28
- 1
src/TensorFlowNET.Core/Framework/random_seed.py.cs View File

@@ -14,16 +14,43 @@
limitations under the License.
******************************************************************************/

using System.Collections.Generic;
using static Tensorflow.Binding;

namespace Tensorflow
{
public class random_seed
{
private static int DEFAULT_GRAPH_SEED = 87654321;
private static Dictionary<string, int> _graph_to_seed_dict = new Dictionary<string, int>();

public static (int?, int?) get_seed(int? op_seed = null)
{
int? global_seed;

if (tf.executing_eagerly())
global_seed = tf.Context.global_seed();
else
global_seed = ops.get_default_graph().seed;

if (global_seed.HasValue)
{
if (!op_seed.HasValue)
if (tf.executing_eagerly())
op_seed = tf.Context.internal_operation_seed();
else
{
if (!_graph_to_seed_dict.TryGetValue(ops.get_default_graph().graph_key, out int seed))
seed = 0;
_graph_to_seed_dict[ops.get_default_graph().graph_key] = seed + 1;
op_seed = seed;
}

return (global_seed, op_seed);
}

if (op_seed.HasValue)
return (DEFAULT_GRAPH_SEED, 0);
return (DEFAULT_GRAPH_SEED, op_seed);
else
return (null, null);
}


+ 0
- 63
src/TensorFlowNET.Core/Framework/sparse_tensor.py.cs View File

@@ -1,63 +0,0 @@
using System;
using System.Linq;
using static Tensorflow.Binding;

namespace Tensorflow.Framework
{
/// <summary>
/// Represents a sparse tensor.
/// </summary>
public class SparseTensor<T> : CompositeTensor, _TensorLike
{
long[,] _indices;
public Tensor indices;

T[] _values;
public Tensor values;

long[] _dense_shape;
public Tensor dense_shape;

TensorShape _shape;
public TensorShape shape => _shape;

public TF_DataType dtype => dtypes.as_dtype(typeof(T));

public SparseTensor(long[,] indices_, T[] values_, long[] dense_shape_)
{
tf_with(ops.name_scope(null, "SparseTensor", new { }), delegate
{
indices = ops.convert_to_tensor(
indices_, name: "indices", dtype: dtypes.int64);
values = ops.convert_to_tensor(values_, name: "values");
dense_shape = ops.convert_to_tensor(
dense_shape_, name: "dense_shape", dtype: dtypes.int64);
});

_indices = indices_;
_values = values_;
_dense_shape = dense_shape_;

var indices_shape = indices.TensorShape.with_rank(2);
var values_shape = values.TensorShape.with_rank(1);
var dense_shape_shape = dense_shape.TensorShape.with_rank(1);

indices_shape["0"].merge_with(values_shape[0]);
indices_shape["1"].merge_with(dense_shape_shape[0]);

_shape = new TensorShape(_dense_shape.Select(x => Convert.ToInt32(x)).ToArray());
}
}

public interface _TensorLike
{
}

public static class sparse_tensor_extension
{
public static bool is_sparse(this _TensorLike x)
{
return x.GetType().Name.Contains("SparseTensor");
}
}
}

+ 2
- 2
src/TensorFlowNET.Core/Framework/tensor_shape.cs View File

@@ -44,14 +44,14 @@ namespace Tensorflow.Framework
return true;
}

if (other.is_sparse())
if (other.IsSparseTensor)
{
return self.dtype.is_compatible_with(other.dtype);
}

return self.dtype.is_compatible_with(other.dtype) &&
_shape_is_compatible_0dim(self.shape, other.shape) &&
!self.is_sparse();
!self.IsSparseTensor;
}

public static Dimension dimension_at_index(TensorShape shape, int index)


+ 8
- 8
src/TensorFlowNET.Core/Gradients/math_grad.cs View File

@@ -291,23 +291,23 @@ namespace Tensorflow.Gradients
var b = math_ops.conj(op.inputs[1]);
if (!t_a && !t_b)
{
grad_a = gen_math_ops.batch_mat_mul(grad, b, adj_y: true);
grad_b = gen_math_ops.batch_mat_mul(a, grad, adj_x: true);
grad_a = math_ops.batch_matmul(grad, b, adj_y: true);
grad_b = math_ops.batch_matmul(a, grad, adj_x: true);
}
else if (!t_a && t_b)
{
grad_a = gen_math_ops.batch_mat_mul(grad, b);
grad_b = gen_math_ops.batch_mat_mul(grad, a, adj_x: true);
grad_a = math_ops.batch_matmul(grad, b);
grad_b = math_ops.batch_matmul(grad, a, adj_x: true);
}
else if (t_a && !t_b)
{
grad_a = gen_math_ops.batch_mat_mul(grad, b);
grad_b = gen_math_ops.batch_mat_mul(grad, a, adj_x: true);
grad_a = math_ops.batch_matmul(grad, b);
grad_b = math_ops.batch_matmul(grad, a, adj_x: true);
}
else if (t_a && t_b)
{
grad_a = gen_math_ops.batch_mat_mul(b, grad, adj_x: true, adj_y: true);
grad_b = gen_math_ops.batch_mat_mul(grad, a, adj_x: true, adj_y: true);
grad_a = math_ops.batch_matmul(b, grad, adj_x: true, adj_y: true);
grad_b = math_ops.batch_matmul(grad, a, adj_x: true, adj_y: true);
}

return new Tensor[] { grad_a, grad_b };


+ 1
- 0
src/TensorFlowNET.Core/Keras/ArgsDefinition/Preprocessing/TextVectorizationArgs.cs View File

@@ -11,5 +11,6 @@ namespace Tensorflow.Keras.ArgsDefinition
public int MaxTokens { get; set; } = -1;
public string OutputMode { get; set; } = "int";
public int OutputSequenceLength { get; set; } = -1;
public string[] Vocabulary { get; set; }
}
}

+ 79
- 400
src/TensorFlowNET.Core/Operations/NnOps/gen_nn_ops.cs View File

@@ -40,37 +40,16 @@ namespace Tensorflow.Operations
/// <param name="parameters"></param>
/// <returns></returns>
public static Tensor conv2d(Conv2dParams parameters)
{
if (tf.executing_eagerly())
{
var results = tf.Runner.TFE_FastPathExecute(tf.Context, tf.Context.DeviceName,
"Conv2D", parameters.Name,
null,
parameters.Input, parameters.Filter,
"strides", parameters.Strides,
"use_cudnn_on_gpu", parameters.UseCudnnOnGpu,
"padding", parameters.Padding,
"explicit_paddings", parameters.ExplicitPaddings,
"data_format", parameters.DataFormat,
"dilations", parameters.Dilations);

return results[0];
}

var _op = tf.OpDefLib._apply_op_helper("Conv2D", name: parameters.Name, args: new
{
input = parameters.Input,
filter = parameters.Filter,
strides = parameters.Strides,
padding = parameters.Padding,
use_cudnn_on_gpu = parameters.UseCudnnOnGpu,
explicit_paddings = parameters.ExplicitPaddings,
data_format = parameters.DataFormat,
dilations = parameters.Dilations
});

return _op.outputs[0];
}
=> tf.Context.ExecuteOp("Conv2D", parameters.Name, new ExecuteOpArgs(parameters.Input, parameters.Filter)
.SetAttributes(new
{
strides = parameters.Strides,
padding = parameters.Padding,
use_cudnn_on_gpu = parameters.UseCudnnOnGpu,
explicit_paddings = parameters.ExplicitPaddings,
data_format = parameters.DataFormat,
dilations = parameters.Dilations
}));

/// <summary>
/// Computes the gradients of convolution with respect to the filter.
@@ -83,43 +62,16 @@ namespace Tensorflow.Operations
string data_format = "NHWC",
int[] dilations = null,
string name = null)
{
if (explicit_paddings == null)
explicit_paddings = new int[0];
if (dilations == null)
dilations = new int[] { 1, 1, 1, 1 };

if (tf.executing_eagerly())
{
var results = tf.Runner.TFE_FastPathExecute(tf.Context, tf.Context.DeviceName,
"Conv2DBackpropFilter", name,
null,
input, filter_sizes, out_backprop,
"strides", strides,
"use_cudnn_on_gpu", use_cudnn_on_gpu,
"padding", padding,
"explicit_paddings", explicit_paddings,
"data_format", data_format,
"dilations", dilations);

return results[0];
}

var _op = tf.OpDefLib._apply_op_helper("Conv2DBackpropFilter", name: name, args: new
{
input,
filter_sizes,
out_backprop,
strides,
padding,
use_cudnn_on_gpu,
explicit_paddings,
data_format,
dilations
});

return _op.outputs[0];
}
=> tf.Context.ExecuteOp("Conv2DBackpropFilter", name, new ExecuteOpArgs(input, filter_sizes, out_backprop)
.SetAttributes(new
{
strides,
padding,
use_cudnn_on_gpu,
explicit_paddings = explicit_paddings ?? new int[0],
data_format,
dilations = dilations ?? new int[] { 1, 1, 1, 1 }
}));

/// <summary>
/// Computes the gradients of convolution with respect to the input.
@@ -132,99 +84,29 @@ namespace Tensorflow.Operations
string data_format = "NHWC",
int[] dilations = null,
string name = null)
{
if (explicit_paddings == null)
explicit_paddings = new int[0];
if (dilations == null)
dilations = new int[] { 1, 1, 1, 1 };

if (tf.executing_eagerly())
{
var results = tf.Runner.TFE_FastPathExecute(tf.Context, tf.Context.DeviceName,
"Conv2DBackpropInput", name,
null,
input_sizes, filter, out_backprop,
"strides", strides,
"use_cudnn_on_gpu", use_cudnn_on_gpu,
"padding", padding,
"explicit_paddings", explicit_paddings,
"data_format", data_format,
"dilations", dilations);

return results[0];
}

var _op = tf.OpDefLib._apply_op_helper("Conv2DBackpropInput", name: name, args: new
{
input_sizes,
filter,
out_backprop,
strides,
padding,
use_cudnn_on_gpu,
explicit_paddings,
data_format,
dilations
});

return _op.outputs[0];
}
=> tf.Context.ExecuteOp("Conv2DBackpropInput", name, new ExecuteOpArgs(input_sizes, filter, out_backprop)
.SetAttributes(new
{
strides,
padding,
use_cudnn_on_gpu,
explicit_paddings = explicit_paddings ?? new int[0],
data_format,
dilations = dilations ?? new int[] { 1, 1, 1, 1 }
}));

public static Tensor bias_add(Tensor value,
IVariableV1 bias,
string data_format = null,
string name = null)
{
if (tf.executing_eagerly())
{
var results = tf.Runner.TFE_FastPathExecute(tf.Context, tf.Context.DeviceName,
"BiasAdd", name,
null,
value, bias,
"data_format", data_format);

return results[0];
}

if (data_format == null)
data_format = "NHWC";

var _op = tf.OpDefLib._apply_op_helper("BiasAdd", name: name, args: new
{
value,
bias,
data_format
});

return _op.outputs[0];
}
=> tf.Context.ExecuteOp("BiasAdd", name, new ExecuteOpArgs(value, bias)
.SetAttributes(new { data_format = data_format ?? "NHWC" }));

public static Tensor bias_add_grad(Tensor out_backprop,
string data_format = "NHWC",
string name = null)
{
if (tf.executing_eagerly())
{
var results = tf.Runner.TFE_FastPathExecute(tf.Context, tf.Context.DeviceName,
"BiasAddGrad", name,
null,
out_backprop,
"data_format", data_format);

return results[0];
}

if (data_format == null)
data_format = "NHWC";

var _op = tf.OpDefLib._apply_op_helper("BiasAddGrad", name: name, args: new
{
out_backprop,
data_format
});

return _op.outputs[0];
}
=> tf.Context.ExecuteOp("BiasAddGrad", name, new ExecuteOpArgs(out_backprop)
.SetAttributes(new { data_format = data_format ?? "NHWC" }));

/// <summary>
/// Computes exponential linear: <c>exp(features) - 1</c> if &amp;lt; 0, <c>features</c> otherwise.
@@ -269,29 +151,19 @@ namespace Tensorflow.Operations
}

public static Tensor[] fused_batch_norm_grad_v3(FusedBatchNormParams @params)
=> tf.Context.RunInAutoMode(()
=> tf.OpDefLib._apply_op_helper("FusedBatchNormGradV3", name: @params.Name,
args: new
{
y_backprop = @params.YBackprop,
x = @params.X,
scale = @params.Scale,
reserve_space_1 = @params.ReserveSpace1,
reserve_space_2 = @params.ReserveSpace2,
reserve_space_3 = @params.ReserveSpace3,
epsilon = @params.Epsilon,
data_format = @params.DataFormat,
is_training = @params.IsTraining
}).outputs, ()
=> tf.Runner.TFE_FastPathExecute(tf.Context, tf.Context.DeviceName,
"FusedBatchNormGradV3", @params.Name,
null,
@params.YBackprop, @params.X, @params.Scale,
@params.ReserveSpace1, @params.ReserveSpace2, @params.ReserveSpace3,
"epsilon", @params.Epsilon,
"data_format", @params.DataFormat,
"is_training", @params.IsTraining),
@params.YBackprop);
=> tf.Context.ExecuteOp("FusedBatchNormGradV3", @params.Name,
new ExecuteOpArgs(@params.YBackprop,
@params.X,
@params.Scale,
@params.ReserveSpace1,
@params.ReserveSpace2,
@params.ReserveSpace3)
.SetAttributes(new
{
epsilon = @params.Epsilon,
data_format = @params.DataFormat,
is_training = @params.IsTraining
}));

public static Tensor[] fused_batch_norm(Tensor x,
Tensor scale,
@@ -328,39 +200,8 @@ namespace Tensorflow.Operations
string data_format = "NHWC",
bool is_training = true,
string name = null)
{
if (tf.executing_eagerly())
{
var results = tf.Runner.TFE_FastPathExecute(tf.Context, tf.Context.DeviceName,
"FusedBatchNormV3", name,
null,
x,
scale,
offset,
mean,
variance,
"epsilon", epsilon,
"exponential_avg_factor", exponential_avg_factor,
"data_format", data_format,
"is_training", is_training);

return results;
}

var _op = tf.OpDefLib._apply_op_helper("FusedBatchNormV3", name: name, args: new
{
x,
scale,
offset,
mean,
variance,
epsilon,
data_format,
is_training
});

return _op.outputs;
}
=> tf.Context.ExecuteOp("FusedBatchNormV3", name, new ExecuteOpArgs(x, scale, offset, mean, variance)
.SetAttributes(new { epsilon, data_format, is_training }));

/// <summary>
/// Local Response Normalization.
@@ -388,14 +229,7 @@ namespace Tensorflow.Operations
}

public static Tensor log_softmax(Tensor logits, string name = null)
=> tf.Context.RunInAutoMode(()
=> tf.OpDefLib._apply_op_helper("LogSoftmax", name: name,
args: new { logits }).output, ()
=> tf.Runner.TFE_FastPathExecute(tf.Context, tf.Context.DeviceName,
"LogSoftmax", name,
null,
logits).FirstOrDefault(),
logits);
=> tf.Context.ExecuteOp("LogSoftmax", name, new ExecuteOpArgs(logits));

/// <summary>
/// Says whether the targets are in the top `K` predictions.
@@ -418,19 +252,8 @@ namespace Tensorflow.Operations
}

public static Tensor leaky_relu(Tensor features, float alpha = 0.2f, string name = null)
=> tf.Context.RunInAutoMode(()
=> tf.OpDefLib._apply_op_helper("LeakyRelu", name: name,
args: new
{
features,
alpha
}).output, ()
=> tf.Runner.TFE_FastPathExecute(tf.Context, tf.Context.DeviceName,
"LeakyRelu", name,
null,
features,
"alpha", alpha).FirstOrDefault(),
features);
=> tf.Context.ExecuteOp("LeakyRelu", name,
new ExecuteOpArgs(features).SetAttributes(new { alpha }));

public static Tensor max_pool(Tensor input,
int[] ksize,
@@ -438,63 +261,25 @@ namespace Tensorflow.Operations
string padding,
string data_format = "NHWC",
string name = null)
{
if (tf.executing_eagerly())
{
var results = tf.Runner.TFE_FastPathExecute(tf.Context, tf.Context.DeviceName,
"MaxPool", name,
null,
input,
"ksize", ksize,
"strides", strides,
"padding", padding,
"data_format", data_format);

return results[0];
}

var _op = tf.OpDefLib._apply_op_helper("MaxPool", name: name, args: new
{
input,
ksize,
strides,
padding,
data_format,
});

return _op.outputs[0];
}
=> tf.Context.ExecuteOp("MaxPool", name, new ExecuteOpArgs(input)
.SetAttributes(new
{
ksize,
strides,
padding,
data_format
}));

public static Tensor max_pool_grad(Tensor orig_input, Tensor orig_output, Tensor grad, int[] ksize, int[] strides, string padding,
string data_format = "NHWC", string name = null)
{
if (tf.executing_eagerly())
{
var results = tf.Runner.TFE_FastPathExecute(tf.Context, tf.Context.DeviceName,
"MaxPoolGrad", name,
null,
orig_input, orig_output, grad,
"ksize", ksize,
"strides", strides,
"padding", padding,
"data_format", data_format);

return results[0];
}

var _op = tf.OpDefLib._apply_op_helper("MaxPoolGrad", name: name, args: new
{
orig_input,
orig_output,
grad,
ksize,
strides,
padding,
data_format
});

return _op.outputs[0];
}
=> tf.Context.ExecuteOp("MaxPoolGrad", name, new ExecuteOpArgs(orig_input, orig_output, grad)
.SetAttributes(new
{
ksize,
strides,
padding,
data_format
}));

public static Tensor[] top_kv2(Tensor input, int k, bool sorted = true, string name = null)
{
@@ -509,68 +294,14 @@ namespace Tensorflow.Operations
}

public static Tensor relu_grad(Tensor gradients, Tensor features, string name = null)
{
if (tf.executing_eagerly())
{
var results = tf.Runner.TFE_FastPathExecute(tf.Context, tf.Context.DeviceName,
"ReluGrad", name,
null,
gradients, features);

return results[0];
}

var _op = tf.OpDefLib._apply_op_helper("ReluGrad", name: name, args: new
{
gradients,
features
});

return _op.outputs[0];
}
=> tf.Context.ExecuteOp("ReluGrad", name, new ExecuteOpArgs(gradients, features));

public static Tensor leaky_relu_grad(Tensor gradients, Tensor features, float alpha = 0.2f, string name = null)
{
if (tf.executing_eagerly())
{
var results = tf.Runner.TFE_FastPathExecute(tf.Context, tf.Context.DeviceName,
"LeakyReluGrad", name,
null,
gradients, features,
"alpha", alpha);

return results[0];
}

var _op = tf.OpDefLib._apply_op_helper("LeakyReluGrad", name: name, args: new
{
gradients,
features,
alpha
});

return _op.output;
}
=> tf.Context.ExecuteOp("LeakyReluGrad", name, new ExecuteOpArgs(gradients, features)
.SetAttributes(new { alpha }));

public static Tensor softmax(Tensor logits, string name = null)
{
if (tf.Context.executing_eagerly())
{
var results = tf.Runner.TFE_FastPathExecute(tf.Context, tf.Context.DeviceName,
"Softmax", name,
null,
logits);

return results[0];
}

var _op = tf.OpDefLib._apply_op_helper("Softmax", name: name, args: new
{
logits
});

return _op.outputs[0];
}
=> tf.Context.ExecuteOp("Softmax", name, new ExecuteOpArgs(logits));

/// <summary>
/// Computes softmax cross entropy cost and gradients to backpropagate.
@@ -581,23 +312,9 @@ namespace Tensorflow.Operations
/// <returns></returns>
public static (Tensor, Tensor) softmax_cross_entropy_with_logits(Tensor features, Tensor labels, string name = null)
{
if (tf.executing_eagerly())
{
var results = tf.Runner.TFE_FastPathExecute(tf.Context, tf.Context.DeviceName,
"SoftmaxCrossEntropyWithLogits", name,
null,
features, labels);

return (results[0], results[1]);
}
var results = tf.Context.ExecuteOp("SoftmaxCrossEntropyWithLogits", name, new ExecuteOpArgs(features, labels));

var _op = tf.OpDefLib._apply_op_helper("SoftmaxCrossEntropyWithLogits", name: name, args: new
{
features,
labels
});

return (_op.outputs[0], _op.outputs[1]);
return (results[0], results[1]);
}

/// <summary>
@@ -629,21 +346,9 @@ namespace Tensorflow.Operations
/// </remarks>
public static (Tensor loss, Tensor backprop) sparse_softmax_cross_entropy_with_logits(Tensor features, Tensor labels, string name = "SparseSoftmaxCrossEntropyWithLogits")
{
if (tf.executing_eagerly())
{
var results = tf.Runner.TFE_FastPathExecute(tf.Context, tf.Context.DeviceName,
"SparseSoftmaxCrossEntropyWithLogits", name,
null,
features, labels);

return (results[0], results[1]);
}

var op = tf.OpDefLib._apply_op_helper("SparseSoftmaxCrossEntropyWithLogits", name: name, args: new { features, labels });
int _idx = 0;
var loss = op.outputs[_idx++];
var backprop = op.outputs[_idx++];
return (loss, backprop);
var results = tf.Context.ExecuteOp("SparseSoftmaxCrossEntropyWithLogits", name, new ExecuteOpArgs(features, labels));

return (results[0], results[1]);
}

/// <summary>
@@ -653,35 +358,9 @@ namespace Tensorflow.Operations
/// <param name="name">A name for the operation (optional).</param>
/// <returns>A `Tensor`. Has the same type as `features`.</returns>
public static Tensor relu(Tensor features, string name = null)
{
if (tf.executing_eagerly())
{
var results = tf.Runner.TFE_FastPathExecute(tf.Context, tf.Context.DeviceName,
"Relu", name,
null,
features);

return results[0];
}

var _op = tf.OpDefLib._apply_op_helper("Relu", name: name, args: new { features });
return _op.outputs[0];
}
=> tf.Context.ExecuteOp("Relu", name, new ExecuteOpArgs(features));

public static Tensor tanh(Tensor x, string name = null)
{
if (tf.Context.executing_eagerly())
{
var results = tf.Runner.TFE_FastPathExecute(tf.Context, tf.Context.DeviceName,
"Tanh", name,
null,
x);

return results[0];
}

var _op = tf.OpDefLib._apply_op_helper("Tanh", name: name, args: new { x });
return _op.outputs[0];
}
=> tf.Context.ExecuteOp("Tanh", name, new ExecuteOpArgs(x));
}
}

+ 6
- 2
src/TensorFlowNET.Core/Operations/OpDefLibrary.cs View File

@@ -68,10 +68,10 @@ namespace Tensorflow
string _scope_name = scope;

// Perform input type inference
foreach (var input_arg in op_def.InputArg)
foreach (var (i, input_arg) in enumerate(op_def.InputArg))
{
var input_name = input_arg.Name;
if (keywords.ContainsKey(input_name))
values = keywords[input_name];
else if (keywords.ContainsKey(input_name + "_"))
@@ -79,6 +79,10 @@ namespace Tensorflow
input_name += "_";
values = keywords[input_name];
}
else if (keywords.ContainsKey($"input_{i}"))
{
values = keywords[$"input_{i}"];
}
else
throw new TypeError("No argument for input " + input_name);



+ 22
- 64
src/TensorFlowNET.Core/Operations/array_ops.cs View File

@@ -57,20 +57,8 @@ namespace Tensorflow
/// gradients in some corner cases.
/// </remarks>
public static Tensor prevent_gradient(Tensor input, string message = "", string name = null)
{
if (tf.executing_eagerly())
{
var results = tf.Runner.TFE_FastPathExecute(tf.Context, tf.Context.DeviceName,
"PreventGradient", name,
null,
input,
"message", message);
return results[0];
}

var op = tf.OpDefLib._apply_op_helper("PreventGradient", name: name, args: new { input, message });
return op.output;
}
=> tf.Context.ExecuteOp("PreventGradient", name, new ExecuteOpArgs(input)
.SetAttributes(new { message }));

internal static Tensor constant(object value,
TF_DataType dtype = TF_DataType.DtInvalid,
@@ -737,35 +725,27 @@ namespace Tensorflow
public static Tensor strided_slice_grad(Tensor shape, Tensor begin, Tensor end, Tensor strides, Tensor dy,
long begin_mask = 0, long end_mask = 0, long ellipsis_mask = 0, long new_axis_mask = 0,
long shrink_axis_mask = 0, string name = null)
=> tf.Context.RunInAutoMode2("StridedSliceGrad", name, new AutoModeArgs
{
OpInputArgs = new
=> tf.Context.ExecuteOp("StridedSliceGrad", name,
new ExecuteOpArgs(shape, begin, end, strides, dy)
{
shape,
begin,
end,
strides,
dy
},
OpAttrs = new
GetGradientAttrs = (op) => new
{
T = op.get_attr<TF_DataType>("T"),
Index = op.get_attr<TF_DataType>("Index"),
begin_mask = op.get_attr<long>("begin_mask"),
end_mask = op.get_attr<long>("end_mask"),
ellipsis_mask = op.get_attr<long>("ellipsis_mask"),
new_axis_mask = op.get_attr<long>("new_axis_mask"),
shrink_axis_mask = op.get_attr<long>("shrink_axis_mask")
}
}.SetAttributes(new
{
begin_mask,
end_mask,
ellipsis_mask,
new_axis_mask,
shrink_axis_mask
},
GetGradientAttrs = (op) => new
{
T = op.get_attr<TF_DataType>("T"),
Index = op.get_attr<TF_DataType>("Index"),
begin_mask = op.get_attr<long>("begin_mask"),
end_mask = op.get_attr<long>("end_mask"),
ellipsis_mask = op.get_attr<long>("ellipsis_mask"),
new_axis_mask = op.get_attr<long>("new_axis_mask"),
shrink_axis_mask = op.get_attr<long>("shrink_axis_mask")
}
});
}));

/// <summary>
/// Removes dimensions of size 1 from the shape of a tensor.
@@ -800,38 +780,17 @@ namespace Tensorflow
int num_cols = -1,
float padding_value = 0,
string align = "RIGHT_LEFT")
{
if (tf.Context.executing_eagerly())
{
var results = tf.Runner.TFE_FastPathExecute(tf.Context, tf.Context.DeviceName,
"MatrixDiagV3", name,
null,
diagonal, k, num_rows, num_cols, padding_value,
"align", align);
return results[0];
}

throw new NotImplementedException("");
}
=> tf.Context.ExecuteOp("MatrixDiagV3", name,
new ExecuteOpArgs(diagonal, k, num_rows, num_cols, padding_value)
.SetAttributes(new { align }));

public static Tensor matrix_set_diag(Tensor input,
Tensor diagonal,
string name = "set_diag",
int k = 0,
string align = "RIGHT_LEFT")
{
if (tf.Context.executing_eagerly())
{
var results = tf.Runner.TFE_FastPathExecute(tf.Context, tf.Context.DeviceName,
"MatrixSetDiagV3", name,
null,
input, diagonal, k,
"align", align);
return results[0];
}

throw new NotImplementedException("");
}
=> tf.Context.ExecuteOp("MatrixSetDiagV3", name, new ExecuteOpArgs(input, diagonal, k)
.SetAttributes(new { align }));

/// <summary>
/// Computes the shape of a broadcast given symbolic shapes.
@@ -960,9 +919,8 @@ namespace Tensorflow
=> gen_array_ops.slice(input, begin, size, name: name);

public static Tensor slice(Tensor input, Tensor begin, Tensor size, string name = null)
=> tf.Context.RunInAutoMode2("Slice", name, new AutoModeArgs
=> tf.Context.ExecuteOp("Slice", name, new ExecuteOpArgs(input, begin, size)
{
OpInputArgs = new { input, begin, size },
GetGradientAttrs = (op) => new
{
T = op.get_attr<TF_DataType>("T"),


+ 2
- 29
src/TensorFlowNET.Core/Operations/bitwise_ops.cs View File

@@ -94,20 +94,7 @@ namespace Tensorflow.Operations
/// <param name="name"></param>
/// <returns></returns>
Tensor unary_op(Tensor x, string opName, string name)
{
if (tf.Context.executing_eagerly())
{
var results = tf.Runner.TFE_FastPathExecute(tf.Context, tf.Context.DeviceName,
opName, name,
null,
x);

return results[0];
}

var _op = tf.OpDefLib._apply_op_helper(opName, name, args: new { x });
return _op.output;
}
=> tf.Context.ExecuteOp(opName, name, new ExecuteOpArgs(x));

/// <summary>
/// Helper method to invoke binary operator with specified name.
@@ -118,21 +105,7 @@ namespace Tensorflow.Operations
/// <param name="name"></param>
/// <returns></returns>
Tensor binary_op(Tensor x, Tensor y, string opName, string name)
{
if (tf.Context.executing_eagerly())
{
var results = tf.Runner.TFE_FastPathExecute(tf.Context, tf.Context.DeviceName,
opName, name,
null,
x, y);

return results[0];
}

var _op = tf.OpDefLib._apply_op_helper(opName, name, args: new { x, y });
return _op.output;
}

=> tf.Context.ExecuteOp(opName, name, new ExecuteOpArgs(x, y));
#endregion
}
}

+ 85
- 362
src/TensorFlowNET.Core/Operations/dataset_ops.cs View File

@@ -8,26 +8,10 @@ namespace Tensorflow
public class dataset_ops
{
public Tensor tensor_dataset(Tensor[] components, TensorShape[] output_shapes, string name = null)
{
if (tf.Context.executing_eagerly())
=> tf.Context.ExecuteOp("TensorDataset", name, new ExecuteOpArgs()
{
var results = tf.Runner.TFE_FastPathExecute(tf.Context, tf.Context.DeviceName,
"TensorDataset", name,
null,
new object[]
{
components,
"output_shapes", output_shapes
});
return results[0];
}

var _op = tf.OpDefLib._apply_op_helper("TensorDataset",
name: name,
args: new { components, output_shapes });

return _op.output;
}
OpInputArgs = new object[] { components }
}.SetAttributes(new { output_shapes }));

/// <summary>
/// Creates a dataset that emits each dim-0 slice of `components` once.
@@ -37,192 +21,62 @@ namespace Tensorflow
/// <param name="name"></param>
/// <returns></returns>
public Tensor tensor_slice_dataset(Tensor[] components, TensorShape[] output_shapes, string name = null)
{
if (tf.Context.executing_eagerly())
=> tf.Context.ExecuteOp("TensorSliceDataset", name, new ExecuteOpArgs()
{
var results = tf.Runner.TFE_FastPathExecute(tf.Context, tf.Context.DeviceName,
"TensorSliceDataset", name,
null,
new object[]
{
components,
"output_shapes", output_shapes
});
return results[0];
}

var _op = tf.OpDefLib._apply_op_helper("TensorSliceDataset",
name: name,
args: new { components, output_shapes });

return _op.outputs[0];
}
OpInputArgs = new object[] { components }
}.SetAttributes(new { output_shapes }));

public Tensor range_dataset(Tensor start, Tensor stop, Tensor step, TF_DataType[] output_types, TensorShape[] output_shapes, string name = null)
{
if (tf.Context.executing_eagerly())
{
var results = tf.Runner.TFE_FastPathExecute(tf.Context, tf.Context.DeviceName,
"RangeDataset", name,
null,
start, stop, step,
"output_types", output_types,
"output_shapes", output_shapes);
return results[0];
}

throw new NotImplementedException("");
}
=> tf.Context.ExecuteOp("RangeDataset", name, new ExecuteOpArgs(start, stop, step)
.SetAttributes(new { output_types, output_shapes }));

public Tensor repeat_dataset(Tensor input_dataset, Tensor count, TF_DataType[] output_types, TensorShape[] output_shapes, string name = null)
{
if (tf.Context.executing_eagerly())
{
var results = tf.Runner.TFE_FastPathExecute(tf.Context, tf.Context.DeviceName,
"RepeatDataset", name,
null,
input_dataset, count,
"output_types", output_types,
"output_shapes", output_shapes);
return results[0];
}

throw new NotImplementedException("");
}
=> tf.Context.ExecuteOp("RepeatDataset", name, new ExecuteOpArgs(input_dataset, count)
.SetAttributes(new { output_types, output_shapes }));

public Tensor shard_dataset(Tensor input_dataset, Tensor num_shards, Tensor index,
TF_DataType[] output_types, TensorShape[] output_shapes,
bool require_non_empty = false, string name = null)
{
if (tf.Context.executing_eagerly())
{
var results = tf.Runner.TFE_FastPathExecute(tf.Context, tf.Context.DeviceName,
"ShardDataset", name,
null,
input_dataset, num_shards, index,
"require_non_empty", require_non_empty,
"output_types", output_types,
"output_shapes", output_shapes);
return results[0];
}

throw new NotImplementedException("");
}
=> tf.Context.ExecuteOp("ShardDataset", name, new ExecuteOpArgs(input_dataset, num_shards, index)
.SetAttributes(new { require_non_empty, output_types, output_shapes }));

public Tensor zip_dataset(Tensor[] input_datasets,
TF_DataType[] output_types,
TensorShape[] output_shapes,
string name = null)
{
if (tf.Context.executing_eagerly())
{
var results = tf.Runner.TFE_FastPathExecute(tf.Context, tf.Context.DeviceName,
"ZipDataset", name,
null,
new object[]
{
input_datasets,
"output_types", output_types,
"output_shapes", output_shapes
});
return results[0];
}

throw new NotImplementedException("");
}
=> tf.Context.ExecuteOp("ZipDataset", name, new ExecuteOpArgs()
{
OpInputArgs = new object[] { input_datasets }
}.SetAttributes(new { output_types, output_shapes }));

public Tensor shuffle_dataset_v3(Tensor input_dataset, Tensor buffer_size,
Tensor seed, Tensor seed2, Tensor seed_generator,
TF_DataType[] output_types, TensorShape[] output_shapes,
bool reshuffle_each_iteration = true,
string name = null)
{
if (tf.Context.executing_eagerly())
{
var results = tf.Runner.TFE_FastPathExecute(tf.Context, tf.Context.DeviceName,
"ShuffleDatasetV3", name,
null,
input_dataset, buffer_size,
seed, seed2, seed_generator,
"reshuffle_each_iteration", reshuffle_each_iteration,
"output_types", output_types,
"output_shapes", output_shapes);
return results[0];
}

throw new NotImplementedException("");
}
=> tf.Context.ExecuteOp("ShuffleDatasetV3", name, new ExecuteOpArgs(input_dataset, buffer_size, seed, seed2, seed_generator)
.SetAttributes(new { reshuffle_each_iteration, output_types, output_shapes }));

public Tensor skip_dataset(Tensor input_dataset, Tensor count,
TF_DataType[] output_types, TensorShape[] output_shapes,
string name = null)
{
if (tf.Context.executing_eagerly())
{
var results = tf.Runner.TFE_FastPathExecute(tf.Context, tf.Context.DeviceName,
"SkipDataset", name,
null,
input_dataset, count,
"output_types", output_types,
"output_shapes", output_shapes);
return results[0];
}

throw new NotImplementedException("");
}
=> tf.Context.ExecuteOp("SkipDataset", name, new ExecuteOpArgs(input_dataset, count)
.SetAttributes(new { output_types, output_shapes }));

public Tensor dummy_seed_generator(string name = null)
{
if (tf.Context.executing_eagerly())
{
var results = tf.Runner.TFE_FastPathExecute(tf.Context, tf.Context.DeviceName,
"DummySeedGenerator", name,
null);
return results[0];
}

throw new NotImplementedException("");
}
=> tf.Context.ExecuteOp("DummySeedGenerator", name, new ExecuteOpArgs());

public Tensor concatenate_dataset(Tensor input_dataset, Tensor another_dataset,
TF_DataType[] output_types, TensorShape[] output_shapes,
string name = null)
{
if (tf.Context.executing_eagerly())
{
var results = tf.Runner.TFE_FastPathExecute(tf.Context, tf.Context.DeviceName,
"ConcatenateDataset", name,
null,
input_dataset, another_dataset,
"output_types", output_types,
"output_shapes", output_shapes);
return results[0];
}

var _op = tf.OpDefLib._apply_op_helper("ConcatenateDataset",
name: name,
args: new { input_dataset, another_dataset, output_types, output_shapes });

return _op.outputs[0];
}
=> tf.Context.ExecuteOp("ConcatenateDataset", name, new ExecuteOpArgs(input_dataset, another_dataset)
.SetAttributes(new { output_types, output_shapes }));

public Tensor cache_dataset_v2(Tensor input_dataset, Tensor filename, Tensor cache,
TF_DataType[] output_types, TensorShape[] output_shapes,
string name = null)
{
if (tf.Context.executing_eagerly())
{
var results = tf.Runner.TFE_FastPathExecute(tf.Context, tf.Context.DeviceName,
"CacheDatasetV2", name,
null,
input_dataset, filename, cache,
"output_types", output_types,
"output_shapes", output_shapes);
return results[0];
}

throw new NotImplementedException("");
}
=> tf.Context.ExecuteOp("CacheDatasetV2", name, new ExecuteOpArgs(input_dataset, filename, cache)
.SetAttributes(new { output_types, output_shapes }));

/// <summary>
/// Creates a dataset that batches `batch_size` elements from `input_dataset`.
@@ -240,21 +94,9 @@ namespace Tensorflow
TF_DataType[] output_types, TensorShape[] output_shapes,
bool parallel_copy = false,
string name = null)
{
if (tf.Context.executing_eagerly())
{
var results = tf.Runner.TFE_FastPathExecute(tf.Context, tf.Context.DeviceName,
"BatchDatasetV2", name,
null,
input_dataset, buffer_size, drop_remainder,
"parallel_copy", parallel_copy,
"output_types", output_types,
"output_shapes", output_shapes);
return results[0];
}

throw new NotImplementedException("");
}
=> tf.Context.ExecuteOp("BatchDatasetV2", name,
new ExecuteOpArgs(input_dataset, buffer_size, drop_remainder)
.SetAttributes(new { parallel_copy, output_types, output_shapes }));

/// <summary>
///
@@ -262,17 +104,7 @@ namespace Tensorflow
/// <param name="name"></param>
/// <returns></returns>
public Tensor dummy_memory_cache(string name = "")
{
if (tf.Context.executing_eagerly())
{
var results = tf.Runner.TFE_FastPathExecute(tf.Context, tf.Context.DeviceName,
"DummyMemoryCache", name,
null);
return results[0];
}

throw new NotImplementedException("");
}
=> tf.Context.ExecuteOp("DummyMemoryCache", name, new ExecuteOpArgs());

/// <summary>
/// Creates a dataset that asynchronously prefetches elements from `input_dataset`.
@@ -290,22 +122,14 @@ namespace Tensorflow
int? slack_period = 0,
bool legacy_autotune = true,
string name = null)
{
if (tf.Context.executing_eagerly())
{
var results = tf.Runner.TFE_FastPathExecute(tf.Context, tf.Context.DeviceName,
"PrefetchDataset", name,
null,
input_dataset, buffer_size,
"output_types", output_types,
"output_shapes", output_shapes,
"slack_period", slack_period,
"legacy_autotune", legacy_autotune);
return results[0];
}

throw new NotImplementedException("");
}
=> tf.Context.ExecuteOp("PrefetchDataset", name, new ExecuteOpArgs(input_dataset, buffer_size)
.SetAttributes(new
{
output_types,
output_shapes,
slack_period,
legacy_autotune
}));

/// <summary>
/// Creates a dataset that contains `count` elements from the `input_dataset`.
@@ -319,20 +143,8 @@ namespace Tensorflow
public Tensor take_dataset(Tensor input_dataset, Tensor count,
TF_DataType[] output_types, TensorShape[] output_shapes,
string name = null)
{
if (tf.Context.executing_eagerly())
{
var results = tf.Runner.TFE_FastPathExecute(tf.Context, tf.Context.DeviceName,
"TakeDataset", name,
null,
input_dataset, count,
"output_types", output_types,
"output_shapes", output_shapes);
return results[0];
}

throw new NotImplementedException("");
}
=> tf.Context.ExecuteOp("TakeDataset", name, new ExecuteOpArgs(input_dataset, count)
.SetAttributes(new { output_types, output_shapes }));

/// <summary>
/// Creates a dataset by applying optimizations to `input_dataset`.
@@ -348,24 +160,13 @@ namespace Tensorflow
TF_DataType[] output_types, TensorShape[] output_shapes,
string[] optimization_configs = null,
string name = null)
{
if (optimization_configs == null)
optimization_configs = new string[0];

if (tf.Context.executing_eagerly())
{
var results = tf.Runner.TFE_FastPathExecute(tf.Context, tf.Context.DeviceName,
"OptimizeDataset", name,
null,
input_dataset, optimizations,
"output_types", output_types,
"output_shapes", output_shapes,
"optimization_configs", optimization_configs);
return results[0];
}

throw new NotImplementedException("");
}
=> tf.Context.ExecuteOp("OptimizeDataset", name, new ExecuteOpArgs(input_dataset, optimizations)
.SetAttributes(new
{
output_types,
output_shapes,
optimization_configs = optimization_configs ?? new string[0]
}));

/// <summary>
/// Identity transformation that models performance.
@@ -381,22 +182,14 @@ namespace Tensorflow
TF_DataType[] output_types, TensorShape[] output_shapes,
AutotuneAlgorithm algorithm, long cpu_budget,
string name = null)
{
if (tf.Context.executing_eagerly())
{
var results = tf.Runner.TFE_FastPathExecute(tf.Context, tf.Context.DeviceName,
"ModelDataset", name,
null,
input_dataset,
"algorithm", algorithm,
"cpu_budget", cpu_budget,
"output_types", output_types,
"output_shapes", output_shapes);
return results[0];
}

throw new NotImplementedException("");
}
=> tf.Context.ExecuteOp("ModelDataset", name, new ExecuteOpArgs(input_dataset)
.SetAttributes(new
{
algorithm,
cpu_budget,
output_types,
output_shapes
}));

/// <summary>
/// A container for an iterator resource.
@@ -407,17 +200,9 @@ namespace Tensorflow
/// <returns>A tuple of `Tensor` objects (handle, deleter).</returns>
public (Tensor, Tensor) anonymous_iterator_v2(TF_DataType[] output_types, TensorShape[] output_shapes, string name = null)
{
if (tf.Context.executing_eagerly())
{
var results = tf.Runner.TFE_FastPathExecute(tf.Context, tf.Context.DeviceName,
"AnonymousIteratorV2", name,
null,
"output_types", output_types,
"output_shapes", output_shapes);
return (results[0], results[1]);
}

throw new NotImplementedException("");
var results = tf.Context.ExecuteOp("AnonymousIteratorV2", name,
new ExecuteOpArgs().SetAttributes(new { output_types, output_shapes }));
return (results[0], results[1]);
}

/// <summary>
@@ -427,19 +212,8 @@ namespace Tensorflow
/// <param name="iterator"></param>
/// <param name="name"></param>
/// <returns>The created Operation.</returns>
public ITensorOrOperation make_iterator(Tensor dataset, Tensor iterator, string name = null)
{
if (tf.Context.executing_eagerly())
{
var results = tf.Runner.TFE_FastPathExecute(tf.Context, tf.Context.DeviceName,
"MakeIterator", name,
null,
dataset, iterator);
return null;
}

throw new NotImplementedException("");
}
public void make_iterator(Tensor dataset, Tensor iterator, string name = null)
=> tf.Context.ExecuteOp("MakeIterator", name, new ExecuteOpArgs(dataset, iterator));

/// <summary>
///
@@ -450,23 +224,15 @@ namespace Tensorflow
/// <returns></returns>
public Tensor map_dataset(Tensor dataset, ConcreteFunction f, TF_DataType[] output_types, TensorShape[] output_shapes,
bool use_inter_op_parallelism = true, bool preserve_cardinality = false, string name = null)
{
if (tf.Context.executing_eagerly())
{
var results = tf.Runner.TFE_FastPathExecute(tf.Context, tf.Context.DeviceName,
"MapDataset", name,
null,
dataset, new Tensor[0],
"f", f,
"output_types", output_types,
"output_shapes", output_shapes,
"use_inter_op_parallelism", use_inter_op_parallelism,
"preserve_cardinality", preserve_cardinality);
return results[0];
}

throw new NotImplementedException("");
}
=> tf.Context.ExecuteOp("MapDataset", name, new ExecuteOpArgs(dataset, new Tensor[0])
.SetAttributes(new
{
f,
output_types,
output_shapes,
use_inter_op_parallelism,
preserve_cardinality
}));

/// <summary>
/// Creates a dataset that applies `f` to the outputs of `input_dataset`.
@@ -479,21 +245,8 @@ namespace Tensorflow
/// <returns></returns>
public Tensor flat_map_dataset(Tensor dataset, ConcreteFunction f, TF_DataType[] output_types, TensorShape[] output_shapes,
string name = null)
{
if (tf.Context.executing_eagerly())
{
var results = tf.Runner.TFE_FastPathExecute(tf.Context, tf.Context.DeviceName,
"FlatMapDataset", name,
null,
dataset, new Tensor[0],
"f", f,
"output_types", output_types,
"output_shapes", output_shapes);
return results[0];
}

throw new NotImplementedException("");
}
=> tf.Context.ExecuteOp("FlatMapDataset", name, new ExecuteOpArgs(dataset, new Tensor[0])
.SetAttributes(new { f, output_types, output_shapes }));

/// <summary>
/// Creates a dataset that applies `f` to the outputs of `input_dataset`.
@@ -512,24 +265,17 @@ namespace Tensorflow
string deterministic = "default",
bool preserve_cardinality = false,
string name = null)
{
if (tf.Context.executing_eagerly())
{
var results = tf.Runner.TFE_FastPathExecute(tf.Context, tf.Context.DeviceName,
"ParallelMapDatasetV2", name,
null,
dataset, new Tensor[0], num_parallel_calls,
"f", f,
"output_types", output_types,
"output_shapes", output_shapes,
"use_inter_op_parallelism", use_inter_op_parallelism,
"deterministic", deterministic,
"preserve_cardinality", preserve_cardinality);
return results[0];
}

throw new NotImplementedException("");
}
=> tf.Context.ExecuteOp("ParallelMapDatasetV2", name,
new ExecuteOpArgs(dataset, new Tensor[0], num_parallel_calls)
.SetAttributes(new
{
f,
output_types,
output_shapes,
use_inter_op_parallelism,
deterministic,
preserve_cardinality
}));

/// <summary>
/// A container for an iterator resource.
@@ -538,19 +284,8 @@ namespace Tensorflow
/// <param name="deleter"></param>
/// <param name="name"></param>
/// <returns>The created Operation.</returns>
public ITensorOrOperation delete_iterator(Tensor handle, Tensor deleter, string name = null)
{
if (tf.Context.executing_eagerly())
{
var results = tf.Runner.TFE_FastPathExecute(tf.Context, tf.Context.DeviceName,
"DeleteIterator", name,
null,
handle, deleter);
return null;
}

throw new NotImplementedException("");
}
public void delete_iterator(Tensor handle, Tensor deleter, string name = null)
=> tf.Context.ExecuteOp("DeleteIterator", name, new ExecuteOpArgs(handle, deleter));

/// <summary>
/// Gets the next output from the given iterator .
@@ -561,19 +296,7 @@ namespace Tensorflow
/// <param name="name"></param>
/// <returns></returns>
public Tensor[] iterator_get_next(Tensor iterator, TF_DataType[] output_types, TensorShape[] output_shapes, string name = null)
{
if (tf.Context.executing_eagerly())
{
var results = tf.Runner.TFE_FastPathExecute(tf.Context, tf.Context.DeviceName,
"IteratorGetNext", name,
null,
iterator,
"output_types", output_types,
"output_shapes", output_shapes);
return results;
}

throw new NotImplementedException("");
}
=> tf.Context.ExecuteOp("IteratorGetNext", name, new ExecuteOpArgs(iterator)
.SetAttributes(new { output_types, output_shapes }));
}
}

+ 55
- 342
src/TensorFlowNET.Core/Operations/gen_array_ops.cs View File

@@ -45,20 +45,7 @@ namespace Tensorflow
/// <param name="name"></param>
/// <returns></returns>
public static Tensor concat_v2<T, Ta>(T[] values, Ta axis, string name = null)
{
if (tf.Context.executing_eagerly())
{
var results = tf.Runner.TFE_FastPathExecute(tf.Context, tf.Context.DeviceName,
"ConcatV2", name,
null,
values, axis);

return results[0];
}

var _op = tf.OpDefLib._apply_op_helper("ConcatV2", name: name, args: new { values, axis });
return _op.output;
}
=> tf.Context.ExecuteOp("ConcatV2", name, new ExecuteOpArgs(values, axis));

public static Tensor concat_v2(Tensor[] values, Tensor axis, string name = null)
{
@@ -72,14 +59,7 @@ namespace Tensorflow
}

public static Tensor concat_v2(Tensor[] values, int axis, string name = null)
=> tf.Context.RunInAutoMode(()
=> tf.OpDefLib._apply_op_helper("ConcatV2", name: name,
args: new { values, axis }).output, ()
=> tf.Runner.TFE_FastPathExecute(tf.Context, tf.Context.DeviceName,
"ConcatV2", name,
null,
values, axis).FirstOrDefault(),
values);
=> tf.Context.ExecuteOp("ConcatV2", name, new ExecuteOpArgs(values, axis));

private static Tensor concat_v2_eager_fallback<T1, T2>(T1[] values, T2 axis, string name, Context ctx)
{
@@ -131,38 +111,11 @@ namespace Tensorflow
/// </code>
/// </remarks>
public static Tensor diag(Tensor diagonal, string name = null)
{
if (tf.Context.executing_eagerly())
{
var results = tf.Runner.TFE_FastPathExecute(tf.Context, tf.Context.DeviceName,
"Diag", name,
null,
diagonal);

return results[0];
}

var op = tf.OpDefLib._apply_op_helper("Diag", name: name, args: new { diagonal });

return op.output;
}
=> tf.Context.ExecuteOp("Diag", name, new ExecuteOpArgs(diagonal));

public static Tensor expand_dims(Tensor input, int axis, string name = null)
{
if (tf.Context.executing_eagerly())
{
var results = tf.Runner.TFE_FastPathExecute(tf.Context, tf.Context.DeviceName,
"ExpandDims", name,
null,
input, tf.convert_to_tensor(axis));

return results[0];
}

var _op = tf.OpDefLib._apply_op_helper("ExpandDims", name: name, args: new { input, dim = axis });

return _op.outputs[0];
}
=> tf.Context.ExecuteOp("ExpandDims", name, new ExecuteOpArgs(input, axis)
.SetAttributes(new { dim = axis }));

public static Tensor gather_v2<T1, T2>(T1 @params, T2 indices, int axis, string name = null)
{
@@ -202,14 +155,10 @@ namespace Tensorflow
}

public static Tensor pack(Tensor[] values, int axis = 0, string name = null)
=> tf.Context.RunInAutoMode(()
=> tf.OpDefLib._apply_op_helper("Pack", name, new { values, axis }).output, ()
=> tf.Runner.TFE_FastPathExecute(tf.Context, tf.Context.DeviceName,
"Pack", name,
null,
values,
"axis", axis).FirstOrDefault(),
values, axis);
=> tf.Context.ExecuteOp("Pack", name, new ExecuteOpArgs()
{
OpInputArgs = new object[] { values }
}.SetAttributes(new { axis }));

/// <summary>
/// Return a tensor with the same shape and contents as the input tensor or value.
@@ -217,29 +166,7 @@ namespace Tensorflow
/// <param name="input"></param>
/// <param name="name"></param>
public static Tensor identity(Tensor input, string name = null)
{
if (tf.executing_eagerly())
{
var results = tf.Runner.TFE_FastPathExecute(tf.Context, tf.Context.DeviceName,
"Identity", name,
null,
input);

return results[0];
}

var _op = tf.OpDefLib._apply_op_helper("Identity", name, new { input });
if (tf.Runner.MustRecordGradient())
{
tf.Runner.RecordGradient("Identity", _op.inputs, new object[]
{
"T", _op.get_attr<TF_DataType>("T")
}, _op.outputs);
}

return _op.output;
}
=> tf.Context.ExecuteOp("Identity", name, new ExecuteOpArgs(input));

public static Tensor invert_permutation(Tensor x, string name = null)
{
@@ -256,21 +183,7 @@ namespace Tensorflow
}

public static Tensor rank(Tensor input, string name = null)
{
if (tf.Context.executing_eagerly())
{
var results = tf.Runner.TFE_FastPathExecute(tf.Context, tf.Context.DeviceName,
"Rank", name,
null,
input);

return results[0];
}

var _op = tf.OpDefLib._apply_op_helper("Rank", name: name, args: new { input });

return _op.outputs[0];
}
=> tf.Context.ExecuteOp("Rank", name, new ExecuteOpArgs(input));

/// <summary>
/// Creates a tensor filled with a scalar value.
@@ -280,20 +193,7 @@ namespace Tensorflow
/// <param name="name">A name for the operation (optional).</param>
/// <returns>A `Tensor`. Has the same type as `value`.</returns>
public static Tensor fill<T>(Tensor dims, T value, string name = null)
{
if (tf.Context.executing_eagerly())
{
var results = tf.Runner.TFE_FastPathExecute(tf.Context, tf.Context.DeviceName,
"Fill", name,
null,
dims, value);

return results[0];
}

var _op = tf.OpDefLib._apply_op_helper("Fill", name, new { dims, value });
return _op.output;
}
=> tf.Context.ExecuteOp("Fill", name, new ExecuteOpArgs(dims, value));

/// <summary>
/// Return the reduction indices for computing gradients of s0 op s1 with broadcast.
@@ -304,19 +204,8 @@ namespace Tensorflow
/// <returns>A tuple of `Tensor` objects (r0, r1).</returns>
public static (Tensor, Tensor) broadcast_gradient_args(Tensor s0, Tensor s1, string name = "")
{
if (tf.Context.executing_eagerly())
{
var results = tf.Runner.TFE_FastPathExecute(tf.Context, tf.Context.DeviceName,
"BroadcastGradientArgs", name,
null,
s0, s1);

return (results[0], results[1]);
}

var _op = tf.OpDefLib._apply_op_helper("BroadcastGradientArgs", name, new { s0, s1 });

return (_op.outputs[0], _op.outputs[1]);
var results = tf.Context.ExecuteOp("BroadcastGradientArgs", name, new ExecuteOpArgs(s0, s1));
return (results[0], results[1]);
}

public static Tensor reverse<T>(Tensor tensor, T axis, string name = null)
@@ -326,31 +215,10 @@ namespace Tensorflow
}

public static Tensor reshape<T>(Tensor tensor, T shape, string name = null)
=> tf.Context.RunInAutoMode(()
=> tf.OpDefLib._apply_op_helper("Reshape", name, new { tensor, shape }).output, ()
=> tf.Runner.TFE_FastPathExecute(tf.Context, tf.Context.DeviceName,
"Reshape", name,
null,
tensor, shape).FirstOrDefault(),
tensor, shape);
=> tf.Context.ExecuteOp("Reshape", name, new ExecuteOpArgs(tensor, shape));

public static Tensor reshape(Tensor tensor, object[] shape, string name = null)
{
try
{
return tf.Context.RunInAutoMode(()
=> tf.OpDefLib._apply_op_helper("Reshape", name, new { tensor, shape }).output, ()
=> tf.Runner.TFE_FastPathExecute(tf.Context, tf.Context.DeviceName,
"Reshape", name,
null,
tensor, shape).FirstOrDefault(),
tensor, shape);
}
catch (InvalidArgumentError ex)
{
return reshape_eager_fallback(tensor, shape, name, tf.Context);
}
}
=> tf.Context.ExecuteOp("Reshape", name, new ExecuteOpArgs(tensor, shape));

private static Tensor reshape_eager_fallback(Tensor tensor, object[] shape, string name, Context ctx)
{
@@ -400,21 +268,8 @@ namespace Tensorflow
TF_DataType dtype = TF_DataType.DtInvalid,
int axis = -1,
string name = null)
{
if (tf.Context.executing_eagerly())
{
var results = tf.Runner.TFE_FastPathExecute(tf.Context, tf.Context.DeviceName,
"OneHot", name,
null,
indices, depth, on_value, off_value,
"axis", axis);

return results[0];
}

var _op = tf.OpDefLib._apply_op_helper("OneHot", name, new { indices, depth, on_value, off_value, axis });
return _op.outputs[0];
}
=> tf.Context.ExecuteOp("OneHot", name, new ExecuteOpArgs(indices, depth, on_value, off_value)
.SetAttributes(new { axis }));

/// <summary>
/// A placeholder op that passes through `input` when its output is not fed.
@@ -430,35 +285,10 @@ namespace Tensorflow
}

public static Tensor select<Tx, Ty>(Tensor condition, Tx x, Ty y, string name = null)
{
if (tf.Context.executing_eagerly())
{
var results = tf.Runner.TFE_FastPathExecute(tf.Context, tf.Context.DeviceName,
"Select", name,
null,
condition, x, y);

return results[0];
}
=> tf.Context.ExecuteOp("Select", name, new ExecuteOpArgs(condition, x, y));

var _op = tf.OpDefLib._apply_op_helper("Select", name, new { condition, t = x, e = y });
return _op.outputs[0];
}
public static Tensor select_v2<Tx, Ty>(Tensor condition, Tx x, Ty y, string name = null)
{
if (tf.Context.executing_eagerly())
{
var results = tf.Runner.TFE_FastPathExecute(tf.Context, tf.Context.DeviceName,
"SelectV2", name,
null,
condition, x, y);

return results[0];
}

var _op = tf.OpDefLib._apply_op_helper("SelectV2", name, new { condition, t = x, e = y });
return _op.outputs[0];
}
=> tf.Context.ExecuteOp("SelectV2", name, new ExecuteOpArgs(condition, x, y));

public static Tensor scatter_nd(Tensor indices, Tensor updates, Tensor[] shape, string name = null)
{
@@ -467,15 +297,8 @@ namespace Tensorflow
}

public static Tensor shape(Tensor input, TF_DataType out_type = TF_DataType.TF_INT32, string name = null)
=> tf.Context.RunInAutoMode(()
=> tf.OpDefLib._apply_op_helper("Shape", name,
new { input, out_type }).output, ()
=> tf.Runner.TFE_FastPathExecute(tf.Context, tf.Context.DeviceName,
"Shape", name,
null,
input,
"out_type", out_type).FirstOrDefault(),
input);
=> tf.Context.ExecuteOp("Shape", name, new ExecuteOpArgs(input)
.SetAttributes(new { out_type }));

/// <summary>
/// Returns shape of tensors.
@@ -485,21 +308,10 @@ namespace Tensorflow
/// <param name="name"></param>
/// <returns></returns>
public static Tensor[] shape_n(Tensor[] input, TF_DataType out_type = TF_DataType.TF_INT32, string name = null)
{
if (tf.executing_eagerly())
=> tf.Context.ExecuteOp("ShapeN", name, new ExecuteOpArgs()
{
var results = tf.Runner.TFE_FastPathExecute(tf.Context, tf.Context.DeviceName,
"ShapeN", name,
null,
input,
"out_type", out_type);

return results;
}

var _op = tf.OpDefLib._apply_op_helper("ShapeN", name, new { input, out_type });
return _op.outputs;
}
OpInputArgs = new object[] { input }
}.SetAttributes(new { out_type }));

public static Tensor size(Tensor input, TF_DataType out_type = TF_DataType.TF_INT32, string name = null)
{
@@ -542,72 +354,23 @@ namespace Tensorflow

public static Tensor[] split_v(Tensor value, Tensor size_splits,
int axis, int num_split, string name = null)
{
if (tf.executing_eagerly())
{
var results = tf.Runner.TFE_FastPathExecute(tf.Context, tf.Context.DeviceName,
"SplitV", name,
null,
value, size_splits, axis,
"num_split", num_split);

return results;
}

var _op = tf.OpDefLib._apply_op_helper("SplitV", name, new { split_dim = axis, value, num_split });
return _op.outputs;
}
=> tf.Context.ExecuteOp("SplitV", name, new ExecuteOpArgs(value, size_splits, axis)
.SetAttributes(new { num_split }));

public static Tensor tile(Tensor input, Tensor multiples, string name = null)
=> tf.Context.RunInAutoMode(()
=> tf.OpDefLib._apply_op_helper("Tile", name, new { input, multiples }).output, ()
=> tf.Runner.TFE_FastPathExecute(tf.Context, tf.Context.DeviceName,
"Tile", name,
null,
input, multiples).FirstOrDefault(),
input, multiples);
=> tf.Context.ExecuteOp("Tile", name, new ExecuteOpArgs(input, multiples));

public static Tensor tile(Tensor input, object[] multiples, string name = null)
=> tf.Context.RunInAutoMode(()
=> tf.OpDefLib._apply_op_helper("Tile", name, new { input, multiples }).output, ()
=> tf.Runner.TFE_FastPathExecute(tf.Context, tf.Context.DeviceName,
"Tile", name,
null,
input, multiples).FirstOrDefault(),
input, multiples);
=> tf.Context.ExecuteOp("Tile", name, new ExecuteOpArgs(input, multiples));

public static Tensor transpose<T1>(Tensor x, T1 perm, string name = null)
{
if (tf.Context.executing_eagerly())
{
var results = tf.Runner.TFE_FastPathExecute(tf.Context, tf.Context.DeviceName,
"Transpose", name,
null,
x, perm);

return results[0];
}
var _op = tf.OpDefLib._apply_op_helper("Transpose", name, new { x, perm });
return _op.outputs[0];
}
=> tf.Context.ExecuteOp("Transpose", name, new ExecuteOpArgs(x, perm));

public static Tensor ones_like(Tensor x, string name = null)
=> tf.Context.RunInAutoMode(()
=> tf.OpDefLib._apply_op_helper("OnesLike", name, new { x }).output, ()
=> tf.Runner.TFE_FastPathExecute(tf.Context, tf.Context.DeviceName,
"OnesLike", name,
null,
x).FirstOrDefault(),
x);
=> tf.Context.ExecuteOp("OnesLike", name, new ExecuteOpArgs(x));

public static Tensor zeros_like(Tensor x, string name = null)
=> tf.Context.RunInAutoMode(()
=> tf.OpDefLib._apply_op_helper("ZerosLike", name, new { x }).output, ()
=> tf.Runner.TFE_FastPathExecute(tf.Context, tf.Context.DeviceName,
"ZerosLike", name,
null,
x).FirstOrDefault(),
x);
=> tf.Context.ExecuteOp("ZerosLike", name, new ExecuteOpArgs(x));

public static Tensor stop_gradient(Tensor x, string name = null)
{
@@ -623,53 +386,32 @@ namespace Tensorflow
long new_axis_mask = 0,
long shrink_axis_mask = 0,
string name = null)
=> tf.Context.RunInAutoMode(()
=> tf.OpDefLib._apply_op_helper("StridedSlice", name, new
{
input,
begin,
end,
strides,
begin_mask,
end_mask,
ellipsis_mask,
new_axis_mask,
shrink_axis_mask
}).output, ()
=> tf.Runner.TFE_FastPathExecute(tf.Context, tf.Context.DeviceName,
"StridedSlice", name,
null,
input, begin, end, strides,
"begin_mask", begin_mask,
"end_mask", end_mask,
"ellipsis_mask", ellipsis_mask,
"new_axis_mask", new_axis_mask,
"shrink_axis_mask", shrink_axis_mask).FirstOrDefault(),
input, begin, end, strides);

public static Operation resource_strided_slice_assign(Tensor input, Tensor begin, Tensor end, Tensor strides, Tensor value,
=> tf.Context.ExecuteOp("StridedSlice", name, new ExecuteOpArgs(input, begin, end, strides)
.SetAttributes(new
{
begin_mask,
end_mask,
ellipsis_mask,
new_axis_mask,
shrink_axis_mask
}));

public static Tensor resource_strided_slice_assign(Tensor input, Tensor begin, Tensor end, Tensor strides, Tensor value,
int begin_mask = 0,
int end_mask = 0,
int ellipsis_mask = 0,
int new_axis_mask = 0,
int shrink_axis_mask = 0,
string name = null)
=> tf.Context.RunInAutoMode(()
=> tf.OpDefLib._apply_op_helper("ResourceStridedSliceAssign", name, new
{
input, begin, end, strides, value,
begin_mask, end_mask, ellipsis_mask, new_axis_mask, shrink_axis_mask
}).output, ()
=> tf.Runner.TFE_FastPathExecute(tf.Context, tf.Context.DeviceName,
"ResourceStridedSliceAssign", name,
null,
input, begin, end, strides, value,
"begin_mask", begin_mask,
"end_mask", end_mask,
"ellipsis_mask", ellipsis_mask,
"new_axis_mask", new_axis_mask,
"shrink_axis_mask", shrink_axis_mask).FirstOrDefault(),
input, begin, end, strides, value);
=> tf.Context.ExecuteOp("ResourceStridedSliceAssign", name, new ExecuteOpArgs(input, begin, end, strides, value)
.SetAttributes(new
{
begin_mask,
end_mask,
ellipsis_mask,
new_axis_mask,
shrink_axis_mask
}));

public static Tensor strided_slice<T>(Tensor input, T[] begin, T[] end, T[] strides,
int begin_mask = 0,
@@ -707,23 +449,8 @@ namespace Tensorflow
/// <param name="name"> A name for the operation (optional).</param>
/// <returns> A `Tensor`. Has the same type as `input`.</returns>
public static Tensor squeeze(Tensor input, int[] axis = null, string name = null)
{
if (tf.Context.executing_eagerly())
{
var results = tf.Runner.TFE_FastPathExecute(tf.Context, tf.Context.DeviceName,
"Squeeze", name,
null,
input,
"squeeze_dims", axis);

return results[0];
}

if (axis == null) axis = new int[0];
var _op = tf.OpDefLib._apply_op_helper("Squeeze", name, args: new { input, squeeze_dims = axis });

return _op.outputs[0];
}
=> tf.Context.ExecuteOp("Squeeze", name, new ExecuteOpArgs(input)
.SetAttributes(new { squeeze_dims = axis }));

/// <summary>
/// Return the shape of s0 op s1 with broadcast.
@@ -749,20 +476,6 @@ namespace Tensorflow
/// <param name="name"></param>
/// <returns></returns>
public static Tensor broadcast_to<T>(Tensor input, T shape, string name = null)
{
if (tf.Context.executing_eagerly())
{
var results = tf.Runner.TFE_FastPathExecute(tf.Context, tf.Context.DeviceName,
"BroadcastTo", name,
null,
input, shape);

return results[0];
}

var _op = tf.OpDefLib._apply_op_helper("BroadcastTo", name, args: new { input, shape, name });

return _op.outputs[0];
}
=> tf.Context.ExecuteOp("BroadcastTo", name, new ExecuteOpArgs(input, shape));
}
}

+ 24
- 94
src/TensorFlowNET.Core/Operations/gen_image_ops.cs View File

@@ -70,38 +70,17 @@ namespace Tensorflow
float acceptable_fraction = 1,
string dct_method = "",
string name = null)
{
// Add nodes to the TensorFlow graph.
if (tf.Context.executing_eagerly())
{
var results = tf.Runner.TFE_FastPathExecute(tf.Context, tf.Context.DeviceName,
"DecodeJpeg", name,
null,
contents,
"channels", channels,
"ratio", ratio,
"fancy_upscaling", fancy_upscaling,
"try_recover_truncated", try_recover_truncated,
"acceptable_fraction", acceptable_fraction,
"dct_method", dct_method);
return results[0];
}
else
{
var _op = tf.OpDefLib._apply_op_helper("DecodeJpeg", name: name, args: new
{
contents,
channels,
ratio,
fancy_upscaling,
try_recover_truncated,
acceptable_fraction,
dct_method
});

return _op.outputs[0];
}
}
=> tf.Context.ExecuteOp("DecodeJpeg", name,
new ExecuteOpArgs(contents).SetAttributes(
new
{
channels,
ratio,
fancy_upscaling,
try_recover_truncated,
acceptable_fraction,
dct_method
}));

public static Tensor decode_gif(Tensor contents,
string name = null)
@@ -171,85 +150,36 @@ namespace Tensorflow
bool align_corners = false,
bool half_pixel_centers = false,
string name = null)
{
if (tf.Context.executing_eagerly())
{
var results = tf.Runner.TFE_FastPathExecute(tf.Context, tf.Context.DeviceName,
"ResizeBilinear", name,
null,
images, size,
"align_corners", align_corners,
"half_pixel_centers", half_pixel_centers);
return results[0];
}

var _op = tf.OpDefLib._apply_op_helper("ResizeBilinear", name: name, args: new
{
images,
size,
align_corners
});

return _op.outputs[0];
}
=> tf.Context.ExecuteOp("ResizeBilinear", name,
new ExecuteOpArgs(images, size).SetAttributes(new
{
align_corners,
half_pixel_centers
}));

public static Tensor resize_bicubic(Tensor images,
Tensor size,
bool align_corners = false,
bool half_pixel_centers = false,
string name = null)
{
if (tf.Context.executing_eagerly())
{
var results = tf.Runner.TFE_FastPathExecute(tf.Context, tf.Context.DeviceName,
"ResizeBicubic", name,
null,
images, size,
"align_corners", align_corners,
"half_pixel_centers", half_pixel_centers);
return results[0];
}

var _op = tf.OpDefLib._apply_op_helper("ResizeBicubic", name: name, args: new
{
images,
size,
align_corners
});

return _op.outputs[0];
}

=> tf.Context.ExecuteOp("ResizeBicubic", name,
new ExecuteOpArgs(images, size).SetAttributes(new { align_corners, half_pixel_centers }));
public static Tensor resize_nearest_neighbor<Tsize>(Tensor images, Tsize size, bool align_corners = false,
bool half_pixel_centers = false, string name = null)
=> tf.Context.RunInAutoMode(()
=> tf.OpDefLib._apply_op_helper("ResizeNearestNeighbor", name: name, args: new
{
images,
size,
align_corners,
half_pixel_centers
}).output, ()
=> tf.Runner.TFE_FastPathExecute(tf.Context, tf.Context.DeviceName,
"ResizeNearestNeighbor", name,
null,
images, size,
"align_corners", align_corners,
"half_pixel_centers", half_pixel_centers).FirstOrDefault(),
images);
=> tf.Context.ExecuteOp("ResizeNearestNeighbor", name,
new ExecuteOpArgs(images, size).SetAttributes(new { align_corners, half_pixel_centers }));

public static Tensor resize_nearest_neighbor_grad(Tensor grads, Tensor size, bool align_corners = false,
bool half_pixel_centers = false, string name = null)
=> tf.Context.RunInAutoMode2("ResizeNearestNeighborGrad", name, new AutoModeArgs
=> tf.Context.ExecuteOp("ResizeNearestNeighborGrad", name, new ExecuteOpArgs(grads, size)
{
OpInputArgs = new { grads, size },
OpAttrs = new { align_corners, half_pixel_centers },
GetGradientAttrs = (op) => new
{
T = op.get_attr<TF_DataType>("T"),
align_corners = op.get_attr<bool>("align_corners"),
half_pixel_centers = op.get_attr<bool>("half_pixel_centers")
}
});
}.SetAttributes(new { align_corners, half_pixel_centers }));
}
}

+ 2
- 3
src/TensorFlowNET.Core/Operations/gen_logging_ops.cs View File

@@ -25,10 +25,9 @@ namespace Tensorflow
{
if (tf.Context.executing_eagerly())
{
var results = tf.Runner.TFE_FastPathExecute(tf.Context, tf.Context.DeviceName,
var results = tf.Runner.TFE_FastPathExecute(new FastPathOpExecInfo(
"Assert", name,
null,
new object[] { condition, data, summarize });
new object[] { condition, data, summarize }));

return results[0];
}


+ 72
- 698
src/TensorFlowNET.Core/Operations/gen_math_ops.cs
File diff suppressed because it is too large
View File


+ 1
- 8
src/TensorFlowNET.Core/Operations/gen_math_ops.eager.cs View File

@@ -6,13 +6,6 @@ namespace Tensorflow
public static partial class gen_math_ops
{
public static Tensor mul(IntPtr x, IntPtr y, string name = null)
{
var results = tf.Runner.TFE_FastPathExecute(tf.Context, tf.Context.DeviceName,
"Mul", name,
null,
x, y);

return results[0];
}
=> tf.Context.ExecuteOp("Mul", name, new ExecuteOpArgs(x, y));
}
}

+ 8
- 92
src/TensorFlowNET.Core/Operations/gen_random_ops.cs View File

@@ -29,31 +29,8 @@ namespace Tensorflow
/// <param name="name"></param>
/// <returns></returns>
public static Tensor random_standard_normal(Tensor shape, TF_DataType dtype = TF_DataType.DtInvalid, int? seed = null, int? seed2 = null, string name = null)
{
if (tf.executing_eagerly())
{
var results = tf.Runner.TFE_FastPathExecute(tf.Context, tf.Context.DeviceName,
"RandomStandardNormal", name,
null,
shape,
"seed", seed,
"seed2", seed2,
"dtype", dtype);

return results[0];
}

if (!seed.HasValue)
seed = 0;
if (!seed2.HasValue)
seed2 = 0;

var _op = tf.OpDefLib._apply_op_helper("RandomStandardNormal",
name: name,
args: new { shape, dtype, seed, seed2 });

return _op.output;
}
=> tf.Context.ExecuteOp("RandomStandardNormal", name, new ExecuteOpArgs(shape)
.SetAttributes(new { dtype, seed = seed ?? 0, seed2 = seed2 ?? 0 }));

/// <summary>
/// Outputs random integers from a uniform distribution.
@@ -89,31 +66,8 @@ namespace Tensorflow
/// <param name="name"></param>
/// <returns></returns>
public static Tensor random_uniform(Tensor shape, TF_DataType dtype, int? seed = 0, int? seed2 = 0, string name = null)
{
if (!seed.HasValue)
seed = 0;
if (!seed2.HasValue)
seed2 = 0;

if (tf.executing_eagerly())
{
var results = tf.Runner.TFE_FastPathExecute(tf.Context, tf.Context.DeviceName,
"RandomUniform", name,
null,
shape,
"seed", seed,
"seed2", seed2,
"dtype", dtype);

return results[0];
}

var _op = tf.OpDefLib._apply_op_helper("RandomUniform",
name: name,
args: new { shape, dtype, seed, seed2 });

return _op.outputs[0];
}
=> tf.Context.ExecuteOp("RandomUniform", name, new ExecuteOpArgs(shape)
.SetAttributes(new { dtype, seed = seed ?? 0, seed2 = seed2 ?? 0 }));

/// <summary>
///
@@ -125,23 +79,8 @@ namespace Tensorflow
/// <returns></returns>
public static Tensor random_shuffle(Tensor value, int seed = 0, int seed2 = 0,
string name = null)
{
if (tf.Context.executing_eagerly())
{
var results = tf.Runner.TFE_FastPathExecute(tf.Context, tf.Context.DeviceName,
"RandomShuffle", name,
null,
value, seed, seed2);

return results[0];
}

var _op = tf.OpDefLib._apply_op_helper("RandomShuffle",
name: name,
args: new { value, seed, seed2 });

return _op.output;
}
=> tf.Context.ExecuteOp("RandomShuffle", name, new ExecuteOpArgs(value)
.SetAttributes(new { seed = seed, seed2 = seed2 }));

/// <summary>
/// Outputs random values from a truncated normal distribution.
@@ -154,31 +93,8 @@ namespace Tensorflow
/// <returns></returns>
public static Tensor truncated_normal(Tensor shape, TF_DataType dtype, int? seed = 0,
int? seed2 = 0, string name = null)
{
if (!seed.HasValue)
seed = 0;
if (!seed2.HasValue)
seed2 = 0;

if (tf.Context.executing_eagerly())
{
var results = tf.Runner.TFE_FastPathExecute(tf.Context, tf.Context.DeviceName,
"TruncatedNormal", name,
null,
shape,
"seed", seed,
"seed2", seed2,
"dtype", dtype);

return results[0];
}

var _op = tf.OpDefLib._apply_op_helper("TruncatedNormal",
name: name,
args: new { shape, dtype, seed, seed2 });

return _op.output;
}
=> tf.Context.ExecuteOp("TruncatedNormal", name, new ExecuteOpArgs(shape)
.SetAttributes(new { dtype, seed = seed ?? 0, seed2 = seed2 ?? 0 }));

public static Tensor multinomial(Tensor logits, int num_samples, int? seed = 0,
int? seed2 = 0, TF_DataType output_dtype = TF_DataType.TF_INT64, string name = null)


+ 23
- 64
src/TensorFlowNET.Core/Operations/gen_resource_variable_ops.cs View File

@@ -24,10 +24,8 @@ namespace Tensorflow
{
if (tf.Context.executing_eagerly())
{
var results = tf.Runner.TFE_FastPathExecute(tf.Context, tf.Context.DeviceName,
"AssignSubVariableOp", name,
null,
resource, value);
tf.Runner.TFE_FastPathExecute(new FastPathOpExecInfo(
"AssignSubVariableOp", name, resource, value));

return null;
}
@@ -46,10 +44,8 @@ namespace Tensorflow
{
if (tf.Context.executing_eagerly())
{
var results = tf.Runner.TFE_FastPathExecute(tf.Context, tf.Context.DeviceName,
"AssignAddVariableOp", name,
null,
resource, value);
tf.Runner.TFE_FastPathExecute(new FastPathOpExecInfo("AssignAddVariableOp", name,
resource, value));

return null;
}
@@ -63,10 +59,8 @@ namespace Tensorflow
{
if (tf.Context.executing_eagerly())
{
tf.Runner.TFE_FastPathExecute(tf.Context, tf.Context.DeviceName,
"AssignVariableOp", name,
null,
resource, value);
tf.Runner.TFE_FastPathExecute(new FastPathOpExecInfo("AssignVariableOp", name,
resource, value));

return null;
}
@@ -80,10 +74,8 @@ namespace Tensorflow
{
if (tf.Context.executing_eagerly())
{
var results = tf.Runner.TFE_FastPathExecute(tf.Context, tf.Context.DeviceName,
"VarIsInitializedOp", name,
null,
resource);
var results = tf.Runner.TFE_FastPathExecute(new FastPathOpExecInfo("VarIsInitializedOp", name,
resource));

return results[0];
}
@@ -107,14 +99,17 @@ namespace Tensorflow
{
if (tf.Context.executing_eagerly())
{
var results = tf.Runner.TFE_FastPathExecute(tf.Context, tf.Context.DeviceName,
"VarHandleOp", name,
null,
"container", container,
"shared_name", shared_name,
"dtype", dtype,
"shape", shape.dims,
"allowed_devices", new string[0]);
var results = tf.Runner.TFE_FastPathExecute(new FastPathOpExecInfo("VarHandleOp", name)
{
attrs = ConvertToDict(new
{
dtype,
shape = shape.dims,
container,
shared_name,
allowed_devices = new string[0]
})
});

return results[0];
}
@@ -131,26 +126,8 @@ namespace Tensorflow
}

public static Tensor destroy_resource_op(Tensor resource, bool ignore_lookup_error = true, string name = null)
{
if (tf.Context.executing_eagerly())
{
var results = tf.Runner.TFE_FastPathExecute(tf.Context, tf.Context.DeviceName,
"DestroyResourceOp", name,
null,
resource,
"ignore_lookup_error", ignore_lookup_error);

return results.Length == 0 ? null : results[0];
}

var _op = tf.OpDefLib._apply_op_helper("DestroyResourceOp", name, new
{
resource,
ignore_lookup_error
});

return _op.output;
}
=> tf.Context.ExecuteOp("DestroyResourceOp", name,
new ExecuteOpArgs(resource).SetAttributes(new { ignore_lookup_error }));

/// <summary>
/// Reads the value of a variable.
@@ -160,26 +137,8 @@ namespace Tensorflow
/// <param name="name"></param>
/// <returns></returns>
public static Tensor read_variable_op(Tensor resource, TF_DataType dtype, string name = null)
{
if (tf.executing_eagerly())
{
var results = tf.Runner.TFE_FastPathExecute(tf.Context, tf.Context.DeviceName,
"ReadVariableOp", name,
null,
resource,
"dtype", dtype);

return results[0];
}

var _op = tf.OpDefLib._apply_op_helper("ReadVariableOp", name, new
{
resource,
dtype
});

return _op.output;
}
=> tf.Context.ExecuteOp("ReadVariableOp", name, new ExecuteOpArgs(resource)
.SetAttributes(new { dtype }));

public static Tensor resource_gather(Tensor resource, Tensor indices, TF_DataType dtype,
int batch_dims = 0, bool validate_indices = true, string name = null)


+ 39
- 44
src/TensorFlowNET.Core/Operations/math_ops.cs View File

@@ -45,10 +45,7 @@ namespace Tensorflow
=> gen_math_ops.add(x, y, name);

public static Tensor add_v2(Tensor x, Tensor y, string name = null)
=> tf.Context.RunInAutoMode2("AddV2", name, new AutoModeArgs
{
OpInputArgs = new { x, y }
});
=> tf.Context.ExecuteOp("AddV2", name, new ExecuteOpArgs(x, y));

public static Tensor add_v2<Tx, Ty>(Tx x, Ty y, string name = null)
=> gen_math_ops.add_v2(x, y, name);
@@ -171,15 +168,12 @@ namespace Tensorflow
}

public static Tensor cumsum<T>(Tensor x, T axis = default, bool exclusive = false, bool reverse = false, string name = null)
{
return tf_with(ops.name_scope(name, "Cumsum", new { x }), scope =>
{
name = scope;
x = ops.convert_to_tensor(x, name: "x");

return gen_math_ops.cumsum(x, axis: axis, exclusive: exclusive, reverse: reverse, name: name);
});
}
=> tf_with(ops.name_scope(name, "Cumsum", new { x }), scope =>
{
name = scope;
return tf.Context.ExecuteOp("Cumsum", name, new ExecuteOpArgs(x, axis)
.SetAttributes(new { exclusive, reverse }));
});

/// <summary>
/// Computes Psi, the derivative of Lgamma (the log of the absolute value of
@@ -261,19 +255,13 @@ namespace Tensorflow
/// <param name="name"></param>
/// <returns></returns>
public static Tensor erf(Tensor x, string name = null)
=> tf.Context.RunInAutoMode2("Erf", name, new AutoModeArgs
{
OpInputArgs = new { x }
});
=> tf.Context.ExecuteOp("Erf", name, new ExecuteOpArgs(x));

public static Tensor sqrt(Tensor x, string name = null)
=> gen_math_ops.sqrt(x, name: name);

public static Tensor multiply(Tensor x, Tensor y, string name = null)
=> tf.Context.RunInAutoMode2("Mul", name, new AutoModeArgs
{
OpInputArgs = new { x, y }
});
=> tf.Context.ExecuteOp("Mul", name, new ExecuteOpArgs(x, y));

public static Tensor multiply<Tx, Ty>(Tx x, Ty y, string name = null)
=> gen_math_ops.mul(x, y, name: name);
@@ -720,23 +708,10 @@ namespace Tensorflow
=> tf_with(ops.name_scope(name, "Pow", new { x, y }), scope =>
{
name = scope;
var x_tensor = ops.convert_to_tensor(x, name: "x");
var y_tensor = ops.convert_to_tensor(y, name: "y", dtype: x_tensor.dtype.as_base_dtype());

if (tf.executing_eagerly())
{
var x_tensor = ops.convert_to_tensor(x, name: "x");
var y_tensor = ops.convert_to_tensor(y, name: "y", dtype: x_tensor.dtype.as_base_dtype());

var results = tf.Runner.TFE_FastPathExecute(tf.Context, tf.Context.DeviceName,
"Pow", name,
null,
x_tensor, y_tensor);

return results[0];
}

var _op = tf.OpDefLib._apply_op_helper("Pow", name, args: new { x, y });

return _op.output;
return tf.Context.ExecuteOp("Pow", name, new ExecuteOpArgs(x_tensor, y_tensor));
});

public static Tensor range(object start, object limit = null, object delta = null, TF_DataType dtype = TF_DataType.DtInvalid, string name = "range")
@@ -818,21 +793,41 @@ namespace Tensorflow
public static Tensor batch_matmul(Tensor x, Tensor y,
bool adj_x = false, bool adj_y = false,
string name = null)
{
Tensor result = null;

tf_with(ops.name_scope(name, "MatMul", new Tensor[] { x, y }), scope =>
=> tf_with(ops.name_scope(name, "MatMul", new Tensor[] { x, y }), scope =>
{
name = scope;

x = ops.convert_to_tensor(x, name: "a");
y = ops.convert_to_tensor(y, name: "b");

result = gen_math_ops.batch_mat_mul(x, y, adj_x, adj_y, name);
return tf.Context.ExecuteOp("BatchMatMul", name, new ExecuteOpArgs(x, y)
.SetAttributes(new { adj_x, adj_y }));
});

return result;
}
public static Tensor bincount(Tensor arr, Tensor weights = null,
Tensor minlength = null,
Tensor maxlength = null,
TF_DataType dtype = TF_DataType.TF_INT32,
string name = null,
TensorShape axis = null,
bool binary_output = false)
=> tf_with(ops.name_scope(name, "bincount"), scope =>
{
name = scope;
if(!binary_output && axis == null)
{
var array_is_nonempty = math_ops.reduce_prod(array_ops.shape(arr)) > 0;
var output_size = math_ops.cast(array_is_nonempty, dtypes.int32) * (math_ops.reduce_max(arr) + 1);
if (minlength != null)
output_size = math_ops.maximum(minlength, output_size);
if (maxlength != null)
output_size = math_ops.minimum(maxlength, output_size);
var weights = constant_op.constant(new long[0], dtype: dtype);
return tf.Context.ExecuteOp("Bincount", name, new ExecuteOpArgs(arr, output_size, weights));
}

throw new NotImplementedException("");
});

/// <summary>
/// Returns the complex conjugate of a complex number.


+ 87
- 60
src/TensorFlowNET.Core/Operations/string_ops.cs View File

@@ -14,6 +14,8 @@
limitations under the License.
******************************************************************************/

using NumSharp;
using Tensorflow.Framework;
using static Tensorflow.Binding;

namespace Tensorflow
@@ -21,53 +23,13 @@ namespace Tensorflow
public class string_ops
{
public Tensor lower(Tensor input, string encoding = "", string name = null)
{
if (tf.Context.executing_eagerly())
{
var results = tf.Runner.TFE_FastPathExecute(tf.Context, tf.Context.DeviceName,
"StringLower", name,
null,
input, encoding);

return results[0];
}

var _op = tf.OpDefLib._apply_op_helper("StringLower", name: name, args: new
{
input,
encoding
});

return _op.output;
}
=> tf.Context.ExecuteOp("StringLower", name, new ExecuteOpArgs(input, encoding));

public Tensor regex_replace(Tensor input, string pattern, string rewrite,
bool replace_global = true, string name = null)
{
if (tf.Context.executing_eagerly())
{
var results = tf.Runner.TFE_FastPathExecute(tf.Context, tf.Context.DeviceName,
"StaticRegexReplace", name,
null,
input,
"pattern", pattern,
"rewrite", rewrite,
"replace_global", replace_global);

return results[0];
}

var _op = tf.OpDefLib._apply_op_helper("StaticRegexReplace", name: name, args: new
{
input,
pattern,
rewrite,
replace_global
});

return _op.output;
}

=> tf.Context.ExecuteOp("StaticRegexReplace", name, new ExecuteOpArgs(input)
.SetAttributes(new { pattern, rewrite, replace_global }));
/// <summary>
/// Return substrings from `Tensor` of strings.
/// </summary>
@@ -79,28 +41,93 @@ namespace Tensorflow
/// <returns></returns>
public Tensor substr<T>(T input, int pos, int len,
string @uint = "BYTE", string name = null)
{
if (tf.Context.executing_eagerly())
=> tf.Context.ExecuteOp("Substr", name, new ExecuteOpArgs(input, pos, len)
.SetAttributes(new { unit = @uint }));

/// <summary>
/// Computes the length of each string given in the input tensor.
/// </summary>
/// <param name="input"></param>
/// <param name="name"></param>
/// <param name="unit"></param>
/// <returns></returns>
public Tensor string_length(Tensor input, string name = null, string unit = "BYTE")
=> tf.Context.ExecuteOp("StringLength", name, new ExecuteOpArgs(input)
{
var input_tensor = tf.constant(input);
var results = tf.Runner.TFE_FastPathExecute(tf.Context, tf.Context.DeviceName,
"Substr", name,
null,
input, pos, len,
"unit", @uint);
GetGradientAttrs = op => new
{
unit = op.get_attr<string>("unit")
}
}.SetAttributes(new { unit }));

return results[0];
}
public RaggedTensor string_split_v2(Tensor input, string sep = "", int maxsplit = -1, string name = null)
{
return tf_with(ops.name_scope(name, "StringSplit"), scope =>
{
var sep_tensor = ops.convert_to_tensor(sep, dtype: TF_DataType.TF_STRING);
var result = tf.Context.ExecuteOp("StringSplitV2", name,
new ExecuteOpArgs(input, sep)
{
GetGradientAttrs = op => new
{
maxsplit = op.get_attr<int>("maxsplit")
}
}.SetAttributes(new { maxsplit }));
var (indices, values, shape) = (result[0], result[1], result[2]);
indices.set_shape(new TensorShape(-1, 2));
values.set_shape(new TensorShape(-1));
shape.set_shape(new TensorShape(2));

var sparse_result = new SparseTensor(indices, values, shape);
return RaggedTensor.from_value_rowids(sparse_result.values,
value_rowids: sparse_result.indices[Slice.All, 0],
nrows: sparse_result.dense_shape[0],
validate: false);
});
}

var _op = tf.OpDefLib._apply_op_helper("Substr", name: name, args: new
public (RaggedTensor, RaggedTensor) unicode_decode_with_offsets(Tensor input, string input_encoding, string errors,
int replacement_char = 0xFFFD, bool replace_control_characters = false, string name = null)
{
return tf_with(ops.name_scope(name, "UnicodeDecodeWithOffsets"), scope =>
{
input,
pos,
len,
unit = @uint
var (codepoints, byte_start_offsets) = _unicode_decode(input, input_encoding, errors,
replacement_char, replace_control_characters,
with_offsets: true, name: name);
return (codepoints, byte_start_offsets);
});
}

(RaggedTensor, RaggedTensor) _unicode_decode(Tensor input, string input_encoding, string errors, int replacement_char,
bool replace_control_characters, bool with_offsets, string name = null)
{
if (with_offsets)
{
var flat_result = tf.Context.ExecuteOp("UnicodeDecodeWithOffsets", name, new ExecuteOpArgs(input)
{
GetGradientAttrs = op => new
{
input_encoding = op.get_attr<string>("input_encoding"),
errors = op.get_attr<string>("errors"),
replacement_char = op.get_attr<int>("replacement_char"),
replace_control_characters = op.get_attr<bool>("replace_control_characters"),
Tsplits = op.get_attr<TF_DataType>("Tsplits")
}
}.SetAttributes(new
{
input_encoding,
errors,
replacement_char,
replace_control_characters
}));

var codepoints = RaggedTensor.from_row_splits(flat_result[1], flat_result[0], validate: false);

var offsets = RaggedTensor.from_row_splits(flat_result[2], flat_result[0], validate: false);
return (codepoints, offsets);
}

return _op.output;
return (null, null);
}
}
}

+ 1
- 0
src/TensorFlowNET.Core/Tensorflow.Binding.csproj View File

@@ -50,6 +50,7 @@ tf.net 0.4x.x aligns with TensorFlow v2.4.1 native library.</PackageReleaseNotes
<AllowUnsafeBlocks>true</AllowUnsafeBlocks>
<DefineConstants>TRACE;DEBUG</DefineConstants>
<PlatformTarget>x64</PlatformTarget>
<DocumentationFile>TensorFlow.NET.xml</DocumentationFile>
</PropertyGroup>

<PropertyGroup Condition="'$(Configuration)|$(Platform)'=='Release|AnyCPU'">


+ 1
- 1
src/TensorFlowNET.Core/Tensors/EagerTensorV2.cs View File

@@ -7,7 +7,7 @@ using static Tensorflow.Binding;

namespace Tensorflow
{
public class EagerTensorV2 : DisposableObject, ITensor
public class EagerTensorV2 : DisposableObject
{
SafeTensorHandleHandle EagerTensorHandle;
public string Device


+ 0
- 7
src/TensorFlowNET.Core/Tensors/ITensor.cs View File

@@ -1,7 +0,0 @@
namespace Tensorflow
{
public interface ITensor
{

}
}

+ 147
- 0
src/TensorFlowNET.Core/Tensors/Ragged/RaggedTensor.cs View File

@@ -0,0 +1,147 @@
/*****************************************************************************
Copyright 2021 Haiping Chen. All Rights Reserved.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
******************************************************************************/

using System;
using System.Collections.Generic;
using System.Text;
using System.Linq;
using Tensorflow.Framework;
using static Tensorflow.Binding;
using NumSharp;

namespace Tensorflow
{
/// <summary>
/// Represents a ragged tensor.
/// </summary>
public class RaggedTensor : CompositeTensor
{
Tensor _values;
RowPartition _row_partition;
Tensor _row_splits => _row_partition.row_splits;

public TF_DataType dtype => _values.dtype;
public TensorShape shape
{
get
{
var nrows = _row_partition.static_nrows;
var ncols = _row_partition.static_uniform_row_length;
return new TensorShape(nrows, ncols);
}
}

public RaggedTensor this[params Slice[] slices]
{
get
{
var row_key = slices[0];
var inner_keys = slices.Skip(1).ToArray();

var args = tensor_util.ParseSlices(slices);

return tf_with(ops.name_scope(null, "RaggedGetItem", args), scope =>
{
string name = scope;
return _ragged_getitem_inner_dimensions(this, inner_keys);
});
}
}

RaggedTensor _ragged_getitem_inner_dimensions(RaggedTensor input, Slice[] slices)
{
return input;
}

public RaggedTensor(Tensor values,
bool @internal = true,
RowPartition row_partition = null)
{
_values = values;
_row_partition = row_partition;
}

public static RaggedTensor from_row_partition(Tensor values, RowPartition row_partition, bool validate = true)
{
return new RaggedTensor(values, @internal: true, row_partition: row_partition);
}

/// <summary>
/// Creates a `RaggedTensor` with rows partitioned by `value_rowids`.
/// </summary>
/// <param name="values"></param>
/// <param name="value_rowids"></param>
/// <param name="nrows"></param>
/// <param name="name"></param>
/// <param name="validate"></param>
/// <returns></returns>
public static RaggedTensor from_value_rowids(Tensor values, Tensor value_rowids,
Tensor nrows = null, string name = null, bool validate = true)
{
return tf_with(ops.name_scope(name, "RaggedFromValueRowIds"), scope =>
{
var row_partition = RowPartition.from_value_rowids(value_rowids,
nrows: nrows,
validate: validate);
return from_row_partition(values, row_partition, validate: validate);
});
}

public static RaggedTensor from_row_splits(Tensor values, Tensor row_splits,
string name = null, bool validate = true)
{
return tf_with(ops.name_scope(name, "RaggedFromRowSplits"), scope =>
{
var row_partition = RowPartition.from_row_splits(row_splits,
validate: validate);
return from_row_partition(values, row_partition, validate: validate);
});
}

Tensor _to_variant(bool batched_input = false, string name = null)
=> tf_with(ops.name_scope(name, "RaggedToVariant"), scope =>
{
return tf.Context.ExecuteOp("RaggedTensorToVariant", name,
new ExecuteOpArgs(nested_row_splits, flat_values)
{
GetGradientAttrs = op => new
{
RAGGED_RANK = op.get_attr<int>("RAGGED_RANK"),
Tvalues = op.get_attr<TF_DataType>("Tvalues"),
Tsplits = op.get_attr<TF_DataType>("Tsplits"),
batched_input = op.get_attr<bool>("batched_input")
}
}.SetAttributes(new { batched_input }));
});

Tensor flat_values
=> _values;

Tensor[] nested_row_splits
=> new[] { _row_splits };

public override string ToString()
=> $"tf.RaggedTensor: shape={shape} [{string.Join(", ", _values.StringData().Take(10))}]";

public static implicit operator Tensor(RaggedTensor indexedSlices)
=> indexedSlices._to_variant();

public static implicit operator RaggedTensor(Tensor tensor)
{
return tensor.Tag as RaggedTensor;
}
}
}

+ 103
- 0
src/TensorFlowNET.Core/Tensors/Ragged/RowPartition.cs View File

@@ -0,0 +1,103 @@
/*****************************************************************************
Copyright 2021 Haiping Chen. All Rights Reserved.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
******************************************************************************/

using System;
using System.Collections.Generic;
using System.Text;
using Tensorflow.Framework;
using static Tensorflow.Binding;

namespace Tensorflow
{
/// <summary>
/// Partitioning of a sequence of values into contiguous subsequences ("rows").
/// </summary>
public class RowPartition : CompositeTensor
{
Tensor _row_splits;
public Tensor row_splits => _row_splits;
Tensor _row_lengths;
Tensor _value_rowids;
Tensor _nrows;

public int static_nrows
{
get
{
return _row_splits.shape[0] - 1;
}
}

public int static_uniform_row_length
{
get
{
return -1;
}
}

public RowPartition(Tensor row_splits,
Tensor row_lengths = null, Tensor value_rowids = null, Tensor nrows = null,
Tensor uniform_row_length = null)
{
_row_splits = row_splits;
_row_lengths = row_lengths;
_value_rowids = value_rowids;
_nrows = nrows;
}

/// <summary>
/// Creates a `RowPartition` with rows partitioned by `value_rowids`.
/// </summary>
/// <param name="value_rowids"></param>
/// <param name="nrows"></param>
/// <param name="validate"></param>
/// <param name="preferred_dtype"></param>
/// <returns></returns>
public static RowPartition from_value_rowids(Tensor value_rowids,
Tensor nrows = null, bool validate = true, TF_DataType preferred_dtype = TF_DataType.DtInvalid)
{
return tf_with(ops.name_scope(null, "RowPartitionFromValueRowIds"), scope =>
{
var value_rowids_int32 = math_ops.cast(value_rowids, dtypes.int32);
var nrows_int32 = math_ops.cast(nrows, dtypes.int32);
var row_lengths = tf.math.bincount(value_rowids_int32,
minlength: nrows_int32,
maxlength: nrows_int32,
dtype: value_rowids.dtype);
var row_splits = array_ops.concat(new object[]
{
ops.convert_to_tensor(new long[] { 0 }),
tf.cumsum(row_lengths)
}, axis: 0);

return new RowPartition(row_splits,
row_lengths: row_lengths,
value_rowids: value_rowids,
nrows: nrows);
});
}

public static RowPartition from_row_splits(Tensor row_splits,
bool validate = true, TF_DataType preferred_dtype = TF_DataType.DtInvalid)
{
return tf_with(ops.name_scope(null, "RowPartitionFromRowSplits"), scope =>
{
return new RowPartition(row_splits);
});
}
}
}

+ 76
- 0
src/TensorFlowNET.Core/Tensors/Ragged/SparseTensor.cs View File

@@ -0,0 +1,76 @@
/*****************************************************************************
Copyright 2021 Haiping Chen. All Rights Reserved.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
******************************************************************************/

using System;
using System.Linq;
using Tensorflow.Framework;
using static Tensorflow.Binding;

namespace Tensorflow
{
/// <summary>
/// Represents a sparse tensor.
/// </summary>
public class SparseTensor : CompositeTensor
{
public Tensor indices;

public Tensor values;

public Tensor dense_shape;

public SparseTensor(Tensor indices, Tensor values, Tensor dense_shape)
{
this.indices = indices;
this.values = values;
this.dense_shape = dense_shape;
_init();
}

public SparseTensor(long[,] indices_, Array values_, long[] dense_shape_)
{
tf_with(ops.name_scope(null, "SparseTensor", new { }), delegate
{
indices = ops.convert_to_tensor(
indices_, name: "indices", dtype: dtypes.int64);
values = ops.convert_to_tensor(values_, name: "values");
dense_shape = ops.convert_to_tensor(
dense_shape_, name: "dense_shape", dtype: dtypes.int64);
});
_init();
}

void _init()
{
var indices_shape = indices.TensorShape.with_rank(2);
var values_shape = values.TensorShape.with_rank(1);
var dense_shape_shape = dense_shape.TensorShape.with_rank(1);

indices_shape["0"].merge_with(values_shape[0]);
indices_shape["1"].merge_with(dense_shape_shape[0]);
}

public static implicit operator Tensor(SparseTensor indexedSlices)
{
return indexedSlices.values;
}

public static implicit operator SparseTensor(Tensor tensor)
{
return tensor.Tag as SparseTensor;
}
}
}

+ 1
- 2
src/TensorFlowNET.Core/Tensors/Tensor.cs View File

@@ -33,9 +33,7 @@ namespace Tensorflow
/// </summary>
[SuppressMessage("ReSharper", "ConvertToAutoProperty")]
public partial class Tensor : DisposableObject,
ITensor,
ITensorOrOperation,
_TensorLike,
ITensorOrTensorArray,
IPackable<Tensor>,
ICanBeFlattened
@@ -97,6 +95,7 @@ namespace Tensorflow
public SafeTensorHandleHandle EagerTensorHandle { get; set; }

public bool IsEagerTensor => this is EagerTensor;
public bool IsSparseTensor => this is SparseTensor;

/// <summary>
/// Returns the shape of a tensor.


+ 10
- 56
src/TensorFlowNET.Core/Training/gen_training_ops.cs View File

@@ -21,46 +21,19 @@ namespace Tensorflow
{
public class gen_training_ops
{
public static Operation resource_apply_adam(Tensor var, Tensor m, Tensor v, Tensor beta1_power, Tensor beta2_power,
public static Tensor resource_apply_adam(Tensor var, Tensor m, Tensor v, Tensor beta1_power, Tensor beta2_power,
Tensor lr, Tensor beta1, Tensor beta2, Tensor epsilon, Tensor grad,
bool use_locking = false, bool use_nesterov = false, string name = null)
{
if (tf.executing_eagerly())
{
var result = tf.Runner.TFE_FastPathExecute(tf.Context, tf.Context.DeviceName,
"ResourceApplyAdam", name,
null,
var, m, v, beta1_power, beta2_power, lr, beta1, beta2, epsilon, grad,
"use_locking", use_locking,
"use_nesterov", use_nesterov);
return null;
}

throw new NotImplementedException("");
}
=> tf.Context.ExecuteOp("ResourceApplyAdam", name,
new ExecuteOpArgs(var, m, v, beta1_power, beta2_power, lr, beta1, beta2, epsilon, grad)
.SetAttributes(new { use_locking, use_nesterov }));

public static Tensor apply_adam(Tensor var, Tensor m, Tensor v, Tensor beta1_power, Tensor beta2_power,
Tensor lr, Tensor beta1, Tensor beta2, Tensor epsilon, Tensor grad,
bool use_locking = false, bool use_nesterov = false, string name = null)
{
var _op = tf.OpDefLib._apply_op_helper("ApplyAdam", name, new
{
var,
m,
v,
beta1_power,
beta2_power,
lr,
beta1,
beta2,
epsilon,
grad,
use_locking,
use_nesterov
});

return _op.outputs[0];
}
=> tf.Context.ExecuteOp("ApplyAdam", name,
new ExecuteOpArgs(var, m, v, beta1_power, beta2_power, lr, beta1, beta2, epsilon, grad)
.SetAttributes(new { use_locking, use_nesterov }));

public static Tensor apply_gradient_descent(IVariableV1 var, Tensor alpha, Tensor delta, bool use_locking = false, string name = null)
{
@@ -75,27 +48,8 @@ namespace Tensorflow
return _op.output;
}

public static Operation resource_apply_gradient_descent(Tensor var, Tensor alpha, Tensor delta, bool use_locking = false, string name = null)
{
if (tf.executing_eagerly())
{
var result = tf.Runner.TFE_FastPathExecute(tf.Context, tf.Context.DeviceName,
"ResourceApplyGradientDescent", name,
null,
var, alpha, delta,
"use_locking", use_locking);
return null;
}

var _op = tf.OpDefLib._apply_op_helper("ResourceApplyGradientDescent", name, new
{
var,
alpha,
delta,
use_locking
});

return _op;
}
public static Tensor resource_apply_gradient_descent(Tensor var, Tensor alpha, Tensor delta, bool use_locking = false, string name = null)
=> tf.Context.ExecuteOp("ResourceApplyGradientDescent", name,
new ExecuteOpArgs(var, alpha, delta).SetAttributes(new { use_locking }));
}
}

+ 2
- 25
src/TensorFlowNET.Core/Variables/gen_state_ops.py.cs View File

@@ -59,31 +59,8 @@ namespace Tensorflow
bool validate_shape = true,
bool use_locking = true,
string name = null)
{
if (tf.executing_eagerly())
{
var results = tf.Runner.TFE_FastPathExecute(tf.Context, tf.Context.DeviceName,
"Assign", name,
null,
@ref, value,
"validate_shape", validate_shape,
"use_locking", use_locking);

return results[0];
}

var _op = tf.OpDefLib._apply_op_helper("Assign", name: name, args: new { @ref, value, validate_shape, use_locking });

var _result = _op.outputs;
var _inputs_flat = _op.inputs;

var _attrs = new Dictionary<string, object>();
_attrs["T"] = _op.get_attr("T");
_attrs["validate_shape"] = _op.get_attr("validate_shape");
_attrs["use_locking"] = _op.get_attr("use_locking");

return _result[0];
}
=> tf.Context.ExecuteOp("Assign", name, new ExecuteOpArgs(@ref, value)
.SetAttributes(new { validate_shape, use_locking }));

public static Tensor assign_add<T>(IVariableV1 @ref, T value, bool use_locking = false, string name = null)
{


+ 2
- 16
src/TensorFlowNET.Keras/Activations/Activations.Relu.cs View File

@@ -4,21 +4,7 @@ namespace Tensorflow.Keras
{
public partial class Activations
{
public Activation Relu = (features, name) =>
{
if (tf.executing_eagerly())
{
var results = tf.Runner.TFE_FastPathExecute(tf.Context, tf.Context.DeviceName,
"Relu", name,
null,
features);

return results[0];
}

var _op = tf.OpDefLib._apply_op_helper("Relu", name: name, args: new { features });

return _op.output;
};
public Activation Relu = (features, name)
=> tf.Context.ExecuteOp("Relu", name, new ExecuteOpArgs(features));
}
}

+ 2
- 16
src/TensorFlowNET.Keras/Activations/Activations.Sigmoid.cs View File

@@ -5,21 +5,7 @@ namespace Tensorflow.Keras
{
public partial class Activations
{
public Activation Sigmoid = (features, name) =>
{
if (tf.executing_eagerly())
{
var results = tf.Runner.TFE_FastPathExecute(tf.Context, tf.Context.DeviceName,
"Sigmoid", name,
null,
features);

return results[0];
}

var _op = tf.OpDefLib._apply_op_helper("Sigmoid", name: name, args: new { x = features });

return _op.output;
};
public Activation Sigmoid = (features, name)
=> tf.Context.ExecuteOp("Sigmoid", name, new ExecuteOpArgs(features));
}
}

+ 2
- 16
src/TensorFlowNET.Keras/Activations/Activations.Tanh.cs View File

@@ -5,21 +5,7 @@ namespace Tensorflow.Keras
{
public partial class Activations
{
public Activation Tanh = (features, name) =>
{
if (tf.executing_eagerly())
{
var results = tf.Runner.TFE_FastPathExecute(tf.Context, tf.Context.DeviceName,
"Tanh", name,
null,
features);

return results[0];
}

var _op = tf.OpDefLib._apply_op_helper("Tanh", name: name, args: new { x = features });

return _op.output;
};
public Activation Tanh = (features, name)
=> tf.Context.ExecuteOp("Tanh", name, new ExecuteOpArgs(features));
}
}

+ 13
- 1
src/TensorFlowNET.Keras/Engine/CombinerPreprocessingLayer.cs View File

@@ -8,11 +8,23 @@ namespace Tensorflow.Keras.Engine
public class CombinerPreprocessingLayer : Layer
{
PreprocessingLayerArgs args;
protected ICombiner combiner;
protected bool _previously_updated;

public CombinerPreprocessingLayer(PreprocessingLayerArgs args)
: base(args)
{
_previously_updated = false;
}

public virtual void adapt(IDatasetV2 data, bool reset_state = true)
{
IAccumulator accumulator;
if (!reset_state)
accumulator = combiner.Restore();

var next_data = data.make_one_shot_iterator();
var data_element = next_data.next();
}
}
}

+ 10
- 0
src/TensorFlowNET.Keras/Engine/Interfaces/IAccumulator.cs View File

@@ -0,0 +1,10 @@
using System;
using System.Collections.Generic;
using System.Text;

namespace Tensorflow.Keras.Engine
{
public interface IAccumulator
{
}
}

+ 19
- 0
src/TensorFlowNET.Keras/Engine/Interfaces/ICombiner.cs View File

@@ -0,0 +1,19 @@
using System;
using System.Collections.Generic;
using System.Text;

namespace Tensorflow.Keras.Engine
{
/// <summary>
/// Functional object that defines a shardable computation.
/// </summary>
public interface ICombiner
{
void Compute(Tensor values, IAccumulator accumulator = null);
void Merge();
void Extract();
IAccumulator Restore();
void Serialize();
void Deserialize();
}
}

+ 30
- 0
src/TensorFlowNET.Keras/Layers/Preprocessing/IndexLookup.cs View File

@@ -0,0 +1,30 @@
using System;
using System.Collections.Generic;
using System.Text;
using Tensorflow.Keras.ArgsDefinition;
using Tensorflow.Keras.Engine;

namespace Tensorflow.Keras.Layers
{
public class IndexLookup : CombinerPreprocessingLayer
{
public IndexLookup(int max_tokens = -1,
int num_oov_indices = 1,
string mask_token = "",
string oov_token = "[UNK]",
string encoding = "utf-8",
bool invert = false) : base(new PreprocessingLayerArgs())
{
var num_mask_tokens = mask_token == null ? 0 : 1;
var vocab_size = max_tokens - (num_oov_indices + num_mask_tokens);
combiner = new IndexLookupCombiner(vocab_size, mask_token);
}

public override void adapt(IDatasetV2 data, bool reset_state = true)
{
if (!reset_state)
throw new ValueError("IndexLookup does not support streaming adapts.");
base.adapt(data, reset_state);
}
}
}

+ 16
- 0
src/TensorFlowNET.Keras/Layers/Preprocessing/IndexLookupAccumulator.cs View File

@@ -0,0 +1,16 @@
using System;
using System.Collections.Generic;
using System.Text;
using Tensorflow.Keras.Engine;

namespace Tensorflow.Keras.Layers
{
public class IndexLookupAccumulator : IAccumulator
{
public Dictionary<string, int> CountDict { get; set; }
public IndexLookupAccumulator()
{
CountDict = new Dictionary<string, int>();
}
}
}

+ 55
- 0
src/TensorFlowNET.Keras/Layers/Preprocessing/IndexLookupCombiner.cs View File

@@ -0,0 +1,55 @@
using System;
using System.Collections.Generic;
using System.Text;
using Tensorflow.Keras.Engine;

namespace Tensorflow.Keras.Layers
{
/// <summary>
/// Combiner for the IndexLookup preprocessing layer.
/// </summary>
public class IndexLookupCombiner : ICombiner
{
int _vocab_size;
string _mask_value;

public IndexLookupCombiner(int vocab_size = -1, string mask_value = null)
{
_vocab_size = vocab_size;
_mask_value = mask_value;
}

public void Compute(Tensor values, IAccumulator accumulator = null)
{
if(accumulator == null)
{
accumulator = new IndexLookupAccumulator();
}
}

public void Deserialize()
{
throw new NotImplementedException();
}

public void Extract()
{
throw new NotImplementedException();
}

public void Merge()
{
throw new NotImplementedException();
}

public IAccumulator Restore()
{
throw new NotImplementedException();
}

public void Serialize()
{
throw new NotImplementedException();
}
}
}

+ 23
- 0
src/TensorFlowNET.Keras/Layers/Preprocessing/StringLookup.cs View File

@@ -0,0 +1,23 @@
using System;
using System.Collections.Generic;
using System.Text;

namespace Tensorflow.Keras.Layers
{
/// <summary>
/// Maps strings from a vocabulary to integer indices.
/// </summary>
class StringLookup : IndexLookup
{
public StringLookup(int max_tokens = -1,
int num_oov_indices = 1,
string mask_token = "",
string[] vocabulary = null,
string oov_token = "[UNK]",
string encoding = "utf-8",
bool invert = false)
{

}
}
}

+ 15
- 4
src/TensorFlowNET.Keras/Layers/Preprocessing/TextVectorization.cs View File

@@ -3,12 +3,14 @@ using System.Collections.Generic;
using System.Text;
using Tensorflow.Keras.ArgsDefinition;
using Tensorflow.Keras.Engine;
using static Tensorflow.Binding;

namespace Tensorflow.Keras.Layers
{
public class TextVectorization : CombinerPreprocessingLayer
{
TextVectorizationArgs args;
IndexLookup _index_lookup_layer;

public TextVectorization(TextVectorizationArgs args)
: base(args)
@@ -16,6 +18,11 @@ namespace Tensorflow.Keras.Layers
this.args = args;
args.DType = TF_DataType.TF_STRING;
// string standardize = "lower_and_strip_punctuation",

var mask_token = args.OutputMode == "int" ? "" : null;
_index_lookup_layer = new StringLookup(max_tokens: args.MaxTokens,
mask_token: mask_token,
vocabulary: args.Vocabulary);
}

/// <summary>
@@ -23,13 +30,14 @@ namespace Tensorflow.Keras.Layers
/// </summary>
/// <param name="data"></param>
/// <param name="reset_state"></param>
public void adapt(IDatasetV2 data, bool reset_state = true)
public override void adapt(IDatasetV2 data, bool reset_state = true)
{
var shape = data.output_shapes[0];
if (shape.rank == 1)
data = data.map(tensor => array_ops.expand_dims(tensor, -1));
build(data.variant_tensor);
var preprocessed_inputs = data.map(_preprocess);
_index_lookup_layer.adapt(preprocessed_inputs);
}

protected override void build(Tensors inputs)
@@ -39,14 +47,17 @@ namespace Tensorflow.Keras.Layers

Tensors _preprocess(Tensors inputs)
{
Tensor input_tensor = null;
if (args.Standardize != null)
inputs = args.Standardize(inputs);
input_tensor = args.Standardize(inputs);
if (!string.IsNullOrEmpty(args.Split))
{
if (inputs.shape.ndim > 1)
inputs = array_ops.squeeze(inputs, axis: new[] { -1 });
input_tensor = array_ops.squeeze(inputs, axis: new[] { -1 });
if (args.Split == "whitespace")
input_tensor = tf.strings.split(input_tensor);
}
return inputs;
return input_tensor;
}
}
}

+ 2
- 0
src/TensorFlowNET.Keras/Preprocessings/DatasetUtils.index_directory.cs View File

@@ -1,4 +1,5 @@
using NumSharp;
using System;
using System.Collections.Generic;
using System.IO;
using System.Linq;
@@ -60,6 +61,7 @@ namespace Tensorflow.Keras.Preprocessings
}
}

Console.WriteLine($"Found {return_file_paths.Length} files belonging to {class_names.Length} classes.");
return (return_file_paths, return_labels, class_names);
}
}


+ 9
- 8
src/TensorFlowNET.Keras/Tensorflow.Keras.csproj View File

@@ -6,7 +6,7 @@
<LangVersion>8.0</LangVersion>
<RootNamespace>Tensorflow.Keras</RootNamespace>
<Platforms>AnyCPU;x64</Platforms>
<Version>0.4.1</Version>
<Version>0.5.0</Version>
<Authors>Haiping Chen</Authors>
<Product>Keras for .NET</Product>
<Copyright>Apache 2.0, Haiping Chen 2020</Copyright>
@@ -23,7 +23,8 @@
* Implemented backward_function.
* Support model.load_weights.
* Add Subtract layer
* Support YOLOv3 model.</PackageReleaseNotes>
* Support YOLOv3 model.
* Text preprocessing</PackageReleaseNotes>
<Description>Keras for .NET

Keras is an API designed for human beings, not machines. Keras follows best practices for reducing cognitive load: it offers consistent &amp; simple APIs, it minimizes the number of user actions required for common use cases, and it provides clear &amp; actionable error messages.</Description>
@@ -34,8 +35,8 @@ Keras is an API designed for human beings, not machines. Keras follows best prac
<RepositoryType>Git</RepositoryType>
<SignAssembly>true</SignAssembly>
<AssemblyOriginatorKeyFile>Open.snk</AssemblyOriginatorKeyFile>
<AssemblyVersion>0.4.1.0</AssemblyVersion>
<FileVersion>0.4.1.0</FileVersion>
<AssemblyVersion>0.5.0.0</AssemblyVersion>
<FileVersion>0.5.0.0</FileVersion>
<PackageLicenseFile>LICENSE</PackageLicenseFile>
</PropertyGroup>

@@ -48,6 +49,10 @@ Keras is an API designed for human beings, not machines. Keras follows best prac
<AllowUnsafeBlocks>false</AllowUnsafeBlocks>
</PropertyGroup>

<PropertyGroup Condition="'$(Configuration)|$(Platform)'=='Debug|x64'">
<DocumentationFile>Tensorflow.Keras.xml</DocumentationFile>
</PropertyGroup>

<ItemGroup>
<PackageReference Include="MethodBoundaryAspect.Fody" Version="2.0.138" />
<PackageReference Include="Newtonsoft.Json" Version="12.0.3" />
@@ -62,10 +67,6 @@ Keras is an API designed for human beings, not machines. Keras follows best prac
</None>
</ItemGroup>

<ItemGroup>
<Folder Include="Engine\Interfaces\" />
</ItemGroup>

<ItemGroup>
<ProjectReference Include="..\TensorFlowNET.Core\Tensorflow.Binding.csproj" />
</ItemGroup>


+ 27
- 1
src/TensorFlowNET.Text/Tokenizers/WhitespaceTokenizer.cs View File

@@ -1,6 +1,8 @@
using System;
using NumSharp;
using System;
using System.Collections.Generic;
using System.Text;
using static Tensorflow.Binding;

namespace Tensorflow.Text.Tokenizers
{
@@ -13,7 +15,31 @@ namespace Tensorflow.Text.Tokenizers
/// <returns></returns>
public Tensor tokenize(Tensor input)
{
tokenize_with_offsets(input);
throw new NotImplementedException("");
}

Tensor[] tokenize_with_offsets(Tensor input)
{
tf_with(ops.name_scope(null, "WhitespaceTokenize"), scope =>
{
_whitespace_tokenize_with_offsets_encode_decode_wrapper(input);
});
throw new NotImplementedException("");
}

Tensor _whitespace_tokenize_with_offsets_encode_decode_wrapper(Tensor input_tensor)
{
// Decode the strings and get byte offsets
var (codepoints, byte_start_offsets) = tf.strings.unicode_decode_with_offsets(input_tensor, "UTF-8");
var byte_end_offsets = array_ops.concat(new Tensor[]
{
byte_start_offsets[Slice.All, new Slice(1)],
math_ops.cast(
array_ops.expand_dims(tf.strings.string_length(input_tensor), 1),
dtypes.int64)
}, 1);
return input_tensor;
}
}
}

+ 1
- 0
src/TensorFlowNet.Benchmarks/Tensorflow.Benchmark.csproj View File

@@ -9,6 +9,7 @@
<PropertyGroup Condition="'$(Configuration)|$(Platform)'=='Debug|AnyCPU'">
<AllowUnsafeBlocks>true</AllowUnsafeBlocks>
<DefineConstants>DEBUG;TRACE</DefineConstants>
<PlatformTarget>x64</PlatformTarget>
</PropertyGroup>

<PropertyGroup Condition="'$(Configuration)|$(Platform)'=='Debug|x64'">


+ 2
- 2
test/TensorFlowNET.UnitTest/Basics/RandomTest.cs View File

@@ -14,7 +14,7 @@ namespace TensorFlowNET.UnitTest.Basics
/// Test the function of setting random seed
/// This will help regenerate the same result
/// </summary>
[TestMethod, Ignore]
[TestMethod]
public void TFRandomSeedTest()
{
var initValue = np.arange(6).reshape(3, 2);
@@ -60,7 +60,7 @@ namespace TensorFlowNET.UnitTest.Basics
/// <summary>
/// This part we use funcs in tf.random rather than only tf
/// </summary>
[TestMethod, Ignore]
[TestMethod]
public void TFRandomRaodomSeedTest()
{
tf.set_random_seed(1234);


+ 20
- 0
test/TensorFlowNET.UnitTest/Dataset/DatasetTest.cs View File

@@ -151,5 +151,25 @@ namespace TensorFlowNET.UnitTest.Dataset
var cardinality = dataset.dataset_cardinality();
Assert.AreEqual(new long[] { 10 }, cardinality.numpy());
}

[TestMethod]
public void Shuffle()
{
tf.set_random_seed(1234);

var dataset = tf.data.Dataset.range(3);
var shuffled = dataset.shuffle(3);

var zipped = tf.data.Dataset.zip(dataset, shuffled);

bool allEqual = true;
foreach (var item in zipped)
{
if (item.Item1 != item.Item2)
allEqual = false;
}

Assert.IsFalse(allEqual);
}
}
}

+ 8
- 0
test/TensorFlowNET.UnitTest/ManagedAPI/StringsApiTest.cs View File

@@ -58,5 +58,13 @@ namespace TensorFlowNET.UnitTest.ManagedAPI
Assert.AreEqual(strings[1], stringData[1]);
Assert.AreEqual(strings[2], stringData[2]);
}

[TestMethod]
public void StringSplit()
{
var tensor = tf.constant(new[] { "hello world", "tensorflow .net csharp", "fsharp" });
var ragged_tensor = tf.strings.split(tensor);
Assert.AreEqual((3, -1), ragged_tensor.shape);
}
}
}

+ 3
- 1
test/TensorFlowNET.UnitTest/Text/TokenizerTest.cs View File

@@ -10,10 +10,12 @@ namespace TensorFlowNET.UnitTest.Text
[TestClass]
public class TokenizerTest
{
[TestMethod]
[TestMethod, Ignore]
public void Tokenize()
{
var docs = tf.constant(new[] { "Everything not saved will be lost." });
var tokenizer = text.WhitespaceTokenizer();
var tokens = tokenizer.tokenize(docs);
}
}
}

Loading…
Cancel
Save