|
|
@@ -68,7 +68,11 @@ namespace Tensorflow.Layers |
|
|
|
throw new NotImplementedException(""); |
|
|
|
} |
|
|
|
|
|
|
|
protected virtual void add_weight() |
|
|
|
protected virtual void add_weight(string name, |
|
|
|
int[] shape, |
|
|
|
TF_DataType dtype = TF_DataType.DtInvalid, |
|
|
|
IInitializer initializer = null, |
|
|
|
bool? trainable = null) |
|
|
|
{ |
|
|
|
var default_graph = ops.get_default_graph(); |
|
|
|
Graph init_graph = null; |
|
|
@@ -84,7 +88,9 @@ namespace Tensorflow.Layers |
|
|
|
existing_variables = variables.global_variables().ToArray(); |
|
|
|
} |
|
|
|
|
|
|
|
var dtype = TF_DataType.TF_FLOAT; |
|
|
|
if(dtype == TF_DataType.DtInvalid) |
|
|
|
dtype = TF_DataType.TF_FLOAT; |
|
|
|
|
|
|
|
_set_scope(); |
|
|
|
var reuse = built || (_reuse != null && _reuse.Value); |
|
|
|
Python.with(tf.variable_scope(_scope, |
|
|
@@ -94,8 +100,19 @@ namespace Tensorflow.Layers |
|
|
|
_current_scope = scope; |
|
|
|
Python.with(ops.name_scope(_name_scope()), delegate |
|
|
|
{ |
|
|
|
|
|
|
|
|
|
|
|
base.add_weight(name, |
|
|
|
shape, |
|
|
|
dtype: dtype, |
|
|
|
initializer: initializer, |
|
|
|
trainable: trainable, |
|
|
|
getter: (name1, shape1, dtype1, initializer1, trainable1) => |
|
|
|
{ |
|
|
|
return tf.get_variable(name1, |
|
|
|
shape: new TensorShape(shape1), |
|
|
|
dtype: dtype1, |
|
|
|
initializer: initializer1, |
|
|
|
trainable: trainable1); |
|
|
|
}); |
|
|
|
}); |
|
|
|
}); |
|
|
|
} |
|
|
|