Browse Source

added InputList

tags/v0.1.0-Tensor
Oceania2018 6 years ago
parent
commit
c45361cca3
4 changed files with 27 additions and 9 deletions
  1. +2
    -5
      src/TensorFlowNET.Core/Eager/Execute.cs
  2. +3
    -0
      src/TensorFlowNET.Core/Eager/pywrap_tfe_src.cs
  3. +16
    -0
      src/TensorFlowNET.Core/Operations/InputList.cs
  4. +6
    -4
      src/TensorFlowNET.Core/Operations/Operation.cs

+ 2
- 5
src/TensorFlowNET.Core/Eager/Execute.cs View File

@@ -6,12 +6,9 @@ namespace Tensorflow.Eager
{
public class Execute
{
public void record_gradient(string op_name, Tensor[] inputs, Dictionary<string, object> attrs, Tensor[] results, string name = "")
public void record_gradient(string op_name, InputList inputs, Dictionary<string, object> attrs, Tensor[] results, string name = "")
{
if (inputs == null)
inputs = new Tensor[0];

pywrap_tfe_src.RecordGradient(op_name, inputs, attrs, results, name);
pywrap_tfe_src.RecordGradient(op_name, inputs._inputs, attrs, results, name);
}
}
}

+ 3
- 0
src/TensorFlowNET.Core/Eager/pywrap_tfe_src.cs View File

@@ -25,6 +25,9 @@ namespace Tensorflow.Eager
}
}
if (!should_record) return;

var op_outputs = results;
var op_inputs = inputs;
}
}
}

+ 16
- 0
src/TensorFlowNET.Core/Operations/InputList.cs View File

@@ -0,0 +1,16 @@
using System;
using System.Collections.Generic;
using System.Text;

namespace Tensorflow
{
public class InputList
{
public Tensor[] _inputs;

public InputList(Tensor[] inputs)
{
_inputs = inputs;
}
}
}

+ 6
- 4
src/TensorFlowNET.Core/Operations/Operation.cs View File

@@ -101,21 +101,23 @@ namespace Tensorflow
}
}

private Tensor[] _inputs;
public Tensor[] inputs
private InputList _inputs;
public InputList inputs
{
get
{
if(_inputs == null)
{
_inputs = new Tensor[NumInputs];
var retval = new Tensor[NumInputs];

for (int i = 0; i < NumInputs; i++)
{
var tf_outpus = Input(i);
var op = new Operation(tf_outpus.oper);
_inputs[i] = op.outputs[tf_outpus.index];
retval[i] = op.outputs[tf_outpus.index];
}

_inputs = new InputList(retval);
}

return _inputs;


Loading…
Cancel
Save