Browse Source

linear regression test 1

tags/v0.8.0
haiping008 6 years ago
parent
commit
3b93c7b0fb
18 changed files with 125 additions and 111 deletions
  1. +1
    -1
      src/TensorFlowNET.Core/APIs/tf.math.cs
  2. +1
    -1
      src/TensorFlowNET.Core/Gradients/gradients_impl.py.cs
  3. +1
    -0
      src/TensorFlowNET.Core/Graphs/Graph.Operation.cs
  4. +1
    -0
      src/TensorFlowNET.Core/Graphs/Graph.cs
  5. +1
    -5
      src/TensorFlowNET.Core/Operations/Operation.cs
  6. +2
    -2
      src/TensorFlowNET.Core/Python.cs
  7. +2
    -0
      src/TensorFlowNET.Core/Sessions/BaseSession.cs
  8. +1
    -1
      src/TensorFlowNET.Core/TensorFlowNET.Core.csproj
  9. +31
    -6
      src/TensorFlowNET.Core/Tensors/Tensor.Creation.cs
  10. +53
    -66
      src/TensorFlowNET.Core/Tensors/Tensor.Operators.cs
  11. +1
    -1
      src/TensorFlowNET.Core/Tensors/c_api.tensor.cs
  12. +4
    -4
      src/TensorFlowNET.Core/Tensors/tensor_util.cs
  13. +22
    -15
      test/TensorFlowNET.Examples/LinearRegression.cs
  14. +1
    -1
      test/TensorFlowNET.Examples/TensorFlowNET.Examples.csproj
  15. +1
    -0
      test/TensorFlowNET.UnitTest/GradientTest.cs
  16. +1
    -1
      test/TensorFlowNET.UnitTest/TensorFlowNET.UnitTest.csproj
  17. +1
    -1
      test/TensorFlowNET.UnitTest/TensorTest.cs
  18. +0
    -6
      test/TensorFlowNET.UnitTest/TrainSaverTest.cs

+ 1
- 1
src/TensorFlowNET.Core/APIs/tf.math.cs View File

@@ -18,7 +18,7 @@ namespace Tensorflow
public static Tensor divide<T>(Tensor x, T[] y, string name = "") where T : struct
=> x / ops.convert_to_tensor(y, dtype: x.dtype.as_base_dtype(), name: "y");

public static Tensor pow(Tensor x, double y) => gen_math_ops.pow(x, y);
public static Tensor pow<T1, T2>(T1 x, T2 y) => gen_math_ops.pow(x, y);

/// <summary>
/// Computes the sum of elements across dimensions of a tensor.


+ 1
- 1
src/TensorFlowNET.Core/Gradients/gradients_impl.py.cs View File

@@ -357,7 +357,7 @@ namespace Tensorflow
if (y.dtype.is_complex())
throw new TypeAccessException($"Gradients of complex tensors must set grad_ys (y.dtype = {y.dtype})");
var shape = array_ops.shape(y);
var constant = constant_op.constant(1.0, name: $"grad_ys_{i}");
var constant = constant_op.constant(1.0f, name: $"grad_ys_{i}");
var fill = gen_array_ops.fill(shape, constant);
new_grad_ys.Add(fill);
}


+ 1
- 0
src/TensorFlowNET.Core/Graphs/Graph.Operation.cs View File

@@ -41,6 +41,7 @@ namespace Tensorflow
public Operation _create_op_from_tf_operation(IntPtr c_op, bool compute_device = true)
{
var ret = new Operation(c_op);
_add_op(ret);

var name_key = ret.name.ToLower();
if (!_names_in_use.ContainsKey(name_key))


+ 1
- 0
src/TensorFlowNET.Core/Graphs/Graph.cs View File

@@ -212,6 +212,7 @@ namespace Tensorflow

public void _add_op(Operation op)
{
op._id_value = _next_id();
_nodes_by_id[op._id] = op;
_nodes_by_name[op.name] = op;
_version = Math.Max(_version, op._id);


+ 1
- 5
src/TensorFlowNET.Core/Operations/Operation.cs View File

@@ -13,7 +13,7 @@ namespace Tensorflow

public Graph graph { get; }
public int _id => _id_value;
private int _id_value;
public int _id_value;

public string type => OpType;
public Operation op => this;
@@ -46,8 +46,6 @@ namespace Tensorflow
_outputs = new Tensor[NumOutputs];
for (int i = 0; i < NumOutputs; i++)
_outputs[i] = new Tensor(this, i, OutputType(i));

graph._add_op(this);
}

public Operation(Graph g, string opType, string oper_name)
@@ -100,8 +98,6 @@ namespace Tensorflow
}

// This will be set by self.inputs.

_id_value = graph._next_id();
if(op_def == null)
op_def = g.GetOpDef(node_def.Op);



+ 2
- 2
src/TensorFlowNET.Core/Python.cs View File

@@ -88,8 +88,8 @@ namespace Tensorflow

public static IEnumerable<(T, T)> zip<T>(NDArray t1, NDArray t2)
{
int index = 0;
yield return(t1.Data<T>(index), t2.Data<T>(index));
for (int i = 0; i < t1.size; i++)
yield return (t1.Data<T>(i), t2.Data<T>(i));
}

public static IEnumerable<(T1, T2)> zip<T1, T2>(IList<T1> t1, IList<T2> t2)


+ 2
- 0
src/TensorFlowNET.Core/Sessions/BaseSession.cs View File

@@ -59,6 +59,7 @@ namespace Tensorflow
{
var subfeed_t = _graph.as_graph_element(subfeed, allow_tensor: true, allow_operation: false);
var subfeed_dtype = subfeed_t.dtype.as_numpy_datatype();

switch (subfeed_val)
{
case IntPtr pointer:
@@ -86,6 +87,7 @@ namespace Tensorflow
Console.WriteLine($"can't handle data type of subfeed_val");
throw new NotImplementedException("_run subfeed");
}
feed_map[subfeed_t.name] = (subfeed_t, subfeed_val);
}
}


+ 1
- 1
src/TensorFlowNET.Core/TensorFlowNET.Core.csproj View File

@@ -45,7 +45,7 @@ Upgraded to TensorFlow 1.13 RC2.

<ItemGroup>
<PackageReference Include="Google.Protobuf" Version="3.6.1" />
<PackageReference Include="NumSharp" Version="0.7.2" />
<PackageReference Include="NumSharp" Version="0.7.3" />
</ItemGroup>

<ItemGroup>


+ 31
- 6
src/TensorFlowNET.Core/Tensors/Tensor.Creation.cs View File

@@ -52,20 +52,45 @@ namespace Tensorflow
var dataType = ToTFDataType(nd.dtype);
// shape
var dims = nd.shape.Select(x => (long)x).ToArray();
var nd1 = nd.ravel();
switch (nd.dtype.Name)
{
case "Int16":
Marshal.Copy(nd.ravel().Data<short>(), 0, dotHandle, nd.size);
Marshal.Copy(nd1.Data<short>(), 0, dotHandle, nd.size);
break;
case "Int32":
Marshal.Copy(nd.ravel().Data<int>(), 0, dotHandle, nd.size);
Marshal.Copy(nd1.Data<int>(), 0, dotHandle, nd.size);
break;
case "Single":
Marshal.Copy(nd.ravel().Data<float>(), 0, dotHandle, nd.size);
Marshal.Copy(nd1.Data<float>(), 0, dotHandle, nd.size);
/*if (nd.size > 1)
{
var bb = nd.Data<byte>();
var bytes = Marshal.AllocHGlobal(bb.Length);
Marshal.Copy(bb, 0, bytes, bb.Length);
ulong bytes_len = c_api.TF_StringEncodedSize((ulong)bb.Length);
var dataTypeByte = ToTFDataType(nd.dtype);
// shape
var dims2 = nd.shape.Select(x => (long)x).ToArray();

var tfHandle2 = c_api.TF_AllocateTensor(dataTypeByte,
dims2,
nd.ndim,
bytes_len + sizeof(Int64));

dotHandle = c_api.TF_TensorData(tfHandle2);
Marshal.WriteInt64(dotHandle, 0);
c_api.TF_StringEncode(bytes, (ulong)bb.Length, dotHandle + sizeof(Int64), bytes_len, status);
return tfHandle2;
}
else
{
Marshal.Copy(nd1.Data<float>(), 0, dotHandle, nd.size);
}*/
break;
case "Double":
Marshal.Copy(nd.ravel().Data<double>(), 0, dotHandle, nd.size);
Marshal.Copy(nd1.Data<double>(), 0, dotHandle, nd.size);
break;
//case "Byte":
/*var bb = nd.Data<byte>();
@@ -119,7 +144,7 @@ namespace Tensorflow
dims,
dims.Length,
dotHandle,
size,
(UIntPtr)size,
deallocator,
ref deallocator_called);



+ 53
- 66
src/TensorFlowNET.Core/Tensors/Tensor.Operators.cs View File

@@ -6,83 +6,70 @@ namespace Tensorflow
{
public partial class Tensor
{
public static Tensor operator +(Tensor x, Tensor y)
{
return Python.with<ops.name_scope, Tensor>(new ops.name_scope("", "add", new Tensor[] { x, y }), scope =>
{
return gen_math_ops.add(x, y, scope);
});
}

public static Tensor operator +(Tensor x, int y)
{
return Python.with<ops.name_scope, Tensor>(new ops.name_scope("", "add", new object[] { x, y }), scope =>
{
var y1 = ops.convert_to_tensor(y, x.dtype.as_base_dtype(), name: "y");
return gen_math_ops.add(x, y1, scope);
});
}
public static Tensor operator +(Tensor x, Tensor y) => BinaryOpWrapper("add", x, y);
public static Tensor operator +(Tensor x, int y) => BinaryOpWrapper("add", x, y);

public static Tensor operator -(Tensor t1) => gen_math_ops.neg(t1);
public static Tensor operator -(Tensor t1, Tensor t2) => gen_math_ops.sub(t1, t2);
public static Tensor operator -(Tensor t1, int t2) => gen_math_ops.sub(t1, t2);
public static Tensor operator -(Tensor t1, double t2) => gen_math_ops.sub(t1, t2);

public static Tensor operator *(double x, Tensor y)
{
return Python.with<ops.name_scope, Tensor>(new ops.name_scope("", "mul", new { x, y }),
scope =>
{
var x1 = ops.convert_to_tensor(x, y.dtype.as_base_dtype(), name: "x");
return gen_math_ops.mul(x1, y, name: scope);
});
}
public static Tensor operator -(Tensor x, Tensor y) => BinaryOpWrapper("sub", x, y);
public static Tensor operator -(Tensor x, int y) => BinaryOpWrapper("sub", x, y);
public static Tensor operator -(Tensor x, double y) => BinaryOpWrapper("sub", x, y);

public static Tensor operator *(Tensor x, Tensor y)
{
return Python.with<ops.name_scope, Tensor>(new ops.name_scope("", "mul", new Tensor[] { x, y }), scope =>
{
return gen_math_ops.mul(x, y, name: scope);
});
}
public static Tensor operator *(float x, Tensor y) => BinaryOpWrapper("mul", x, y);
public static Tensor operator *(double x, Tensor y) => BinaryOpWrapper("mul", x, y);
public static Tensor operator *(Tensor x, Tensor y) => BinaryOpWrapper("mul", x, y);
public static Tensor operator *(Tensor x, int y) => BinaryOpWrapper("mul", x, y);

public static Tensor operator *(Tensor x, int y)
{
return Python.with<ops.name_scope, Tensor>(new ops.name_scope("", "mul", new object[] { x, y }), scope =>
{
var y1 = ops.convert_to_tensor(y, x.dtype.as_base_dtype(), name: "y");
return gen_math_ops.mul(x, y1, name: scope);
});
}
public static Tensor operator /(Tensor x, Tensor y) => BinaryOpWrapper("truediv", x, y);
public static Tensor operator /(Tensor x, float y) => BinaryOpWrapper("truediv", x, y);
public static Tensor operator /(Tensor x, double y) => BinaryOpWrapper("truediv", x, y);

public static Tensor operator /(Tensor x, Tensor y)
{
return Python.with<ops.name_scope, Tensor>(new ops.name_scope("truediv/", "truediv", new Tensor[] { x, y }), scope =>
{
return gen_math_ops.real_div(x, y, scope);
});
}
public static Tensor operator %(Tensor x, Tensor y) => BinaryOpWrapper("mod", x, y);

public static Tensor operator /(Tensor x, double y)
{
return Python.with<ops.name_scope, Tensor>(new ops.name_scope("truediv/", "truediv", new object[] { x, y }), scope =>
{
var y1 = ops.convert_to_tensor(y, dtype: x.dtype.as_base_dtype(), name: "y");
return gen_math_ops.real_div(x, y1, scope);
});
}
public static Tensor operator >(Tensor x, int y) => gen_array_ops.greater(x, y);
public static Tensor operator >(Tensor x, double y) => gen_array_ops.greater(x, y);
public static Tensor operator <(Tensor x, int y) => gen_array_ops.less(x, y);
public static Tensor operator <(Tensor x, double y) => gen_array_ops.less(x, y);

public static Tensor operator %(Tensor x, Tensor y)
private static Tensor BinaryOpWrapper<Tx, Ty>(string name, Tx x, Ty y)
{
return Python.with<ops.name_scope, Tensor>(new ops.name_scope("", "mod", new object[] { x, y }), scope =>
TF_DataType dtype = TF_DataType.DtInvalid;
if (x is Tensor tl)
dtype = tl.dtype.as_base_dtype();
if( y is Tensor tr)
dtype = tr.dtype.as_base_dtype();
var namescope = new ops.name_scope("", name, new { x, y });
return Python.with<ops.name_scope, Tensor>(namescope, scope =>
{
return gen_math_ops.floor_mod(x, y, scope);
Tensor result = null;
var x1 = ops.convert_to_tensor(x, dtype: dtype, name: "x");
var y1 = ops.convert_to_tensor(y, dtype: dtype, name: "y");

switch (name)
{
case "add":
result = gen_math_ops.add(x1, y1, name: scope);
break;
case "truediv":
result = gen_math_ops.real_div(x1, y1, name: scope);
break;
case "mul":
result = gen_math_ops.mul(x1, y1, name: scope);
break;
case "sub":
result = gen_math_ops.sub(x1, y1, name: scope);
break;
case "mod":
result = gen_math_ops.floor_mod(x1, y1, name: scope);
break;
default:
throw new NotImplementedException($"BinaryOpWrapper: {name} - {typeof(Tx).Name}, {typeof(Ty)}");
}

return result;
});
}

public static Tensor operator >(Tensor x, int y) => gen_array_ops.greater(x, y);
public static Tensor operator >(Tensor x, double y) => gen_array_ops.greater(x, y);
public static Tensor operator <(Tensor x, int y) => gen_array_ops.less(x, y);
public static Tensor operator <(Tensor x, double y) => gen_array_ops.less(x, y);
}
}

+ 1
- 1
src/TensorFlowNET.Core/Tensors/c_api.tensor.cs View File

@@ -55,7 +55,7 @@ namespace Tensorflow
/// <param name="deallocator_arg"></param>
/// <returns></returns>
[DllImport(TensorFlowLibName)]
public static extern IntPtr TF_NewTensor(TF_DataType dataType, long[] dims, int num_dims, IntPtr data, ulong len, Deallocator deallocator, ref bool deallocator_arg);
public static extern IntPtr TF_NewTensor(TF_DataType dataType, long[] dims, int num_dims, IntPtr data, UIntPtr len, Deallocator deallocator, ref bool deallocator_arg);

/// <summary>
/// Return the number of dimensions that the tensor has.


+ 4
- 4
src/TensorFlowNET.Core/Tensors/tensor_util.cs View File

@@ -96,19 +96,19 @@ namespace Tensorflow
if (values.GetType().IsArray)
nparray = np.array((int[])values, np_dt);
else
nparray = (int)values;
nparray = Convert.ToInt32(values);
break;
case "Single":
if (values.GetType().IsArray)
nparray = np.array((float[])values, np_dt);
else
nparray = (float)values;
nparray = Convert.ToSingle(values);
break;
case "Double":
nparray = (double)values;
nparray = Convert.ToDouble(values);
break;
case "String":
nparray = values.ToString();
nparray = Convert.ToString(values);
break;
default:
throw new NotImplementedException("make_tensor_proto Not Implemented");


+ 22
- 15
test/TensorFlowNET.Examples/LinearRegression.cs View File

@@ -10,41 +10,47 @@ namespace TensorFlowNET.Examples
/// A linear regression learning algorithm example using TensorFlow library.
/// https://github.com/aymericdamien/TensorFlow-Examples/blob/master/examples/2_BasicModels/linear_regression.py
/// </summary>
public class LinearRegression : IExample
public class LinearRegression : Python, IExample
{
private NumPyRandom rng = np.random;

public void Run()
{
var graph = tf.Graph().as_default();

// Parameters
double learning_rate = 0.01;
float learning_rate = 0.01f;
int training_epochs = 1000;
int display_step = 50;
int display_step = 1;

// Training Data
var train_X = np.array(3.3, 4.4, 5.5, 6.71, 6.93, 4.168, 9.779, 6.182, 7.59, 2.167,
7.042, 10.791, 5.313, 7.997, 5.654, 9.27, 3.1);
var train_Y = np.array(1.7, 2.76, 2.09, 3.19, 1.694, 1.573, 3.366, 2.596, 2.53, 1.221,
2.827, 3.465, 1.65, 2.904, 2.42, 2.94, 1.3);
var train_X = np.array(3.3f, 4.4f, 5.5f, 6.71f, 6.93f, 4.168f, 9.779f, 6.182f, 7.59f, 2.167f,
7.042f, 10.791f, 5.313f, 7.997f, 5.654f, 9.27f, 3.1f);
var train_Y = np.array(1.7f, 2.76f, 2.09f, 3.19f, 1.694f, 1.573f, 3.366f, 2.596f, 2.53f, 1.221f,
2.827f, 3.465f, 1.65f, 2.904f, 2.42f, 2.94f, 1.3f);
var n_samples = train_X.shape[0];
// tf Graph Input
var X = tf.placeholder(tf.float64);
var Y = tf.placeholder(tf.float64);
var X = tf.placeholder(tf.float32);
var Y = tf.placeholder(tf.float32);

// Set model weights
var W = tf.Variable(rng.randn<double>(), name: "weight");
var b = tf.Variable(rng.randn<double>(), name: "bias");
//var rnd1 = rng.randn<float>();
//var rnd2 = rng.randn<float>();
var W = tf.Variable(-0.06f, name: "weight");
var b = tf.Variable(-0.73f, name: "bias");

var mul = tf.multiply(X, W);
var pred = tf.add(mul, b);

// Mean squared error
var sub = pred - Y;
var pow = tf.pow(sub, 2);
var pow = tf.pow(sub, 2.0f);

var reduce = tf.reduce_sum(pow);
var cost = reduce / (2d * n_samples);
var cost = reduce / (2.0f * n_samples);

// import graph

// radient descent
// Note, minimize() knows to modify W and b because Variable objects are trainable=True by default
@@ -55,7 +61,7 @@ namespace TensorFlowNET.Examples
var init = tf.global_variables_initializer();

// Start training
Python.with<Session>(tf.Session(), sess =>
Python.with<Session>(tf.Session(graph), sess =>
{
// Run the initializer
sess.run(init);
@@ -63,11 +69,12 @@ namespace TensorFlowNET.Examples
// Fit all training data
for (int epoch = 0; epoch < training_epochs; epoch++)
{
foreach (var (x, y) in Python.zip<double>(train_X, train_Y))
foreach (var (x, y) in zip<float>(train_X, train_Y))
{
sess.run(optimizer,
new FeedItem(X, x),
new FeedItem(Y, y));
var w = sess.run(W);
}

// Display logs per epoch step


+ 1
- 1
test/TensorFlowNET.Examples/TensorFlowNET.Examples.csproj View File

@@ -6,7 +6,7 @@
</PropertyGroup>

<ItemGroup>
<PackageReference Include="NumSharp" Version="0.7.2" />
<PackageReference Include="NumSharp" Version="0.7.3" />
<PackageReference Include="TensorFlow.NET" Version="0.3.0" />
</ItemGroup>



+ 1
- 0
test/TensorFlowNET.UnitTest/GradientTest.cs View File

@@ -12,6 +12,7 @@ namespace TensorFlowNET.UnitTest
[TestMethod]
public void Gradients()
{
var graph = tf.Graph().as_default();
var a = tf.constant(0.0);
var b = 2.0 * a;
Assert.AreEqual(b.name, "mul:0");


+ 1
- 1
test/TensorFlowNET.UnitTest/TensorFlowNET.UnitTest.csproj View File

@@ -19,7 +19,7 @@
<PackageReference Include="Microsoft.NET.Test.Sdk" Version="15.9.0" />
<PackageReference Include="MSTest.TestAdapter" Version="1.4.0" />
<PackageReference Include="MSTest.TestFramework" Version="1.4.0" />
<PackageReference Include="NumSharp" Version="0.7.2" />
<PackageReference Include="NumSharp" Version="0.7.3" />
<PackageReference Include="TensorFlow.NET" Version="0.3.0" />
</ItemGroup>



+ 1
- 1
test/TensorFlowNET.UnitTest/TensorTest.cs View File

@@ -60,7 +60,7 @@ namespace TensorFlowNET.UnitTest
EXPECT_EQ((int)tensor.shape[0], nd.shape[0]);
EXPECT_EQ((int)tensor.shape[1], nd.shape[1]);
EXPECT_EQ(tensor.bytesize, (ulong)nd.size * sizeof(float));
Assert.IsTrue(Enumerable.SequenceEqual(nd.Data<float>(), array));
Assert.IsTrue(Enumerable.SequenceEqual(nd.Data<float>(), new float[] { 1, 4, 2, 5, 3, 6 }));
}

/// <summary>


+ 0
- 6
test/TensorFlowNET.UnitTest/TrainSaverTest.cs View File

@@ -10,7 +10,6 @@ namespace TensorFlowNET.UnitTest
[TestClass]
public class TrainSaverTest : Python
{
[TestMethod]
public void ExportGraph()
{
var v = tf.Variable(0, name: "my_variable");
@@ -18,7 +17,6 @@ namespace TensorFlowNET.UnitTest
tf.train.write_graph(sess.graph, "/tmp/my-model", "train1.pbtxt");
}

[TestMethod]
public void ImportGraph()
{
with<Session>(tf.Session(), sess =>
@@ -27,7 +25,6 @@ namespace TensorFlowNET.UnitTest
});
}

[TestMethod]
public void ImportSavedModel()
{
with<Session>(Session.LoadFromSavedModel("mobilenet"), sess =>
@@ -36,14 +33,12 @@ namespace TensorFlowNET.UnitTest
});
}

[TestMethod]
public void ImportGraphDefFromPbFile()
{
var g = new Graph();
var status = g.Import("mobilenet/saved_model.pb");
}

[TestMethod]
public void Save1()
{
var w1 = tf.Variable(0, name: "save1");
@@ -63,7 +58,6 @@ namespace TensorFlowNET.UnitTest
});
}

[TestMethod]
public void Save2()
{
var v1 = tf.get_variable("v1", shape: new TensorShape(3), initializer: tf.zeros_initializer);


Loading…
Cancel
Save