diff --git a/src/TensorFlowNET.Core/Eager/EagerTensor.Creation.cs b/src/TensorFlowNET.Core/Eager/EagerTensor.Creation.cs index 5733e08d..88b9280a 100644 --- a/src/TensorFlowNET.Core/Eager/EagerTensor.Creation.cs +++ b/src/TensorFlowNET.Core/Eager/EagerTensor.Creation.cs @@ -66,15 +66,30 @@ namespace Tensorflow.Eager Tensor placeholder = null; tf_with(ops.control_dependencies(null), delegate { - placeholder = tf.placeholder(dtype, shape: shape, name: name ?? this.name); + placeholder = tf.placeholder(dtype, name: name); }); - // custom_gradient.copy_handle_data(value, placeholder) + copy_handle_data(placeholder); return placeholder; } - void copy_handle_data() + public Tensor AsContatnt(string name = null) { + Tensor constant = null; + tf_with(ops.control_dependencies(null), delegate + { + constant = tf.constant(numpy(), name: name); + }); + return constant; + } + void copy_handle_data(Tensor target_t) + { + if(target_t.dtype == TF_DataType.TF_RESOURCE || + target_t.dtype == TF_DataType.TF_VARIANT) + { + // need to export + // c_api.TF_GraphSetOutputHandleShapesAndTypes(target_t.graph, target_t._as_tf_output(), 0, new IntPtr[0], new int[0], new DataType[0], tf.Status.Handle); + } } public override IntPtr ToPointer()