@@ -193,8 +193,7 @@ namespace Tensorflow | |||
Name = name | |||
}); | |||
throw new NotImplementedException(""); | |||
//return layer.apply(inputs).Item1; | |||
return layer.Apply(inputs); | |||
} | |||
/// <summary> | |||
@@ -66,8 +66,8 @@ namespace Tensorflow | |||
Tensor keep = null; | |||
if (keep_prob != null) | |||
keep = 1.0f - keep_prob; | |||
return nn_ops.dropout_v2(x, rate: rate.Value, noise_shape: noise_shape, seed: seed, name: name); | |||
var rate_tensor = rate.HasValue ? tf.constant(rate.Value) : keep; | |||
return nn_ops.dropout_v2(x, rate: rate_tensor, noise_shape: noise_shape, seed: seed, name: name); | |||
} | |||
/// <summary> | |||
@@ -150,7 +150,7 @@ namespace Tensorflow | |||
var variables = graph.get_collection<IVariableV1>(tf.GraphKeys.GLOBAL_VARIABLES, | |||
scope: scope_to_prepend_to_names); | |||
var var_list = new Dictionary<string, IVariableV1>(); | |||
// variables.ForEach(v => var_list[ops.strip_name_scope(v.Name, scope_to_prepend_to_names)] = v); | |||
variables.ForEach(v => var_list[ops.strip_name_scope(v.Name, scope_to_prepend_to_names)] = v); | |||
return (var_list, imported_return_elements); | |||
} | |||
@@ -277,6 +277,11 @@ namespace Tensorflow | |||
var proto = x_ref_var.to_proto(export_scope); | |||
col_def.BytesList.Value.Add(proto.ToByteString()); | |||
} | |||
else if(x is ResourceVariable x_res_var) | |||
{ | |||
var proto = x_res_var.to_proto(export_scope); | |||
col_def.BytesList.Value.Add(proto.ToByteString()); | |||
} | |||
} | |||
break; | |||
case List<RefVariable> collection_list: | |||
@@ -31,8 +31,23 @@ namespace Tensorflow | |||
/// <param name="output_func_def"></param> | |||
/// <param name="status"></param> | |||
[DllImport(TensorFlowLibName)] | |||
public static extern void TF_FunctionToFunctionDef(IntPtr func, IntPtr output_func_def, SafeStatusHandle status); | |||
public static extern void TF_FunctionToFunctionDef(IntPtr func, SafeBufferHandle output_func_def, SafeStatusHandle status); | |||
[DllImport(TensorFlowLibName)] | |||
public static extern IntPtr TF_GraphToFunction(IntPtr fn_body, string fn_name, | |||
bool append_hash_to_fn_name, | |||
int num_opers, IntPtr[] opers, | |||
int ninputs, TF_Output[] inputs, | |||
int noutputs, TF_Output[] outputs, | |||
IntPtr output_names, | |||
IntPtr opts, | |||
string description, | |||
SafeStatusHandle status); | |||
[DllImport(TensorFlowLibName)] | |||
public static extern IntPtr TF_FunctionName(IntPtr func); | |||
[DllImport(TensorFlowLibName)] | |||
public static extern void TF_GraphCopyFunction(IntPtr g, IntPtr func, IntPtr grad, SafeStatusHandle status); | |||
} | |||
} |
@@ -327,8 +327,9 @@ namespace Tensorflow.Gradients | |||
var output_shape = op.outputs[0]._shape_tuple(); | |||
Tensor result, factor_tensor; | |||
if(input_shape != null && | |||
output_shape != null) | |||
if(tf.executing_eagerly() | |||
&& input_shape != null | |||
&& output_shape != null) | |||
{ | |||
var input_size = np.prod(input_shape); | |||
var output_size = np.prod(output_shape); | |||
@@ -339,11 +340,7 @@ namespace Tensorflow.Gradients | |||
{ | |||
var input_shape_tensor = array_ops.shape(op.inputs[0]); | |||
var output_shape_tensor = array_ops.shape(op.outputs[0]); | |||
var factor = _safe_shape_div(math_ops.reduce_prod(input_shape_tensor), math_ops.reduce_prod(output_shape_tensor)); | |||
throw new NotImplementedException(""); | |||
#pragma warning disable CS0162 // Unreachable code detected | |||
factor_tensor = null; | |||
#pragma warning restore CS0162 // Unreachable code detected | |||
factor_tensor = _safe_shape_div(math_ops.reduce_prod(input_shape_tensor), math_ops.reduce_prod(output_shape_tensor)); | |||
} | |||
result = math_ops.truediv(sum_grad, math_ops.cast(factor_tensor, sum_grad.dtype)); | |||
@@ -128,10 +128,10 @@ namespace Tensorflow.Gradients | |||
[RegisterGradient("Conv2D")] | |||
public static Tensor[] _Conv2DGrad(Operation op, Tensor[] grads) | |||
{ | |||
var dilations = op.get_attr<int[]>("dilations"); | |||
var strides = op.get_attr<int[]>("strides"); | |||
var dilations = op.get_attr_list<int>("dilations"); | |||
var strides = op.get_attr_list<int>("strides"); | |||
var padding = op.get_attr<string>("padding"); | |||
var explicit_paddings = op.get_attr<int[]>("explicit_paddings"); | |||
var explicit_paddings = op.get_attr_list<int>("explicit_paddings"); | |||
var use_cudnn_on_gpu = op.get_attr<bool>("use_cudnn_on_gpu"); | |||
var data_format = op.get_attr<string>("data_format"); | |||
var shape = gen_array_ops.shape_n(new Tensor[] { op.inputs[0], op.inputs[1] }); | |||
@@ -287,8 +287,8 @@ namespace Tensorflow.Gradients | |||
op.inputs[0], | |||
op.outputs[0], | |||
grad, | |||
op.get_attr("ksize") as int[], | |||
op.get_attr("strides") as int[], | |||
op.get_attr_list<int>("ksize"), | |||
op.get_attr_list<int>("strides"), | |||
padding: op.get_attr("padding").ToString(), | |||
data_format: op.get_attr("data_format").ToString()) | |||
}; | |||
@@ -293,12 +293,6 @@ namespace Tensorflow | |||
_create_op_helper(op, compute_device); | |||
/*Console.Write($"create_op: {op_type} '{node_def.Name}'"); | |||
Console.Write($", inputs: {(inputs.Length == 0 ? "empty" : String.Join(", ", inputs.Select(x => x.name)))}"); | |||
Console.Write($", control_inputs: {(control_inputs.Length == 0 ? "empty" : String.Join(", ", control_inputs.Select(x => x.name)))}"); | |||
Console.Write($", outputs: {(op.outputs.Length == 0 ? "empty" : String.Join(", ", op.outputs.Select(x => x.name)))}"); | |||
Console.WriteLine();*/ | |||
return op; | |||
} | |||
@@ -139,7 +139,7 @@ namespace Tensorflow | |||
/// <param name="status">TF_Status*</param> | |||
[DllImport(TensorFlowLibName)] | |||
public static extern void TF_GraphToGraphDef(IntPtr graph, SafeBufferHandle output_graph_def, SafeStatusHandle status); | |||
/// <summary> | |||
/// Returns the number of dimensions of the Tensor referenced by `output` | |||
/// in `graph`. | |||
@@ -0,0 +1,16 @@ | |||
using System; | |||
using System.Collections.Generic; | |||
using System.Text; | |||
namespace Tensorflow.Keras.ArgsDefinition | |||
{ | |||
public class TensorLikeDataAdapterArgs | |||
{ | |||
public Tensor X { get; set; } | |||
public Tensor Y { get; set; } | |||
public int BatchSize { get; set; } | |||
public int Steps { get; set; } | |||
public int Epochs { get; set; } | |||
public bool Shuffle { get; set; } | |||
} | |||
} |
@@ -27,7 +27,9 @@ namespace Tensorflow.Keras.Engine.DataAdapters | |||
public DataHandler(DataHandlerArgs args) | |||
{ | |||
this.args = args; | |||
var adapter_cls = new TensorLikeDataAdapter(new TensorLikeDataAdapterArgs { }); | |||
} | |||
} | |||
} |
@@ -1,6 +1,7 @@ | |||
using System; | |||
using System.Collections.Generic; | |||
using System.Text; | |||
using Tensorflow.Keras.ArgsDefinition; | |||
using static Tensorflow.Binding; | |||
namespace Tensorflow.Keras.Engine.DataAdapters | |||
@@ -10,7 +11,7 @@ namespace Tensorflow.Keras.Engine.DataAdapters | |||
/// </summary> | |||
public class TensorLikeDataAdapter : IDataAdapter | |||
{ | |||
public TensorLikeDataAdapter() | |||
public TensorLikeDataAdapter(TensorLikeDataAdapterArgs args) | |||
{ | |||
tf.data.Dataset.range(5); | |||
} | |||
@@ -0,0 +1,21 @@ | |||
using System; | |||
using System.Collections.Generic; | |||
using System.Text; | |||
namespace Tensorflow.Keras.Engine | |||
{ | |||
public partial class Layer | |||
{ | |||
Dictionary<Layer, object> trainable_state; | |||
Dictionary<Layer, object> _get_trainable_state() | |||
{ | |||
trainable_state = new Dictionary<Layer, object>(); | |||
throw new NotImplementedException(""); | |||
} | |||
void _set_trainable_state(Dictionary<Layer, object> trainable_state) | |||
{ | |||
throw new NotImplementedException(""); | |||
} | |||
} | |||
} |
@@ -1,6 +1,7 @@ | |||
using NumSharp; | |||
using static Tensorflow.Binding; | |||
using System; | |||
using Tensorflow.Keras.ArgsDefinition; | |||
using Tensorflow.Keras.Engine.DataAdapters; | |||
using Tensorflow.Keras.Losses; | |||
using Tensorflow.Keras.Optimizers; | |||
@@ -21,6 +22,7 @@ namespace Tensorflow.Keras.Engine | |||
#pragma warning restore CS0108 // Member hides inherited member; missing new keyword | |||
string loss; | |||
IOptimizer optimizer; | |||
IVariableV1 _steps_per_execution; | |||
public Model(ModelArgs args) | |||
: base(args) | |||
@@ -37,10 +39,25 @@ namespace Tensorflow.Keras.Engine | |||
break; | |||
} | |||
int experimental_steps_per_execution = 1; | |||
_configure_steps_per_execution(experimental_steps_per_execution); | |||
_reset_compile_cache(); | |||
loss = lossName; | |||
_is_compiled = true; | |||
} | |||
void _configure_steps_per_execution(int steps_per_execution) | |||
{ | |||
_steps_per_execution = tf.Variable(steps_per_execution, | |||
dtype: TF_DataType.TF_INT64, | |||
aggregation: VariableAggregation.OnlyFirstReplica); | |||
} | |||
void _reset_compile_cache() | |||
{ | |||
// Prepare list of loss functions, same size of model outputs. | |||
} | |||
public void compile(string optimizerName, ILossFunc lossName) | |||
@@ -70,6 +87,20 @@ namespace Tensorflow.Keras.Engine | |||
int workers = 1, | |||
bool use_multiprocessing = false) | |||
{ | |||
var data_handler = new DataHandler(new DataHandlerArgs | |||
{ | |||
X = x, | |||
BatchSize = batch_size, | |||
StepsPerEpoch = steps, | |||
InitialEpoch = 0, | |||
Epochs = 1, | |||
MaxQueueSize = max_queue_size, | |||
Workers = workers, | |||
UseMultiprocessing = use_multiprocessing, | |||
Model = this, | |||
StepsPerExecution = _steps_per_execution | |||
}); | |||
throw new NotImplementedException(""); | |||
} | |||
} | |||
@@ -14,6 +14,7 @@ | |||
limitations under the License. | |||
******************************************************************************/ | |||
using System.Linq; | |||
using Tensorflow.Keras.ArgsDefinition; | |||
using Tensorflow.Keras.Engine; | |||
using static Tensorflow.Binding; | |||
@@ -44,6 +45,9 @@ namespace Tensorflow.Keras.Layers | |||
if (args.InputShape == null) | |||
args.InputShape = args.InputLength; | |||
if (args.BatchInputShape == null) | |||
args.BatchInputShape = new int[] { args.BatchSize }.Concat(args.InputShape.dims).ToArray(); | |||
embeddings_initializer = embeddings_initializer ?? tf.random_uniform_initializer; | |||
SupportsMasking = mask_zero; | |||
} | |||
@@ -34,10 +34,13 @@ namespace Tensorflow.Keras.Layers | |||
/// <summary> | |||
/// Turns positive integers (indexes) into dense vectors of fixed size. | |||
/// This layer can only be used as the first layer in a model. | |||
/// e.g. [[4], [20]] -> [[0.25, 0.1], [0.6, -0.2]] | |||
/// https://www.tensorflow.org/api_docs/python/tf/keras/layers/Embedding | |||
/// </summary> | |||
/// <param name="input_dim"></param> | |||
/// <param name="output_dim"></param> | |||
/// <param name="embeddings_initializer"></param> | |||
/// <param name="input_dim">Size of the vocabulary, i.e. maximum integer index + 1.</param> | |||
/// <param name="output_dim">Dimension of the dense embedding.</param> | |||
/// <param name="embeddings_initializer">Initializer for the embeddings matrix (see keras.initializers).</param> | |||
/// <param name="mask_zero"></param> | |||
/// <returns></returns> | |||
public Embedding Embedding(int input_dim, | |||
@@ -36,9 +36,9 @@ namespace Tensorflow.Operations.Initializers | |||
public Tensor Apply(InitializerArgs args) | |||
{ | |||
if (args.DType == TF_DataType.DtInvalid) | |||
args.DType = this.dtype; | |||
return random_ops.truncated_normal(args.Shape, mean, stddev, dtype : dtype, seed: seed); | |||
if (args.DType != TF_DataType.DtInvalid) | |||
dtype = args.DType; | |||
return random_ops.truncated_normal(args.Shape, mean, stddev, dtype: dtype, seed: seed); | |||
} | |||
} | |||
} |
@@ -230,6 +230,35 @@ namespace Tensorflow | |||
public virtual T get_attr<T>(string name) | |||
=> (T)get_attr(name); | |||
public virtual T[] get_attr_list<T>(string name) | |||
{ | |||
if (tf.executing_eagerly()) | |||
return (T[])get_attr(name); | |||
AttrValue x = null; | |||
lock (Locks.ProcessWide) | |||
{ | |||
using var buf = new Buffer(); | |||
c_api.TF_OperationGetAttrValueProto(_handle, name, buf.Handle, tf.Status.Handle); | |||
tf.Status.Check(true); | |||
x = AttrValue.Parser.ParseFrom(buf.DangerousMemoryBlock.Stream()); | |||
} | |||
string oneof_value = x.ValueCase.ToString(); | |||
if (string.IsNullOrEmpty(oneof_value)) | |||
return null; | |||
switch (typeof(T).Name) | |||
{ | |||
case nameof(Int32): | |||
return x.List.I.Select(x => (T)Convert.ChangeType(x, typeof(T))).ToArray(); | |||
default: | |||
return null; | |||
} | |||
} | |||
public virtual object get_attr(string name) | |||
{ | |||
AttrValue x = null; | |||
@@ -250,7 +279,7 @@ namespace Tensorflow | |||
if (oneof_value == "list") | |||
throw new NotImplementedException($"Unsupported field type in {x.ToString()}"); | |||
if (oneof_value == "type") | |||
if (string.Equals("type", oneof_value, StringComparison.OrdinalIgnoreCase)) | |||
return x.Type; | |||
object result = x.GetType().GetProperty(oneof_value).GetValue(x); | |||
@@ -85,26 +85,56 @@ namespace Tensorflow | |||
allow_broadcast: false); | |||
public static Tensor zeros(TensorShape shape, TF_DataType dtype = TF_DataType.TF_FLOAT, string name = null) | |||
=> tf_with(ops.name_scope(name, "zeros", shape), scope => | |||
{ | |||
dtype = dtype.as_base_dtype(); | |||
if (tf.executing_eagerly()) | |||
{ | |||
dtype = dtype.as_base_dtype(); | |||
name = scope; | |||
var shape_tensor = constant_op._tensor_shape_tensor_conversion_function(shape); | |||
Tensor zeros = null; | |||
switch (dtype) | |||
return tf_with(ops.name_scope(name, "zeros", shape), scope => | |||
{ | |||
case TF_DataType.TF_DOUBLE: | |||
zeros = constant(0d); | |||
break; | |||
case TF_DataType.TF_FLOAT: | |||
zeros = constant(0f); | |||
break; | |||
default: | |||
zeros = constant(0); | |||
break; | |||
} | |||
return fill(shape_tensor, zeros, name: name); | |||
}); | |||
name = scope; | |||
var shape_tensor = constant_op._tensor_shape_tensor_conversion_function(shape); | |||
Tensor zeros = null; | |||
switch (dtype) | |||
{ | |||
case TF_DataType.TF_DOUBLE: | |||
zeros = constant(0d); | |||
break; | |||
case TF_DataType.TF_FLOAT: | |||
zeros = constant(0f); | |||
break; | |||
default: | |||
zeros = constant(0); | |||
break; | |||
} | |||
return fill(shape_tensor, zeros, name: name); | |||
}); | |||
} | |||
else | |||
{ | |||
return tf_with(ops.name_scope(name, "zeros", shape), scope => | |||
{ | |||
name = scope; | |||
switch (dtype) | |||
{ | |||
case TF_DataType.TF_BOOL: | |||
return _constant_if_small(false, shape, dtype, name); | |||
case TF_DataType.TF_DOUBLE: | |||
return _constant_if_small(0.0D, shape, dtype, name); | |||
case TF_DataType.TF_FLOAT: | |||
return _constant_if_small(0.0F, shape, dtype, name); | |||
case TF_DataType.TF_INT64: | |||
return _constant_if_small(0l, shape, dtype, name); | |||
case TF_DataType.TF_INT32: | |||
return _constant_if_small(0, shape, dtype, name); | |||
case TF_DataType.TF_INT8: | |||
return _constant_if_small<byte>(0, shape, dtype, name); | |||
default: | |||
throw new TypeError("can't find type for zeros"); | |||
} | |||
}); | |||
} | |||
} | |||
public static Tensor boolean_mask<T1, T2>(T1 tensor, T2 mask, string name = "boolean_mask", int axis = 0) | |||
{ | |||
@@ -68,7 +68,7 @@ namespace Tensorflow | |||
/// <param name="seed"></param> | |||
/// <param name="name"></param> | |||
/// <returns></returns> | |||
public static Tensor dropout_v2(Tensor x, float rate, Tensor noise_shape = null, int? seed = null, string name = null) | |||
public static Tensor dropout_v2(Tensor x, Tensor rate, Tensor noise_shape = null, int? seed = null, string name = null) | |||
{ | |||
return tf_with(ops.name_scope(name, "dropout", x), scope => | |||
{ | |||
@@ -60,17 +60,17 @@ namespace Tensorflow.Train | |||
}); | |||
} | |||
public override Operation _apply_dense(Tensor grad, RefVariable var) | |||
public override Operation _apply_dense(Tensor grad, ResourceVariable var) | |||
{ | |||
var m = get_slot(var, "m"); | |||
var v = get_slot(var, "v"); | |||
var (beta1_power, beta2_power) = _get_beta_accumulators(); | |||
return gen_training_ops.apply_adam( | |||
var, | |||
m, | |||
v, | |||
math_ops.cast(beta1_power, var.dtype.as_base_dtype()), | |||
math_ops.cast(beta2_power, var.dtype.as_base_dtype()), | |||
var.Handle, | |||
m.Handle, | |||
v.Handle, | |||
math_ops.cast(beta1_power.Handle, var.dtype.as_base_dtype()), | |||
math_ops.cast(beta2_power.Handle, var.dtype.as_base_dtype()), | |||
math_ops.cast(_lr_t, var.dtype.as_base_dtype()), | |||
math_ops.cast(_beta1_t, var.dtype.as_base_dtype()), | |||
math_ops.cast(_beta2_t, var.dtype.as_base_dtype()), | |||
@@ -278,8 +278,16 @@ namespace Tensorflow | |||
public virtual Operation _apply_dense(Tensor grad, ResourceVariable var) | |||
{ | |||
var alpha = math_ops.cast(LearningRateTensor, var.dtype.as_base_dtype()); | |||
return gen_training_ops.resource_apply_gradient_descent(var.Handle, alpha, grad, use_locking: _use_locking).op; | |||
if (tf.executing_eagerly()) | |||
{ | |||
var alpha = math_ops.cast(LearningRateTensor, var.dtype.as_base_dtype()); | |||
return gen_training_ops.resource_apply_gradient_descent(var, alpha, grad, use_locking: _use_locking).op; | |||
} | |||
else | |||
{ | |||
var alpha = math_ops.cast(LearningRateTensor, var.dtype.as_base_dtype()); | |||
return gen_training_ops.apply_gradient_descent(var, alpha, grad, use_locking: _use_locking).op; | |||
} | |||
} | |||
public virtual Operation _apply_dense(Tensor grad, RefVariable var) | |||
@@ -314,6 +322,11 @@ namespace Tensorflow | |||
return _apply_sparse(gradient_no_duplicate_indices, var); | |||
} | |||
public virtual Operation _apply_sparse(IndexedSlices grad, ResourceVariable var) | |||
{ | |||
throw new NotImplementedException("_apply_sparse"); | |||
} | |||
public virtual Operation _apply_sparse(IndexedSlices grad, RefVariable var) | |||
{ | |||
throw new NotImplementedException("_apply_sparse"); | |||
@@ -224,7 +224,7 @@ namespace Tensorflow | |||
var saveable_tensors = all_tensors.Skip(idx).Take(saveable.specs.Length); | |||
idx += saveable.specs.Length; | |||
var restored = saveable.restore(saveable_tensors.ToArray(), shapes == null ? null : shapes.ToArray()); | |||
assign_ops.Add(restored as ITensorOrOperation); | |||
assign_ops.Add(restored); | |||
} | |||
return control_flow_ops.group(assign_ops.ToArray(), name: name); | |||
@@ -13,6 +13,7 @@ | |||
See the License for the specific language governing permissions and | |||
limitations under the License. | |||
******************************************************************************/ | |||
using static Tensorflow.Binding; | |||
namespace Tensorflow | |||
{ | |||
@@ -67,9 +67,7 @@ namespace Tensorflow | |||
{ | |||
ops.init_scope(); | |||
var variable = ops.internal_convert_to_tensor(op, as_ref: true); | |||
if (variable.op.type == "Variable" || | |||
variable.op.type == "VariableV2" || | |||
variable.op.type == "AutoReloadVariable") | |||
if (variable.dtype.is_ref_dtype()) | |||
yield return new ReferenceVariableSaveable(variable, "", name); | |||
else | |||
yield return new ResourceVariableSaveable(variable, "", name); | |||
@@ -102,7 +100,7 @@ namespace Tensorflow | |||
if (convert_variable_to_tensor) | |||
{ | |||
if (var is ResourceVariable) | |||
if (!var.dtype.is_ref_dtype()) | |||
tensor = var.GraphElement; | |||
else | |||
tensor = ops.internal_convert_to_tensor(var, as_ref: true); | |||
@@ -41,7 +41,7 @@ namespace Tensorflow | |||
throw new NotImplementedException(""); | |||
} | |||
public static Tensor apply_adam(IVariableV1 var, IVariableV1 m, IVariableV1 v, Tensor beta1_power, Tensor beta2_power, | |||
public static Tensor apply_adam(Tensor var, Tensor m, Tensor v, Tensor beta1_power, Tensor beta2_power, | |||
Tensor lr, Tensor beta1, Tensor beta2, Tensor epsilon, Tensor grad, | |||
bool use_locking = false, bool use_nesterov = false, string name = null) | |||
{ | |||
@@ -64,7 +64,7 @@ namespace Tensorflow | |||
return _op.outputs[0]; | |||
} | |||
public static Tensor apply_gradient_descent(RefVariable var, Tensor alpha, Tensor delta, bool use_locking = false, string name = null) | |||
public static Tensor apply_gradient_descent(IVariableV1 var, Tensor alpha, Tensor delta, bool use_locking = false, string name = null) | |||
{ | |||
var _op = tf.OpDefLib._apply_op_helper("ApplyGradientDescent", name, new | |||
{ | |||
@@ -82,7 +82,7 @@ namespace Tensorflow | |||
if (tf.executing_eagerly()) | |||
{ | |||
var result = tf.Runner.TFE_FastPathExecute(tf.Context, tf.Context.DeviceName, | |||
"ResourceApplyGradientDescent", name, | |||
"ResourceApplyGradientDescent", name, | |||
null, | |||
var, alpha, delta, | |||
"use_locking", use_locking); | |||
@@ -28,6 +28,8 @@ namespace Tensorflow | |||
protected Tensor _initial_value; | |||
public Tensor initial_value => _initial_value; | |||
public Operation initializer => initializer_op; | |||
protected Tensor _parent_op; | |||
public Tensor parent_op => _parent_op; | |||
@@ -73,6 +75,14 @@ namespace Tensorflow | |||
public ITensorOrOperation assign<T>(T value, bool use_locking = false, string name = null, bool read_value = true) | |||
{ | |||
if(value.GetType() == typeof(Tensor)) | |||
{ | |||
var assign = gen_state_ops.assign(handle, value, use_locking: use_locking, name: name); | |||
if (read_value) | |||
return assign; | |||
return assign.op; | |||
} | |||
var value_tensor = ops.convert_to_tensor(value, dtype: dtype); | |||
var assign_op = gen_resource_variable_ops.assign_variable_op( | |||
handle, value_tensor, name: name); | |||
@@ -82,7 +92,7 @@ namespace Tensorflow | |||
return assign_op; | |||
} | |||
public Tensor value() => _read_variable_op(); | |||
public Tensor value() => tf.executing_eagerly() ? _read_variable_op() : GraphElement; | |||
protected Tensor _read_variable_op() | |||
{ | |||
@@ -149,6 +159,7 @@ namespace Tensorflow | |||
{ | |||
} | |||
public Tensor AsTensor() => read_value(); | |||
public Tensor AsTensor() | |||
=> tf.executing_eagerly() ? read_value() : GraphElement; | |||
} | |||
} |
@@ -33,10 +33,16 @@ namespace Tensorflow | |||
{ | |||
public string UniqueId { get; } | |||
public string Name { get; } | |||
/// <summary> | |||
/// Handle is ref type | |||
/// </summary> | |||
public Tensor Handle { get; } | |||
public string Device { get; } | |||
public Operation Initializer { get; } | |||
public Operation Op { get; } | |||
/// <summary> | |||
/// GraphElement is a copy of Handle | |||
/// </summary> | |||
public Tensor GraphElement { get; } | |||
public Graph Graph { get; } | |||
public TF_DataType dtype { get; } | |||
@@ -1,5 +1,6 @@ | |||
using System; | |||
using Tensorflow.Eager; | |||
using static Tensorflow.Binding; | |||
namespace Tensorflow | |||
{ | |||
@@ -21,11 +22,6 @@ namespace Tensorflow | |||
public static implicit operator EagerTensor(ResourceVariable var) | |||
=> var._dense_var_to_tensor() as EagerTensor; | |||
public static implicit operator RefVariable(ResourceVariable var) | |||
{ | |||
return null; | |||
} | |||
public static implicit operator IntPtr(ResourceVariable var) | |||
=> var._handle; | |||
@@ -35,5 +31,13 @@ namespace Tensorflow | |||
{ | |||
return value(); | |||
} | |||
public Tensor _TensorConversionFunction(TF_DataType dtype = TF_DataType.DtInvalid, string name = null, bool as_ref = false) | |||
{ | |||
if (as_ref) | |||
return handle; | |||
else | |||
return tf.executing_eagerly() ? AsTensor() : value(); | |||
} | |||
} | |||
} |
@@ -49,6 +49,7 @@ namespace Tensorflow | |||
VariableDef variable_def = null, | |||
TF_DataType dtype = TF_DataType.DtInvalid, | |||
string import_scope = "", | |||
VariableAggregation aggregation = VariableAggregation.None, | |||
TensorShape shape = null) | |||
{ | |||
if (variable_def != null) | |||
@@ -65,6 +66,7 @@ namespace Tensorflow | |||
caching_device: caching_device, | |||
name: name, | |||
dtype: dtype, | |||
aggregation: aggregation, | |||
shape: shape); | |||
} | |||
} | |||
@@ -75,6 +77,7 @@ namespace Tensorflow | |||
string caching_device = "", | |||
string name = null, | |||
TF_DataType dtype = TF_DataType.DtInvalid, | |||
VariableAggregation aggregation = VariableAggregation.None, | |||
TensorShape shape = null) | |||
{ | |||
var init_from_fn = initial_value.GetType().Name == "Func`1" || | |||
@@ -114,55 +117,43 @@ namespace Tensorflow | |||
if (initial_value.GetType().GetInterface("IInitializer") != null) | |||
initial_value = ops.convert_to_tensor((initial_value as IInitializer).Apply(new InitializerArgs(shape, dtype: dtype))); | |||
else | |||
initial_value = ops.convert_to_tensor(init_from_fn ? (initial_value as Func<Tensor>)() : initial_value, | |||
{ | |||
var value = init_from_fn ? (initial_value as Func<Tensor>)() : initial_value; | |||
initial_value = ops.convert_to_tensor(value, | |||
name: "initial_value", | |||
dtype: dtype); | |||
} | |||
}); | |||
_shape = shape ?? (initial_value as Tensor).TensorShape; | |||
_initial_value = initial_value as Tensor; | |||
handle = resource_variable_ops.eager_safe_variable_handle( | |||
initial_value: _initial_value, | |||
shape: _shape, | |||
shared_name: shared_name, | |||
name: name, | |||
graph_mode: _in_graph_mode); | |||
_dtype = _initial_value.dtype.as_base_dtype(); | |||
if (_in_graph_mode) | |||
{ | |||
tf_with(ops.name_scope("IsInitialized"), delegate | |||
{ | |||
is_initialized_op = gen_resource_variable_ops.var_is_initialized_op(handle); | |||
}); | |||
if(initial_value != null) | |||
{ | |||
tf_with(ops.name_scope("Assign"), scope1 => | |||
{ | |||
string n = scope1; | |||
var _initial_value2 = variables._try_guard_against_uninitialized_dependencies(name, _initial_value); | |||
initializer_op = gen_resource_variable_ops.assign_variable_op(handle, _initial_value2, name: n); | |||
}); | |||
} | |||
handle = state_ops.variable_op_v2(_initial_value.shape, _initial_value.dtype.as_base_dtype(), name: name); | |||
initializer_op = gen_state_ops.assign(handle, _initial_value, true).op; | |||
// Manually assign reads to the handle's device to avoid log | |||
// messages. | |||
tf_with(ops.name_scope("Read"), delegate | |||
{ | |||
var value = gen_resource_variable_ops.read_variable_op(handle, _dtype); | |||
// _maybe_set_handle_data(dtype, handle, value); | |||
_graph_element = value; | |||
}); | |||
ops.colocate_with(initializer_op); | |||
_graph_element = gen_array_ops.identity(handle, name = "read"); | |||
ops.add_to_collections<IVariableV1>(collections, this); | |||
_dtype = handle.dtype; | |||
} | |||
else | |||
{ | |||
handle = resource_variable_ops.eager_safe_variable_handle( | |||
initial_value: _initial_value, | |||
shape: _shape, | |||
shared_name: shared_name, | |||
name: name, | |||
graph_mode: _in_graph_mode); | |||
gen_resource_variable_ops.assign_variable_op(handle, _initial_value); | |||
is_initialized_op = null; | |||
initializer_op = null; | |||
_graph_element = null; | |||
_dtype = _initial_value.dtype.as_base_dtype(); | |||
initial_value = _in_graph_mode ? initial_value : null; | |||
} | |||
@@ -237,5 +228,23 @@ namespace Tensorflow | |||
return array_ops.identity(value); | |||
}); | |||
} | |||
public VariableDef to_proto(string export_scope) | |||
{ | |||
if (string.IsNullOrEmpty(export_scope) || Handle.name.StartsWith(export_scope)) | |||
{ | |||
var var_def = new VariableDef(); | |||
var_def.VariableName = ops.strip_name_scope(Handle.name, export_scope); | |||
if (_initial_value != null) | |||
var_def.InitialValueName = ops.strip_name_scope(_initial_value.name, export_scope); | |||
var_def.Trainable = _trainable; | |||
var_def.InitializerName = ops.strip_name_scope(initializer.name, export_scope); | |||
var_def.SnapshotName = ops.strip_name_scope(_graph_element.name, export_scope); | |||
return var_def; | |||
} | |||
throw new NotImplementedException("to_proto RefVariable"); | |||
} | |||
} | |||
} |
@@ -467,7 +467,7 @@ namespace Tensorflow | |||
case RefVariable varVal: | |||
return varVal._TensorConversionFunction(dtype: dtype, name: name, as_ref: as_ref); | |||
case ResourceVariable varVal: | |||
return varVal.value(); | |||
return varVal._TensorConversionFunction(dtype: dtype, name: name, as_ref: as_ref); | |||
case TensorShape ts: | |||
return constant_op.constant(ts.dims, dtype: dtype, name: name); | |||
case int[] dims: | |||
@@ -70,12 +70,14 @@ namespace Tensorflow | |||
bool use_resource = true, | |||
string name = null, | |||
TF_DataType dtype = TF_DataType.DtInvalid, | |||
VariableAggregation aggregation = VariableAggregation.None, | |||
int[] shape = null) | |||
=> new ResourceVariable(data, | |||
trainable: trainable, | |||
validate_shape: validate_shape, | |||
name: name, | |||
dtype: dtype, | |||
aggregation: aggregation, | |||
shape: shape); | |||
public Tensor placeholder(TF_DataType dtype, TensorShape shape = null, string name = null) | |||