@@ -51,17 +51,13 @@ namespace Tensorflow | |||||
return handle == IntPtr.Zero ? String.Empty : Marshal.PtrToStringAnsi(handle); | return handle == IntPtr.Zero ? String.Empty : Marshal.PtrToStringAnsi(handle); | ||||
} | } | ||||
public unsafe static byte[] ByteStringPiece(IntPtr handle) | |||||
public unsafe static byte[] ByteStringPiece(Buffer? handle) | |||||
{ | { | ||||
byte* str_data = (byte*)handle.ToPointer(); | |||||
List<byte> bytes = new List<byte>(); | |||||
byte current = 255; | |||||
while (current != ((byte)'\0')) | |||||
{ | |||||
current = *(str_data++); | |||||
bytes.Add(current); | |||||
if(handle is null){ | |||||
return new byte[0]; | |||||
} | } | ||||
return bytes.Take(bytes.Count - 1).ToArray(); | |||||
var data = handle.ToArray(); | |||||
return data; | |||||
} | } | ||||
[UnmanagedFunctionPointer(CallingConvention.Winapi)] | [UnmanagedFunctionPointer(CallingConvention.Winapi)] | ||||
@@ -10,7 +10,7 @@ namespace Tensorflow | |||||
[DllImport(TensorFlowLibName)] | [DllImport(TensorFlowLibName)] | ||||
public static extern void TFC_SetAttr(SafeGraphHandle graph, IntPtr op, string attr_name, SafeBufferHandle attr_value_proto, SafeStatusHandle status); | public static extern void TFC_SetAttr(SafeGraphHandle graph, IntPtr op, string attr_name, SafeBufferHandle attr_value_proto, SafeStatusHandle status); | ||||
[DllImport(TensorFlowLibName)] | [DllImport(TensorFlowLibName)] | ||||
public static extern IntPtr TFC_GetHandleShapeAndType(SafeGraphHandle c_graph, TF_Output output); | |||||
public static extern SafeBufferHandle TFC_GetHandleShapeAndType(SafeGraphHandle c_graph, TF_Output output); | |||||
[DllImport(TensorFlowLibName)] | [DllImport(TensorFlowLibName)] | ||||
public static extern void TFC_SetHandleShapeAndType(SafeGraphHandle c_graph, TF_Output output, byte[] data, long proto_len, SafeStatusHandle status); | public static extern void TFC_SetHandleShapeAndType(SafeGraphHandle c_graph, TF_Output output, byte[] data, long proto_len, SafeStatusHandle status); | ||||
} | } | ||||
@@ -0,0 +1,25 @@ | |||||
using Tensorflow; | |||||
internal static class GraphOnlyOps | |||||
{ | |||||
/// <summary> | |||||
/// Graph-only version of tf.compat.v1.placeholder(), for internal use only. | |||||
/// </summary> | |||||
/// <param name="dtyype"></param> | |||||
/// <param name="shape"></param> | |||||
/// <param name="name"></param> | |||||
/// <returns></returns> | |||||
internal static Tensor graph_placeholder(TF_DataType dtype, Shape shape, string? name = null) | |||||
{ | |||||
var dtype_value = new AttrValue() { Type = dtype.as_datatype_enum() }; | |||||
var shape_value = new AttrValue() { Shape = shape.as_proto() }; | |||||
var g = ops.get_default_graph(); | |||||
Dictionary<string, AttrValue> attrs = new(); | |||||
attrs["dtype"] = dtype_value; | |||||
attrs["shape"] = shape_value; | |||||
var op = g.create_op("Placeholder", new Tensor[0], new TF_DataType[] { dtype }, | |||||
new TF_DataType[0], attrs: attrs, name: name); | |||||
var result = op.outputs[0]; | |||||
return result; | |||||
} | |||||
} |
@@ -544,12 +544,12 @@ public class FuncGraph : Graph, IDisposable | |||||
Tensor placeholder; | Tensor placeholder; | ||||
try | try | ||||
{ | { | ||||
placeholder = tf.placeholder(tensor.dtype, tensor.shape, name); | |||||
placeholder = GraphOnlyOps.graph_placeholder(tensor.dtype, tensor.shape, name); | |||||
} | } | ||||
catch (ValueError) | |||||
catch (ValueError ex) | |||||
{ | { | ||||
// TODO(Rinne): Add warning here. | |||||
placeholder = tf.placeholder(tensor.dtype, tensor.shape); | |||||
tf.Logger.Warning(ex.ToString()); | |||||
placeholder = GraphOnlyOps.graph_placeholder(tensor.dtype, tensor.shape); | |||||
} | } | ||||
handle_data_util.copy_handle_data(tensor, placeholder); | handle_data_util.copy_handle_data(tensor, placeholder); | ||||
if (name is not null) | if (name is not null) | ||||
@@ -575,12 +575,12 @@ public class FuncGraph : Graph, IDisposable | |||||
Tensor placeholder; | Tensor placeholder; | ||||
try | try | ||||
{ | { | ||||
placeholder = tf.placeholder(spec.dtype, spec.shape, requested_name); | |||||
placeholder = GraphOnlyOps.graph_placeholder(spec.dtype, spec.shape, requested_name); | |||||
} | } | ||||
catch (ValueError) | catch (ValueError) | ||||
{ | { | ||||
// TODO(Rinne): Add warning here. | // TODO(Rinne): Add warning here. | ||||
placeholder = tf.placeholder(spec.dtype, spec.shape); | |||||
placeholder = GraphOnlyOps.graph_placeholder(spec.dtype, spec.shape); | |||||
} | } | ||||
if (name is not null) | if (name is not null) | ||||
{ | { | ||||
@@ -31,7 +31,7 @@ namespace Tensorflow.Operations | |||||
} | } | ||||
else | else | ||||
{ | { | ||||
return ops.convert_to_tensor(shape); | |||||
return ops.convert_to_tensor(shape, dtype: dtypes.int32); | |||||
} | } | ||||
} | } | ||||
@@ -38,9 +38,9 @@ namespace Tensorflow.Operations | |||||
int len_orig_loop_vars = orig_loop_vars.Length; | int len_orig_loop_vars = orig_loop_vars.Length; | ||||
loop_vars = _tensor_array_to_flow(loop_vars); | loop_vars = _tensor_array_to_flow(loop_vars); | ||||
loop_vars = Nest.MapStructure(x => _convert_to_tensor_or_indexed_slices(x, TF_DataType.DtInvalid, null), loop_vars).ToTensors(); | |||||
loop_vars = Nest.MapStructure(x => _convert_to_tensor_or_indexed_slices(x), loop_vars).ToTensors(); | |||||
var loop_vars_signature = Nest.MapStructure(x => new TensorSpec(x.shape, x.dtype), _tensor_array_to_flow(loop_vars)); | |||||
var loop_vars_signature = Nest.MapStructure(x => new TensorSpec(x.shape, x.dtype), loop_vars); | |||||
var flat_shape_invariants = Nest.Flatten(loop_vars_signature).Select(x => x.shape).ToArray(); | var flat_shape_invariants = Nest.Flatten(loop_vars_signature).Select(x => x.shape).ToArray(); | ||||
@@ -379,10 +379,9 @@ namespace Tensorflow.Operations | |||||
return cond_graph.unique_name(cond_graph.Name + "___redundant_placeholder"); | return cond_graph.unique_name(cond_graph.Name + "___redundant_placeholder"); | ||||
} | } | ||||
private static Tensor _convert_to_tensor_or_indexed_slices(Tensor value, TF_DataType dtype, | |||||
string name) | |||||
private static Tensor _convert_to_tensor_or_indexed_slices(Tensor value) | |||||
{ | { | ||||
return ops.convert_to_tensor(value, dtype, name, false); | |||||
return ops.convert_to_tensor(value, as_ref: false); | |||||
} | } | ||||
private static Tensor _build_maximum_iterations_loop_var(int maximum_iterations = -1) | private static Tensor _build_maximum_iterations_loop_var(int maximum_iterations = -1) | ||||
@@ -576,7 +576,8 @@ namespace Tensorflow | |||||
public static HandleData get_resource_handle_data(Tensor graph_op) | public static HandleData get_resource_handle_data(Tensor graph_op) | ||||
{ | { | ||||
var handle_data = c_api.TFC_GetHandleShapeAndType(graph_op.graph.c_graph, graph_op._as_tf_output()); | var handle_data = c_api.TFC_GetHandleShapeAndType(graph_op.graph.c_graph, graph_op._as_tf_output()); | ||||
return HandleData.Parser.ParseFrom(c_api.ByteStringPiece(handle_data)); | |||||
var handle_str = c_api.ByteStringPiece(handle_data.DangerousGetHandle() == IntPtr.Zero ? null : new Buffer(handle_data)); | |||||
return HandleData.Parser.ParseFrom(handle_str); | |||||
} | } | ||||
public static void dismantle_graph(Graph graph) | public static void dismantle_graph(Graph graph) | ||||