@@ -20,6 +20,8 @@ namespace Tensorflow | |||
{ | |||
public partial class tensorflow | |||
{ | |||
public InitializersImpl initializers { get; } = new InitializersImpl(); | |||
public IInitializer constant_initializer<T>(T value, TF_DataType dtype = TF_DataType.TF_FLOAT, bool verify_shape = false) | |||
=> new Constant<T>(value, dtype: dtype, verify_shape: verify_shape); | |||
public IInitializer zeros_initializer => new Zeros(); | |||
@@ -82,5 +84,20 @@ namespace Tensorflow | |||
uniform: uniform, | |||
seed: seed, | |||
dtype: dtype); | |||
public class InitializersImpl | |||
{ | |||
public IInitializer random_normal_initializer(float mean = 0.0f, | |||
float stddev = 1.0f, | |||
int? seed = null, | |||
TF_DataType dtype = TF_DataType.TF_FLOAT) => new RandomNormal(mean: mean, | |||
stddev: stddev, | |||
seed: seed, | |||
dtype: dtype); | |||
public IInitializer zeros_initializer(TensorShape shape = null, | |||
TF_DataType dtype = TF_DataType.TF_FLOAT) => new Zeros(shape: shape, | |||
dtype: dtype); | |||
} | |||
} | |||
} |
@@ -27,6 +27,18 @@ namespace Tensorflow | |||
public class KerasOptimizers | |||
{ | |||
public SGD SGD(float learning_rate) => new SGD(learning_rate); | |||
public Adam Adam(float learning_rate = 0.001f, | |||
float beta_1 = 0.9f, | |||
float beta_2 = 0.999f, | |||
float epsilon = 1e-7f, | |||
bool amsgrad = false, | |||
string name = "Adam") => new Adam(learning_rate: learning_rate, | |||
beta_1: beta_1, | |||
beta_2: beta_2, | |||
epsilon: epsilon, | |||
amsgrad: amsgrad, | |||
name: name); | |||
} | |||
} | |||
} |
@@ -51,28 +51,14 @@ namespace Tensorflow.Eager | |||
public override object get_attr(string attr_name) | |||
{ | |||
object value = null; | |||
byte isList = 0; | |||
var attrType = c_api.TFE_OpNameGetAttrType(tf.Context.Handle, Name, attr_name, ref isList, tf.Status.Handle); | |||
switch (attrType) | |||
{ | |||
case TF_AttrType.TF_ATTR_BOOL: | |||
value = get_attr_bool(attr_name); | |||
break; | |||
default: | |||
break; | |||
} | |||
return value; | |||
} | |||
public bool get_attr_bool(string attr_name) | |||
{ | |||
// var attrType = c_api.TFE_OpNameGetAttrType(tf.Context.Handle, Name, attr_name, ref isList, tf.Status.Handle); | |||
for (int i = 0; i < Attrs.Length; i = i + 2) | |||
{ | |||
if (Attrs[i].Equals(attr_name)) | |||
return Attrs[i + 1].Equals("1"); | |||
return Attrs[i + 1]; | |||
} | |||
throw new ValueError($"Can't find attr: {attr_name}"); | |||
return null; | |||
} | |||
public override string ToString() | |||
@@ -344,6 +344,11 @@ namespace Tensorflow.Eager | |||
c_api.TFE_OpSetAttrTypeList(op, key, values2, values2.Length); | |||
attr_list_sizes[key] = values2.Length; | |||
} | |||
else if (type == TF_AttrType.TF_ATTR_INT && values is int[] values4) | |||
{ | |||
c_api.TFE_OpSetAttrIntList(op, key, values4.Select(x => Convert.ToInt64(x)).ToArray(), values4.Length); | |||
attr_list_sizes[key] = values4.Length; | |||
} | |||
else | |||
{ | |||
throw new NotImplementedException(""); | |||
@@ -209,6 +209,9 @@ namespace Tensorflow | |||
[DllImport(TensorFlowLibName)] | |||
public static extern void TFE_OpSetAttrTypeList(SafeOpHandle op, string attr_name, TF_DataType[] values, int num_values); | |||
[DllImport(TensorFlowLibName)] | |||
public static extern void TFE_OpSetAttrIntList(SafeOpHandle op, string attr_name, long[] values, int num_values); | |||
[DllImport(TensorFlowLibName)] | |||
public static extern void TFE_OpSetAttrValueProto(SafeOpHandle op, string attr_name, IMessage[] proto, int proto_len, SafeStatusHandle status); | |||
@@ -119,7 +119,7 @@ namespace Tensorflow.Gradients | |||
return (results[0], results[1]); | |||
} | |||
public Tensor[] gradient(Tensor target, List<IVariableV1> sources) | |||
public Tensor[] gradient(Tensor target, IEnumerable<IVariableV1> sources) | |||
{ | |||
if (_recording) | |||
{ | |||
@@ -128,12 +128,12 @@ namespace Tensorflow.Gradients | |||
[RegisterGradient("Conv2D")] | |||
public static Tensor[] _Conv2DGrad(Operation op, Tensor[] grads) | |||
{ | |||
var dilations = (op.get_attr("dilations") as AttrValue.Types.ListValue).I.Select(x => Convert.ToInt32(x)).ToArray(); | |||
var strides = (op.get_attr("strides") as AttrValue.Types.ListValue).I.Select(x => Convert.ToInt32(x)).ToArray(); | |||
var padding = op.get_attr("padding"); | |||
var explicit_paddings = (op.get_attr("explicit_paddings") as AttrValue.Types.ListValue).I.Select(x => Convert.ToInt32(x)).ToArray(); | |||
var use_cudnn_on_gpu = op.get_attr("use_cudnn_on_gpu"); | |||
var data_format = op.get_attr("data_format"); | |||
var dilations = op.get_attr<int[]>("dilations"); | |||
var strides = op.get_attr<int[]>("strides"); | |||
var padding = op.get_attr<string>("padding"); | |||
var explicit_paddings = op.get_attr<int[]>("explicit_paddings"); | |||
var use_cudnn_on_gpu = op.get_attr<bool>("use_cudnn_on_gpu"); | |||
var data_format = op.get_attr<string>("data_format"); | |||
var shape = gen_array_ops.shape_n(new Tensor[] { op.inputs[0], op.inputs[1] }); | |||
return new Tensor[] | |||
@@ -287,8 +287,8 @@ namespace Tensorflow.Gradients | |||
op.inputs[0], | |||
op.outputs[0], | |||
grad, | |||
(op.get_attr("ksize") as AttrValue.Types.ListValue).I.Select(x => Convert.ToInt32(x)).ToArray(), | |||
(op.get_attr("strides") as AttrValue.Types.ListValue).I.Select(x => Convert.ToInt32(x)).ToArray(), | |||
op.get_attr("ksize") as int[], | |||
op.get_attr("strides") as int[], | |||
padding: op.get_attr("padding").ToString(), | |||
data_format: op.get_attr("data_format").ToString()) | |||
}; | |||
@@ -0,0 +1,91 @@ | |||
using System; | |||
using System.Collections.Generic; | |||
using System.Linq; | |||
using System.Text; | |||
using Tensorflow.Eager; | |||
namespace Tensorflow.Keras.Optimizers | |||
{ | |||
/// <summary> | |||
/// Optimizer that implements the Adam algorithm. | |||
/// Adam optimization is a stochastic gradient descent method that is based on | |||
/// adaptive estimation of first-order and second-order moments. | |||
/// </summary> | |||
public class Adam : OptimizerV2 | |||
{ | |||
protected override string _name => "Adam"; | |||
float epsilon = 1e-7f; | |||
bool amsgrad = false; | |||
public Adam(float learning_rate = 0.001f, | |||
float beta_1 = 0.9f, | |||
float beta_2 = 0.999f, | |||
float epsilon = 1e-7f, | |||
bool amsgrad = false, | |||
string name = "Adam") | |||
{ | |||
_set_hyper("learning_rate", learning_rate); | |||
// _set_hyper("decay", _initial_decay); | |||
_set_hyper("beta_1", beta_1); | |||
_set_hyper("beta_2", beta_2); | |||
this.epsilon = epsilon; | |||
this.amsgrad = amsgrad; | |||
} | |||
protected override void _create_slots(IVariableV1[] var_list) | |||
{ | |||
foreach(var var in var_list) | |||
add_slot(var, "m"); | |||
foreach (var var in var_list) | |||
add_slot(var, "v"); | |||
if (amsgrad) | |||
foreach (var var in var_list) | |||
add_slot(var, "vhat"); | |||
} | |||
protected override void _prepare_local(DeviceDType device_dtype, Dictionary<DeviceDType, Dictionary<string, Tensor>> apply_state) | |||
{ | |||
base._prepare_local(device_dtype, apply_state); | |||
var var_dtype = device_dtype.DType; | |||
var var_device = device_dtype.Device; | |||
var local_step = math_ops.cast(iterations + 1, var_dtype); | |||
var beta_1_t = array_ops.identity(_get_hyper("beta_1", var_dtype)); | |||
var beta_2_t = array_ops.identity(_get_hyper("beta_2", var_dtype)); | |||
var beta_1_power = math_ops.pow(beta_1_t, local_step); | |||
var beta_2_power = math_ops.pow(beta_2_t, local_step); | |||
var lr = apply_state[device_dtype]["lr_t"] * (math_ops.sqrt(1 - beta_2_power) / (1 - beta_1_power)); | |||
// update state | |||
apply_state[device_dtype]["lr"] = lr; | |||
apply_state[device_dtype]["epsilon"] = ops.convert_to_tensor(epsilon); | |||
apply_state[device_dtype]["beta_1_t"] = beta_1_t; | |||
apply_state[device_dtype]["beta_1_power"] = beta_1_power; | |||
apply_state[device_dtype]["one_minus_beta_1_t"] = 1 - beta_1_t; | |||
apply_state[device_dtype]["beta_2_t"] = beta_2_t; | |||
apply_state[device_dtype]["beta_2_power"] = beta_2_power; | |||
apply_state[device_dtype]["one_minus_beta_2_t"] = 1 - beta_2_t; | |||
} | |||
protected override Operation _resource_apply_dense(IVariableV1 var, Tensor grad, Dictionary<DeviceDType, Dictionary<string, Tensor>> apply_state) | |||
{ | |||
var (var_device, var_dtype) = (var.Device, var.dtype.as_base_dtype()); | |||
var coefficients = apply_state.FirstOrDefault(x => x.Key.Device == var_device && x.Key.DType == var_dtype).Value ?? _fallback_apply_state(var_device, var_dtype); | |||
var m = get_slot(var, "m"); | |||
var v = get_slot(var, "v"); | |||
if (!amsgrad) | |||
return gen_training_ops.resource_apply_adam(var.Handle, | |||
m.Handle, | |||
v.Handle, | |||
coefficients["beta_1_power"], | |||
coefficients["beta_2_power"], | |||
coefficients["lr_t"], | |||
coefficients["beta_1_t"], | |||
coefficients["beta_2_t"], | |||
coefficients["epsilon"], | |||
grad, | |||
use_locking: _use_locking); | |||
else | |||
throw new NotImplementedException(""); | |||
} | |||
} | |||
} |
@@ -18,22 +18,25 @@ namespace Tensorflow.Keras.Optimizers | |||
protected bool _hypers_created; | |||
protected virtual string _name { get; } | |||
ResourceVariable _iterations; | |||
List<ResourceVariable> _weight; | |||
IVariableV1 _iterations; | |||
protected ResourceVariable iterations => _iterations as ResourceVariable; | |||
List<IVariableV1> _weights; | |||
Dictionary<string, float> _hyper; | |||
Dictionary<string, ResourceVariable> _hyper_variables; | |||
Dictionary<string, IVariableV1> _hyper_variables; | |||
protected bool _momentum; | |||
protected float _initial_decay = 0.0f; | |||
protected bool _use_locking = true; | |||
Dictionary<DeviceDType, Dictionary<string, Tensor>> apply_state; | |||
Dictionary<string, Dictionary<string, IVariableV1>> _slots; | |||
List<string> _slot_names; | |||
public OptimizerV2() : base() | |||
{ | |||
_weight = new List<ResourceVariable>(); | |||
_weights = new List<IVariableV1>(); | |||
_hyper = new Dictionary<string, float>(); | |||
_hyper_variables = new Dictionary<string, ResourceVariable>(); | |||
apply_state = new Dictionary<DeviceDType, Dictionary<string, Tensor>>(); | |||
_hyper_variables = new Dictionary<string, IVariableV1>(); | |||
_slots = new Dictionary<string, Dictionary<string, IVariableV1>>(); | |||
_slot_names = new List<string>(); | |||
} | |||
public void apply_gradients((Tensor, ResourceVariable) grads_and_vars, | |||
@@ -61,7 +64,7 @@ namespace Tensorflow.Keras.Optimizers | |||
if (grads_and_vars == null || grads_and_vars.Count() == 0) | |||
return control_flow_ops.no_op(); | |||
apply_state = _prepare(var_list); | |||
var apply_state = _prepare(var_list); | |||
if(experimental_aggregate_gradients) | |||
{ | |||
// var reduced_grads = _aggregate_gradients(grads_and_vars); | |||
@@ -72,13 +75,13 @@ namespace Tensorflow.Keras.Optimizers | |||
}); | |||
} | |||
void apply_grad_to_update_var(ResourceVariable var, EagerTensor grad) | |||
void apply_grad_to_update_var(ResourceVariable var, Tensor grad, Dictionary<DeviceDType, Dictionary<string, Tensor>> apply_state) | |||
{ | |||
_resource_apply_dense(var, grad, apply_state); | |||
} | |||
protected virtual Operation _resource_apply_dense(IVariableV1 var, | |||
EagerTensor grad, | |||
Tensor grad, | |||
Dictionary<DeviceDType, Dictionary<string, Tensor>> _apply_state) | |||
{ | |||
throw new NotImplementedException("_resource_apply_dense"); | |||
@@ -94,7 +97,7 @@ namespace Tensorflow.Keras.Optimizers | |||
{ | |||
tf_with(ops.name_scope("update"), delegate | |||
{ | |||
apply_grad_to_update_var(var, grad as EagerTensor); | |||
apply_grad_to_update_var(var, grad, _apply_state); | |||
}); | |||
} | |||
@@ -107,6 +110,12 @@ namespace Tensorflow.Keras.Optimizers | |||
return grads_and_vars.Select(x => x.Item1).ToArray(); | |||
} | |||
protected IVariableV1 get_slot(IVariableV1 var, string slot_name) | |||
{ | |||
var slot_dict = _slots[var.UniqueId]; | |||
return slot_dict[slot_name]; | |||
} | |||
Dictionary<DeviceDType, Dictionary<string, Tensor>> _prepare(IVariableV1[] var_list) | |||
{ | |||
var _apply_state = new Dictionary<DeviceDType, Dictionary<string, Tensor>>(); | |||
@@ -125,6 +134,11 @@ namespace Tensorflow.Keras.Optimizers | |||
return _apply_state; | |||
} | |||
protected Dictionary<string, Tensor> _fallback_apply_state(string var_device, TF_DataType var_dtype) | |||
{ | |||
throw new NotImplementedException(""); | |||
} | |||
protected virtual void _prepare_local(DeviceDType device_dtype, | |||
Dictionary<DeviceDType, Dictionary<string, Tensor>> _apply_state) | |||
{ | |||
@@ -145,7 +159,7 @@ namespace Tensorflow.Keras.Optimizers | |||
return lr_t; | |||
} | |||
protected ResourceVariable _get_hyper(string name, TF_DataType dtype = TF_DataType.DtInvalid) | |||
protected Tensor _get_hyper(string name, TF_DataType dtype = TF_DataType.DtInvalid) | |||
{ | |||
var value = _hyper_variables[name]; | |||
return math_ops.cast(value, dtype); | |||
@@ -160,7 +174,7 @@ namespace Tensorflow.Keras.Optimizers | |||
dtype: TF_DataType.TF_INT64, | |||
trainable: false, | |||
aggregation: VariableAggregation.OnlyFirstReplica); | |||
_weight.Add(_iterations); | |||
_weights.Add(_iterations); | |||
} | |||
_create_hypers(); | |||
@@ -190,7 +204,7 @@ namespace Tensorflow.Keras.Optimizers | |||
_hypers_created = true; | |||
} | |||
void _create_slots(IVariableV1[] var_list) | |||
protected virtual void _create_slots(IVariableV1[] var_list) | |||
{ | |||
if(_momentum) | |||
{ | |||
@@ -199,6 +213,35 @@ namespace Tensorflow.Keras.Optimizers | |||
} | |||
} | |||
protected IVariableV1 add_slot(IVariableV1 var, string slot_name, IInitializer initializer = null) | |||
{ | |||
if (initializer == null) | |||
initializer = tf.zeros_initializer; | |||
if (!_slot_names.Contains(slot_name)) | |||
_slot_names.append(slot_name); | |||
if (!_slots.ContainsKey(var.UniqueId)) | |||
_slots[var.UniqueId] = new Dictionary<string, IVariableV1>(); | |||
var slot_dict = _slots[var.UniqueId]; | |||
if (!slot_dict.ContainsKey(slot_name)) | |||
{ | |||
var weight = tf.Variable(initializer, | |||
dtype: var.dtype, | |||
trainable: false, | |||
shape: var.shape, | |||
name: $"{var.Name}/{slot_name}"); | |||
slot_dict[slot_name] = weight; | |||
_weights.append(weight); | |||
return weight; | |||
} | |||
else | |||
{ | |||
return slot_dict[slot_name]; | |||
} | |||
} | |||
ResourceVariable add_weight(string name, | |||
TensorShape shape, | |||
TF_DataType dtype = TF_DataType.TF_FLOAT, | |||
@@ -40,7 +40,7 @@ namespace Tensorflow.Keras.Optimizers | |||
_get_hyper("momentum", device_dtype.DType)); | |||
} | |||
protected override Operation _resource_apply_dense(IVariableV1 var, EagerTensor grad, Dictionary<DeviceDType, Dictionary<string, Tensor>> _apply_state) | |||
protected override Operation _resource_apply_dense(IVariableV1 var, Tensor grad, Dictionary<DeviceDType, Dictionary<string, Tensor>> _apply_state) | |||
{ | |||
if (_momentum) | |||
{ | |||
@@ -36,11 +36,7 @@ namespace Tensorflow.Keras.Utils | |||
ops.init_scope(); | |||
Func<Tensor> init_val = () => args.Initializer.Apply(new InitializerArgs | |||
{ | |||
Shape = args.Shape, | |||
DType = args.DType | |||
}); | |||
Func<Tensor> init_val = () => args.Initializer.Apply(new InitializerArgs(args.Shape, dtype: args.DType)); | |||
var variable_dtype = args.DType.as_base_dtype(); | |||
var v = tf.Variable(init_val, | |||
@@ -6,8 +6,20 @@ namespace Tensorflow | |||
{ | |||
public class InitializerArgs | |||
{ | |||
public string Name { get; set; } | |||
public TensorShape Shape { get; set; } | |||
public TF_DataType DType { get; set; } | |||
public bool? VerifyShape { get; set; } = null; | |||
public InitializerArgs(TensorShape shape, | |||
TF_DataType dtype = TF_DataType.DtInvalid, | |||
bool? verify_shape = null, | |||
string name = null) | |||
{ | |||
Shape = shape; | |||
DType = dtype; | |||
VerifyShape = verify_shape; | |||
Name = name; | |||
} | |||
} | |||
} |
@@ -18,17 +18,21 @@ namespace Tensorflow.Operations.Initializers | |||
{ | |||
public class Zeros : IInitializer | |||
{ | |||
private TF_DataType dtype; | |||
TensorShape shape; | |||
TF_DataType dtype; | |||
public Zeros(TF_DataType dtype = TF_DataType.TF_FLOAT) | |||
public Zeros(TensorShape shape = null, TF_DataType dtype = TF_DataType.TF_FLOAT) | |||
{ | |||
this.shape = shape; | |||
this.dtype = dtype; | |||
} | |||
public Tensor Apply(InitializerArgs args) | |||
{ | |||
if (args.DType == TF_DataType.DtInvalid) | |||
args.DType = this.dtype; | |||
args.DType = dtype; | |||
if (args.Shape == null) | |||
args.Shape = shape; | |||
return array_ops.zeros(args.Shape, dtype); | |||
} | |||
@@ -71,7 +71,7 @@ namespace Tensorflow.Operations | |||
public bool UseCudnnOnGpu { get; set; } = true; | |||
public int[] Dilations { get; set; } = new [] { 1, 1, 1, 1 }; | |||
public int[] Dilations { get; set; } = new int[] { 1, 1, 1, 1 }; | |||
public Conv2dParams() | |||
{ | |||
@@ -42,6 +42,22 @@ namespace Tensorflow.Operations | |||
/// <returns></returns> | |||
public static Tensor conv2d(Conv2dParams parameters) | |||
{ | |||
if (tf.executing_eagerly()) | |||
{ | |||
var results = tf.Runner.TFE_FastPathExecute(tf.Context, tf.Context.DeviceName, | |||
"Conv2D", parameters.Name, | |||
null, | |||
parameters.Input, parameters.Filter, | |||
"strides", parameters.Strides, | |||
"use_cudnn_on_gpu", parameters.UseCudnnOnGpu, | |||
"padding", parameters.Padding, | |||
"explicit_paddings", parameters.ExplicitPaddings, | |||
"data_format", parameters.DataFormat, | |||
"dilations", parameters.Dilations); | |||
return results[0]; | |||
} | |||
var _op = tf.OpDefLib._apply_op_helper("Conv2D", name: parameters.Name, args: new | |||
{ | |||
input = parameters.Input, | |||
@@ -64,6 +80,22 @@ namespace Tensorflow.Operations | |||
/// <returns></returns> | |||
public static Tensor conv2d_backprop_filter(Conv2dParams parameters) | |||
{ | |||
if (tf.executing_eagerly()) | |||
{ | |||
var results = tf.Runner.TFE_FastPathExecute(tf.Context, tf.Context.DeviceName, | |||
"Conv2DBackpropFilter", parameters.Name, | |||
null, | |||
parameters.Input, parameters.FilterSizes, parameters.OutBackProp, | |||
"strides", parameters.Strides, | |||
"use_cudnn_on_gpu", parameters.UseCudnnOnGpu, | |||
"padding", parameters.Padding, | |||
"explicit_paddings", parameters.ExplicitPaddings, | |||
"data_format", parameters.DataFormat, | |||
"dilations", parameters.Dilations); | |||
return results[0]; | |||
} | |||
var _op = tf.OpDefLib._apply_op_helper("Conv2DBackpropFilter", name: parameters.Name, args: new | |||
{ | |||
input = parameters.Input, | |||
@@ -87,6 +119,22 @@ namespace Tensorflow.Operations | |||
/// <returns></returns> | |||
public static Tensor conv2d_backprop_input(Conv2dParams parameters) | |||
{ | |||
if (tf.executing_eagerly()) | |||
{ | |||
var results = tf.Runner.TFE_FastPathExecute(tf.Context, tf.Context.DeviceName, | |||
"Conv2DBackpropInput", parameters.Name, | |||
null, | |||
parameters.InputSizes, parameters.Filter, parameters.OutBackProp, | |||
"strides", parameters.Strides, | |||
"use_cudnn_on_gpu", parameters.UseCudnnOnGpu, | |||
"padding", parameters.Padding, | |||
"explicit_paddings", parameters.ExplicitPaddings, | |||
"data_format", parameters.DataFormat, | |||
"dilations", parameters.Dilations); | |||
return results[0]; | |||
} | |||
var _op = tf.OpDefLib._apply_op_helper("Conv2DBackpropInput", name: parameters.Name, args: new | |||
{ | |||
input_sizes = parameters.InputSizes, | |||
@@ -341,6 +389,20 @@ namespace Tensorflow.Operations | |||
string data_format = "NHWC", | |||
string name = null) | |||
{ | |||
if (tf.executing_eagerly()) | |||
{ | |||
var results = tf.Runner.TFE_FastPathExecute(tf.Context, tf.Context.DeviceName, | |||
"MaxPool", name, | |||
null, | |||
input, | |||
"ksize", ksize, | |||
"strides", strides, | |||
"padding", padding, | |||
"data_format", data_format); | |||
return results[0]; | |||
} | |||
var _op = tf.OpDefLib._apply_op_helper("MaxPool", name: name, args: new | |||
{ | |||
input, | |||
@@ -356,6 +418,20 @@ namespace Tensorflow.Operations | |||
public static Tensor max_pool_grad(Tensor orig_input, Tensor orig_output, Tensor grad, int[] ksize, int[] strides, string padding, | |||
string data_format= "NHWC", string name= null) | |||
{ | |||
if (tf.executing_eagerly()) | |||
{ | |||
var results = tf.Runner.TFE_FastPathExecute(tf.Context, tf.Context.DeviceName, | |||
"MaxPoolGrad", name, | |||
null, | |||
orig_input, orig_output, grad, | |||
"ksize", ksize, | |||
"strides", strides, | |||
"padding", padding, | |||
"data_format", data_format); | |||
return results[0]; | |||
} | |||
var _op = tf.OpDefLib._apply_op_helper("MaxPoolGrad", name: name, args: new | |||
{ | |||
orig_input, | |||
@@ -384,7 +460,7 @@ namespace Tensorflow.Operations | |||
public static Tensor relu_grad(Tensor gradients, Tensor features, string name = null) | |||
{ | |||
if (tf.Context.executing_eagerly()) | |||
if (tf.executing_eagerly()) | |||
{ | |||
var results = tf.Runner.TFE_FastPathExecute(tf.Context, tf.Context.DeviceName, | |||
"ReluGrad", name, | |||
@@ -227,7 +227,7 @@ namespace Tensorflow | |||
return grouped_inputs.ToArray(); | |||
} | |||
public T get_attr<T>(string name) | |||
public virtual T get_attr<T>(string name) | |||
=> (T)get_attr(name); | |||
public virtual object get_attr(string name) | |||
@@ -424,6 +424,17 @@ namespace Tensorflow | |||
/// <returns></returns> | |||
public static Tensor[] shape_n(Tensor[] input, TF_DataType out_type = TF_DataType.TF_INT32, string name = null) | |||
{ | |||
if (tf.executing_eagerly()) | |||
{ | |||
var results = tf.Runner.TFE_FastPathExecute(tf.Context, tf.Context.DeviceName, | |||
"ShapeN", name, | |||
null, | |||
input, | |||
"out_type", out_type); | |||
return results; | |||
} | |||
var _op = tf.OpDefLib._apply_op_helper("ShapeN", name, new { input, out_type }); | |||
return _op.outputs; | |||
} | |||
@@ -450,7 +461,7 @@ namespace Tensorflow | |||
public static Tensor tile<T>(Tensor input, T multiples, string name = null) | |||
{ | |||
if (tf.Context.executing_eagerly()) | |||
if (tf.executing_eagerly()) | |||
{ | |||
var results = tf.Runner.TFE_FastPathExecute(tf.Context, tf.Context.DeviceName, | |||
"Tile", name, | |||
@@ -320,7 +320,7 @@ namespace Tensorflow | |||
/// </remarks> | |||
public static Tensor sigmoid(Tensor x, string name = "Sigmoid") | |||
{ | |||
if (tf.Context.executing_eagerly()) | |||
if (tf.executing_eagerly()) | |||
{ | |||
var results = tf.Runner.TFE_FastPathExecute(tf.Context, tf.Context.DeviceName, | |||
"Sigmoid", name, | |||
@@ -1074,23 +1074,6 @@ namespace Tensorflow | |||
return _op.outputs[0]; | |||
} | |||
public static Tensor pow<Tx, Ty>(Tx x, Ty y, string name = null) | |||
{ | |||
if (tf.Context.executing_eagerly()) | |||
{ | |||
var results = tf.Runner.TFE_FastPathExecute(tf.Context, tf.Context.DeviceName, | |||
"Pow", name, | |||
null, | |||
x, y); | |||
return results[0]; | |||
} | |||
var _op = tf.OpDefLib._apply_op_helper("Pow", name, args: new { x, y }); | |||
return _op.outputs[0]; | |||
} | |||
public static Tensor _sum<Tx, Ty>(Tx input, Ty axis = default, bool keep_dims = false, string name = null) | |||
{ | |||
if (tf.Context.executing_eagerly()) | |||
@@ -681,7 +681,19 @@ namespace Tensorflow | |||
var x_tensor = ops.convert_to_tensor(x, name: "x"); | |||
var y_tensor = ops.convert_to_tensor(y, name: "y", dtype: x_tensor.dtype.as_base_dtype()); | |||
return gen_math_ops.pow(x_tensor, y_tensor, name: name); | |||
if (tf.executing_eagerly()) | |||
{ | |||
var results = tf.Runner.TFE_FastPathExecute(tf.Context, tf.Context.DeviceName, | |||
"Pow", name, | |||
null, | |||
x_tensor, y_tensor); | |||
return results[0]; | |||
} | |||
var _op = tf.OpDefLib._apply_op_helper("Pow", name, args: new { x_tensor, y_tensor }); | |||
return _op.output; | |||
}); | |||
public static Tensor range(object start, object limit = null, object delta = null, TF_DataType dtype = TF_DataType.DtInvalid, string name = "range") | |||
@@ -754,9 +766,6 @@ namespace Tensorflow | |||
if (transpose_b && adjoint_b) | |||
throw new ValueError("Only one of transpose_b and adjoint_b can be True."); | |||
a = ops.convert_to_tensor(a, name: "a"); | |||
b = ops.convert_to_tensor(b, name: "b"); | |||
result = gen_math_ops.mat_mul(a, b, transpose_a, transpose_b, name); | |||
}); | |||
@@ -30,7 +30,7 @@ namespace Tensorflow | |||
/// <param name="seed"></param> | |||
/// <param name="name"></param> | |||
/// <returns></returns> | |||
public static Tensor random_normal(int[] shape, | |||
public static Tensor random_normal(TensorShape shape, | |||
float mean = 0.0f, | |||
float stddev = 1.0f, | |||
TF_DataType dtype = TF_DataType.TF_FLOAT, | |||
@@ -23,6 +23,24 @@ namespace Tensorflow | |||
{ | |||
public class gen_training_ops | |||
{ | |||
public static Operation resource_apply_adam(Tensor var, Tensor m, Tensor v, Tensor beta1_power, Tensor beta2_power, | |||
Tensor lr, Tensor beta1, Tensor beta2, Tensor epsilon, Tensor grad, | |||
bool use_locking = false, bool use_nesterov = false, string name = null) | |||
{ | |||
if (tf.executing_eagerly()) | |||
{ | |||
var result = tf.Runner.TFE_FastPathExecute(tf.Context, tf.Context.DeviceName, | |||
"ResourceApplyAdam", name, | |||
null, | |||
var, m, v, beta1_power, beta2_power, lr, beta1, beta2, epsilon, grad, | |||
"use_locking", use_locking, | |||
"use_nesterov", use_nesterov); | |||
return null; | |||
} | |||
throw new NotImplementedException(""); | |||
} | |||
public static Tensor apply_adam(IVariableV1 var, IVariableV1 m, IVariableV1 v, Tensor beta1_power, Tensor beta2_power, | |||
Tensor lr, Tensor beta1, Tensor beta2, Tensor epsilon, Tensor grad, | |||
bool use_locking = false, bool use_nesterov = false, string name = null) | |||
@@ -56,12 +74,12 @@ namespace Tensorflow | |||
use_locking | |||
}); | |||
return _op.outputs[0]; | |||
return _op.output; | |||
} | |||
public static Operation resource_apply_gradient_descent(Tensor var, Tensor alpha, Tensor delta, bool use_locking = false, string name = null) | |||
{ | |||
if (tf.Context.executing_eagerly()) | |||
if (tf.executing_eagerly()) | |||
{ | |||
var result = tf.Runner.TFE_FastPathExecute(tf.Context, tf.Context.DeviceName, | |||
"ResourceApplyGradientDescent", name, | |||
@@ -18,7 +18,7 @@ namespace Tensorflow | |||
protected string handle_name => _handle_name; | |||
protected string _unique_id; | |||
public string unique_id => _unique_id; | |||
public string UniqueId => _unique_id; | |||
protected bool _in_graph_mode; | |||
@@ -31,6 +31,7 @@ namespace Tensorflow | |||
/// </summary> | |||
public interface IVariableV1 | |||
{ | |||
public string UniqueId { get; } | |||
public string Name { get; } | |||
public Tensor Handle { get; } | |||
public string Device { get; } | |||
@@ -25,6 +25,7 @@ namespace Tensorflow | |||
public partial class RefVariable : IVariableV1, IProtoBuf<VariableDef, RefVariable> | |||
{ | |||
protected string _name; | |||
public string UniqueId => _name; | |||
public Tensor GraphElement { get; } | |||
public Tensor _variable; | |||
public Tensor Handle => _variable; | |||
@@ -67,8 +67,6 @@ namespace Tensorflow | |||
dtype: dtype, | |||
shape: shape); | |||
} | |||
// handle.ResourceVar = this; | |||
} | |||
private void _init_from_args(object initial_value = null, | |||
@@ -79,7 +77,8 @@ namespace Tensorflow | |||
TF_DataType dtype = TF_DataType.DtInvalid, | |||
TensorShape shape = null) | |||
{ | |||
var init_from_fn = initial_value.GetType().Name == "Func`1"; | |||
var init_from_fn = initial_value.GetType().Name == "Func`1" || | |||
initial_value.GetType().GetInterface("IInitializer") != null; | |||
if(collections == null) | |||
collections = new List<string>() { tf.GraphKeys.GLOBAL_VARIABLES }; | |||
_trainable = trainable; | |||
@@ -112,9 +111,12 @@ namespace Tensorflow | |||
attr.List.S.Add(ByteString.CopyFromUtf8($"loc:@{handle_name}")); | |||
tf_with(ops.name_scope("Initializer"), delegate | |||
{ | |||
initial_value = ops.convert_to_tensor(init_from_fn ? (initial_value as Func<Tensor>)() : initial_value, | |||
name: "initial_value", | |||
dtype: dtype); | |||
if (initial_value.GetType().GetInterface("IInitializer") != null) | |||
initial_value = ops.convert_to_tensor((initial_value as IInitializer).Apply(new InitializerArgs(shape, dtype: dtype))); | |||
else | |||
initial_value = ops.convert_to_tensor(init_from_fn ? (initial_value as Func<Tensor>)() : initial_value, | |||
name: "initial_value", | |||
dtype: dtype); | |||
}); | |||
_shape = shape ?? (initial_value as Tensor).TensorShape; | |||
_initial_value = initial_value as Tensor; | |||
@@ -162,11 +162,7 @@ namespace Tensorflow | |||
} | |||
else | |||
{ | |||
Func<Tensor> init_val = () => initializer.Apply(new InitializerArgs | |||
{ | |||
Shape = shape, | |||
DType = dtype | |||
}); | |||
Func<Tensor> init_val = () => initializer.Apply(new InitializerArgs(shape, dtype: dtype)); | |||
var variable_dtype = dtype.as_base_dtype(); | |||
v = variable_scope.default_variable_creator(init_val, | |||