Browse Source

add np.load.

tags/v0.60-tf.numpy
Oceania2018 4 years ago
parent
commit
7f0b9b68ac
5 changed files with 175 additions and 3 deletions
  1. +16
    -0
      src/TensorFlowNET.Core/NumPy/Implementation/NumPyImpl.Creation.cs
  2. +145
    -0
      src/TensorFlowNET.Core/NumPy/Implementation/NumPyImpl.load.cs
  3. +4
    -3
      src/TensorFlowNET.Core/NumPy/NDArray.Implicit.cs
  4. +7
    -0
      src/TensorFlowNET.Core/Numpy/NDArray.Creation.cs
  5. +3
    -0
      src/TensorFlowNET.Core/Numpy/Numpy.Creation.cs

+ 16
- 0
src/TensorFlowNET.Core/NumPy/Implementation/NumPyImpl.Creation.cs View File

@@ -1,6 +1,9 @@
using System;
using System.Collections.Generic;
using System.IO;
using System.Linq;
using System.Text;
using Tensorflow.Util;
using static Tensorflow.Binding;

namespace Tensorflow.NumPy
@@ -67,6 +70,19 @@ namespace Tensorflow.NumPy
return new NDArray(result);
}

Array ReadValueMatrix(BinaryReader reader, Array matrix, int bytes, Type type, int[] shape)
{
int total = 1;
for (int i = 0; i < shape.Length; i++)
total *= shape[i];
var buffer = new byte[bytes * total];

reader.Read(buffer, 0, buffer.Length);
System.Buffer.BlockCopy(buffer, 0, matrix, 0, buffer.Length);

return matrix;
}

public (NDArray, NDArray) meshgrid<T>(T[] array, bool copy = true, bool sparse = false)
{
var tensors = array_ops.meshgrid(array, copy: copy, sparse: sparse);


+ 145
- 0
src/TensorFlowNET.Core/NumPy/Implementation/NumPyImpl.load.cs View File

@@ -0,0 +1,145 @@
using System;
using System.Collections;
using System.Collections.Generic;
using System.IO;
using System.Linq;
using System.Reflection;
using System.Text;
using Tensorflow.Util;

namespace Tensorflow.NumPy
{
public partial class NumPyImpl
{
public NDArray load(string file)
{
using var stream = new FileStream(file, FileMode.Open);
using var reader = new BinaryReader(stream, Encoding.ASCII, leaveOpen: true);
int bytes;
Type type;
int[] shape;
if (!ParseReader(reader, out bytes, out type, out shape))
throw new FormatException();

Array array = Create(type, shape.Aggregate((dims, dim) => dims * dim));

var result = new NDArray(ReadValueMatrix(reader, array, bytes, type, shape));
return result.reshape(shape);
}

bool ParseReader(BinaryReader reader, out int bytes, out Type t, out int[] shape)
{
bytes = 0;
t = null;
shape = null;

// The first 6 bytes are a magic string: exactly "x93NUMPY"
if (reader.ReadChar() != 63) return false;
if (reader.ReadChar() != 'N') return false;
if (reader.ReadChar() != 'U') return false;
if (reader.ReadChar() != 'M') return false;
if (reader.ReadChar() != 'P') return false;
if (reader.ReadChar() != 'Y') return false;

byte major = reader.ReadByte(); // 1
byte minor = reader.ReadByte(); // 0

if (major != 1 || minor != 0)
throw new NotSupportedException();

ushort len = reader.ReadUInt16();

string header = new String(reader.ReadChars(len));
string mark = "'descr': '";
int s = header.IndexOf(mark) + mark.Length;
int e = header.IndexOf("'", s + 1);
string type = header.Substring(s, e - s);
bool? isLittleEndian;
t = GetType(type, out bytes, out isLittleEndian);

if (isLittleEndian.HasValue && isLittleEndian.Value == false)
throw new Exception();

mark = "'fortran_order': ";
s = header.IndexOf(mark) + mark.Length;
e = header.IndexOf(",", s + 1);
bool fortran = bool.Parse(header.Substring(s, e - s));

if (fortran)
throw new Exception();

mark = "'shape': (";
s = header.IndexOf(mark) + mark.Length;
e = header.IndexOf(")", s + 1);
shape = header.Substring(s, e - s).Split(',').Where(v => !String.IsNullOrEmpty(v)).Select(Int32.Parse).ToArray();

return true;
}

Type GetType(string dtype, out int bytes, out bool? isLittleEndian)
{
isLittleEndian = IsLittleEndian(dtype);
bytes = Int32.Parse(dtype.Substring(2));

string typeCode = dtype.Substring(1);

if (typeCode == "b1")
return typeof(bool);
if (typeCode == "i1")
return typeof(Byte);
if (typeCode == "i2")
return typeof(Int16);
if (typeCode == "i4")
return typeof(Int32);
if (typeCode == "i8")
return typeof(Int64);
if (typeCode == "u1")
return typeof(Byte);
if (typeCode == "u2")
return typeof(UInt16);
if (typeCode == "u4")
return typeof(UInt32);
if (typeCode == "u8")
return typeof(UInt64);
if (typeCode == "f4")
return typeof(Single);
if (typeCode == "f8")
return typeof(Double);
if (typeCode.StartsWith("S"))
return typeof(String);

throw new NotSupportedException();
}

bool? IsLittleEndian(string type)
{
bool? littleEndian = null;

switch (type[0])
{
case '<':
littleEndian = true;
break;
case '>':
littleEndian = false;
break;
case '|':
littleEndian = null;
break;
default:
throw new Exception();
}

return littleEndian;
}

Array Create(Type type, int length)
{
// ReSharper disable once PossibleNullReferenceException
while (type.IsArray)
type = type.GetElementType();

return Array.CreateInstance(type, length);
}
}
}

+ 4
- 3
src/TensorFlowNET.Core/NumPy/NDArray.Implicit.cs View File

@@ -8,9 +8,10 @@ namespace Tensorflow.NumPy
{
public void Deconstruct(out byte blue, out byte green, out byte red)
{
blue = (byte)dims[0];
green = (byte)dims[1];
red = (byte)dims[2];
var data = Data<byte>();
blue = data[0];
green = data[1];
red = data[2];
}

public static implicit operator NDArray(Array array)


+ 7
- 0
src/TensorFlowNET.Core/Numpy/NDArray.Creation.cs View File

@@ -19,6 +19,7 @@ namespace Tensorflow.NumPy
public NDArray(Shape shape, TF_DataType dtype = TF_DataType.TF_DOUBLE) => Init(shape, dtype: dtype);
public NDArray(Tensor value, Shape? shape = null) => Init(value, shape);
public NDArray(byte[] bytes, Shape shape, TF_DataType dtype) => Init(bytes, shape, dtype);
public NDArray(IntPtr address, Shape shape, TF_DataType dtype) => Init(address, shape, dtype);

public static NDArray Scalar<T>(T value) where T : unmanaged
=> value switch
@@ -75,5 +76,11 @@ namespace Tensorflow.NumPy
_tensor = new Tensor(bytes, shape, dtype);
_tensor.SetReferencedByNDArray();
}

void Init(IntPtr address, Shape shape, TF_DataType dtype)
{
_tensor = new Tensor(address, shape, dtype);
_tensor.SetReferencedByNDArray();
}
}
}

+ 3
- 0
src/TensorFlowNET.Core/Numpy/Numpy.Creation.cs View File

@@ -40,6 +40,9 @@ namespace Tensorflow.NumPy
TF_DataType dtype = TF_DataType.TF_DOUBLE, int axis = 0) where T : unmanaged
=> tf.numpy.linspace(start, stop, num: num, endpoint: endpoint, retstep: retstep, dtype: dtype, axis: axis);

public static NDArray load(string file)
=> tf.numpy.load(file);

public static (NDArray, NDArray) meshgrid<T>(T x, T y, bool copy = true, bool sparse = false)
=> tf.numpy.meshgrid(new[] { x, y }, copy: copy, sparse: sparse);



Loading…
Cancel
Save