Browse Source

VariableScope, vairables.py.cs

tags/v0.1.0-Tensor
haiping008 6 years ago
parent
commit
9f18881d89
15 changed files with 213 additions and 12 deletions
  1. +1
    -0
      src/TensorFlowNET.Core/APIs/c_api.cs
  2. +19
    -0
      src/TensorFlowNET.Core/Exceptions/ValueError.cs
  3. +15
    -0
      src/TensorFlowNET.Core/Graphs/Graph.cs
  4. +1
    -0
      src/TensorFlowNET.Core/Train/Optimizer.cs
  5. +1
    -1
      src/TensorFlowNET.Core/Variables/RefVariable.cs
  6. +11
    -0
      src/TensorFlowNET.Core/Variables/VariableScope.cs
  7. +14
    -0
      src/TensorFlowNET.Core/Variables/VariableSynchronization.cs
  8. +2
    -2
      src/TensorFlowNET.Core/Variables/VariableV1.cs
  9. +16
    -0
      src/TensorFlowNET.Core/Variables/_VariableScopeStore.cs
  10. +74
    -0
      src/TensorFlowNET.Core/Variables/variable_scope.py.cs
  11. +18
    -0
      src/TensorFlowNET.Core/Variables/variables.py.cs
  12. +26
    -0
      src/TensorFlowNET.Core/ops.GraphKeys.py.cs
  13. +12
    -1
      src/TensorFlowNET.Core/ops.py.cs
  14. +1
    -1
      src/TensorFlowNET.Core/tf.cs
  15. +2
    -7
      test/TensorFlowNET.Examples/LinearRegression.cs

+ 1
- 0
src/TensorFlowNET.Core/APIs/c_api.cs View File

@@ -25,6 +25,7 @@ namespace Tensorflow
/// size_t* => ref uint
/// void* => IntPtr
/// string => IntPtr c_api.StringPiece(IntPtr)
/// unsigned char => byte
/// </summary>
public static partial class c_api
{


+ 19
- 0
src/TensorFlowNET.Core/Exceptions/ValueError.cs View File

@@ -0,0 +1,19 @@
using System;
using System.Collections.Generic;
using System.Text;

namespace Tensorflow
{
public class ValueError : Exception
{
public ValueError() : base()
{

}

public ValueError(string message) : base(message)
{

}
}
}

+ 15
- 0
src/TensorFlowNET.Core/Graphs/Graph.cs View File

@@ -27,6 +27,11 @@ namespace Tensorflow
public string _graph_key;
public Status Status { get; }

/// <summary>
/// Arbitrary collections of objects.
/// </summary>
private Dictionary<string, object> _collections = new Dictionary<string, object>();

public Graph()
{
_handle = c_api.TF_NewGraph();
@@ -86,6 +91,11 @@ namespace Tensorflow
throw new Exception($"Can not convert a {typeof(T).Name} into a {types_str}.");
}

public void add_to_collection(string name, object value)
{
_collections[name] = value;
}

public unsafe Operation create_op(string op_type, List<Tensor> inputs, TF_DataType[] dtypes,
TF_DataType[] input_types = null, string name = "",
Dictionary<string, AttrValue> attrs = null, OpDef op_def = null)
@@ -221,6 +231,11 @@ namespace Tensorflow
return _nodes_by_name.Values.Select(x => x).ToArray();
}

public Dictionary<string, object> get_collection(string name)
{
return _collections;
}

public void Dispose()
{
c_api.TF_DeleteGraph(_handle);


+ 1
- 0
src/TensorFlowNET.Core/Train/Optimizer.cs View File

@@ -49,6 +49,7 @@ namespace Tensorflow
}

var var_list = variables.trainable_variables();
return null;
}
}


src/TensorFlowNET.Core/Tensors/RefVariable.cs → src/TensorFlowNET.Core/Variables/RefVariable.cs View File

@@ -4,7 +4,7 @@ using System.Text;

namespace Tensorflow
{
public class RefVariable : Variable
public class RefVariable : VariableV1
{
public bool _in_graph_mode = true;
public Tensor _initial_value;

+ 11
- 0
src/TensorFlowNET.Core/Variables/VariableScope.cs View File

@@ -0,0 +1,11 @@
using System;
using System.Collections.Generic;
using System.Text;

namespace Tensorflow
{
public class VariableScope
{
public bool? use_resource { get; set; }
}
}

+ 14
- 0
src/TensorFlowNET.Core/Variables/VariableSynchronization.cs View File

@@ -0,0 +1,14 @@
using System;
using System.Collections.Generic;
using System.Text;

namespace Tensorflow
{
public enum VariableSynchronization
{
AUTO = 0,
NONE = 1,
ON_WRITE = 2,
ON_READ = 3
}
}

src/TensorFlowNET.Core/Tensors/Variable.cs → src/TensorFlowNET.Core/Variables/VariableV1.cs View File

@@ -14,9 +14,9 @@ namespace Tensorflow
/// the variable are fixed. The value can be changed using one of the assign methods.
/// https://tensorflow.org/guide/variables
/// </summary>
public class Variable
public class VariableV1
{
public Variable(object initial_value, string name = "", TF_DataType trainable = TF_DataType.DtInvalid, bool validate_shape = true)
public VariableV1(object initial_value, string name = "", TF_DataType trainable = TF_DataType.DtInvalid, bool validate_shape = true)
{

}

+ 16
- 0
src/TensorFlowNET.Core/Variables/_VariableScopeStore.cs View File

@@ -0,0 +1,16 @@
using System;
using System.Collections.Generic;
using System.Text;

namespace Tensorflow
{
public class _VariableScopeStore
{
public VariableScope current_scope { get; set; }

public _VariableScopeStore()
{
current_scope = new VariableScope();
}
}
}

+ 74
- 0
src/TensorFlowNET.Core/Variables/variable_scope.py.cs View File

@@ -0,0 +1,74 @@
using System;
using System.Collections.Generic;
using System.Text;

namespace Tensorflow
{
public class variable_scope
{
public static string _VARSCOPESTORE_KEY = "__varscope";
public static bool _DEFAULT_USE_RESOURCE = false;

public static RefVariable default_variable_creator(object initial_value, string name = "", TF_DataType dtype = TF_DataType.DtInvalid, bool ? use_resource = null, VariableSynchronization synchronization = VariableSynchronization.AUTO)
{
var trainable = _get_trainable_value(synchronization);
if (!use_resource.HasValue)
{
use_resource = get_variable_scope().use_resource;
}

if(!use_resource.HasValue)
use_resource = _DEFAULT_USE_RESOURCE;

if (use_resource.Value)
{
throw new NotImplementedException();
}
else
{
return new RefVariable(initial_value);
}
}

public static VariableScope get_variable_scope()
{
return get_variable_scope_store().current_scope;
}

public static _VariableScopeStore get_variable_scope_store()
{
var scope_store = ops.get_collection(_VARSCOPESTORE_KEY);
if (scope_store == null)
{
scope_store = new _VariableScopeStore();
ops.add_to_collection(_VARSCOPESTORE_KEY, scope_store);
}
else
{
// scope_store = scope_store[0];
}

return scope_store;
}

public static bool _get_trainable_value(VariableSynchronization synchronization, bool? trainable = null)
{
if(synchronization == VariableSynchronization.ON_READ)
{
if (trainable.Value)
throw new ValueError("Synchronization value can be set to " +
"VariableSynchronization.ON_READ only for non-trainable variables. " +
"You have specified trainable=True and " +
"synchronization=VariableSynchronization.ON_READ.");
else
trainable = false;
}
else if (!trainable.HasValue)
{
trainable = true;
}

return trainable.Value;
}
}
}

+ 18
- 0
src/TensorFlowNET.Core/Variables/variables.py.cs View File

@@ -0,0 +1,18 @@
using System;
using System.Collections.Generic;
using System.Text;

namespace Tensorflow
{
public class variables
{
/// <summary>
/// Returns all variables created with `trainable=True`
/// </summary>
/// <returns></returns>
public static object trainable_variables()
{
return ops.get_collection(ops.GraphKey.TRAINABLE_VARIABLES);
}
}
}

+ 26
- 0
src/TensorFlowNET.Core/ops.GraphKeys.py.cs View File

@@ -0,0 +1,26 @@
using System;
using System.Collections.Generic;
using System.Text;

namespace Tensorflow
{
public partial class ops
{
/// <summary>
/// Standard names to use for graph collections.
/// The standard library uses various well-known names to collect and
/// retrieve values associated with a graph. For example, the
/// `tf.Optimizer` subclasses default to optimizing the variables
/// collected under `tf.GraphKeys.TRAINABLE_VARIABLES` if none is
/// specified, but it is also possible to pass an explicit list of
/// variables.
/// </summary>
public static class GraphKey
{
/// <summary>
/// the subset of `Variable` objects that will be trained by an optimizer.
/// </summary>
public static string TRAINABLE_VARIABLES = "trainable_variables";
}
}
}

src/TensorFlowNET.Core/Operations/ops.cs → src/TensorFlowNET.Core/ops.py.cs View File

@@ -10,8 +10,19 @@ using System.Linq;

namespace Tensorflow
{
public static class ops
public partial class ops
{
public static void add_to_collection(string name, object value)
{
var graph = tf.get_default_graph();
graph.add_to_collection(name, value);
}

public static _VariableScopeStore get_collection(string key)
{
return null;// get_default_graph().get_collection(key);
}

public static Graph get_default_graph()
{
return tf.Graph();

+ 1
- 1
src/TensorFlowNET.Core/tf.cs View File

@@ -22,7 +22,7 @@ namespace Tensorflow

public static RefVariable Variable<T>(T data, string name = "", TF_DataType dtype = TF_DataType.DtInvalid)
{
return new RefVariable(data, name, dtype);
return variable_scope.default_variable_creator(data, name: name, dtype: TF_DataType.DtInvalid);
}

public static unsafe Tensor placeholder(TF_DataType dtype, TensorShape shape = null)


+ 2
- 7
test/TensorFlowNET.Examples/LinearRegression.cs View File

@@ -43,18 +43,13 @@ namespace TensorFlowNET.Examples
var sub = pred - Y;
var pow = tf.pow(sub, 2);







var reduce = tf.reduce_sum(pow);
var cost = reduce / (2d * n_samples);

// radient descent
// Note, minimize() knows to modify W and b because Variable objects are trainable=True by default
var optimizer = tf.train.GradientDescentOptimizer(learning_rate).minimize(cost);
var optimizer = tf.train.GradientDescentOptimizer(learning_rate);
optimizer.minimize(cost);
}
}
}

Loading…
Cancel
Save