@@ -40,6 +40,7 @@ namespace Tensorflow.Eager | |||
}*/ | |||
} | |||
// Console.WriteLine($"RecordGradient: should_record={should_record}, op_name={op_name}"); | |||
if (!should_record) return should_record; | |||
Tensor[] op_outputs; | |||
@@ -2,6 +2,7 @@ | |||
using System.Collections.Generic; | |||
using System.Linq; | |||
using System.Text; | |||
using Tensorflow.Framework.Models; | |||
using Tensorflow.Graphs; | |||
using static Tensorflow.Binding; | |||
@@ -14,6 +15,8 @@ namespace Tensorflow.Functions | |||
{ | |||
public string Name => _handle == IntPtr.Zero ? string.Empty : c_api.StringPiece(c_api.TF_FunctionName(_handle)); | |||
IntPtr _handle; | |||
public Tensor[] Outputs; | |||
public TensorSpec[] OutputStructure; | |||
public ConcreteFunction(Func<Tensor, Tensor> func, TF_DataType dtype) | |||
{ | |||
@@ -43,23 +46,38 @@ namespace Tensorflow.Functions | |||
var input = tf.placeholder(dtype); | |||
var output = func(input); | |||
OutputStructure = output.structure; | |||
var opers = graph._nodes_by_name.Values.Select(x => x as Operation).ToArray(); | |||
_handle = graph.ToGraph(opers, | |||
new Operation[] { input }, | |||
new Operation[] { }, | |||
new Operation[] { output.variant_tensor.op }, | |||
null); | |||
} | |||
} | |||
public Tensor Execute(Tensor arg) | |||
public ConcreteFunction(Func<Tensor, (Tensor, Tensor), (Tensor, Tensor)> func, | |||
TF_DataType[] dtypes, TensorShape[] shapes) | |||
{ | |||
var result = tf.Runner.TFE_Execute(tf.Context, | |||
tf.Context.DeviceName, | |||
Name, | |||
new[] { arg }, | |||
null, | |||
1); | |||
return result[0]; | |||
string func_name = $"autograph_{Guid.NewGuid()}_{func.Method.Name}"; | |||
// IntPtr func_handle; | |||
using (var graph = new FuncGraph(func_name)) | |||
{ | |||
var input1 = tf.placeholder(dtypes[0], shape: shapes[0], name: "args"); | |||
var input2 = tf.placeholder(dtypes[1], shape: shapes[1], name: "args"); | |||
var input3 = tf.placeholder(dtypes[2], shape: shapes[2], name: "args"); | |||
var outputs = func(input1, (input2, input3)); | |||
Outputs = new[] { outputs.Item1, outputs.Item2 }; | |||
OutputStructure = new[] { outputs.Item1.ToTensorSpec(), outputs.Item2.ToTensorSpec() }; | |||
var opers = graph._nodes_by_name.Values.Select(x => x as Operation).ToArray(); | |||
_handle = graph.ToGraph(opers, | |||
new Operation[] { input1, input2, input3 }, | |||
new Operation[] { outputs.Item1.op, outputs.Item2.op }, | |||
null); | |||
} | |||
} | |||
public void Dispose() | |||
@@ -35,7 +35,7 @@ namespace Tensorflow.Gradients | |||
if (!state.op_tape.find(op, out var trace)) | |||
continue; | |||
Console.WriteLine($"ComputeGradient: {state.op_tape[op].op_type}"); | |||
// Console.WriteLine($"ComputeGradient: {state.op_tape[op].op_type}"); | |||
state.op_tape.erase(op); | |||
var out_gradients = new List<Tensor>(trace.output_tensor_info.Length); | |||
@@ -2,6 +2,7 @@ | |||
using System.Collections.Generic; | |||
using System.Text; | |||
using Tensorflow.Keras.ArgsDefinition; | |||
using static Tensorflow.Binding; | |||
namespace Tensorflow.Keras.Engine.DataAdapters | |||
{ | |||
@@ -12,10 +13,29 @@ namespace Tensorflow.Keras.Engine.DataAdapters | |||
{ | |||
DataHandlerArgs args; | |||
IDataAdapter _adapter; | |||
IDatasetV2 _dataset; | |||
int _inferred_steps; | |||
int _current_step; | |||
int _step_increment; | |||
bool _insufficient_data; | |||
int _steps_per_execution_value; | |||
int _initial_epoch => args.InitialEpoch; | |||
int _epochs => args.Epochs; | |||
IVariableV1 _steps_per_execution; | |||
public DataHandler(DataHandlerArgs args) | |||
{ | |||
this.args = args; | |||
if(args.StepsPerExecution == null) | |||
{ | |||
_steps_per_execution = tf.Variable(1); | |||
_steps_per_execution_value = 1; | |||
} | |||
else | |||
{ | |||
_steps_per_execution = args.StepsPerExecution; | |||
_steps_per_execution_value = args.StepsPerExecution.numpy(); | |||
} | |||
_adapter = new TensorLikeDataAdapter(new TensorLikeDataAdapterArgs | |||
{ | |||
@@ -30,11 +50,64 @@ namespace Tensorflow.Keras.Engine.DataAdapters | |||
UseMultiprocessing = args.UseMultiprocessing, | |||
Model = args.Model | |||
}); | |||
_dataset = _adapter.GetDataset(); | |||
_inferred_steps = _infer_steps(args.StepsPerEpoch, _dataset); | |||
_current_step = 0; | |||
_step_increment = args.StepsPerExecution.numpy() - 1; | |||
_insufficient_data = false; | |||
} | |||
Tensor _infer_steps(IDatasetV2 dataset) | |||
int _infer_steps(int steps_per_epoch, IDatasetV2 dataset) | |||
{ | |||
if (steps_per_epoch > -1) | |||
return steps_per_epoch; | |||
var adapter_steps = _adapter.GetSize(); | |||
if (adapter_steps > -1) | |||
return adapter_steps; | |||
throw new NotImplementedException(""); | |||
} | |||
public IEnumerable<(int, OwnedIterator)> enumerate_epochs() | |||
{ | |||
using var ownedIterator = new OwnedIterator(_dataset); | |||
foreach (var epoch in range(_initial_epoch, _epochs)) | |||
{ | |||
if (_insufficient_data) | |||
break; | |||
yield return (epoch, ownedIterator); | |||
} | |||
} | |||
public IEnumerable<int> steps() | |||
{ | |||
_current_step = 0; | |||
while(_current_step < _inferred_steps) | |||
{ | |||
if (_insufficient_data) | |||
break; | |||
bool can_run_full_execution = _steps_per_execution_value == 1 | |||
|| _inferred_steps < 0 | |||
|| _inferred_steps - _current_step >= _steps_per_execution_value; | |||
if (can_run_full_execution) | |||
{ | |||
_step_increment = _steps_per_execution_value - 1; | |||
yield return _current_step; | |||
_current_step += _steps_per_execution_value; | |||
} | |||
else | |||
{ | |||
var steps_remaining = _inferred_steps - _current_step; | |||
_steps_per_execution.assign(steps_remaining); | |||
_step_increment = steps_remaining - 1; | |||
yield return _current_step; | |||
_current_step += steps_remaining; | |||
_steps_per_execution.assign(_steps_per_execution_value); | |||
} | |||
} | |||
} | |||
} | |||
} |
@@ -18,5 +18,7 @@ namespace Tensorflow.Keras.Engine.DataAdapters | |||
/// <param name="y">target labels</param> | |||
/// <returns></returns> | |||
bool CanHandle(Tensor x, Tensor y = null); | |||
IDatasetV2 GetDataset(); | |||
int GetSize(); | |||
} | |||
} |
@@ -16,6 +16,7 @@ namespace Tensorflow.Keras.Engine.DataAdapters | |||
int _batch_size; | |||
int num_samples; | |||
int num_full_batches; | |||
IDatasetV2 _dataset; | |||
public TensorLikeDataAdapter(TensorLikeDataAdapterArgs args) | |||
{ | |||
@@ -32,6 +33,7 @@ namespace Tensorflow.Keras.Engine.DataAdapters | |||
indices_dataset = indices_dataset.repeat(); | |||
indices_dataset = indices_dataset.map(permutation).prefetch(1); | |||
indices_dataset = indices_dataset.flat_map(slice_batch_indices); | |||
_dataset = slice_inputs(indices_dataset, args.X, args.Y); | |||
} | |||
Tensor permutation(Tensor tensor) | |||
@@ -53,13 +55,24 @@ namespace Tensorflow.Keras.Engine.DataAdapters | |||
var first_k_indices = array_ops.slice(indices, new int[] { 0 }, new int[] { num_in_full_batch }); | |||
first_k_indices = array_ops.reshape(first_k_indices, new int[] { num_full_batches, _batch_size }); | |||
var flat_dataset = tf.data.Dataset.from_tensor_slices(first_k_indices); | |||
return flat_dataset; | |||
} | |||
void slice_inputs(IDatasetV2 indices_dataset, Tensor x, Tensor y) | |||
IDatasetV2 slice_inputs(IDatasetV2 indices_dataset, Tensor x, Tensor y) | |||
{ | |||
var dataset = tf.data.Dataset.from_tensor(x, y); | |||
var dataset2 = tf.data.Dataset.from_tensor(x, y).repeat(); | |||
var dataset = tf.data.Dataset.zip(indices_dataset, dataset2); | |||
dataset = dataset.map((batch, data) => | |||
{ | |||
var x = gen_array_ops.gather_v2(data.Item1, batch, 0); | |||
var y = gen_array_ops.gather_v2(data.Item2, batch, 0); | |||
return (x, y); | |||
}); | |||
dataset = dataset.with_options(new DatasetOptions { }); | |||
return dataset; | |||
} | |||
public bool CanHandle(Tensor x, Tensor y = null) | |||
@@ -70,5 +83,11 @@ namespace Tensorflow.Keras.Engine.DataAdapters | |||
void _process_tensorlike() | |||
{ | |||
} | |||
public IDatasetV2 GetDataset() | |||
=> _dataset; | |||
public int GetSize() | |||
=> _size; | |||
} | |||
} |
@@ -21,5 +21,20 @@ namespace Tensorflow.Keras.Engine | |||
_loss_metric = new Mean(name: "loss"); | |||
_built = false; | |||
} | |||
/// <summary> | |||
/// Computes the overall loss. | |||
/// </summary> | |||
/// <param name="y_true"></param> | |||
/// <param name="y_pred"></param> | |||
public void Apply(Tensor y_true, Tensor y_pred) | |||
{ | |||
} | |||
public void Build() | |||
{ | |||
} | |||
} | |||
} |
@@ -51,6 +51,21 @@ namespace Tensorflow.Keras.Engine | |||
Model = this, | |||
StepsPerExecution = _steps_per_execution | |||
}); | |||
stop_training = false; | |||
_train_counter.assign(0); | |||
foreach(var (epoch, iterator) in data_handler.enumerate_epochs()) | |||
{ | |||
// reset_metrics(); | |||
// callbacks.on_epoch_begin(epoch) | |||
// data_handler.catch_stop_iteration(); | |||
foreach(var step in data_handler.steps()) | |||
{ | |||
// callbacks.on_train_batch_begin(step) | |||
step_function(iterator); | |||
} | |||
} | |||
} | |||
} | |||
} |
@@ -0,0 +1,30 @@ | |||
using NumSharp; | |||
using System; | |||
using System.Collections.Generic; | |||
using System.Text; | |||
using static Tensorflow.Binding; | |||
namespace Tensorflow.Keras.Engine | |||
{ | |||
public partial class Model | |||
{ | |||
Tensor step_function(OwnedIterator iterator) | |||
{ | |||
var data = iterator.next(); | |||
train_step(data[0], data[1]); | |||
throw new NotImplementedException(""); | |||
} | |||
/// <summary> | |||
/// The logic for one training step. | |||
/// </summary> | |||
/// <param name="data"></param> | |||
/// <returns></returns> | |||
Tensor train_step(Tensor x, Tensor y) | |||
{ | |||
using var tape = tf.GradientTape(); | |||
var y_pred = Apply(x, is_training: true); | |||
throw new NotImplementedException(""); | |||
} | |||
} | |||
} |
@@ -33,11 +33,12 @@ namespace Tensorflow.Keras.Engine | |||
IVariableV1 _test_counter; | |||
IVariableV1 _predict_counter; | |||
bool _base_model_initialized; | |||
bool stop_training; | |||
public Model(ModelArgs args) | |||
: base(args) | |||
{ | |||
_init_batch_counters(); | |||
} | |||
void _configure_steps_per_execution(int steps_per_execution) | |||
@@ -64,6 +64,7 @@ namespace Tensorflow | |||
var inferred_from = new Dictionary<string, object>(); | |||
var base_types = new List<TF_DataType>(); | |||
var types = new List<TF_DataType>(); | |||
string _scope_name = scope; | |||
// Perform input type inference | |||
foreach (var input_arg in op_def.InputArg) | |||
@@ -241,7 +242,7 @@ namespace Tensorflow | |||
var op = g.create_op(op_type_name, | |||
inputs.ToArray(), | |||
output_types.ToArray(), | |||
name: scope, | |||
name: _scope_name, | |||
input_types: input_types.ToArray(), | |||
attrs: attr_protos, | |||
op_def: op_def); | |||
@@ -471,6 +471,42 @@ namespace Tensorflow | |||
throw new NotImplementedException(""); | |||
} | |||
/// <summary> | |||
/// Creates a dataset that applies `f` to the outputs of `input_dataset`. | |||
/// </summary> | |||
/// <param name="dataset"></param> | |||
/// <param name="num_parallel_calls"></param> | |||
/// <param name="f"></param> | |||
/// <param name="output_types"></param> | |||
/// <param name="output_shapes"></param> | |||
/// <param name="use_inter_op_parallelism"></param> | |||
/// <param name="name"></param> | |||
/// <returns></returns> | |||
public Tensor parallel_map_dataset_v2(Tensor dataset, Tensor num_parallel_calls, ConcreteFunction f, | |||
TF_DataType[] output_types, TensorShape[] output_shapes, | |||
bool use_inter_op_parallelism = true, | |||
string deterministic = "default", | |||
bool preserve_cardinality = false, | |||
string name = null) | |||
{ | |||
if (tf.Context.executing_eagerly()) | |||
{ | |||
var results = tf.Runner.TFE_FastPathExecute(tf.Context, tf.Context.DeviceName, | |||
"ParallelMapDatasetV2", name, | |||
null, | |||
dataset, new Tensor[0], num_parallel_calls, | |||
"f", f, | |||
"output_types", output_types, | |||
"output_shapes", output_shapes, | |||
"use_inter_op_parallelism", use_inter_op_parallelism, | |||
"deterministic", deterministic, | |||
"preserve_cardinality", preserve_cardinality); | |||
return results[0]; | |||
} | |||
throw new NotImplementedException(""); | |||
} | |||
/// <summary> | |||
/// A container for an iterator resource. | |||
/// </summary> | |||
@@ -739,9 +739,9 @@ namespace Tensorflow | |||
return tf_with(ops.name_scope(name, "Range", new { start, limit, delta }), scope => | |||
{ | |||
name = scope; | |||
var start1 = ops.convert_to_tensor(start, name: "start"); | |||
var limit1 = ops.convert_to_tensor(limit, name: "limit"); | |||
var delta1 = ops.convert_to_tensor(delta, name: "delta"); | |||
var start1 = ops.convert_to_tensor(start, name: "start", dtype: dtype); | |||
var limit1 = ops.convert_to_tensor(limit, name: "limit", dtype: dtype); | |||
var delta1 = ops.convert_to_tensor(delta, name: "delta", dtype: dtype); | |||
return gen_math_ops.range(start1, limit1, delta1, name); | |||
}); | |||
@@ -14,6 +14,7 @@ | |||
limitations under the License. | |||
******************************************************************************/ | |||
using NumSharp; | |||
using System; | |||
using System.Collections.Generic; | |||
@@ -31,24 +32,25 @@ namespace Tensorflow | |||
/// </summary> | |||
public interface IVariableV1 | |||
{ | |||
public string UniqueId { get; } | |||
public string Name { get; } | |||
string UniqueId { get; } | |||
string Name { get; } | |||
/// <summary> | |||
/// Handle is ref type | |||
/// </summary> | |||
public Tensor Handle { get; } | |||
public string Device { get; } | |||
public Operation Initializer { get; } | |||
public Operation Op { get; } | |||
Tensor Handle { get; } | |||
string Device { get; } | |||
Operation Initializer { get; } | |||
Operation Op { get; } | |||
/// <summary> | |||
/// GraphElement is a copy of Handle | |||
/// </summary> | |||
public Tensor GraphElement { get; } | |||
public Graph Graph { get; } | |||
public TF_DataType dtype { get; } | |||
public TensorShape shape { get; } | |||
Tensor GraphElement { get; } | |||
Graph Graph { get; } | |||
TF_DataType dtype { get; } | |||
TensorShape shape { get; } | |||
Tensor assign_add<T>(T delta, bool use_locking = false, string name = null, bool read_value = true); | |||
Tensor assign<T>(T value, bool use_locking = false, string name = null, bool read_value = true); | |||
Tensor AsTensor(TF_DataType dtype = TF_DataType.DtInvalid, string name = null, bool as_ref = false); | |||
NDArray numpy(); | |||
} | |||
} |
@@ -15,6 +15,7 @@ | |||
******************************************************************************/ | |||
using Google.Protobuf; | |||
using NumSharp; | |||
using System; | |||
using System.Collections.Generic; | |||
using System.Linq; | |||
@@ -424,5 +425,8 @@ namespace Tensorflow | |||
var _op = tf.OpDefLib._apply_op_helper("AssignAdd", name: name, args: new { variable, value, use_locking }); | |||
return _op; | |||
} | |||
public NDArray numpy() | |||
=> throw new RuntimeError("Graph mode can't use numpy()."); | |||
} | |||
} |