@@ -34,26 +34,7 @@ namespace Tensorflow | |||
/// <returns></returns> | |||
public static Tensor reduce_sum(Tensor input, int[] axis = null) | |||
{ | |||
Tensor rank; | |||
string name; | |||
using (var namescop = new ops.name_scope("", "Rank", new List<Tensor> { input })) | |||
{ | |||
name = namescop; | |||
rank = gen_array_ops.rank(input, namescop); | |||
} | |||
using (var namescope = new ops.name_scope("range", "Range", new List<Tensor> { 0D, input, 1D })) | |||
{ | |||
name = namescope; | |||
var start = ops.convert_to_tensor(0D); | |||
var limit = ops.convert_to_tensor(input); | |||
var delta = ops.convert_to_tensor(1D); | |||
var t = gen_math_ops.range(start, limit, delta, name); | |||
} | |||
var s = gen_math_ops.sum(input, rank); | |||
return s; | |||
return math_ops.reduce_sum(input); | |||
} | |||
} | |||
} |
@@ -181,7 +181,11 @@ namespace Tensorflow | |||
{ | |||
foreach(var x in _NonEagerInputs(op, xs)) | |||
{ | |||
pending_count[x.op.Name] -= 1; | |||
if (!pending_count.ContainsKey(x.op.Name)) | |||
pending_count[x.op.Name] = 0; | |||
else | |||
pending_count[x.op.Name] -= 1; | |||
var ready = pending_count[x.op.Name] == 0; | |||
if(loop_state != null && !ready) | |||
@@ -440,22 +444,43 @@ namespace Tensorflow | |||
reached_ops.Add(op); | |||
foreach (var output in op.outputs) | |||
{ | |||
var c = _Consumers(output, func_graphs).ToList(); | |||
c.ForEach(x => queue.Enqueue(x)); | |||
if (_IsBackpropagatable(output)) | |||
{ | |||
var c = _Consumers(output, func_graphs).ToList(); | |||
c.ForEach(x => queue.Enqueue(x)); | |||
} | |||
} | |||
} | |||
} | |||
} | |||
private static bool _IsTrainable(Tensor tensor) | |||
{ | |||
var dtype = tensor.dtype.as_base_dtype(); | |||
return new TF_DataType[] {TF_DataType.TF_HALF, TF_DataType.TF_FLOAT, TF_DataType.TF_DOUBLE, | |||
TF_DataType.TF_COMPLEX64, TF_DataType.TF_COMPLEX128, TF_DataType.TF_RESOURCE}.Contains(dtype); | |||
} | |||
private static bool _IsBackpropagatable(Tensor tensor) | |||
{ | |||
if(_IsTrainable(tensor)) | |||
{ | |||
return true; | |||
} | |||
else | |||
{ | |||
var dtype = tensor.dtype.as_base_dtype(); | |||
return new TF_DataType[] { TF_DataType.TF_BFLOAT16, TF_DataType.TF_VARIANT }.Contains(dtype); | |||
} | |||
} | |||
/// <summary> | |||
/// Returns the consumers of t, crossing closure boundaries where necessary. | |||
/// </summary> | |||
/// <param name="t"></param> | |||
/// <param name="func_graphs"></param> | |||
private static List<Operation> _Consumers(Tensor t, List<object> func_graphs) | |||
private static Operation[] _Consumers(Tensor t, List<object> func_graphs) | |||
{ | |||
var consumers = t.consumers(); | |||
return consumers; | |||
return t.consumers(); | |||
} | |||
private static List<Tensor> _AsList(object ys) | |||
@@ -17,6 +17,22 @@ namespace Tensorflow | |||
return (grad, grad); | |||
} | |||
public static (Tensor, Tensor) _SumGrad(Operation op, Tensor grad) | |||
{ | |||
if (op.inputs[0].NDims > -1) | |||
{ | |||
} | |||
var input_shape = array_ops.shape(op.inputs[0]); | |||
ops.colocate_with(input_shape); | |||
var output_shape_kept_dims = math_ops.reduced_shape(input_shape, op.inputs[1]); | |||
//var tile_scaling = _safe_shape_div(input_shape, output_shape_kept_dims); | |||
//var grad = array_ops.reshape(grad, output_shape_kept_dims); | |||
return (grad, grad); | |||
} | |||
public static (Tensor, Tensor) _RealDivGrad(Operation op, Tensor grad) | |||
{ | |||
var x = op.inputs[0]; | |||
@@ -24,9 +40,17 @@ namespace Tensorflow | |||
var sx = array_ops.shape(x); | |||
var sy = array_ops.shape(y); | |||
// rx, ry = gen_array_ops.broadcast_gradient_args(sx, sy) | |||
var (rx, ry) = gen_array_ops.broadcast_gradient_args(sx, sy); | |||
x = math_ops.conj(x); | |||
y = math_ops.conj(y); | |||
return (grad, grad); | |||
var realdiv1 = gen_math_ops.real_div(grad, y); | |||
var reduce_sum1 = math_ops.reduce_sum(realdiv1, rx); | |||
var realdiv2 = gen_math_ops.real_div(-x, y); | |||
var realdiv3 = gen_math_ops.real_div(realdiv2, y); | |||
var reduce_sum2 = math_ops.reduce_sum(grad * realdiv3, ry); | |||
return (gen_array_ops.reshape(reduce_sum1, sx), gen_array_ops.reshape(reduce_sum2, sy)); | |||
} | |||
} | |||
} |
@@ -53,6 +53,11 @@ namespace Tensorflow | |||
} | |||
} | |||
public static Tensor rank(Tensor input, string name = "") | |||
{ | |||
return math_ops.rank_internal(input, name, optimize: true); | |||
} | |||
/// <summary> | |||
/// Returns the shape of a tensor. | |||
/// </summary> | |||
@@ -68,11 +73,14 @@ namespace Tensorflow | |||
return shape_internal(input, name, optimize: true, out_type: out_type); | |||
} | |||
private static Tensor shape_internal(Tensor input, string name = "", bool optimize = true, TF_DataType out_type = TF_DataType.TF_INT32) | |||
public static Tensor size(Tensor input, string name = "", TF_DataType out_type = TF_DataType.TF_INT32) | |||
{ | |||
Tensor result = null; | |||
return size_internal(input, name, optimize: true, out_type: out_type); | |||
} | |||
Python.with<ops.name_scope>(new ops.name_scope(name, "Shape", new Tensor[] { input }), scope => | |||
private static Tensor shape_internal(Tensor input, string name = "", bool optimize = true, TF_DataType out_type = TF_DataType.TF_INT32) | |||
{ | |||
return Python.with<ops.name_scope, Tensor>(new ops.name_scope(name, "Shape", new Tensor[] { input }), scope => | |||
{ | |||
name = scope; | |||
@@ -83,16 +91,46 @@ namespace Tensorflow | |||
if (optimize && input_shape.is_fully_defined()) | |||
{ | |||
var nd = np.array(input_tensor.shape, out_type.as_numpy_datatype()); | |||
result = constant_op.constant(nd, name); | |||
return constant_op.constant(nd, name); | |||
} | |||
} | |||
else | |||
{ | |||
// result = gen_array_ops.shape(); | |||
} | |||
return null; | |||
}); | |||
} | |||
return result; | |||
private static Tensor size_internal(Tensor input, string name = "", bool optimize = true, TF_DataType out_type = TF_DataType.TF_INT32) | |||
{ | |||
return Python.with<ops.name_scope, Tensor>(new ops.name_scope(name, "Size", new Tensor[] { input }), scope => | |||
{ | |||
name = scope; | |||
if (!tf.context.executing_eagerly()) | |||
{ | |||
var input_tensor = ops.convert_to_tensor(input); | |||
var input_shape = tensor_util.to_shape(input_tensor.shape); | |||
if (optimize) | |||
{ | |||
if (input_shape.is_fully_defined()) | |||
{ | |||
var nd = np.array(input_tensor.shape, out_type.as_numpy_datatype()); | |||
return constant_op.constant(nd, name); | |||
} | |||
} | |||
return gen_array_ops.size(input, name: name, out_type: out_type); | |||
} | |||
else | |||
{ | |||
// result = gen_array_ops.shape(); | |||
} | |||
return null; | |||
}); | |||
} | |||
} | |||
} |
@@ -60,9 +60,30 @@ namespace Tensorflow | |||
return _op.outputs[0]; | |||
} | |||
/// <summary> | |||
/// Return the reduction indices for computing gradients of s0 op s1 with broadcast. | |||
/// </summary> | |||
/// <param name="s0">A `Tensor`. Must be one of the following types: `int32`, `int64`.</param> | |||
/// <param name="s1">A `Tensor`. Must have the same type as `s0`.</param> | |||
/// <param name="name">A name for the operation (optional).</param> | |||
/// <returns>A tuple of `Tensor` objects (r0, r1).</returns> | |||
public static (Tensor, Tensor) broadcast_gradient_args(Tensor s0, Tensor s1, string name = "") | |||
{ | |||
return (null, null); | |||
var _op = _op_def_lib._apply_op_helper("BroadcastGradientArgs", name, new { s0, s1 }); | |||
return (_op.outputs[0], _op.outputs[1]); | |||
} | |||
public static Tensor reshape(Tensor tensor, Tensor shape, string name = "") | |||
{ | |||
var _op = _op_def_lib._apply_op_helper("Reshape", name, new { tensor, shape }); | |||
return _op.outputs[0]; | |||
} | |||
public static Tensor size(Tensor input, TF_DataType out_type = TF_DataType.TF_INT32, string name = "") | |||
{ | |||
var _op = _op_def_lib._apply_op_helper("Size", name, new { input, out_type }); | |||
return _op.outputs[0]; | |||
} | |||
} | |||
} |
@@ -17,9 +17,16 @@ namespace Tensorflow | |||
return _op.outputs[0]; | |||
} | |||
public static Tensor sub(Tensor x, Tensor y) | |||
public static Tensor neg(Tensor x, string name = "") | |||
{ | |||
var _op = _op_def_lib._apply_op_helper("Sub", name: "sub", args: new { x, y }); | |||
var _op = _op_def_lib._apply_op_helper("Neg", name, args: new { x }); | |||
return _op.outputs[0]; | |||
} | |||
public static Tensor sub(Tensor x, Tensor y, string name = "") | |||
{ | |||
var _op = _op_def_lib._apply_op_helper("Sub", name, args: new { x, y }); | |||
return _op.outputs[0]; | |||
} | |||
@@ -31,9 +38,9 @@ namespace Tensorflow | |||
return _op.outputs[0]; | |||
} | |||
public static Tensor real_div(Tensor x, Tensor y) | |||
public static Tensor real_div(Tensor x, Tensor y, string name = "") | |||
{ | |||
var _op = _op_def_lib._apply_op_helper("RealDiv", name: "truediv", args: new { x, y }); | |||
var _op = _op_def_lib._apply_op_helper("RealDiv", name, args: new { x, y }); | |||
return _op.outputs[0]; | |||
} | |||
@@ -61,9 +68,9 @@ namespace Tensorflow | |||
return _op.outputs[0]; | |||
} | |||
public static Tensor sum(Tensor input, Tensor axis = null) | |||
public static Tensor sum(Tensor input, Tensor axis = null, bool keep_dims = false, string name = "") | |||
{ | |||
var _op = _op_def_lib._apply_op_helper("Sum", args: new { input, reduction_indices = axis, keep_dims = false }); | |||
var _op = _op_def_lib._apply_op_helper("Sum", name, args: new { input, reduction_indices = axis, keep_dims }); | |||
return _op.outputs[0]; | |||
} | |||
@@ -6,6 +6,104 @@ namespace Tensorflow | |||
{ | |||
public class math_ops | |||
{ | |||
/// <summary> | |||
/// Helper function for reduction ops. | |||
/// </summary> | |||
/// <param name="input_shape">1-D Tensor, the shape of the Tensor being reduced.</param> | |||
/// <param name="axes">1-D Tensor, the reduction axes.</param> | |||
/// <returns>A 1-D Tensor, the output shape as if keepdims were set to True.</returns> | |||
public static Tensor reduced_shape(Tensor input_shape, Tensor axes) | |||
{ | |||
input_shape = to_int32(input_shape); | |||
axes = to_int32(axes); | |||
var input_rank = array_ops.size(input_shape); | |||
axes = (axes + input_rank) % input_rank; | |||
return null; | |||
} | |||
/// <summary> | |||
/// Casts a tensor to type `int32`. | |||
/// </summary> | |||
/// <param name="x">A `Tensor` or `SparseTensor` or `IndexedSlices`.</param> | |||
/// <param name="name">A name for the operation (optional).</param> | |||
/// <returns>A `Tensor` or `SparseTensor` or `IndexedSlices` with same shape as `x` with type `int32`.</returns> | |||
private static Tensor to_int32(Tensor x, string name = "ToInt32") | |||
{ | |||
return __case__(x, TF_DataType.TF_INT32, name: name); | |||
} | |||
/// <summary> | |||
/// Casts a tensor to a new type. | |||
/// </summary> | |||
/// <param name="x"></param> | |||
/// <param name="dtype"></param> | |||
/// <param name="name"></param> | |||
/// <returns>A `Tensor` or `SparseTensor` or `IndexedSlices` with same shape as `x` and same type as `dtype`.</returns> | |||
public static Tensor __case__(Tensor x, TF_DataType dtype, string name = "") | |||
{ | |||
var base_type = dtype.as_base_dtype(); | |||
if (x is Tensor && base_type == x.dtype) | |||
return x; | |||
// math_ops.py cast | |||
throw new NotImplementedException(); | |||
} | |||
public static Tensor reduce_sum(Tensor input_tensor, Tensor axis = null, bool keepdims = false) | |||
{ | |||
var r = _ReductionDims(input_tensor, axis); | |||
var m = gen_math_ops.sum(input_tensor, r); | |||
return _may_reduce_to_scalar(keepdims, m); | |||
} | |||
private static Tensor _may_reduce_to_scalar(bool keepdims, Tensor output) | |||
{ | |||
output.shape = new long[0]; | |||
return output; | |||
} | |||
private static Tensor _ReductionDims(Tensor x, Tensor axis) | |||
{ | |||
if (axis != null) | |||
{ | |||
return axis; | |||
} | |||
else | |||
{ | |||
var rank = array_ops.rank(x); | |||
return range(0, rank, 1); | |||
} | |||
} | |||
public static Tensor range(object start, Tensor limit = null, object delta = null, TF_DataType dtype = TF_DataType.DtInvalid, string name = "range" ) | |||
{ | |||
return Python.with<ops.name_scope, Tensor>(new ops.name_scope(name, "Range", new object[] { start, limit, delta }), scope => | |||
{ | |||
name = scope; | |||
var start1 = ops.convert_to_tensor(start, name: "start"); | |||
var limit1 = ops.convert_to_tensor(limit, name: "limit"); | |||
var delta1 = ops.convert_to_tensor(delta, name: "delta"); | |||
return gen_math_ops.range(start1, limit1, delta1, name); | |||
}); | |||
} | |||
public static Tensor rank_internal(Tensor input, string name = "", bool optimize = true) | |||
{ | |||
return Python.with<ops.name_scope, Tensor>(new ops.name_scope(name, "Rank", new List<Tensor> { input }), scope => | |||
{ | |||
name = scope; | |||
var input_tensor = ops.convert_to_tensor(input); | |||
var input_shape = tensor_util.to_shape(input_tensor.shape); | |||
if (optimize && input_shape.NDim == null) | |||
return constant_op.constant(input_shape.NDim); | |||
else | |||
return gen_array_ops.rank(input, name); | |||
}); | |||
} | |||
public static Tensor matmul(Tensor a, Tensor b, | |||
bool transpose_a = false, bool transpose_b = false, | |||
bool adjoint_a = false, bool adjoint_b = false, | |||
@@ -31,5 +129,24 @@ namespace Tensorflow | |||
return result; | |||
} | |||
/// <summary> | |||
/// Returns the complex conjugate of a complex number. | |||
/// </summary> | |||
/// <param name="x">`Tensor` to conjugate. Must have numeric or variant type.</param> | |||
/// <param name="name">A name for the operation (optional).</param> | |||
/// <returns>A `Tensor` that is the conjugate of `x` (with the same type).</returns> | |||
public static Tensor conj(Tensor x, string name = "") | |||
{ | |||
var dt = x.dtype; | |||
if (dt.is_floating() || dt.is_integer()) | |||
return x; | |||
return Python.with<ops.name_scope, Tensor>(new ops.name_scope(name, "Conj", new List<Tensor> { x }), scope => | |||
{ | |||
return x; | |||
}); | |||
} | |||
} | |||
} |
@@ -6,15 +6,19 @@ namespace Tensorflow | |||
{ | |||
public partial class Tensor | |||
{ | |||
public static implicit operator Tensor(double scalar) | |||
/// <summary> | |||
/// Issue unresolved, will cause name_scope problem. | |||
/// </summary> | |||
/// <param name="scalar"></param> | |||
/*public static implicit operator Tensor(double scalar) | |||
{ | |||
return constant_op.constant(scalar); | |||
} | |||
}*/ | |||
public static implicit operator Tensor(int scalar) | |||
/*public static implicit operator Tensor(int scalar) | |||
{ | |||
return constant_op.constant(scalar); | |||
} | |||
}*/ | |||
public static implicit operator int(Tensor tensor) | |||
{ | |||
@@ -8,14 +8,24 @@ namespace Tensorflow | |||
{ | |||
public static Tensor operator +(Tensor x, Tensor y) | |||
{ | |||
Tensor t = null; | |||
return Python.with<ops.name_scope, Tensor>(new ops.name_scope("", "add", new Tensor[] { x, y }), scope => | |||
{ | |||
return gen_math_ops.add(x, y, scope); | |||
}); | |||
} | |||
Python.with<ops.name_scope>(new ops.name_scope("", "add", new Tensor[] { x, y }), scope => | |||
public static Tensor operator +(Tensor x, int y) | |||
{ | |||
return Python.with<ops.name_scope, Tensor>(new ops.name_scope("", "add", new object[] { x, y }), scope => | |||
{ | |||
t = gen_math_ops.add(x, y, scope); | |||
var y1 = ops.convert_to_tensor(y, x.dtype.as_base_dtype(), name: "y"); | |||
return gen_math_ops.add(x, y1, scope); | |||
}); | |||
} | |||
return t; | |||
public static Tensor operator -(Tensor t1) | |||
{ | |||
return gen_math_ops.neg(t1); | |||
} | |||
public static Tensor operator -(Tensor t1, Tensor t2) | |||
@@ -35,19 +45,41 @@ namespace Tensorflow | |||
public static Tensor operator *(Tensor x, Tensor y) | |||
{ | |||
Tensor t = null; | |||
return Python.with<ops.name_scope, Tensor>(new ops.name_scope("", "mul", new Tensor[] { x, y }), scope => | |||
{ | |||
return gen_math_ops.mul(x, y, name: scope); | |||
}); | |||
} | |||
Python.with<ops.name_scope>(new ops.name_scope("", "mul", new Tensor[] { x, y }), scope => | |||
public static Tensor operator *(Tensor x, int y) | |||
{ | |||
return Python.with<ops.name_scope, Tensor>(new ops.name_scope("", "mul", new object[] { x, y }), scope => | |||
{ | |||
t = gen_math_ops.mul(x, y, name: scope); | |||
var y1 = ops.convert_to_tensor(y, x.dtype.as_base_dtype(), name: "y"); | |||
return gen_math_ops.mul(x, y1, name: scope); | |||
}); | |||
} | |||
return t; | |||
public static Tensor operator /(Tensor x, Tensor y) | |||
{ | |||
return Python.with<ops.name_scope, Tensor>(new ops.name_scope("truediv/", "truediv", new Tensor[] { x, y }), scope => | |||
{ | |||
return gen_math_ops.real_div(x, y, scope); | |||
}); | |||
} | |||
public static Tensor operator /(Tensor x, double y) | |||
{ | |||
return Python.with<ops.name_scope, Tensor>(new ops.name_scope("truediv/", "truediv", new object[] { x, y }), scope => | |||
{ | |||
var y1 = ops.convert_to_tensor(y, dtype: x.dtype.as_base_dtype(), name: "y"); | |||
return gen_math_ops.real_div(x, y1, scope); | |||
}); | |||
} | |||
public static Tensor operator /(Tensor t1, Tensor t2) | |||
public static Tensor operator %(Tensor x, Tensor y) | |||
{ | |||
return gen_math_ops.real_div(t1, t2); | |||
throw new NotImplementedException("math mod is not implemented"); | |||
} | |||
} | |||
} |
@@ -38,6 +38,7 @@ namespace Tensorflow | |||
public ulong size => _handle == IntPtr.Zero ? 0 : bytesize / dataTypeSize; | |||
public IntPtr buffer => _handle == IntPtr.Zero ? IntPtr.Zero : c_api.TF_TensorData(_handle); | |||
public int num_consumers(TF_Output oper_out) => _handle == IntPtr.Zero ? 0 : c_api.TF_OperationOutputNumConsumers(oper_out); | |||
public long[] shape | |||
{ | |||
get | |||
@@ -60,7 +61,10 @@ namespace Tensorflow | |||
set | |||
{ | |||
// c_api.TF_GraphSetTensorShape_wrapper | |||
if (value == null) | |||
c_api.TF_GraphSetTensorShape(this.Graph, this._as_tf_output(), null, -1, status); | |||
else | |||
c_api.TF_GraphSetTensorShape(this.Graph, this._as_tf_output(), value, value.Length, status); | |||
} | |||
} | |||
@@ -170,11 +174,12 @@ namespace Tensorflow | |||
_id = ops.uid(); | |||
} | |||
public List<Operation> consumers() | |||
public Operation[] Consumers => consumers(); | |||
public Operation[] consumers() | |||
{ | |||
var output = _as_tf_output(); | |||
var consumer_names = c_api.TF_OperationOutputConsumers_wrapper(output); | |||
return consumer_names.Select(x => Graph.OperationByName(x)).ToList(); | |||
return consumer_names.Select(x => Graph.OperationByName(x)).ToArray(); | |||
} | |||
public TF_Output _as_tf_output() | |||
@@ -23,7 +23,7 @@ namespace Tensorflow | |||
/// <returns></returns> | |||
public bool is_fully_defined() | |||
{ | |||
return Dimensions != null; | |||
return Dimensions != null && Dimensions.Count(x => x > 0) > 0; | |||
} | |||
} | |||
} |
@@ -82,5 +82,15 @@ namespace Tensorflow | |||
{ | |||
return type == TF_DataType.TF_COMPLEX || type == TF_DataType.TF_COMPLEX64 || type == TF_DataType.TF_COMPLEX128; | |||
} | |||
public static bool is_integer(this TF_DataType type) | |||
{ | |||
return type == TF_DataType.TF_INT8 || type == TF_DataType.TF_INT16 || type == TF_DataType.TF_INT32 || type == TF_DataType.TF_INT64; | |||
} | |||
public static bool is_floating(this TF_DataType type) | |||
{ | |||
return type == TF_DataType.TF_HALF || type == TF_DataType.TF_FLOAT || type == TF_DataType.TF_DOUBLE; | |||
} | |||
} | |||
} |
@@ -220,6 +220,11 @@ namespace Tensorflow | |||
_colocate_with_for_gradient(op, null, ignore_existing); | |||
} | |||
public static void colocate_with(Tensor tensor, bool ignore_existing = false) | |||
{ | |||
_colocate_with_for_gradient(tensor.op, null, ignore_existing); | |||
} | |||
private static void _colocate_with_for_gradient(Operation op, int? gradient_uid, bool ignore_existing = false) | |||
{ | |||
var default_graph = get_default_graph(); | |||
@@ -290,10 +295,12 @@ namespace Tensorflow | |||
{ | |||
case "Add": | |||
return math_grad._AddGrad(op, out_grads); | |||
case "Sum": | |||
return math_grad._SumGrad(op, out_grads); | |||
case "RealDiv": | |||
return math_grad._RealDivGrad(op, out_grads); | |||
default: | |||
throw new NotImplementedException("get_gradient_function"); | |||
throw new NotImplementedException($"get_gradient_function {oper.type}"); | |||
} | |||
/*var result = typeof(math_grad).GetMethod($"_{op.type}Grad").Invoke(null, new object[] { op, out_grads }); | |||
var p1 = result.GetType().GetProperty("Item1"); | |||
@@ -14,4 +14,8 @@ | |||
<ProjectReference Include="..\..\src\TensorFlowNET.Core\TensorFlowNET.Core.csproj" /> | |||
</ItemGroup> | |||
<ItemGroup> | |||
<Folder Include="python\" /> | |||
</ItemGroup> | |||
</Project> |
@@ -0,0 +1,94 @@ | |||
''' | |||
A linear regression learning algorithm example using TensorFlow library. | |||
Author: Aymeric Damien | |||
Project: https://github.com/aymericdamien/TensorFlow-Examples/ | |||
''' | |||
from __future__ import print_function | |||
import tensorflow as tf | |||
import numpy | |||
import matplotlib.pyplot as plt | |||
rng = numpy.random | |||
# Parameters | |||
learning_rate = 0.01 | |||
training_epochs = 1000 | |||
display_step = 50 | |||
# Training Data | |||
train_X = numpy.asarray([3.3,4.4,5.5,6.71,6.93,4.168,9.779,6.182,7.59,2.167, | |||
7.042,10.791,5.313,7.997,5.654,9.27,3.1]) | |||
train_Y = numpy.asarray([1.7,2.76,2.09,3.19,1.694,1.573,3.366,2.596,2.53,1.221, | |||
2.827,3.465,1.65,2.904,2.42,2.94,1.3]) | |||
n_samples = train_X.shape[0] | |||
# tf Graph Input | |||
X = tf.placeholder("float") | |||
Y = tf.placeholder("float") | |||
# Set model weights | |||
W = tf.Variable(rng.randn(), name="weight") | |||
b = tf.Variable(rng.randn(), name="bias") | |||
# Construct a linear model | |||
mul = tf.multiply(X, W) | |||
pred = tf.add(mul, b) | |||
# Mean squared error | |||
sub = pred-Y | |||
pow = tf.pow(sub, 2) | |||
reduce = tf.reduce_sum(pow) | |||
cost = reduce/(2*n_samples) | |||
# Gradient descent | |||
# Note, minimize() knows to modify W and b because Variable objects are trainable=True by default | |||
grad = tf.train.GradientDescentOptimizer(learning_rate) | |||
optimizer = grad.minimize(cost) | |||
# Initialize the variables (i.e. assign their default value) | |||
init = tf.global_variables_initializer() | |||
# Start training | |||
with tf.Session() as sess: | |||
# Run the initializer | |||
sess.run(init) | |||
# Fit all training data | |||
for epoch in range(training_epochs): | |||
for (x, y) in zip(train_X, train_Y): | |||
sess.run(optimizer, feed_dict={X: x, Y: y}) | |||
# Display logs per epoch step | |||
if (epoch+1) % display_step == 0: | |||
c = sess.run(cost, feed_dict={X: train_X, Y:train_Y}) | |||
print("Epoch:", '%04d' % (epoch+1), "cost=", "{:.9f}".format(c), \ | |||
"W=", sess.run(W), "b=", sess.run(b)) | |||
print("Optimization Finished!") | |||
training_cost = sess.run(cost, feed_dict={X: train_X, Y: train_Y}) | |||
print("Training cost=", training_cost, "W=", sess.run(W), "b=", sess.run(b), '\n') | |||
# Graphic display | |||
plt.plot(train_X, train_Y, 'ro', label='Original data') | |||
plt.plot(train_X, sess.run(W) * train_X + sess.run(b), label='Fitted line') | |||
plt.legend() | |||
plt.show() | |||
# Testing example, as requested (Issue #2) | |||
test_X = numpy.asarray([6.83, 4.668, 8.9, 7.91, 5.7, 8.7, 3.1, 2.1]) | |||
test_Y = numpy.asarray([1.84, 2.273, 3.2, 2.831, 2.92, 3.24, 1.35, 1.03]) | |||
print("Testing... (Mean square loss Comparison)") | |||
testing_cost = sess.run( | |||
tf.reduce_sum(tf.pow(pred - Y, 2)) / (2 * test_X.shape[0]), | |||
feed_dict={X: test_X, Y: test_Y}) # same function as cost above | |||
print("Testing cost=", testing_cost) | |||
print("Absolute mean square loss difference:", abs( | |||
training_cost - testing_cost)) | |||
plt.plot(test_X, test_Y, 'bo', label='Testing data') | |||
plt.plot(train_X, sess.run(W) * train_X + sess.run(b), label='Fitted line') | |||
plt.legend() | |||
plt.show() |