@@ -15,6 +15,7 @@ | |||
******************************************************************************/ | |||
using System; | |||
using System.Linq; | |||
using Tensorflow.Eager; | |||
namespace Tensorflow.Contexts | |||
@@ -87,6 +88,29 @@ namespace Tensorflow.Contexts | |||
context_switches.Pop(); | |||
} | |||
public Tensor RunInAutoMode(Func<Tensor> graphAction, Func<Tensor> eagerAction, params Tensor[] tensors) | |||
{ | |||
var shouldRunInEager = executing_eagerly() | |||
&& tensors.Count(x => x.IsEagerTensor) == tensors.Length; | |||
if (shouldRunInEager) | |||
return eagerAction(); | |||
else | |||
{ | |||
if (executing_eagerly()) | |||
{ | |||
graph_mode(); | |||
var result = graphAction(); | |||
restore_mode(); | |||
return result; | |||
} | |||
else | |||
{ | |||
return graphAction(); | |||
} | |||
} | |||
} | |||
public void Dispose() | |||
=> Handle.Dispose(); | |||
} | |||
@@ -0,0 +1,12 @@ | |||
using NumSharp; | |||
using System; | |||
using System.Collections.Generic; | |||
using System.Text; | |||
namespace Tensorflow.Keras.ArgsDefinition | |||
{ | |||
public class ZeroPadding2DArgs : LayerArgs | |||
{ | |||
public NDArray Padding { get; set; } | |||
} | |||
} |
@@ -14,6 +14,7 @@ | |||
limitations under the License. | |||
******************************************************************************/ | |||
using NumSharp; | |||
using System; | |||
using System.Collections.Generic; | |||
using static Tensorflow.Binding; | |||
@@ -121,6 +122,38 @@ namespace Tensorflow.Keras | |||
_GRAPH_LEARNING_PHASES[tf.get_default_graph()] = (GraphLearningPhase)((value) ? 1 : 0); | |||
} | |||
/// <summary> | |||
/// Pads the 2nd and 3rd dimensions of a 4D tensor. | |||
/// </summary> | |||
/// <param name="x"></param> | |||
/// <param name="padding"></param> | |||
/// <param name="data_format"></param> | |||
/// <returns></returns> | |||
public Tensor spatial_2d_padding(Tensor x, NDArray padding = null, string data_format = null) | |||
{ | |||
if (padding == null) | |||
padding = new[,] { { 1, 1 }, { 1, 1 } }; | |||
NDArray pattern; | |||
if (data_format == "channels_first") | |||
pattern = new int[,] | |||
{ | |||
{ 0, 0 }, | |||
{ 0, 0 }, | |||
{ padding[0][0], padding[0][1] }, | |||
{ padding[1][0], padding[1][1] } | |||
}; | |||
else | |||
pattern = new int[,] | |||
{ | |||
{ 0, 0 }, | |||
{ padding[0][0], padding[0][1] }, | |||
{ padding[1][0], padding[1][1] }, | |||
{ 0, 0 } | |||
}; | |||
return array_ops.pad(x, pattern); | |||
} | |||
public class _DummyEagerGraph | |||
{ } | |||
@@ -0,0 +1,47 @@ | |||
using System; | |||
using System.Collections.Generic; | |||
using System.Linq; | |||
using System.Security.Cryptography.X509Certificates; | |||
using System.Text; | |||
using static Tensorflow.Binding; | |||
namespace Tensorflow.Keras.Engine | |||
{ | |||
public class BaseLayerUtils | |||
{ | |||
public static Layer[] CreateKerasHistoryHelper(Tensors tensors) | |||
{ | |||
var processed_ops = new List<Operation>(); | |||
var created_layers = new List<Layer>(); | |||
foreach (var tensor in tensors) | |||
{ | |||
if (tensor.KerasHistory != null) | |||
continue; | |||
var op = tensor.op; | |||
if (!processed_ops.Contains(op)) | |||
{ | |||
var layer_inputs = new List<Tensor>(); | |||
foreach (var (i, op_input) in enumerate(op.inputs._inputs)) | |||
{ | |||
if (uses_keras_history(op_input)) | |||
layer_inputs.Add(op_input); | |||
else | |||
{ | |||
} | |||
} | |||
} | |||
} | |||
return created_layers.ToArray(); | |||
} | |||
static bool uses_keras_history(Tensor op_input) | |||
{ | |||
return Layer.KerasHistories.Any(x => x.tensor == op_input); | |||
} | |||
} | |||
} |
@@ -1,5 +1,7 @@ | |||
using System; | |||
using System.Collections.Generic; | |||
using System.Linq; | |||
using System.Security.Cryptography.X509Certificates; | |||
using System.Text; | |||
using Tensorflow.Keras.ArgsDefinition; | |||
@@ -47,12 +49,15 @@ namespace Tensorflow.Keras.Engine | |||
// A graph network does not autocast inputs, as its layers will cast them instead. | |||
_autocast = false; | |||
if (outputs.Any(x => x.KerasHistory == null)) | |||
BaseLayerUtils.CreateKerasHistoryHelper(outputs); | |||
// Build self._output_layers: | |||
foreach(var x in outputs) | |||
foreach (var x in outputs) | |||
{ | |||
var (layer, node_index, tensor_index) = x.KerasHistory; | |||
_output_layers.append(layer); | |||
_output_coordinates.append(new KerasHistory(layer, node_index, tensor_index)); | |||
_output_coordinates.append(new KerasHistory(layer, node_index, tensor_index, x)); | |||
} | |||
// Build self._input_layers: | |||
@@ -60,8 +65,9 @@ namespace Tensorflow.Keras.Engine | |||
{ | |||
var (layer, node_index, tensor_index) = x.KerasHistory; | |||
_input_layers.append(layer); | |||
_input_coordinates.append(new KerasHistory(layer, node_index, tensor_index)); | |||
_input_coordinates.append(new KerasHistory(layer, node_index, tensor_index, x)); | |||
} | |||
} | |||
} | |||
} |
@@ -12,12 +12,15 @@ namespace Tensorflow.Keras.Engine | |||
Layer layer; | |||
int node_index; | |||
int tensor_index; | |||
public Tensor tensor; | |||
public KerasHistory(Layer layer, int node_index, int tensor_index) | |||
public KerasHistory(Layer layer, int node_index, int tensor_index, Tensor tensor) | |||
{ | |||
this.layer = layer; | |||
this.node_index = node_index; | |||
this.tensor_index = tensor_index; | |||
this.tensor = tensor; | |||
Console.WriteLine(tensor.name); | |||
} | |||
public void Deconstruct(out Layer layer, out int node_index, out int tensor_index) | |||
@@ -27,6 +30,9 @@ namespace Tensorflow.Keras.Engine | |||
tensor_index = this.tensor_index; | |||
} | |||
public override string ToString() | |||
=> $"{layer.GetType().Name} {layer.Name} {tensor.name}"; | |||
public static implicit operator Layer(KerasHistory history) | |||
=> history.layer; | |||
} | |||
@@ -0,0 +1,18 @@ | |||
using System; | |||
using System.Collections.Generic; | |||
using System.Text; | |||
namespace Tensorflow.Keras.Engine | |||
{ | |||
public partial class Layer | |||
{ | |||
/// <summary> | |||
/// Loads all layer weights, either from a TensorFlow or an HDF5 weight file. | |||
/// </summary> | |||
/// <param name="filepath"></param> | |||
public void load_weights(string filepath) | |||
{ | |||
} | |||
} | |||
} |
@@ -56,6 +56,7 @@ namespace Tensorflow.Keras.Engine | |||
/// Provides information about which inputs are compatible with the layer. | |||
/// </summary> | |||
protected InputSpec inputSpec; | |||
bool dynamic = true; | |||
public bool SupportsMasking { get; set; } | |||
protected List<IVariableV1> trainableWeights; | |||
public List<IVariableV1> trainable_variables | |||
@@ -88,6 +89,7 @@ namespace Tensorflow.Keras.Engine | |||
ThreadLocal<CallContext> callContext; | |||
public CallContext CallContext => callContext.Value; | |||
public static List<KerasHistory> KerasHistories = new List<KerasHistory>(); | |||
public Layer(LayerArgs args) | |||
{ | |||
@@ -129,6 +131,11 @@ namespace Tensorflow.Keras.Engine | |||
Value = new CallContext() | |||
}; | |||
var history = inputs.Where(x => x.KerasHistory != null | |||
&& !KerasHistories.Contains(x.KerasHistory)) | |||
.Select(x => x.KerasHistory); | |||
KerasHistories.AddRange(history); | |||
if (_in_functional_construction_mode(inputs)) | |||
return _functional_construction_call(inputs); | |||
@@ -166,7 +173,8 @@ namespace Tensorflow.Keras.Engine | |||
bool _in_functional_construction_mode(Tensors inputs) | |||
{ | |||
return inputs.Count(x => !x.IsEagerTensor) == inputs.Count(); | |||
return tf.Context.executing_eagerly() | |||
&& inputs.Count(x => !x.IsEagerTensor) == inputs.Count(); | |||
} | |||
Tensors _functional_construction_call(Tensors inputs) | |||
@@ -191,6 +199,15 @@ namespace Tensorflow.Keras.Engine | |||
{ | |||
MaybeBuild(inputs); | |||
// Wrapping `call` function in autograph to allow for dynamic control | |||
// flow and control dependencies in call. We are limiting this to | |||
// subclassed layers as autograph is strictly needed only for | |||
// subclassed layers and models. | |||
// tf_convert will respect the value of autograph setting in the | |||
// enclosing tf.function, if any. | |||
if (!dynamic) | |||
throw new NotImplementedException(""); | |||
outputs = call(inputs); | |||
outputs = _set_connectivity_metadata_(inputs, outputs); | |||
@@ -243,6 +260,13 @@ namespace Tensorflow.Keras.Engine | |||
return null; | |||
} | |||
/// <summary> | |||
/// Subclass has to override this method. | |||
/// </summary> | |||
/// <param name="inputs"></param> | |||
/// <param name="state"></param> | |||
/// <param name="is_training"></param> | |||
/// <returns></returns> | |||
protected virtual Tensors call(Tensors inputs, Tensor state = null, bool is_training = false) | |||
{ | |||
throw new NotImplementedException(""); | |||
@@ -263,9 +287,9 @@ namespace Tensorflow.Keras.Engine | |||
tf.init_scope(); | |||
//tf.Context.eager_mode(); | |||
tf.Context.eager_mode(); | |||
build(inputs.shape); | |||
//tf.Context.restore_mode(); | |||
tf.Context.restore_mode(); | |||
built = true; | |||
} | |||
@@ -282,18 +306,14 @@ namespace Tensorflow.Keras.Engine | |||
protected virtual IVariableV1 add_weight(string name, | |||
TensorShape shape, | |||
TF_DataType dtype = TF_DataType.DtInvalid, | |||
TF_DataType dtype = TF_DataType.TF_FLOAT, | |||
IInitializer initializer = null, | |||
IRegularizer regularizer = null, | |||
bool? trainable = null, | |||
VariableSynchronization synchronization = VariableSynchronization.Auto, | |||
VariableAggregation aggregation = VariableAggregation.None, | |||
bool trainable = true, | |||
Func<VariableArgs, IVariableV1> getter = null) | |||
{ | |||
if (dtype == TF_DataType.DtInvalid) | |||
dtype = TF_DataType.TF_FLOAT; | |||
if (trainable == null) | |||
trainable = true; | |||
// Initialize variable when no initializer provided | |||
if (initializer == null) | |||
{ | |||
@@ -306,6 +326,9 @@ namespace Tensorflow.Keras.Engine | |||
throw new ValueError($"An initializer for variable {name} of type {dtype.as_base_dtype()} is required for layer {name}"); | |||
} | |||
if (synchronization == VariableSynchronization.OnRead) | |||
trainable = false; | |||
var args = new VariableArgs | |||
{ | |||
Name = name, | |||
@@ -314,7 +337,9 @@ namespace Tensorflow.Keras.Engine | |||
Getter = getter ?? base_layer_utils.make_variable, | |||
Overwrite = true, | |||
Initializer = initializer, | |||
Trainable = trainable.Value | |||
Synchronization = synchronization, | |||
Aggregation = aggregation, | |||
Trainable = trainable | |||
}; | |||
var variable = _add_variable_with_custom_getter(args); | |||
@@ -58,7 +58,7 @@ namespace Tensorflow.Keras.Engine | |||
// Set metadata on outputs. | |||
var node_index = layer.InboundNodes.Count - 1; | |||
foreach (var (i, tensor) in enumerate(Outputs)) | |||
tensor.KerasHistory = new KerasHistory(layer, node_index, i); | |||
tensor.KerasHistory = new KerasHistory(layer, node_index, i, tensor); | |||
} | |||
} | |||
} |
@@ -38,8 +38,8 @@ namespace Tensorflow.Keras.Layers | |||
string _data_format; | |||
IInitializer beta_initializer => args.BetaInitializer; | |||
IInitializer gamma_initializer => args.GammaInitializer; | |||
IInitializer moving_mean_initializer; | |||
IInitializer moving_variance_initializer; | |||
IInitializer moving_mean_initializer => args.MovingMeanInitializer; | |||
IInitializer moving_variance_initializer => args.MovingVarianceInitializer; | |||
IRegularizer gamma_regularizer => args.GammaRegularizer; | |||
IVariableV1 gamma; | |||
IVariableV1 beta; | |||
@@ -101,13 +101,17 @@ namespace Tensorflow.Keras.Layers | |||
param_shape, | |||
dtype: param_dtype, | |||
initializer: moving_mean_initializer, | |||
synchronization: VariableSynchronization.OnRead, | |||
aggregation: VariableAggregation.Mean, | |||
trainable: false); | |||
moving_variance = add_weight("moving_variance", | |||
shape: param_shape, | |||
dtype: param_dtype, | |||
initializer: moving_variance_initializer, | |||
trainable: false); | |||
shape: param_shape, | |||
dtype: param_dtype, | |||
initializer: moving_variance_initializer, | |||
synchronization: VariableSynchronization.OnRead, | |||
aggregation: VariableAggregation.Mean, | |||
trainable: false); | |||
if (renorm) | |||
throw new NotImplementedException("build when renorm is true"); | |||
@@ -131,6 +135,12 @@ namespace Tensorflow.Keras.Layers | |||
private Tensor _fused_batch_norm(Tensor inputs, Tensor training) | |||
{ | |||
TensorShape input_batch_size = null; | |||
var use_fused_avg_updates = true; | |||
float exponential_avg_factor = 0; | |||
if (use_fused_avg_updates) | |||
exponential_avg_factor = 1.0f - momentum; | |||
var beta = this.beta; | |||
var gamma = this.gamma; | |||
@@ -146,17 +156,22 @@ namespace Tensorflow.Keras.Layers | |||
Func<Tensor[]> _fused_batch_norm_inference = () => | |||
{ | |||
var moving_mean_tensor = moving_mean.AsTensor(); | |||
var moving_variance_tensor = moving_variance.AsTensor(); | |||
return tf.nn.fused_batch_norm( | |||
inputs, | |||
gamma, | |||
beta, | |||
mean: moving_mean.AsTensor(), | |||
variance: moving_variance.AsTensor(), | |||
mean: moving_mean_tensor, | |||
variance: moving_variance_tensor, | |||
epsilon: epsilon, | |||
is_training: false, | |||
data_format: _data_format); | |||
}; | |||
if (use_fused_avg_updates && input_batch_size != null) | |||
throw new NotImplementedException(""); | |||
var results = tf_utils.smart_cond(training, _fused_batch_norm_training, _fused_batch_norm_inference); | |||
var (output, mean, variance) = (results[0], results[1], results[2]); | |||
var training_value = tf_utils.constant_value(training); | |||
@@ -1,4 +1,5 @@ | |||
using System; | |||
using NumSharp; | |||
using System; | |||
using System.Collections.Generic; | |||
using System.Text; | |||
using Tensorflow.Keras.ArgsDefinition; | |||
@@ -33,6 +34,7 @@ namespace Tensorflow.Keras.Layers | |||
DataFormat = data_format, | |||
DilationRate = dilation_rate == null ? (1, 1) : dilation_rate, | |||
Groups = groups, | |||
UseBias = use_bias, | |||
KernelRegularizer = kernel_regularizer, | |||
KernelInitializer = kernel_initializer == null ? tf.glorot_uniform_initializer : kernel_initializer, | |||
BiasInitializer = bias_initializer == null ? tf.zeros_initializer : bias_initializer, | |||
@@ -129,6 +131,17 @@ namespace Tensorflow.Keras.Layers | |||
InputShape = input_shape | |||
}); | |||
/// <summary> | |||
/// Zero-padding layer for 2D input (e.g. picture). | |||
/// </summary> | |||
/// <param name="padding"></param> | |||
/// <returns></returns> | |||
public ZeroPadding2D ZeroPadding2D(NDArray padding) | |||
=> new ZeroPadding2D(new ZeroPadding2DArgs | |||
{ | |||
Padding = padding | |||
}); | |||
Activation GetActivationByName(string name) | |||
=> name switch | |||
{ | |||
@@ -0,0 +1,39 @@ | |||
using NumSharp; | |||
using System; | |||
using System.Collections.Generic; | |||
using System.Text; | |||
using Tensorflow.Keras.ArgsDefinition; | |||
using Tensorflow.Keras.Engine; | |||
using Tensorflow.Keras.Utils; | |||
using static Tensorflow.Binding; | |||
namespace Tensorflow.Keras.Layers | |||
{ | |||
/// <summary> | |||
/// Zero-padding layer for 2D input (e.g. picture). | |||
/// | |||
/// This layer can add rows and columns of zeros | |||
/// at the top, bottom, left and right side of an image tensor. | |||
/// </summary> | |||
public class ZeroPadding2D : Layer | |||
{ | |||
string data_format; | |||
NDArray padding; | |||
InputSpec input_spec; | |||
public ZeroPadding2D(ZeroPadding2DArgs args, string data_format = null) | |||
: base(args) | |||
{ | |||
this.data_format = conv_utils.normalize_data_format(data_format); | |||
this.padding = args.Padding; | |||
this.input_spec = new InputSpec(ndim: 4); | |||
} | |||
protected override Tensors call(Tensors inputs, Tensor state = null, bool is_training = false) | |||
{ | |||
return tf.keras.backend.spatial_2d_padding(inputs, | |||
padding: padding, | |||
data_format: data_format); | |||
} | |||
} | |||
} |
@@ -127,7 +127,7 @@ namespace Tensorflow.Layers | |||
int[] shape, | |||
TF_DataType dtype = TF_DataType.DtInvalid, | |||
IInitializer initializer = null, | |||
bool? trainable = null, | |||
bool trainable = true, | |||
VariableSynchronization synchronization = VariableSynchronization.Auto, | |||
VariableAggregation aggregation = VariableAggregation.None) | |||
{ | |||
@@ -137,8 +137,6 @@ namespace Tensorflow.Layers | |||
if (synchronization == VariableSynchronization.OnRead) | |||
trainable = false; | |||
else if (!trainable.HasValue) | |||
trainable = true; | |||
if (default_graph.building_function) | |||
{ | |||
@@ -56,20 +56,24 @@ namespace Tensorflow.Operations | |||
var strides = _get_sequence(args.Strides, num_spatial_dims, channel_index); | |||
Tensor result = null; | |||
tf_with(ops.name_scope(name, default_name: null, (input, filters)), scope => | |||
tf_with(ops.name_scope(name, default_name: null), scope => | |||
{ | |||
name = scope; | |||
if (num_spatial_dims == 2) | |||
{ | |||
var filters_tensor = filters.AsTensor(); | |||
result = gen_nn_ops.conv2d(new Conv2dParams | |||
{ | |||
Input = input, | |||
Filter = filters.AsTensor(), | |||
Filter = filters_tensor, | |||
Strides = strides, | |||
Padding = padding, | |||
DataFormat = data_format, | |||
Dilations = dilations, | |||
Name = name | |||
}); | |||
} | |||
else | |||
throw new NotImplementedException(""); | |||
}); | |||
@@ -263,7 +263,7 @@ namespace Tensorflow | |||
List<TF_DataType> types, | |||
List<TF_DataType> base_types, | |||
List<TF_DataType> input_types, | |||
dynamic values) | |||
object values) | |||
{ | |||
var input_name = input_arg.Name; | |||
@@ -73,6 +73,16 @@ namespace Tensorflow | |||
return _op.output; | |||
} | |||
public static Tensor concat_v2(Tensor[] values, int axis, string name = null) | |||
=> tf.Context.RunInAutoMode(() | |||
=> tf.OpDefLib._apply_op_helper("ConcatV2", name: name, | |||
args: new { values, axis }).output, () | |||
=> tf.Runner.TFE_FastPathExecute(tf.Context, tf.Context.DeviceName, | |||
"ConcatV2", name, | |||
null, | |||
values, axis).FirstOrDefault(), | |||
values); | |||
private static Tensor concat_v2_eager_fallback<T1, T2>(T1[] values, T2 axis, string name, Context ctx) | |||
{ | |||
var _attr_N = len(values); | |||
@@ -293,20 +303,13 @@ namespace Tensorflow | |||
} | |||
public static Tensor reshape<T>(Tensor tensor, T shape, string name = null) | |||
{ | |||
if (tf.Context.executing_eagerly()) | |||
{ | |||
var results = tf.Runner.TFE_FastPathExecute(tf.Context, tf.Context.DeviceName, | |||
"Reshape", name, | |||
=> tf.Context.RunInAutoMode(() | |||
=> tf.OpDefLib._apply_op_helper("Reshape", name, new { tensor, shape }).output, () | |||
=> tf.Runner.TFE_FastPathExecute(tf.Context, tf.Context.DeviceName, | |||
"Reshape", name, | |||
null, | |||
tensor, shape); | |||
return results[0]; | |||
} | |||
var _op = tf.OpDefLib._apply_op_helper("Reshape", name, new { tensor, shape }); | |||
return _op.output; | |||
} | |||
tensor, shape).FirstOrDefault(), | |||
tensor); | |||
public static Tensor reshape(Tensor tensor, int[] shape, string name = null) | |||
{ | |||
@@ -399,21 +402,15 @@ namespace Tensorflow | |||
} | |||
public static Tensor shape(Tensor input, TF_DataType out_type = TF_DataType.TF_INT32, string name = null) | |||
{ | |||
if (tf.Context.executing_eagerly()) | |||
{ | |||
var results = tf.Runner.TFE_FastPathExecute(tf.Context, tf.Context.DeviceName, | |||
"Shape", name, | |||
=> tf.Context.RunInAutoMode(() | |||
=> tf.OpDefLib._apply_op_helper("Shape", name, | |||
new { input, out_type }).output, () | |||
=> tf.Runner.TFE_FastPathExecute(tf.Context, tf.Context.DeviceName, | |||
"Shape", name, | |||
null, | |||
input, | |||
"out_type", out_type); | |||
return results[0]; | |||
} | |||
var _op = tf.OpDefLib._apply_op_helper("Shape", name, new { input, out_type }); | |||
return _op.outputs[0]; | |||
} | |||
"out_type", out_type).FirstOrDefault(), | |||
input); | |||
/// <summary> | |||
/// Returns shape of tensors. | |||
@@ -460,20 +457,13 @@ namespace Tensorflow | |||
} | |||
public static Tensor tile<T>(Tensor input, T multiples, string name = null) | |||
{ | |||
if (tf.executing_eagerly()) | |||
{ | |||
var results = tf.Runner.TFE_FastPathExecute(tf.Context, tf.Context.DeviceName, | |||
"Tile", name, | |||
=> tf.Context.RunInAutoMode(() | |||
=> tf.OpDefLib._apply_op_helper("Tile", name, new { input, multiples }).output, () | |||
=> tf.Runner.TFE_FastPathExecute(tf.Context, tf.Context.DeviceName, | |||
"Tile", name, | |||
null, | |||
input, multiples); | |||
return results[0]; | |||
} | |||
var _op = tf.OpDefLib._apply_op_helper("Tile", name, new { input, multiples }); | |||
return _op.outputs[0]; | |||
} | |||
input, multiples).FirstOrDefault(), | |||
input); | |||
public static Tensor transpose<T1, T2>(T1 x, T2 perm, string name = null) | |||
{ | |||
@@ -510,37 +500,29 @@ namespace Tensorflow | |||
int new_axis_mask = 0, | |||
int shrink_axis_mask = 0, | |||
string name = null) | |||
{ | |||
if (tf.Context.executing_eagerly()) | |||
{ | |||
var results = tf.Runner.TFE_FastPathExecute(tf.Context, tf.Context.DeviceName, | |||
"StridedSlice", name, | |||
=> tf.Context.RunInAutoMode(() | |||
=> tf.OpDefLib._apply_op_helper("StridedSlice", name, new | |||
{ | |||
input, | |||
begin, | |||
end, | |||
strides, | |||
begin_mask, | |||
end_mask, | |||
ellipsis_mask, | |||
new_axis_mask, | |||
shrink_axis_mask | |||
}).output, () | |||
=> tf.Runner.TFE_FastPathExecute(tf.Context, tf.Context.DeviceName, | |||
"StridedSlice", name, | |||
null, | |||
input, begin, end, strides, | |||
"begin_mask", begin_mask, | |||
"end_mask", end_mask, | |||
"ellipsis_mask", ellipsis_mask, | |||
"new_axis_mask", new_axis_mask, | |||
"shrink_axis_mask", shrink_axis_mask); | |||
return results[0]; | |||
} | |||
var _op = tf.OpDefLib._apply_op_helper("StridedSlice", name, new | |||
{ | |||
input, | |||
begin, | |||
end, | |||
strides, | |||
begin_mask, | |||
end_mask, | |||
ellipsis_mask, | |||
new_axis_mask, | |||
shrink_axis_mask | |||
}); | |||
return _op.outputs[0]; | |||
} | |||
"shrink_axis_mask", shrink_axis_mask).FirstOrDefault(), | |||
input, begin, end, strides); | |||
public static Tensor strided_slice<T>(Tensor input, T[] begin, T[] end, T[] strides, | |||
int begin_mask = 0, | |||
@@ -319,21 +319,13 @@ namespace Tensorflow | |||
/// Specifically, <c>y = 1 / (1 + exp(-x))</c>. | |||
/// </remarks> | |||
public static Tensor sigmoid(Tensor x, string name = "Sigmoid") | |||
{ | |||
if (tf.executing_eagerly()) | |||
{ | |||
var results = tf.Runner.TFE_FastPathExecute(tf.Context, tf.Context.DeviceName, | |||
"Sigmoid", name, | |||
=> tf.Context.RunInAutoMode(() | |||
=> tf.OpDefLib._apply_op_helper("Sigmoid", name: name, new { x }).output, () | |||
=> tf.Runner.TFE_FastPathExecute(tf.Context, tf.Context.DeviceName, | |||
"Sigmoid", name, | |||
null, | |||
x); | |||
return results[0]; | |||
} | |||
var op = tf.OpDefLib._apply_op_helper("Sigmoid", name: name, new { x }); | |||
return op.output; | |||
} | |||
x).FirstOrDefault(), | |||
x); | |||
/// <summary> | |||
/// Computes the gradient of the sigmoid of <c>x</c> wrt its input. | |||
@@ -668,11 +660,13 @@ namespace Tensorflow | |||
/// <param name="name"> A name for the operation (optional).</param> | |||
/// <returns> A `Tensor`. Has the same type as `x`.</returns> | |||
public static Tensor exp(Tensor x, string name = null) | |||
{ | |||
var _op = tf.OpDefLib._apply_op_helper("Exp", name, args: new { x }); | |||
return _op.outputs[0]; | |||
} | |||
=> tf.Context.RunInAutoMode(() | |||
=> tf.OpDefLib._apply_op_helper("Exp", name, args: new { x }).output, () | |||
=> tf.Runner.TFE_FastPathExecute(tf.Context, tf.Context.DeviceName, | |||
"Exp", name, | |||
null, | |||
x).FirstOrDefault(), | |||
x); | |||
/// <summary> | |||
/// Computes natural logarithm of x element-wise. | |||
@@ -698,22 +692,14 @@ namespace Tensorflow | |||
} | |||
public static Tensor cast(Tensor x, TF_DataType DstT, bool Truncate= false, string name= null) | |||
{ | |||
if (tf.Context.executing_eagerly()) | |||
{ | |||
var results = tf.Runner.TFE_FastPathExecute(tf.Context, tf.Context.DeviceName, | |||
=> tf.Context.RunInAutoMode(() | |||
=> tf.OpDefLib._apply_op_helper("Cast", name, args: new { x, DstT, Truncate }).output, () | |||
=> tf.Runner.TFE_FastPathExecute(tf.Context, tf.Context.DeviceName, | |||
"Cast", name, | |||
null, | |||
x, | |||
"DstT", DstT, "Truncate", Truncate); | |||
return results[0]; | |||
} | |||
var _op = tf.OpDefLib._apply_op_helper("Cast", name, args: new { x, DstT, Truncate }); | |||
return _op.outputs[0]; | |||
} | |||
"DstT", DstT, "Truncate", Truncate).FirstOrDefault(), | |||
x); | |||
public static Tensor neg(Tensor x, string name = null) | |||
{ | |||
@@ -1151,20 +1137,13 @@ namespace Tensorflow | |||
/// <param name="name"></param> | |||
/// <returns></returns> | |||
public static Tensor range(Tensor start, Tensor limit, Tensor delta, string name = null) | |||
{ | |||
if (tf.Context.executing_eagerly()) | |||
{ | |||
var results = tf.Runner.TFE_FastPathExecute(tf.Context, tf.Context.DeviceName, | |||
"Range", name, | |||
=> tf.Context.RunInAutoMode(() | |||
=> tf.OpDefLib._apply_op_helper("Range", name, new { start, limit, delta }).output, () | |||
=> tf.Runner.TFE_FastPathExecute(tf.Context, tf.Context.DeviceName, | |||
"Range", name, | |||
null, | |||
start, limit, delta); | |||
return results[0]; | |||
} | |||
var _op = tf.OpDefLib._apply_op_helper("Range", name, new { start, limit, delta }); | |||
return _op.outputs[0]; | |||
} | |||
start, limit, delta).FirstOrDefault(), | |||
start, limit, delta); | |||
/// <summary> | |||
/// Rounds the values of a tensor to the nearest integer, element-wise. | |||
@@ -225,14 +225,12 @@ namespace Tensorflow | |||
public static string name_from_scope_name(string name) | |||
{ | |||
if (name.EndsWith("/")) | |||
{ | |||
if (name == null) | |||
return null; | |||
else if (name.EndsWith("/")) | |||
return name.Substring(0, name.Length - 1); | |||
} | |||
else | |||
{ | |||
return name; | |||
} | |||
} | |||
/// <summary> | |||
@@ -444,7 +442,12 @@ namespace Tensorflow | |||
case NDArray nd: | |||
return constant_op.constant(nd, dtype: dtype, name: name); | |||
case EagerTensor tensor: | |||
return tf.executing_eagerly() ? tensor : tensor.AsPlaceholder(name: name); | |||
if (tf.executing_eagerly()) | |||
return tensor; | |||
else | |||
return tensor.dtype == TF_DataType.TF_RESOURCE | |||
? tensor.AsPlaceholder(name: name) | |||
: tensor.AsContatnt(name: name); | |||
case Tensor tensor: | |||
return tensor; | |||
case Tensor[] tensors: | |||
@@ -48,13 +48,13 @@ namespace Tensorflow | |||
public void __enter__() | |||
{ | |||
_name = _name ?? _default_name; | |||
if (tf.Context.executing_eagerly()) | |||
{ | |||
(scope_name, old_scope_name) = enter_eager_name_scope(tf.Context, _name); | |||
} | |||
else | |||
{ | |||
_name = _name ?? _default_name; | |||
Graph g = null; | |||
if (_values is List<Tensor> vList) | |||
@@ -72,7 +72,8 @@ namespace Tensorflow | |||
private (string, string) enter_eager_name_scope(Context ctx, string name) | |||
{ | |||
if (name == null) | |||
return (null, null); | |||
/*if (name == null) | |||
name = ""; | |||
var scope_name = name; | |||
@@ -87,7 +88,7 @@ namespace Tensorflow | |||
} | |||
ctx.ScopeName = scope_name; | |||
return (scope_name, old_name); | |||
return (scope_name, old_name);*/ | |||
} | |||
[DebuggerHidden] | |||