Browse Source

Removed all use-cases of Marshal.PtrToStructure in favor of unsafe ptrs.

tags/v0.12
Eli Belash 6 years ago
parent
commit
bd18f5db17
6 changed files with 28 additions and 24 deletions
  1. +0
    -1
      src/TensorFlowNET.Core/DisposableObject.cs
  2. +4
    -5
      src/TensorFlowNET.Core/Graphs/Graph.Import.cs
  3. +7
    -3
      src/TensorFlowNET.Core/Graphs/Graph.Operation.cs
  4. +6
    -6
      src/TensorFlowNET.Core/Graphs/Graph.cs
  5. +3
    -5
      src/TensorFlowNET.Core/Operations/Operation.Output.cs
  6. +8
    -4
      src/TensorFlowNET.Core/Sessions/c_api.tf_session_helper.cs

+ 0
- 1
src/TensorFlowNET.Core/DisposableObject.cs View File

@@ -54,7 +54,6 @@ namespace Tensorflow
if (_handle != IntPtr.Zero) if (_handle != IntPtr.Zero)
{ {
DisposeUnmanagedResources(_handle); DisposeUnmanagedResources(_handle);

_handle = IntPtr.Zero; _handle = IntPtr.Zero;
} }
} }


+ 4
- 5
src/TensorFlowNET.Core/Graphs/Graph.Import.cs View File

@@ -30,11 +30,10 @@ namespace Tensorflow
var return_output_handle = Marshal.AllocHGlobal(size * num_return_outputs); var return_output_handle = Marshal.AllocHGlobal(size * num_return_outputs);


c_api.TF_GraphImportGraphDefWithReturnOutputs(_handle, graph_def, opts, return_output_handle, num_return_outputs, s); c_api.TF_GraphImportGraphDefWithReturnOutputs(_handle, graph_def, opts, return_output_handle, num_return_outputs, s);
for (int i = 0; i < num_return_outputs; i++)
{
var handle = return_output_handle + i * size;
return_outputs[i] = Marshal.PtrToStructure<TF_Output>(handle);
}

var tf_output_ptr = (TF_Output*) return_output_handle;
for (int i = 0; i < num_return_outputs; i++)
return_outputs[i] = *(tf_output_ptr + i);


Marshal.FreeHGlobal(return_output_handle); Marshal.FreeHGlobal(return_output_handle);




+ 7
- 3
src/TensorFlowNET.Core/Graphs/Graph.Operation.cs View File

@@ -39,16 +39,20 @@ namespace Tensorflow
return c_api.TF_NewOperation(_handle, opType, opName); return c_api.TF_NewOperation(_handle, opType, opName);
} }
public unsafe Operation[] ReturnOperations(IntPtr results)
public Operation[] ReturnOperations(IntPtr results)
{ {
TF_Operation return_oper_handle = new TF_Operation(); TF_Operation return_oper_handle = new TF_Operation();
int num_return_opers = 0; int num_return_opers = 0;
c_api.TF_ImportGraphDefResultsReturnOperations(results, ref num_return_opers, ref return_oper_handle); c_api.TF_ImportGraphDefResultsReturnOperations(results, ref num_return_opers, ref return_oper_handle);
Operation[] return_opers = new Operation[num_return_opers]; Operation[] return_opers = new Operation[num_return_opers];
var tf_op_size = Marshal.SizeOf<TF_Operation>();
for (int i = 0; i < num_return_opers; i++) for (int i = 0; i < num_return_opers; i++)
{ {
var handle = return_oper_handle.node + Marshal.SizeOf<TF_Operation>() * i;
return_opers[i] = new Operation(*(IntPtr*)handle);
unsafe
{
var handle = return_oper_handle.node + tf_op_size * i;
return_opers[i] = new Operation(*(IntPtr*)handle);
}
} }
return return_opers; return return_opers;


+ 6
- 6
src/TensorFlowNET.Core/Graphs/Graph.cs View File

@@ -369,7 +369,7 @@ namespace Tensorflow
var name_key = name.ToLower(); var name_key = name.ToLower();
int i = 0; int i = 0;
if (_names_in_use.ContainsKey(name_key)) if (_names_in_use.ContainsKey(name_key))
i = _names_in_use[name_key];
i = _names_in_use[name_key];
// Increment the number for "name_key". // Increment the number for "name_key".
if (mark_as_used) if (mark_as_used)
_names_in_use[name_key] = i + 1; _names_in_use[name_key] = i + 1;
@@ -399,13 +399,13 @@ namespace Tensorflow
int num_return_outputs = 0; int num_return_outputs = 0;
c_api.TF_ImportGraphDefResultsReturnOutputs(results, ref num_return_outputs, ref return_output_handle); c_api.TF_ImportGraphDefResultsReturnOutputs(results, ref num_return_outputs, ref return_output_handle);
TF_Output[] return_outputs = new TF_Output[num_return_outputs]; TF_Output[] return_outputs = new TF_Output[num_return_outputs];
for (int i = 0; i < num_return_outputs; i++)
unsafe
{ {
var handle = return_output_handle + (Marshal.SizeOf<TF_Output>() * i);
return_outputs[i] = Marshal.PtrToStructure<TF_Output>(handle);
var tf_output_ptr = (TF_Output*) return_output_handle;
for (int i = 0; i < num_return_outputs; i++)
return_outputs[i] = *(tf_output_ptr + i);
return return_outputs;
} }

return return_outputs;
} }


public string[] get_all_collection_keys() public string[] get_all_collection_keys()


+ 3
- 5
src/TensorFlowNET.Core/Operations/Operation.Output.cs View File

@@ -50,14 +50,12 @@ namespace Tensorflow


public unsafe TF_Input[] OutputConsumers(int index, int max_consumers) public unsafe TF_Input[] OutputConsumers(int index, int max_consumers)
{ {
int size = Marshal.SizeOf<TF_Input>();
var handle = Marshal.AllocHGlobal(size);
var handle = Marshal.AllocHGlobal(Marshal.SizeOf<TF_Input>());
int num = c_api.TF_OperationOutputConsumers(new TF_Output(_handle, index), handle, max_consumers); int num = c_api.TF_OperationOutputConsumers(new TF_Output(_handle, index), handle, max_consumers);
var consumers = new TF_Input[num]; var consumers = new TF_Input[num];
var inputptr = (TF_Input*) handle;
for (int i = 0; i < num; i++) for (int i = 0; i < num; i++)
{
consumers[i] = Marshal.PtrToStructure<TF_Input>(handle + i * size);
}
consumers[i] = *(inputptr + i);


return consumers; return consumers;
} }


+ 8
- 4
src/TensorFlowNET.Core/Sessions/c_api.tf_session_helper.cs View File

@@ -27,13 +27,17 @@ namespace Tensorflow
var handle = Marshal.AllocHGlobal(size * num_consumers); var handle = Marshal.AllocHGlobal(size * num_consumers);
int num = TF_OperationOutputConsumers(oper_out, handle, num_consumers); int num = TF_OperationOutputConsumers(oper_out, handle, num_consumers);
var consumers = new string[num_consumers]; var consumers = new string[num_consumers];
for (int i = 0; i < num; i++)
unsafe
{ {
TF_Input input = Marshal.PtrToStructure<TF_Input>(handle + i * size);
consumers[i] = Marshal.PtrToStringAnsi(TF_OperationName(input.oper));
var inputptr = (TF_Input*) handle;
for (int i = 0; i < num; i++)
{
var oper = (inputptr + i)->oper;
consumers[i] = Marshal.PtrToStringAnsi(TF_OperationName(oper));
}
} }


return consumers; return consumers;
} }
} }
}
}

Loading…
Cancel
Save