Browse Source

build_results can't handle when is_op is false condition.

tags/v0.1.0-Tensor
Oceania2018 6 years ago
parent
commit
aecaa78c6e
10 changed files with 93 additions and 27 deletions
  1. +1
    -1
      src/TensorFlowNET.Core/Sessions/BaseSession.cs
  2. +18
    -6
      src/TensorFlowNET.Core/Sessions/_ElementFetchMapper.cs
  3. +16
    -2
      src/TensorFlowNET.Core/Sessions/_FetchHandler.cs
  4. +4
    -1
      src/TensorFlowNET.Core/Sessions/_FetchMapper.cs
  5. +0
    -5
      src/TensorFlowNET.Core/Tensors/Tensor.Implicit.cs
  6. +24
    -0
      src/TensorFlowNET.Core/Variables/RefVariable.Implicit.cs
  7. +16
    -0
      src/TensorFlowNET.Core/Variables/RefVariable.Operators.cs
  8. +3
    -8
      src/TensorFlowNET.Core/Variables/RefVariable.cs
  9. +9
    -2
      src/TensorFlowNET.Core/ops.py.cs
  10. +2
    -2
      test/TensorFlowNET.UnitTest/VariableTest.cs

+ 1
- 1
src/TensorFlowNET.Core/Sessions/BaseSession.cs View File

@@ -72,7 +72,7 @@ namespace Tensorflow
// or if the call is a partial run that specifies feeds.
var results = _do_run(final_targets.Select(x => (Operation)(object)x).ToList(), final_fetches, feed_dict_tensor);

return fetch_handler.build_results(null, results);
return fetch_handler.build_results(this, results);
}

/// <summary>


+ 18
- 6
src/TensorFlowNET.Core/Sessions/_ElementFetchMapper.cs View File

@@ -11,9 +11,9 @@ namespace Tensorflow
public class _ElementFetchMapper<T> : _FetchMapper<T>
{
private List<object> _unique_fetches = new List<object>();
private Func<List<object>, NDArray> _contraction_fn;
private Func<List<object>, object> _contraction_fn;

public _ElementFetchMapper(List<T> fetches, Func<List<object>, NDArray> contraction_fn)
public _ElementFetchMapper(List<T> fetches, Func<List<object>, object> contraction_fn)
{
foreach(var fetch in fetches)
{
@@ -32,10 +32,22 @@ namespace Tensorflow
/// <returns></returns>
public NDArray build_results(List<object> values)
{
if (values.Count == 0)
return null;
else
return _contraction_fn(values);
NDArray result = null;

if (values.Count > 0)
{
var ret = _contraction_fn(values);
switch (ret)
{
case NDArray value:
result = value;
break;
default:
break;
}
}

return result;
}

public List<object> unique_fetches()


+ 16
- 2
src/TensorFlowNET.Core/Sessions/_FetchHandler.cs View File

@@ -16,7 +16,7 @@ namespace Tensorflow
private List<Tensor> _final_fetches = new List<Tensor>();
private List<T> _targets = new List<T>();

public _FetchHandler(Graph graph, T fetches, Dictionary<Tensor, NDArray> feeds = null, object feed_handles = null)
public _FetchHandler(Graph graph, T fetches, Dictionary<Tensor, NDArray> feeds = null, Action feed_handles = null)
{
_fetch_mapper = new _FetchMapper<T>().for_fetch(fetches);
foreach(var fetch in _fetch_mapper.unique_fetches())
@@ -40,18 +40,32 @@ namespace Tensorflow
_final_fetches = _fetches;
}

public NDArray build_results(Session session, NDArray[] tensor_values)
public NDArray build_results(BaseSession session, NDArray[] tensor_values)
{
var full_values = new List<object>();
if (_final_fetches.Count != tensor_values.Length)
throw new InvalidOperationException("_final_fetches mismatch tensor_values");

int i = 0;
int j = 0;
foreach(var is_op in _ops)
{
if (is_op)
{
full_values.Add(null);
}
else
{
var value = tensor_values[j];
j += 1;
full_values.Add(value);
}
i += 1;
}

if (j != tensor_values.Length)
throw new InvalidOperationException("j mismatch tensor_values");

return _fetch_mapper.build_results(full_values);
}



+ 4
- 1
src/TensorFlowNET.Core/Sessions/_FetchMapper.cs View File

@@ -10,7 +10,10 @@ namespace Tensorflow
{
var fetches = new List<T> { fetch };

return new _ElementFetchMapper<T>(fetches, null);
return new _ElementFetchMapper<T>(fetches, (List<object> fetched_vals) =>
{
return fetched_vals[0];
});
}
}
}

+ 0
- 5
src/TensorFlowNET.Core/Tensors/Tensor.Implicit.cs View File

@@ -25,10 +25,5 @@ namespace Tensorflow
{
return new Tensor(handle);
}

public static implicit operator Tensor(RefVariable var)
{
return var._initial_value;
}
}
}

+ 24
- 0
src/TensorFlowNET.Core/Variables/RefVariable.Implicit.cs View File

@@ -0,0 +1,24 @@
using System;
using System.Collections.Generic;
using System.Text;

namespace Tensorflow
{
public partial class RefVariable
{
public static implicit operator _VariableScopeStore(RefVariable variable)
{
return null;
}

public static implicit operator RefVariable(_VariableScopeStore store)
{
return null;
}

public static implicit operator Tensor(RefVariable var)
{
return var._AsTensor();
}
}
}

+ 16
- 0
src/TensorFlowNET.Core/Variables/RefVariable.Operators.cs View File

@@ -0,0 +1,16 @@
using System;
using System.Collections.Generic;
using System.Text;

namespace Tensorflow
{
public partial class RefVariable
{
public static Tensor operator +(RefVariable t1, int t2)
{
var tensor1 = t1._AsTensor();
var tensor2 = ops.convert_to_tensor(t2, tensor1.dtype, "y");
return gen_math_ops.add(tensor1, tensor2);
}
}
}

+ 3
- 8
src/TensorFlowNET.Core/Variables/RefVariable.cs View File

@@ -4,7 +4,7 @@ using System.Text;

namespace Tensorflow
{
public class RefVariable : VariableV1
public partial class RefVariable : VariableV1
{
public bool _in_graph_mode = true;
public Tensor _initial_value;
@@ -106,14 +106,9 @@ namespace Tensorflow
return _variable;
}

public static implicit operator _VariableScopeStore(RefVariable variable)
public Tensor _AsTensor()
{
return null;
}

public static implicit operator RefVariable(_VariableScopeStore store)
{
return null;
return _snapshot;
}
}
}

+ 9
- 2
src/TensorFlowNET.Core/ops.py.cs View File

@@ -59,7 +59,14 @@ namespace Tensorflow
return get_default_graph();
}

public static Tensor convert_to_tensor(object value, string name = "")
/// <summary>
/// Converts the given `value` to a `Tensor`.
/// </summary>
/// <param name="value"></param>
/// <param name="dtype"></param>
/// <param name="name"></param>
/// <returns></returns>
public static Tensor convert_to_tensor(object value, TF_DataType dtype = TF_DataType.DtInvalid, string name = "")
{
switch (value)
{
@@ -67,7 +74,7 @@ namespace Tensorflow
return val;
default:
var nd = tensor_util.convert_to_numpy_ndarray(value);
return tf.constant(nd, name);
return constant_op.Constant(nd, name);
}
}



+ 2
- 2
test/TensorFlowNET.UnitTest/VariableTest.cs View File

@@ -38,8 +38,8 @@ namespace TensorFlowNET.UnitTest
session.run(model);
for(int i = 0; i < 5; i++)
{
// x = x + 1;
var result = session.run(x);
var x1 = x + 1;
var result = session.run(x1);
print(result);
}
}


Loading…
Cancel
Save