@@ -12,7 +12,6 @@ namespace Tensorflow | |||||
{ | { | ||||
public void Run() | public void Run() | ||||
{ | { | ||||
tf.UseKeras<KerasInterface>(); | |||||
var inputs = np.random.random((6, 10, 8)).astype(np.float32); | var inputs = np.random.random((6, 10, 8)).astype(np.float32); | ||||
//var simple_rnn = tf.keras.layers.SimpleRNN(4); | //var simple_rnn = tf.keras.layers.SimpleRNN(4); | ||||
//var output = simple_rnn.Apply(inputs); // The output has shape `[32, 4]`. | //var output = simple_rnn.Apply(inputs); // The output has shape `[32, 4]`. | ||||
@@ -2,7 +2,6 @@ | |||||
using System.Collections; | using System.Collections; | ||||
using System.Collections.Generic; | using System.Collections.Generic; | ||||
using System.IO; | using System.IO; | ||||
using System.Numerics; | |||||
using System.Text; | using System.Text; | ||||
using static Tensorflow.Binding; | using static Tensorflow.Binding; | ||||
@@ -103,11 +102,15 @@ namespace Tensorflow.NumPy | |||||
public static NDArray ones(Shape shape, TF_DataType dtype = TF_DataType.TF_DOUBLE) | public static NDArray ones(Shape shape, TF_DataType dtype = TF_DataType.TF_DOUBLE) | ||||
=> new NDArray(tf.ones(shape, dtype: dtype)); | => new NDArray(tf.ones(shape, dtype: dtype)); | ||||
public static NDArray ones_like(NDArray a, Type dtype = null) | |||||
=> throw new NotImplementedException(""); | |||||
public static NDArray ones_like(NDArray a, TF_DataType dtype = TF_DataType.DtInvalid) | |||||
=> new NDArray(tf.ones_like(a, dtype: dtype)); | |||||
[AutoNumPy] | [AutoNumPy] | ||||
public static NDArray zeros(Shape shape, TF_DataType dtype = TF_DataType.TF_DOUBLE) | public static NDArray zeros(Shape shape, TF_DataType dtype = TF_DataType.TF_DOUBLE) | ||||
=> new NDArray(tf.zeros(shape, dtype: dtype)); | => new NDArray(tf.zeros(shape, dtype: dtype)); | ||||
[AutoNumPy] | |||||
public static NDArray zeros_like(NDArray a, TF_DataType dtype = TF_DataType.DtInvalid) | |||||
=> new NDArray(tf.zeros_like(a, dtype: dtype)); | |||||
} | } | ||||
} | } |
@@ -291,7 +291,14 @@ namespace Tensorflow | |||||
protected override void DisposeUnmanagedResources(IntPtr handle) | protected override void DisposeUnmanagedResources(IntPtr handle) | ||||
{ | { | ||||
// c_api.TF_CloseSession(handle, tf.Status.Handle); | // c_api.TF_CloseSession(handle, tf.Status.Handle); | ||||
c_api.TF_DeleteSession(handle, c_api.TF_NewStatus()); | |||||
if (tf.Status == null || tf.Status.Handle.IsInvalid) | |||||
{ | |||||
c_api.TF_DeleteSession(handle, c_api.TF_NewStatus()); | |||||
} | |||||
else | |||||
{ | |||||
c_api.TF_DeleteSession(handle, tf.Status.Handle); | |||||
} | |||||
} | } | ||||
} | } | ||||
} | } |
@@ -1,15 +1,17 @@ | |||||
import numpy as np | import numpy as np | ||||
import tensorflow as tf | import tensorflow as tf | ||||
import tensorflow.experimental.numpy as tnp | |||||
# tf.experimental.numpy | # tf.experimental.numpy | ||||
inputs = np.random.random([32, 10, 8]).astype(np.float32) | |||||
simple_rnn = tf.keras.layers.SimpleRNN(4) | |||||
inputs = np.arange(6 * 10 * 8).reshape([6, 10, 8]).astype(np.float32) | |||||
# simple_rnn = tf.keras.layers.SimpleRNN(4) | |||||
output = simple_rnn(inputs) # The output has shape `[32, 4]`. | |||||
# output = simple_rnn(inputs) # The output has shape `[6, 4]`. | |||||
simple_rnn = tf.keras.layers.SimpleRNN( | |||||
4, return_sequences=True, return_state=True) | |||||
simple_rnn = tf.keras.layers.SimpleRNN(4, return_sequences=True, return_state=True) | |||||
# whole_sequence_output has shape `[32, 10, 4]`. | |||||
# final_state has shape `[32, 4]`. | |||||
whole_sequence_output, final_state = simple_rnn(inputs) | |||||
# whole_sequence_output has shape `[6, 10, 4]`. | |||||
# final_state has shape `[6, 4]`. | |||||
whole_sequence_output, final_state = simple_rnn(inputs) | |||||
print(whole_sequence_output) | |||||
print(final_state) |
@@ -4,7 +4,6 @@ | |||||
<TargetFramework>net6.0</TargetFramework> | <TargetFramework>net6.0</TargetFramework> | ||||
<IsPackable>false</IsPackable> | <IsPackable>false</IsPackable> | ||||
<LangVersion>11.0</LangVersion> | |||||
<Platforms>AnyCPU;x64</Platforms> | <Platforms>AnyCPU;x64</Platforms> | ||||
</PropertyGroup> | </PropertyGroup> | ||||