|
|
@@ -23,18 +23,68 @@ namespace Tensorflow |
|
|
|
|
|
|
|
public string name => _variable.name; |
|
|
|
|
|
|
|
public RefVariable(object initial_value, |
|
|
|
public RefVariable(object initial_value = null, |
|
|
|
bool trainable = true, |
|
|
|
List<string> collections = null, |
|
|
|
bool validate_shape = true, |
|
|
|
string caching_device = "", |
|
|
|
string name = "", |
|
|
|
TF_DataType dtype = TF_DataType.DtInvalid) : |
|
|
|
base(initial_value, trainable, collections, validate_shape, caching_device, name, dtype) |
|
|
|
VariableDef variable_def = null, |
|
|
|
TF_DataType dtype = TF_DataType.DtInvalid, |
|
|
|
string import_scope = "") : base(initial_value, |
|
|
|
trainable, |
|
|
|
collections, |
|
|
|
validate_shape, |
|
|
|
caching_device, |
|
|
|
name, |
|
|
|
dtype) |
|
|
|
{ |
|
|
|
_in_graph_mode = true; |
|
|
|
|
|
|
|
_init_from_args(initial_value, trainable, collections, validate_shape, caching_device, name, dtype); |
|
|
|
if (variable_def != null) |
|
|
|
{ |
|
|
|
if (initial_value != null) |
|
|
|
throw new ValueError("variable_def and initial_value are mutually exclusive."); |
|
|
|
_init_from_proto(variable_def, import_scope: import_scope); |
|
|
|
} |
|
|
|
else |
|
|
|
{ |
|
|
|
_init_from_args(initial_value, trainable, collections, validate_shape, caching_device, name, dtype); |
|
|
|
} |
|
|
|
} |
|
|
|
|
|
|
|
private void _init_from_proto(VariableDef variable_def, string import_scope = "") |
|
|
|
{ |
|
|
|
var g = ops.get_default_graph(); |
|
|
|
|
|
|
|
_variable = g.as_graph_element( |
|
|
|
ops.prepend_name_scope(variable_def.VariableName, |
|
|
|
import_scope: import_scope)) as Tensor; |
|
|
|
|
|
|
|
_initializer_op = g.as_graph_element( |
|
|
|
ops.prepend_name_scope(variable_def.InitializerName, |
|
|
|
import_scope: import_scope)) as Operation; |
|
|
|
|
|
|
|
// Tests whether initial_value_name exists first for backwards compatibility. |
|
|
|
if (!string.IsNullOrEmpty(variable_def.InitialValueName)) |
|
|
|
_initial_value = g.as_graph_element( |
|
|
|
ops.prepend_name_scope(variable_def.InitialValueName, |
|
|
|
import_scope: import_scope)) as Tensor; |
|
|
|
else |
|
|
|
_initial_value = null; |
|
|
|
|
|
|
|
_trainable = variable_def.Trainable; |
|
|
|
_snapshot = g.as_graph_element( |
|
|
|
ops.prepend_name_scope(variable_def.SnapshotName, |
|
|
|
import_scope: import_scope)) as Tensor; |
|
|
|
|
|
|
|
if (variable_def.SaveSliceInfoDef != null) |
|
|
|
throw new NotImplementedException("save_slice_info_def"); |
|
|
|
else |
|
|
|
;// _save_slice_info = null; |
|
|
|
|
|
|
|
//_caching_device = null; |
|
|
|
//_constraint = null; |
|
|
|
} |
|
|
|
|
|
|
|
private void _init_from_args(object initial_value, |
|
|
|