@@ -16,51 +16,95 @@ EndProject | |||
Global | |||
GlobalSection(SolutionConfigurationPlatforms) = preSolution | |||
Debug|Any CPU = Debug|Any CPU | |||
Debug|x64 = Debug|x64 | |||
Debug-Minimal|Any CPU = Debug-Minimal|Any CPU | |||
Debug-Minimal|x64 = Debug-Minimal|x64 | |||
Publish|Any CPU = Publish|Any CPU | |||
Publish|x64 = Publish|x64 | |||
Release|Any CPU = Release|Any CPU | |||
Release|x64 = Release|x64 | |||
EndGlobalSection | |||
GlobalSection(ProjectConfigurationPlatforms) = postSolution | |||
{FD682AC0-7B2D-45D3-8B0D-C6D678B04144}.Debug|Any CPU.ActiveCfg = Debug|Any CPU | |||
{FD682AC0-7B2D-45D3-8B0D-C6D678B04144}.Debug|Any CPU.Build.0 = Debug|Any CPU | |||
{FD682AC0-7B2D-45D3-8B0D-C6D678B04144}.Debug|x64.ActiveCfg = Debug|x64 | |||
{FD682AC0-7B2D-45D3-8B0D-C6D678B04144}.Debug|x64.Build.0 = Debug|x64 | |||
{FD682AC0-7B2D-45D3-8B0D-C6D678B04144}.Debug-Minimal|Any CPU.ActiveCfg = Debug|Any CPU | |||
{FD682AC0-7B2D-45D3-8B0D-C6D678B04144}.Debug-Minimal|Any CPU.Build.0 = Debug|Any CPU | |||
{FD682AC0-7B2D-45D3-8B0D-C6D678B04144}.Debug-Minimal|x64.ActiveCfg = Debug|x64 | |||
{FD682AC0-7B2D-45D3-8B0D-C6D678B04144}.Debug-Minimal|x64.Build.0 = Debug|x64 | |||
{FD682AC0-7B2D-45D3-8B0D-C6D678B04144}.Publish|Any CPU.ActiveCfg = Release|Any CPU | |||
{FD682AC0-7B2D-45D3-8B0D-C6D678B04144}.Publish|Any CPU.Build.0 = Release|Any CPU | |||
{FD682AC0-7B2D-45D3-8B0D-C6D678B04144}.Publish|x64.ActiveCfg = Release|x64 | |||
{FD682AC0-7B2D-45D3-8B0D-C6D678B04144}.Publish|x64.Build.0 = Release|x64 | |||
{FD682AC0-7B2D-45D3-8B0D-C6D678B04144}.Release|Any CPU.ActiveCfg = Release|Any CPU | |||
{FD682AC0-7B2D-45D3-8B0D-C6D678B04144}.Release|Any CPU.Build.0 = Release|Any CPU | |||
{FD682AC0-7B2D-45D3-8B0D-C6D678B04144}.Release|x64.ActiveCfg = Release|x64 | |||
{FD682AC0-7B2D-45D3-8B0D-C6D678B04144}.Release|x64.Build.0 = Release|x64 | |||
{3A6EB896-604F-4E25-B677-B8103BCF3D2E}.Debug|Any CPU.ActiveCfg = Debug|Any CPU | |||
{3A6EB896-604F-4E25-B677-B8103BCF3D2E}.Debug|Any CPU.Build.0 = Debug|Any CPU | |||
{3A6EB896-604F-4E25-B677-B8103BCF3D2E}.Debug|x64.ActiveCfg = Debug|x64 | |||
{3A6EB896-604F-4E25-B677-B8103BCF3D2E}.Debug|x64.Build.0 = Debug|x64 | |||
{3A6EB896-604F-4E25-B677-B8103BCF3D2E}.Debug-Minimal|Any CPU.ActiveCfg = Debug|Any CPU | |||
{3A6EB896-604F-4E25-B677-B8103BCF3D2E}.Debug-Minimal|Any CPU.Build.0 = Debug|Any CPU | |||
{3A6EB896-604F-4E25-B677-B8103BCF3D2E}.Debug-Minimal|x64.ActiveCfg = Debug|x64 | |||
{3A6EB896-604F-4E25-B677-B8103BCF3D2E}.Debug-Minimal|x64.Build.0 = Debug|x64 | |||
{3A6EB896-604F-4E25-B677-B8103BCF3D2E}.Publish|Any CPU.ActiveCfg = Release|Any CPU | |||
{3A6EB896-604F-4E25-B677-B8103BCF3D2E}.Publish|Any CPU.Build.0 = Release|Any CPU | |||
{3A6EB896-604F-4E25-B677-B8103BCF3D2E}.Publish|x64.ActiveCfg = Release|x64 | |||
{3A6EB896-604F-4E25-B677-B8103BCF3D2E}.Publish|x64.Build.0 = Release|x64 | |||
{3A6EB896-604F-4E25-B677-B8103BCF3D2E}.Release|Any CPU.ActiveCfg = Release|Any CPU | |||
{3A6EB896-604F-4E25-B677-B8103BCF3D2E}.Release|Any CPU.Build.0 = Release|Any CPU | |||
{3A6EB896-604F-4E25-B677-B8103BCF3D2E}.Release|x64.ActiveCfg = Release|x64 | |||
{3A6EB896-604F-4E25-B677-B8103BCF3D2E}.Release|x64.Build.0 = Release|x64 | |||
{23C28035-2FCE-41F3-9A12-E73CE8A5AE32}.Debug|Any CPU.ActiveCfg = Debug|Any CPU | |||
{23C28035-2FCE-41F3-9A12-E73CE8A5AE32}.Debug|Any CPU.Build.0 = Debug|Any CPU | |||
{23C28035-2FCE-41F3-9A12-E73CE8A5AE32}.Debug|x64.ActiveCfg = Debug|x64 | |||
{23C28035-2FCE-41F3-9A12-E73CE8A5AE32}.Debug|x64.Build.0 = Debug|x64 | |||
{23C28035-2FCE-41F3-9A12-E73CE8A5AE32}.Debug-Minimal|Any CPU.ActiveCfg = Debug|Any CPU | |||
{23C28035-2FCE-41F3-9A12-E73CE8A5AE32}.Debug-Minimal|Any CPU.Build.0 = Debug|Any CPU | |||
{23C28035-2FCE-41F3-9A12-E73CE8A5AE32}.Debug-Minimal|x64.ActiveCfg = Debug|x64 | |||
{23C28035-2FCE-41F3-9A12-E73CE8A5AE32}.Debug-Minimal|x64.Build.0 = Debug|x64 | |||
{23C28035-2FCE-41F3-9A12-E73CE8A5AE32}.Publish|Any CPU.ActiveCfg = Release|Any CPU | |||
{23C28035-2FCE-41F3-9A12-E73CE8A5AE32}.Publish|Any CPU.Build.0 = Release|Any CPU | |||
{23C28035-2FCE-41F3-9A12-E73CE8A5AE32}.Publish|x64.ActiveCfg = Release|x64 | |||
{23C28035-2FCE-41F3-9A12-E73CE8A5AE32}.Publish|x64.Build.0 = Release|x64 | |||
{23C28035-2FCE-41F3-9A12-E73CE8A5AE32}.Release|Any CPU.ActiveCfg = Release|Any CPU | |||
{23C28035-2FCE-41F3-9A12-E73CE8A5AE32}.Release|Any CPU.Build.0 = Release|Any CPU | |||
{23C28035-2FCE-41F3-9A12-E73CE8A5AE32}.Release|x64.ActiveCfg = Release|x64 | |||
{23C28035-2FCE-41F3-9A12-E73CE8A5AE32}.Release|x64.Build.0 = Release|x64 | |||
{6268B461-486A-460B-9B3C-86493CBBAAF7}.Debug|Any CPU.ActiveCfg = Debug|Any CPU | |||
{6268B461-486A-460B-9B3C-86493CBBAAF7}.Debug|Any CPU.Build.0 = Debug|Any CPU | |||
{6268B461-486A-460B-9B3C-86493CBBAAF7}.Debug|x64.ActiveCfg = Debug|x64 | |||
{6268B461-486A-460B-9B3C-86493CBBAAF7}.Debug|x64.Build.0 = Debug|x64 | |||
{6268B461-486A-460B-9B3C-86493CBBAAF7}.Debug-Minimal|Any CPU.ActiveCfg = Debug|Any CPU | |||
{6268B461-486A-460B-9B3C-86493CBBAAF7}.Debug-Minimal|Any CPU.Build.0 = Debug|Any CPU | |||
{6268B461-486A-460B-9B3C-86493CBBAAF7}.Debug-Minimal|x64.ActiveCfg = Debug|x64 | |||
{6268B461-486A-460B-9B3C-86493CBBAAF7}.Debug-Minimal|x64.Build.0 = Debug|x64 | |||
{6268B461-486A-460B-9B3C-86493CBBAAF7}.Publish|Any CPU.ActiveCfg = Release|Any CPU | |||
{6268B461-486A-460B-9B3C-86493CBBAAF7}.Publish|Any CPU.Build.0 = Release|Any CPU | |||
{6268B461-486A-460B-9B3C-86493CBBAAF7}.Publish|x64.ActiveCfg = Release|x64 | |||
{6268B461-486A-460B-9B3C-86493CBBAAF7}.Publish|x64.Build.0 = Release|x64 | |||
{6268B461-486A-460B-9B3C-86493CBBAAF7}.Release|Any CPU.ActiveCfg = Release|Any CPU | |||
{6268B461-486A-460B-9B3C-86493CBBAAF7}.Release|Any CPU.Build.0 = Release|Any CPU | |||
{6268B461-486A-460B-9B3C-86493CBBAAF7}.Release|x64.ActiveCfg = Release|x64 | |||
{6268B461-486A-460B-9B3C-86493CBBAAF7}.Release|x64.Build.0 = Release|x64 | |||
{EB92DD90-6346-41FB-B967-2B33A860AD98}.Debug|Any CPU.ActiveCfg = Debug|Any CPU | |||
{EB92DD90-6346-41FB-B967-2B33A860AD98}.Debug|Any CPU.Build.0 = Debug|Any CPU | |||
{EB92DD90-6346-41FB-B967-2B33A860AD98}.Debug|x64.ActiveCfg = Debug|x64 | |||
{EB92DD90-6346-41FB-B967-2B33A860AD98}.Debug|x64.Build.0 = Debug|x64 | |||
{EB92DD90-6346-41FB-B967-2B33A860AD98}.Debug-Minimal|Any CPU.ActiveCfg = Debug|Any CPU | |||
{EB92DD90-6346-41FB-B967-2B33A860AD98}.Debug-Minimal|Any CPU.Build.0 = Debug|Any CPU | |||
{EB92DD90-6346-41FB-B967-2B33A860AD98}.Debug-Minimal|x64.ActiveCfg = Debug|x64 | |||
{EB92DD90-6346-41FB-B967-2B33A860AD98}.Debug-Minimal|x64.Build.0 = Debug|x64 | |||
{EB92DD90-6346-41FB-B967-2B33A860AD98}.Publish|Any CPU.ActiveCfg = Release|Any CPU | |||
{EB92DD90-6346-41FB-B967-2B33A860AD98}.Publish|Any CPU.Build.0 = Release|Any CPU | |||
{EB92DD90-6346-41FB-B967-2B33A860AD98}.Publish|x64.ActiveCfg = Release|x64 | |||
{EB92DD90-6346-41FB-B967-2B33A860AD98}.Publish|x64.Build.0 = Release|x64 | |||
{EB92DD90-6346-41FB-B967-2B33A860AD98}.Release|Any CPU.ActiveCfg = Release|Any CPU | |||
{EB92DD90-6346-41FB-B967-2B33A860AD98}.Release|Any CPU.Build.0 = Release|Any CPU | |||
{EB92DD90-6346-41FB-B967-2B33A860AD98}.Release|x64.ActiveCfg = Release|x64 | |||
{EB92DD90-6346-41FB-B967-2B33A860AD98}.Release|x64.Build.0 = Release|x64 | |||
EndGlobalSection | |||
GlobalSection(SolutionProperties) = preSolution | |||
HideSolutionNode = FALSE | |||
@@ -20,8 +20,8 @@ namespace Tensorflow | |||
{ | |||
public partial class tensorflow | |||
{ | |||
public GradientActor GradientTape() | |||
=> new GradientActor(); | |||
public GradientTape GradientTape() | |||
=> new GradientTape(); | |||
public Tensor[] gradients(Tensor[] ys, | |||
Tensor[] xs, | |||
@@ -123,8 +123,8 @@ namespace Tensorflow | |||
=> gen_nn_ops.relu(features, name); | |||
public Tensor[] fused_batch_norm(Tensor x, | |||
VariableV1 scale, | |||
VariableV1 offset, | |||
IVariableV1 scale, | |||
IVariableV1 offset, | |||
Tensor mean = null, | |||
Tensor variance = null, | |||
float epsilon = 0.001f, | |||
@@ -50,7 +50,7 @@ namespace Tensorflow | |||
public ExponentialMovingAverage ExponentialMovingAverage(float decay) | |||
=> new ExponentialMovingAverage(decay); | |||
public Saver Saver(VariableV1[] var_list = null, int max_to_keep = 5) | |||
public Saver Saver(IVariableV1[] var_list = null, int max_to_keep = 5) | |||
=> new Saver(var_list: var_list, max_to_keep: max_to_keep); | |||
public string write_graph(Graph graph, string logdir, string name, bool as_text = true) | |||
@@ -68,7 +68,7 @@ namespace Tensorflow | |||
clear_devices, | |||
import_scope).Item1; | |||
public (MetaGraphDef, Dictionary<string, VariableV1>) export_meta_graph(string filename = "", | |||
public (MetaGraphDef, Dictionary<string, IVariableV1>) export_meta_graph(string filename = "", | |||
bool as_text = false, | |||
bool clear_devices = false, | |||
bool clear_extraneous_savers = false, | |||
@@ -21,9 +21,9 @@ namespace Tensorflow | |||
{ | |||
public partial class tensorflow | |||
{ | |||
public VariableV1[] global_variables(string scope = null) | |||
public IVariableV1[] global_variables(string scope = null) | |||
{ | |||
return (ops.get_collection(tf.GraphKeys.GLOBAL_VARIABLES, scope) as List<VariableV1>) | |||
return (ops.get_collection(tf.GraphKeys.GLOBAL_VARIABLES, scope) as List<IVariableV1>) | |||
.ToArray(); | |||
} | |||
@@ -33,7 +33,7 @@ namespace Tensorflow | |||
/// <param name="var_list">List of `Variable` objects to initialize.</param> | |||
/// <param name="name">Optional name for the returned operation.</param> | |||
/// <returns>An Op that run the initializers of all the specified variables.</returns> | |||
public Operation variables_initializer(VariableV1[] var_list, string name = "init") | |||
public Operation variables_initializer(IVariableV1[] var_list, string name = "init") | |||
=> variables.variables_initializer(var_list, name: name); | |||
public Operation global_variables_initializer() | |||
@@ -47,8 +47,8 @@ namespace Tensorflow | |||
/// </summary> | |||
/// <param name="scope"></param> | |||
/// <returns></returns> | |||
public VariableV1[] trainable_variables(string scope = null) | |||
=> (variables.trainable_variables() as List<VariableV1>).ToArray(); | |||
public IVariableV1[] trainable_variables(string scope = null) | |||
=> (variables.trainable_variables() as List<IVariableV1>).ToArray(); | |||
public RefVariable get_variable(string name, | |||
TensorShape shape = null, | |||
@@ -8,6 +8,7 @@ namespace Tensorflow.Eager | |||
{ | |||
public int NumInputs; | |||
public Tensor[] Inputs { get; set; } | |||
public int[] SkipInputIndices { get; set; } | |||
public EagerOperation() : base(IntPtr.Zero) { } | |||
@@ -11,7 +11,17 @@ namespace Tensorflow | |||
public static extern void TFE_RegisterGradientFunction(_gradient_function_callback callbackPointer); | |||
[UnmanagedFunctionPointer(CallingConvention.StdCall)] | |||
public delegate IntPtr _gradient_function_callback(string op_name, int num_inputs, IntPtr op_inputs, int num_attrs, int num_outputs, IntPtr output_grads); | |||
public delegate IntPtr _gradient_function_callback(string op_name, | |||
int num_inputs, | |||
IntPtr op_inputs, | |||
int num_attrs, | |||
int num_outputs, | |||
IntPtr output_grads, | |||
int num_skip_inputs, | |||
IntPtr skip_input_indices); | |||
[DllImport(TensorFlowLibName)] | |||
public static extern IntPtr TFE_WrapGradientResult(IntPtr[] gradients, int num_gradients); | |||
[DllImport(TensorFlowLibName)] | |||
public static extern IntPtr VSpace_Handle(VSpace_callback_Ones ones, VSpace_callback_AggregateGrads aggregate_grads); | |||
@@ -373,11 +383,17 @@ namespace Tensorflow | |||
public static extern void TFE_TapeSetRemove(IntPtr tape); | |||
[DllImport(TensorFlowLibName)] | |||
public static extern void TFE_TapeWatch(IntPtr tape, IntPtr tensor); | |||
public static extern void TFE_TapeWatch(IntPtr tape, IntPtr variable); | |||
[DllImport(TensorFlowLibName)] | |||
public static extern void TFE_TapeVariableAccessed(IntPtr variable); | |||
[DllImport(TensorFlowLibName)] | |||
public static extern IntPtr TFE_TapeWatchedVariables(IntPtr tape); | |||
[DllImport(TensorFlowLibName)] | |||
public static extern IntPtr ResourceVariable_Handle(IntPtr variable); | |||
[DllImport(TensorFlowLibName)] | |||
public static extern IntPtr TFE_TapeGradient(IntPtr tape, | |||
IntPtr[] target, int target_size, | |||
@@ -35,7 +35,7 @@ namespace Tensorflow | |||
return meta_graph_def; | |||
} | |||
public static (Dictionary<string, VariableV1>, ITensorOrOperation[]) import_scoped_meta_graph_with_return_elements(MetaGraphDef meta_graph_or_file, | |||
public static (Dictionary<string, IVariableV1>, ITensorOrOperation[]) import_scoped_meta_graph_with_return_elements(MetaGraphDef meta_graph_or_file, | |||
bool clear_devices = false, | |||
string import_scope = "", | |||
Dictionary<string, Tensor> input_map = null, | |||
@@ -77,7 +77,7 @@ namespace Tensorflow | |||
return_elements: return_elements); | |||
// Restores all the other collections. | |||
var variable_objects = new Dictionary<ByteString, VariableV1>(); | |||
var variable_objects = new Dictionary<ByteString, IVariableV1>(); | |||
foreach (var col in meta_graph_def.CollectionDef.OrderBy(x => x.Key)) | |||
{ | |||
// Don't add unbound_inputs to the new graph. | |||
@@ -99,7 +99,7 @@ namespace Tensorflow | |||
{ | |||
foreach (var value in col.Value.BytesList.Value) | |||
{ | |||
VariableV1 variable = null; | |||
IVariableV1 variable = null; | |||
if (!variable_objects.ContainsKey(value)) | |||
{ | |||
var proto = VariableDef.Parser.ParseFrom(value); | |||
@@ -147,10 +147,10 @@ namespace Tensorflow | |||
} | |||
} | |||
var variables = graph.get_collection<VariableV1>(tf.GraphKeys.GLOBAL_VARIABLES, | |||
var variables = graph.get_collection<IVariableV1>(tf.GraphKeys.GLOBAL_VARIABLES, | |||
scope: scope_to_prepend_to_names); | |||
var var_list = new Dictionary<string, VariableV1>(); | |||
variables.ForEach(v => var_list[ops.strip_name_scope(v.name, scope_to_prepend_to_names)] = v); | |||
var var_list = new Dictionary<string, IVariableV1>(); | |||
variables.ForEach(v => var_list[ops.strip_name_scope(v.Name, scope_to_prepend_to_names)] = v); | |||
return (var_list, imported_return_elements); | |||
} | |||
@@ -168,7 +168,7 @@ namespace Tensorflow | |||
/// <param name="strip_default_attrs"></param> | |||
/// <param name="meta_info_def"></param> | |||
/// <returns></returns> | |||
public static (MetaGraphDef, Dictionary<string, VariableV1>) export_scoped_meta_graph(string filename = "", | |||
public static (MetaGraphDef, Dictionary<string, IVariableV1>) export_scoped_meta_graph(string filename = "", | |||
GraphDef graph_def = null, | |||
bool as_text = false, | |||
string unbound_inputs_col_name = "unbound_inputs", | |||
@@ -180,14 +180,14 @@ namespace Tensorflow | |||
{ | |||
var graph = ops.get_default_graph(); | |||
var var_list = new Dictionary<string, VariableV1>(); | |||
var variables = graph.get_collection<VariableV1>(tf.GraphKeys.GLOBAL_VARIABLES); | |||
var var_list = new Dictionary<string, IVariableV1>(); | |||
var variables = graph.get_collection<IVariableV1>(tf.GraphKeys.GLOBAL_VARIABLES); | |||
if (variables != null) | |||
{ | |||
foreach (var v in variables) | |||
{ | |||
var_list[v.name] = v; | |||
var_list[v.Name] = v; | |||
} | |||
} | |||
@@ -268,7 +268,7 @@ namespace Tensorflow | |||
switch (graph.get_collection(key)) | |||
{ | |||
case List<VariableV1> collection_list: | |||
case List<IVariableV1> collection_list: | |||
col_def.BytesList = new Types.BytesList(); | |||
foreach (var x in collection_list) | |||
{ | |||
@@ -1,109 +0,0 @@ | |||
using System; | |||
using System.Collections.Generic; | |||
using System.Linq; | |||
using System.Text; | |||
using Tensorflow.Eager; | |||
using static Tensorflow.Binding; | |||
namespace Tensorflow.Gradients | |||
{ | |||
/// <summary> | |||
/// Record operations for automatic differentiation. | |||
/// | |||
/// Operations are recorded if they are executed within this context manager and | |||
/// at least one of their inputs is being "watched". | |||
/// | |||
/// Trainable variables (created by `tf.Variable` or `tf.compat.v1.get_variable`, | |||
/// where `trainable=True` is default in both cases) are automatically watched. | |||
/// Tensors can be manually watched by invoking the `watch` method on this context | |||
/// manager. | |||
/// </summary> | |||
public class GradientActor : IDisposable | |||
{ | |||
bool _recording; | |||
bool _persistent; | |||
bool _watch_accessed_variables; | |||
bool _created_eagerly; | |||
Tape _tape; | |||
public GradientActor(bool persistent = false, | |||
bool watch_accessed_variables = true) | |||
{ | |||
_persistent = persistent; | |||
_watch_accessed_variables = watch_accessed_variables; | |||
_created_eagerly = tf.context.executing_eagerly(); | |||
_push_tape(); | |||
} | |||
private void _push_tape() | |||
{ | |||
if (_recording) | |||
throw new ValueError("Tape is still recording, This can happen if you try to " + | |||
"re-enter an already-active tape."); | |||
if (_tape == null) | |||
_tape = new Tape(_persistent, _watch_accessed_variables); | |||
else | |||
throw new NotImplementedException(""); | |||
_recording = true; | |||
} | |||
private void _pop_tape() | |||
{ | |||
if (!_recording) | |||
throw new ValueError("Tape is not recording."); | |||
_tape.pop_tape(_tape); | |||
_recording = false; | |||
} | |||
/// <summary> | |||
/// Marks this tensor to be watched by the given tape. | |||
/// </summary> | |||
/// <param name="x"></param> | |||
public void watch(Tensor x) | |||
{ | |||
_tape.watch(x as EagerTensor); | |||
} | |||
public Tensor gradient(Tensor target, Tensor source) | |||
{ | |||
if(_recording) | |||
{ | |||
if (!_persistent) | |||
_pop_tape(); | |||
} | |||
using var status = new Status(); | |||
var et = c_api.TFE_TapeGradient(_tape, | |||
new [] { (target as EagerTensor).EagerTensorHandle }, 1, | |||
new [] { (source as EagerTensor).EagerTensorHandle }, 1, | |||
status); | |||
status.Check(true); | |||
return new EagerTensor(et); | |||
} | |||
public Tensor gradient(Tensor target, ResourceVariable[] sources) | |||
{ | |||
if (_recording) | |||
{ | |||
if (!_persistent) | |||
_pop_tape(); | |||
} | |||
using var status = new Status(); | |||
EagerTensorHandle et = c_api.TFE_TapeGradient(_tape, | |||
new[] { (target as EagerTensor).EagerTensorHandle }, 1, | |||
sources.Select(x => (x.handle as EagerTensor).EagerTensorHandle).ToArray(), sources.Length, | |||
status); | |||
status.Check(true); | |||
return et; | |||
} | |||
public void Dispose() | |||
{ | |||
if (_recording) | |||
_pop_tape(); | |||
} | |||
} | |||
} |
@@ -1,6 +1,9 @@ | |||
using System; | |||
using Google.Protobuf.WellKnownTypes; | |||
using System; | |||
using System.Collections.Generic; | |||
using System.Linq; | |||
using System.Text; | |||
using Tensorflow.Eager; | |||
using static Tensorflow.Binding; | |||
namespace Tensorflow.Gradients | |||
@@ -16,16 +19,104 @@ namespace Tensorflow.Gradients | |||
/// Tensors can be manually watched by invoking the `watch` method on this context | |||
/// manager. | |||
/// </summary> | |||
public class GradientTape | |||
public class GradientTape : IDisposable | |||
{ | |||
bool _recording; | |||
bool _persistent; | |||
bool _watch_accessed_variables; | |||
ResourceVariable[] _watched_variables; | |||
bool _created_eagerly; | |||
Tape _tape; | |||
public GradientTape(bool persistent = false, | |||
bool watch_accessed_variables = true) | |||
{ | |||
_persistent = persistent; | |||
_watch_accessed_variables = watch_accessed_variables; | |||
_created_eagerly = tf.context.executing_eagerly(); | |||
_push_tape(); | |||
} | |||
private void _push_tape() | |||
{ | |||
if (_recording) | |||
throw new ValueError("Tape is still recording, This can happen if you try to " + | |||
"re-enter an already-active tape."); | |||
if (_tape == null) | |||
_tape = new Tape(_persistent, _watch_accessed_variables); | |||
else | |||
throw new NotImplementedException(""); | |||
_recording = true; | |||
} | |||
private void _pop_tape() | |||
{ | |||
if (!_recording) | |||
throw new ValueError("Tape is not recording."); | |||
_tape.pop_tape(_tape); | |||
_recording = false; | |||
} | |||
/// <summary> | |||
/// Marks this tensor to be watched by the given tape. | |||
/// </summary> | |||
/// <param name="x"></param> | |||
public void watch(Tensor x) | |||
{ | |||
_tape.watch(x as EagerTensor); | |||
} | |||
public Tensor gradient(Tensor target, Tensor source) | |||
{ | |||
if(_recording) | |||
{ | |||
if (!_persistent) | |||
_pop_tape(); | |||
} | |||
using var status = new Status(); | |||
var et = c_api.TFE_TapeGradient(_tape, | |||
new [] { (target as EagerTensor).EagerTensorHandle }, 1, | |||
new [] { (source as EagerTensor).EagerTensorHandle }, 1, | |||
status); | |||
status.Check(true); | |||
return new EagerTensor(et); | |||
} | |||
public unsafe (Tensor, Tensor) gradient(Tensor target, (ResourceVariable, ResourceVariable) sources) | |||
{ | |||
if (_recording) | |||
{ | |||
if (!_persistent) | |||
_pop_tape(); | |||
} | |||
using var status = new Status(); | |||
IntPtr et = c_api.TFE_TapeGradient(_tape, | |||
new IntPtr[] { target as EagerTensor }, 1, | |||
new IntPtr[] { sources.Item1.Handle as EagerTensor, sources.Item2.Handle as EagerTensor }, 2, | |||
status); | |||
status.Check(true); | |||
var results = new Tensor[2]; | |||
for (int i = 0; i < 2; i++) | |||
results[i] = new EagerTensor(*((IntPtr*)et + i)); | |||
if (!_persistent) | |||
{ | |||
// Keep track of watched variables before setting tape to None | |||
_watched_variables = _tape.watched_variables(); | |||
_tape = null; | |||
} | |||
return (results[0], results[1]); | |||
} | |||
public void Dispose() | |||
{ | |||
if (_recording) | |||
_pop_tape(); | |||
} | |||
} | |||
} |
@@ -1,5 +1,6 @@ | |||
using System; | |||
using System.Collections.Generic; | |||
using System.Runtime.InteropServices; | |||
using System.Text; | |||
using Tensorflow.Eager; | |||
@@ -7,7 +8,6 @@ namespace Tensorflow.Gradients | |||
{ | |||
public class Tape : DisposableObject | |||
{ | |||
public GradientTape tape { get; set; } | |||
public int nesting_id { get; set; } | |||
public Tape(bool persistent, bool watch_accessed_variables) | |||
@@ -27,7 +27,21 @@ namespace Tensorflow.Gradients | |||
public static void variable_accessed(ResourceVariable variable) | |||
{ | |||
c_api.TFE_TapeVariableAccessed(variable.handle as EagerTensor); | |||
c_api.TFE_TapeVariableAccessed(variable); | |||
} | |||
public unsafe ResourceVariable[] watched_variables() | |||
{ | |||
BindingArray result = c_api.TFE_TapeWatchedVariables(_handle); | |||
var variables = new ResourceVariable[result.length]; | |||
for (int i = 0; i < result.length; i++) | |||
{ | |||
var handle = *((IntPtr*)result.array + i); | |||
var tensor = c_api.ResourceVariable_Handle(handle); | |||
variables[i] = new ResourceVariable(handle, tensor); | |||
} | |||
return variables; | |||
} | |||
public static bool IsDtypeTrainable(DataType dtype) | |||
@@ -191,7 +191,7 @@ namespace Tensorflow.Gradients | |||
grad_ctxt.Enter(); | |||
var result = control_flow_ops._Enter( | |||
grad, grad_ctxt.name, is_constant: false, | |||
grad, grad_ctxt.Name, is_constant: false, | |||
parallel_iterations: grad_ctxt.parallel_iterations, | |||
name: "b_exit"); | |||
@@ -17,6 +17,7 @@ | |||
using NumSharp; | |||
using System; | |||
using System.Linq; | |||
using Tensorflow.Eager; | |||
using Tensorflow.Operations; | |||
using static Tensorflow.Binding; | |||
@@ -169,10 +170,28 @@ namespace Tensorflow.Gradients | |||
var x = op.inputs[0]; | |||
var y = op.inputs[1]; | |||
var grad = grads[0]; | |||
if (grad is Tensor && | |||
if (op is EagerOperation op_eager && | |||
op_eager.SkipInputIndices.Contains(1) && | |||
y.NDims == 0) | |||
{ | |||
return new Tensor[] | |||
{ | |||
gen_math_ops.mul(grad, math_ops.conj(y)), | |||
null | |||
}; | |||
} | |||
if (grad is Tensor && | |||
_ShapesFullySpecifiedAndEqual(x, y, grad) && | |||
new TF_DataType[] { tf.int32, tf.float32 }.Contains(grad.dtype)) | |||
return new Tensor[] { gen_math_ops.mul(grad, y), gen_math_ops.mul(grad, x) }; | |||
{ | |||
return new Tensor[] | |||
{ | |||
gen_math_ops.mul(grad, y), | |||
gen_math_ops.mul(grad, x) | |||
}; | |||
} | |||
var (sx, sy) = SmartBroadcastGradientArgs(x, y); | |||
var (rx, ry) = gen_array_ops.broadcast_gradient_args(sx, sy); | |||
@@ -180,15 +199,39 @@ namespace Tensorflow.Gradients | |||
x = math_ops.conj(x); | |||
y = math_ops.conj(y); | |||
var mul1 = gen_math_ops.mul(grad, y); | |||
var reduce_sum1 = math_ops.reduce_sum(mul1, rx); | |||
var reshape1 = gen_array_ops.reshape(reduce_sum1, sx); | |||
Tensor gx = null, gy = null; | |||
if (op is EagerOperation op_eager1 && | |||
op_eager1.SkipInputIndices.Contains(0)) | |||
{ | |||
return new Tensor[] | |||
{ | |||
gen_math_ops.mul(grad, math_ops.conj(y)), | |||
null | |||
}; | |||
} | |||
// else if not must_reduce_x: | |||
// gx = gen_math_ops.mul(grad, y) | |||
else | |||
{ | |||
gx = array_ops.reshape( | |||
math_ops.reduce_sum(gen_math_ops.mul(grad, y), rx), sx); | |||
} | |||
if (op is EagerOperation op_eager2 && | |||
op_eager2.SkipInputIndices.Contains(1)) | |||
{ | |||
var mul2 = gen_math_ops.mul(x, grad); | |||
var reduce_sum2 = math_ops.reduce_sum(mul2, ry); | |||
var reshape2 = gen_array_ops.reshape(reduce_sum2, sy); | |||
} | |||
// else if not must_reduce_y: | |||
// gy = gen_math_ops.mul(x, grad) | |||
else | |||
{ | |||
gy = array_ops.reshape( | |||
math_ops.reduce_sum(gen_math_ops.mul(x, grad), ry), sy); | |||
} | |||
return new Tensor[] { reshape1, reshape2 }; | |||
return new Tensor[] { gx, gy }; | |||
} | |||
[RegisterGradient("MatMul")] | |||
@@ -617,7 +660,9 @@ namespace Tensorflow.Gradients | |||
var x = op.inputs[0]; | |||
var y = op.inputs[1]; | |||
if (tf.context.executing_eagerly()) | |||
if (op is EagerOperation op_eager && | |||
op_eager.SkipInputIndices.Contains(1) && | |||
y.NDims == 0) | |||
{ | |||
x = math_ops.conj(x); | |||
y = math_ops.conj(y); | |||
@@ -444,7 +444,7 @@ namespace Tensorflow | |||
var collection = _collections.ContainsKey(name) ? _collections[name] : new List<T>(); | |||
switch (collection) | |||
{ | |||
case List<VariableV1> list: | |||
case List<IVariableV1> list: | |||
t = list.Select(x => (T)(object)x).ToList(); | |||
break; | |||
case List<ResourceVariable> list: | |||
@@ -37,8 +37,8 @@ namespace Tensorflow.Keras.Layers | |||
private IInitializer gamma_initializer; | |||
private IInitializer moving_mean_initializer; | |||
private IInitializer moving_variance_initializer; | |||
private VariableV1 gamma; | |||
private VariableV1 beta; | |||
private IVariableV1 gamma; | |||
private IVariableV1 beta; | |||
private RefVariable moving_mean; | |||
private RefVariable moving_variance; | |||
@@ -23,7 +23,7 @@ namespace Tensorflow.Keras.Layers | |||
private int input_dim; | |||
private int output_dim; | |||
private bool mask_zero; | |||
public VariableV1 embeddings; | |||
public IVariableV1 embeddings; | |||
public IInitializer embeddings_initializer; | |||
int input_length; | |||
@@ -51,8 +51,8 @@ namespace Tensorflow.Keras.Layers | |||
/// </summary> | |||
protected InputSpec input_spec; | |||
protected bool supports_masking; | |||
protected List<VariableV1> _trainable_weights; | |||
protected List<VariableV1> _non_trainable_weights; | |||
protected List<IVariableV1> _trainable_weights; | |||
protected List<IVariableV1> _non_trainable_weights; | |||
private string _name; | |||
public string name => _name; | |||
protected string _base_name; | |||
@@ -84,8 +84,8 @@ namespace Tensorflow.Keras.Layers | |||
this.supports_masking = false; | |||
_init_set_name(name); | |||
_trainable_weights = new List<VariableV1>(); | |||
_non_trainable_weights = new List<VariableV1>(); | |||
_trainable_weights = new List<IVariableV1>(); | |||
_non_trainable_weights = new List<IVariableV1>(); | |||
_compute_previous_mask = false; | |||
_updates = new List<Operation>(); | |||
@@ -207,12 +207,12 @@ namespace Tensorflow.Keras.Layers | |||
built = true; | |||
} | |||
protected virtual VariableV1 add_weight(string name, | |||
protected virtual IVariableV1 add_weight(string name, | |||
int[] shape, | |||
TF_DataType dtype = TF_DataType.DtInvalid, | |||
IInitializer initializer = null, | |||
bool? trainable = null, | |||
Func<string, int[], TF_DataType, IInitializer, bool, VariableV1> getter = null) | |||
Func<string, int[], TF_DataType, IInitializer, bool, IVariableV1> getter = null) | |||
{ | |||
if (dtype == TF_DataType.DtInvalid) | |||
dtype = TF_DataType.TF_FLOAT; | |||
@@ -10,5 +10,15 @@ namespace Tensorflow.Keras.Optimizers | |||
/// </summary> | |||
public class OptimizerV2 : Trackable, IOptimizer | |||
{ | |||
public OptimizerV2() : base() | |||
{ | |||
} | |||
public void apply_gradients((Tensor, Tensor) gradients, | |||
(ResourceVariable, ResourceVariable) vars) | |||
{ | |||
} | |||
} | |||
} |
@@ -4,9 +4,9 @@ using System.Text; | |||
namespace Tensorflow.Keras.Optimizers | |||
{ | |||
public class SGD | |||
public class SGD : OptimizerV2 | |||
{ | |||
public SGD(float learning_rate) | |||
public SGD(float learning_rate) : base() | |||
{ | |||
} | |||
@@ -32,7 +32,7 @@ namespace Tensorflow.Keras.Utils | |||
/// <param name="initializer"></param> | |||
/// <param name="trainable"></param> | |||
/// <returns></returns> | |||
public static VariableV1 make_variable(string name, | |||
public static IVariableV1 make_variable(string name, | |||
int[] shape, | |||
TF_DataType dtype = TF_DataType.TF_FLOAT, | |||
IInitializer initializer = null, | |||
@@ -42,14 +42,14 @@ namespace Tensorflow.Keras | |||
/// Allows to give unique autogenerated names to layers, in a graph-specific way. | |||
/// </summary> | |||
public static Dictionary<Graph, Dictionary<(string, string), int>> PER_GRAPH_LAYER_NAME_UIDS = new Dictionary<Graph, Dictionary<(string, string), int>>(); | |||
public static Dictionary<string, VariableV1> _GRAPH_VARIABLES = new Dictionary<string, VariableV1>(); | |||
public static Dictionary<string, IVariableV1> _GRAPH_VARIABLES = new Dictionary<string, IVariableV1>(); | |||
public static Dictionary<string, Optimizer> _GRAPH_TF_OPTIMIZERS = new Dictionary<string, Optimizer>(); | |||
public static _DummyEagerGraph _DUMMY_EAGER_GRAPH = new _DummyEagerGraph(); | |||
public static void track_variable(VariableV1 v) | |||
public static void track_variable(IVariableV1 v) | |||
{ | |||
var graph = v.graph; | |||
var graph = v.Graph; | |||
_GRAPH_VARIABLES[graph.graph_key] = v; | |||
} | |||
@@ -42,8 +42,8 @@ namespace Tensorflow.Layers | |||
this._reuse = _reuse; | |||
// Avoid an incorrect lint error | |||
_trainable_weights = new List<VariableV1>(); | |||
_non_trainable_weights = new List<VariableV1>(); | |||
_trainable_weights = new List<IVariableV1>(); | |||
_non_trainable_weights = new List<IVariableV1>(); | |||
this.built = false; | |||
_keras_style = false; | |||
} | |||
@@ -116,7 +116,7 @@ namespace Tensorflow.Layers | |||
/// <param name="synchronization"></param> | |||
/// <param name="aggregation"></param> | |||
/// <returns></returns> | |||
protected virtual VariableV1 add_weight(string name, | |||
protected virtual IVariableV1 add_weight(string name, | |||
int[] shape, | |||
TF_DataType dtype = TF_DataType.DtInvalid, | |||
IInitializer initializer = null, | |||
@@ -126,7 +126,7 @@ namespace Tensorflow.Layers | |||
{ | |||
var default_graph = ops.get_default_graph(); | |||
Graph init_graph = null; | |||
VariableV1[] existing_variables = null; | |||
IVariableV1[] existing_variables = null; | |||
if (synchronization == VariableSynchronization.OnRead) | |||
trainable = false; | |||
@@ -77,7 +77,7 @@ namespace Tensorflow.Operations | |||
_external_values = new Dictionary<string, ITensorOrOperation>(); | |||
} | |||
public string name { get => _name; } | |||
public string Name { get => _name; } | |||
protected string _name; | |||
public void __init__(ValuesDef values_def = null, string import_scope = null) | |||
@@ -141,7 +141,7 @@ namespace Tensorflow.Operations.ControlFlows | |||
parallel_iterations: forward_ctxt.parallel_iterations, | |||
back_prop: forward_ctxt.back_prop, | |||
swap_memory: forward_ctxt.swap_memory, | |||
name: forward_ctxt.name, | |||
name: forward_ctxt.Name, | |||
grad_state: this); | |||
_grad_index = _grad_context.AddBackpropLoopCounter(cnt, outer_grad_state); | |||
if (outer_forward_ctxt != null) | |||
@@ -21,8 +21,8 @@ namespace Tensorflow | |||
bool _state_is_tuple; | |||
IActivation _activation; | |||
LSTMStateTuple _state; | |||
VariableV1 _kernel; | |||
VariableV1 _bias; | |||
IVariableV1 _kernel; | |||
IVariableV1 _bias; | |||
string _WEIGHTS_VARIABLE_NAME = "kernel"; | |||
string _BIAS_VARIABLE_NAME = "bias"; | |||
@@ -28,9 +28,9 @@ namespace Tensorflow | |||
public override object state_size => _num_units; | |||
public override int output_size => _num_units; | |||
public VariableV1 _kernel; | |||
public IVariableV1 _kernel; | |||
string _WEIGHTS_VARIABLE_NAME = "kernel"; | |||
public VariableV1 _bias; | |||
public IVariableV1 _bias; | |||
string _BIAS_VARIABLE_NAME = "bias"; | |||
public BasicRnnCell(int num_units, | |||
@@ -64,6 +64,7 @@ namespace Tensorflow | |||
bool _is_stateful; | |||
public NodeDef node_def | |||
{ | |||
get | |||
@@ -61,7 +61,7 @@ namespace Tensorflow | |||
/// <param name="name"></param> | |||
/// <param name="max_norm"></param> | |||
/// <returns></returns> | |||
public static Tensor _embedding_lookup_and_transform(VariableV1 @params, | |||
public static Tensor _embedding_lookup_and_transform(IVariableV1 @params, | |||
Tensor ids, | |||
string partition_strategy = "mod", | |||
string name = null, | |||
@@ -131,7 +131,7 @@ namespace Tensorflow | |||
max_norm: max_norm); | |||
} | |||
public static Tensor embedding_lookup(VariableV1 @params, Tensor ids, | |||
public static Tensor embedding_lookup(IVariableV1 @params, Tensor ids, | |||
string partition_strategy = "mod", | |||
string name = null, | |||
bool validate_indices = true, | |||
@@ -821,7 +821,7 @@ namespace Tensorflow | |||
{ | |||
x as EagerTensor, | |||
y as EagerTensor, | |||
}, 1, null, status); | |||
}, 2, null, status); | |||
status.Check(true); | |||
return tensor; | |||
} | |||
@@ -98,8 +98,8 @@ namespace Tensorflow | |||
/// <param name="name"></param> | |||
/// <returns></returns> | |||
public static Tensor[] fused_batch_norm(Tensor x, | |||
VariableV1 scale, | |||
VariableV1 offset, | |||
IVariableV1 scale, | |||
IVariableV1 offset, | |||
Tensor mean, | |||
Tensor variance, | |||
float epsilon = 0.001f, | |||
@@ -15,6 +15,7 @@ | |||
******************************************************************************/ | |||
using System; | |||
using System.Linq; | |||
using Tensorflow.Framework; | |||
using static Tensorflow.CppShapeInferenceResult.Types; | |||
@@ -70,7 +71,7 @@ namespace Tensorflow | |||
throw new NotImplementedException(); | |||
} | |||
public static bool is_resource_variable(VariableV1 var) | |||
public static bool is_resource_variable(IVariableV1 var) | |||
{ | |||
return var is ResourceVariable; | |||
} | |||
@@ -128,14 +129,34 @@ namespace Tensorflow | |||
// When in eager mode, explicitly ensure so here. When in graph mode, it's | |||
// ensured by always generating different variable names. | |||
var exists = gen_resource_variable_ops.var_is_initialized_op(handle); | |||
} | |||
return handle; | |||
// We create an assert Op instead of checking right away in order to be | |||
// compatible with ASYNC execution mode. Further, since not all devices | |||
// support string tensors, we encode the assertion string in the Op name | |||
/*gen_logging_ops._assert( | |||
math_ops.logical_not(exists), [exists], name = "EagerVariableNameReuse");*/ | |||
var handle_data = new HandleData(); | |||
handle_data.IsSet = true; | |||
handle_data.ShapeAndType.Add(new HandleShapeAndType | |||
{ | |||
Dtype = dtype.as_datatype_enum(), | |||
Shape = shape.as_proto() | |||
}); | |||
_set_handle_shapes_and_types(handle, handle_data, graph_mode); | |||
return handle; | |||
} | |||
} | |||
private static void _set_handle_shapes_and_types(Tensor handle, HandleData full_handle_data, bool graph_mode) | |||
/// <summary> | |||
/// Sets the shape inference result HandleData on tensor. | |||
/// </summary> | |||
/// <param name="handle"></param> | |||
/// <param name="full_handle_data"></param> | |||
/// <param name="graph_mode"></param> | |||
private static void _set_handle_shapes_and_types(Tensor handle, HandleData handle_data, bool graph_mode) | |||
{ | |||
if (!graph_mode) | |||
return; | |||
} | |||
/// <summary> | |||
@@ -171,20 +192,5 @@ namespace Tensorflow | |||
return HandleData.Parser.ParseFrom(handle.BufferToArray()); | |||
} | |||
} | |||
/// <summary> | |||
/// Represents a future for a read of a variable. | |||
/// Pretends to be the tensor if anyone looks. | |||
/// </summary> | |||
public class _UnreadVariable : BaseResourceVariable | |||
{ | |||
} | |||
/// <summary> | |||
/// A python variable from an existing handle. | |||
/// </summary> | |||
public class BaseResourceVariable : VariableV1 | |||
{ | |||
} | |||
} | |||
} |
@@ -6,7 +6,7 @@ | |||
/// </summary> | |||
public interface IProtoBuf<TProtoDef, TDef> | |||
{ | |||
string name { get; } | |||
string Name { get; } | |||
/// <summary> | |||
/// Converts a `Variable` to a `VariableDef` protocol buffer. | |||
@@ -31,10 +31,16 @@ https://tensorflownet.readthedocs.io</Description> | |||
<PackageRequireLicenseAcceptance>true</PackageRequireLicenseAcceptance> | |||
<SignAssembly>true</SignAssembly> | |||
<AssemblyOriginatorKeyFile>Open.snk</AssemblyOriginatorKeyFile> | |||
<Platforms>AnyCPU</Platforms> | |||
<Platforms>AnyCPU;x64</Platforms> | |||
</PropertyGroup> | |||
<PropertyGroup Condition="'$(Configuration)|$(Platform)'=='Debug|AnyCPU'"> | |||
<AllowUnsafeBlocks>true</AllowUnsafeBlocks> | |||
<DefineConstants>TRACE;DEBUG</DefineConstants> | |||
<PlatformTarget>AnyCPU</PlatformTarget> | |||
</PropertyGroup> | |||
<PropertyGroup Condition="'$(Configuration)|$(Platform)'=='Debug|x64'"> | |||
<AllowUnsafeBlocks>true</AllowUnsafeBlocks> | |||
<DefineConstants>TRACE;DEBUG</DefineConstants> | |||
<PlatformTarget>x64</PlatformTarget> | |||
@@ -44,6 +50,10 @@ https://tensorflownet.readthedocs.io</Description> | |||
<AllowUnsafeBlocks>true</AllowUnsafeBlocks> | |||
</PropertyGroup> | |||
<PropertyGroup Condition="'$(Configuration)|$(Platform)'=='Release|x64'"> | |||
<AllowUnsafeBlocks>true</AllowUnsafeBlocks> | |||
</PropertyGroup> | |||
<ItemGroup> | |||
<Compile Remove="Distribute\**" /> | |||
<Compile Remove="Models\**" /> | |||
@@ -111,7 +111,7 @@ namespace Tensorflow.Train | |||
protected override void _create_slots(RefVariable[] var_list) | |||
{ | |||
var first_var = var_list.OrderBy(x => x.name).First(); | |||
var first_var = var_list.OrderBy(x => x.Name).First(); | |||
_create_non_slot_variable(initial_value: _beta1, name: "beta1_power", colocate_with: first_var); | |||
_create_non_slot_variable(initial_value: _beta2, name: "beta2_power", colocate_with: first_var); | |||
@@ -44,7 +44,7 @@ namespace Tensorflow | |||
public Tensor LearningRateTensor => _lr_t; | |||
public bool _use_locking; | |||
public Dictionary<string, Dictionary<string, RefVariable>> _slots; | |||
public Dictionary<string, VariableV1> _non_slot_dict; | |||
public Dictionary<string, IVariableV1> _non_slot_dict; | |||
public Dictionary<string, object> _deferred_slot_restorations; | |||
SlotCreator slot_creator = new SlotCreator(); | |||
@@ -58,7 +58,7 @@ namespace Tensorflow | |||
_lr = learning_rate; | |||
// Dictionary of slots. | |||
_slots = new Dictionary<string, Dictionary<string, RefVariable>>(); | |||
_non_slot_dict = new Dictionary<string, VariableV1>(); | |||
_non_slot_dict = new Dictionary<string, IVariableV1>(); | |||
_deferred_slot_restorations = new Dictionary<string, object>(); | |||
} | |||
@@ -72,7 +72,7 @@ namespace Tensorflow | |||
_lr_t = learning_rate; | |||
// Dictionary of slots. | |||
_slots = new Dictionary<string, Dictionary<string, RefVariable>>(); | |||
_non_slot_dict = new Dictionary<string, VariableV1>(); | |||
_non_slot_dict = new Dictionary<string, IVariableV1>(); | |||
_deferred_slot_restorations = new Dictionary<string, object>(); | |||
} | |||
@@ -122,7 +122,7 @@ namespace Tensorflow | |||
var vars_with_grad = grads_and_vars.Where(x => x.Item1 != null).Select(x => x.Item2).ToArray(); | |||
if (vars_with_grad.Length == 0) | |||
throw new ValueError($"No gradients provided for any variable, check your graph for ops" + | |||
$" that do not support gradients, between variables {string.Join(",", vars_with_grad.Select(x => x.name))} and loss {loss}."); | |||
$" that do not support gradients, between variables {string.Join(",", vars_with_grad.Select(x => x.Name))} and loss {loss}."); | |||
return apply_gradients(grads_and_vars, global_step:global_step, name:name); | |||
} | |||
@@ -175,7 +175,7 @@ namespace Tensorflow | |||
if (grad == null) | |||
continue; | |||
var scope_name = var.op.name; | |||
var scope_name = var.Op.name; | |||
tf_with(ops.name_scope("update_" + scope_name), scope2 => | |||
{ | |||
var op = processor.update_op(this, grad); | |||
@@ -241,10 +241,10 @@ namespace Tensorflow | |||
/// <param name="initial_value"></param> | |||
/// <param name="name"></param> | |||
/// <param name="colocate_with"></param> | |||
protected VariableV1 _create_non_slot_variable(float initial_value, string name, RefVariable colocate_with) | |||
protected IVariableV1 _create_non_slot_variable(float initial_value, string name, RefVariable colocate_with) | |||
{ | |||
// Recommendation: Use OptimizerV2 if your optimizer uses non-slot variables. | |||
var graph = colocate_with.graph; | |||
var graph = colocate_with.Graph; | |||
var key = $"{name}.{graph.graph_key}"; | |||
var v = _non_slot_dict.ContainsKey(key) ? _non_slot_dict[key] : null; | |||
if(v == null) | |||
@@ -333,10 +333,10 @@ namespace Tensorflow | |||
private string _var_key(RefVariable var) | |||
{ | |||
return $"{var.op.graph.graph_key}.{var.op.name}"; | |||
return $"{var.Op.graph.graph_key}.{var.Op.name}"; | |||
} | |||
protected VariableV1 _get_non_slot_variable(string name, Graph graph = null) | |||
protected IVariableV1 _get_non_slot_variable(string name, Graph graph = null) | |||
{ | |||
var key = $"{name}.{graph.graph_key}"; | |||
var non_slot = _non_slot_dict.ContainsKey(key) ? _non_slot_dict[key] : null; | |||
@@ -385,7 +385,7 @@ namespace Tensorflow | |||
case List<RefVariable> values: | |||
var_list = values.Concat(vars).ToList(); | |||
break; | |||
case List<VariableV1> values: | |||
case List<IVariableV1> values: | |||
var_list = values.Select(x => x as RefVariable).Concat(vars).ToList(); | |||
break; | |||
} | |||
@@ -79,7 +79,7 @@ namespace Tensorflow | |||
return gen_io_ops.restore_v2(filename_tensor, names.ToArray(), slices.ToArray(), dtypes.ToArray()); | |||
} | |||
public virtual SaverDef _build_internal(VariableV1[] names_to_saveables, | |||
public virtual SaverDef _build_internal(IVariableV1[] names_to_saveables, | |||
bool reshape = false, | |||
bool sharded = false, | |||
int max_to_keep = 5, | |||
@@ -22,7 +22,7 @@ namespace Tensorflow | |||
Tensor[] bulk_restore(Tensor filename_tensor, SaveableObject[] saveables, int preferred_shard, bool restore_sequentially); | |||
SaverDef _build_internal(VariableV1[] names_to_saveables, | |||
SaverDef _build_internal(IVariableV1[] names_to_saveables, | |||
bool reshape = false, | |||
bool sharded = false, | |||
int max_to_keep = 5, | |||
@@ -29,7 +29,7 @@ namespace Tensorflow | |||
/// </summary> | |||
public class Saver | |||
{ | |||
private VariableV1[] _var_list; | |||
private IVariableV1[] _var_list; | |||
private bool _reshape; | |||
private bool _sharded; | |||
private int _max_to_keep; | |||
@@ -50,7 +50,7 @@ namespace Tensorflow | |||
private Dictionary<string, float> _last_checkpoints; | |||
private Dictionary<string, float> _checkpoints_to_be_deleted; | |||
public Saver(VariableV1[] var_list = null, | |||
public Saver(IVariableV1[] var_list = null, | |||
bool reshape = false, | |||
bool sharded = false, | |||
int max_to_keep = 5, | |||
@@ -28,7 +28,7 @@ namespace Tensorflow | |||
/// </summary> | |||
/// <param name="names_to_saveables"></param> | |||
/// <returns></returns> | |||
public static SaveableObject[] validate_and_slice_inputs(VariableV1[] names_to_saveables) | |||
public static SaveableObject[] validate_and_slice_inputs(IVariableV1[] names_to_saveables) | |||
{ | |||
var names_to_saveables_dict = op_list_to_dict(names_to_saveables); | |||
var saveables = new List<SaveableObject>(); | |||
@@ -76,9 +76,9 @@ namespace Tensorflow | |||
} | |||
} | |||
public static Dictionary<string, Tensor> op_list_to_dict(VariableV1[] op_list, bool convert_variable_to_tensor = true) | |||
public static Dictionary<string, Tensor> op_list_to_dict(IVariableV1[] op_list, bool convert_variable_to_tensor = true) | |||
{ | |||
op_list = op_list.OrderBy(x => x.name).ToArray(); | |||
op_list = op_list.OrderBy(x => x.Name).ToArray(); | |||
var names_to_saveables = new Dictionary<string, Tensor>(); | |||
foreach(var var in op_list) | |||
@@ -103,7 +103,7 @@ namespace Tensorflow | |||
if (convert_variable_to_tensor) | |||
{ | |||
if (var is ResourceVariable) | |||
tensor = var.graph_element; | |||
tensor = var.GraphElement; | |||
else | |||
tensor = ops.internal_convert_to_tensor(var, as_ref: true); | |||
} | |||
@@ -111,7 +111,7 @@ namespace Tensorflow | |||
if (tensor.op.type == "ReadVariableOp") | |||
name = tensor.op.inputs[0].op.name; | |||
else | |||
name = var.op.name; | |||
name = var.Op.name; | |||
if (names_to_saveables.ContainsKey(name)) | |||
throw new ValueError($"At least two variables have the same name: {name}"); | |||
@@ -53,7 +53,7 @@ namespace Tensorflow | |||
/// <returns></returns> | |||
public static Saver _create_saver_from_imported_meta_graph(MetaGraphDef meta_graph_def, | |||
string import_scope, | |||
Dictionary<string, VariableV1> imported_vars) | |||
Dictionary<string, IVariableV1> imported_vars) | |||
{ | |||
if(meta_graph_def.SaverDef != null) | |||
{ | |||
@@ -64,7 +64,7 @@ namespace Tensorflow | |||
{ | |||
var sample_key = var_names[0]; | |||
var sample_var = imported_vars[sample_key]; | |||
scope = string.Join("", sample_var.name.Skip(sample_key.Length)); | |||
scope = string.Join("", sample_var.Name.Skip(sample_key.Length)); | |||
} | |||
return new Saver(saver_def: meta_graph_def.SaverDef, name: scope); | |||
} | |||
@@ -33,7 +33,7 @@ namespace Tensorflow.Train | |||
public RefVariable create_slot(RefVariable primary, Tensor val, string name, bool colocate_with_primary = true) | |||
{ | |||
var validate_shape = val.TensorShape.is_fully_defined(); | |||
var prefix = primary.op.name; | |||
var prefix = primary.Op.name; | |||
return tf_with(tf.variable_scope(name: null, prefix + "/" + name), delegate | |||
{ | |||
return _create_slot_var(primary, val, "", validate_shape, null, TF_DataType.DtInvalid); | |||
@@ -74,7 +74,7 @@ namespace Tensorflow.Train | |||
TF_DataType dtype, string name, bool colocate_with_primary = true) | |||
{ | |||
var validate_shape = shape.is_fully_defined(); | |||
var prefix = primary.op.name; | |||
var prefix = primary.Op.name; | |||
return tf_with(new variable_scope(string.Empty, prefix + "/" + name), delegate | |||
{ | |||
return _create_slot_var(primary, initializer, "", validate_shape, shape, dtype); | |||
@@ -91,7 +91,7 @@ namespace Tensorflow.Train | |||
/// <param name="shape"></param> | |||
/// <param name="dtype"></param> | |||
/// <returns></returns> | |||
private RefVariable _create_slot_var(VariableV1 primary, object val, string scope, bool validate_shape, | |||
private RefVariable _create_slot_var(IVariableV1 primary, object val, string scope, bool validate_shape, | |||
TensorShape shape, TF_DataType dtype) | |||
{ | |||
bool use_resource = primary is ResourceVariable; | |||
@@ -26,11 +26,11 @@ namespace Tensorflow.Train | |||
/// Restore-on-create for a variable be saved with this `Checkpointable`. | |||
/// </summary> | |||
/// <returns></returns> | |||
protected virtual VariableV1 _add_variable_with_custom_getter(string name, | |||
protected virtual IVariableV1 _add_variable_with_custom_getter(string name, | |||
int[] shape, | |||
TF_DataType dtype = TF_DataType.TF_FLOAT, | |||
IInitializer initializer = null, | |||
Func<string, int[], TF_DataType, IInitializer, bool, VariableV1> getter = null, | |||
Func<string, int[], TF_DataType, IInitializer, bool, IVariableV1> getter = null, | |||
bool overwrite = false, | |||
bool trainable = false) | |||
{ | |||
@@ -53,13 +53,13 @@ namespace Tensorflow.Train | |||
/// </summary> | |||
/// <param name="name"></param> | |||
/// <param name="trackable"></param> | |||
protected void _handle_deferred_dependencies(string name, VariableV1 trackable) | |||
protected void _handle_deferred_dependencies(string name, IVariableV1 trackable) | |||
{ | |||
_maybe_initialize_trackable(); | |||
// TODO | |||
} | |||
protected VariableV1 _track_checkpointable(VariableV1 checkpointable, string name, bool overwrite = false) | |||
protected IVariableV1 _track_checkpointable(IVariableV1 checkpointable, string name, bool overwrite = false) | |||
{ | |||
return checkpointable; | |||
} | |||
@@ -62,7 +62,7 @@ namespace Tensorflow.Train | |||
var g = graph.as_default(); | |||
g.name_scope(null); | |||
g.name_scope(global_step_tensor.op.name + "/"); | |||
g.name_scope(global_step_tensor.Op.name + "/"); | |||
// using initialized_value to ensure that global_step is initialized before | |||
// this run. This is needed for example Estimator makes all model_fn build | |||
// under global_step_read_tensor dependency. | |||
@@ -0,0 +1,31 @@ | |||
/***************************************************************************** | |||
Copyright 2018 The TensorFlow.NET Authors. All Rights Reserved. | |||
Licensed under the Apache License, Version 2.0 (the "License"); | |||
you may not use this file except in compliance with the License. | |||
You may obtain a copy of the License at | |||
http://www.apache.org/licenses/LICENSE-2.0 | |||
Unless required by applicable law or agreed to in writing, software | |||
distributed under the License is distributed on an "AS IS" BASIS, | |||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
See the License for the specific language governing permissions and | |||
limitations under the License. | |||
******************************************************************************/ | |||
using System; | |||
using System.Runtime.InteropServices; | |||
namespace Tensorflow | |||
{ | |||
[StructLayout(LayoutKind.Sequential)] | |||
public struct BindingArray | |||
{ | |||
public IntPtr array; | |||
public int length; | |||
public static implicit operator BindingArray(IntPtr handle) | |||
=> Marshal.PtrToStructure<BindingArray>(handle); | |||
} | |||
} |
@@ -2,13 +2,18 @@ | |||
using System; | |||
using System.Collections.Generic; | |||
using System.Text; | |||
using Tensorflow.Eager; | |||
using Tensorflow.Gradients; | |||
using static Tensorflow.Binding; | |||
namespace Tensorflow | |||
{ | |||
public class BaseResourceVariable : VariableV1 | |||
public class BaseResourceVariable : DisposableObject, IVariableV1 | |||
{ | |||
protected string _name; | |||
public virtual string Name => _handle_name; | |||
protected TF_DataType _dtype; | |||
public TF_DataType dtype => _dtype; | |||
protected string _handle_name; | |||
protected string handle_name => _handle_name; | |||
@@ -26,17 +31,30 @@ namespace Tensorflow | |||
protected Tensor _parent_op; | |||
public Tensor parent_op => _parent_op; | |||
protected Tensor _handle; | |||
/// <summary> | |||
/// Variable handle | |||
/// Tensor handle | |||
/// </summary> | |||
public Tensor handle => _handle; | |||
protected Tensor handle; | |||
public Tensor Handle => handle; | |||
protected Tensor _graph_element; | |||
public Tensor GraphElement => _graph_element; | |||
protected TensorShape _shape; | |||
public TensorShape shape => _shape; | |||
public BaseResourceVariable() : base() | |||
protected Operation initializer_op; | |||
public Operation Initializer => initializer_op; | |||
public Operation Op => handle.op; | |||
public Graph Graph => handle.graph; | |||
public BaseResourceVariable() | |||
{ | |||
_handle = c_api.TFE_NewResourceVariable(); | |||
} | |||
public BaseResourceVariable(IntPtr handle, IntPtr tensor) | |||
{ | |||
_handle = handle; | |||
this.handle = new EagerTensor(tensor); | |||
} | |||
public void __init__(bool trainable = true, | |||
@@ -48,15 +66,17 @@ namespace Tensorflow | |||
_trainable = trainable; | |||
_handle_name = handle_name + ":0"; | |||
_unique_id = unique_id; | |||
_handle = handle; | |||
this.handle = handle; | |||
_name = name; | |||
// handle_deleter | |||
} | |||
public override BaseResourceVariable assign(object value, bool use_locking = false, string name = null, bool read_value = true) | |||
public BaseResourceVariable assign(object value, bool use_locking = false, string name = null, bool read_value = true) | |||
{ | |||
var value_tensor = ops.convert_to_tensor(value, dtype: dtype); | |||
var assign_op = gen_resource_variable_ops.assign_variable_op( | |||
_handle, value_tensor, name: name); | |||
handle, value_tensor, name: name); | |||
if (read_value) | |||
return _lazy_read(assign_op, value_tensor); | |||
return null; | |||
@@ -67,7 +87,7 @@ namespace Tensorflow | |||
protected Tensor _read_variable_op() | |||
{ | |||
variable_accessed(this); | |||
var result = gen_resource_variable_ops.read_variable_op(_handle, _dtype); | |||
var result = gen_resource_variable_ops.read_variable_op(handle, _dtype); | |||
// _maybe_set_handle_data(_dtype, _handle, result); | |||
return result; | |||
} | |||
@@ -75,7 +95,7 @@ namespace Tensorflow | |||
BaseResourceVariable _lazy_read(Operation op, Tensor value) | |||
{ | |||
variable_accessed(this); | |||
return new _UnreadVariable(_handle, _dtype, _shape, _in_graph_mode, _unique_id); | |||
return new _UnreadVariable(handle, _dtype, _shape, _in_graph_mode, _unique_id); | |||
} | |||
/// <summary> | |||
@@ -102,8 +122,13 @@ namespace Tensorflow | |||
}); | |||
public override string ToString() | |||
=> $"tf.Variable '{name}' shape={shape} dtype={dtype.as_numpy_name()}, numpy={numpy()}"; | |||
=> $"tf.Variable '{Name}' shape={shape} dtype={dtype.as_numpy_name()}, numpy={numpy()}"; | |||
public NDArray numpy() => read_value().numpy(); | |||
protected override void DisposeUnmanagedResources(IntPtr handle) | |||
{ | |||
// delete | |||
} | |||
} | |||
} |
@@ -1,5 +1,5 @@ | |||
/***************************************************************************** | |||
Copyright 2018 The TensorFlow.NET Authors. All Rights Reserved. | |||
Copyright 2020 The TensorFlow.NET Authors. All Rights Reserved. | |||
Licensed under the Apache License, Version 2.0 (the "License"); | |||
you may not use this file except in compliance with the License. | |||
@@ -29,39 +29,13 @@ namespace Tensorflow | |||
/// the variable are fixed. The value can be changed using one of the assign methods. | |||
/// https://tensorflow.org/guide/variables | |||
/// </summary> | |||
public abstract class VariableV1 | |||
public interface IVariableV1 | |||
{ | |||
protected string _name; | |||
public virtual string name { get; } | |||
public virtual Tensor graph_element { get; } | |||
public virtual Operation op { get; } | |||
public virtual Operation initializer { get; } | |||
public Tensor _variable; | |||
protected string _graph_key; | |||
public Graph graph => _variable.graph; | |||
public Tensor _is_initialized_op { get; set; } | |||
protected TF_DataType _dtype; | |||
public TF_DataType dtype => _dtype; | |||
public VariableV1() | |||
{ | |||
} | |||
public virtual Tensor eval() | |||
{ | |||
throw new NotImplementedException(""); | |||
} | |||
public virtual BaseResourceVariable assign(object value, bool use_locking = false, string name = null, bool read_value = true) | |||
{ | |||
throw new NotImplementedException(""); | |||
/*var assign = gen_state_ops.assign(_variable, value, use_locking: use_locking, name: name); | |||
if (read_value) | |||
return assign; | |||
return assign.op;*/ | |||
} | |||
public string Name { get; } | |||
public Tensor Handle { get; } | |||
public Operation Initializer { get; } | |||
public Operation Op { get; } | |||
public Tensor GraphElement { get; } | |||
public Graph Graph { get; } | |||
} | |||
} |
@@ -22,8 +22,19 @@ using static Tensorflow.Binding; | |||
namespace Tensorflow | |||
{ | |||
public partial class RefVariable : VariableV1, IProtoBuf<VariableDef, RefVariable> | |||
public partial class RefVariable : IVariableV1, IProtoBuf<VariableDef, RefVariable> | |||
{ | |||
protected string _name; | |||
public Tensor GraphElement { get; } | |||
public Tensor _variable; | |||
public Tensor Handle => _variable; | |||
protected string _graph_key; | |||
public Graph Graph => _variable.graph; | |||
public Tensor _is_initialized_op { get; set; } | |||
protected TF_DataType _dtype; | |||
public bool _in_graph_mode = true; | |||
public Tensor _initial_value; | |||
public bool _trainable; | |||
@@ -32,13 +43,13 @@ namespace Tensorflow | |||
public bool _save_slice_info; | |||
private Operation _initializer_op; | |||
public override Operation initializer => _initializer_op; | |||
public override Operation op => _variable.op; | |||
public Operation Initializer => _initializer_op; | |||
public Operation Op => _variable.op; | |||
public TF_DataType dtype => _variable.dtype; | |||
public TensorShape shape => tensor_util.to_shape(_variable.shape); | |||
public override string name => _variable.name; | |||
public string Name => _variable.name; | |||
public Tensor eval() => _variable; | |||
@@ -198,7 +209,7 @@ namespace Tensorflow | |||
_snapshot = gen_array_ops.identity(_variable, name = "read"); | |||
} | |||
ops.add_to_collections(collections, this as VariableV1); | |||
ops.add_to_collections(collections, this as IVariableV1); | |||
}); | |||
}); | |||
} | |||
@@ -299,7 +310,7 @@ namespace Tensorflow | |||
tf.GraphKeys.LOCAL_VARIABLES }) | |||
{ | |||
foreach (var var in variable_op.graph.get_collection<RefVariable>(collection_name)) | |||
if (var_names.Contains(var.name)) | |||
if (var_names.Contains(var.Name)) | |||
return var.initialized_value(); | |||
} | |||
@@ -330,7 +341,7 @@ namespace Tensorflow | |||
public override string ToString() | |||
{ | |||
return $"tf.RefVariable '{name}' shape={shape} dtype={dtype}"; | |||
return $"tf.RefVariable '{Name}' shape={shape} dtype={dtype}"; | |||
} | |||
public VariableDef to_proto(string export_scope) | |||
@@ -342,7 +353,7 @@ namespace Tensorflow | |||
if (_initial_value != null) | |||
var_def.InitialValueName = ops.strip_name_scope(_initial_value.name, export_scope); | |||
var_def.Trainable = _trainable; | |||
var_def.InitializerName = ops.strip_name_scope(initializer.name, export_scope); | |||
var_def.InitializerName = ops.strip_name_scope(Initializer.name, export_scope); | |||
var_def.SnapshotName = ops.strip_name_scope(_snapshot.name, export_scope); | |||
if (_save_slice_info) | |||
throw new NotImplementedException("to_proto _save_slice_info"); | |||
@@ -1,4 +1,7 @@ | |||
namespace Tensorflow | |||
using System; | |||
using Tensorflow.Eager; | |||
namespace Tensorflow | |||
{ | |||
public partial class ResourceVariable | |||
{ | |||
@@ -13,14 +16,20 @@ | |||
} | |||
public static implicit operator Tensor(ResourceVariable var) | |||
=> var.handle; | |||
=> var.Handle; | |||
public static implicit operator EagerTensor(ResourceVariable var) | |||
=> var.Handle as EagerTensor; | |||
public static implicit operator ResourceVariable(Tensor var) | |||
=> var.ResourceVar; | |||
/*public static implicit operator ResourceVariable(Tensor var) | |||
=> var.ResourceVar;*/ | |||
public static implicit operator RefVariable(ResourceVariable var) | |||
{ | |||
return null; | |||
} | |||
public static implicit operator IntPtr(ResourceVariable var) | |||
=> var._handle; | |||
} | |||
} |
@@ -31,7 +31,7 @@ namespace Tensorflow | |||
public static Tensor operator -(ResourceVariable x, double y) => op_helper("sub", x, y); | |||
public static Tensor operator -(ResourceVariable x, Tensor y) => op_helper("sub", x, y); | |||
public static Tensor operator *(ResourceVariable x, ResourceVariable y) => gen_math_ops.mul(x, y); | |||
public static Tensor operator *(ResourceVariable x, ResourceVariable y) => op_helper("mul", x, y); | |||
public static Tensor operator *(ResourceVariable x, NDArray y) => op_helper("mul", x, y); | |||
public static Tensor operator <(ResourceVariable x, Tensor y) => gen_math_ops.less(x.value(), y); | |||
@@ -62,8 +62,8 @@ namespace Tensorflow | |||
throw new NotImplementedException(""); | |||
} | |||
x.assign(result); | |||
result.ResourceVar = x; | |||
// x.assign(result); | |||
// result.ResourceVar = x; | |||
return result; | |||
}); | |||
} | |||
@@ -28,15 +28,15 @@ namespace Tensorflow | |||
/// </summary> | |||
public partial class ResourceVariable : BaseResourceVariable | |||
{ | |||
public override string name => _handle_name; | |||
Operation _initializer_op; | |||
public override Operation initializer => _initializer_op; | |||
Tensor _cached_value; | |||
Tensor _graph_element; | |||
public override Tensor graph_element => _graph_element; | |||
public string Device => _handle.Device; | |||
public Graph Graph => _handle.graph; | |||
public override Operation op => _handle.op; | |||
public string Device => handle.Device; | |||
public Graph Graph => handle.graph; | |||
public Operation op => handle.op; | |||
public Tensor is_initialized_op { get; set; } | |||
public ResourceVariable(IntPtr handle, IntPtr tensor) : base(handle, tensor) | |||
{ | |||
} | |||
public ResourceVariable(object initial_value = null, | |||
bool trainable = true, | |||
@@ -47,7 +47,7 @@ namespace Tensorflow | |||
VariableDef variable_def = null, | |||
TF_DataType dtype = TF_DataType.DtInvalid, | |||
string import_scope = "", | |||
TensorShape shape = null) : base() | |||
TensorShape shape = null) | |||
{ | |||
if (variable_def != null) | |||
{ | |||
@@ -66,7 +66,7 @@ namespace Tensorflow | |||
shape: shape); | |||
} | |||
_handle.ResourceVar = this; | |||
// handle.ResourceVar = this; | |||
} | |||
private void _init_from_args(object initial_value = null, | |||
@@ -91,14 +91,19 @@ namespace Tensorflow | |||
{ | |||
name = scope; | |||
var handle_name = ops.name_from_scope_name(name); | |||
var unique_id = $"{handle_name}_{ops.uid()}"; | |||
var shared_name = tf.context.shared_name(); | |||
string unique_id = ""; | |||
string shared_name = ""; | |||
if (_in_graph_mode) | |||
{ | |||
shared_name = handle_name; | |||
unique_id = shared_name; | |||
} | |||
else | |||
{ | |||
unique_id = $"{handle_name}_{ops.uid()}"; | |||
shared_name = tf.context.shared_name(); | |||
} | |||
var attr = new AttrValue(); | |||
attr.List = new AttrValue.Types.ListValue(); | |||
@@ -111,7 +116,7 @@ namespace Tensorflow | |||
}); | |||
_shape = shape ?? (initial_value as Tensor).TensorShape; | |||
_initial_value = initial_value as Tensor; | |||
_handle = resource_variable_ops.eager_safe_variable_handle( | |||
handle = resource_variable_ops.eager_safe_variable_handle( | |||
initial_value: _initial_value, | |||
shape: _shape, | |||
shared_name: shared_name, | |||
@@ -124,7 +129,7 @@ namespace Tensorflow | |||
{ | |||
tf_with(ops.name_scope("IsInitialized"), delegate | |||
{ | |||
_is_initialized_op = gen_resource_variable_ops.var_is_initialized_op(_handle); | |||
is_initialized_op = gen_resource_variable_ops.var_is_initialized_op(handle); | |||
}); | |||
if(initial_value != null) | |||
@@ -132,7 +137,7 @@ namespace Tensorflow | |||
tf_with(ops.name_scope("Assign"), scope1 => | |||
{ | |||
string n = scope1; | |||
_initializer_op = gen_resource_variable_ops.assign_variable_op(_handle, | |||
initializer_op = gen_resource_variable_ops.assign_variable_op(handle, | |||
variables._try_guard_against_uninitialized_dependencies(name, _initial_value), | |||
name: n); | |||
}); | |||
@@ -150,11 +155,18 @@ namespace Tensorflow | |||
} | |||
else | |||
{ | |||
gen_resource_variable_ops.assign_variable_op(_handle, _initial_value); | |||
gen_resource_variable_ops.assign_variable_op(handle, _initial_value); | |||
is_initialized_op = null; | |||
initializer_op = null; | |||
_graph_element = null; | |||
initial_value = _in_graph_mode ? initial_value : null; | |||
c_api.TFE_SetResourceVariableHandle(_handle, handle as EagerTensor); | |||
c_api.TFE_SetResourceVariableName(_handle, handle_name + ":0"); | |||
} | |||
base.__init__(trainable: trainable, | |||
handle: _handle, | |||
handle: handle, | |||
name: name, | |||
unique_id: unique_id, | |||
handle_name: handle_name); | |||
@@ -170,11 +182,11 @@ namespace Tensorflow | |||
// Create from variable_def. | |||
var g = ops.get_default_graph(); | |||
var prepend_name_scope = ops.prepend_name_scope(variable_def.VariableName, import_scope: import_scope); | |||
_handle = g.as_graph_element(prepend_name_scope) as Tensor; | |||
_shape = new TensorShape(_handle.op.get_attr("shape") as TensorShapeProto); | |||
handle = g.as_graph_element(prepend_name_scope) as Tensor; | |||
_shape = new TensorShape(handle.op.get_attr("shape") as TensorShapeProto); | |||
prepend_name_scope = ops.prepend_name_scope(variable_def.InitializerName, import_scope: import_scope); | |||
_initializer_op = g.as_graph_element(prepend_name_scope) as Operation; | |||
initializer_op = g.as_graph_element(prepend_name_scope) as Operation; | |||
if (!string.IsNullOrEmpty(variable_def.InitialValueName)) | |||
{ | |||
prepend_name_scope = ops.prepend_name_scope(variable_def.InitialValueName, import_scope: import_scope); | |||
@@ -208,7 +220,7 @@ namespace Tensorflow | |||
throw new NotImplementedException("SaveSliceInfoDef _init_from_proto"); | |||
} | |||
_dtype = dtypes.as_tf_dtype((DataType)_handle.op.get_attr("dtype")); | |||
_dtype = dtypes.as_tf_dtype((DataType)handle.op.get_attr("dtype")); | |||
} | |||
public Tensor sparse_read(Tensor indices, string name = "Gather") | |||
@@ -217,7 +229,7 @@ namespace Tensorflow | |||
{ | |||
name = scope; | |||
var value = gen_resource_variable_ops.resource_gather( | |||
_handle, indices, dtype: _dtype, name: name); | |||
handle, indices, dtype: _dtype, name: name); | |||
return array_ops.identity(value); | |||
}); | |||
@@ -225,7 +237,7 @@ namespace Tensorflow | |||
public override string ToString() | |||
{ | |||
return $"tf.Variable: '{name}' shape={string.Join(",", shape)}, dtype={dtype.as_numpy_name()}, numpy={EagerTensor.GetFormattedString(dtype, numpy())}"; | |||
return $"tf.Variable: '{Name}' shape={string.Join(",", shape)}, dtype={dtype.as_numpy_name()}, numpy={EagerTensor.GetFormattedString(dtype, numpy())}"; | |||
} | |||
} | |||
} |
@@ -11,14 +11,14 @@ namespace Tensorflow | |||
/// </summary> | |||
public class _UnreadVariable : BaseResourceVariable | |||
{ | |||
public override string name => _in_graph_mode ? _parent_op.name : "UnreadVariable"; | |||
public override string Name => _in_graph_mode ? _parent_op.name : "UnreadVariable"; | |||
public _UnreadVariable(Tensor handle, TF_DataType dtype, TensorShape shape, | |||
bool in_graph_mode, string unique_id) : base() | |||
{ | |||
_dtype = dtype; | |||
_shape = shape; | |||
_handle = handle; | |||
base.handle = handle; | |||
_unique_id = unique_id; | |||
_in_graph_mode = in_graph_mode; | |||
@@ -36,7 +36,7 @@ namespace Tensorflow | |||
_store_eager_variables = false; | |||
} | |||
public VariableV1 get_variable(string name, | |||
public IVariableV1 get_variable(string name, | |||
TensorShape shape = null, | |||
TF_DataType dtype = TF_DataType.TF_FLOAT, | |||
object initializer = null, // IInitializer or Tensor | |||
@@ -61,7 +61,7 @@ namespace Tensorflow | |||
aggregation: aggregation); | |||
} | |||
private VariableV1 _true_getter(string name, | |||
private IVariableV1 _true_getter(string name, | |||
TensorShape shape = null, | |||
TF_DataType dtype = TF_DataType.TF_FLOAT, | |||
object initializer = null, | |||
@@ -110,7 +110,7 @@ namespace Tensorflow | |||
} | |||
} | |||
private VariableV1 _get_single_variable(string name, | |||
private IVariableV1 _get_single_variable(string name, | |||
TensorShape shape = null, | |||
TF_DataType dtype = TF_DataType.DtInvalid, | |||
IInitializer initializer = null, | |||
@@ -136,7 +136,7 @@ namespace Tensorflow | |||
throw new NotImplementedException("_get_single_variable"); | |||
} | |||
VariableV1 v = null; | |||
IVariableV1 v = null; | |||
// Create the tensor to initialize the variable with default value. | |||
if (initializer == null) | |||
{ | |||
@@ -0,0 +1,19 @@ | |||
using System; | |||
using System.Collections.Generic; | |||
using System.Runtime.InteropServices; | |||
using System.Text; | |||
namespace Tensorflow | |||
{ | |||
public partial class c_api | |||
{ | |||
[DllImport(TensorFlowLibName)] | |||
public static extern IntPtr TFE_NewResourceVariable(); | |||
[DllImport(TensorFlowLibName)] | |||
public static extern void TFE_SetResourceVariableHandle(IntPtr variable, IntPtr tensor); | |||
[DllImport(TensorFlowLibName)] | |||
public static extern void TFE_SetResourceVariableName(IntPtr variable, string name); | |||
} | |||
} |
@@ -172,7 +172,7 @@ namespace Tensorflow | |||
return $"{prefix}_{idx}"; | |||
} | |||
public static VariableV1 default_variable_creator(object initial_value, | |||
public static IVariableV1 default_variable_creator(object initial_value, | |||
string name = null, | |||
bool? trainable = null, | |||
List<string> collections = null, | |||
@@ -37,12 +37,12 @@ namespace Tensorflow | |||
/// </summary> | |||
/// <param name="scope"></param> | |||
/// <returns></returns> | |||
public static VariableV1[] _all_saveable_objects(string scope = "") | |||
public static IVariableV1[] _all_saveable_objects(string scope = "") | |||
{ | |||
var all = new List<VariableV1>(); | |||
var all = new List<IVariableV1>(); | |||
all.AddRange(ops.get_collection<VariableV1>(tf.GraphKeys.GLOBAL_VARIABLES, scope)); | |||
all.AddRange(ops.get_collection<VariableV1>(tf.GraphKeys.SAVEABLE_OBJECTS, scope)); | |||
all.AddRange(ops.get_collection<IVariableV1>(tf.GraphKeys.GLOBAL_VARIABLES, scope)); | |||
all.AddRange(ops.get_collection<IVariableV1>(tf.GraphKeys.SAVEABLE_OBJECTS, scope)); | |||
return all.ToArray(); | |||
} | |||
@@ -58,9 +58,9 @@ namespace Tensorflow | |||
/// special tokens filters by prefix. | |||
/// </param> | |||
/// <returns>A list of `Variable` objects.</returns> | |||
public static List<VariableV1> global_variables(string scope = null) | |||
public static List<IVariableV1> global_variables(string scope = null) | |||
{ | |||
return ops.get_collection<VariableV1>(tf.GraphKeys.GLOBAL_VARIABLES, scope); | |||
return ops.get_collection<IVariableV1>(tf.GraphKeys.GLOBAL_VARIABLES, scope); | |||
} | |||
/// <summary> | |||
@@ -69,10 +69,10 @@ namespace Tensorflow | |||
/// <param name="var_list">List of `Variable` objects to initialize.</param> | |||
/// <param name="name">Optional name for the returned operation.</param> | |||
/// <returns>An Op that run the initializers of all the specified variables.</returns> | |||
public static Operation variables_initializer(VariableV1[] var_list, string name = "init") | |||
public static Operation variables_initializer(IVariableV1[] var_list, string name = "init") | |||
{ | |||
if (var_list.Length > 0) | |||
return control_flow_ops.group(var_list.Select(x => x.initializer).ToArray(), name); | |||
return control_flow_ops.group(var_list.Select(x => x.Initializer).ToArray(), name); | |||
else | |||
return gen_control_flow_ops.no_op(name: name); | |||
} | |||
@@ -62,7 +62,7 @@ namespace Tensorflow | |||
}); | |||
ops.RegisterFromAssembly(); | |||
c_api.TFE_RegisterGradientFunction((op_name, num_inputs, op_inputs, num_attrs, num_outputs, output_grads) => | |||
c_api.TFE_RegisterGradientFunction((op_name, num_inputs, op_inputs, num_attrs, num_outputs, output_grads, num_skip_inputs, skip_input_indices) => | |||
{ | |||
var input_tensors = new EagerTensor[num_inputs]; | |||
for (int i = 0; i < num_inputs; i++) | |||
@@ -72,16 +72,21 @@ namespace Tensorflow | |||
for (int i = 0; i < num_outputs; i++) | |||
output_grad_tensors[i] = new EagerTensor(*((IntPtr*)output_grads + i)); | |||
var skip_input_indices_param = new int[num_skip_inputs]; | |||
for (int i = 0; i < num_skip_inputs; i++) | |||
skip_input_indices_param[i] = *((int*)skip_input_indices + i); | |||
var gradients = ops.gradientFunctions[op_name](new EagerOperation | |||
{ | |||
NumInputs = num_inputs, | |||
Inputs = input_tensors | |||
Inputs = input_tensors, | |||
SkipInputIndices = skip_input_indices_param | |||
}, output_grad_tensors); | |||
var ret_tensors = Marshal.AllocHGlobal(sizeof(IntPtr) * num_inputs); | |||
Marshal.Copy(gradients.Select(x => x == null ? IntPtr.Zero : (x as EagerTensor).EagerTensorHandle).ToArray(), 0, ret_tensors, 2); | |||
// Marshal.FreeHGlobal(ret_tensors); | |||
return ret_tensors; | |||
var gradients_handles = gradients.Select(x => x == null ? IntPtr.Zero : (x as EagerTensor).EagerTensorHandle).ToArray(); | |||
var wrap_handle = c_api.TFE_WrapGradientResult(gradients_handles, gradients.Length); | |||
return wrap_handle; | |||
}); | |||
} | |||
@@ -11,7 +11,7 @@ namespace Tensorflow.Keras.Engine | |||
{ | |||
public static (Metric, Metric) create_mean_metric(Tensor value, string name = null) => throw new NotImplementedException(); | |||
public static VariableV1 make_variable(string name, TensorShape shape= null, TF_DataType dtype= TF_DataType.TF_FLOAT, Initializer initializer= null, | |||
public static IVariableV1 make_variable(string name, TensorShape shape= null, TF_DataType dtype= TF_DataType.TF_FLOAT, Initializer initializer= null, | |||
bool trainable= true, string caching_device= null, bool validate_shape= true, Constraints.ConstraintBase constraint= null, | |||
bool use_resource= false, Graph[] collections= null, VariableSynchronization synchronization= VariableSynchronization.Auto, | |||
VariableAggregation aggregation= VariableAggregation.None) => throw new NotImplementedException(); | |||
@@ -373,7 +373,7 @@ namespace Keras.Layers | |||
private void _symbolic_add_metric(Metric value, string aggregation = null, string name = null) => throw new NotImplementedException(); | |||
private void _handle_weight_regularization(string name, VariableV1 variable, Regularizer regularizer) => throw new NotImplementedException(); | |||
private void _handle_weight_regularization(string name, IVariableV1 variable, Regularizer regularizer) => throw new NotImplementedException(); | |||
private void _handle_activity_regularization(Tensor[] inputs, Tensor[] outputs) => throw new NotImplementedException(); | |||
@@ -36,7 +36,7 @@ namespace Tensorflow.Keras | |||
public static void in_place_subclassed_model_state_restoration(Model model) => throw new NotImplementedException(); | |||
public static void clone_and_build_model(Model model, Tensor[] input_tensors= null, Tensor[] target_tensors= null, object custom_objects= null, | |||
bool compile_clone= true, bool in_place_reset= false, VariableV1 optimizer_iterations= null, Hashtable optimizer_config= null) | |||
bool compile_clone= true, bool in_place_reset= false, IVariableV1 optimizer_iterations= null, Hashtable optimizer_config= null) | |||
=> throw new NotImplementedException(); | |||
} | |||
} |
@@ -4,6 +4,7 @@ | |||
<TargetFramework>netstandard2.0</TargetFramework> | |||
<AssemblyName>Tensorflow.Keras</AssemblyName> | |||
<RootNamespace>Tensorflow.Keras</RootNamespace> | |||
<Platforms>AnyCPU;x64</Platforms> | |||
</PropertyGroup> | |||
<ItemGroup> | |||
@@ -3,16 +3,25 @@ | |||
<PropertyGroup> | |||
<OutputType>Exe</OutputType> | |||
<TargetFramework>netcoreapp3.1</TargetFramework> | |||
<Platforms>AnyCPU;x64</Platforms> | |||
</PropertyGroup> | |||
<PropertyGroup Condition="'$(Configuration)|$(Platform)'=='Debug|AnyCPU'"> | |||
<AllowUnsafeBlocks>true</AllowUnsafeBlocks> | |||
</PropertyGroup> | |||
<PropertyGroup Condition="'$(Configuration)|$(Platform)'=='Debug|x64'"> | |||
<AllowUnsafeBlocks>true</AllowUnsafeBlocks> | |||
</PropertyGroup> | |||
<PropertyGroup Condition="'$(Configuration)|$(Platform)'=='Release|AnyCPU'"> | |||
<AllowUnsafeBlocks>true</AllowUnsafeBlocks> | |||
</PropertyGroup> | |||
<PropertyGroup Condition="'$(Configuration)|$(Platform)'=='Release|x64'"> | |||
<AllowUnsafeBlocks>true</AllowUnsafeBlocks> | |||
</PropertyGroup> | |||
<ItemGroup> | |||
<None Remove="tensorflow.dll" /> | |||
</ItemGroup> | |||
@@ -15,7 +15,7 @@ namespace TensorFlowNET.UnitTest.Basics | |||
public void NewVariable() | |||
{ | |||
var x = tf.Variable(10, name: "new_variable_x"); | |||
Assert.AreEqual("new_variable_x:0", x.name); | |||
Assert.AreEqual("new_variable_x:0", x.Name); | |||
Assert.AreEqual(0, x.shape.ndim); | |||
Assert.AreEqual(10, (int)x.numpy()); | |||
} | |||
@@ -56,10 +56,10 @@ namespace TensorFlowNET.UnitTest.Basics | |||
public void Accumulation() | |||
{ | |||
var x = tf.Variable(10, name: "x"); | |||
for (int i = 0; i < 5; i++) | |||
/*for (int i = 0; i < 5; i++) | |||
x = x + 1; | |||
Assert.AreEqual(15, (int)x.numpy()); | |||
Assert.AreEqual(15, (int)x.numpy());*/ | |||
} | |||
[TestMethod] | |||
@@ -12,9 +12,17 @@ | |||
<AssemblyOriginatorKeyFile>Open.snk</AssemblyOriginatorKeyFile> | |||
<LangVersion>8.0</LangVersion> | |||
<Platforms>AnyCPU;x64</Platforms> | |||
</PropertyGroup> | |||
<PropertyGroup Condition="'$(Configuration)|$(Platform)'=='Debug|AnyCPU'"> | |||
<DefineConstants>DEBUG;TRACE</DefineConstants> | |||
<AllowUnsafeBlocks>true</AllowUnsafeBlocks> | |||
<PlatformTarget>AnyCPU</PlatformTarget> | |||
</PropertyGroup> | |||
<PropertyGroup Condition="'$(Configuration)|$(Platform)'=='Debug|x64'"> | |||
<DefineConstants>DEBUG;TRACE</DefineConstants> | |||
<AllowUnsafeBlocks>true</AllowUnsafeBlocks> | |||
<PlatformTarget>x64</PlatformTarget> | |||
@@ -24,6 +32,10 @@ | |||
<AllowUnsafeBlocks>true</AllowUnsafeBlocks> | |||
</PropertyGroup> | |||
<PropertyGroup Condition="'$(Configuration)|$(Platform)'=='Release|x64'"> | |||
<AllowUnsafeBlocks>true</AllowUnsafeBlocks> | |||
</PropertyGroup> | |||
<ItemGroup> | |||
<Compile Remove="KerasTests.cs" /> | |||
</ItemGroup> | |||
@@ -92,7 +92,7 @@ namespace TensorFlowNET.UnitTest.ops_test | |||
self.assertEqual(op.graph, g); | |||
self.assertIsNotNone(op._get_control_flow_context()); | |||
var cond_text = op._get_control_flow_context() as ControlFlowContext; | |||
self.assertEqual(cond_text.name, "cond/cond_text"); | |||
self.assertEqual(cond_text.Name, "cond/cond_text"); | |||
} | |||
[Ignore("Todo: Port")] | |||
@@ -122,7 +122,7 @@ namespace TensorFlowNET.UnitTest.ops_test | |||
self.assertItemsEqual(op_input.inputs.OfType<Operation>().ToArray(), new[] {x}); | |||
self.assertEqual(op.graph, graph); | |||
self.assertIsNotNone(op._get_control_flow_context()); | |||
self.assertEqual(((ControlFlowContext)op._get_control_flow_context()).name, "myloop/while_context"); | |||
self.assertEqual(((ControlFlowContext)op._get_control_flow_context()).Name, "myloop/while_context"); | |||
/* | |||
@test_util.run_v1_only("b/120545219") | |||
def testWhileLoop(self): | |||
@@ -4,6 +4,8 @@ | |||
<TargetFramework>netcoreapp3.1</TargetFramework> | |||
<IsPackable>false</IsPackable> | |||
<Platforms>AnyCPU;x64</Platforms> | |||
</PropertyGroup> | |||
<ItemGroup> | |||