Browse Source

only resolve if necessary.

tags/TimeSeries
Oceania2018 3 years ago
parent
commit
392fa09fb4
9 changed files with 121 additions and 114 deletions
  1. +1
    -1
      src/TensorFlowNET.Core/Eager/EagerRunner.TFE_FastPathExecute.cs
  2. +3
    -7
      src/TensorFlowNET.Core/Eager/EagerTensor.Creation.cs
  3. +48
    -0
      src/TensorFlowNET.Core/Eager/EagerTensor.cs
  4. +5
    -3
      src/TensorFlowNET.Core/Numpy/NDArray.Creation.cs
  5. +1
    -1
      src/TensorFlowNET.Core/Tensors/Tensor.Creation.cs
  6. +48
    -84
      src/TensorFlowNET.Core/Tensors/Tensor.Explicit.cs
  7. +1
    -1
      src/TensorFlowNET.Core/Tensors/Tensor.Value.cs
  8. +12
    -15
      src/TensorFlowNET.Core/Tensors/Tensor.cs
  9. +2
    -2
      src/TensorFlowNET.Core/Tensors/tensor_util.cs

+ 1
- 1
src/TensorFlowNET.Core/Eager/EagerRunner.TFE_FastPathExecute.cs View File

@@ -213,7 +213,7 @@ namespace Tensorflow.Eager


if (add_type_attr && !string.IsNullOrEmpty(input_arg.TypeAttr)) if (add_type_attr && !string.IsNullOrEmpty(input_arg.TypeAttr))
{ {
var dtype = c_api.TFE_TensorHandleDataType(tensor.EagerTensorHandle);
var dtype = tensor.dtype;
c_api.TFE_OpSetAttrType(op, input_arg.TypeAttr, dtype); c_api.TFE_OpSetAttrType(op, input_arg.TypeAttr, dtype);
flattened_attrs.Add(input_arg.TypeAttr); flattened_attrs.Add(input_arg.TypeAttr);
flattened_attrs.Add(dtype); flattened_attrs.Add(dtype);


+ 3
- 7
src/TensorFlowNET.Core/Eager/EagerTensor.Creation.cs View File

@@ -7,16 +7,10 @@ namespace Tensorflow.Eager
{ {
public partial class EagerTensor public partial class EagerTensor
{ {
public EagerTensor(SafeTensorHandle handle)
{
NewEagerTensorHandle(handle);
}

public EagerTensor(SafeEagerTensorHandle handle) public EagerTensor(SafeEagerTensorHandle handle)
{ {
_id = ops.uid(); _id = ops.uid();
_eagerTensorHandle = handle; _eagerTensorHandle = handle;
Resolve();
} }


#region scalar eager tensor #region scalar eager tensor
@@ -67,8 +61,10 @@ namespace Tensorflow.Eager
tf.Status.Check(true); tf.Status.Check(true);
} }


private void Resolve()
public void Resolve()
{ {
if (_handle != null)
return;
_handle = c_api.TFE_TensorHandleResolve(_eagerTensorHandle, tf.Status.Handle); _handle = c_api.TFE_TensorHandleResolve(_eagerTensorHandle, tf.Status.Handle);
tf.Status.Check(true); tf.Status.Check(true);
} }


+ 48
- 0
src/TensorFlowNET.Core/Eager/EagerTensor.cs View File

@@ -6,11 +6,47 @@ namespace Tensorflow.Eager
{ {
public partial class EagerTensor : Tensor public partial class EagerTensor : Tensor
{ {
public override SafeTensorHandle Handle
{
get
{
Resolve();
return _handle;
}
}

public override IntPtr buffer
{
get
{
Resolve();
return base.buffer;
}
}

public override string Device => c_api.StringPiece(c_api.TFE_TensorHandleDeviceName(_eagerTensorHandle, tf.Status.Handle)); public override string Device => c_api.StringPiece(c_api.TFE_TensorHandleDeviceName(_eagerTensorHandle, tf.Status.Handle));
public override TF_DataType dtype => c_api.TFE_TensorHandleDataType(_eagerTensorHandle); public override TF_DataType dtype => c_api.TFE_TensorHandleDataType(_eagerTensorHandle);


public override int rank => c_api.TFE_TensorHandleNumDims(EagerTensorHandle, tf.Status.Handle); public override int rank => c_api.TFE_TensorHandleNumDims(EagerTensorHandle, tf.Status.Handle);


public override ulong bytesize
{
get
{
Resolve();
return base.bytesize;
}
}

public override IntPtr TensorDataPointer
{
get
{
Resolve();
return base.TensorDataPointer;
}
}

protected override Shape GetShapeInternal() protected override Shape GetShapeInternal()
{ {
var dims = new int[c_api.TFE_TensorHandleNumDims(_eagerTensorHandle, tf.Status.Handle)]; var dims = new int[c_api.TFE_TensorHandleNumDims(_eagerTensorHandle, tf.Status.Handle)];
@@ -19,6 +55,12 @@ namespace Tensorflow.Eager
return dims; return dims;
} }


protected override void SetShapeInternal(Shape value)
{
if (!shape.is_compatible_with(value))
throw new ValueError($"Tensor's shape is not compatible.");
}

public static int GetRank(IntPtr handle) public static int GetRank(IntPtr handle)
{ {
var tfe_tensor_handle = c_api.TFE_EagerTensorHandle(handle); var tfe_tensor_handle = c_api.TFE_EagerTensorHandle(handle);
@@ -33,5 +75,11 @@ namespace Tensorflow.Eager
dims[i] = c_api.TFE_TensorHandleDim(tfe_tensor_handle, i, tf.Status.Handle); dims[i] = c_api.TFE_TensorHandleDim(tfe_tensor_handle, i, tf.Status.Handle);
return dims; return dims;
} }

public override T[] ToArray<T>()
{
Resolve();
return base.ToArray<T>();
}
} }
} }

+ 5
- 3
src/TensorFlowNET.Core/Numpy/NDArray.Creation.cs View File

@@ -38,8 +38,8 @@ namespace Tensorflow.NumPy
tensor = tf.defaultSession.eval(tensor); tensor = tf.defaultSession.eval(tensor);
_handle = tensor.Handle; _handle = tensor.Handle;
} }
NewEagerTensorHandle();
NewEagerTensorHandle();
} }


public static NDArray Scalar<T>(T value) where T : unmanaged public static NDArray Scalar<T>(T value) where T : unmanaged
@@ -57,7 +57,9 @@ namespace Tensorflow.NumPy
void NewEagerTensorHandle() void NewEagerTensorHandle()
{ {
if (_handle is not null) if (_handle is not null)
_eagerTensorHandle = new EagerTensor(_handle).EagerTensorHandle;
{
_eagerTensorHandle = c_api.TFE_NewTensorHandle(_handle, tf.Status.Handle);
}
} }
} }
} }

+ 1
- 1
src/TensorFlowNET.Core/Tensors/Tensor.Creation.cs View File

@@ -28,7 +28,7 @@ namespace Tensorflow
[SuppressMessage("ReSharper", "InvokeAsExtensionMethod")] [SuppressMessage("ReSharper", "InvokeAsExtensionMethod")]
public partial class Tensor public partial class Tensor
{ {
public IntPtr TensorDataPointer => _handle == null ? IntPtr.Zero : TF_TensorData(_handle);
public virtual IntPtr TensorDataPointer => _handle == null ? IntPtr.Zero : TF_TensorData(_handle);


public Tensor() public Tensor()
{ {


+ 48
- 84
src/TensorFlowNET.Core/Tensors/Tensor.Explicit.cs View File

@@ -5,124 +5,88 @@ namespace Tensorflow
{ {
public partial class Tensor public partial class Tensor
{ {
public static explicit operator bool(Tensor tensor)
public unsafe static explicit operator bool(Tensor tensor)
{ {
unsafe
{
EnsureScalar(tensor);
EnsureDType(tensor, TF_DataType.TF_BOOL);
return *(bool*)tensor.buffer;
}
EnsureScalar(tensor);
EnsureDType(tensor, TF_DataType.TF_BOOL);
return *(bool*)tensor.buffer;
} }


public static explicit operator sbyte(Tensor tensor)
public unsafe static explicit operator sbyte(Tensor tensor)
{ {
unsafe
{
EnsureScalar(tensor);
EnsureDType(tensor, TF_DataType.TF_INT8);
return *(sbyte*)tensor.buffer;
}
EnsureScalar(tensor);
EnsureDType(tensor, TF_DataType.TF_INT8);
return *(sbyte*)tensor.buffer;
} }


public static explicit operator byte(Tensor tensor)
public unsafe static explicit operator byte(Tensor tensor)
{ {
unsafe
{
EnsureScalar(tensor);
EnsureDType(tensor, TF_DataType.TF_UINT8);
return *(byte*)tensor.buffer;
}
EnsureScalar(tensor);
EnsureDType(tensor, TF_DataType.TF_UINT8);
return *(byte*)tensor.buffer;
} }


public static explicit operator ushort(Tensor tensor)
public unsafe static explicit operator ushort(Tensor tensor)
{ {
unsafe
{
EnsureScalar(tensor);
EnsureDType(tensor, TF_DataType.TF_UINT16);
return *(ushort*)tensor.buffer;
}
EnsureScalar(tensor);
EnsureDType(tensor, TF_DataType.TF_UINT16);
return *(ushort*)tensor.buffer;
} }


public static explicit operator short(Tensor tensor)
public unsafe static explicit operator short(Tensor tensor)
{ {
unsafe
{
EnsureScalar(tensor);
EnsureDType(tensor, TF_DataType.TF_INT16);
return *(short*)tensor.buffer;
}
EnsureScalar(tensor);
EnsureDType(tensor, TF_DataType.TF_INT16);
return *(short*)tensor.buffer;
} }


public static explicit operator int(Tensor tensor)
public unsafe static explicit operator int(Tensor tensor)
{ {
unsafe
{
EnsureScalar(tensor);
EnsureDType(tensor, TF_DataType.TF_INT32);
return *(int*)tensor.buffer;
}
EnsureScalar(tensor);
EnsureDType(tensor, TF_DataType.TF_INT32);
return *(int*)tensor.buffer;
} }


public static explicit operator uint(Tensor tensor)
public unsafe static explicit operator uint(Tensor tensor)
{ {
unsafe
{
EnsureScalar(tensor);
EnsureDType(tensor, TF_DataType.TF_UINT32);
return *(uint*)tensor.buffer;
}
EnsureScalar(tensor);
EnsureDType(tensor, TF_DataType.TF_UINT32);
return *(uint*)tensor.buffer;
} }


public static explicit operator long(Tensor tensor)
public unsafe static explicit operator long(Tensor tensor)
{ {
unsafe
{
EnsureScalar(tensor);
EnsureDType(tensor, TF_DataType.TF_INT64);
return *(long*)tensor.buffer;
}
EnsureScalar(tensor);
EnsureDType(tensor, TF_DataType.TF_INT64);
return *(long*)tensor.buffer;
} }


public static explicit operator ulong(Tensor tensor)
public unsafe static explicit operator ulong(Tensor tensor)
{ {
unsafe
{
EnsureScalar(tensor);
EnsureDType(tensor, TF_DataType.TF_UINT64);
return *(ulong*)tensor.buffer;
}
EnsureScalar(tensor);
EnsureDType(tensor, TF_DataType.TF_UINT64);
return *(ulong*)tensor.buffer;
} }


public static explicit operator float(Tensor tensor)
public unsafe static explicit operator float(Tensor tensor)
{ {
unsafe
{
EnsureScalar(tensor);
EnsureDType(tensor, TF_DataType.TF_FLOAT);
return *(float*)tensor.buffer;
}
EnsureScalar(tensor);
EnsureDType(tensor, TF_DataType.TF_FLOAT);
return *(float*)tensor.buffer;
} }


public static explicit operator double(Tensor tensor)
public unsafe static explicit operator double(Tensor tensor)
{ {
unsafe
{
EnsureScalar(tensor);
EnsureDType(tensor, TF_DataType.TF_DOUBLE);
return *(double*)tensor.buffer;
}
EnsureScalar(tensor);
EnsureDType(tensor, TF_DataType.TF_DOUBLE);
return *(double*)tensor.buffer;
} }


public static explicit operator string(Tensor tensor)
public unsafe static explicit operator string(Tensor tensor)
{ {
unsafe
{
EnsureScalar(tensor);
EnsureDType(tensor, TF_DataType.TF_STRING);
return new string((char*)tensor.buffer, 0, (int)tensor.size);
}
EnsureScalar(tensor);
EnsureDType(tensor, TF_DataType.TF_STRING);
return new string((char*)tensor.buffer, 0, (int)tensor.size);
} }


[MethodImpl(MethodImplOptions.AggressiveInlining)] [MethodImpl(MethodImplOptions.AggressiveInlining)]


+ 1
- 1
src/TensorFlowNET.Core/Tensors/Tensor.Value.cs View File

@@ -12,7 +12,7 @@ namespace Tensorflow
/// </summary> /// </summary>
/// <typeparam name="T"></typeparam> /// <typeparam name="T"></typeparam>
/// <returns></returns> /// <returns></returns>
public unsafe T[] ToArray<T>() where T : unmanaged
public virtual unsafe T[] ToArray<T>() where T : unmanaged
{ {
//Are the types matching? //Are the types matching?
if (typeof(T).as_tf_dtype() != dtype) if (typeof(T).as_tf_dtype() != dtype)


+ 12
- 15
src/TensorFlowNET.Core/Tensors/Tensor.cs View File

@@ -68,10 +68,10 @@ namespace Tensorflow
/// The DType of elements in this tensor. /// The DType of elements in this tensor.
/// </summary> /// </summary>
public virtual TF_DataType dtype => _handle == null ? _override_dtype : c_api.TF_TensorType(_handle); public virtual TF_DataType dtype => _handle == null ? _override_dtype : c_api.TF_TensorType(_handle);
public ulong bytesize => _handle == null ? 0 : c_api.TF_TensorByteSize(_handle);
public virtual ulong bytesize => _handle == null ? 0 : c_api.TF_TensorByteSize(_handle);
public ulong dtypesize => (ulong)dtype.get_datatype_size(); public ulong dtypesize => (ulong)dtype.get_datatype_size();
public ulong size => _handle == null ? 0 : bytesize / dtypesize; public ulong size => _handle == null ? 0 : bytesize / dtypesize;
public IntPtr buffer => _handle == null ? IntPtr.Zero : c_api.TF_TensorData(_handle);
public virtual IntPtr buffer => _handle == null ? IntPtr.Zero : c_api.TF_TensorData(_handle);
public int num_consumers(TF_Output oper_out) => _handle == null ? 0 : c_api.TF_OperationOutputNumConsumers(oper_out); public int num_consumers(TF_Output oper_out) => _handle == null ? 0 : c_api.TF_OperationOutputNumConsumers(oper_out);
public int ndim => rank; public int ndim => rank;


@@ -86,7 +86,7 @@ namespace Tensorflow
/// </summary> /// </summary>
public object Tag { get; set; } public object Tag { get; set; }
protected new SafeTensorHandle _handle; protected new SafeTensorHandle _handle;
public SafeTensorHandle Handle => _handle;
public virtual SafeTensorHandle Handle => _handle;


protected SafeEagerTensorHandle _eagerTensorHandle; protected SafeEagerTensorHandle _eagerTensorHandle;
/// <summary> /// <summary>
@@ -114,18 +114,7 @@ namespace Tensorflow


set set
{ {
if (this is EagerTensor)
{
if(!shape.is_compatible_with(value))
throw new ValueError($"Tensor's shape is not compatible.");
return;
}

if (value == null)
c_api.TF_GraphSetTensorShape(graph, _as_tf_output(), null, -1, tf.Status.Handle);
else
c_api.TF_GraphSetTensorShape(graph, _as_tf_output(), value.dims, value.ndim, tf.Status.Handle);

SetShapeInternal(value);
tf.Status.Check(true); tf.Status.Check(true);
} }
} }
@@ -147,6 +136,14 @@ namespace Tensorflow
return dims; return dims;
} }


protected virtual void SetShapeInternal(Shape value)
{
if (value == null)
c_api.TF_GraphSetTensorShape(graph, _as_tf_output(), null, -1, tf.Status.Handle);
else
c_api.TF_GraphSetTensorShape(graph, _as_tf_output(), value.dims, value.ndim, tf.Status.Handle);
}

public int[] _shape_tuple() public int[] _shape_tuple()
{ {
return rank < 0 ? null : shape.dims.Select(x => (int)x).ToArray(); return rank < 0 ? null : shape.dims.Select(x => (int)x).ToArray();


+ 2
- 2
src/TensorFlowNET.Core/Tensors/tensor_util.cs View File

@@ -233,9 +233,9 @@ namespace Tensorflow
return false; return false;
} }


if (tensor.GetType() == typeof(EagerTensor))
if (tensor is EagerTensor eagerTensor)
{ {
if(tensor.dtype == TF_DataType.TF_INT64)
if(tensor.dtype == tf.int64)
return new Shape(tensor.ToArray<long>()); return new Shape(tensor.ToArray<long>());
else else
return new Shape(tensor.ToArray<int>()); return new Shape(tensor.ToArray<int>());


Loading…
Cancel
Save