@@ -224,7 +224,7 @@ namespace Tensorflow.Keras.Layers | |||||
overwrite: true, | overwrite: true, | ||||
initializer: initializer, | initializer: initializer, | ||||
trainable: trainable.Value); | trainable: trainable.Value); | ||||
backend.track_variable(variable); | |||||
//backend.track_variable(variable); | |||||
_trainable_weights.Add(variable); | _trainable_weights.Add(variable); | ||||
return variable; | return variable; | ||||
@@ -601,7 +601,15 @@ namespace Tensorflow | |||||
} | } | ||||
public static Tensor gather<T1, T2>(T1 @params, T2 indices, string name = null, int axis = 0) | public static Tensor gather<T1, T2>(T1 @params, T2 indices, string name = null, int axis = 0) | ||||
=> gen_array_ops.gather_v2(@params, indices, axis, name: name); | |||||
{ | |||||
if (axis != 0) | |||||
return gen_array_ops.gather_v2(@params, indices, axis, name: name); | |||||
if (@params is ResourceVariable variable) | |||||
return variable.sparse_read(); | |||||
return gen_array_ops.gather_v2(@params, indices, axis, name: name); | |||||
} | |||||
public static Tensor transpose<T1, T2>(T1 a, T2 perm, string name = "transpose", bool conjugate = false) | public static Tensor transpose<T1, T2>(T1 a, T2 perm, string name = "transpose", bool conjugate = false) | ||||
{ | { | ||||
@@ -31,7 +31,7 @@ namespace Tensorflow | |||||
{ | { | ||||
var _op = _op_def_lib._apply_op_helper("VarIsInitializedOp", name, new { resource }); | var _op = _op_def_lib._apply_op_helper("VarIsInitializedOp", name, new { resource }); | ||||
return _op; | |||||
return _op.output; | |||||
} | } | ||||
/// <summary> | /// <summary> | ||||
@@ -53,7 +53,25 @@ namespace Tensorflow | |||||
shared_name | shared_name | ||||
}); | }); | ||||
return _op; | |||||
return _op.output; | |||||
} | |||||
/// <summary> | |||||
/// Reads the value of a variable. | |||||
/// </summary> | |||||
/// <param name="resource"></param> | |||||
/// <param name="dtype"></param> | |||||
/// <param name="name"></param> | |||||
/// <returns></returns> | |||||
public static Tensor read_variable_op(Tensor resource, TF_DataType dtype, string name = null) | |||||
{ | |||||
var _op = _op_def_lib._apply_op_helper("ReadVariableOp", name, new | |||||
{ | |||||
resource, | |||||
dtype | |||||
}); | |||||
return _op.output; | |||||
} | } | ||||
} | } | ||||
} | } |
@@ -149,12 +149,20 @@ namespace Tensorflow | |||||
// messages. | // messages. | ||||
tf_with(ops.name_scope("Read"), delegate | tf_with(ops.name_scope("Read"), delegate | ||||
{ | { | ||||
var value = _read_variable_op(); | |||||
_graph_element = value; | |||||
}); | }); | ||||
} | } | ||||
ops.add_to_collections(collections, this); | |||||
}); | }); | ||||
} | |||||
throw new NotImplementedException(""); | |||||
private Tensor _read_variable_op() | |||||
{ | |||||
var result = gen_resource_variable_ops.read_variable_op(_handle, _dtype); | |||||
// _maybe_set_handle_data(_dtype, _handle, result); | |||||
return result; | |||||
} | } | ||||
private void _init_from_proto(VariableDef variable_def, string import_scope = null) | private void _init_from_proto(VariableDef variable_def, string import_scope = null) | ||||