Browse Source

_read_variable_op

tags/v0.12
Oceania2018 6 years ago
parent
commit
267d775ab1
4 changed files with 40 additions and 6 deletions
  1. +1
    -1
      src/TensorFlowNET.Core/Keras/Layers/Layer.cs
  2. +9
    -1
      src/TensorFlowNET.Core/Operations/array_ops.py.cs
  3. +20
    -2
      src/TensorFlowNET.Core/Operations/gen_resource_variable_ops.cs
  4. +10
    -2
      src/TensorFlowNET.Core/Variables/ResourceVariable.cs

+ 1
- 1
src/TensorFlowNET.Core/Keras/Layers/Layer.cs View File

@@ -224,7 +224,7 @@ namespace Tensorflow.Keras.Layers
overwrite: true,
initializer: initializer,
trainable: trainable.Value);
backend.track_variable(variable);
//backend.track_variable(variable);
_trainable_weights.Add(variable);

return variable;


+ 9
- 1
src/TensorFlowNET.Core/Operations/array_ops.py.cs View File

@@ -601,7 +601,15 @@ namespace Tensorflow
}
public static Tensor gather<T1, T2>(T1 @params, T2 indices, string name = null, int axis = 0)
=> gen_array_ops.gather_v2(@params, indices, axis, name: name);
{
if (axis != 0)
return gen_array_ops.gather_v2(@params, indices, axis, name: name);
if (@params is ResourceVariable variable)
return variable.sparse_read();
return gen_array_ops.gather_v2(@params, indices, axis, name: name);
}
public static Tensor transpose<T1, T2>(T1 a, T2 perm, string name = "transpose", bool conjugate = false)
{


+ 20
- 2
src/TensorFlowNET.Core/Operations/gen_resource_variable_ops.cs View File

@@ -31,7 +31,7 @@ namespace Tensorflow
{
var _op = _op_def_lib._apply_op_helper("VarIsInitializedOp", name, new { resource });

return _op;
return _op.output;
}

/// <summary>
@@ -53,7 +53,25 @@ namespace Tensorflow
shared_name
});

return _op;
return _op.output;
}

/// <summary>
/// Reads the value of a variable.
/// </summary>
/// <param name="resource"></param>
/// <param name="dtype"></param>
/// <param name="name"></param>
/// <returns></returns>
public static Tensor read_variable_op(Tensor resource, TF_DataType dtype, string name = null)
{
var _op = _op_def_lib._apply_op_helper("ReadVariableOp", name, new
{
resource,
dtype
});

return _op.output;
}
}
}

+ 10
- 2
src/TensorFlowNET.Core/Variables/ResourceVariable.cs View File

@@ -149,12 +149,20 @@ namespace Tensorflow
// messages.
tf_with(ops.name_scope("Read"), delegate
{
var value = _read_variable_op();
_graph_element = value;
});
}

ops.add_to_collections(collections, this);
});
}

throw new NotImplementedException("");
private Tensor _read_variable_op()
{
var result = gen_resource_variable_ops.read_variable_op(_handle, _dtype);
// _maybe_set_handle_data(_dtype, _handle, result);
return result;
}

private void _init_from_proto(VariableDef variable_def, string import_scope = null)


Loading…
Cancel
Save