Browse Source

tf.sparse_tensor_to_dense, TensorShape.merge_with #396

tags/v0.12
Oceania2018 6 years ago
parent
commit
060cc37dd4
7 changed files with 157 additions and 131 deletions
  1. +12
    -1
      src/TensorFlowNET.Core/APIs/tf.sparse.cs
  2. +15
    -6
      src/TensorFlowNET.Core/Framework/sparse_tensor.py.cs
  3. +20
    -0
      src/TensorFlowNET.Core/Operations/gen_sparse_ops.cs
  4. +27
    -0
      src/TensorFlowNET.Core/Tensors/Dimension.cs
  5. +17
    -2
      src/TensorFlowNET.Core/Tensors/TensorShape.cs
  6. +51
    -115
      src/TensorFlowNET.Core/Tensors/tensor_util.cs
  7. +15
    -7
      test/TensorFlowNET.UnitTest/TensorTest.cs

+ 12
- 1
src/TensorFlowNET.Core/APIs/tf.sparse.cs View File

@@ -20,9 +20,20 @@ namespace Tensorflow
{ {
public partial class tensorflow public partial class tensorflow
{ {
public SparseTensor<T> SparseTensor<T>(long[,] indices, T[] values, int[] dense_shape)
public SparseTensor<T> SparseTensor<T>(long[,] indices, T[] values, long[] dense_shape)
=> new SparseTensor<T>(indices, values, dense_shape); => new SparseTensor<T>(indices, values, dense_shape);


public Tensor sparse_tensor_to_dense<T>(SparseTensor<T> sp_input,
T default_value = default,
bool validate_indices = true,
string name = null)
=> gen_sparse_ops.sparse_to_dense(sp_input.indices,
sp_input.dense_shape,
sp_input.values,
default_value: default_value,
validate_indices: validate_indices,
name: name);

/// <summary> /// <summary>
/// Converts a sparse representation into a dense tensor. /// Converts a sparse representation into a dense tensor.
/// </summary> /// </summary>


+ 15
- 6
src/TensorFlowNET.Core/Framework/sparse_tensor.py.cs View File

@@ -1,4 +1,6 @@
using static Tensorflow.Binding;
using System;
using System.Linq;
using static Tensorflow.Binding;


namespace Tensorflow.Framework namespace Tensorflow.Framework
{ {
@@ -8,15 +10,20 @@ namespace Tensorflow.Framework
public class SparseTensor<T> : CompositeTensor, _TensorLike public class SparseTensor<T> : CompositeTensor, _TensorLike
{ {
long[,] _indices; long[,] _indices;
Tensor indices;
public Tensor indices;


T[] _values; T[] _values;
Tensor values;
public Tensor values;


int[] _dense_shape;
Tensor dense_shape;
long[] _dense_shape;
public Tensor dense_shape;


public SparseTensor(long[,] indices_, T[] values_, int[] dense_shape_)
TensorShape _shape;
public TensorShape shape => _shape;

public TF_DataType dtype => dtypes.as_dtype(typeof(T));

public SparseTensor(long[,] indices_, T[] values_, long[] dense_shape_)
{ {
tf_with(ops.name_scope(null, "SparseTensor", new { }), delegate tf_with(ops.name_scope(null, "SparseTensor", new { }), delegate
{ {
@@ -37,6 +44,8 @@ namespace Tensorflow.Framework


indices_shape[0].merge_with(values_shape.dims[0]); indices_shape[0].merge_with(values_shape.dims[0]);
indices_shape[1].merge_with(dense_shape_shape.dims[0]); indices_shape[1].merge_with(dense_shape_shape.dims[0]);

_shape = new TensorShape(_dense_shape.Select(x => Convert.ToInt32(x)).ToArray());
} }
} }




+ 20
- 0
src/TensorFlowNET.Core/Operations/gen_sparse_ops.cs View File

@@ -15,6 +15,7 @@
******************************************************************************/ ******************************************************************************/


using System.Collections.Generic; using System.Collections.Generic;
using Tensorflow.Framework;


namespace Tensorflow namespace Tensorflow
{ {
@@ -50,5 +51,24 @@ namespace Tensorflow


return _op.output; return _op.output;
} }

public static Tensor sparse_to_dense<T>(Tensor sparse_indices,
Tensor output_shape,
Tensor sparse_values,
T default_value = default,
bool validate_indices = true,
string name = null)
{
var _op = _op_def_lib._apply_op_helper("SparseToDense", name, args: new
{
sparse_indices,
output_shape,
sparse_values,
default_value,
validate_indices
});

return _op.output;
}
} }
} }

+ 27
- 0
src/TensorFlowNET.Core/Tensors/Dimension.cs View File

@@ -0,0 +1,27 @@
using System;
using System.Collections.Generic;
using System.Text;

namespace Tensorflow
{
public class Dimension
{
int _value;
public int value => _value;

public Dimension(int value)
{
_value = value;
}

public Dimension merge_with(Dimension other)
{
if (_value == -1)
return new Dimension(other.value);
else
return new Dimension(_value);
}

public override string ToString() => $"Dimension({_value})";
}
}

+ 17
- 2
src/TensorFlowNET.Core/Tensors/TensorShape.cs View File

@@ -1,9 +1,10 @@
using NumSharp; using NumSharp;
using System; using System;
using System.Collections.Generic;
using System.Diagnostics.CodeAnalysis; using System.Diagnostics.CodeAnalysis;
using System.Linq; using System.Linq;
using System.Runtime.CompilerServices; using System.Runtime.CompilerServices;
using NumSharp.Utilities;
using static Tensorflow.Binding;


namespace Tensorflow namespace Tensorflow
{ {
@@ -196,12 +197,26 @@ namespace Tensorflow
} }
} }


/// <summary>
/// Returns a `TensorShape` combining the information in `self` and `other`.
/// </summary>
/// <param name="other"></param>
/// <returns></returns>
public TensorShape merge_with(TensorShape other) public TensorShape merge_with(TensorShape other)
{ {
if (dims.Length == 0) if (dims.Length == 0)
return other; return other;


throw new NotImplementedException("merge_with");
var new_dims = new List<int>();

foreach (var i in range(ndim))
{
var dim = new Dimension(dims[i]);
var merged = dim.merge_with(new Dimension(other.dims[i]));
new_dims.Add(merged.value);
}

return new TensorShape(new_dims.ToArray());
} }


/// <summary> /// <summary>


+ 51
- 115
src/TensorFlowNET.Core/Tensors/tensor_util.cs View File

@@ -118,110 +118,10 @@ namespace Tensorflow
if (values == null) if (values == null)
throw new ValueError("None values not supported."); throw new ValueError("None values not supported.");


if(np_dt == null)
{
switch (values)
{
case bool boolVal:
nparray = boolVal;
break;
case int intVal:
nparray = intVal;
break;
case int[] intVals:
nparray = np.array(intVals);
break;
case int[,] intVals:
nparray = np.array(intVals);
break;
case long intVal:
nparray = intVal;
break;
case long[] intVals:
nparray = np.array(intVals);
break;
case long[,] intVals:
nparray = np.array(intVals);
break;
case float floatVal:
nparray = floatVal;
break;
case float[] floatVals:
nparray = floatVals;
break;
case float[,] floatVals:
nparray = np.array(floatVals);
break;
case double doubleVal:
nparray = doubleVal;
break;
case double[] doubleVals:
nparray = np.array(doubleVals);
break;
case double[,] doubleVals:
nparray = np.array(doubleVals);
break;
case string strVal:
nparray = strVal;
break;
case string[] strVals:
nparray = strVals;
break;
case byte[] byteValues:
nparray = byteValues;
break;
case byte[,] byteValues:
nparray = np.array(byteValues);
break;
default:
throw new NotImplementedException($"make_tensor_proto: Support for type {values.GetType()} Not Implemented");
}
}
else
{
// convert data type
switch (np_dt.Name)
{
case "Int32":
if (values.GetType().IsArray)
nparray = np.array((int[])values, np_dt);
else
nparray = Converts.ToInt32(values);
break;
case "Int64":
if (values.GetType().IsArray)
nparray = np.array((int[])values, np_dt);
else
nparray = Converts.ToInt64(values);
break;
case "Single":
if (values.GetType().IsArray)
nparray = np.array((float[])values, np_dt);
else
nparray = Converts.ToSingle(values);
break;
case "Double":
if (values.GetType().IsArray)
nparray = np.array((double[])values, np_dt);
else
nparray = Converts.ToDouble(values);
break;
case "String":
if (values.GetType().IsArray)
nparray = np.array((string[])values, np_dt);
else
nparray = NDArray.FromString(Converts.ToString(values));
break;
case "Boolean":
if (values.GetType().IsArray)
nparray = np.array((bool[])values, np_dt);
else
nparray = Converts.ToBoolean(values);
break;
default:
throw new NotImplementedException($"make_tensor_proto: Support for type {np_dt.Name} Not Implemented");
}
}
nparray = convert_to_numpy_ndarray(values);

if (np_dt != null && np_dt != typeof(string))
nparray = nparray.astype(np_dt);
} }


var numpy_dtype = nparray.dtype.as_dtype(dtype: dtype); var numpy_dtype = nparray.dtype.as_dtype(dtype: dtype);
@@ -316,23 +216,59 @@ namespace Tensorflow
case NDArray val: case NDArray val:
nd = val; nd = val;
break; break;
case int val:
nd = np.asarray(val);
case bool boolVal:
nd = boolVal;
break;
case int intVal:
nd = intVal;
break;
case int[] intVals:
nd = np.array(intVals);
break;
case int[,] intVals:
nd = np.array(intVals);
break;
case long intVal:
nd = intVal;
break;
case long[] intVals:
nd = np.array(intVals);
break;
case long[,] intVals:
nd = np.array(intVals);
break;
case float floatVal:
nd = floatVal;
break;
case float[] floatVals:
nd = floatVals;
break;
case float[,] floatVals:
nd = np.array(floatVals);
break;
case double doubleVal:
nd = doubleVal;
break;
case double[] doubleVals:
nd = np.array(doubleVals);
break;
case double[,] doubleVals:
nd = np.array(doubleVals);
break; break;
case int[] val:
nd = np.array(val);
case string strVal:
nd = NDArray.FromString(strVal);
break; break;
case float val:
nd = np.asarray(val);
case string[] strVals:
nd = strVals;
break; break;
case double val:
nd = np.asarray(val);
case byte[] byteValues:
nd = byteValues;
break; break;
case string val:
nd = np.asarray(val);
case byte[,] byteValues:
nd = np.array(byteValues);
break; break;
default: default:
throw new Exception("Not Implemented");
throw new NotImplementedException($"convert_to_numpy_ndarray: Support for type {values.GetType()} Not Implemented");
} }


return nd; return nd;


+ 15
- 7
test/TensorFlowNET.UnitTest/TensorTest.cs View File

@@ -225,14 +225,22 @@ namespace TensorFlowNET.UnitTest
[TestMethod] [TestMethod]
public void sparse_tensor_to_dense() public void sparse_tensor_to_dense()
{ {
/*int[,] dense_array =
var decoded_list = tf.SparseTensor(new[,]
{ {
{ 1, 0, 0, 0, 0 },
{ 0, 1, 0, 0, 0 },
{ 0, 0, 1, 0, 0 },
{ 0, 0, 0, 1, 0 }
};
var sparseTensor = new SparseTensor<int>(indices, values, dense_shape);*/
{ 0L, 0L },
{ 1L, 2L }
},
new int[] { 1, 2 },
new[] { 3L, 4L });

var onehot = tf.sparse_tensor_to_dense(decoded_list);
using (var sess = tf.Session())
{
var result = sess.run(onehot);
Assert.IsTrue(Enumerable.SequenceEqual(new int[] { 1, 0, 0, 0 }, result[0].ToArray<int>()));
Assert.IsTrue(Enumerable.SequenceEqual(new int[] { 0, 0, 2, 0 }, result[1].ToArray<int>()));
Assert.IsTrue(Enumerable.SequenceEqual(new int[] { 0, 0, 0, 0 }, result[2].ToArray<int>()));
}
} }
} }
} }

Loading…
Cancel
Save