@@ -11,8 +11,6 @@ | |||||
*master branch is based on tensorflow 2.2 now, v0.15-tensorflow1.15 is from tensorflow1.15.* | *master branch is based on tensorflow 2.2 now, v0.15-tensorflow1.15 is from tensorflow1.15.* | ||||
TF.NET is a member project of [SciSharp STACK](https://github.com/SciSharp). | |||||
 |  | ||||
@@ -56,59 +54,40 @@ using static Tensorflow.Binding; | |||||
Linear Regression: | Linear Regression: | ||||
```c# | ```c# | ||||
// We can set a fixed init value in order to debug | |||||
// Parameters | |||||
int training_steps = 1000; | |||||
float learning_rate = 0.01f; | |||||
int display_step = 100; | |||||
// We can set a fixed init value in order to demo | |||||
var W = tf.Variable(-0.06f, name: "weight"); | var W = tf.Variable(-0.06f, name: "weight"); | ||||
var b = tf.Variable(-0.73f, name: "bias"); | var b = tf.Variable(-0.73f, name: "bias"); | ||||
var optimizer = tf.optimizers.SGD(learning_rate); | |||||
// Construct a linear model | |||||
var pred = tf.add(tf.multiply(X, W), b); | |||||
// Mean squared error | |||||
var cost = tf.reduce_sum(tf.pow(pred - Y, 2.0f)) / (2.0f * n_samples); | |||||
// Gradient descent | |||||
// Note, minimize() knows to modify W and b because Variable objects are trainable=True by default | |||||
var optimizer = tf.train.GradientDescentOptimizer(learning_rate).minimize(cost); | |||||
// Initialize the variables (i.e. assign their default value) | |||||
var init = tf.global_variables_initializer(); | |||||
// Start training | |||||
using(tf.Session()) | |||||
// Run training for the given number of steps. | |||||
foreach (var step in range(1, training_steps + 1)) | |||||
{ | { | ||||
// Run the initializer | |||||
sess.run(init); | |||||
// Fit all training data | |||||
for (int epoch = 0; epoch < training_epochs; epoch++) | |||||
// Run the optimization to update W and b values. | |||||
// Wrap computation inside a GradientTape for automatic differentiation. | |||||
using var g = tf.GradientTape(); | |||||
// Linear regression (Wx + b). | |||||
var pred = W * X + b; | |||||
// Mean square error. | |||||
var loss = tf.reduce_sum(tf.pow(pred - Y, 2)) / (2 * n_samples); | |||||
// should stop recording | |||||
// Compute gradients. | |||||
var gradients = g.gradient(loss, (W, b)); | |||||
// Update W and b following gradients. | |||||
optimizer.apply_gradients(zip(gradients, (W, b))); | |||||
if (step % display_step == 0) | |||||
{ | { | ||||
foreach (var (x, y) in zip<float>(train_X, train_Y)) | |||||
sess.run(optimizer, (X, x), (Y, y)); | |||||
// Display logs per epoch step | |||||
if ((epoch + 1) % display_step == 0) | |||||
{ | |||||
var c = sess.run(cost, (X, train_X), (Y, train_Y)); | |||||
Console.WriteLine($"Epoch: {epoch + 1} cost={c} " + $"W={sess.run(W)} b={sess.run(b)}"); | |||||
} | |||||
pred = W * X + b; | |||||
loss = tf.reduce_sum(tf.pow(pred - Y, 2)) / (2 * n_samples); | |||||
print($"step: {step}, loss: {loss.numpy()}, W: {W.numpy()}, b: {b.numpy()}"); | |||||
} | } | ||||
Console.WriteLine("Optimization Finished!"); | |||||
var training_cost = sess.run(cost, (X, train_X), (Y, train_Y)); | |||||
Console.WriteLine($"Training cost={training_cost} W={sess.run(W)} b={sess.run(b)}"); | |||||
// Testing example | |||||
var test_X = np.array(6.83f, 4.668f, 8.9f, 7.91f, 5.7f, 8.7f, 3.1f, 2.1f); | |||||
var test_Y = np.array(1.84f, 2.273f, 3.2f, 2.831f, 2.92f, 3.24f, 1.35f, 1.03f); | |||||
Console.WriteLine("Testing... (Mean square loss Comparison)"); | |||||
var testing_cost = sess.run(tf.reduce_sum(tf.pow(pred - Y, 2.0f)) / (2.0f * test_X.shape[0]), | |||||
(X, test_X), (Y, test_Y)); | |||||
Console.WriteLine($"Testing cost={testing_cost}"); | |||||
var diff = Math.Abs((float)training_cost - (float)testing_cost); | |||||
Console.WriteLine($"Absolute mean square loss difference: {diff}"); | |||||
return diff < 0.01; | |||||
}); | |||||
} | |||||
``` | ``` | ||||
Run this example in [Jupyter Notebook](https://github.com/SciSharp/SciSharpCube). | Run this example in [Jupyter Notebook](https://github.com/SciSharp/SciSharpCube). | ||||
@@ -25,7 +25,15 @@ TensorFlow.NET uses the .NET Standard 2.0 standard, so your new project Target F | |||||
```cmd | ```cmd | ||||
### install tensorflow C# binding | |||||
PM> Install-Package TensorFlow.NET | PM> Install-Package TensorFlow.NET | ||||
### Install tensorflow binary | |||||
### For CPU version | |||||
PM> Install-Package SciSharp.TensorFlow.Redist | |||||
### For GPU version (CUDA and cuDNN are required) | |||||
PM> Install-Package SciSharp.TensorFlow.Redist-Windows-GPU | |||||
``` | ``` | ||||
### Start coding Hello World | ### Start coding Hello World | ||||
@@ -36,7 +44,7 @@ After installing the TensorFlow.NET package, you can use the `using Tensorflow` | |||||
```csharp | ```csharp | ||||
using System; | using System; | ||||
using Tensorflow; | |||||
using static Tensorflow.Binding; | |||||
namespace TensorFlowNET.Examples | namespace TensorFlowNET.Examples | ||||
{ | { | ||||
@@ -8,13 +8,13 @@ In this chapter we will talk about another common data type in TensorFlow: Place | |||||
var x = tf.placeholder(tf.int32); | var x = tf.placeholder(tf.int32); | ||||
var y = x * 3; | var y = x * 3; | ||||
Python.with<Session>(tf.Session(), sess => | |||||
using (var sess = tf.Session()) | |||||
{ | { | ||||
var result = sess.run(y, feed_dict: new FeedItem[] | var result = sess.run(y, feed_dict: new FeedItem[] | ||||
{ | { | ||||
new FeedItem(x, 2) | new FeedItem(x, 2) | ||||
}); | }); | ||||
// (int)result should be 6; | // (int)result should be 6; | ||||
}); | |||||
} | |||||
``` | ``` | ||||
@@ -8,7 +8,7 @@ | |||||
</PropertyGroup> | </PropertyGroup> | ||||
<ItemGroup> | <ItemGroup> | ||||
<PackageReference Include="SciSharp.TensorFlow.Redist" Version="2.2.0.1" /> | |||||
<PackageReference Include="SciSharp.TensorFlow.Redist" Version="2.2.0.2" /> | |||||
</ItemGroup> | </ItemGroup> | ||||
<ItemGroup> | <ItemGroup> | ||||
@@ -43,7 +43,7 @@ namespace Tensorflow | |||||
/// </summary> | /// </summary> | ||||
public partial class c_api | public partial class c_api | ||||
{ | { | ||||
public const string TensorFlowLibName = @"D:\SciSharp\tensorflow-google\bazel-bin\tensorflow\tensorflow.dll"; | |||||
public const string TensorFlowLibName = "tensorflow"; | |||||
public static string StringPiece(IntPtr handle) | public static string StringPiece(IntPtr handle) | ||||
{ | { | ||||
@@ -186,7 +186,7 @@ namespace Tensorflow | |||||
=> array_ops.slice(input, begin, size, name: name); | => array_ops.slice(input, begin, size, name: name); | ||||
public Tensor squeeze(Tensor input, int[] axis = null, string name = null, int squeeze_dims = -1) | public Tensor squeeze(Tensor input, int[] axis = null, string name = null, int squeeze_dims = -1) | ||||
=> gen_array_ops.squeeze(input, axis, name); | |||||
=> array_ops.squeeze(input, axis, name); | |||||
/// <summary> | /// <summary> | ||||
/// Stacks a list of rank-`R` tensors into one rank-`(R+1)` tensor. | /// Stacks a list of rank-`R` tensors into one rank-`(R+1)` tensor. | ||||
@@ -217,7 +217,7 @@ namespace Tensorflow | |||||
Tensor off_value = null, | Tensor off_value = null, | ||||
TF_DataType dtype = TF_DataType.DtInvalid, | TF_DataType dtype = TF_DataType.DtInvalid, | ||||
int axis = -1, | int axis = -1, | ||||
string name = null) => array_ops.one_hot(indices, depth, dtype: dtype, axis: axis, name: name); | |||||
string name = null) => array_ops.one_hot(indices, ops.convert_to_tensor(depth), dtype: dtype, axis: axis, name: name); | |||||
/// <summary> | /// <summary> | ||||
/// Pads a tensor | /// Pads a tensor | ||||
@@ -0,0 +1,30 @@ | |||||
/***************************************************************************** | |||||
Copyright 2020 The TensorFlow.NET Authors. All Rights Reserved. | |||||
Licensed under the Apache License, Version 2.0 (the "License"); | |||||
you may not use this file except in compliance with the License. | |||||
You may obtain a copy of the License at | |||||
http://www.apache.org/licenses/LICENSE-2.0 | |||||
Unless required by applicable law or agreed to in writing, software | |||||
distributed under the License is distributed on an "AS IS" BASIS, | |||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
See the License for the specific language governing permissions and | |||||
limitations under the License. | |||||
******************************************************************************/ | |||||
using NumSharp; | |||||
namespace Tensorflow | |||||
{ | |||||
public partial class tensorflow | |||||
{ | |||||
public DataOps data { get; } = new DataOps(); | |||||
public class DataOps | |||||
{ | |||||
public TensorSliceDataset Dataset { get; } = new TensorSliceDataset(); | |||||
} | |||||
} | |||||
} |
@@ -0,0 +1,25 @@ | |||||
/***************************************************************************** | |||||
Copyright 2020 Haiping Chen. All Rights Reserved. | |||||
Licensed under the Apache License, Version 2.0 (the "License"); | |||||
you may not use this file except in compliance with the License. | |||||
You may obtain a copy of the License at | |||||
http://www.apache.org/licenses/LICENSE-2.0 | |||||
Unless required by applicable law or agreed to in writing, software | |||||
distributed under the License is distributed on an "AS IS" BASIS, | |||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
See the License for the specific language governing permissions and | |||||
limitations under the License. | |||||
******************************************************************************/ | |||||
using NumSharp; | |||||
namespace Tensorflow | |||||
{ | |||||
public partial class tensorflow | |||||
{ | |||||
public KerasApi keras { get; } = new KerasApi(); | |||||
} | |||||
} |
@@ -21,6 +21,13 @@ namespace Tensorflow | |||||
{ | { | ||||
public partial class tensorflow | public partial class tensorflow | ||||
{ | { | ||||
public MathApi math { get; } = new MathApi(); | |||||
public class MathApi | |||||
{ | |||||
public Tensor log(Tensor x, string name = null) | |||||
=> gen_math_ops.log(x, name); | |||||
} | |||||
public Tensor abs(Tensor x, string name = null) | public Tensor abs(Tensor x, string name = null) | ||||
=> math_ops.abs(x, name); | => math_ops.abs(x, name); | ||||
@@ -254,7 +261,7 @@ namespace Tensorflow | |||||
/// Any values less than <c>clip_value_min</c> are set to <c>clip_value_min</c>. Any values | /// Any values less than <c>clip_value_min</c> are set to <c>clip_value_min</c>. Any values | ||||
/// greater than <c>clip_value_max</c> are set to <c>clip_value_max</c>. | /// greater than <c>clip_value_max</c> are set to <c>clip_value_max</c>. | ||||
/// </remarks> | /// </remarks> | ||||
public Tensor clip_by_value (Tensor t, Tensor clip_value_min, Tensor clip_value_max, string name = "ClipByValue") | |||||
public Tensor clip_by_value<T1, T2>(Tensor t, T1 clip_value_min, T2 clip_value_max, string name = "ClipByValue") | |||||
=> clip_ops.clip_by_value(t, clip_value_min, clip_value_max, name); | => clip_ops.clip_by_value(t, clip_value_min, clip_value_max, name); | ||||
public Tensor sub<Tx, Ty>(Tx a, Ty b, string name = null) | public Tensor sub<Tx, Ty>(Tx a, Ty b, string name = null) | ||||
@@ -14,6 +14,7 @@ | |||||
limitations under the License. | limitations under the License. | ||||
******************************************************************************/ | ******************************************************************************/ | ||||
using System; | |||||
using Tensorflow.Operations; | using Tensorflow.Operations; | ||||
using Tensorflow.Operations.Activation; | using Tensorflow.Operations.Activation; | ||||
using static Tensorflow.Binding; | using static Tensorflow.Binding; | ||||
@@ -182,7 +183,13 @@ namespace Tensorflow | |||||
=> nn_impl.sigmoid_cross_entropy_with_logits(labels: labels, logits: logits, name: name); | => nn_impl.sigmoid_cross_entropy_with_logits(labels: labels, logits: logits, name: name); | ||||
public Tensor softmax(Tensor logits, int axis = -1, string name = null) | public Tensor softmax(Tensor logits, int axis = -1, string name = null) | ||||
=> gen_nn_ops.softmax(logits, name); | |||||
{ | |||||
if (axis == -1) | |||||
return gen_nn_ops.softmax(logits, name); | |||||
else | |||||
throw new NotImplementedException(""); | |||||
} | |||||
/// <summary> | /// <summary> | ||||
/// Computes sparse softmax cross entropy between `logits` and `labels`. | /// Computes sparse softmax cross entropy between `logits` and `labels`. | ||||
@@ -38,6 +38,24 @@ namespace Tensorflow | |||||
TF_DataType dtype = TF_DataType.TF_FLOAT, | TF_DataType dtype = TF_DataType.TF_FLOAT, | ||||
int? seed = null, | int? seed = null, | ||||
string name = null) => random_ops.random_normal(shape, mean, stddev, dtype, seed, name); | string name = null) => random_ops.random_normal(shape, mean, stddev, dtype, seed, name); | ||||
/// <summary> | |||||
/// Outputs random values from a truncated normal distribution. | |||||
/// </summary> | |||||
/// <param name="shape"></param> | |||||
/// <param name="mean"></param> | |||||
/// <param name="stddev"></param> | |||||
/// <param name="dtype"></param> | |||||
/// <param name="seed"></param> | |||||
/// <param name="name"></param> | |||||
/// <returns></returns> | |||||
public Tensor truncated_normal(TensorShape shape, | |||||
float mean = 0.0f, | |||||
float stddev = 1.0f, | |||||
TF_DataType dtype = TF_DataType.TF_FLOAT, | |||||
int? seed = null, | |||||
string name = null) => random_ops.truncated_normal(shape, mean, stddev, dtype, seed, name); | |||||
public Tensor categorical( | public Tensor categorical( | ||||
Tensor logits, | Tensor logits, | ||||
int num_samples, | int num_samples, | ||||
@@ -0,0 +1,10 @@ | |||||
using System; | |||||
using System.Collections.Generic; | |||||
using System.Text; | |||||
namespace Tensorflow | |||||
{ | |||||
public class DatasetOps | |||||
{ | |||||
} | |||||
} |
@@ -0,0 +1,20 @@ | |||||
using NumSharp; | |||||
using System; | |||||
using System.Collections.Generic; | |||||
using System.Text; | |||||
namespace Tensorflow | |||||
{ | |||||
public class TensorSliceDataset | |||||
{ | |||||
public TensorSliceDataset(params NDArray[] elements) | |||||
{ | |||||
} | |||||
public TensorSliceDataset from_tensor_slices(params NDArray[] elements) | |||||
{ | |||||
throw new NotImplementedException(""); | |||||
} | |||||
} | |||||
} |
@@ -1,20 +1,30 @@ | |||||
using System; | using System; | ||||
using System.Collections.Generic; | using System.Collections.Generic; | ||||
using System.Linq; | |||||
using System.Text; | using System.Text; | ||||
using static Tensorflow.Binding; | |||||
namespace Tensorflow.Eager | namespace Tensorflow.Eager | ||||
{ | { | ||||
public class EagerOperation : Operation | public class EagerOperation : Operation | ||||
{ | { | ||||
public int NumInputs; | |||||
static Dictionary<string, OpDef> op_dict; | |||||
public string Name { get; set; } | |||||
public new int NumInputs; | |||||
public IntPtr[] InputHandles { get; set; } | public IntPtr[] InputHandles { get; set; } | ||||
public Tensor[] Inputs { get; set; } | public Tensor[] Inputs { get; set; } | ||||
public int NumOutputs; | |||||
public new int NumOutputs; | |||||
public IntPtr[] OutputHandles { get; set; } | public IntPtr[] OutputHandles { get; set; } | ||||
public Tensor[] Outputs { get; set; } | public Tensor[] Outputs { get; set; } | ||||
public int[] SkipInputIndices { get; set; } | |||||
public BindingArray SkipInputIndicesArray { get; set; } | |||||
public unsafe int[] SkipInputIndices => SkipInputIndicesArray.Data.Select(x => *(int*) x).ToArray(); | |||||
public string[] AttrsArray { get; set; } | |||||
public EagerOperation() : base(IntPtr.Zero) { } | |||||
public EagerOperation() : base(IntPtr.Zero) | |||||
{ | |||||
if (op_dict == null) | |||||
op_dict = op_def_registry.get_registered_ops(); | |||||
} | |||||
public override InputList inputs | public override InputList inputs | ||||
{ | { | ||||
@@ -22,13 +32,6 @@ namespace Tensorflow.Eager | |||||
{ | { | ||||
if (_inputs_val == null) | if (_inputs_val == null) | ||||
{ | { | ||||
var retval = new Tensor[NumInputs]; | |||||
for (int i = 0; i < NumInputs; i++) | |||||
{ | |||||
} | |||||
_inputs_val = new InputList(Inputs); | _inputs_val = new InputList(Inputs); | ||||
} | } | ||||
@@ -48,5 +51,35 @@ namespace Tensorflow.Eager | |||||
return _outputs; | return _outputs; | ||||
} | } | ||||
} | } | ||||
public override object get_attr(string attr_name) | |||||
{ | |||||
object value = null; | |||||
byte isList = 0; | |||||
using var status = new Status(); | |||||
var attrType = c_api.TFE_OpNameGetAttrType(tf.context, Name, attr_name, ref isList, status.Handle); | |||||
switch (attrType) | |||||
{ | |||||
case TF_AttrType.TF_ATTR_BOOL: | |||||
value = get_attr_bool(attr_name); | |||||
break; | |||||
default: | |||||
break; | |||||
} | |||||
return value; | |||||
} | |||||
public bool get_attr_bool(string attr_name) | |||||
{ | |||||
for (int i = 0; i < AttrsArray.Length; i = i + 2) | |||||
if (AttrsArray[i] == attr_name) | |||||
return AttrsArray[i + 1] == "1"; | |||||
throw new ValueError($"Can't find attr: {attr_name}"); | |||||
} | |||||
public override string ToString() | |||||
=> $"tf.EagerOperation {Name}"; | |||||
} | } | ||||
} | } |
@@ -2,6 +2,7 @@ | |||||
using System; | using System; | ||||
using System.Collections.Generic; | using System.Collections.Generic; | ||||
using System.Text; | using System.Text; | ||||
using System.Threading; | |||||
using static Tensorflow.Binding; | using static Tensorflow.Binding; | ||||
namespace Tensorflow.Eager | namespace Tensorflow.Eager | ||||
@@ -49,6 +50,10 @@ namespace Tensorflow.Eager | |||||
print($"new TensorHandle {Id} {tfe_tensor_handle.ToString("x16")}"); | print($"new TensorHandle {Id} {tfe_tensor_handle.ToString("x16")}"); | ||||
print($"new EagerTensor {Id} {EagerTensorHandle.ToString("x16")}");*/ | print($"new EagerTensor {Id} {EagerTensorHandle.ToString("x16")}");*/ | ||||
if (tfe_tensor_handle == IntPtr.Zero && _id == 0) | |||||
{ | |||||
} | |||||
GarbageCollector.Increase(_handle, GCItemType.TensorHandle); | GarbageCollector.Increase(_handle, GCItemType.TensorHandle); | ||||
GarbageCollector.Increase(tfe_tensor_handle, GCItemType.LocalTensorHandle); | GarbageCollector.Increase(tfe_tensor_handle, GCItemType.LocalTensorHandle); | ||||
GarbageCollector.Increase(EagerTensorHandle, GCItemType.EagerTensorHandle); | GarbageCollector.Increase(EagerTensorHandle, GCItemType.EagerTensorHandle); | ||||
@@ -56,6 +61,9 @@ namespace Tensorflow.Eager | |||||
return this; | return this; | ||||
} | } | ||||
public override IntPtr ToPointer() | |||||
=> EagerTensorHandle; | |||||
protected override void DisposeUnmanagedResources(IntPtr handle) | protected override void DisposeUnmanagedResources(IntPtr handle) | ||||
{ | { | ||||
GarbageCollector.Decrease(_handle); | GarbageCollector.Decrease(_handle); | ||||
@@ -13,7 +13,7 @@ namespace Tensorflow.Eager | |||||
public IntPtr EagerTensorHandle { get; set; } | public IntPtr EagerTensorHandle { get; set; } | ||||
public override string Device => c_api.StringPiece(c_api.TFE_TensorHandleDeviceName(tfe_tensor_handle, status.Handle)); | public override string Device => c_api.StringPiece(c_api.TFE_TensorHandleDeviceName(tfe_tensor_handle, status.Handle)); | ||||
// public override int rank => c_api.TFE_TensorHandleNumDims(tfe_tensor_handle, status); | |||||
public override int rank => c_api.TFE_TensorHandleNumDims(tfe_tensor_handle, status.Handle); | |||||
public static int GetRank(IntPtr handle) | public static int GetRank(IntPtr handle) | ||||
{ | { | ||||
@@ -25,7 +25,7 @@ namespace Tensorflow | |||||
public delegate IntPtr gradient_function_callback(string op_name, | public delegate IntPtr gradient_function_callback(string op_name, | ||||
IntPtr op_inputs, | IntPtr op_inputs, | ||||
IntPtr op_outputs, | IntPtr op_outputs, | ||||
int num_attrs, | |||||
string attrs_string, | |||||
IntPtr output_grads, | IntPtr output_grads, | ||||
IntPtr skip_input_indices); | IntPtr skip_input_indices); | ||||
@@ -72,6 +72,9 @@ namespace Tensorflow | |||||
[DllImport(TensorFlowLibName)] | [DllImport(TensorFlowLibName)] | ||||
public static extern TF_AttrType TFE_OpGetAttrType(IntPtr op, string attr_name, ref byte is_list, SafeStatusHandle status); | public static extern TF_AttrType TFE_OpGetAttrType(IntPtr op, string attr_name, ref byte is_list, SafeStatusHandle status); | ||||
[DllImport(TensorFlowLibName)] | |||||
public static extern TF_AttrType TFE_OpNameGetAttrType(IntPtr ct, string op_or_function_name, string attr_name, ref byte is_list, SafeStatusHandle status); | |||||
/// <summary> | /// <summary> | ||||
/// Returns the length (number of tensors) of the input argument `input_name` | /// Returns the length (number of tensors) of the input argument `input_name` | ||||
/// found in the provided `op`. | /// found in the provided `op`. | ||||
@@ -399,6 +402,7 @@ namespace Tensorflow | |||||
string name, | string name, | ||||
IntPtr[] inputs, | IntPtr[] inputs, | ||||
int input_size, | int input_size, | ||||
string attrs_string, | |||||
TFE_FastPathExecute_SetOpAttrs set_op_attrs, | TFE_FastPathExecute_SetOpAttrs set_op_attrs, | ||||
IntPtr[] outputs, | IntPtr[] outputs, | ||||
int output_size); | int output_size); | ||||
@@ -31,6 +31,37 @@ namespace Tensorflow.Eager | |||||
} | } | ||||
} | } | ||||
public static string SetOpAttrs2(params object[] attrs) | |||||
{ | |||||
string attr_string = string.Empty; | |||||
for(int i = 0; i < attrs.Length; i = i + 2) | |||||
{ | |||||
object key = attrs[i]; | |||||
object value = attrs[i + 1]; | |||||
switch (value) | |||||
{ | |||||
case TF_DataType dtype: | |||||
value = (int)dtype; | |||||
break; | |||||
case bool bVal: | |||||
value = bVal ? 1 : 0; | |||||
break; | |||||
case int[] shape: | |||||
value = shape.Length == 0 ? "null" : string.Join(" ", shape); | |||||
break; | |||||
default: | |||||
break; | |||||
} | |||||
attr_string += string.IsNullOrEmpty(attr_string) ? | |||||
$"{key},{value}" : | |||||
$",{key},{value}"; | |||||
} | |||||
return attr_string; | |||||
} | |||||
/// <summary> | /// <summary> | ||||
/// This function will set the op attrs required. If an attr has the value of | /// This function will set the op attrs required. If an attr has the value of | ||||
/// None, then it will read the AttrDef to get the default value and set that | /// None, then it will read the AttrDef to get the default value and set that | ||||
@@ -1,7 +1,9 @@ | |||||
using Google.Protobuf.WellKnownTypes; | using Google.Protobuf.WellKnownTypes; | ||||
using NumSharp.Utilities; | |||||
using System; | using System; | ||||
using System.Collections.Generic; | using System.Collections.Generic; | ||||
using System.Linq; | using System.Linq; | ||||
using System.Reflection; | |||||
using System.Runtime.InteropServices; | using System.Runtime.InteropServices; | ||||
using System.Text; | using System.Text; | ||||
using Tensorflow.Eager; | using Tensorflow.Eager; | ||||
@@ -72,22 +74,40 @@ namespace Tensorflow.Gradients | |||||
public Tensor gradient(Tensor target, Tensor source) | public Tensor gradient(Tensor target, Tensor source) | ||||
{ | { | ||||
if(_recording) | |||||
if (_recording) | |||||
{ | { | ||||
if (!_persistent) | if (!_persistent) | ||||
_pop_tape(); | _pop_tape(); | ||||
} | } | ||||
var results = new[] { new EagerTensor() }; | |||||
var results = EagerTensorPass.Create(); | |||||
var targets = EagerTensorPass.From(target); | |||||
var sources = EagerTensorPass.From(source); | |||||
using Status status = new Status(c_api.TFE_TapeGradient(_tape, | using Status status = new Status(c_api.TFE_TapeGradient(_tape, | ||||
new [] { (target as EagerTensor).EagerTensorHandle }, 1, | |||||
new [] { (source as EagerTensor).EagerTensorHandle }, 1, | |||||
results.Select(x => x.EagerTensorHandle).ToArray(), results.Length)); | |||||
targets.Points, targets.Length, | |||||
sources.Points, sources.Length, | |||||
results.Points, results.Length)); | |||||
status.Check(true); | status.Check(true); | ||||
return results[0].Resolve(); | return results[0].Resolve(); | ||||
} | } | ||||
public unsafe (Tensor, Tensor) gradient(Tensor target, (ResourceVariable, ResourceVariable) sources) | |||||
public Tensor gradient(Tensor target, ResourceVariable source) | |||||
{ | |||||
var results = gradient(target as EagerTensor, new[] { source }); | |||||
return results[0]; | |||||
} | |||||
public (Tensor, Tensor) gradient(Tensor target, (ResourceVariable, ResourceVariable) sources) | |||||
{ | |||||
var results = gradient(target as EagerTensor, new[] { sources.Item1, sources.Item2 }); | |||||
return (results[0], results[1]); | |||||
} | |||||
public EagerTensor[] gradient(EagerTensor target, ResourceVariable[] sources) | |||||
{ | { | ||||
if (_recording) | if (_recording) | ||||
{ | { | ||||
@@ -95,18 +115,14 @@ namespace Tensorflow.Gradients | |||||
_pop_tape(); | _pop_tape(); | ||||
} | } | ||||
var results = new[] { new EagerTensor(), new EagerTensor() }; | |||||
var results = EagerTensorPass.Create(sources.Length); | |||||
var target_inputs = EagerTensorPass.From(target); | |||||
var source_inputs = EagerTensorPass.From(sources.Select(x => x.Handle).ToArray()); | |||||
using Status status = new Status(c_api.TFE_TapeGradient(_tape, | using Status status = new Status(c_api.TFE_TapeGradient(_tape, | ||||
new IntPtr[] | |||||
{ | |||||
target as EagerTensor | |||||
}, 1, | |||||
new IntPtr[] | |||||
{ | |||||
(sources.Item1.Handle as EagerTensor).EagerTensorHandle, | |||||
(sources.Item2.Handle as EagerTensor).EagerTensorHandle | |||||
}, 2, | |||||
results.Select(x => x.EagerTensorHandle).ToArray(), results.Length)); | |||||
target_inputs.Points, target_inputs.Length, | |||||
source_inputs.Points, source_inputs.Length, | |||||
results.Points, results.Length)); | |||||
status.Check(true); | status.Check(true); | ||||
if (!_persistent) | if (!_persistent) | ||||
@@ -116,13 +132,15 @@ namespace Tensorflow.Gradients | |||||
_tape = null; | _tape = null; | ||||
} | } | ||||
return (results[0].Resolve(), results[1].Resolve()); | |||||
return results.Items.Select(x => x.Resolve()).ToArray(); | |||||
} | } | ||||
public void Dispose() | public void Dispose() | ||||
{ | { | ||||
if (_recording) | if (_recording) | ||||
_pop_tape(); | _pop_tape(); | ||||
tf.tensorMgr.Reset(); | |||||
} | } | ||||
} | } | ||||
} | } |
@@ -310,23 +310,26 @@ namespace Tensorflow.Gradients | |||||
var input_shape = op.inputs[0]._shape_tuple(); | var input_shape = op.inputs[0]._shape_tuple(); | ||||
var output_shape = op.outputs[0]._shape_tuple(); | var output_shape = op.outputs[0]._shape_tuple(); | ||||
Tensor result, factor_tensor; | |||||
if(input_shape != null && | if(input_shape != null && | ||||
output_shape != null) | output_shape != null) | ||||
{ | { | ||||
var input_size = np.prod(input_shape); | var input_size = np.prod(input_shape); | ||||
var output_size = np.prod(output_shape); | var output_size = np.prod(output_shape); | ||||
var factor = (int)input_size / Math.Max((int)output_size, 1); | var factor = (int)input_size / Math.Max((int)output_size, 1); | ||||
var factor_tensor = constant_op.constant((int)input_size, dtype: sum_grad.dtype); | |||||
return new Tensor[] { math_ops.truediv(sum_grad, math_ops.cast(factor_tensor, sum_grad.dtype)), null }; | |||||
factor_tensor = constant_op.constant(factor, dtype: sum_grad.dtype); | |||||
} | } | ||||
else | else | ||||
{ | { | ||||
var input_shape_tensor = array_ops.shape(op.inputs[0]); | var input_shape_tensor = array_ops.shape(op.inputs[0]); | ||||
var output_shape_tensor = array_ops.shape(op.outputs[0]); | var output_shape_tensor = array_ops.shape(op.outputs[0]); | ||||
var factor = _safe_shape_div(math_ops.reduce_prod(input_shape_tensor), math_ops.reduce_prod(output_shape_tensor)); | var factor = _safe_shape_div(math_ops.reduce_prod(input_shape_tensor), math_ops.reduce_prod(output_shape_tensor)); | ||||
return new Tensor[] { math_ops.truediv(sum_grad, math_ops.cast(factor, sum_grad.dtype)), null }; | |||||
throw new NotImplementedException(""); | |||||
factor_tensor = null; | |||||
} | } | ||||
result = math_ops.truediv(sum_grad, math_ops.cast(factor_tensor, sum_grad.dtype)); | |||||
return new Tensor[] { result, null }; | |||||
} | } | ||||
/// <summary> | /// <summary> | ||||
@@ -497,8 +500,8 @@ namespace Tensorflow.Gradients | |||||
if (tf.context.executing_eagerly()) | if (tf.context.executing_eagerly()) | ||||
{ | { | ||||
// should add ones_rank_cache | // should add ones_rank_cache | ||||
var new_shape_tensor = constant_op.constant(np.array(new int[] { 1 }) * rank, dtype: TF_DataType.TF_INT32); | |||||
grad = array_ops.reshape(grad, new_shape_tensor); | |||||
var new_shape = constant_op.constant(range(0, rank).Select(x => 1).ToArray(), dtype: TF_DataType.TF_INT32); | |||||
grad = array_ops.reshape(grad, new_shape); | |||||
} | } | ||||
else | else | ||||
{ | { | ||||
@@ -513,20 +516,23 @@ namespace Tensorflow.Gradients | |||||
input_shape = array_ops.shape(op.inputs[0]); | input_shape = array_ops.shape(op.inputs[0]); | ||||
return new Tensor[] { gen_array_ops.tile(grad, input_shape), null }; | return new Tensor[] { gen_array_ops.tile(grad, input_shape), null }; | ||||
} | } | ||||
else | |||||
else if (!input_0_shape.Contains(-1) && !tf.context.executing_eagerly()) | |||||
{ | { | ||||
throw new NotImplementedException(""); | |||||
} | } | ||||
} | } | ||||
} | } | ||||
input_shape = array_ops.shape(op.inputs[0]); | input_shape = array_ops.shape(op.inputs[0]); | ||||
ops.colocate_with(input_shape); | |||||
var output_shape_kept_dims = math_ops.reduced_shape(input_shape, op.inputs[1]); | |||||
var tile_scaling = _safe_shape_div(input_shape, output_shape_kept_dims); | |||||
grad = gen_array_ops.reshape(grad, output_shape_kept_dims); | |||||
if (!op.get_attr<bool>("keep_dims")) | |||||
{ | |||||
ops.colocate_with(input_shape); | |||||
var output_shape_kept_dims = math_ops.reduced_shape(input_shape, op.inputs[1]); | |||||
// var tile_scaling = _safe_shape_div(input_shape, output_shape_kept_dims); | |||||
grad = gen_array_ops.reshape(grad, output_shape_kept_dims); | |||||
} | |||||
return new Tensor[] { gen_array_ops.tile(grad, tile_scaling), null }; | |||||
return new Tensor[] { gen_array_ops.broadcast_to(grad, input_shape), null }; | |||||
} | } | ||||
[RegisterGradient("RealDiv")] | [RegisterGradient("RealDiv")] | ||||
@@ -17,6 +17,7 @@ | |||||
using System.Collections.Generic; | using System.Collections.Generic; | ||||
using System.Diagnostics.CodeAnalysis; | using System.Diagnostics.CodeAnalysis; | ||||
using System.Linq; | using System.Linq; | ||||
using Tensorflow.Eager; | |||||
using Tensorflow.Operations; | using Tensorflow.Operations; | ||||
using static Tensorflow.Binding; | using static Tensorflow.Binding; | ||||
@@ -81,6 +82,9 @@ namespace Tensorflow | |||||
/// </summary> | /// </summary> | ||||
public _ControlDependenciesController control_dependencies(object[] control_inputs) | public _ControlDependenciesController control_dependencies(object[] control_inputs) | ||||
{ | { | ||||
if (tf.context.executing_eagerly()) | |||||
return new _ControlDependenciesController(this, null); | |||||
if (control_inputs == null) | if (control_inputs == null) | ||||
return new _ControlDependenciesController(this, null); | return new _ControlDependenciesController(this, null); | ||||
@@ -0,0 +1,34 @@ | |||||
using System; | |||||
using System.Collections.Generic; | |||||
using System.Text; | |||||
namespace Tensorflow | |||||
{ | |||||
public class NullContextmanager : ITensorFlowObject | |||||
{ | |||||
public void __init__() | |||||
{ | |||||
throw new NotImplementedException(); | |||||
} | |||||
public void __enter__() | |||||
{ | |||||
throw new NotImplementedException(); | |||||
} | |||||
public void __del__() | |||||
{ | |||||
throw new NotImplementedException(); | |||||
} | |||||
public void __exit__() | |||||
{ | |||||
throw new NotImplementedException(); | |||||
} | |||||
public void Dispose() | |||||
{ | |||||
throw new NotImplementedException(); | |||||
} | |||||
} | |||||
} |
@@ -0,0 +1,27 @@ | |||||
using NumSharp; | |||||
using System; | |||||
using System.Collections.Generic; | |||||
using System.Text; | |||||
namespace Tensorflow.Keras.Datasets | |||||
{ | |||||
public class DatasetPass | |||||
{ | |||||
public (NDArray, NDArray) Train { get; set; } | |||||
public (NDArray, NDArray) Test { get; set; } | |||||
public void Deconstruct(out NDArray x_train, out NDArray y_train, out NDArray x_test, out NDArray y_test) | |||||
{ | |||||
x_train = Train.Item1; | |||||
y_train = Train.Item2; | |||||
x_test = Test.Item1; | |||||
y_test = Test.Item2; | |||||
} | |||||
public void Deconstruct(out (NDArray, NDArray) train, out (NDArray, NDArray) test) | |||||
{ | |||||
train = Train; | |||||
test = Test; | |||||
} | |||||
} | |||||
} |
@@ -0,0 +1,27 @@ | |||||
/***************************************************************************** | |||||
Copyright 2020 Haiping Chen. All Rights Reserved. | |||||
Licensed under the Apache License, Version 2.0 (the "License"); | |||||
you may not use this file except in compliance with the License. | |||||
You may obtain a copy of the License at | |||||
http://www.apache.org/licenses/LICENSE-2.0 | |||||
Unless required by applicable law or agreed to in writing, software | |||||
distributed under the License is distributed on an "AS IS" BASIS, | |||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
See the License for the specific language governing permissions and | |||||
limitations under the License. | |||||
******************************************************************************/ | |||||
using System; | |||||
using System.Collections.Generic; | |||||
using System.Text; | |||||
namespace Tensorflow.Keras.Datasets | |||||
{ | |||||
public class KerasDataset | |||||
{ | |||||
public Mnist mnist { get; } = new Mnist(); | |||||
} | |||||
} |
@@ -0,0 +1,76 @@ | |||||
/***************************************************************************** | |||||
Copyright 2020 Haiping Chen. All Rights Reserved. | |||||
Licensed under the Apache License, Version 2.0 (the "License"); | |||||
you may not use this file except in compliance with the License. | |||||
You may obtain a copy of the License at | |||||
http://www.apache.org/licenses/LICENSE-2.0 | |||||
Unless required by applicable law or agreed to in writing, software | |||||
distributed under the License is distributed on an "AS IS" BASIS, | |||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
See the License for the specific language governing permissions and | |||||
limitations under the License. | |||||
******************************************************************************/ | |||||
using NumSharp; | |||||
using System; | |||||
using System.Collections.Generic; | |||||
using System.IO; | |||||
using System.Net; | |||||
using System.Text; | |||||
namespace Tensorflow.Keras.Datasets | |||||
{ | |||||
public class Mnist | |||||
{ | |||||
string origin_folder = "https://storage.googleapis.com/tensorflow/tf-keras-datasets/"; | |||||
string file_name = "mnist.npz"; | |||||
/// <summary> | |||||
/// Loads the [MNIST dataset](http://yann.lecun.com/exdb/mnist/). | |||||
/// </summary> | |||||
/// <returns></returns> | |||||
public DatasetPass load_data() | |||||
{ | |||||
var file = Download(); | |||||
var bytes = File.ReadAllBytes(file); | |||||
var datax = LoadX(bytes); | |||||
var datay = LoadY(bytes); | |||||
return new DatasetPass | |||||
{ | |||||
Train = (datax.Item1, datay.Item1), | |||||
Test = (datax.Item2, datay.Item2) | |||||
}; | |||||
} | |||||
(NDArray, NDArray) LoadX(byte[] bytes) | |||||
{ | |||||
var y = np.Load_Npz<byte[,,]>(bytes); | |||||
return (y["x_train.npy"], y["x_test.npy"]); | |||||
} | |||||
(NDArray, NDArray) LoadY(byte[] bytes) | |||||
{ | |||||
var y = np.Load_Npz<byte[]>(bytes); | |||||
return (y["y_train.npy"], y["y_test.npy"]); | |||||
} | |||||
string Download() | |||||
{ | |||||
var fileSaveTo = Path.Combine(Path.GetTempPath(), file_name); | |||||
if (File.Exists(fileSaveTo)) | |||||
{ | |||||
Console.WriteLine($"The file {fileSaveTo} already exists"); | |||||
return fileSaveTo; | |||||
} | |||||
using var wc = new WebClient(); | |||||
wc.DownloadFileTaskAsync(origin_folder + file_name, fileSaveTo).Wait(); | |||||
return fileSaveTo; | |||||
} | |||||
} | |||||
} |
@@ -0,0 +1,12 @@ | |||||
using System.Data; | |||||
using Tensorflow.Keras; | |||||
using Tensorflow.Keras.Datasets; | |||||
namespace Tensorflow | |||||
{ | |||||
public class KerasApi | |||||
{ | |||||
public KerasDataset datasets { get; } = new KerasDataset(); | |||||
public Initializers initializers { get; } = new Initializers(); | |||||
} | |||||
} |
@@ -36,6 +36,13 @@ namespace Tensorflow.Keras.Optimizers | |||||
apply_state = new Dictionary<DeviceDType, Dictionary<string, Tensor>>(); | apply_state = new Dictionary<DeviceDType, Dictionary<string, Tensor>>(); | ||||
} | } | ||||
public void apply_gradients((Tensor, ResourceVariable) grads_and_vars, | |||||
string name = null, | |||||
bool experimental_aggregate_gradients = true) | |||||
=> apply_gradients(new (Tensor, ResourceVariable)[] { grads_and_vars }, | |||||
name: name, | |||||
experimental_aggregate_gradients: experimental_aggregate_gradients); | |||||
/// <summary> | /// <summary> | ||||
/// Apply gradients to variables. | /// Apply gradients to variables. | ||||
/// </summary> | /// </summary> | ||||
@@ -1,12 +0,0 @@ | |||||
using Tensorflow.Keras; | |||||
namespace Tensorflow | |||||
{ | |||||
public partial class tensorflow | |||||
{ | |||||
public class keras | |||||
{ | |||||
public static Initializers initializers => new Initializers(); | |||||
} | |||||
} | |||||
} |
@@ -373,6 +373,19 @@ namespace Tensorflow.Operations | |||||
public static Tensor relu_grad(Tensor gradients, Tensor features, string name = null) | public static Tensor relu_grad(Tensor gradients, Tensor features, string name = null) | ||||
{ | { | ||||
if (tf.context.executing_eagerly()) | |||||
{ | |||||
var results = EagerTensorPass.Create(); | |||||
var inputs = EagerTensorPass.From(gradients, features); | |||||
using Status status = new Status(c_api.TFE_FastPathExecute(tf.context, tf.context.device_name, | |||||
"ReluGrad", name, | |||||
inputs.Points, inputs.Length, | |||||
null, null, | |||||
results.Points, results.Length)); | |||||
status.Check(true); | |||||
return results[0].Resolve(); | |||||
} | |||||
var _op = _op_def_lib._apply_op_helper("ReluGrad", name: name, args: new | var _op = _op_def_lib._apply_op_helper("ReluGrad", name: name, args: new | ||||
{ | { | ||||
gradients, | gradients, | ||||
@@ -396,6 +409,19 @@ namespace Tensorflow.Operations | |||||
public static Tensor softmax(Tensor logits, string name = null) | public static Tensor softmax(Tensor logits, string name = null) | ||||
{ | { | ||||
if (tf.context.executing_eagerly()) | |||||
{ | |||||
var results = EagerTensorPass.Create(); | |||||
var inputs = EagerTensorPass.From(logits); | |||||
using Status status = new Status(c_api.TFE_FastPathExecute(tf.context, tf.context.device_name, | |||||
"Softmax", name, | |||||
inputs.Points, inputs.Length, | |||||
null, null, | |||||
results.Points, results.Length)); | |||||
status.Check(true); | |||||
return results[0].Resolve(); | |||||
} | |||||
var _op = _op_def_lib._apply_op_helper("Softmax", name: name, args: new | var _op = _op_def_lib._apply_op_helper("Softmax", name: name, args: new | ||||
{ | { | ||||
logits | logits | ||||
@@ -473,7 +499,8 @@ namespace Tensorflow.Operations | |||||
"Relu", name, new IntPtr[] | "Relu", name, new IntPtr[] | ||||
{ | { | ||||
features as EagerTensor, | features as EagerTensor, | ||||
}, 1, null, | |||||
}, 1, | |||||
null, null, | |||||
results.Select(x => x.EagerTensorHandle).ToArray(), results.Length)); | results.Select(x => x.EagerTensorHandle).ToArray(), results.Length)); | ||||
status.Check(true); | status.Check(true); | ||||
return results[0].Resolve(); | return results[0].Resolve(); | ||||
@@ -492,7 +519,8 @@ namespace Tensorflow.Operations | |||||
"Tanh", name, new IntPtr[] | "Tanh", name, new IntPtr[] | ||||
{ | { | ||||
x as EagerTensor, | x as EagerTensor, | ||||
}, 1, null, | |||||
}, 1, | |||||
null, null, | |||||
results.Select(x => x.EagerTensorHandle).ToArray(), results.Length)); | results.Select(x => x.EagerTensorHandle).ToArray(), results.Length)); | ||||
status.Check(true); | status.Check(true); | ||||
return results[0].Resolve(); | return results[0].Resolve(); | ||||
@@ -57,7 +57,7 @@ namespace Tensorflow | |||||
public int _id_value { get; set; } | public int _id_value { get; set; } | ||||
public Operation op => this; | public Operation op => this; | ||||
public TF_DataType dtype => TF_DataType.DtInvalid; | public TF_DataType dtype => TF_DataType.DtInvalid; | ||||
public string name => _handle == IntPtr.Zero ? null : c_api.StringPiece(c_api.TF_OperationName(_handle)); | |||||
public virtual string name => _handle == IntPtr.Zero ? null : c_api.StringPiece(c_api.TF_OperationName(_handle)); | |||||
public string OpType => _handle == IntPtr.Zero ? null : c_api.StringPiece(c_api.TF_OperationOpType(_handle)); | public string OpType => _handle == IntPtr.Zero ? null : c_api.StringPiece(c_api.TF_OperationOpType(_handle)); | ||||
public string Device => _handle == IntPtr.Zero ? null : c_api.StringPiece(c_api.TF_OperationDevice(_handle)); | public string Device => _handle == IntPtr.Zero ? null : c_api.StringPiece(c_api.TF_OperationDevice(_handle)); | ||||
@@ -228,7 +228,7 @@ namespace Tensorflow | |||||
public T get_attr<T>(string name) | public T get_attr<T>(string name) | ||||
=> (T)get_attr(name); | => (T)get_attr(name); | ||||
public object get_attr(string name) | |||||
public virtual object get_attr(string name) | |||||
{ | { | ||||
AttrValue x = null; | AttrValue x = null; | ||||
@@ -349,7 +349,7 @@ namespace Tensorflow | |||||
return fill(shape_tensor, ones, name: name); | return fill(shape_tensor, ones, name: name); | ||||
}); | }); | ||||
public static Tensor one_hot(Tensor indices, int depth, | |||||
public static Tensor one_hot(Tensor indices, Tensor depth, | |||||
Tensor on_value = null, | Tensor on_value = null, | ||||
Tensor off_value = null, | Tensor off_value = null, | ||||
TF_DataType dtype = TF_DataType.DtInvalid, | TF_DataType dtype = TF_DataType.DtInvalid, | ||||
@@ -25,7 +25,7 @@ namespace Tensorflow | |||||
{ | { | ||||
public class clip_ops | public class clip_ops | ||||
{ | { | ||||
public static Tensor clip_by_value(Tensor t, Tensor clip_value_min, Tensor clip_value_max, string name = null) | |||||
public static Tensor clip_by_value<T1, T2>(Tensor t, T1 clip_value_min, T2 clip_value_max, string name = null) | |||||
{ | { | ||||
return tf_with(ops.name_scope(name, "clip_by_value", new { t, clip_value_min, clip_value_max }), delegate | return tf_with(ops.name_scope(name, "clip_by_value", new { t, clip_value_min, clip_value_max }), delegate | ||||
{ | { | ||||
@@ -21,6 +21,7 @@ using static Tensorflow.Binding; | |||||
using Tensorflow.Eager; | using Tensorflow.Eager; | ||||
using System.Linq; | using System.Linq; | ||||
using static Tensorflow.Binding; | using static Tensorflow.Binding; | ||||
using System.Security.Cryptography.X509Certificates; | |||||
namespace Tensorflow | namespace Tensorflow | ||||
{ | { | ||||
@@ -60,7 +61,8 @@ namespace Tensorflow | |||||
{ | { | ||||
values as EagerTensor, | values as EagerTensor, | ||||
axis as EagerTensor | axis as EagerTensor | ||||
}, 2, null, | |||||
}, 2, | |||||
null, null, | |||||
results.Select(x => x.EagerTensorHandle).ToArray(), results.Length)); | results.Select(x => x.EagerTensorHandle).ToArray(), results.Length)); | ||||
status.Check(true); | status.Check(true); | ||||
return results[0].Resolve(); | return results[0].Resolve(); | ||||
@@ -165,7 +167,8 @@ namespace Tensorflow | |||||
var results = new[] { new EagerTensor() }; | var results = new[] { new EagerTensor() }; | ||||
using Status status = new Status(c_api.TFE_FastPathExecute(tf.context, tf.context.device_name, | using Status status = new Status(c_api.TFE_FastPathExecute(tf.context, tf.context.device_name, | ||||
"Pack", name, | "Pack", name, | ||||
values.Select(x => (x as EagerTensor).EagerTensorHandle).ToArray(), values.Length, | |||||
values.Select(x => (x as EagerTensor).EagerTensorHandle).ToArray(), values.Length, | |||||
wrap_tfe_src.SetOpAttrs2("axis", axis), | |||||
op => wrap_tfe_src.SetOpAttrs(op, "axis", axis), | op => wrap_tfe_src.SetOpAttrs(op, "axis", axis), | ||||
results.Select(x => x.EagerTensorHandle).ToArray(), results.Length)); | results.Select(x => x.EagerTensorHandle).ToArray(), results.Length)); | ||||
status.Check(true); | status.Check(true); | ||||
@@ -235,7 +238,8 @@ namespace Tensorflow | |||||
"Identity", name, new IntPtr[] | "Identity", name, new IntPtr[] | ||||
{ | { | ||||
input as EagerTensor | input as EagerTensor | ||||
}, 1, null, | |||||
}, 1, | |||||
null, null, | |||||
results.Select(x => x.EagerTensorHandle).ToArray(), results.Length)); | results.Select(x => x.EagerTensorHandle).ToArray(), results.Length)); | ||||
status.Check(true); | status.Check(true); | ||||
return results[0].Resolve(); | return results[0].Resolve(); | ||||
@@ -278,15 +282,16 @@ namespace Tensorflow | |||||
{ | { | ||||
if (tf.context.executing_eagerly()) | if (tf.context.executing_eagerly()) | ||||
{ | { | ||||
var results = new[] { new EagerTensor() }; | |||||
var results = EagerTensorPass.Create(); | |||||
var inputs = EagerTensorPass.From(dims, value); | |||||
using Status status = new Status(c_api.TFE_FastPathExecute(tf.context, tf.context.device_name, | using Status status = new Status(c_api.TFE_FastPathExecute(tf.context, tf.context.device_name, | ||||
"Fill", name, new IntPtr[] | |||||
{ | |||||
dims as EagerTensor, | |||||
value as EagerTensor | |||||
}, 2, null, | |||||
results.Select(x => x.EagerTensorHandle).ToArray(), results.Length)); | |||||
"Fill", name, | |||||
inputs.Points, inputs.Length, | |||||
null, null, | |||||
results.Points, results.Length)); | |||||
status.Check(true); | status.Check(true); | ||||
return results[0].Resolve(); | return results[0].Resolve(); | ||||
} | } | ||||
@@ -311,7 +316,8 @@ namespace Tensorflow | |||||
{ | { | ||||
s0 as EagerTensor, | s0 as EagerTensor, | ||||
s1 as EagerTensor | s1 as EagerTensor | ||||
}, 2, null, | |||||
}, 2, | |||||
null, null, | |||||
results.Select(x => x.EagerTensorHandle).ToArray(), results.Length)); | results.Select(x => x.EagerTensorHandle).ToArray(), results.Length)); | ||||
status.Check(true); | status.Check(true); | ||||
return (results[0].Resolve(), results[1].Resolve()); | return (results[0].Resolve(), results[1].Resolve()); | ||||
@@ -338,7 +344,8 @@ namespace Tensorflow | |||||
{ | { | ||||
tensor as EagerTensor, | tensor as EagerTensor, | ||||
shape as EagerTensor | shape as EagerTensor | ||||
}, 2, null, | |||||
}, 2, | |||||
null, null, | |||||
results.Select(x => x.EagerTensorHandle).ToArray(), results.Length)); | results.Select(x => x.EagerTensorHandle).ToArray(), results.Length)); | ||||
status.Check(true); | status.Check(true); | ||||
return results[0].Resolve(); | return results[0].Resolve(); | ||||
@@ -381,13 +388,30 @@ namespace Tensorflow | |||||
return _op.output; | return _op.output; | ||||
} | } | ||||
public static Tensor one_hot(Tensor indices, int depth, | |||||
public static Tensor one_hot(Tensor indices, Tensor depth, | |||||
Tensor on_value = null, | Tensor on_value = null, | ||||
Tensor off_value = null, | Tensor off_value = null, | ||||
TF_DataType dtype = TF_DataType.DtInvalid, | TF_DataType dtype = TF_DataType.DtInvalid, | ||||
int axis = -1, | int axis = -1, | ||||
string name = null) | string name = null) | ||||
{ | { | ||||
if (tf.context.executing_eagerly()) | |||||
{ | |||||
var results = EagerTensorPass.Create(); | |||||
var inputs = EagerTensorPass.From(indices, depth, on_value, off_value); | |||||
var attrs = new object[] { "axis", axis }; | |||||
using Status status = new Status(c_api.TFE_FastPathExecute(tf.context, tf.context.device_name, | |||||
"OneHot", name, | |||||
inputs.Points, inputs.Length, | |||||
wrap_tfe_src.SetOpAttrs2(attrs), | |||||
op => wrap_tfe_src.SetOpAttrs(op, attrs), | |||||
results.Points, results.Length)); | |||||
status.Check(true); | |||||
return results[0].Resolve(); | |||||
} | |||||
var _op = _op_def_lib._apply_op_helper("OneHot", name, new { indices, depth, on_value, off_value, axis }); | var _op = _op_def_lib._apply_op_helper("OneHot", name, new { indices, depth, on_value, off_value, axis }); | ||||
return _op.outputs[0]; | return _op.outputs[0]; | ||||
} | } | ||||
@@ -407,6 +431,21 @@ namespace Tensorflow | |||||
public static Tensor select<Tx, Ty>(Tensor condition, Tx t, Ty e, string name = null) | public static Tensor select<Tx, Ty>(Tensor condition, Tx t, Ty e, string name = null) | ||||
{ | { | ||||
if (tf.context.executing_eagerly()) | |||||
{ | |||||
var results = EagerTensorPass.Create(); | |||||
var inputs = EagerTensorPass.From(condition, t, e); | |||||
using Status status = new Status(c_api.TFE_FastPathExecute(tf.context, tf.context.device_name, | |||||
"SelectV2", name, | |||||
inputs.Points, inputs.Length, | |||||
null, null, | |||||
results.Points, results.Length)); | |||||
status.Check(true); | |||||
return results[0].Resolve(); | |||||
} | |||||
var _op = _op_def_lib._apply_op_helper("Select", name, new { condition, t, e }); | var _op = _op_def_lib._apply_op_helper("Select", name, new { condition, t, e }); | ||||
return _op.outputs[0]; | return _op.outputs[0]; | ||||
} | } | ||||
@@ -427,6 +466,7 @@ namespace Tensorflow | |||||
{ | { | ||||
input as EagerTensor, | input as EagerTensor, | ||||
}, 1, | }, 1, | ||||
wrap_tfe_src.SetOpAttrs2("out_type", out_type), | |||||
op => wrap_tfe_src.SetOpAttrs(op, "out_type", out_type), | op => wrap_tfe_src.SetOpAttrs(op, "out_type", out_type), | ||||
results.Select(x => x.EagerTensorHandle).ToArray(), results.Length)); | results.Select(x => x.EagerTensorHandle).ToArray(), results.Length)); | ||||
status.Check(true); | status.Check(true); | ||||
@@ -486,7 +526,8 @@ namespace Tensorflow | |||||
{ | { | ||||
input as EagerTensor, | input as EagerTensor, | ||||
multiples as EagerTensor | multiples as EagerTensor | ||||
}, 2, null, | |||||
}, 2, | |||||
null, null, | |||||
results.Select(x => x.EagerTensorHandle).ToArray(), results.Length)); | results.Select(x => x.EagerTensorHandle).ToArray(), results.Length)); | ||||
status.Check(true); | status.Check(true); | ||||
return results[0].Resolve(); | return results[0].Resolve(); | ||||
@@ -526,6 +567,14 @@ namespace Tensorflow | |||||
if (tf.context.executing_eagerly()) | if (tf.context.executing_eagerly()) | ||||
{ | { | ||||
var results = new[] { new EagerTensor() }; | var results = new[] { new EagerTensor() }; | ||||
var attrs = new object[] | |||||
{ | |||||
"begin_mask", begin_mask, | |||||
"end_mask", end_mask, | |||||
"ellipsis_mask", ellipsis_mask, | |||||
"new_axis_mask", new_axis_mask, | |||||
"shrink_axis_mask", shrink_axis_mask | |||||
}; | |||||
using Status status = new Status(c_api.TFE_FastPathExecute(tf.context, tf.context.device_name, | using Status status = new Status(c_api.TFE_FastPathExecute(tf.context, tf.context.device_name, | ||||
"StridedSlice", name, new IntPtr[] | "StridedSlice", name, new IntPtr[] | ||||
{ | { | ||||
@@ -534,12 +583,8 @@ namespace Tensorflow | |||||
end as EagerTensor, | end as EagerTensor, | ||||
strides as EagerTensor, | strides as EagerTensor, | ||||
}, 4, | }, 4, | ||||
op => wrap_tfe_src.SetOpAttrs(op, | |||||
"begin_mask", begin_mask, | |||||
"end_mask", end_mask, | |||||
"ellipsis_mask", ellipsis_mask, | |||||
"new_axis_mask", new_axis_mask, | |||||
"shrink_axis_mask", shrink_axis_mask), | |||||
wrap_tfe_src.SetOpAttrs2(attrs), | |||||
op => wrap_tfe_src.SetOpAttrs(op, attrs), | |||||
results.Select(x => x.EagerTensorHandle).ToArray(), results.Length)); | results.Select(x => x.EagerTensorHandle).ToArray(), results.Length)); | ||||
status.Check(true); | status.Check(true); | ||||
return results[0].Resolve(); | return results[0].Resolve(); | ||||
@@ -645,6 +690,21 @@ namespace Tensorflow | |||||
/// <returns> A `Tensor`. Has the same type as `input`.</returns> | /// <returns> A `Tensor`. Has the same type as `input`.</returns> | ||||
public static Tensor squeeze(Tensor input, int[] axis = null, string name = null) | public static Tensor squeeze(Tensor input, int[] axis = null, string name = null) | ||||
{ | { | ||||
if (tf.context.executing_eagerly()) | |||||
{ | |||||
var results = new[] { new EagerTensor() }; | |||||
using Status status = new Status(c_api.TFE_FastPathExecute(tf.context, tf.context.device_name, | |||||
"Squeeze", name, new IntPtr[] | |||||
{ | |||||
input as EagerTensor | |||||
}, 1, | |||||
wrap_tfe_src.SetOpAttrs2("squeeze_dims", axis), | |||||
op => wrap_tfe_src.SetOpAttrs(op, "squeeze_dims", axis), | |||||
results.Select(x => x.EagerTensorHandle).ToArray(), results.Length)); | |||||
status.Check(true); | |||||
return results[0].Resolve(); | |||||
} | |||||
if (axis == null) axis = new int[0]; | if (axis == null) axis = new int[0]; | ||||
var _op = _op_def_lib._apply_op_helper("Squeeze", name, args: new { input, squeeze_dims = axis }); | var _op = _op_def_lib._apply_op_helper("Squeeze", name, args: new { input, squeeze_dims = axis }); | ||||
@@ -674,8 +734,22 @@ namespace Tensorflow | |||||
/// <param name="shape"></param> | /// <param name="shape"></param> | ||||
/// <param name="name"></param> | /// <param name="name"></param> | ||||
/// <returns></returns> | /// <returns></returns> | ||||
public static Tensor broadcast_to(Tensor input, int[] shape, string name = null) | |||||
public static Tensor broadcast_to<T>(Tensor input, T shape, string name = null) | |||||
{ | { | ||||
if (tf.context.executing_eagerly()) | |||||
{ | |||||
var results = EagerTensorPass.Create(); | |||||
var inputs = EagerTensorPass.From(input, shape); | |||||
using Status status = new Status(c_api.TFE_FastPathExecute(tf.context, tf.context.device_name, | |||||
"BroadcastTo", name, | |||||
inputs.Points, inputs.Length, | |||||
null, null, | |||||
results.Points, results.Length)); | |||||
status.Check(true); | |||||
return results[0].Resolve(); | |||||
} | |||||
var _op = _op_def_lib._apply_op_helper("BroadcastTo", name, args: new { input, shape, name }); | var _op = _op_def_lib._apply_op_helper("BroadcastTo", name, args: new { input, shape, name }); | ||||
return _op.outputs[0]; | return _op.outputs[0]; | ||||
@@ -48,7 +48,7 @@ namespace Tensorflow | |||||
using Status status = new Status(c_api.TFE_FastPathExecute(tf.context, tf.context.device_name, | using Status status = new Status(c_api.TFE_FastPathExecute(tf.context, tf.context.device_name, | ||||
"AddN", name, | "AddN", name, | ||||
inputs.Select(x => (x as EagerTensor).EagerTensorHandle).ToArray(), inputs.Length, | inputs.Select(x => (x as EagerTensor).EagerTensorHandle).ToArray(), inputs.Length, | ||||
null, | |||||
null, null, | |||||
results.Select(x => x.EagerTensorHandle).ToArray(), results.Length)); | results.Select(x => x.EagerTensorHandle).ToArray(), results.Length)); | ||||
status.Check(true); | status.Check(true); | ||||
return results[0].Resolve(); | return results[0].Resolve(); | ||||
@@ -65,7 +65,7 @@ namespace Tensorflow | |||||
using Status status = new Status(c_api.TFE_FastPathExecute(tf.context, tf.context.device_name, | using Status status = new Status(c_api.TFE_FastPathExecute(tf.context, tf.context.device_name, | ||||
"AddN", name, | "AddN", name, | ||||
inputs, inputs.Length, | inputs, inputs.Length, | ||||
null, | |||||
null, null, | |||||
results, results.Length)); | results, results.Length)); | ||||
status.Check(true); | status.Check(true); | ||||
return results[0]; | return results[0]; | ||||
@@ -80,7 +80,23 @@ namespace Tensorflow | |||||
/// <param name="name"></param> | /// <param name="name"></param> | ||||
/// <returns></returns> | /// <returns></returns> | ||||
public static Tensor arg_max(Tensor input, int dimension, TF_DataType output_type = TF_DataType.TF_INT64, string name = null) | public static Tensor arg_max(Tensor input, int dimension, TF_DataType output_type = TF_DataType.TF_INT64, string name = null) | ||||
=> _op_def_lib._apply_op_helper("ArgMax", name, args: new { input, dimension, output_type }).outputs[0]; | |||||
{ | |||||
if (tf.context.executing_eagerly()) | |||||
{ | |||||
var results = EagerTensorPass.Create(); | |||||
var inputs = EagerTensorPass.From(input, dimension); | |||||
using Status status = new Status(c_api.TFE_FastPathExecute(tf.context, tf.context.device_name, | |||||
"ArgMax", name, | |||||
inputs.Points, inputs.Length, | |||||
wrap_tfe_src.SetOpAttrs2("output_type", output_type), | |||||
op => wrap_tfe_src.SetOpAttrs(op, "output_type", output_type), | |||||
results.Points, results.Length)); | |||||
status.Check(true); | |||||
return results[0].Resolve(); | |||||
} | |||||
return _op_def_lib._apply_op_helper("ArgMax", name, args: new { input, dimension, output_type }).output; | |||||
} | |||||
/// <summary> | /// <summary> | ||||
/// Returns the index with the smallest value across dimensions of a tensor. | /// Returns the index with the smallest value across dimensions of a tensor. | ||||
@@ -152,6 +168,7 @@ namespace Tensorflow | |||||
input as EagerTensor, | input as EagerTensor, | ||||
axis as EagerTensor | axis as EagerTensor | ||||
}, 2, | }, 2, | ||||
wrap_tfe_src.SetOpAttrs2("keep_dims", keep_dims), | |||||
op => wrap_tfe_src.SetOpAttrs(op, "keep_dims", keep_dims), | op => wrap_tfe_src.SetOpAttrs(op, "keep_dims", keep_dims), | ||||
results.Select(x => x.EagerTensorHandle).ToArray(), results.Length)); | results.Select(x => x.EagerTensorHandle).ToArray(), results.Length)); | ||||
status.Check(true); | status.Check(true); | ||||
@@ -198,6 +215,7 @@ namespace Tensorflow | |||||
input as EagerTensor, | input as EagerTensor, | ||||
axis as EagerTensor | axis as EagerTensor | ||||
}, 2, | }, 2, | ||||
wrap_tfe_src.SetOpAttrs2("keep_dims", keep_dims), | |||||
op => wrap_tfe_src.SetOpAttrs(op, "keep_dims", keep_dims), | op => wrap_tfe_src.SetOpAttrs(op, "keep_dims", keep_dims), | ||||
results.Select(x => x.EagerTensorHandle).ToArray(), results.Length)); | results.Select(x => x.EagerTensorHandle).ToArray(), results.Length)); | ||||
status.Check(true); | status.Check(true); | ||||
@@ -247,7 +265,8 @@ namespace Tensorflow | |||||
{ | { | ||||
x as EagerTensor, | x as EagerTensor, | ||||
y as EagerTensor | y as EagerTensor | ||||
}, 2, null, | |||||
}, 2, | |||||
null, null, | |||||
results.Select(x => x.EagerTensorHandle).ToArray(), results.Length)); | results.Select(x => x.EagerTensorHandle).ToArray(), results.Length)); | ||||
status.Check(true); | status.Check(true); | ||||
return results[0].Resolve(); | return results[0].Resolve(); | ||||
@@ -268,7 +287,8 @@ namespace Tensorflow | |||||
{ | { | ||||
x as EagerTensor, | x as EagerTensor, | ||||
y as EagerTensor | y as EagerTensor | ||||
}, 2, null, | |||||
}, 2, | |||||
null, null, | |||||
results.Select(x => x.EagerTensorHandle).ToArray(), results.Length)); | results.Select(x => x.EagerTensorHandle).ToArray(), results.Length)); | ||||
status.Check(true); | status.Check(true); | ||||
return results[0].Resolve(); | return results[0].Resolve(); | ||||
@@ -290,7 +310,8 @@ namespace Tensorflow | |||||
{ | { | ||||
x as EagerTensor, | x as EagerTensor, | ||||
y as EagerTensor | y as EagerTensor | ||||
}, 2, null, | |||||
}, 2, | |||||
null, null, | |||||
results.Select(x => x.EagerTensorHandle).ToArray(), results.Length)); | results.Select(x => x.EagerTensorHandle).ToArray(), results.Length)); | ||||
status.Check(true); | status.Check(true); | ||||
return results[0].Resolve(); | return results[0].Resolve(); | ||||
@@ -324,7 +345,8 @@ namespace Tensorflow | |||||
"Sin", name, new IntPtr[] | "Sin", name, new IntPtr[] | ||||
{ | { | ||||
x as EagerTensor, | x as EagerTensor, | ||||
}, 1, null, | |||||
}, 1, | |||||
null, null, | |||||
results.Select(x => x.EagerTensorHandle).ToArray(), results.Length)); | results.Select(x => x.EagerTensorHandle).ToArray(), results.Length)); | ||||
status.Check(true); | status.Check(true); | ||||
return results[0].Resolve(); | return results[0].Resolve(); | ||||
@@ -358,7 +380,8 @@ namespace Tensorflow | |||||
"Sigmoid", name, new IntPtr[] | "Sigmoid", name, new IntPtr[] | ||||
{ | { | ||||
x as EagerTensor, | x as EagerTensor, | ||||
}, 1, null, | |||||
}, 1, | |||||
null, null, | |||||
results.Select(x => x.EagerTensorHandle).ToArray(), results.Length)); | results.Select(x => x.EagerTensorHandle).ToArray(), results.Length)); | ||||
status.Check(true); | status.Check(true); | ||||
return results[0].Resolve(); | return results[0].Resolve(); | ||||
@@ -451,7 +474,8 @@ namespace Tensorflow | |||||
"Tan", name, new IntPtr[] | "Tan", name, new IntPtr[] | ||||
{ | { | ||||
x as EagerTensor, | x as EagerTensor, | ||||
}, 1, null, | |||||
}, 1, | |||||
null, null, | |||||
results.Select(x => x.EagerTensorHandle).ToArray(), results.Length)); | results.Select(x => x.EagerTensorHandle).ToArray(), results.Length)); | ||||
status.Check(true); | status.Check(true); | ||||
return results[0].Resolve(); | return results[0].Resolve(); | ||||
@@ -464,6 +488,20 @@ namespace Tensorflow | |||||
public static Tensor tanh(Tensor x, string name = null) | public static Tensor tanh(Tensor x, string name = null) | ||||
{ | { | ||||
if (tf.context.executing_eagerly()) | |||||
{ | |||||
var results = new[] { new EagerTensor() }; | |||||
using Status status = new Status(c_api.TFE_FastPathExecute(tf.context, tf.context.device_name, | |||||
"Tanh", name, new IntPtr[] | |||||
{ | |||||
x as EagerTensor, | |||||
}, 1, | |||||
null, null, | |||||
results.Select(x => x.EagerTensorHandle).ToArray(), results.Length)); | |||||
status.Check(true); | |||||
return results[0].Resolve(); | |||||
} | |||||
var _op = _op_def_lib._apply_op_helper("Tanh", name, args: new { x }); | var _op = _op_def_lib._apply_op_helper("Tanh", name, args: new { x }); | ||||
return _op.outputs[0]; | return _op.outputs[0]; | ||||
@@ -477,7 +515,25 @@ namespace Tensorflow | |||||
/// <param name="name"></param> | /// <param name="name"></param> | ||||
/// <returns></returns> | /// <returns></returns> | ||||
public static Tensor tanh_grad(Tensor y, Tensor dy, string name = null) | public static Tensor tanh_grad(Tensor y, Tensor dy, string name = null) | ||||
=> _op_def_lib._apply_op_helper("TanhGrad", name: name, args: new { y, dy }).output; | |||||
{ | |||||
if (tf.context.executing_eagerly()) | |||||
{ | |||||
var results = new[] { new EagerTensor() }; | |||||
using Status status = new Status(c_api.TFE_FastPathExecute(tf.context, tf.context.device_name, | |||||
"TanhGrad", name, new IntPtr[] | |||||
{ | |||||
y as EagerTensor, | |||||
dy as EagerTensor | |||||
}, 2, | |||||
null, null, | |||||
results.Select(x => x.EagerTensorHandle).ToArray(), results.Length)); | |||||
status.Check(true); | |||||
return results[0].Resolve(); | |||||
} | |||||
var _op = _op_def_lib._apply_op_helper("TanhGrad", name: name, args: new { y, dy }).output; | |||||
return _op.outputs[0]; | |||||
} | |||||
public static Tensor floor(Tensor x, string name = null) | public static Tensor floor(Tensor x, string name = null) | ||||
{ | { | ||||
@@ -495,6 +551,19 @@ namespace Tensorflow | |||||
public static Tensor greater<Tx, Ty>(Tx x, Ty y, string name = null) | public static Tensor greater<Tx, Ty>(Tx x, Ty y, string name = null) | ||||
{ | { | ||||
if (tf.context.executing_eagerly()) | |||||
{ | |||||
var results = EagerTensorPass.Create(); | |||||
var inputs = EagerTensorPass.From(x, y); | |||||
using Status status = new Status(c_api.TFE_FastPathExecute(tf.context, tf.context.device_name, | |||||
"Greater", name, | |||||
inputs.Points, inputs.Length, | |||||
null, null, | |||||
results.Points, results.Length)); | |||||
status.Check(true); | |||||
return results[0].Resolve(); | |||||
} | |||||
var _op = _op_def_lib._apply_op_helper("Greater", name: name, args: new { x, y }); | var _op = _op_def_lib._apply_op_helper("Greater", name: name, args: new { x, y }); | ||||
return _op.outputs[0]; | return _op.outputs[0]; | ||||
@@ -520,6 +589,21 @@ namespace Tensorflow | |||||
public static Tensor greater_equal<Tx, Ty>(Tx x, Ty y, string name = null) | public static Tensor greater_equal<Tx, Ty>(Tx x, Ty y, string name = null) | ||||
{ | { | ||||
if (tf.context.executing_eagerly()) | |||||
{ | |||||
var results = EagerTensorPass.Create(); | |||||
var inputs = EagerTensorPass.From(x, y); | |||||
using Status status = new Status(c_api.TFE_FastPathExecute(tf.context, tf.context.device_name, | |||||
"GreaterEqual", name, | |||||
inputs.Points, inputs.Length, | |||||
null, null, | |||||
results.Points, results.Length)); | |||||
status.Check(true); | |||||
return results[0].Resolve(); | |||||
} | |||||
var _op = _op_def_lib._apply_op_helper("GreaterEqual", name: name, args: new { x, y }); | var _op = _op_def_lib._apply_op_helper("GreaterEqual", name: name, args: new { x, y }); | ||||
return _op.outputs[0]; | return _op.outputs[0]; | ||||
@@ -529,14 +613,13 @@ namespace Tensorflow | |||||
{ | { | ||||
if (tf.context.executing_eagerly()) | if (tf.context.executing_eagerly()) | ||||
{ | { | ||||
var results = new[] { new EagerTensor() }; | |||||
var results = EagerTensorPass.Create(); | |||||
var inputs = EagerTensorPass.From(x, y); | |||||
using Status status = new Status(c_api.TFE_FastPathExecute(tf.context, tf.context.device_name, | using Status status = new Status(c_api.TFE_FastPathExecute(tf.context, tf.context.device_name, | ||||
"Less", name, new IntPtr[] | |||||
{ | |||||
x as EagerTensor, | |||||
y as EagerTensor | |||||
}, 2, null, | |||||
results.Select(x => x.EagerTensorHandle).ToArray(), results.Length)); | |||||
"Less", name, | |||||
inputs.Points, inputs.Length, | |||||
null, null, | |||||
results.Points, results.Length)); | |||||
status.Check(true); | status.Check(true); | ||||
return results[0].Resolve(); | return results[0].Resolve(); | ||||
} | } | ||||
@@ -548,6 +631,19 @@ namespace Tensorflow | |||||
public static Tensor less_equal<Tx, Ty>(Tx x, Ty y, string name = null) | public static Tensor less_equal<Tx, Ty>(Tx x, Ty y, string name = null) | ||||
{ | { | ||||
if (tf.context.executing_eagerly()) | |||||
{ | |||||
var results = EagerTensorPass.Create(); | |||||
var inputs = EagerTensorPass.From(x, y); | |||||
using Status status = new Status(c_api.TFE_FastPathExecute(tf.context, tf.context.device_name, | |||||
"LessEqual", name, | |||||
inputs.Points, inputs.Length, | |||||
null, null, | |||||
results.Points, results.Length)); | |||||
status.Check(true); | |||||
return results[0].Resolve(); | |||||
} | |||||
var _op = _op_def_lib._apply_op_helper("LessEqual", name: name, args: new { x, y }); | var _op = _op_def_lib._apply_op_helper("LessEqual", name: name, args: new { x, y }); | ||||
return _op.outputs[0]; | return _op.outputs[0]; | ||||
@@ -611,7 +707,8 @@ namespace Tensorflow | |||||
"Square", name, new IntPtr[] | "Square", name, new IntPtr[] | ||||
{ | { | ||||
x as EagerTensor, | x as EagerTensor, | ||||
}, 1, null, | |||||
}, 1, | |||||
null, null, | |||||
results.Select(x => x.EagerTensorHandle).ToArray(), results.Length)); | results.Select(x => x.EagerTensorHandle).ToArray(), results.Length)); | ||||
status.Check(true); | status.Check(true); | ||||
return results[0].Resolve(); | return results[0].Resolve(); | ||||
@@ -663,6 +760,21 @@ namespace Tensorflow | |||||
/// <returns> A `Tensor`. Has the same type as `x`.</returns> | /// <returns> A `Tensor`. Has the same type as `x`.</returns> | ||||
public static Tensor log(Tensor x, string name = null) | public static Tensor log(Tensor x, string name = null) | ||||
{ | { | ||||
if (tf.context.executing_eagerly()) | |||||
{ | |||||
var results = EagerTensorPass.Create(); | |||||
var inputs = EagerTensorPass.From(x); | |||||
using Status status = new Status(c_api.TFE_FastPathExecute(tf.context, tf.context.device_name, | |||||
"Log", name, | |||||
inputs.Points, inputs.Length, | |||||
null, null, | |||||
results.Points, results.Length)); | |||||
status.Check(true); | |||||
return results[0].Resolve(); | |||||
} | |||||
var _op = _op_def_lib._apply_op_helper("Log", name, args: new { x }); | var _op = _op_def_lib._apply_op_helper("Log", name, args: new { x }); | ||||
return _op.outputs[0]; | return _op.outputs[0]; | ||||
@@ -673,12 +785,20 @@ namespace Tensorflow | |||||
if (tf.context.executing_eagerly()) | if (tf.context.executing_eagerly()) | ||||
{ | { | ||||
var results = new[] { new EagerTensor() }; | var results = new[] { new EagerTensor() }; | ||||
var attrs = new object[] | |||||
{ | |||||
"DstT", DstT, | |||||
"Truncate", Truncate | |||||
}; | |||||
using Status status = new Status(c_api.TFE_FastPathExecute(tf.context, tf.context.device_name, | using Status status = new Status(c_api.TFE_FastPathExecute(tf.context, tf.context.device_name, | ||||
"Cast", name, | "Cast", name, | ||||
new IntPtr[] { x as EagerTensor }, 1, | new IntPtr[] { x as EagerTensor }, 1, | ||||
op => wrap_tfe_src.SetOpAttrs(op, "DstT", DstT, "Truncate", Truncate), | |||||
wrap_tfe_src.SetOpAttrs2(attrs), | |||||
op => wrap_tfe_src.SetOpAttrs(op, attrs), | |||||
results.Select(x => x.EagerTensorHandle).ToArray(), results.Length)); | results.Select(x => x.EagerTensorHandle).ToArray(), results.Length)); | ||||
status.Check(true); | status.Check(true); | ||||
return results[0].Resolve(); | return results[0].Resolve(); | ||||
} | } | ||||
@@ -691,14 +811,16 @@ namespace Tensorflow | |||||
{ | { | ||||
if (tf.context.executing_eagerly()) | if (tf.context.executing_eagerly()) | ||||
{ | { | ||||
var results = new[] { new EagerTensor() }; | |||||
var results = EagerTensorPass.Create(); | |||||
var inputs = EagerTensorPass.From(x); | |||||
using Status status = new Status(c_api.TFE_FastPathExecute(tf.context, tf.context.device_name, | using Status status = new Status(c_api.TFE_FastPathExecute(tf.context, tf.context.device_name, | ||||
"Neg", name, new IntPtr[] | |||||
{ | |||||
x as EagerTensor | |||||
}, 2, null, | |||||
results.Select(x => x.EagerTensorHandle).ToArray(), results.Length)); | |||||
"Neg", name, | |||||
inputs.Points, inputs.Length, | |||||
null, null, | |||||
results.Points, results.Length)); | |||||
status.Check(true); | status.Check(true); | ||||
return results[0].Resolve(); | return results[0].Resolve(); | ||||
} | } | ||||
@@ -716,7 +838,8 @@ namespace Tensorflow | |||||
"Sqrt", name, new IntPtr[] | "Sqrt", name, new IntPtr[] | ||||
{ | { | ||||
x as EagerTensor, | x as EagerTensor, | ||||
}, 1, null, | |||||
}, 1, | |||||
null, null, | |||||
results.Select(x => x.EagerTensorHandle).ToArray(), results.Length)); | results.Select(x => x.EagerTensorHandle).ToArray(), results.Length)); | ||||
status.Check(true); | status.Check(true); | ||||
return results[0].Resolve(); | return results[0].Resolve(); | ||||
@@ -737,7 +860,8 @@ namespace Tensorflow | |||||
{ | { | ||||
x as EagerTensor, | x as EagerTensor, | ||||
y as EagerTensor | y as EagerTensor | ||||
}, 2, null, | |||||
}, 2, | |||||
null, null, | |||||
results.Select(x => x.EagerTensorHandle).ToArray(), results.Length)); | results.Select(x => x.EagerTensorHandle).ToArray(), results.Length)); | ||||
status.Check(true); | status.Check(true); | ||||
return results[0].Resolve(); | return results[0].Resolve(); | ||||
@@ -758,7 +882,8 @@ namespace Tensorflow | |||||
{ | { | ||||
x as EagerTensor, | x as EagerTensor, | ||||
y as EagerTensor | y as EagerTensor | ||||
}, 2, null, | |||||
}, 2, | |||||
null, null, | |||||
results.Select(x => x.EagerTensorHandle).ToArray(), results.Length)); | results.Select(x => x.EagerTensorHandle).ToArray(), results.Length)); | ||||
status.Check(true); | status.Check(true); | ||||
return results[0].Resolve(); | return results[0].Resolve(); | ||||
@@ -786,7 +911,8 @@ namespace Tensorflow | |||||
{ | { | ||||
x as EagerTensor, | x as EagerTensor, | ||||
y as EagerTensor | y as EagerTensor | ||||
}, 2, null, | |||||
}, 2, | |||||
null, null, | |||||
results.Select(x => x.EagerTensorHandle).ToArray(), results.Length)); | results.Select(x => x.EagerTensorHandle).ToArray(), results.Length)); | ||||
status.Check(true); | status.Check(true); | ||||
return results[0].Resolve(); | return results[0].Resolve(); | ||||
@@ -815,7 +941,8 @@ namespace Tensorflow | |||||
{ | { | ||||
x as EagerTensor, | x as EagerTensor, | ||||
y as EagerTensor | y as EagerTensor | ||||
}, 2, null, | |||||
}, 2, | |||||
null, null, | |||||
results.Select(x => x.EagerTensorHandle).ToArray(), results.Length)); | results.Select(x => x.EagerTensorHandle).ToArray(), results.Length)); | ||||
status.Check(true); | status.Check(true); | ||||
return results[0].Resolve(); | return results[0].Resolve(); | ||||
@@ -836,7 +963,8 @@ namespace Tensorflow | |||||
{ | { | ||||
y as EagerTensor, | y as EagerTensor, | ||||
x as EagerTensor | x as EagerTensor | ||||
}, 2, null, | |||||
}, 2, | |||||
null, null, | |||||
results.Select(x => x.EagerTensorHandle).ToArray(), results.Length)); | results.Select(x => x.EagerTensorHandle).ToArray(), results.Length)); | ||||
status.Check(true); | status.Check(true); | ||||
return results[0].Resolve(); | return results[0].Resolve(); | ||||
@@ -856,7 +984,8 @@ namespace Tensorflow | |||||
{ | { | ||||
x as EagerTensor, | x as EagerTensor, | ||||
y as EagerTensor | y as EagerTensor | ||||
}, 2, null, | |||||
}, 2, | |||||
null, null, | |||||
results.Select(x => x.EagerTensorHandle).ToArray(), results.Length)); | results.Select(x => x.EagerTensorHandle).ToArray(), results.Length)); | ||||
status.Check(true); | status.Check(true); | ||||
return results[0].Resolve(); | return results[0].Resolve(); | ||||
@@ -877,7 +1006,8 @@ namespace Tensorflow | |||||
{ | { | ||||
x as EagerTensor, | x as EagerTensor, | ||||
y as EagerTensor, | y as EagerTensor, | ||||
}, 2, null, | |||||
}, 2, | |||||
null, null, | |||||
results.Select(x => x.EagerTensorHandle).ToArray(), results.Length)); | results.Select(x => x.EagerTensorHandle).ToArray(), results.Length)); | ||||
status.Check(true); | status.Check(true); | ||||
return results[0].Resolve(); | return results[0].Resolve(); | ||||
@@ -905,7 +1035,8 @@ namespace Tensorflow | |||||
{ | { | ||||
x as EagerTensor, | x as EagerTensor, | ||||
y as EagerTensor | y as EagerTensor | ||||
}, 2, null, | |||||
}, 2, | |||||
null, null, | |||||
results.Select(x => x.EagerTensorHandle).ToArray(), results.Length)); | results.Select(x => x.EagerTensorHandle).ToArray(), results.Length)); | ||||
status.Check(true); | status.Check(true); | ||||
return results[0].Resolve(); | return results[0].Resolve(); | ||||
@@ -918,6 +1049,21 @@ namespace Tensorflow | |||||
public static Tensor reciprocal(Tensor x, string name = null) | public static Tensor reciprocal(Tensor x, string name = null) | ||||
{ | { | ||||
if (tf.context.executing_eagerly()) | |||||
{ | |||||
var results = EagerTensorPass.Create(); | |||||
var inputs = EagerTensorPass.From(x); | |||||
using Status status = new Status(c_api.TFE_FastPathExecute(tf.context, tf.context.device_name, | |||||
"Reciprocal", name, | |||||
inputs.Points, inputs.Length, | |||||
null, null, | |||||
results.Points, results.Length)); | |||||
status.Check(true); | |||||
return results[0].Resolve(); | |||||
} | |||||
var _op = _op_def_lib._apply_op_helper("Reciprocal", name, args: new { x }); | var _op = _op_def_lib._apply_op_helper("Reciprocal", name, args: new { x }); | ||||
return _op.outputs[0]; | return _op.outputs[0]; | ||||
@@ -933,7 +1079,8 @@ namespace Tensorflow | |||||
{ | { | ||||
x as EagerTensor, | x as EagerTensor, | ||||
y as EagerTensor | y as EagerTensor | ||||
}, 2, null, | |||||
}, 2, | |||||
null, null, | |||||
results.Select(x => x.EagerTensorHandle).ToArray(), results.Length)); | results.Select(x => x.EagerTensorHandle).ToArray(), results.Length)); | ||||
status.Check(true); | status.Check(true); | ||||
return results[0].Resolve(); | return results[0].Resolve(); | ||||
@@ -954,7 +1101,8 @@ namespace Tensorflow | |||||
{ | { | ||||
x as EagerTensor, | x as EagerTensor, | ||||
y as EagerTensor | y as EagerTensor | ||||
}, 2, null, | |||||
}, 2, | |||||
null, null, | |||||
results.Select(x => x.EagerTensorHandle).ToArray(), results.Length)); | results.Select(x => x.EagerTensorHandle).ToArray(), results.Length)); | ||||
status.Check(true); | status.Check(true); | ||||
return results[0].Resolve(); | return results[0].Resolve(); | ||||
@@ -978,18 +1126,19 @@ namespace Tensorflow | |||||
{ | { | ||||
if (tf.context.executing_eagerly()) | if (tf.context.executing_eagerly()) | ||||
{ | { | ||||
var results = new[] { new EagerTensor() }; | |||||
var results = EagerTensorPass.Create(); | |||||
var inputs = EagerTensorPass.From(a, b); | |||||
var attrs = new object[] | |||||
{ | |||||
"transpose_a", transpose_a, | |||||
"transpose_b", transpose_b | |||||
}; | |||||
using Status status = new Status(c_api.TFE_FastPathExecute(tf.context, tf.context.device_name, | using Status status = new Status(c_api.TFE_FastPathExecute(tf.context, tf.context.device_name, | ||||
"MatMul", name, | "MatMul", name, | ||||
new IntPtr[] | |||||
{ | |||||
a as EagerTensor, | |||||
b as EagerTensor | |||||
}, 2, | |||||
op => wrap_tfe_src.SetOpAttrs(op, | |||||
"transpose_a", transpose_a, | |||||
"transpose_b", transpose_b), | |||||
results.Select(x => x.EagerTensorHandle).ToArray(), results.Length)); | |||||
inputs.Points, inputs.Length, | |||||
wrap_tfe_src.SetOpAttrs2(attrs), | |||||
op => wrap_tfe_src.SetOpAttrs(op, attrs), | |||||
results.Points, results.Length)); | |||||
status.Check(true); | status.Check(true); | ||||
return results[0].Resolve(); | return results[0].Resolve(); | ||||
} | } | ||||
@@ -1043,6 +1192,21 @@ namespace Tensorflow | |||||
/// <returns></returns> | /// <returns></returns> | ||||
public static Tensor maximum<T1, T2>(T1 x, T2 y, string name = null) | public static Tensor maximum<T1, T2>(T1 x, T2 y, string name = null) | ||||
{ | { | ||||
if (tf.context.executing_eagerly()) | |||||
{ | |||||
var results = EagerTensorPass.Create(); | |||||
var inputs = EagerTensorPass.From(x, y); | |||||
using Status status = new Status(c_api.TFE_FastPathExecute(tf.context, tf.context.device_name, | |||||
"Maximum", name, | |||||
inputs.Points, inputs.Length, | |||||
null, null, | |||||
results.Points, results.Length)); | |||||
status.Check(true); | |||||
return results[0].Resolve(); | |||||
} | |||||
var _op = _op_def_lib._apply_op_helper("Maximum", name, args: new { x, y }); | var _op = _op_def_lib._apply_op_helper("Maximum", name, args: new { x, y }); | ||||
return _op.outputs[0]; | return _op.outputs[0]; | ||||
@@ -1050,6 +1214,21 @@ namespace Tensorflow | |||||
public static Tensor minimum<T1, T2>(T1 x, T2 y, string name = null) | public static Tensor minimum<T1, T2>(T1 x, T2 y, string name = null) | ||||
{ | { | ||||
if (tf.context.executing_eagerly()) | |||||
{ | |||||
var results = EagerTensorPass.Create(); | |||||
var inputs = EagerTensorPass.From(x, y); | |||||
using Status status = new Status(c_api.TFE_FastPathExecute(tf.context, tf.context.device_name, | |||||
"Minimum", name, | |||||
inputs.Points, inputs.Length, | |||||
null, null, | |||||
results.Points, results.Length)); | |||||
status.Check(true); | |||||
return results[0].Resolve(); | |||||
} | |||||
var _op = _op_def_lib._apply_op_helper("Minimum", name, args: new { x, y }); | var _op = _op_def_lib._apply_op_helper("Minimum", name, args: new { x, y }); | ||||
return _op.outputs[0]; | return _op.outputs[0]; | ||||
@@ -1093,7 +1272,8 @@ namespace Tensorflow | |||||
{ | { | ||||
x as EagerTensor, | x as EagerTensor, | ||||
y as EagerTensor | y as EagerTensor | ||||
}, 2, null, | |||||
}, 2, | |||||
null, null, | |||||
results.Select(x => x.EagerTensorHandle).ToArray(), results.Length)); | results.Select(x => x.EagerTensorHandle).ToArray(), results.Length)); | ||||
status.Check(true); | status.Check(true); | ||||
return results[0].Resolve(); | return results[0].Resolve(); | ||||
@@ -1108,17 +1288,18 @@ namespace Tensorflow | |||||
{ | { | ||||
if (tf.context.executing_eagerly()) | if (tf.context.executing_eagerly()) | ||||
{ | { | ||||
var results = new[] { new EagerTensor() }; | |||||
var results = EagerTensorPass.Create(); | |||||
var inputs = EagerTensorPass.From(input, axis); | |||||
var attrs = new object[] { "keep_dims", keep_dims }; | |||||
using Status status = new Status(c_api.TFE_FastPathExecute(tf.context, tf.context.device_name, | using Status status = new Status(c_api.TFE_FastPathExecute(tf.context, tf.context.device_name, | ||||
"Sum", name, | "Sum", name, | ||||
new IntPtr[] | |||||
{ | |||||
input as EagerTensor, | |||||
axis as EagerTensor | |||||
}, 2, | |||||
op => wrap_tfe_src.SetOpAttrs(op, "keep_dims", keep_dims), | |||||
results.Select(x => x.EagerTensorHandle).ToArray(), results.Length)); | |||||
inputs.Points, inputs.Length, | |||||
wrap_tfe_src.SetOpAttrs2(attrs), | |||||
op => wrap_tfe_src.SetOpAttrs(op, attrs), | |||||
results.Points, results.Length)); | |||||
status.Check(true); | status.Check(true); | ||||
return results[0].Resolve(); | return results[0].Resolve(); | ||||
} | } | ||||
@@ -1169,7 +1350,8 @@ namespace Tensorflow | |||||
start as EagerTensor, | start as EagerTensor, | ||||
limit as EagerTensor, | limit as EagerTensor, | ||||
delta as EagerTensor | delta as EagerTensor | ||||
}, 3, null, | |||||
}, 3, | |||||
null, null, | |||||
results.Select(x => x.EagerTensorHandle).ToArray(), results.Length)); | results.Select(x => x.EagerTensorHandle).ToArray(), results.Length)); | ||||
status.Check(true); | status.Check(true); | ||||
return results[0].Resolve(); | return results[0].Resolve(); | ||||
@@ -11,14 +11,15 @@ namespace Tensorflow | |||||
{ | { | ||||
public static EagerTensor mul(IntPtr x, IntPtr y, string name = null) | public static EagerTensor mul(IntPtr x, IntPtr y, string name = null) | ||||
{ | { | ||||
var results = new[] { new EagerTensor() }; | |||||
var results = EagerTensorPass.Create(); | |||||
using Status status = new Status(c_api.TFE_FastPathExecute(tf.context, tf.context.device_name, | using Status status = new Status(c_api.TFE_FastPathExecute(tf.context, tf.context.device_name, | ||||
"Mul", name, new IntPtr[] | "Mul", name, new IntPtr[] | ||||
{ | { | ||||
x, | x, | ||||
y, | y, | ||||
}, 2, null, | |||||
results.Select(x => x.EagerTensorHandle).ToArray(), results.Length)); | |||||
}, 2, | |||||
null, null, | |||||
results.Points, results.Length)); | |||||
status.Check(true); | status.Check(true); | ||||
return results[0].Resolve(); | return results[0].Resolve(); | ||||
} | } | ||||
@@ -42,17 +42,20 @@ namespace Tensorflow | |||||
if (tf.context.executing_eagerly()) | if (tf.context.executing_eagerly()) | ||||
{ | { | ||||
var results = new[] { new EagerTensor() }; | |||||
var results = EagerTensorPass.Create(); | |||||
var attrs = new object[] | |||||
{ | |||||
"seed", seed, | |||||
"seed2", seed2, | |||||
"dtype", dtype | |||||
}; | |||||
var inputs = EagerTensorPass.From(shape); | |||||
using Status status = new Status(c_api.TFE_FastPathExecute(tf.context, tf.context.device_name, | using Status status = new Status(c_api.TFE_FastPathExecute(tf.context, tf.context.device_name, | ||||
"RandomStandardNormal", name, new IntPtr[] | |||||
{ | |||||
shape as EagerTensor, | |||||
}, 1, | |||||
op => wrap_tfe_src.SetOpAttrs(op, | |||||
"seed", seed, | |||||
"seed2", seed2, | |||||
"dtype", dtype), | |||||
results.Select(x => x.EagerTensorHandle).ToArray(), results.Length)); | |||||
"RandomStandardNormal", name, | |||||
inputs.Points, inputs.Length, | |||||
wrap_tfe_src.SetOpAttrs2(attrs), | |||||
op => wrap_tfe_src.SetOpAttrs(op, attrs), | |||||
results.Points, results.Length)); | |||||
status.Check(true); | status.Check(true); | ||||
return results[0].Resolve(); | return results[0].Resolve(); | ||||
} | } | ||||
@@ -146,6 +149,26 @@ namespace Tensorflow | |||||
if (!seed2.HasValue) | if (!seed2.HasValue) | ||||
seed2 = 0; | seed2 = 0; | ||||
if (tf.context.executing_eagerly()) | |||||
{ | |||||
var results = EagerTensorPass.Create(); | |||||
var inputs = EagerTensorPass.From(shape); | |||||
var attrs = new object[] | |||||
{ | |||||
"seed", seed, | |||||
"seed2", seed2, | |||||
"dtype", dtype | |||||
}; | |||||
using Status status = new Status(c_api.TFE_FastPathExecute(tf.context, tf.context.device_name, | |||||
"TruncatedNormal", name, | |||||
inputs.Points, inputs.Length, | |||||
wrap_tfe_src.SetOpAttrs2(attrs), | |||||
op => wrap_tfe_src.SetOpAttrs(op, attrs), | |||||
results.Points, results.Length)); | |||||
status.Check(true); | |||||
return results[0].Resolve(); | |||||
} | |||||
var _op = _op_def_lib._apply_op_helper("TruncatedNormal", | var _op = _op_def_lib._apply_op_helper("TruncatedNormal", | ||||
name: name, | name: name, | ||||
args: new { shape, dtype, seed, seed2 }); | args: new { shape, dtype, seed, seed2 }); | ||||
@@ -29,15 +29,13 @@ namespace Tensorflow | |||||
{ | { | ||||
if (tf.context.executing_eagerly()) | if (tf.context.executing_eagerly()) | ||||
{ | { | ||||
var results = new[] { new EagerTensor() }; | |||||
var results = EagerTensorPass.Create(); | |||||
var inputs = EagerTensorPass.From(resource, value); | |||||
using Status status = new Status(c_api.TFE_FastPathExecute(tf.context, tf.context.device_name, | using Status status = new Status(c_api.TFE_FastPathExecute(tf.context, tf.context.device_name, | ||||
"AssignSubVariableOp", name, | "AssignSubVariableOp", name, | ||||
new IntPtr[] | |||||
{ | |||||
resource as EagerTensor, | |||||
value as EagerTensor | |||||
}, 2, null, | |||||
results.Select(x => x.EagerTensorHandle).ToArray(), results.Length)); | |||||
inputs.Points, inputs.Length, | |||||
null, null, | |||||
results.Points, results.Length)); | |||||
status.Check(true); | status.Check(true); | ||||
return results[0].Resolve(); | return results[0].Resolve(); | ||||
} | } | ||||
@@ -56,13 +54,11 @@ namespace Tensorflow | |||||
{ | { | ||||
if (tf.context.executing_eagerly()) | if (tf.context.executing_eagerly()) | ||||
{ | { | ||||
var inputs = EagerTensorPass.From(resource, value); | |||||
using Status status = new Status(c_api.TFE_FastPathExecute(tf.context, tf.context.device_name, | using Status status = new Status(c_api.TFE_FastPathExecute(tf.context, tf.context.device_name, | ||||
"AssignAddVariableOp", name, | "AssignAddVariableOp", name, | ||||
new IntPtr[] | |||||
{ | |||||
resource as EagerTensor, | |||||
value as EagerTensor | |||||
}, 2, null, | |||||
inputs.Points, inputs.Length, | |||||
null, null, | |||||
null, 0)); | null, 0)); | ||||
status.Check(true); | status.Check(true); | ||||
return null; | return null; | ||||
@@ -75,13 +71,11 @@ namespace Tensorflow | |||||
{ | { | ||||
if (tf.context.executing_eagerly()) | if (tf.context.executing_eagerly()) | ||||
{ | { | ||||
var inputs = EagerTensorPass.From(resource, value); | |||||
using Status status = new Status(c_api.TFE_FastPathExecute(tf.context, tf.context.device_name, | using Status status = new Status(c_api.TFE_FastPathExecute(tf.context, tf.context.device_name, | ||||
"AssignVariableOp", name, | "AssignVariableOp", name, | ||||
new IntPtr[] | |||||
{ | |||||
resource as EagerTensor, | |||||
value as EagerTensor | |||||
}, 2, null, | |||||
inputs.Points, inputs.Length, | |||||
null, null, | |||||
null, 0)); | null, 0)); | ||||
status.Check(true); | status.Check(true); | ||||
return null; | return null; | ||||
@@ -100,7 +94,8 @@ namespace Tensorflow | |||||
using Status status = new Status(c_api.TFE_FastPathExecute(tf.context, tf.context.device_name, | using Status status = new Status(c_api.TFE_FastPathExecute(tf.context, tf.context.device_name, | ||||
"VarIsInitializedOp", name, | "VarIsInitializedOp", name, | ||||
new IntPtr[] { resource as EagerTensor }, | new IntPtr[] { resource as EagerTensor }, | ||||
1, null, | |||||
1, | |||||
null, null, | |||||
results.Select(x => x.EagerTensorHandle).ToArray(), results.Length)); | results.Select(x => x.EagerTensorHandle).ToArray(), results.Length)); | ||||
status.Check(true); | status.Check(true); | ||||
return results[0].Resolve(); | return results[0].Resolve(); | ||||
@@ -125,15 +120,19 @@ namespace Tensorflow | |||||
{ | { | ||||
if(tf.context.executing_eagerly()) | if(tf.context.executing_eagerly()) | ||||
{ | { | ||||
var results = new[] { new EagerTensor() }; | |||||
var results = EagerTensorPass.Create(); | |||||
var attrs = new object[] | |||||
{ | |||||
"container", container, | |||||
"shared_name", shared_name, | |||||
"dtype", dtype, | |||||
"shape", shape.dims | |||||
}; | |||||
using Status status = new Status(c_api.TFE_FastPathExecute(tf.context, tf.context.device_name, | using Status status = new Status(c_api.TFE_FastPathExecute(tf.context, tf.context.device_name, | ||||
"VarHandleOp", name, null, 0, | "VarHandleOp", name, null, 0, | ||||
op => wrap_tfe_src.SetOpAttrs(op, | |||||
"container", container, | |||||
"shared_name", shared_name, | |||||
"dtype", dtype, | |||||
"shape", shape.dims), | |||||
results.Select(x => x.EagerTensorHandle).ToArray(), results.Length)); | |||||
wrap_tfe_src.SetOpAttrs2(attrs), | |||||
op => wrap_tfe_src.SetOpAttrs(op, attrs), | |||||
results.Points, results.Length)); | |||||
status.Check(true); | status.Check(true); | ||||
return results[0].Resolve(); | return results[0].Resolve(); | ||||
} | } | ||||
@@ -163,6 +162,7 @@ namespace Tensorflow | |||||
using Status status = new Status(c_api.TFE_FastPathExecute(tf.context, tf.context.device_name, | using Status status = new Status(c_api.TFE_FastPathExecute(tf.context, tf.context.device_name, | ||||
"ReadVariableOp", name, | "ReadVariableOp", name, | ||||
new IntPtr[] { resource as EagerTensor }, 1, | new IntPtr[] { resource as EagerTensor }, 1, | ||||
wrap_tfe_src.SetOpAttrs2("dtype", dtype), | |||||
op => wrap_tfe_src.SetOpAttrs(op, "dtype", dtype), | op => wrap_tfe_src.SetOpAttrs(op, "dtype", dtype), | ||||
results.Select(x => x.EagerTensorHandle).ToArray(), results.Length)); | results.Select(x => x.EagerTensorHandle).ToArray(), results.Length)); | ||||
status.Check(true); | status.Check(true); | ||||
@@ -348,6 +348,14 @@ namespace Tensorflow | |||||
/// <returns>A 1-D Tensor, the output shape as if keepdims were set to True.</returns> | /// <returns>A 1-D Tensor, the output shape as if keepdims were set to True.</returns> | ||||
public static Tensor reduced_shape(Tensor input_shape, Tensor axes) | public static Tensor reduced_shape(Tensor input_shape, Tensor axes) | ||||
{ | { | ||||
if(tf.context.executing_eagerly()) | |||||
{ | |||||
var input_shape_val = input_shape.numpy(); | |||||
var axes_val = (int)axes.numpy(); | |||||
input_shape_val[axes_val] = 1; | |||||
return tf.constant(input_shape_val); | |||||
} | |||||
input_shape = to_int32(input_shape); | input_shape = to_int32(input_shape); | ||||
axes = to_int32(axes); | axes = to_int32(axes); | ||||
@@ -522,7 +530,8 @@ namespace Tensorflow | |||||
public static Tensor reduce_sum(Tensor input_tensor, int axis, bool keepdims = false, string name = null) | public static Tensor reduce_sum(Tensor input_tensor, int axis, bool keepdims = false, string name = null) | ||||
{ | { | ||||
var m = gen_math_ops._sum(input_tensor, axis, keep_dims: keepdims, name: name); | |||||
var dims = _ReductionDims(input_tensor, axis); | |||||
var m = gen_math_ops._sum(input_tensor, dims, keep_dims: keepdims, name: name); | |||||
return _may_reduce_to_scalar(keepdims, axis, m); | return _may_reduce_to_scalar(keepdims, axis, m); | ||||
} | } | ||||
@@ -54,8 +54,11 @@ namespace Tensorflow | |||||
public static void Decrease(IntPtr handle) | public static void Decrease(IntPtr handle) | ||||
{ | { | ||||
if (handle != IntPtr.Zero && container.ContainsKey(handle)) | |||||
container[handle].RefCounter--; | |||||
lock (locker) | |||||
{ | |||||
if (handle != IntPtr.Zero && container.ContainsKey(handle)) | |||||
container[handle].RefCounter--; | |||||
} | |||||
} | } | ||||
private static void Recycle() | private static void Recycle() | ||||
@@ -64,7 +67,7 @@ namespace Tensorflow | |||||
lock (locker) | lock (locker) | ||||
{ | { | ||||
var items = container.Values | var items = container.Values | ||||
.Where(x => x.RefCounter <= 0 && (DateTime.Now - x.LastUpdateTime).TotalMilliseconds > 100) | |||||
.Where(x => x.RefCounter <= 0 && (DateTime.Now - x.LastUpdateTime).TotalMilliseconds > 300) | |||||
.ToArray(); | .ToArray(); | ||||
foreach (var item in items) | foreach (var item in items) | ||||
@@ -74,15 +77,15 @@ namespace Tensorflow | |||||
switch (item.ItemType) | switch (item.ItemType) | ||||
{ | { | ||||
case GCItemType.TensorHandle: | case GCItemType.TensorHandle: | ||||
// print($"c_api.TF_DeleteTensor({item.Handle.ToString("x16")})"); | |||||
//print($"c_api.TF_DeleteTensor({item.Handle.ToString("x16")})"); | |||||
c_api.TF_DeleteTensor(item.Handle); | c_api.TF_DeleteTensor(item.Handle); | ||||
break; | break; | ||||
case GCItemType.LocalTensorHandle: | case GCItemType.LocalTensorHandle: | ||||
// print($"c_api.TFE_DeleteTensorHandle({item.Handle.ToString("x16")})"); | |||||
//print($"c_api.TFE_DeleteTensorHandle({item.Handle.ToString("x16")})"); | |||||
c_api.TFE_DeleteTensorHandle(item.Handle); | c_api.TFE_DeleteTensorHandle(item.Handle); | ||||
break; | break; | ||||
case GCItemType.EagerTensorHandle: | case GCItemType.EagerTensorHandle: | ||||
// print($"c_api.TFE_DeleteEagerTensor({item.Handle.ToString("x16")})"); | |||||
//print($"c_api.TFE_DeleteEagerTensor({item.Handle.ToString("x16")})"); | |||||
c_api.TFE_DeleteEagerTensor(item.Handle); | c_api.TFE_DeleteEagerTensor(item.Handle); | ||||
break; | break; | ||||
default: | default: | ||||
@@ -5,7 +5,7 @@ | |||||
<AssemblyName>TensorFlow.NET</AssemblyName> | <AssemblyName>TensorFlow.NET</AssemblyName> | ||||
<RootNamespace>Tensorflow</RootNamespace> | <RootNamespace>Tensorflow</RootNamespace> | ||||
<TargetTensorFlow>2.2.0</TargetTensorFlow> | <TargetTensorFlow>2.2.0</TargetTensorFlow> | ||||
<Version>0.20.0-alpha2</Version> | |||||
<Version>0.20.0-preview1</Version> | |||||
<LangVersion>8.0</LangVersion> | <LangVersion>8.0</LangVersion> | ||||
<Authors>Haiping Chen, Meinrad Recheis, Eli Belash</Authors> | <Authors>Haiping Chen, Meinrad Recheis, Eli Belash</Authors> | ||||
<Company>SciSharp STACK</Company> | <Company>SciSharp STACK</Company> | ||||
@@ -38,7 +38,8 @@ namespace Tensorflow | |||||
_TensorLike, | _TensorLike, | ||||
ITensorOrTensorArray, | ITensorOrTensorArray, | ||||
IPackable<Tensor>, | IPackable<Tensor>, | ||||
ICanBeFlattened | |||||
ICanBeFlattened, | |||||
IPointerInputs | |||||
{ | { | ||||
protected int _id; | protected int _id; | ||||
private readonly Operation _op; | private readonly Operation _op; | ||||
@@ -280,6 +281,10 @@ namespace Tensorflow | |||||
} else | } else | ||||
throw new InvalidOperationException($"Tensor.AllocationHandle is not null ({AllocationHandle}) but AllocationType is not matched to a C# allocation type ({AllocationType})."); | throw new InvalidOperationException($"Tensor.AllocationHandle is not null ({AllocationHandle}) but AllocationType is not matched to a C# allocation type ({AllocationType})."); | ||||
} | } | ||||
public virtual IntPtr ToPointer() | |||||
=> _handle; | |||||
public bool IsDisposed => _disposed; | public bool IsDisposed => _disposed; | ||||
// public int tensor_int_val { get; set; } | // public int tensor_int_val { get; set; } | ||||
@@ -199,6 +199,7 @@ namespace Tensorflow | |||||
=> type switch | => type switch | ||||
{ | { | ||||
TF_DataType.TF_STRING => "string", | TF_DataType.TF_STRING => "string", | ||||
TF_DataType.TF_UINT8 => "uint8", | |||||
TF_DataType.TF_INT32 => "int32", | TF_DataType.TF_INT32 => "int32", | ||||
TF_DataType.TF_FLOAT => "float32", | TF_DataType.TF_FLOAT => "float32", | ||||
TF_DataType.TF_BOOL => "bool", | TF_DataType.TF_BOOL => "bool", | ||||
@@ -72,6 +72,7 @@ namespace Tensorflow | |||||
alpha, | alpha, | ||||
delta | delta | ||||
}, 3, | }, 3, | ||||
wrap_tfe_src.SetOpAttrs2("use_locking", use_locking), | |||||
op => wrap_tfe_src.SetOpAttrs(op, "use_locking", use_locking), | op => wrap_tfe_src.SetOpAttrs(op, "use_locking", use_locking), | ||||
null, 0)); | null, 0)); | ||||
status.Check(true); | status.Check(true); | ||||
@@ -0,0 +1,22 @@ | |||||
using System; | |||||
using System.Collections.Generic; | |||||
using System.Linq; | |||||
using System.Text; | |||||
using Tensorflow.Eager; | |||||
namespace Tensorflow | |||||
{ | |||||
public class EagerTensorPass : PointerInputs<EagerTensor> | |||||
{ | |||||
public EagerTensorPass(params EagerTensor[] tensors) | |||||
{ | |||||
data = tensors; | |||||
} | |||||
public static EagerTensorPass Create(int count = 1) | |||||
=> new EagerTensorPass(Enumerable.Range(0, count).Select(x => new EagerTensor()).ToArray()); | |||||
public static EagerTensorPass From(params object[] objects) | |||||
=> new EagerTensorPass(objects.Select(x => ops.convert_to_tensor(x) as EagerTensor).ToArray()); | |||||
} | |||||
} |
@@ -0,0 +1,11 @@ | |||||
using System; | |||||
using System.Collections.Generic; | |||||
using System.Text; | |||||
namespace Tensorflow | |||||
{ | |||||
public interface IPointerInputs | |||||
{ | |||||
public IntPtr ToPointer(); | |||||
} | |||||
} |
@@ -0,0 +1,30 @@ | |||||
using System; | |||||
using System.Collections.Generic; | |||||
using System.Text; | |||||
using System.Linq; | |||||
namespace Tensorflow | |||||
{ | |||||
public abstract class PointerInputs<T> | |||||
where T : IPointerInputs, new() | |||||
{ | |||||
protected T[] data; | |||||
public int Length | |||||
=> data.Length; | |||||
public IntPtr[] Points | |||||
=> data.Select(x => x.ToPointer()).ToArray(); | |||||
public PointerInputs(params T[] data) | |||||
=> this.data = data; | |||||
public T this[int idx] | |||||
=> data[idx]; | |||||
public T[] Items | |||||
=> data; | |||||
public static implicit operator IntPtr[](PointerInputs<T> inputs) | |||||
=> inputs.data.Select(x => x.ToPointer()).ToArray(); | |||||
} | |||||
} |
@@ -0,0 +1,31 @@ | |||||
using System; | |||||
using System.Collections.Generic; | |||||
using System.Text; | |||||
using Tensorflow.Eager; | |||||
namespace Tensorflow | |||||
{ | |||||
public class TensorManager | |||||
{ | |||||
Dictionary<IntPtr, EagerTensor> tensors; | |||||
public TensorManager() | |||||
{ | |||||
tensors = new Dictionary<IntPtr, EagerTensor>(); | |||||
} | |||||
public EagerTensor GetTensor(IntPtr handle) | |||||
{ | |||||
if (tensors.ContainsKey(handle)) | |||||
return tensors[handle]; | |||||
//return new EagerTensor(handle); | |||||
tensors[handle] = new EagerTensor(handle); | |||||
return tensors[handle]; | |||||
} | |||||
public void Reset() | |||||
{ | |||||
tensors.Clear(); | |||||
} | |||||
} | |||||
} |
@@ -54,7 +54,7 @@ namespace Tensorflow | |||||
public BaseResourceVariable(IntPtr handle, IntPtr tensor) | public BaseResourceVariable(IntPtr handle, IntPtr tensor) | ||||
{ | { | ||||
_handle = handle; | _handle = handle; | ||||
this.handle = new EagerTensor(tensor); | |||||
this.handle = tf.tensorMgr.GetTensor(tensor); | |||||
} | } | ||||
public void __init__(bool trainable = true, | public void __init__(bool trainable = true, | ||||
@@ -22,21 +22,24 @@ namespace Tensorflow | |||||
{ | { | ||||
public partial class ResourceVariable | public partial class ResourceVariable | ||||
{ | { | ||||
public static OpDefLibrary _op_def_lib = new OpDefLibrary(); | |||||
public static Tensor operator +(ResourceVariable x, int y) => op_helper("add", x, y); | public static Tensor operator +(ResourceVariable x, int y) => op_helper("add", x, y); | ||||
public static Tensor operator +(ResourceVariable x, float y) => op_helper("add", x, y); | public static Tensor operator +(ResourceVariable x, float y) => op_helper("add", x, y); | ||||
public static Tensor operator +(ResourceVariable x, double y) => op_helper("add", x, y); | public static Tensor operator +(ResourceVariable x, double y) => op_helper("add", x, y); | ||||
public static Tensor operator +(ResourceVariable x, ResourceVariable y) => op_helper("add", x, y); | |||||
public static Tensor operator -(ResourceVariable x, int y) => op_helper("sub", x, y); | public static Tensor operator -(ResourceVariable x, int y) => op_helper("sub", x, y); | ||||
public static Tensor operator -(ResourceVariable x, float y) => op_helper("sub", x, y); | public static Tensor operator -(ResourceVariable x, float y) => op_helper("sub", x, y); | ||||
public static Tensor operator -(ResourceVariable x, double y) => op_helper("sub", x, y); | public static Tensor operator -(ResourceVariable x, double y) => op_helper("sub", x, y); | ||||
public static Tensor operator -(ResourceVariable x, Tensor y) => op_helper("sub", x, y); | public static Tensor operator -(ResourceVariable x, Tensor y) => op_helper("sub", x, y); | ||||
public static Tensor operator -(ResourceVariable x, ResourceVariable y) => op_helper("sub", x, y); | |||||
public static Tensor operator *(ResourceVariable x, ResourceVariable y) => op_helper("mul", x, y); | public static Tensor operator *(ResourceVariable x, ResourceVariable y) => op_helper("mul", x, y); | ||||
public static Tensor operator *(ResourceVariable x, NDArray y) => op_helper("mul", x, y); | public static Tensor operator *(ResourceVariable x, NDArray y) => op_helper("mul", x, y); | ||||
public static Tensor operator <(ResourceVariable x, Tensor y) => gen_math_ops.less(x.value(), y); | |||||
public static Tensor operator <(ResourceVariable x, Tensor y) => op_helper("less", x, y); | |||||
public static Tensor operator >(ResourceVariable x, Tensor y) => gen_math_ops.greater(x.value(), y); | |||||
public static Tensor operator >(ResourceVariable x, Tensor y) => op_helper("greater", x, y); | |||||
private static Tensor op_helper<T>(string default_name, ResourceVariable x, T y) | private static Tensor op_helper<T>(string default_name, ResourceVariable x, T y) | ||||
=> tf_with(ops.name_scope(null, default_name, new { x, y }), scope => | => tf_with(ops.name_scope(null, default_name, new { x, y }), scope => | ||||
@@ -58,6 +61,12 @@ namespace Tensorflow | |||||
case "mul": | case "mul": | ||||
result = gen_math_ops.mul(xVal, yTensor, name: name); | result = gen_math_ops.mul(xVal, yTensor, name: name); | ||||
break; | break; | ||||
case "less": | |||||
result = gen_math_ops.less(xVal, yTensor, name); | |||||
break; | |||||
case "greater": | |||||
result = gen_math_ops.greater(xVal, yTensor, name); | |||||
break; | |||||
default: | default: | ||||
throw new NotImplementedException(""); | throw new NotImplementedException(""); | ||||
} | } | ||||
@@ -15,6 +15,7 @@ | |||||
******************************************************************************/ | ******************************************************************************/ | ||||
using System.Collections.Generic; | using System.Collections.Generic; | ||||
using System.Diagnostics; | |||||
using Tensorflow.Eager; | using Tensorflow.Eager; | ||||
using static Tensorflow.Binding; | using static Tensorflow.Binding; | ||||
@@ -96,15 +97,18 @@ namespace Tensorflow | |||||
get_default_graph()._name_stack = old_scope_name; | get_default_graph()._name_stack = old_scope_name; | ||||
} | } | ||||
[DebuggerNonUserCode] | |||||
public void __exit__() | public void __exit__() | ||||
{ | { | ||||
} | } | ||||
[DebuggerNonUserCode] | |||||
public void __init__() | public void __init__() | ||||
{ | { | ||||
} | } | ||||
[DebuggerNonUserCode] | |||||
public void __del__() | public void __del__() | ||||
{ | { | ||||
@@ -40,14 +40,15 @@ namespace Tensorflow | |||||
public TF_DataType @string = TF_DataType.TF_STRING; | public TF_DataType @string = TF_DataType.TF_STRING; | ||||
public Context context = new Context(new ContextOptions(), new Status()); | public Context context = new Context(new ContextOptions(), new Status()); | ||||
public TensorManager tensorMgr; | |||||
public tensorflow() | public tensorflow() | ||||
{ | { | ||||
_constructThreadingObjects(); | _constructThreadingObjects(); | ||||
InitGradientEnvironment(); | InitGradientEnvironment(); | ||||
tensorMgr = new TensorManager(); | |||||
} | } | ||||
private unsafe void InitGradientEnvironment() | |||||
private void InitGradientEnvironment() | |||||
{ | { | ||||
GarbageCollector.Init(); | GarbageCollector.Init(); | ||||
@@ -64,25 +65,30 @@ namespace Tensorflow | |||||
ops.RegisterFromAssembly(); | ops.RegisterFromAssembly(); | ||||
// ops.RegisterFromAssemblyEager(); | // ops.RegisterFromAssemblyEager(); | ||||
c_api.TFE_RegisterGradientFunction((op_name, op_inputs, op_outputs, num_attrs, output_grads, skip_input_indices) => | |||||
c_api.TFE_RegisterGradientFunction((op_name, op_inputs, op_outputs, attrs_string, output_grads, skip_input_indices) => | |||||
{ | { | ||||
/*var input_tensors = new BindingArray(op_inputs); | /*var input_tensors = new BindingArray(op_inputs); | ||||
var output_tensors = new BindingArray(op_outputs); | var output_tensors = new BindingArray(op_outputs); | ||||
var output_grad_tensors = new BindingArray(output_grads);*/ | var output_grad_tensors = new BindingArray(output_grads);*/ | ||||
var input_tensors = new BindingTensorArray(op_inputs).Data.Select(x => new EagerTensor(x)).ToArray(); | |||||
var output_tensors = new BindingTensorArray(op_outputs).Data.Select(x => new EagerTensor(x)).ToArray(); | |||||
var output_grad_tensors = new BindingTensorArray(output_grads).Data.Select(x => new EagerTensor(x)).ToArray(); | |||||
var skip_input_indices_param = new BindingArray(skip_input_indices).Data.Select(x => *(int*)x).ToArray(); | |||||
var input_tensors = new BindingTensorArray(op_inputs) | |||||
.Data.Select(x => tf.tensorMgr.GetTensor(x)).ToArray(); | |||||
var output_tensors = new BindingTensorArray(op_outputs) | |||||
.Data.Select(x => tf.tensorMgr.GetTensor(x)).ToArray(); | |||||
var output_grad_tensors = new BindingTensorArray(output_grads) | |||||
.Data.Select(x => tf.tensorMgr.GetTensor(x)).ToArray(); | |||||
var skip_input_indices_param = new BindingArray(skip_input_indices); | |||||
var gradients = ops.gradientFunctions[op_name](new EagerOperation | var gradients = ops.gradientFunctions[op_name](new EagerOperation | ||||
{ | { | ||||
Name = op_name, | |||||
NumInputs = input_tensors.Length, | NumInputs = input_tensors.Length, | ||||
Inputs = input_tensors, | Inputs = input_tensors, | ||||
// InputHandles = input_tensors.Data, | // InputHandles = input_tensors.Data, | ||||
NumOutputs = output_tensors.Length, | NumOutputs = output_tensors.Length, | ||||
Outputs = output_tensors, | Outputs = output_tensors, | ||||
// OutputHandles = output_tensors.Data, | // OutputHandles = output_tensors.Data, | ||||
SkipInputIndices = skip_input_indices_param | |||||
SkipInputIndicesArray = skip_input_indices_param, | |||||
AttrsArray = attrs_string.Split(',') | |||||
}, output_grad_tensors); | }, output_grad_tensors); | ||||
var gradients_handles = gradients.Select(x => x == null ? IntPtr.Zero : (x as EagerTensor).EagerTensorHandle).ToArray(); | var gradients_handles = gradients.Select(x => x == null ? IntPtr.Zero : (x as EagerTensor).EagerTensorHandle).ToArray(); | ||||
@@ -56,10 +56,10 @@ namespace TensorFlowNET.UnitTest.Basics | |||||
public void Accumulation() | public void Accumulation() | ||||
{ | { | ||||
var x = tf.Variable(10, name: "x"); | var x = tf.Variable(10, name: "x"); | ||||
/*for (int i = 0; i < 5; i++) | |||||
x = x + 1; | |||||
for (int i = 0; i < 5; i++) | |||||
x.assign(x + 1); | |||||
Assert.AreEqual(15, (int)x.numpy());*/ | |||||
Assert.AreEqual(15, (int)x.numpy()); | |||||
} | } | ||||
[TestMethod] | [TestMethod] | ||||
@@ -44,10 +44,10 @@ | |||||
<ItemGroup> | <ItemGroup> | ||||
<PackageReference Include="FluentAssertions" Version="5.10.3" /> | <PackageReference Include="FluentAssertions" Version="5.10.3" /> | ||||
<PackageReference Include="Microsoft.NET.Test.Sdk" Version="16.6.1" /> | <PackageReference Include="Microsoft.NET.Test.Sdk" Version="16.6.1" /> | ||||
<PackageReference Include="MSTest.TestAdapter" Version="2.1.1" /> | |||||
<PackageReference Include="MSTest.TestFramework" Version="2.1.1" /> | |||||
<PackageReference Include="MSTest.TestAdapter" Version="2.1.2" /> | |||||
<PackageReference Include="MSTest.TestFramework" Version="2.1.2" /> | |||||
<PackageReference Include="NumSharp.Lite" Version="0.1.7" /> | <PackageReference Include="NumSharp.Lite" Version="0.1.7" /> | ||||
<PackageReference Include="SciSharp.TensorFlow.Redist" Version="2.2.0.1" /> | |||||
<PackageReference Include="SciSharp.TensorFlow.Redist" Version="2.2.0.2" /> | |||||
</ItemGroup> | </ItemGroup> | ||||
<ItemGroup> | <ItemGroup> | ||||