Browse Source

np.ones_like and np.zeros_like

tags/v0.100.4-load-saved-model
Haiping Chen 2 years ago
parent
commit
ec340eeff5
5 changed files with 24 additions and 14 deletions
  1. +0
    -1
      src/TensorFlowNET.Console/SimpleRnnTest.cs
  2. +6
    -3
      src/TensorFlowNET.Core/Numpy/Numpy.Creation.cs
  3. +8
    -1
      src/TensorFlowNET.Core/Sessions/BaseSession.cs
  4. +10
    -8
      src/python/simple_rnn.py
  5. +0
    -1
      test/TensorFlowNET.Keras.UnitTest/Tensorflow.Keras.UnitTest.csproj

+ 0
- 1
src/TensorFlowNET.Console/SimpleRnnTest.cs View File

@@ -12,7 +12,6 @@ namespace Tensorflow
{
public void Run()
{
tf.UseKeras<KerasInterface>();
var inputs = np.random.random((6, 10, 8)).astype(np.float32);
//var simple_rnn = tf.keras.layers.SimpleRNN(4);
//var output = simple_rnn.Apply(inputs); // The output has shape `[32, 4]`.


+ 6
- 3
src/TensorFlowNET.Core/Numpy/Numpy.Creation.cs View File

@@ -2,7 +2,6 @@
using System.Collections;
using System.Collections.Generic;
using System.IO;
using System.Numerics;
using System.Text;
using static Tensorflow.Binding;

@@ -103,11 +102,15 @@ namespace Tensorflow.NumPy
public static NDArray ones(Shape shape, TF_DataType dtype = TF_DataType.TF_DOUBLE)
=> new NDArray(tf.ones(shape, dtype: dtype));

public static NDArray ones_like(NDArray a, Type dtype = null)
=> throw new NotImplementedException("");
public static NDArray ones_like(NDArray a, TF_DataType dtype = TF_DataType.DtInvalid)
=> new NDArray(tf.ones_like(a, dtype: dtype));

[AutoNumPy]
public static NDArray zeros(Shape shape, TF_DataType dtype = TF_DataType.TF_DOUBLE)
=> new NDArray(tf.zeros(shape, dtype: dtype));

[AutoNumPy]
public static NDArray zeros_like(NDArray a, TF_DataType dtype = TF_DataType.DtInvalid)
=> new NDArray(tf.zeros_like(a, dtype: dtype));
}
}

+ 8
- 1
src/TensorFlowNET.Core/Sessions/BaseSession.cs View File

@@ -291,7 +291,14 @@ namespace Tensorflow
protected override void DisposeUnmanagedResources(IntPtr handle)
{
// c_api.TF_CloseSession(handle, tf.Status.Handle);
c_api.TF_DeleteSession(handle, c_api.TF_NewStatus());
if (tf.Status == null || tf.Status.Handle.IsInvalid)
{
c_api.TF_DeleteSession(handle, c_api.TF_NewStatus());
}
else
{
c_api.TF_DeleteSession(handle, tf.Status.Handle);
}
}
}
}

+ 10
- 8
src/python/simple_rnn.py View File

@@ -1,15 +1,17 @@
import numpy as np
import tensorflow as tf
import tensorflow.experimental.numpy as tnp

# tf.experimental.numpy
inputs = np.random.random([32, 10, 8]).astype(np.float32)
simple_rnn = tf.keras.layers.SimpleRNN(4)
inputs = np.arange(6 * 10 * 8).reshape([6, 10, 8]).astype(np.float32)
# simple_rnn = tf.keras.layers.SimpleRNN(4)

output = simple_rnn(inputs) # The output has shape `[32, 4]`.
# output = simple_rnn(inputs) # The output has shape `[6, 4]`.

simple_rnn = tf.keras.layers.SimpleRNN(
4, return_sequences=True, return_state=True)
simple_rnn = tf.keras.layers.SimpleRNN(4, return_sequences=True, return_state=True)

# whole_sequence_output has shape `[32, 10, 4]`.
# final_state has shape `[32, 4]`.
whole_sequence_output, final_state = simple_rnn(inputs)
# whole_sequence_output has shape `[6, 10, 4]`.
# final_state has shape `[6, 4]`.
whole_sequence_output, final_state = simple_rnn(inputs)
print(whole_sequence_output)
print(final_state)

+ 0
- 1
test/TensorFlowNET.Keras.UnitTest/Tensorflow.Keras.UnitTest.csproj View File

@@ -4,7 +4,6 @@
<TargetFramework>net6.0</TargetFramework>

<IsPackable>false</IsPackable>
<LangVersion>11.0</LangVersion>
<Platforms>AnyCPU;x64</Platforms>
</PropertyGroup>



Loading…
Cancel
Save