Browse Source

Merge pull request #1113 from SciSharp/rnn-dev

fix: error pf training RNN on Linux
tags/v0.110.0-LSTM-Model
Haiping GitHub 2 years ago
parent
commit
7449fa5c5b
No known key found for this signature in database GPG Key ID: 4AEE18F83AFDEB23
8 changed files with 59 additions and 17 deletions
  1. +13
    -2
      src/TensorFlowNET.Core/APIs/c_api.cs
  2. +1
    -1
      src/TensorFlowNET.Core/APIs/c_api.customize.cs
  3. +25
    -0
      src/TensorFlowNET.Core/Eager/GraphOnlyOps.cs
  4. +6
    -6
      src/TensorFlowNET.Core/Graphs/FuncGraph.cs
  5. +1
    -1
      src/TensorFlowNET.Core/Operations/list_ops.cs
  6. +4
    -5
      src/TensorFlowNET.Core/Operations/while_v2.cs
  7. +8
    -1
      src/TensorFlowNET.Core/ops.cs
  8. +1
    -1
      tools/Tensorflow.UnitTest.RedistHolder/Tensorflow.UnitTest.RedistHolder.csproj

+ 13
- 2
src/TensorFlowNET.Core/APIs/c_api.cs View File

@@ -51,7 +51,17 @@ namespace Tensorflow
return handle == IntPtr.Zero ? String.Empty : Marshal.PtrToStringAnsi(handle);
}

public unsafe static byte[] ByteStringPiece(IntPtr handle)
public unsafe static byte[] ByteStringPiece(Buffer? handle)
{
if (handle is null)
{
return new byte[0];
}
var data = handle.ToArray();
return data;
}
public unsafe static byte[] ByteStringPieceFromNativeString(IntPtr handle)
{
if (handle == IntPtr.Zero)
{
@@ -66,7 +76,8 @@ namespace Tensorflow
current = *(str_data++);
bytes.Add(current);
}
return bytes.Take(bytes.Count - 1).ToArray();
var data = bytes.ToArray();
return data;
}

[UnmanagedFunctionPointer(CallingConvention.Winapi)]


+ 1
- 1
src/TensorFlowNET.Core/APIs/c_api.customize.cs View File

@@ -10,7 +10,7 @@ namespace Tensorflow
[DllImport(TensorFlowLibName)]
public static extern void TFC_SetAttr(SafeGraphHandle graph, IntPtr op, string attr_name, SafeBufferHandle attr_value_proto, SafeStatusHandle status);
[DllImport(TensorFlowLibName)]
public static extern IntPtr TFC_GetHandleShapeAndType(SafeGraphHandle c_graph, TF_Output output);
public static extern SafeBufferHandle TFC_GetHandleShapeAndType(SafeGraphHandle c_graph, TF_Output output);
[DllImport(TensorFlowLibName)]
public static extern void TFC_SetHandleShapeAndType(SafeGraphHandle c_graph, TF_Output output, byte[] data, long proto_len, SafeStatusHandle status);
}


+ 25
- 0
src/TensorFlowNET.Core/Eager/GraphOnlyOps.cs View File

@@ -0,0 +1,25 @@
using Tensorflow;

internal static class GraphOnlyOps
{
/// <summary>
/// Graph-only version of tf.compat.v1.placeholder(), for internal use only.
/// </summary>
/// <param name="dtyype"></param>
/// <param name="shape"></param>
/// <param name="name"></param>
/// <returns></returns>
internal static Tensor graph_placeholder(TF_DataType dtype, Shape shape, string? name = null)
{
var dtype_value = new AttrValue() { Type = dtype.as_datatype_enum() };
var shape_value = new AttrValue() { Shape = shape.as_proto() };
var g = ops.get_default_graph();
Dictionary<string, AttrValue> attrs = new();
attrs["dtype"] = dtype_value;
attrs["shape"] = shape_value;
var op = g.create_op("Placeholder", new Tensor[0], new TF_DataType[] { dtype },
new TF_DataType[0], attrs: attrs, name: name);
var result = op.outputs[0];
return result;
}
}

+ 6
- 6
src/TensorFlowNET.Core/Graphs/FuncGraph.cs View File

@@ -544,12 +544,12 @@ public class FuncGraph : Graph, IDisposable
Tensor placeholder;
try
{
placeholder = tf.placeholder(tensor.dtype, tensor.shape, name);
placeholder = GraphOnlyOps.graph_placeholder(tensor.dtype, tensor.shape, name);
}
catch (ValueError)
catch (ValueError ex)
{
// TODO(Rinne): Add warning here.
placeholder = tf.placeholder(tensor.dtype, tensor.shape);
tf.Logger.Warning(ex.ToString());
placeholder = GraphOnlyOps.graph_placeholder(tensor.dtype, tensor.shape);
}
handle_data_util.copy_handle_data(tensor, placeholder);
if (name is not null)
@@ -575,12 +575,12 @@ public class FuncGraph : Graph, IDisposable
Tensor placeholder;
try
{
placeholder = tf.placeholder(spec.dtype, spec.shape, requested_name);
placeholder = GraphOnlyOps.graph_placeholder(spec.dtype, spec.shape, requested_name);
}
catch (ValueError)
{
// TODO(Rinne): Add warning here.
placeholder = tf.placeholder(spec.dtype, spec.shape);
placeholder = GraphOnlyOps.graph_placeholder(spec.dtype, spec.shape);
}
if (name is not null)
{


+ 1
- 1
src/TensorFlowNET.Core/Operations/list_ops.cs View File

@@ -31,7 +31,7 @@ namespace Tensorflow.Operations
}
else
{
return ops.convert_to_tensor(shape);
return ops.convert_to_tensor(shape, dtype: dtypes.int32);
}
}



+ 4
- 5
src/TensorFlowNET.Core/Operations/while_v2.cs View File

@@ -38,9 +38,9 @@ namespace Tensorflow.Operations
int len_orig_loop_vars = orig_loop_vars.Length;

loop_vars = _tensor_array_to_flow(loop_vars);
loop_vars = Nest.MapStructure(x => _convert_to_tensor_or_indexed_slices(x, TF_DataType.DtInvalid, null), loop_vars).ToTensors();
loop_vars = Nest.MapStructure(x => _convert_to_tensor_or_indexed_slices(x), loop_vars).ToTensors();

var loop_vars_signature = Nest.MapStructure(x => new TensorSpec(x.shape, x.dtype), _tensor_array_to_flow(loop_vars));
var loop_vars_signature = Nest.MapStructure(x => new TensorSpec(x.shape, x.dtype), loop_vars);

var flat_shape_invariants = Nest.Flatten(loop_vars_signature).Select(x => x.shape).ToArray();

@@ -379,10 +379,9 @@ namespace Tensorflow.Operations
return cond_graph.unique_name(cond_graph.Name + "___redundant_placeholder");
}

private static Tensor _convert_to_tensor_or_indexed_slices(Tensor value, TF_DataType dtype,
string name)
private static Tensor _convert_to_tensor_or_indexed_slices(Tensor value)
{
return ops.convert_to_tensor(value, dtype, name, false);
return ops.convert_to_tensor(value, as_ref: false);
}

private static Tensor _build_maximum_iterations_loop_var(int maximum_iterations = -1)


+ 8
- 1
src/TensorFlowNET.Core/ops.cs View File

@@ -576,7 +576,14 @@ namespace Tensorflow
public static HandleData get_resource_handle_data(Tensor graph_op)
{
var handle_data = c_api.TFC_GetHandleShapeAndType(graph_op.graph.c_graph, graph_op._as_tf_output());
return HandleData.Parser.ParseFrom(c_api.ByteStringPiece(handle_data));
try{
var handle_str = c_api.ByteStringPiece(handle_data.DangerousGetHandle() == IntPtr.Zero ? null : new Buffer(handle_data));
return HandleData.Parser.ParseFrom(handle_str);
}
catch(Exception){
var handle_str = c_api.ByteStringPieceFromNativeString(handle_data.DangerousGetHandle());
return HandleData.Parser.ParseFrom(handle_str);
}
}

public static void dismantle_graph(Graph graph)


+ 1
- 1
tools/Tensorflow.UnitTest.RedistHolder/Tensorflow.UnitTest.RedistHolder.csproj View File

@@ -5,7 +5,7 @@
</PropertyGroup>

<ItemGroup>
<PackageReference Include="SciSharp.TensorFlow.Redist" Version="2.11.3" />
<PackageReference Include="SciSharp.TensorFlow.Redist" Version="2.11.4" />
<PackageReference Include="SciSharp.TensorFlow.Redist-Lite" Version="2.6.0" />
</ItemGroup>



Loading…
Cancel
Save