@@ -57,5 +57,21 @@ namespace Tensorflow.Keras.Engine | |||||
{ | { | ||||
} | } | ||||
protected virtual void add_weight(string name, | |||||
int[] shape, | |||||
TF_DataType dtype = TF_DataType.DtInvalid, | |||||
IInitializer initializer = null, | |||||
bool? trainable = null, | |||||
Func<string, int[], TF_DataType, IInitializer, bool, RefVariable> getter = null) | |||||
{ | |||||
_add_variable_with_custom_getter(name, | |||||
shape, | |||||
dtype: dtype, | |||||
getter: getter, | |||||
overwrite: true, | |||||
initializer: initializer, | |||||
trainable: trainable.Value); | |||||
} | |||||
} | } | ||||
} | } |
@@ -53,7 +53,11 @@ namespace Tensorflow.Keras.Layers | |||||
int channel_axis = data_format == "channels_first" ? 1 : -1; | int channel_axis = data_format == "channels_first" ? 1 : -1; | ||||
int input_dim = input_shape.Dimensions[input_shape.NDim - 1]; | int input_dim = input_shape.Dimensions[input_shape.NDim - 1]; | ||||
var kernel_shape = new int[] { kernel_size[0], kernel_size[1], input_dim, filters }; | var kernel_shape = new int[] { kernel_size[0], kernel_size[1], input_dim, filters }; | ||||
add_weight(); | |||||
add_weight(name: "kernel", | |||||
shape: kernel_shape, | |||||
initializer: kernel_initializer, | |||||
trainable: true, | |||||
dtype: _dtype); | |||||
} | } | ||||
} | } | ||||
} | } |
@@ -68,7 +68,11 @@ namespace Tensorflow.Layers | |||||
throw new NotImplementedException(""); | throw new NotImplementedException(""); | ||||
} | } | ||||
protected virtual void add_weight() | |||||
protected virtual void add_weight(string name, | |||||
int[] shape, | |||||
TF_DataType dtype = TF_DataType.DtInvalid, | |||||
IInitializer initializer = null, | |||||
bool? trainable = null) | |||||
{ | { | ||||
var default_graph = ops.get_default_graph(); | var default_graph = ops.get_default_graph(); | ||||
Graph init_graph = null; | Graph init_graph = null; | ||||
@@ -84,7 +88,9 @@ namespace Tensorflow.Layers | |||||
existing_variables = variables.global_variables().ToArray(); | existing_variables = variables.global_variables().ToArray(); | ||||
} | } | ||||
var dtype = TF_DataType.TF_FLOAT; | |||||
if(dtype == TF_DataType.DtInvalid) | |||||
dtype = TF_DataType.TF_FLOAT; | |||||
_set_scope(); | _set_scope(); | ||||
var reuse = built || (_reuse != null && _reuse.Value); | var reuse = built || (_reuse != null && _reuse.Value); | ||||
Python.with(tf.variable_scope(_scope, | Python.with(tf.variable_scope(_scope, | ||||
@@ -94,8 +100,19 @@ namespace Tensorflow.Layers | |||||
_current_scope = scope; | _current_scope = scope; | ||||
Python.with(ops.name_scope(_name_scope()), delegate | Python.with(ops.name_scope(_name_scope()), delegate | ||||
{ | { | ||||
base.add_weight(name, | |||||
shape, | |||||
dtype: dtype, | |||||
initializer: initializer, | |||||
trainable: trainable, | |||||
getter: (name1, shape1, dtype1, initializer1, trainable1) => | |||||
{ | |||||
return tf.get_variable(name1, | |||||
shape: new TensorShape(shape1), | |||||
dtype: dtype1, | |||||
initializer: initializer1, | |||||
trainable: trainable1); | |||||
}); | |||||
}); | }); | ||||
}); | }); | ||||
} | } | ||||
@@ -1,5 +1,6 @@ | |||||
using System; | using System; | ||||
using System.Collections.Generic; | using System.Collections.Generic; | ||||
using System.Linq; | |||||
using System.Text; | using System.Text; | ||||
namespace Tensorflow.Operations.Initializers | namespace Tensorflow.Operations.Initializers | ||||
@@ -64,7 +65,16 @@ namespace Tensorflow.Operations.Initializers | |||||
if (shape.Length == 2) | if (shape.Length == 2) | ||||
return (shape[0], shape[1]); | return (shape[0], shape[1]); | ||||
else | else | ||||
throw new NotImplementedException("VarianceScaling._compute_fans"); | |||||
{ | |||||
// Assuming convolution kernels (2D, 3D, or more). | |||||
// kernel shape: (..., input_depth, depth) | |||||
int receptive_field_size = 1; | |||||
foreach (var dim in shape.Take(2)) | |||||
receptive_field_size *= dim; | |||||
var fan_in = shape[shape.Length - 2] * receptive_field_size; | |||||
var fan_out = shape[shape.Length - 1] * receptive_field_size; | |||||
return (fan_in, fan_out); | |||||
} | |||||
} | } | ||||
public virtual object get_config() | public virtual object get_config() | ||||
@@ -4,7 +4,22 @@ using System.Text; | |||||
namespace Tensorflow | namespace Tensorflow | ||||
{ | { | ||||
public class CheckpointableBase | |||||
public abstract class CheckpointableBase | |||||
{ | { | ||||
/// <summary> | |||||
/// Restore-on-create for a variable be saved with this `Checkpointable`. | |||||
/// </summary> | |||||
/// <returns></returns> | |||||
protected virtual RefVariable _add_variable_with_custom_getter(string name, | |||||
int[] shape, | |||||
TF_DataType dtype = TF_DataType.TF_FLOAT, | |||||
IInitializer initializer = null, | |||||
Func<string, int[], TF_DataType, IInitializer, bool, RefVariable> getter = null, | |||||
bool overwrite = false, | |||||
bool trainable = false) | |||||
{ | |||||
var new_variable = getter(name, shape, dtype, initializer, trainable); | |||||
throw new NotImplementedException("_add_variable_with_custom_getter"); | |||||
} | |||||
} | } | ||||
} | } |
@@ -48,6 +48,7 @@ namespace Tensorflow | |||||
shape: shape, | shape: shape, | ||||
dtype: dtype, | dtype: dtype, | ||||
initializer: initializer, | initializer: initializer, | ||||
reuse: resue, | |||||
trainable: trainable, | trainable: trainable, | ||||
synchronization: synchronization, | synchronization: synchronization, | ||||
aggregation: aggregation); | aggregation: aggregation); | ||||
@@ -24,6 +24,7 @@ namespace Tensorflow | |||||
TensorShape shape = null, | TensorShape shape = null, | ||||
TF_DataType dtype = TF_DataType.TF_FLOAT, | TF_DataType dtype = TF_DataType.TF_FLOAT, | ||||
object initializer = null, // IInitializer or Tensor | object initializer = null, // IInitializer or Tensor | ||||
bool? reuse = null, | |||||
bool? trainable = null, | bool? trainable = null, | ||||
bool validate_shape = true, | bool validate_shape = true, | ||||
VariableSynchronization synchronization = VariableSynchronization.AUTO, | VariableSynchronization synchronization = VariableSynchronization.AUTO, | ||||
@@ -100,7 +101,7 @@ namespace Tensorflow | |||||
VariableSynchronization synchronization = VariableSynchronization.AUTO, | VariableSynchronization synchronization = VariableSynchronization.AUTO, | ||||
VariableAggregation aggregation = VariableAggregation.NONE) | VariableAggregation aggregation = VariableAggregation.NONE) | ||||
{ | { | ||||
bool initializing_from_value = true; | |||||
bool initializing_from_value = false; | |||||
if (use_resource == null) | if (use_resource == null) | ||||
use_resource = false; | use_resource = false; | ||||