Browse Source

linear regression gradients.

tags/v0.8.0
haiping008 6 years ago
parent
commit
f2c8883bd9
15 changed files with 428 additions and 59 deletions
  1. +1
    -20
      src/TensorFlowNET.Core/APIs/tf.math.cs
  2. +31
    -6
      src/TensorFlowNET.Core/Gradients/gradients_impl.py.cs
  3. +26
    -2
      src/TensorFlowNET.Core/Gradients/math_grad.py.cs
  4. +43
    -5
      src/TensorFlowNET.Core/Operations/array_ops.py.cs
  5. +22
    -1
      src/TensorFlowNET.Core/Operations/gen_array_ops.cs
  6. +13
    -6
      src/TensorFlowNET.Core/Operations/gen_math_ops.cs
  7. +117
    -0
      src/TensorFlowNET.Core/Operations/math_ops.py.cs
  8. +8
    -4
      src/TensorFlowNET.Core/Tensors/Tensor.Implicit.cs
  9. +42
    -10
      src/TensorFlowNET.Core/Tensors/Tensor.Operators.cs
  10. +8
    -3
      src/TensorFlowNET.Core/Tensors/Tensor.cs
  11. +1
    -1
      src/TensorFlowNET.Core/Tensors/TensorShape.cs
  12. +10
    -0
      src/TensorFlowNET.Core/Tensors/dtypes.cs
  13. +8
    -1
      src/TensorFlowNET.Core/ops.py.cs
  14. +4
    -0
      test/TensorFlowNET.Examples/TensorFlowNET.Examples.csproj
  15. +94
    -0
      test/TensorFlowNET.Examples/python/linear_regression.py

+ 1
- 20
src/TensorFlowNET.Core/APIs/tf.math.cs View File

@@ -34,26 +34,7 @@ namespace Tensorflow
/// <returns></returns>
public static Tensor reduce_sum(Tensor input, int[] axis = null)
{
Tensor rank;
string name;
using (var namescop = new ops.name_scope("", "Rank", new List<Tensor> { input }))
{
name = namescop;
rank = gen_array_ops.rank(input, namescop);
}

using (var namescope = new ops.name_scope("range", "Range", new List<Tensor> { 0D, input, 1D }))
{
name = namescope;
var start = ops.convert_to_tensor(0D);
var limit = ops.convert_to_tensor(input);
var delta = ops.convert_to_tensor(1D);

var t = gen_math_ops.range(start, limit, delta, name);
}
var s = gen_math_ops.sum(input, rank);
return s;
return math_ops.reduce_sum(input);
}
}
}

+ 31
- 6
src/TensorFlowNET.Core/Gradients/gradients_impl.py.cs View File

@@ -181,7 +181,11 @@ namespace Tensorflow
{
foreach(var x in _NonEagerInputs(op, xs))
{
pending_count[x.op.Name] -= 1;
if (!pending_count.ContainsKey(x.op.Name))
pending_count[x.op.Name] = 0;
else
pending_count[x.op.Name] -= 1;

var ready = pending_count[x.op.Name] == 0;

if(loop_state != null && !ready)
@@ -440,22 +444,43 @@ namespace Tensorflow
reached_ops.Add(op);
foreach (var output in op.outputs)
{
var c = _Consumers(output, func_graphs).ToList();
c.ForEach(x => queue.Enqueue(x));
if (_IsBackpropagatable(output))
{
var c = _Consumers(output, func_graphs).ToList();
c.ForEach(x => queue.Enqueue(x));
}
}
}
}
}

private static bool _IsTrainable(Tensor tensor)
{
var dtype = tensor.dtype.as_base_dtype();
return new TF_DataType[] {TF_DataType.TF_HALF, TF_DataType.TF_FLOAT, TF_DataType.TF_DOUBLE,
TF_DataType.TF_COMPLEX64, TF_DataType.TF_COMPLEX128, TF_DataType.TF_RESOURCE}.Contains(dtype);
}
private static bool _IsBackpropagatable(Tensor tensor)
{
if(_IsTrainable(tensor))
{
return true;
}
else
{
var dtype = tensor.dtype.as_base_dtype();
return new TF_DataType[] { TF_DataType.TF_BFLOAT16, TF_DataType.TF_VARIANT }.Contains(dtype);
}
}

/// <summary>
/// Returns the consumers of t, crossing closure boundaries where necessary.
/// </summary>
/// <param name="t"></param>
/// <param name="func_graphs"></param>
private static List<Operation> _Consumers(Tensor t, List<object> func_graphs)
private static Operation[] _Consumers(Tensor t, List<object> func_graphs)
{
var consumers = t.consumers();
return consumers;
return t.consumers();
}

private static List<Tensor> _AsList(object ys)


+ 26
- 2
src/TensorFlowNET.Core/Gradients/math_grad.py.cs View File

@@ -17,6 +17,22 @@ namespace Tensorflow
return (grad, grad);
}

public static (Tensor, Tensor) _SumGrad(Operation op, Tensor grad)
{
if (op.inputs[0].NDims > -1)
{

}

var input_shape = array_ops.shape(op.inputs[0]);
ops.colocate_with(input_shape);
var output_shape_kept_dims = math_ops.reduced_shape(input_shape, op.inputs[1]);
//var tile_scaling = _safe_shape_div(input_shape, output_shape_kept_dims);
//var grad = array_ops.reshape(grad, output_shape_kept_dims);

return (grad, grad);
}

public static (Tensor, Tensor) _RealDivGrad(Operation op, Tensor grad)
{
var x = op.inputs[0];
@@ -24,9 +40,17 @@ namespace Tensorflow

var sx = array_ops.shape(x);
var sy = array_ops.shape(y);
// rx, ry = gen_array_ops.broadcast_gradient_args(sx, sy)
var (rx, ry) = gen_array_ops.broadcast_gradient_args(sx, sy);
x = math_ops.conj(x);
y = math_ops.conj(y);

return (grad, grad);
var realdiv1 = gen_math_ops.real_div(grad, y);
var reduce_sum1 = math_ops.reduce_sum(realdiv1, rx);
var realdiv2 = gen_math_ops.real_div(-x, y);
var realdiv3 = gen_math_ops.real_div(realdiv2, y);
var reduce_sum2 = math_ops.reduce_sum(grad * realdiv3, ry);

return (gen_array_ops.reshape(reduce_sum1, sx), gen_array_ops.reshape(reduce_sum2, sy));
}
}
}

+ 43
- 5
src/TensorFlowNET.Core/Operations/array_ops.py.cs View File

@@ -53,6 +53,11 @@ namespace Tensorflow
}
}

public static Tensor rank(Tensor input, string name = "")
{
return math_ops.rank_internal(input, name, optimize: true);
}

/// <summary>
/// Returns the shape of a tensor.
/// </summary>
@@ -68,11 +73,14 @@ namespace Tensorflow
return shape_internal(input, name, optimize: true, out_type: out_type);
}

private static Tensor shape_internal(Tensor input, string name = "", bool optimize = true, TF_DataType out_type = TF_DataType.TF_INT32)
public static Tensor size(Tensor input, string name = "", TF_DataType out_type = TF_DataType.TF_INT32)
{
Tensor result = null;
return size_internal(input, name, optimize: true, out_type: out_type);
}

Python.with<ops.name_scope>(new ops.name_scope(name, "Shape", new Tensor[] { input }), scope =>
private static Tensor shape_internal(Tensor input, string name = "", bool optimize = true, TF_DataType out_type = TF_DataType.TF_INT32)
{
return Python.with<ops.name_scope, Tensor>(new ops.name_scope(name, "Shape", new Tensor[] { input }), scope =>
{
name = scope;

@@ -83,16 +91,46 @@ namespace Tensorflow
if (optimize && input_shape.is_fully_defined())
{
var nd = np.array(input_tensor.shape, out_type.as_numpy_datatype());
result = constant_op.constant(nd, name);
return constant_op.constant(nd, name);
}
}
else
{
// result = gen_array_ops.shape();
}

return null;
});
}

return result;
private static Tensor size_internal(Tensor input, string name = "", bool optimize = true, TF_DataType out_type = TF_DataType.TF_INT32)
{
return Python.with<ops.name_scope, Tensor>(new ops.name_scope(name, "Size", new Tensor[] { input }), scope =>
{
name = scope;

if (!tf.context.executing_eagerly())
{
var input_tensor = ops.convert_to_tensor(input);
var input_shape = tensor_util.to_shape(input_tensor.shape);
if (optimize)
{
if (input_shape.is_fully_defined())
{
var nd = np.array(input_tensor.shape, out_type.as_numpy_datatype());
return constant_op.constant(nd, name);
}
}

return gen_array_ops.size(input, name: name, out_type: out_type);
}
else
{
// result = gen_array_ops.shape();
}

return null;
});
}
}
}

+ 22
- 1
src/TensorFlowNET.Core/Operations/gen_array_ops.cs View File

@@ -60,9 +60,30 @@ namespace Tensorflow
return _op.outputs[0];
}

/// <summary>
/// Return the reduction indices for computing gradients of s0 op s1 with broadcast.
/// </summary>
/// <param name="s0">A `Tensor`. Must be one of the following types: `int32`, `int64`.</param>
/// <param name="s1">A `Tensor`. Must have the same type as `s0`.</param>
/// <param name="name">A name for the operation (optional).</param>
/// <returns>A tuple of `Tensor` objects (r0, r1).</returns>
public static (Tensor, Tensor) broadcast_gradient_args(Tensor s0, Tensor s1, string name = "")
{
return (null, null);
var _op = _op_def_lib._apply_op_helper("BroadcastGradientArgs", name, new { s0, s1 });

return (_op.outputs[0], _op.outputs[1]);
}

public static Tensor reshape(Tensor tensor, Tensor shape, string name = "")
{
var _op = _op_def_lib._apply_op_helper("Reshape", name, new { tensor, shape });
return _op.outputs[0];
}

public static Tensor size(Tensor input, TF_DataType out_type = TF_DataType.TF_INT32, string name = "")
{
var _op = _op_def_lib._apply_op_helper("Size", name, new { input, out_type });
return _op.outputs[0];
}
}
}

+ 13
- 6
src/TensorFlowNET.Core/Operations/gen_math_ops.cs View File

@@ -17,9 +17,16 @@ namespace Tensorflow
return _op.outputs[0];
}

public static Tensor sub(Tensor x, Tensor y)
public static Tensor neg(Tensor x, string name = "")
{
var _op = _op_def_lib._apply_op_helper("Sub", name: "sub", args: new { x, y });
var _op = _op_def_lib._apply_op_helper("Neg", name, args: new { x });

return _op.outputs[0];
}

public static Tensor sub(Tensor x, Tensor y, string name = "")
{
var _op = _op_def_lib._apply_op_helper("Sub", name, args: new { x, y });

return _op.outputs[0];
}
@@ -31,9 +38,9 @@ namespace Tensorflow
return _op.outputs[0];
}

public static Tensor real_div(Tensor x, Tensor y)
public static Tensor real_div(Tensor x, Tensor y, string name = "")
{
var _op = _op_def_lib._apply_op_helper("RealDiv", name: "truediv", args: new { x, y });
var _op = _op_def_lib._apply_op_helper("RealDiv", name, args: new { x, y });

return _op.outputs[0];
}
@@ -61,9 +68,9 @@ namespace Tensorflow
return _op.outputs[0];
}

public static Tensor sum(Tensor input, Tensor axis = null)
public static Tensor sum(Tensor input, Tensor axis = null, bool keep_dims = false, string name = "")
{
var _op = _op_def_lib._apply_op_helper("Sum", args: new { input, reduction_indices = axis, keep_dims = false });
var _op = _op_def_lib._apply_op_helper("Sum", name, args: new { input, reduction_indices = axis, keep_dims });

return _op.outputs[0];
}


+ 117
- 0
src/TensorFlowNET.Core/Operations/math_ops.py.cs View File

@@ -6,6 +6,104 @@ namespace Tensorflow
{
public class math_ops
{
/// <summary>
/// Helper function for reduction ops.
/// </summary>
/// <param name="input_shape">1-D Tensor, the shape of the Tensor being reduced.</param>
/// <param name="axes">1-D Tensor, the reduction axes.</param>
/// <returns>A 1-D Tensor, the output shape as if keepdims were set to True.</returns>
public static Tensor reduced_shape(Tensor input_shape, Tensor axes)
{
input_shape = to_int32(input_shape);
axes = to_int32(axes);

var input_rank = array_ops.size(input_shape);
axes = (axes + input_rank) % input_rank;

return null;
}

/// <summary>
/// Casts a tensor to type `int32`.
/// </summary>
/// <param name="x">A `Tensor` or `SparseTensor` or `IndexedSlices`.</param>
/// <param name="name">A name for the operation (optional).</param>
/// <returns>A `Tensor` or `SparseTensor` or `IndexedSlices` with same shape as `x` with type `int32`.</returns>
private static Tensor to_int32(Tensor x, string name = "ToInt32")
{
return __case__(x, TF_DataType.TF_INT32, name: name);
}

/// <summary>
/// Casts a tensor to a new type.
/// </summary>
/// <param name="x"></param>
/// <param name="dtype"></param>
/// <param name="name"></param>
/// <returns>A `Tensor` or `SparseTensor` or `IndexedSlices` with same shape as `x` and same type as `dtype`.</returns>
public static Tensor __case__(Tensor x, TF_DataType dtype, string name = "")
{
var base_type = dtype.as_base_dtype();
if (x is Tensor && base_type == x.dtype)
return x;

// math_ops.py cast
throw new NotImplementedException();
}

public static Tensor reduce_sum(Tensor input_tensor, Tensor axis = null, bool keepdims = false)
{
var r = _ReductionDims(input_tensor, axis);
var m = gen_math_ops.sum(input_tensor, r);
return _may_reduce_to_scalar(keepdims, m);
}

private static Tensor _may_reduce_to_scalar(bool keepdims, Tensor output)
{
output.shape = new long[0];
return output;
}

private static Tensor _ReductionDims(Tensor x, Tensor axis)
{
if (axis != null)
{
return axis;
}
else
{
var rank = array_ops.rank(x);
return range(0, rank, 1);
}
}

public static Tensor range(object start, Tensor limit = null, object delta = null, TF_DataType dtype = TF_DataType.DtInvalid, string name = "range" )
{
return Python.with<ops.name_scope, Tensor>(new ops.name_scope(name, "Range", new object[] { start, limit, delta }), scope =>
{
name = scope;
var start1 = ops.convert_to_tensor(start, name: "start");
var limit1 = ops.convert_to_tensor(limit, name: "limit");
var delta1 = ops.convert_to_tensor(delta, name: "delta");

return gen_math_ops.range(start1, limit1, delta1, name);
});
}

public static Tensor rank_internal(Tensor input, string name = "", bool optimize = true)
{
return Python.with<ops.name_scope, Tensor>(new ops.name_scope(name, "Rank", new List<Tensor> { input }), scope =>
{
name = scope;
var input_tensor = ops.convert_to_tensor(input);
var input_shape = tensor_util.to_shape(input_tensor.shape);
if (optimize && input_shape.NDim == null)
return constant_op.constant(input_shape.NDim);
else
return gen_array_ops.rank(input, name);
});
}

public static Tensor matmul(Tensor a, Tensor b,
bool transpose_a = false, bool transpose_b = false,
bool adjoint_a = false, bool adjoint_b = false,
@@ -31,5 +129,24 @@ namespace Tensorflow

return result;
}

/// <summary>
/// Returns the complex conjugate of a complex number.
/// </summary>
/// <param name="x">`Tensor` to conjugate. Must have numeric or variant type.</param>
/// <param name="name">A name for the operation (optional).</param>
/// <returns>A `Tensor` that is the conjugate of `x` (with the same type).</returns>
public static Tensor conj(Tensor x, string name = "")
{
var dt = x.dtype;
if (dt.is_floating() || dt.is_integer())
return x;

return Python.with<ops.name_scope, Tensor>(new ops.name_scope(name, "Conj", new List<Tensor> { x }), scope =>
{

return x;
});
}
}
}

+ 8
- 4
src/TensorFlowNET.Core/Tensors/Tensor.Implicit.cs View File

@@ -6,15 +6,19 @@ namespace Tensorflow
{
public partial class Tensor
{
public static implicit operator Tensor(double scalar)
/// <summary>
/// Issue unresolved, will cause name_scope problem.
/// </summary>
/// <param name="scalar"></param>
/*public static implicit operator Tensor(double scalar)
{
return constant_op.constant(scalar);
}
}*/

public static implicit operator Tensor(int scalar)
/*public static implicit operator Tensor(int scalar)
{
return constant_op.constant(scalar);
}
}*/

public static implicit operator int(Tensor tensor)
{


+ 42
- 10
src/TensorFlowNET.Core/Tensors/Tensor.Operators.cs View File

@@ -8,14 +8,24 @@ namespace Tensorflow
{
public static Tensor operator +(Tensor x, Tensor y)
{
Tensor t = null;
return Python.with<ops.name_scope, Tensor>(new ops.name_scope("", "add", new Tensor[] { x, y }), scope =>
{
return gen_math_ops.add(x, y, scope);
});
}

Python.with<ops.name_scope>(new ops.name_scope("", "add", new Tensor[] { x, y }), scope =>
public static Tensor operator +(Tensor x, int y)
{
return Python.with<ops.name_scope, Tensor>(new ops.name_scope("", "add", new object[] { x, y }), scope =>
{
t = gen_math_ops.add(x, y, scope);
var y1 = ops.convert_to_tensor(y, x.dtype.as_base_dtype(), name: "y");
return gen_math_ops.add(x, y1, scope);
});
}

return t;
public static Tensor operator -(Tensor t1)
{
return gen_math_ops.neg(t1);
}

public static Tensor operator -(Tensor t1, Tensor t2)
@@ -35,19 +45,41 @@ namespace Tensorflow

public static Tensor operator *(Tensor x, Tensor y)
{
Tensor t = null;
return Python.with<ops.name_scope, Tensor>(new ops.name_scope("", "mul", new Tensor[] { x, y }), scope =>
{
return gen_math_ops.mul(x, y, name: scope);
});
}

Python.with<ops.name_scope>(new ops.name_scope("", "mul", new Tensor[] { x, y }), scope =>
public static Tensor operator *(Tensor x, int y)
{
return Python.with<ops.name_scope, Tensor>(new ops.name_scope("", "mul", new object[] { x, y }), scope =>
{
t = gen_math_ops.mul(x, y, name: scope);
var y1 = ops.convert_to_tensor(y, x.dtype.as_base_dtype(), name: "y");
return gen_math_ops.mul(x, y1, name: scope);
});
}

return t;
public static Tensor operator /(Tensor x, Tensor y)
{
return Python.with<ops.name_scope, Tensor>(new ops.name_scope("truediv/", "truediv", new Tensor[] { x, y }), scope =>
{
return gen_math_ops.real_div(x, y, scope);
});
}

public static Tensor operator /(Tensor x, double y)
{
return Python.with<ops.name_scope, Tensor>(new ops.name_scope("truediv/", "truediv", new object[] { x, y }), scope =>
{
var y1 = ops.convert_to_tensor(y, dtype: x.dtype.as_base_dtype(), name: "y");
return gen_math_ops.real_div(x, y1, scope);
});
}

public static Tensor operator /(Tensor t1, Tensor t2)
public static Tensor operator %(Tensor x, Tensor y)
{
return gen_math_ops.real_div(t1, t2);
throw new NotImplementedException("math mod is not implemented");
}
}
}

+ 8
- 3
src/TensorFlowNET.Core/Tensors/Tensor.cs View File

@@ -38,6 +38,7 @@ namespace Tensorflow
public ulong size => _handle == IntPtr.Zero ? 0 : bytesize / dataTypeSize;
public IntPtr buffer => _handle == IntPtr.Zero ? IntPtr.Zero : c_api.TF_TensorData(_handle);
public int num_consumers(TF_Output oper_out) => _handle == IntPtr.Zero ? 0 : c_api.TF_OperationOutputNumConsumers(oper_out);

public long[] shape
{
get
@@ -60,7 +61,10 @@ namespace Tensorflow

set
{
// c_api.TF_GraphSetTensorShape_wrapper
if (value == null)
c_api.TF_GraphSetTensorShape(this.Graph, this._as_tf_output(), null, -1, status);
else
c_api.TF_GraphSetTensorShape(this.Graph, this._as_tf_output(), value, value.Length, status);
}
}
@@ -170,11 +174,12 @@ namespace Tensorflow
_id = ops.uid();
}

public List<Operation> consumers()
public Operation[] Consumers => consumers();
public Operation[] consumers()
{
var output = _as_tf_output();
var consumer_names = c_api.TF_OperationOutputConsumers_wrapper(output);
return consumer_names.Select(x => Graph.OperationByName(x)).ToList();
return consumer_names.Select(x => Graph.OperationByName(x)).ToArray();
}

public TF_Output _as_tf_output()


+ 1
- 1
src/TensorFlowNET.Core/Tensors/TensorShape.cs View File

@@ -23,7 +23,7 @@ namespace Tensorflow
/// <returns></returns>
public bool is_fully_defined()
{
return Dimensions != null;
return Dimensions != null && Dimensions.Count(x => x > 0) > 0;
}
}
}

+ 10
- 0
src/TensorFlowNET.Core/Tensors/dtypes.cs View File

@@ -82,5 +82,15 @@ namespace Tensorflow
{
return type == TF_DataType.TF_COMPLEX || type == TF_DataType.TF_COMPLEX64 || type == TF_DataType.TF_COMPLEX128;
}

public static bool is_integer(this TF_DataType type)
{
return type == TF_DataType.TF_INT8 || type == TF_DataType.TF_INT16 || type == TF_DataType.TF_INT32 || type == TF_DataType.TF_INT64;
}

public static bool is_floating(this TF_DataType type)
{
return type == TF_DataType.TF_HALF || type == TF_DataType.TF_FLOAT || type == TF_DataType.TF_DOUBLE;
}
}
}

+ 8
- 1
src/TensorFlowNET.Core/ops.py.cs View File

@@ -220,6 +220,11 @@ namespace Tensorflow
_colocate_with_for_gradient(op, null, ignore_existing);
}

public static void colocate_with(Tensor tensor, bool ignore_existing = false)
{
_colocate_with_for_gradient(tensor.op, null, ignore_existing);
}

private static void _colocate_with_for_gradient(Operation op, int? gradient_uid, bool ignore_existing = false)
{
var default_graph = get_default_graph();
@@ -290,10 +295,12 @@ namespace Tensorflow
{
case "Add":
return math_grad._AddGrad(op, out_grads);
case "Sum":
return math_grad._SumGrad(op, out_grads);
case "RealDiv":
return math_grad._RealDivGrad(op, out_grads);
default:
throw new NotImplementedException("get_gradient_function");
throw new NotImplementedException($"get_gradient_function {oper.type}");
}
/*var result = typeof(math_grad).GetMethod($"_{op.type}Grad").Invoke(null, new object[] { op, out_grads });
var p1 = result.GetType().GetProperty("Item1");


+ 4
- 0
test/TensorFlowNET.Examples/TensorFlowNET.Examples.csproj View File

@@ -14,4 +14,8 @@
<ProjectReference Include="..\..\src\TensorFlowNET.Core\TensorFlowNET.Core.csproj" />
</ItemGroup>

<ItemGroup>
<Folder Include="python\" />
</ItemGroup>

</Project>

+ 94
- 0
test/TensorFlowNET.Examples/python/linear_regression.py View File

@@ -0,0 +1,94 @@
'''
A linear regression learning algorithm example using TensorFlow library.
Author: Aymeric Damien
Project: https://github.com/aymericdamien/TensorFlow-Examples/
'''

from __future__ import print_function

import tensorflow as tf
import numpy
import matplotlib.pyplot as plt
rng = numpy.random

# Parameters
learning_rate = 0.01
training_epochs = 1000
display_step = 50

# Training Data
train_X = numpy.asarray([3.3,4.4,5.5,6.71,6.93,4.168,9.779,6.182,7.59,2.167,
7.042,10.791,5.313,7.997,5.654,9.27,3.1])
train_Y = numpy.asarray([1.7,2.76,2.09,3.19,1.694,1.573,3.366,2.596,2.53,1.221,
2.827,3.465,1.65,2.904,2.42,2.94,1.3])
n_samples = train_X.shape[0]

# tf Graph Input
X = tf.placeholder("float")
Y = tf.placeholder("float")

# Set model weights
W = tf.Variable(rng.randn(), name="weight")
b = tf.Variable(rng.randn(), name="bias")

# Construct a linear model
mul = tf.multiply(X, W)
pred = tf.add(mul, b)

# Mean squared error
sub = pred-Y
pow = tf.pow(sub, 2)

reduce = tf.reduce_sum(pow)
cost = reduce/(2*n_samples)
# Gradient descent
# Note, minimize() knows to modify W and b because Variable objects are trainable=True by default
grad = tf.train.GradientDescentOptimizer(learning_rate)
optimizer = grad.minimize(cost)

# Initialize the variables (i.e. assign their default value)
init = tf.global_variables_initializer()

# Start training
with tf.Session() as sess:

# Run the initializer
sess.run(init)

# Fit all training data
for epoch in range(training_epochs):
for (x, y) in zip(train_X, train_Y):
sess.run(optimizer, feed_dict={X: x, Y: y})

# Display logs per epoch step
if (epoch+1) % display_step == 0:
c = sess.run(cost, feed_dict={X: train_X, Y:train_Y})
print("Epoch:", '%04d' % (epoch+1), "cost=", "{:.9f}".format(c), \
"W=", sess.run(W), "b=", sess.run(b))

print("Optimization Finished!")
training_cost = sess.run(cost, feed_dict={X: train_X, Y: train_Y})
print("Training cost=", training_cost, "W=", sess.run(W), "b=", sess.run(b), '\n')

# Graphic display
plt.plot(train_X, train_Y, 'ro', label='Original data')
plt.plot(train_X, sess.run(W) * train_X + sess.run(b), label='Fitted line')
plt.legend()
plt.show()

# Testing example, as requested (Issue #2)
test_X = numpy.asarray([6.83, 4.668, 8.9, 7.91, 5.7, 8.7, 3.1, 2.1])
test_Y = numpy.asarray([1.84, 2.273, 3.2, 2.831, 2.92, 3.24, 1.35, 1.03])

print("Testing... (Mean square loss Comparison)")
testing_cost = sess.run(
tf.reduce_sum(tf.pow(pred - Y, 2)) / (2 * test_X.shape[0]),
feed_dict={X: test_X, Y: test_Y}) # same function as cost above
print("Testing cost=", testing_cost)
print("Absolute mean square loss difference:", abs(
training_cost - testing_cost))

plt.plot(test_X, test_Y, 'bo', label='Testing data')
plt.plot(train_X, sess.run(W) * train_X + sess.run(b), label='Fitted line')
plt.legend()
plt.show()

Loading…
Cancel
Save