Browse Source

only resolve if necessary.

tags/TimeSeries
Oceania2018 3 years ago
parent
commit
392fa09fb4
9 changed files with 121 additions and 114 deletions
  1. +1
    -1
      src/TensorFlowNET.Core/Eager/EagerRunner.TFE_FastPathExecute.cs
  2. +3
    -7
      src/TensorFlowNET.Core/Eager/EagerTensor.Creation.cs
  3. +48
    -0
      src/TensorFlowNET.Core/Eager/EagerTensor.cs
  4. +5
    -3
      src/TensorFlowNET.Core/Numpy/NDArray.Creation.cs
  5. +1
    -1
      src/TensorFlowNET.Core/Tensors/Tensor.Creation.cs
  6. +48
    -84
      src/TensorFlowNET.Core/Tensors/Tensor.Explicit.cs
  7. +1
    -1
      src/TensorFlowNET.Core/Tensors/Tensor.Value.cs
  8. +12
    -15
      src/TensorFlowNET.Core/Tensors/Tensor.cs
  9. +2
    -2
      src/TensorFlowNET.Core/Tensors/tensor_util.cs

+ 1
- 1
src/TensorFlowNET.Core/Eager/EagerRunner.TFE_FastPathExecute.cs View File

@@ -213,7 +213,7 @@ namespace Tensorflow.Eager

if (add_type_attr && !string.IsNullOrEmpty(input_arg.TypeAttr))
{
var dtype = c_api.TFE_TensorHandleDataType(tensor.EagerTensorHandle);
var dtype = tensor.dtype;
c_api.TFE_OpSetAttrType(op, input_arg.TypeAttr, dtype);
flattened_attrs.Add(input_arg.TypeAttr);
flattened_attrs.Add(dtype);


+ 3
- 7
src/TensorFlowNET.Core/Eager/EagerTensor.Creation.cs View File

@@ -7,16 +7,10 @@ namespace Tensorflow.Eager
{
public partial class EagerTensor
{
public EagerTensor(SafeTensorHandle handle)
{
NewEagerTensorHandle(handle);
}

public EagerTensor(SafeEagerTensorHandle handle)
{
_id = ops.uid();
_eagerTensorHandle = handle;
Resolve();
}

#region scalar eager tensor
@@ -67,8 +61,10 @@ namespace Tensorflow.Eager
tf.Status.Check(true);
}

private void Resolve()
public void Resolve()
{
if (_handle != null)
return;
_handle = c_api.TFE_TensorHandleResolve(_eagerTensorHandle, tf.Status.Handle);
tf.Status.Check(true);
}


+ 48
- 0
src/TensorFlowNET.Core/Eager/EagerTensor.cs View File

@@ -6,11 +6,47 @@ namespace Tensorflow.Eager
{
public partial class EagerTensor : Tensor
{
public override SafeTensorHandle Handle
{
get
{
Resolve();
return _handle;
}
}

public override IntPtr buffer
{
get
{
Resolve();
return base.buffer;
}
}

public override string Device => c_api.StringPiece(c_api.TFE_TensorHandleDeviceName(_eagerTensorHandle, tf.Status.Handle));
public override TF_DataType dtype => c_api.TFE_TensorHandleDataType(_eagerTensorHandle);

public override int rank => c_api.TFE_TensorHandleNumDims(EagerTensorHandle, tf.Status.Handle);

public override ulong bytesize
{
get
{
Resolve();
return base.bytesize;
}
}

public override IntPtr TensorDataPointer
{
get
{
Resolve();
return base.TensorDataPointer;
}
}

protected override Shape GetShapeInternal()
{
var dims = new int[c_api.TFE_TensorHandleNumDims(_eagerTensorHandle, tf.Status.Handle)];
@@ -19,6 +55,12 @@ namespace Tensorflow.Eager
return dims;
}

protected override void SetShapeInternal(Shape value)
{
if (!shape.is_compatible_with(value))
throw new ValueError($"Tensor's shape is not compatible.");
}

public static int GetRank(IntPtr handle)
{
var tfe_tensor_handle = c_api.TFE_EagerTensorHandle(handle);
@@ -33,5 +75,11 @@ namespace Tensorflow.Eager
dims[i] = c_api.TFE_TensorHandleDim(tfe_tensor_handle, i, tf.Status.Handle);
return dims;
}

public override T[] ToArray<T>()
{
Resolve();
return base.ToArray<T>();
}
}
}

+ 5
- 3
src/TensorFlowNET.Core/Numpy/NDArray.Creation.cs View File

@@ -38,8 +38,8 @@ namespace Tensorflow.NumPy
tensor = tf.defaultSession.eval(tensor);
_handle = tensor.Handle;
}
NewEagerTensorHandle();
NewEagerTensorHandle();
}

public static NDArray Scalar<T>(T value) where T : unmanaged
@@ -57,7 +57,9 @@ namespace Tensorflow.NumPy
void NewEagerTensorHandle()
{
if (_handle is not null)
_eagerTensorHandle = new EagerTensor(_handle).EagerTensorHandle;
{
_eagerTensorHandle = c_api.TFE_NewTensorHandle(_handle, tf.Status.Handle);
}
}
}
}

+ 1
- 1
src/TensorFlowNET.Core/Tensors/Tensor.Creation.cs View File

@@ -28,7 +28,7 @@ namespace Tensorflow
[SuppressMessage("ReSharper", "InvokeAsExtensionMethod")]
public partial class Tensor
{
public IntPtr TensorDataPointer => _handle == null ? IntPtr.Zero : TF_TensorData(_handle);
public virtual IntPtr TensorDataPointer => _handle == null ? IntPtr.Zero : TF_TensorData(_handle);

public Tensor()
{


+ 48
- 84
src/TensorFlowNET.Core/Tensors/Tensor.Explicit.cs View File

@@ -5,124 +5,88 @@ namespace Tensorflow
{
public partial class Tensor
{
public static explicit operator bool(Tensor tensor)
public unsafe static explicit operator bool(Tensor tensor)
{
unsafe
{
EnsureScalar(tensor);
EnsureDType(tensor, TF_DataType.TF_BOOL);
return *(bool*)tensor.buffer;
}
EnsureScalar(tensor);
EnsureDType(tensor, TF_DataType.TF_BOOL);
return *(bool*)tensor.buffer;
}

public static explicit operator sbyte(Tensor tensor)
public unsafe static explicit operator sbyte(Tensor tensor)
{
unsafe
{
EnsureScalar(tensor);
EnsureDType(tensor, TF_DataType.TF_INT8);
return *(sbyte*)tensor.buffer;
}
EnsureScalar(tensor);
EnsureDType(tensor, TF_DataType.TF_INT8);
return *(sbyte*)tensor.buffer;
}

public static explicit operator byte(Tensor tensor)
public unsafe static explicit operator byte(Tensor tensor)
{
unsafe
{
EnsureScalar(tensor);
EnsureDType(tensor, TF_DataType.TF_UINT8);
return *(byte*)tensor.buffer;
}
EnsureScalar(tensor);
EnsureDType(tensor, TF_DataType.TF_UINT8);
return *(byte*)tensor.buffer;
}

public static explicit operator ushort(Tensor tensor)
public unsafe static explicit operator ushort(Tensor tensor)
{
unsafe
{
EnsureScalar(tensor);
EnsureDType(tensor, TF_DataType.TF_UINT16);
return *(ushort*)tensor.buffer;
}
EnsureScalar(tensor);
EnsureDType(tensor, TF_DataType.TF_UINT16);
return *(ushort*)tensor.buffer;
}

public static explicit operator short(Tensor tensor)
public unsafe static explicit operator short(Tensor tensor)
{
unsafe
{
EnsureScalar(tensor);
EnsureDType(tensor, TF_DataType.TF_INT16);
return *(short*)tensor.buffer;
}
EnsureScalar(tensor);
EnsureDType(tensor, TF_DataType.TF_INT16);
return *(short*)tensor.buffer;
}

public static explicit operator int(Tensor tensor)
public unsafe static explicit operator int(Tensor tensor)
{
unsafe
{
EnsureScalar(tensor);
EnsureDType(tensor, TF_DataType.TF_INT32);
return *(int*)tensor.buffer;
}
EnsureScalar(tensor);
EnsureDType(tensor, TF_DataType.TF_INT32);
return *(int*)tensor.buffer;
}

public static explicit operator uint(Tensor tensor)
public unsafe static explicit operator uint(Tensor tensor)
{
unsafe
{
EnsureScalar(tensor);
EnsureDType(tensor, TF_DataType.TF_UINT32);
return *(uint*)tensor.buffer;
}
EnsureScalar(tensor);
EnsureDType(tensor, TF_DataType.TF_UINT32);
return *(uint*)tensor.buffer;
}

public static explicit operator long(Tensor tensor)
public unsafe static explicit operator long(Tensor tensor)
{
unsafe
{
EnsureScalar(tensor);
EnsureDType(tensor, TF_DataType.TF_INT64);
return *(long*)tensor.buffer;
}
EnsureScalar(tensor);
EnsureDType(tensor, TF_DataType.TF_INT64);
return *(long*)tensor.buffer;
}

public static explicit operator ulong(Tensor tensor)
public unsafe static explicit operator ulong(Tensor tensor)
{
unsafe
{
EnsureScalar(tensor);
EnsureDType(tensor, TF_DataType.TF_UINT64);
return *(ulong*)tensor.buffer;
}
EnsureScalar(tensor);
EnsureDType(tensor, TF_DataType.TF_UINT64);
return *(ulong*)tensor.buffer;
}

public static explicit operator float(Tensor tensor)
public unsafe static explicit operator float(Tensor tensor)
{
unsafe
{
EnsureScalar(tensor);
EnsureDType(tensor, TF_DataType.TF_FLOAT);
return *(float*)tensor.buffer;
}
EnsureScalar(tensor);
EnsureDType(tensor, TF_DataType.TF_FLOAT);
return *(float*)tensor.buffer;
}

public static explicit operator double(Tensor tensor)
public unsafe static explicit operator double(Tensor tensor)
{
unsafe
{
EnsureScalar(tensor);
EnsureDType(tensor, TF_DataType.TF_DOUBLE);
return *(double*)tensor.buffer;
}
EnsureScalar(tensor);
EnsureDType(tensor, TF_DataType.TF_DOUBLE);
return *(double*)tensor.buffer;
}

public static explicit operator string(Tensor tensor)
public unsafe static explicit operator string(Tensor tensor)
{
unsafe
{
EnsureScalar(tensor);
EnsureDType(tensor, TF_DataType.TF_STRING);
return new string((char*)tensor.buffer, 0, (int)tensor.size);
}
EnsureScalar(tensor);
EnsureDType(tensor, TF_DataType.TF_STRING);
return new string((char*)tensor.buffer, 0, (int)tensor.size);
}

[MethodImpl(MethodImplOptions.AggressiveInlining)]


+ 1
- 1
src/TensorFlowNET.Core/Tensors/Tensor.Value.cs View File

@@ -12,7 +12,7 @@ namespace Tensorflow
/// </summary>
/// <typeparam name="T"></typeparam>
/// <returns></returns>
public unsafe T[] ToArray<T>() where T : unmanaged
public virtual unsafe T[] ToArray<T>() where T : unmanaged
{
//Are the types matching?
if (typeof(T).as_tf_dtype() != dtype)


+ 12
- 15
src/TensorFlowNET.Core/Tensors/Tensor.cs View File

@@ -68,10 +68,10 @@ namespace Tensorflow
/// The DType of elements in this tensor.
/// </summary>
public virtual TF_DataType dtype => _handle == null ? _override_dtype : c_api.TF_TensorType(_handle);
public ulong bytesize => _handle == null ? 0 : c_api.TF_TensorByteSize(_handle);
public virtual ulong bytesize => _handle == null ? 0 : c_api.TF_TensorByteSize(_handle);
public ulong dtypesize => (ulong)dtype.get_datatype_size();
public ulong size => _handle == null ? 0 : bytesize / dtypesize;
public IntPtr buffer => _handle == null ? IntPtr.Zero : c_api.TF_TensorData(_handle);
public virtual IntPtr buffer => _handle == null ? IntPtr.Zero : c_api.TF_TensorData(_handle);
public int num_consumers(TF_Output oper_out) => _handle == null ? 0 : c_api.TF_OperationOutputNumConsumers(oper_out);
public int ndim => rank;

@@ -86,7 +86,7 @@ namespace Tensorflow
/// </summary>
public object Tag { get; set; }
protected new SafeTensorHandle _handle;
public SafeTensorHandle Handle => _handle;
public virtual SafeTensorHandle Handle => _handle;

protected SafeEagerTensorHandle _eagerTensorHandle;
/// <summary>
@@ -114,18 +114,7 @@ namespace Tensorflow

set
{
if (this is EagerTensor)
{
if(!shape.is_compatible_with(value))
throw new ValueError($"Tensor's shape is not compatible.");
return;
}

if (value == null)
c_api.TF_GraphSetTensorShape(graph, _as_tf_output(), null, -1, tf.Status.Handle);
else
c_api.TF_GraphSetTensorShape(graph, _as_tf_output(), value.dims, value.ndim, tf.Status.Handle);

SetShapeInternal(value);
tf.Status.Check(true);
}
}
@@ -147,6 +136,14 @@ namespace Tensorflow
return dims;
}

protected virtual void SetShapeInternal(Shape value)
{
if (value == null)
c_api.TF_GraphSetTensorShape(graph, _as_tf_output(), null, -1, tf.Status.Handle);
else
c_api.TF_GraphSetTensorShape(graph, _as_tf_output(), value.dims, value.ndim, tf.Status.Handle);
}

public int[] _shape_tuple()
{
return rank < 0 ? null : shape.dims.Select(x => (int)x).ToArray();


+ 2
- 2
src/TensorFlowNET.Core/Tensors/tensor_util.cs View File

@@ -233,9 +233,9 @@ namespace Tensorflow
return false;
}

if (tensor.GetType() == typeof(EagerTensor))
if (tensor is EagerTensor eagerTensor)
{
if(tensor.dtype == TF_DataType.TF_INT64)
if(tensor.dtype == tf.int64)
return new Shape(tensor.ToArray<long>());
else
return new Shape(tensor.ToArray<int>());


Loading…
Cancel
Save