@@ -1,6 +1,7 @@ | |||||
using System; | using System; | ||||
using System.Collections.Generic; | using System.Collections.Generic; | ||||
using System.Text; | using System.Text; | ||||
using NumSharp; | |||||
using static Tensorflow.Binding; | using static Tensorflow.Binding; | ||||
namespace Tensorflow | namespace Tensorflow | ||||
@@ -18,6 +19,22 @@ namespace Tensorflow | |||||
var tensor = tf.constant(3112.0f); | var tensor = tf.constant(3112.0f); | ||||
} | } | ||||
}; | }; | ||||
public Action<int> Constant2x3 | |||||
=> (iterate) => | |||||
{ | |||||
var nd = np.array(new byte[,] | |||||
{ | |||||
{1, 2, 3}, | |||||
{4, 5, 6} | |||||
}); | |||||
for (int i = 0; i < iterate; i++) | |||||
{ | |||||
var tensor = tf.constant(nd); | |||||
var data = tensor.numpy(); | |||||
} | |||||
}; | |||||
public Action<int> Variable | public Action<int> Variable | ||||
=> (iterate) => | => (iterate) => | ||||
{ | { | ||||
@@ -15,6 +15,9 @@ namespace Tensorflow | |||||
int batchSize = 1000; | int batchSize = 1000; | ||||
// explaination of constant | |||||
mm.Execute(10, 100 * batchSize, cases.Constant2x3); | |||||
// 1 million float tensor 68M. | // 1 million float tensor 68M. | ||||
mm.Execute(10, 100 * batchSize, cases.Constant); | mm.Execute(10, 100 * batchSize, cases.Constant); | ||||
@@ -43,7 +43,7 @@ namespace Tensorflow | |||||
/// </summary> | /// </summary> | ||||
public partial class c_api | public partial class c_api | ||||
{ | { | ||||
public const string TensorFlowLibName = "tensorflow"; | |||||
public const string TensorFlowLibName = @"C:\Users\haipi\Documents\Projects\tensorflow\bazel-bin\tensorflow\tensorflow"; | |||||
public static string StringPiece(IntPtr handle) | public static string StringPiece(IntPtr handle) | ||||
{ | { | ||||
@@ -70,8 +70,8 @@ namespace Tensorflow.Eager | |||||
protected override void DisposeUnmanagedResources(IntPtr handle) | protected override void DisposeUnmanagedResources(IntPtr handle) | ||||
{ | { | ||||
base.DisposeUnmanagedResources(handle); | |||||
//print($"deleting DeleteTensorHandle {Id} {_handle.ToString("x16")}"); | //print($"deleting DeleteTensorHandle {Id} {_handle.ToString("x16")}"); | ||||
c_api.TF_DeleteTensor(_handle); | |||||
} | } | ||||
} | } | ||||
} | } |
@@ -311,6 +311,7 @@ namespace Tensorflow | |||||
while (queue.Count > 0) | while (queue.Count > 0) | ||||
{ | { | ||||
var op = queue.Dequeue(); | var op = queue.Dequeue(); | ||||
if (reached_ops.Contains(op)) | if (reached_ops.Contains(op)) | ||||
{ | { | ||||
between_ops.Add(op); | between_ops.Add(op); | ||||
@@ -278,7 +278,11 @@ namespace Tensorflow | |||||
// after removing the trailing '/'. | // after removing the trailing '/'. | ||||
name = name.EndsWith("/") ? ops.name_from_scope_name(name) : unique_name(name); | name = name.EndsWith("/") ? ops.name_from_scope_name(name) : unique_name(name); | ||||
var node_def = ops._NodeDef(op_type, name, attrs: attrs); | var node_def = ops._NodeDef(op_type, name, attrs: attrs); | ||||
if (name == "rnn/while/basic_rnn_cell/MatMul" | |||||
|| name == "rnn/while/basic_rnn_cell/MatMul/Enter") | |||||
{ | |||||
} | |||||
var input_ops = inputs.Select(x => x.op).ToArray(); | var input_ops = inputs.Select(x => x.op).ToArray(); | ||||
var control_inputs = _control_dependencies_for_inputs(input_ops); | var control_inputs = _control_dependencies_for_inputs(input_ops); | ||||
@@ -21,7 +21,7 @@ namespace Tensorflow.Keras.Engine | |||||
_channels_first = args.DataFormat == "channels_first"; | _channels_first = args.DataFormat == "channels_first"; | ||||
} | } | ||||
protected override Tensor call(Tensor inputs, bool is_training = false, Tensor state = null) | |||||
protected override Tensor call(Tensor inputs, bool is_training = false) | |||||
{ | { | ||||
if (_channels_first) | if (_channels_first) | ||||
{ | { | ||||
@@ -121,7 +121,7 @@ namespace Tensorflow.Keras.Engine | |||||
/// <param name="input"></param> | /// <param name="input"></param> | ||||
/// <param name="is_training"></param> | /// <param name="is_training"></param> | ||||
/// <returns></returns> | /// <returns></returns> | ||||
public Tensor Apply(Tensor inputs, bool is_training = false, Tensor state = null) | |||||
public Tensor Apply(Tensor inputs, bool is_training = false) | |||||
{ | { | ||||
Tensor outputs = null; | Tensor outputs = null; | ||||
@@ -148,7 +148,7 @@ namespace Tensorflow.Keras.Engine | |||||
if (!built) | if (!built) | ||||
MaybeBuild(inputs); | MaybeBuild(inputs); | ||||
outputs = call(inputs, is_training: is_training, state: state); | |||||
outputs = call(inputs, is_training: is_training); | |||||
outputs = _set_connectivity_metadata_(inputs, outputs); | outputs = _set_connectivity_metadata_(inputs, outputs); | ||||
_handle_activity_regularization(inputs, outputs); | _handle_activity_regularization(inputs, outputs); | ||||
@@ -161,6 +161,35 @@ namespace Tensorflow.Keras.Engine | |||||
return outputs; | return outputs; | ||||
} | } | ||||
public Tensor[] Apply(Tensor[] inputs, Tensor state, bool is_training = false) | |||||
{ | |||||
Tensor[] outputs = null; | |||||
callContext = callContext ?? new ThreadLocal<CallContext>() | |||||
{ | |||||
Value = new CallContext() | |||||
}; | |||||
var eager = tf.executing_eagerly(); | |||||
using var ctxManager = CallContext.enter(); | |||||
string nameScope = ""; | |||||
if (eager) | |||||
nameScope = name; | |||||
else | |||||
nameScope = _name_scope(); | |||||
tf_with(ops.name_scope(nameScope), scope => | |||||
{ | |||||
if (!built) | |||||
MaybeBuild(inputs[0]); | |||||
outputs = call(inputs, is_training: is_training, state: state); | |||||
}); | |||||
return outputs; | |||||
} | |||||
private Tensor _set_connectivity_metadata_(Tensor inputs, Tensor outputs) | private Tensor _set_connectivity_metadata_(Tensor inputs, Tensor outputs) | ||||
{ | { | ||||
/*var returnOutputs = new List<Tensor>(); | /*var returnOutputs = new List<Tensor>(); | ||||
@@ -200,7 +229,12 @@ namespace Tensorflow.Keras.Engine | |||||
return null; | return null; | ||||
} | } | ||||
protected virtual Tensor call(Tensor inputs, bool is_training = false, Tensor state = null) | |||||
protected virtual Tensor call(Tensor inputs, bool is_training = false) | |||||
{ | |||||
throw new NotImplementedException(""); | |||||
} | |||||
protected virtual Tensor[] call(Tensor[] inputs, Tensor state, bool is_training = false) | |||||
{ | { | ||||
throw new NotImplementedException(""); | throw new NotImplementedException(""); | ||||
} | } | ||||
@@ -143,7 +143,7 @@ namespace Tensorflow.Keras.Layers | |||||
built = true; | built = true; | ||||
} | } | ||||
protected override Tensor call(Tensor inputs, bool is_training = false, Tensor state = null) | |||||
protected override Tensor call(Tensor inputs, bool is_training = false) | |||||
{ | { | ||||
Tensor outputs = null; | Tensor outputs = null; | ||||
@@ -95,7 +95,7 @@ namespace Tensorflow.Keras.Layers | |||||
built = true; | built = true; | ||||
} | } | ||||
protected override Tensor call(Tensor inputs, bool training = false, Tensor state = null) | |||||
protected override Tensor call(Tensor inputs, bool training = false) | |||||
{ | { | ||||
var outputs = _convolution_op.__call__(inputs, kernel); | var outputs = _convolution_op.__call__(inputs, kernel); | ||||
if (use_bias) | if (use_bias) | ||||
@@ -65,7 +65,7 @@ namespace Tensorflow.Keras.Layers | |||||
built = true; | built = true; | ||||
} | } | ||||
protected override Tensor call(Tensor inputs, bool training = false, Tensor state = null) | |||||
protected override Tensor call(Tensor inputs, bool training = false) | |||||
{ | { | ||||
Tensor outputs = null; | Tensor outputs = null; | ||||
var rank = inputs.rank; | var rank = inputs.rank; | ||||
@@ -18,7 +18,7 @@ namespace Tensorflow.Keras.Layers | |||||
this.args = args; | this.args = args; | ||||
} | } | ||||
protected override Tensor call(Tensor inputs, bool is_training = false, Tensor state = null) | |||||
protected override Tensor call(Tensor inputs, bool is_training = false) | |||||
{ | { | ||||
var output = tf_utils.smart_cond(is_training, | var output = tf_utils.smart_cond(is_training, | ||||
() => tf.nn.dropout(inputs, | () => tf.nn.dropout(inputs, | ||||
@@ -62,7 +62,7 @@ namespace Tensorflow.Keras.Layers | |||||
built = true; | built = true; | ||||
} | } | ||||
protected override Tensor call(Tensor inputs, bool is_training = false, Tensor state = null) | |||||
protected override Tensor call(Tensor inputs, bool is_training = false) | |||||
{ | { | ||||
var dtype = inputs.dtype; | var dtype = inputs.dtype; | ||||
if (dtype != tf.int32 && dtype != tf.int64) | if (dtype != tf.int32 && dtype != tf.int64) | ||||
@@ -29,9 +29,9 @@ namespace Tensorflow.Keras.Layers | |||||
.ToArray(); | .ToArray(); | ||||
} | } | ||||
protected override Tensor call(Tensor inputs, bool is_training = false, Tensor state = null) | |||||
protected override Tensor call(Tensor inputs, bool is_training = false) | |||||
{ | { | ||||
return base.call(inputs, is_training, state); | |||||
return base.call(inputs, is_training); | |||||
} | } | ||||
} | } | ||||
} | } |
@@ -36,7 +36,7 @@ namespace Tensorflow.Keras.Layers | |||||
input_spec = new InputSpec(ndim: 4); | input_spec = new InputSpec(ndim: 4); | ||||
} | } | ||||
protected override Tensor call(Tensor inputs, bool is_training = false, Tensor state = null) | |||||
protected override Tensor call(Tensor inputs, bool is_training = false) | |||||
{ | { | ||||
int[] pool_shape; | int[] pool_shape; | ||||
int[] strides; | int[] strides; | ||||
@@ -20,7 +20,7 @@ namespace Tensorflow.Keras.Layers | |||||
this.args = args; | this.args = args; | ||||
} | } | ||||
protected override Tensor call(Tensor inputs, bool is_training = false, Tensor state = null) | |||||
protected override Tensor call(Tensor inputs, bool is_training = false) | |||||
{ | { | ||||
scale = math_ops.cast(args.Scale, args.DType); | scale = math_ops.cast(args.Scale, args.DType); | ||||
offset = math_ops.cast(args.Offset, args.DType); | offset = math_ops.cast(args.Offset, args.DType); | ||||
@@ -61,9 +61,8 @@ namespace Tensorflow.Layers | |||||
return (results[0], results[1]); | return (results[0], results[1]); | ||||
} | } | ||||
public Tensor[] __call__(Tensor inputs, | |||||
public Tensor __call__(Tensor inputs, | |||||
Tensor training = null, | Tensor training = null, | ||||
Tensor state = null, | |||||
VariableScope scope = null) | VariableScope scope = null) | ||||
{ | { | ||||
_set_scope(scope); | _set_scope(scope); | ||||
@@ -88,16 +87,54 @@ namespace Tensorflow.Layers | |||||
{ | { | ||||
_current_scope = scope2; | _current_scope = scope2; | ||||
// Actually call layer | // Actually call layer | ||||
outputs = base.Apply(inputs, | |||||
is_training: training == null ? false : false, | |||||
state: state); | |||||
outputs = base.Apply(inputs[0], | |||||
is_training: training == null ? false : false); | |||||
}); | |||||
// Update global default collections. | |||||
_add_elements_to_collection(updates.ToArray(), new string[] { tf.GraphKeys.UPDATE_OPS }); | |||||
return outputs; | |||||
} | |||||
public Tensor[] __call__(Tensor[] inputs, | |||||
Tensor state = null, | |||||
Tensor training = null, | |||||
VariableScope scope = null) | |||||
{ | |||||
_set_scope(scope); | |||||
_graph = ops._get_graph_from_inputs(inputs, graph: _graph); | |||||
variable_scope scope_context_manager = null; | |||||
if (built) | |||||
{ | |||||
scope_context_manager = tf.variable_scope(_scope, | |||||
reuse: true, | |||||
auxiliary_name_scope: false); | |||||
} | |||||
else | |||||
{ | |||||
scope_context_manager = tf.variable_scope(_scope, | |||||
reuse: _reuse, | |||||
auxiliary_name_scope: false); | |||||
} | |||||
Tensor[] outputs = null; | |||||
tf_with(scope_context_manager, scope2 => | |||||
{ | |||||
_current_scope = scope2; | |||||
// Actually call layer | |||||
outputs = base.Apply(inputs, | |||||
state, | |||||
is_training: training == null ? false : false); | |||||
}); | }); | ||||
// Update global default collections. | // Update global default collections. | ||||
_add_elements_to_collection(updates.ToArray(), new string[] { tf.GraphKeys.UPDATE_OPS }); | _add_elements_to_collection(updates.ToArray(), new string[] { tf.GraphKeys.UPDATE_OPS }); | ||||
return new Tensor[] { outputs }; | |||||
return outputs; | |||||
} | } | ||||
protected virtual void _add_elements_to_collection(Operation[] elements, string[] collection_list) | protected virtual void _add_elements_to_collection(Operation[] elements, string[] collection_list) | ||||
@@ -326,7 +326,7 @@ namespace Tensorflow.Operations | |||||
protected override void _AddOpInternal(Operation op) | protected override void _AddOpInternal(Operation op) | ||||
{ | { | ||||
if (op.name == "gradients/rnn/while/basic_rnn_cell/Tanh_grad/TanhGrad") | |||||
if (op.name == "rnn/basic_rnn_cell/kernel/Initializer/random_uniform/shape") | |||||
{ | { | ||||
} | } | ||||
@@ -61,7 +61,7 @@ namespace Tensorflow | |||||
built = true; | built = true; | ||||
} | } | ||||
public Tensor[] __call__(Tensor inputs, LSTMStateTuple state) | |||||
public Tensor __call__(Tensor inputs, LSTMStateTuple state) | |||||
{ | { | ||||
_state = state; | _state = state; | ||||
return base.__call__(inputs); | return base.__call__(inputs); | ||||
@@ -74,7 +74,7 @@ namespace Tensorflow | |||||
/// <param name="training"></param> | /// <param name="training"></param> | ||||
/// <param name="state"></param> | /// <param name="state"></param> | ||||
/// <returns></returns> | /// <returns></returns> | ||||
protected override Tensor call(Tensor inputs, bool is_training = false, Tensor state = null) | |||||
protected override Tensor call(Tensor inputs, bool is_training = false) | |||||
{ | { | ||||
var one = constant_op.constant(1, dtype: dtypes.int32); | var one = constant_op.constant(1, dtype: dtypes.int32); | ||||
// Parameters of gates are concatenated into one multiply for efficiency. | // Parameters of gates are concatenated into one multiply for efficiency. | ||||
@@ -67,14 +67,14 @@ namespace Tensorflow | |||||
built = true; | built = true; | ||||
} | } | ||||
protected override Tensor call(Tensor inputs, bool is_training = false, Tensor state = null) | |||||
protected override Tensor[] call(Tensor[] inputs, Tensor state, bool is_training = false) | |||||
{ | { | ||||
// Most basic RNN: output = new_state = act(W * input + U * state + B). | // Most basic RNN: output = new_state = act(W * input + U * state + B). | ||||
var concat = array_ops.concat(new[] { inputs, state }, 1); | |||||
var concat = array_ops.concat(new[] { inputs[0], state }, 1); | |||||
var gate_inputs = math_ops.matmul(concat, _kernel.AsTensor()); | var gate_inputs = math_ops.matmul(concat, _kernel.AsTensor()); | ||||
gate_inputs = nn_ops.bias_add(gate_inputs, _bias.AsTensor()); | gate_inputs = nn_ops.bias_add(gate_inputs, _bias.AsTensor()); | ||||
var output = _activation(gate_inputs, null); | var output = _activation(gate_inputs, null); | ||||
return output; | |||||
return new[] { output, output }; | |||||
} | } | ||||
} | } | ||||
} | } |
@@ -364,7 +364,7 @@ namespace Tensorflow.Operations | |||||
if (sequence_length != null) | if (sequence_length != null) | ||||
throw new NotImplementedException("sequence_length != null"); | throw new NotImplementedException("sequence_length != null"); | ||||
else | else | ||||
outputs = cell.__call__(input_t_t, state: state1); | |||||
outputs = cell.__call__(new[] { input_t_t }, state: state1); | |||||
var (output, new_state) = (outputs[0], outputs[1]); | var (output, new_state) = (outputs[0], outputs[1]); | ||||
// Keras cells always wrap state as list, even if it's a single tensor. | // Keras cells always wrap state as list, even if it's a single tensor. | ||||
@@ -326,7 +326,7 @@ namespace Tensorflow | |||||
// the updated inputs are reloaded from the c_api | // the updated inputs are reloaded from the c_api | ||||
lock (Locks.ProcessWide) | lock (Locks.ProcessWide) | ||||
{ | { | ||||
// c_api.UpdateEdge(_graph, output, input, tf.Status.Handle); | |||||
c_api.UpdateEdge(_graph, output, input, tf.Status.Handle); | |||||
//var updated_inputs = inputs; | //var updated_inputs = inputs; | ||||
tf.Status.Check(); | tf.Status.Check(); | ||||
} | } | ||||
@@ -5,7 +5,7 @@ | |||||
<AssemblyName>TensorFlow.NET</AssemblyName> | <AssemblyName>TensorFlow.NET</AssemblyName> | ||||
<RootNamespace>Tensorflow</RootNamespace> | <RootNamespace>Tensorflow</RootNamespace> | ||||
<TargetTensorFlow>2.2.0</TargetTensorFlow> | <TargetTensorFlow>2.2.0</TargetTensorFlow> | ||||
<Version>0.20.0</Version> | |||||
<Version>0.20.1</Version> | |||||
<LangVersion>8.0</LangVersion> | <LangVersion>8.0</LangVersion> | ||||
<Authors>Haiping Chen, Meinrad Recheis, Eli Belash</Authors> | <Authors>Haiping Chen, Meinrad Recheis, Eli Belash</Authors> | ||||
<Company>SciSharp STACK</Company> | <Company>SciSharp STACK</Company> | ||||
@@ -19,13 +19,13 @@ | |||||
<Description>Google's TensorFlow full binding in .NET Standard. | <Description>Google's TensorFlow full binding in .NET Standard. | ||||
Building, training and infering deep learning models. | Building, training and infering deep learning models. | ||||
https://tensorflownet.readthedocs.io</Description> | https://tensorflownet.readthedocs.io</Description> | ||||
<AssemblyVersion>0.20.0.0</AssemblyVersion> | |||||
<AssemblyVersion>0.20.1.0</AssemblyVersion> | |||||
<PackageReleaseNotes>tf.net 0.20.x and above are based on tensorflow native 2.x. | <PackageReleaseNotes>tf.net 0.20.x and above are based on tensorflow native 2.x. | ||||
* Eager Mode is added finally. | * Eager Mode is added finally. | ||||
* tf.keras is partially working. | * tf.keras is partially working. | ||||
* tf.data is added.</PackageReleaseNotes> | * tf.data is added.</PackageReleaseNotes> | ||||
<FileVersion>0.20.0.0</FileVersion> | |||||
<FileVersion>0.20.1.0</FileVersion> | |||||
<PackageLicenseFile>LICENSE</PackageLicenseFile> | <PackageLicenseFile>LICENSE</PackageLicenseFile> | ||||
<PackageRequireLicenseAcceptance>true</PackageRequireLicenseAcceptance> | <PackageRequireLicenseAcceptance>true</PackageRequireLicenseAcceptance> | ||||
<SignAssembly>true</SignAssembly> | <SignAssembly>true</SignAssembly> | ||||
@@ -50,6 +50,8 @@ namespace Tensorflow | |||||
/// </summary> | /// </summary> | ||||
public AllocationType AllocationType { get; protected set; } | public AllocationType AllocationType { get; protected set; } | ||||
public IntPtr TensorDataPointer => TF_TensorData(_handle); | |||||
/// <summary> | /// <summary> | ||||
/// Create a Tensor object from an existing TF handle | /// Create a Tensor object from an existing TF handle | ||||
/// </summary> | /// </summary> | ||||
@@ -261,7 +261,6 @@ namespace Tensorflow | |||||
protected override void DisposeUnmanagedResources(IntPtr handle) | protected override void DisposeUnmanagedResources(IntPtr handle) | ||||
{ | { | ||||
c_api.TF_DeleteTensor(handle); | c_api.TF_DeleteTensor(handle); | ||||
if (AllocationHandle == null) | if (AllocationHandle == null) | ||||
return; | return; | ||||
@@ -88,80 +88,83 @@ namespace Tensorflow | |||||
if (trainable && !collections.Contains(tf.GraphKeys.TRAINABLE_VARIABLES)) | if (trainable && !collections.Contains(tf.GraphKeys.TRAINABLE_VARIABLES)) | ||||
collections.Add(tf.GraphKeys.TRAINABLE_VARIABLES); | collections.Add(tf.GraphKeys.TRAINABLE_VARIABLES); | ||||
ops.init_scope(); | |||||
_in_graph_mode = !tf.Context.executing_eagerly(); | _in_graph_mode = !tf.Context.executing_eagerly(); | ||||
tf_with(ops.name_scope(name, "Variable"), scope => | |||||
tf_with(ops.init_scope2(), delegate | |||||
{ | { | ||||
name = scope; | |||||
var handle_name = ops.name_from_scope_name(name); | |||||
string unique_id = ""; | |||||
string shared_name = ""; | |||||
if (_in_graph_mode) | |||||
{ | |||||
shared_name = handle_name; | |||||
unique_id = shared_name; | |||||
} | |||||
else | |||||
var values = init_from_fn ? new object[0] : new object[] { initial_value }; | |||||
tf_with(ops.name_scope(name, "Variable", values), scope => | |||||
{ | { | ||||
unique_id = $"{handle_name}_{ops.uid()}"; | |||||
shared_name = tf.Context.shared_name(); | |||||
} | |||||
var attr = new AttrValue(); | |||||
attr.List = new AttrValue.Types.ListValue(); | |||||
attr.List.S.Add(ByteString.CopyFromUtf8($"loc:@{handle_name}")); | |||||
tf_with(ops.name_scope("Initializer"), delegate | |||||
{ | |||||
if (initial_value.GetType().GetInterface("IInitializer") != null) | |||||
initial_value = ops.convert_to_tensor((initial_value as IInitializer).Apply(new InitializerArgs(shape, dtype: dtype))); | |||||
name = scope; | |||||
var handle_name = ops.name_from_scope_name(name); | |||||
string unique_id = ""; | |||||
string shared_name = ""; | |||||
if (_in_graph_mode) | |||||
{ | |||||
shared_name = handle_name; | |||||
unique_id = shared_name; | |||||
} | |||||
else | else | ||||
{ | { | ||||
var value = init_from_fn ? (initial_value as Func<Tensor>)() : initial_value; | |||||
initial_value = ops.convert_to_tensor(value, | |||||
name: "initial_value", | |||||
dtype: dtype); | |||||
unique_id = $"{handle_name}_{ops.uid()}"; | |||||
shared_name = tf.Context.shared_name(); | |||||
} | } | ||||
}); | |||||
_shape = shape ?? (initial_value as Tensor).TensorShape; | |||||
_initial_value = initial_value as Tensor; | |||||
var attr = new AttrValue(); | |||||
attr.List = new AttrValue.Types.ListValue(); | |||||
attr.List.S.Add(ByteString.CopyFromUtf8($"loc:@{handle_name}")); | |||||
tf_with(ops.name_scope("Initializer"), delegate | |||||
{ | |||||
if (initial_value.GetType().GetInterface("IInitializer") != null) | |||||
initial_value = ops.convert_to_tensor((initial_value as IInitializer).Apply(new InitializerArgs(shape, dtype: dtype))); | |||||
else | |||||
{ | |||||
var value = init_from_fn ? (initial_value as Func<Tensor>)() : initial_value; | |||||
initial_value = ops.convert_to_tensor(value, | |||||
name: "initial_value", | |||||
dtype: dtype); | |||||
} | |||||
}); | |||||
_shape = shape ?? (initial_value as Tensor).TensorShape; | |||||
_initial_value = initial_value as Tensor; | |||||
if (_in_graph_mode) | |||||
{ | |||||
handle = state_ops.variable_op_v2(_initial_value.shape, _initial_value.dtype.as_base_dtype(), name: name); | |||||
initializer_op = gen_state_ops.assign(handle, _initial_value, true).op; | |||||
if (_in_graph_mode) | |||||
{ | |||||
handle = state_ops.variable_op_v2(_initial_value.shape, _initial_value.dtype.as_base_dtype(), name: name); | |||||
initializer_op = gen_state_ops.assign(handle, _initial_value, true).op; | |||||
ops.colocate_with(initializer_op); | |||||
ops.colocate_with(initializer_op); | |||||
_graph_element = gen_array_ops.identity(handle, name = "read"); | |||||
ops.add_to_collections<IVariableV1>(collections, this); | |||||
_dtype = handle.dtype; | |||||
} | |||||
else | |||||
{ | |||||
handle = resource_variable_ops.eager_safe_variable_handle( | |||||
initial_value: _initial_value, | |||||
shape: _shape, | |||||
shared_name: shared_name, | |||||
name: name, | |||||
graph_mode: _in_graph_mode); | |||||
gen_resource_variable_ops.assign_variable_op(handle, _initial_value); | |||||
is_initialized_op = null; | |||||
initializer_op = null; | |||||
_graph_element = null; | |||||
_dtype = _initial_value.dtype.as_base_dtype(); | |||||
initial_value = _in_graph_mode ? initial_value : null; | |||||
} | |||||
_graph_element = gen_array_ops.identity(handle, name = "read"); | |||||
ops.add_to_collections<IVariableV1>(collections, this); | |||||
_dtype = handle.dtype; | |||||
} | |||||
else | |||||
{ | |||||
handle = resource_variable_ops.eager_safe_variable_handle( | |||||
initial_value: _initial_value, | |||||
shape: _shape, | |||||
shared_name: shared_name, | |||||
name: name, | |||||
graph_mode: _in_graph_mode); | |||||
gen_resource_variable_ops.assign_variable_op(handle, _initial_value); | |||||
is_initialized_op = null; | |||||
initializer_op = null; | |||||
_graph_element = null; | |||||
_dtype = _initial_value.dtype.as_base_dtype(); | |||||
initial_value = _in_graph_mode ? initial_value : null; | |||||
} | |||||
base.__init__(trainable: trainable, | |||||
handle: handle, | |||||
name: name, | |||||
unique_id: unique_id, | |||||
handle_name: handle_name); | |||||
base.__init__(trainable: trainable, | |||||
handle: handle, | |||||
name: name, | |||||
unique_id: unique_id, | |||||
handle_name: handle_name); | |||||
}); | |||||
}); | }); | ||||
} | } | ||||
@@ -30,6 +30,7 @@ | |||||
<ItemGroup> | <ItemGroup> | ||||
<PackageReference Include="BenchmarkDotNet" Version="0.12.1" /> | <PackageReference Include="BenchmarkDotNet" Version="0.12.1" /> | ||||
<PackageReference Include="SciSharp.TensorFlow.Redist" Version="2.3.0" /> | <PackageReference Include="SciSharp.TensorFlow.Redist" Version="2.3.0" /> | ||||
<PackageReference Include="TensorFlow.NET" Version="0.20.0" /> | |||||
</ItemGroup> | </ItemGroup> | ||||
<ItemGroup> | <ItemGroup> | ||||
@@ -43,7 +43,7 @@ | |||||
<ItemGroup> | <ItemGroup> | ||||
<PackageReference Include="FluentAssertions" Version="5.10.3" /> | <PackageReference Include="FluentAssertions" Version="5.10.3" /> | ||||
<PackageReference Include="Microsoft.NET.Test.Sdk" Version="16.7.0" /> | |||||
<PackageReference Include="Microsoft.NET.Test.Sdk" Version="16.7.1" /> | |||||
<PackageReference Include="MSTest.TestAdapter" Version="2.1.2" /> | <PackageReference Include="MSTest.TestAdapter" Version="2.1.2" /> | ||||
<PackageReference Include="MSTest.TestFramework" Version="2.1.2" /> | <PackageReference Include="MSTest.TestFramework" Version="2.1.2" /> | ||||
<PackageReference Include="SciSharp.TensorFlow.Redist" Version="2.3.0" /> | <PackageReference Include="SciSharp.TensorFlow.Redist" Version="2.3.0" /> | ||||