Browse Source

load all images for paths_and_labels_to_dataset

tags/v0.60-tf.numpy
Oceania2018 4 years ago
parent
commit
9a2ed2768f
4 changed files with 21 additions and 4 deletions
  1. +1
    -1
      src/TensorFlowNET.Core/Tensors/Tensor.Creation.cs
  2. +2
    -0
      src/TensorFlowNET.Core/Tensors/dtypes.cs
  3. +1
    -1
      src/TensorFlowNET.Core/Variables/BaseResourceVariable.cs
  4. +17
    -2
      src/TensorFlowNET.Keras/Preprocessings/Preprocessing.paths_and_labels_to_dataset.cs

+ 1
- 1
src/TensorFlowNET.Core/Tensors/Tensor.Creation.cs View File

@@ -42,7 +42,7 @@ namespace Tensorflow
public unsafe Tensor(SafeTensorHandle handle, bool clone = false)
{
_handle = handle;
if (clone)
if (clone && handle != null)
_handle = TF_NewTensor(shape, dtype, data: TensorDataPointer.ToPointer());
isCreatedInGraphMode = !tf.executing_eagerly();


+ 2
- 0
src/TensorFlowNET.Core/Tensors/dtypes.cs View File

@@ -17,6 +17,7 @@
using Tensorflow.NumPy;
using System;
using System.Numerics;
using System.Diagnostics;

namespace Tensorflow
{
@@ -221,6 +222,7 @@ namespace Tensorflow
return (int)type > 100 ? (DataType)((int)type - 100) : type;
}

[DebuggerStepThrough]
public static TF_DataType as_tf_dtype(this DataType type)
{
return (TF_DataType)type;


+ 1
- 1
src/TensorFlowNET.Core/Variables/BaseResourceVariable.cs View File

@@ -222,7 +222,7 @@ namespace Tensorflow
public override string ToString()
{
if (tf.Context.executing_eagerly())
return $"tf.Variable: '{Name}' shape={string.Join(",", shape)}, dtype={dtype.as_numpy_name()}, numpy={read_value()}";
return $"tf.Variable: '{Name}' shape={string.Join(",", shape)}, dtype={dtype.as_numpy_name()}, numpy={read_value().numpy()}";
else
return $"tf.Variable: '{Name}' shape={string.Join(",", shape)}, dtype={dtype.as_numpy_name()}";
}


+ 17
- 2
src/TensorFlowNET.Keras/Preprocessings/Preprocessing.paths_and_labels_to_dataset.cs View File

@@ -1,6 +1,7 @@
using System;
using System.IO;
using static Tensorflow.Binding;
using Tensorflow.NumPy;

namespace Tensorflow.Keras
{
@@ -14,9 +15,23 @@ namespace Tensorflow.Keras
int num_classes,
string interpolation)
{
var path_ds = tf.data.Dataset.from_tensor_slices(image_paths);
var img_ds = path_ds.map(x => path_to_image(x, image_size, num_channels, interpolation));
// option 1: will load all images into memory, not efficient
var images = np.zeros((image_paths.Length, image_size[0], image_size[1], num_channels), np.float32);
for (int i = 0; i < len(images); i++)
{
var img = tf.io.read_file(image_paths[i]);
img = tf.image.decode_image(
img, channels: num_channels, expand_animations: false);
var resized_image = tf.image.resize_images_v2(img, image_size, method: interpolation);
images[i] = resized_image.numpy();
tf_output_redirect.WriteLine(image_paths[i]);
};

// option 2: dynamic load, but has error, need to fix
/* var path_ds = tf.data.Dataset.from_tensor_slices(image_paths);
var img_ds = path_ds.map(x => path_to_image(x, image_size, num_channels, interpolation));*/

var img_ds = tf.data.Dataset.from_tensor_slices(images);
if (label_mode == "int")
{
var label_ds = dataset_utils.labels_to_dataset(labels, label_mode, num_classes);


Loading…
Cancel
Save