@@ -18,22 +18,27 @@ namespace Tensorflow | |||||
{ | { | ||||
public partial class tensorflow | public partial class tensorflow | ||||
{ | { | ||||
/// <summary> | |||||
/// Outputs random values from a normal distribution. | |||||
/// </summary> | |||||
/// <param name="shape"></param> | |||||
/// <param name="mean"></param> | |||||
/// <param name="stddev"></param> | |||||
/// <param name="dtype"></param> | |||||
/// <param name="seed"></param> | |||||
/// <param name="name"></param> | |||||
/// <returns></returns> | |||||
public Tensor random_normal(TensorShape shape, | |||||
float mean = 0.0f, | |||||
float stddev = 1.0f, | |||||
TF_DataType dtype = TF_DataType.TF_FLOAT, | |||||
int? seed = null, | |||||
string name = null) => random_ops.random_normal(shape, mean, stddev, dtype, seed, name); | |||||
public Random random => new Random(); | |||||
public class Random | |||||
{ | |||||
/// <summary> | |||||
/// Outputs random values from a normal distribution. | |||||
/// </summary> | |||||
/// <param name="shape"></param> | |||||
/// <param name="mean"></param> | |||||
/// <param name="stddev"></param> | |||||
/// <param name="dtype"></param> | |||||
/// <param name="seed"></param> | |||||
/// <param name="name"></param> | |||||
/// <returns></returns> | |||||
public Tensor normal(TensorShape shape, | |||||
float mean = 0.0f, | |||||
float stddev = 1.0f, | |||||
TF_DataType dtype = TF_DataType.TF_FLOAT, | |||||
int? seed = null, | |||||
string name = null) => random_ops.random_normal(shape, mean, stddev, dtype, seed, name); | |||||
} | |||||
public Tensor random_uniform(TensorShape shape, | public Tensor random_uniform(TensorShape shape, | ||||
float minval = 0, | float minval = 0, | ||||
@@ -45,7 +45,7 @@ namespace Tensorflow.Eager | |||||
op_name, | op_name, | ||||
inputs.Select(x => (x as EagerTensor).GetTfeTensorHandle()).ToArray(), | inputs.Select(x => (x as EagerTensor).GetTfeTensorHandle()).ToArray(), | ||||
inputs.Length, | inputs.Length, | ||||
op => wrap_tfe_src.SetOpAttrs(ctx, op, attrs, status), | |||||
op => wrap_tfe_src.SetOpAttrs(op, attrs), | |||||
outputs, | outputs, | ||||
num_outputs, | num_outputs, | ||||
status); | status); | ||||
@@ -2,6 +2,7 @@ | |||||
using System.Linq; | using System.Linq; | ||||
using System; | using System; | ||||
using static Tensorflow.OpDef.Types; | using static Tensorflow.OpDef.Types; | ||||
using static Tensorflow.Binding; | |||||
namespace Tensorflow.Eager | namespace Tensorflow.Eager | ||||
{ | { | ||||
@@ -10,8 +11,9 @@ namespace Tensorflow.Eager | |||||
/// </summary> | /// </summary> | ||||
public partial class wrap_tfe_src | public partial class wrap_tfe_src | ||||
{ | { | ||||
public static void SetOpAttrs(Context ctx, TFE_Op op, object[] attrs, Status out_status) | |||||
public static void SetOpAttrs(TFE_Op op, params object[] attrs) | |||||
{ | { | ||||
using var status = new Status(); | |||||
var len = attrs.Length; | var len = attrs.Length; | ||||
for (int i = 0; i < len; i += 2) | for (int i = 0; i < len; i += 2) | ||||
{ | { | ||||
@@ -19,13 +21,13 @@ namespace Tensorflow.Eager | |||||
var value = attrs[i + 1]; | var value = attrs[i + 1]; | ||||
byte is_list = 0; | byte is_list = 0; | ||||
var type = c_api.TFE_OpGetAttrType(op, key, ref is_list, out_status); | |||||
if (!out_status.ok()) return; | |||||
var type = c_api.TFE_OpGetAttrType(op, key, ref is_list, status); | |||||
if (!status.ok()) return; | |||||
if (is_list != 0) | if (is_list != 0) | ||||
SetOpAttrList(ctx, op, key, value, type, null, out_status); | |||||
SetOpAttrList(tf.context, op, key, value, type, null, status); | |||||
else | else | ||||
SetOpAttrScalar(ctx, op, key, value, type, null, out_status); | |||||
out_status.Check(true); | |||||
SetOpAttrScalar(tf.context, op, key, value, type, null, status); | |||||
status.Check(true); | |||||
} | } | ||||
} | } | ||||
@@ -165,7 +165,7 @@ namespace Tensorflow | |||||
var tensor = c_api.TFE_FastPathExecute(tf.context, tf.context.device_name, | var tensor = c_api.TFE_FastPathExecute(tf.context, tf.context.device_name, | ||||
"Pack", name, | "Pack", name, | ||||
values.Select(x => (x as EagerTensor).EagerTensorHandle).ToArray(), values.Length, | values.Select(x => (x as EagerTensor).EagerTensorHandle).ToArray(), values.Length, | ||||
op => wrap_tfe_src.SetOpAttrs(tf.context, op, new object[] { "axis", axis } , status), | |||||
op => wrap_tfe_src.SetOpAttrs(op, "axis", axis), | |||||
status); | status); | ||||
status.Check(true); | status.Check(true); | ||||
return new EagerTensor(tensor); | return new EagerTensor(tensor); | ||||
@@ -421,11 +421,8 @@ namespace Tensorflow | |||||
"Shape", name, new IntPtr[] | "Shape", name, new IntPtr[] | ||||
{ | { | ||||
input as EagerTensor, | input as EagerTensor, | ||||
}, 1, | |||||
op => wrap_tfe_src.SetOpAttrs(tf.context, op, new object[] | |||||
{ | |||||
"out_type", out_type | |||||
}, status), | |||||
}, 1, | |||||
op => wrap_tfe_src.SetOpAttrs(op, "out_type", out_type), | |||||
status); | status); | ||||
status.Check(true); | status.Check(true); | ||||
return tensor; | return tensor; | ||||
@@ -531,14 +528,12 @@ namespace Tensorflow | |||||
end as EagerTensor, | end as EagerTensor, | ||||
strides as EagerTensor, | strides as EagerTensor, | ||||
}, 4, | }, 4, | ||||
op => wrap_tfe_src.SetOpAttrs(tf.context, op, new object[] | |||||
{ | |||||
op => wrap_tfe_src.SetOpAttrs(op, | |||||
"begin_mask", begin_mask, | "begin_mask", begin_mask, | ||||
"end_mask", end_mask, | "end_mask", end_mask, | ||||
"ellipsis_mask", ellipsis_mask, | "ellipsis_mask", ellipsis_mask, | ||||
"new_axis_mask", new_axis_mask, | "new_axis_mask", new_axis_mask, | ||||
"shrink_axis_mask", shrink_axis_mask | |||||
}, status), | |||||
"shrink_axis_mask", shrink_axis_mask), | |||||
status); | status); | ||||
status.Check(true); | status.Check(true); | ||||
return tensor; | return tensor; | ||||
@@ -44,13 +44,13 @@ namespace Tensorflow | |||||
if (tf.context.executing_eagerly()) | if (tf.context.executing_eagerly()) | ||||
{ | { | ||||
using var status = new Status(); | using var status = new Status(); | ||||
var _result = c_api.TFE_FastPathExecute(tf.context, tf.context.device_name, | |||||
EagerTensorHandle _result = c_api.TFE_FastPathExecute(tf.context, tf.context.device_name, | |||||
"AddN", name, | "AddN", name, | ||||
inputs.Select(x => (x as EagerTensor).EagerTensorHandle).ToArray(), inputs.Length, | inputs.Select(x => (x as EagerTensor).EagerTensorHandle).ToArray(), inputs.Length, | ||||
null, | null, | ||||
status); | status); | ||||
status.Check(true); | status.Check(true); | ||||
return new EagerTensor(_result); | |||||
return _result; | |||||
} | } | ||||
var _op = _op_def_lib._apply_op_helper("AddN", name, args: new { inputs }); | var _op = _op_def_lib._apply_op_helper("AddN", name, args: new { inputs }); | ||||
@@ -132,17 +132,17 @@ namespace Tensorflow | |||||
if (tf.context.executing_eagerly()) | if (tf.context.executing_eagerly()) | ||||
{ | { | ||||
using var status = new Status(); | using var status = new Status(); | ||||
var tensor = c_api.TFE_FastPathExecute(tf.context, tf.context.device_name, | |||||
EagerTensorHandle tensor = c_api.TFE_FastPathExecute(tf.context, tf.context.device_name, | |||||
"Mean", name, | "Mean", name, | ||||
new IntPtr[] | new IntPtr[] | ||||
{ | { | ||||
input as EagerTensor, | input as EagerTensor, | ||||
axis as EagerTensor | axis as EagerTensor | ||||
}, 2, | }, 2, | ||||
op => wrap_tfe_src.SetOpAttrs(tf.context, op, new object[] { "keep_dims", keep_dims }, status), | |||||
op => wrap_tfe_src.SetOpAttrs(op, "keep_dims", keep_dims), | |||||
status); | status); | ||||
status.Check(true); | status.Check(true); | ||||
return new EagerTensor(tensor); | |||||
return tensor; | |||||
} | } | ||||
var _op = _op_def_lib._apply_op_helper("Mean", name, args: new { input, reduction_indices = axis, keep_dims = keep_dims }); | var _op = _op_def_lib._apply_op_helper("Mean", name, args: new { input, reduction_indices = axis, keep_dims = keep_dims }); | ||||
@@ -185,10 +185,7 @@ namespace Tensorflow | |||||
input as EagerTensor, | input as EagerTensor, | ||||
axis as EagerTensor | axis as EagerTensor | ||||
}, 2, | }, 2, | ||||
op => wrap_tfe_src.SetOpAttrs(tf.context, op, new object[] | |||||
{ | |||||
"keep_dims", keep_dims | |||||
}, status), | |||||
op => wrap_tfe_src.SetOpAttrs(op, "keep_dims", keep_dims), | |||||
status); | status); | ||||
status.Check(true); | status.Check(true); | ||||
return tensor; | return tensor; | ||||
@@ -232,14 +229,14 @@ namespace Tensorflow | |||||
if (tf.context.executing_eagerly()) | if (tf.context.executing_eagerly()) | ||||
{ | { | ||||
using var status = new Status(); | using var status = new Status(); | ||||
var _result = c_api.TFE_FastPathExecute(tf.context, tf.context.device_name, | |||||
EagerTensorHandle _result = c_api.TFE_FastPathExecute(tf.context, tf.context.device_name, | |||||
"Add", name, new IntPtr[] | "Add", name, new IntPtr[] | ||||
{ | { | ||||
x as EagerTensor, | x as EagerTensor, | ||||
y as EagerTensor | y as EagerTensor | ||||
}, 2, null, status); | }, 2, null, status); | ||||
status.Check(true); | status.Check(true); | ||||
return new EagerTensor(_result); | |||||
return _result; | |||||
} | } | ||||
var _op = _op_def_lib._apply_op_helper("Add", name, args: new { x, y }); | var _op = _op_def_lib._apply_op_helper("Add", name, args: new { x, y }); | ||||
@@ -273,14 +270,14 @@ namespace Tensorflow | |||||
if (tf.context.executing_eagerly()) | if (tf.context.executing_eagerly()) | ||||
{ | { | ||||
using var status = new Status(); | using var status = new Status(); | ||||
var tensor = c_api.TFE_FastPathExecute(tf.context, tf.context.device_name, | |||||
EagerTensorHandle tensor = c_api.TFE_FastPathExecute(tf.context, tf.context.device_name, | |||||
"AddV2", name, new IntPtr[] | "AddV2", name, new IntPtr[] | ||||
{ | { | ||||
x as EagerTensor, | x as EagerTensor, | ||||
y as EagerTensor | y as EagerTensor | ||||
}, 2, null, status); | }, 2, null, status); | ||||
status.Check(true); | status.Check(true); | ||||
return new EagerTensor(tensor); | |||||
return tensor; | |||||
} | } | ||||
var _op = _op_def_lib._apply_op_helper("AddV2", name, args: new { x, y }); | var _op = _op_def_lib._apply_op_helper("AddV2", name, args: new { x, y }); | ||||
@@ -574,6 +571,18 @@ namespace Tensorflow | |||||
/// <returns> A `Tensor`. Has the same type as `x`.</returns> | /// <returns> A `Tensor`. Has the same type as `x`.</returns> | ||||
public static Tensor square(Tensor x, string name = null) | public static Tensor square(Tensor x, string name = null) | ||||
{ | { | ||||
if (tf.context.executing_eagerly()) | |||||
{ | |||||
using var status = new Status(); | |||||
EagerTensorHandle tensor = c_api.TFE_FastPathExecute(tf.context, tf.context.device_name, | |||||
"Square", name, new IntPtr[] | |||||
{ | |||||
x as EagerTensor, | |||||
}, 1, null, status); | |||||
status.Check(true); | |||||
return tensor; | |||||
} | |||||
var _op = _op_def_lib._apply_op_helper("Square", name, args: new { x }); | var _op = _op_def_lib._apply_op_helper("Square", name, args: new { x }); | ||||
return _op.outputs[0]; | return _op.outputs[0]; | ||||
@@ -633,7 +642,7 @@ namespace Tensorflow | |||||
var tensor = c_api.TFE_FastPathExecute(tf.context, tf.context.device_name, | var tensor = c_api.TFE_FastPathExecute(tf.context, tf.context.device_name, | ||||
"Cast", name, | "Cast", name, | ||||
new IntPtr[] { x as EagerTensor }, 1, | new IntPtr[] { x as EagerTensor }, 1, | ||||
op => wrap_tfe_src.SetOpAttrs(tf.context, op, new object[] { "DstT", DstT, "Truncate", Truncate }, status), | |||||
op => wrap_tfe_src.SetOpAttrs(op, "DstT", DstT, "Truncate", Truncate), | |||||
status); | status); | ||||
status.Check(true); | status.Check(true); | ||||
return new EagerTensor(tensor); | return new EagerTensor(tensor); | ||||
@@ -918,11 +927,9 @@ namespace Tensorflow | |||||
a as EagerTensor, | a as EagerTensor, | ||||
b as EagerTensor | b as EagerTensor | ||||
}, 2, | }, 2, | ||||
op => wrap_tfe_src.SetOpAttrs(tf.context, op, new object[] | |||||
{ | |||||
op => wrap_tfe_src.SetOpAttrs(op, | |||||
"transpose_a", transpose_a, | "transpose_a", transpose_a, | ||||
"transpose_b", transpose_b | |||||
}, status), | |||||
"transpose_b", transpose_b), | |||||
status); | status); | ||||
status.Check(true); | status.Check(true); | ||||
return new EagerTensor(tensor); | return new EagerTensor(tensor); | ||||
@@ -1049,7 +1056,7 @@ namespace Tensorflow | |||||
input as EagerTensor, | input as EagerTensor, | ||||
axis as EagerTensor | axis as EagerTensor | ||||
}, 2, | }, 2, | ||||
op => wrap_tfe_src.SetOpAttrs(tf.context, op, new object[] { "keep_dims", keep_dims }, status), | |||||
op => wrap_tfe_src.SetOpAttrs(op, "keep_dims", keep_dims), | |||||
status); | status); | ||||
status.Check(true); | status.Check(true); | ||||
return new EagerTensor(tensor); | return new EagerTensor(tensor); | ||||
@@ -13,6 +13,9 @@ | |||||
See the License for the specific language governing permissions and | See the License for the specific language governing permissions and | ||||
limitations under the License. | limitations under the License. | ||||
******************************************************************************/ | ******************************************************************************/ | ||||
using System; | |||||
using Tensorflow.Eager; | |||||
using static Tensorflow.Binding; | |||||
namespace Tensorflow | namespace Tensorflow | ||||
{ | { | ||||
@@ -36,6 +39,23 @@ namespace Tensorflow | |||||
if (!seed2.HasValue) | if (!seed2.HasValue) | ||||
seed2 = 0; | seed2 = 0; | ||||
if (tf.context.executing_eagerly()) | |||||
{ | |||||
using var status = new Status(); | |||||
EagerTensorHandle tensor = c_api.TFE_FastPathExecute(tf.context, tf.context.device_name, | |||||
"RandomStandardNormal", name, new IntPtr[] | |||||
{ | |||||
shape as EagerTensor, | |||||
}, 1, | |||||
op => wrap_tfe_src.SetOpAttrs(op, | |||||
"seed", seed, | |||||
"seed2", seed2, | |||||
"dtype", dtype), | |||||
status); | |||||
status.Check(true); | |||||
return tensor; | |||||
} | |||||
var _op = _op_def_lib._apply_op_helper("RandomStandardNormal", | var _op = _op_def_lib._apply_op_helper("RandomStandardNormal", | ||||
name: name, | name: name, | ||||
args: new { shape, dtype, seed, seed2 }); | args: new { shape, dtype, seed, seed2 }); | ||||
@@ -25,6 +25,25 @@ namespace Tensorflow | |||||
{ | { | ||||
public static OpDefLibrary _op_def_lib = new OpDefLibrary(); | public static OpDefLibrary _op_def_lib = new OpDefLibrary(); | ||||
public static Operation assign_sub_variable_op(Tensor resource, Tensor value, string name = null) | |||||
{ | |||||
if (tf.context.executing_eagerly()) | |||||
{ | |||||
using var status = new Status(); | |||||
var tensor = c_api.TFE_FastPathExecute(tf.context, tf.context.device_name, | |||||
"AssignSubVariableOp", name, | |||||
new IntPtr[] | |||||
{ | |||||
resource as EagerTensor, | |||||
value as EagerTensor | |||||
}, 2, null, status); | |||||
status.Check(true); | |||||
return tensor; | |||||
} | |||||
return null; | |||||
} | |||||
public static Operation assign_variable_op(Tensor resource, Tensor value, string name = null) | public static Operation assign_variable_op(Tensor resource, Tensor value, string name = null) | ||||
{ | { | ||||
if (tf.context.executing_eagerly()) | if (tf.context.executing_eagerly()) | ||||
@@ -51,12 +70,12 @@ namespace Tensorflow | |||||
if (tf.context.executing_eagerly()) | if (tf.context.executing_eagerly()) | ||||
{ | { | ||||
using var status = new Status(); | using var status = new Status(); | ||||
var tensor = c_api.TFE_FastPathExecute(tf.context, tf.context.device_name, | |||||
EagerTensorHandle tensor = c_api.TFE_FastPathExecute(tf.context, tf.context.device_name, | |||||
"VarIsInitializedOp", name, | "VarIsInitializedOp", name, | ||||
new IntPtr[] { resource as EagerTensor }, | new IntPtr[] { resource as EagerTensor }, | ||||
1, null, status); | 1, null, status); | ||||
status.Check(true); | status.Check(true); | ||||
return new EagerTensor(tensor); | |||||
return tensor; | |||||
} | } | ||||
var _op = _op_def_lib._apply_op_helper("VarIsInitializedOp", name, new { resource }); | var _op = _op_def_lib._apply_op_helper("VarIsInitializedOp", name, new { resource }); | ||||
@@ -79,18 +98,16 @@ namespace Tensorflow | |||||
if (tf.context.executing_eagerly()) | if (tf.context.executing_eagerly()) | ||||
{ | { | ||||
using var status = new Status(); | using var status = new Status(); | ||||
var tensor = c_api.TFE_FastPathExecute(tf.context, tf.context.device_name, | |||||
EagerTensorHandle tensor = c_api.TFE_FastPathExecute(tf.context, tf.context.device_name, | |||||
"VarHandleOp", name, null, 0, | "VarHandleOp", name, null, 0, | ||||
op => wrap_tfe_src.SetOpAttrs(tf.context, op, new object[] | |||||
{ | |||||
op => wrap_tfe_src.SetOpAttrs(op, | |||||
"container", container, | "container", container, | ||||
"shared_name", shared_name, | "shared_name", shared_name, | ||||
"dtype", dtype, | "dtype", dtype, | ||||
"shape", shape.dims | |||||
}, status), | |||||
"shape", shape.dims), | |||||
status); | status); | ||||
status.Check(true); | status.Check(true); | ||||
return new EagerTensor(tensor); | |||||
return tensor; | |||||
} | } | ||||
var _op = _op_def_lib._apply_op_helper("VarHandleOp", name, new { | var _op = _op_def_lib._apply_op_helper("VarHandleOp", name, new { | ||||
@@ -118,7 +135,7 @@ namespace Tensorflow | |||||
EagerTensorHandle tensor = c_api.TFE_FastPathExecute(tf.context, tf.context.device_name, | EagerTensorHandle tensor = c_api.TFE_FastPathExecute(tf.context, tf.context.device_name, | ||||
"ReadVariableOp", name, | "ReadVariableOp", name, | ||||
new IntPtr[] { resource as EagerTensor }, 1, | new IntPtr[] { resource as EagerTensor }, 1, | ||||
op => wrap_tfe_src.SetOpAttrs(tf.context, op, new object[] { "dtype", dtype }, status), | |||||
op => wrap_tfe_src.SetOpAttrs(op, "dtype", dtype), | |||||
status); | status); | ||||
status.Check(true); | status.Check(true); | ||||
return tensor; | return tensor; | ||||
@@ -47,6 +47,7 @@ namespace Tensorflow | |||||
var rnd = gen_random_ops.random_standard_normal(shape_tensor, dtype: dtype, seed: seed1, seed2: seed2); | var rnd = gen_random_ops.random_standard_normal(shape_tensor, dtype: dtype, seed: seed1, seed2: seed2); | ||||
var mul = rnd * stddev_tensor; | var mul = rnd * stddev_tensor; | ||||
var value = math_ops.add(mul, mean_tensor, name: name); | var value = math_ops.add(mul, mean_tensor, name: name); | ||||
// tensor_util.maybe_set_static_shape(value, shape) | |||||
return value; | return value; | ||||
}); | }); | ||||
} | } | ||||
@@ -0,0 +1,37 @@ | |||||
/***************************************************************************** | |||||
Copyright 2018 The TensorFlow.NET Authors. All Rights Reserved. | |||||
Licensed under the Apache License, Version 2.0 (the "License"); | |||||
you may not use this file except in compliance with the License. | |||||
You may obtain a copy of the License at | |||||
http://www.apache.org/licenses/LICENSE-2.0 | |||||
Unless required by applicable law or agreed to in writing, software | |||||
distributed under the License is distributed on an "AS IS" BASIS, | |||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
See the License for the specific language governing permissions and | |||||
limitations under the License. | |||||
******************************************************************************/ | |||||
using NumSharp; | |||||
using System; | |||||
using static Tensorflow.Binding; | |||||
namespace Tensorflow | |||||
{ | |||||
public partial class ResourceVariable | |||||
{ | |||||
/// <summary> | |||||
/// Subtracts a value from this variable. | |||||
/// </summary> | |||||
/// <param name="delta"></param> | |||||
/// <param name="use_locking"></param> | |||||
/// <param name="name"></param> | |||||
/// <param name="read_value"></param> | |||||
public void assign_sub(Tensor delta, bool use_locking = false, string name = null, bool read_value = true) | |||||
{ | |||||
gen_resource_variable_ops.assign_sub_variable_op(handle, delta, name: name); | |||||
} | |||||
} | |||||
} |
@@ -11,22 +11,52 @@ namespace TensorFlowNET.UnitTest.Training | |||||
[TestClass] | [TestClass] | ||||
public class BasicLinearModel | public class BasicLinearModel | ||||
{ | { | ||||
int NUM_EXAMPLES = 1000; | |||||
/// <summary> | |||||
/// Linear Regression without tf.train.Optimizer | |||||
/// https://www.tensorflow.org/tutorials/customization/custom_training | |||||
/// </summary> | |||||
[TestMethod] | [TestMethod] | ||||
public void FitLinear() | |||||
public void LinearRegression() | |||||
{ | { | ||||
// Initialize the weights to `5.0` and the bias to `0.0` | // Initialize the weights to `5.0` and the bias to `0.0` | ||||
// In practice, these should be initialized to random values (for example, with `tf.random.normal`) | // In practice, these should be initialized to random values (for example, with `tf.random.normal`) | ||||
var W = tf.Variable(5.0f); | var W = tf.Variable(5.0f); | ||||
var b = tf.Variable(0.0); | |||||
var b = tf.Variable(0.0f); | |||||
// Define linear model | |||||
Func<Tensor, Tensor> model = (x) => W * x + b; | |||||
// Define the loss function | |||||
Func<Tensor, Tensor, Tensor> loss = (target_y, predicted_y) | |||||
=> tf.reduce_mean(tf.square(target_y - predicted_y)); | |||||
int NUM_EXAMPLES = 1000; | |||||
float TRUE_W = 3.0f; | |||||
float TRUE_b = 2.0f; | |||||
var inputs = tf.random.normal(shape: NUM_EXAMPLES); | |||||
var noise = tf.random.normal(shape: NUM_EXAMPLES); | |||||
var outputs = inputs * TRUE_W + TRUE_b + noise; | |||||
print($"Current loss: {loss(model(inputs), outputs).numpy()}"); | |||||
// define linear model | |||||
Func<NDArray, Tensor> model = (x) => W * x + b; | |||||
// Define a training loop | |||||
Action<Tensor, Tensor, float> train = (inputs, outputs, learning_rate) | |||||
=> | |||||
{ | |||||
using var t = tf.GradientTape(); | |||||
var current_loss = loss(outputs, model(inputs)); | |||||
var (dW, db) = t.gradient(current_loss, (W, b)); | |||||
W.assign_sub(learning_rate * dW); | |||||
b.assign_sub(learning_rate * db); | |||||
}; | |||||
// var inputs = tf.random.normal(shape =[NUM_EXAMPLES]); | |||||
// noise = tf.random.normal(shape =[NUM_EXAMPLES]) | |||||
// outputs = inputs * TRUE_W + TRUE_b + noise | |||||
var epochs = range(10); | |||||
foreach(var epoch in epochs) | |||||
{ | |||||
train(inputs, outputs, 0.1f); | |||||
print($"Epoch %2d: W=%1.2f b=%1.2f, loss=%2.5f"); | |||||
} | |||||
} | } | ||||
} | } | ||||
} | } |