@@ -17,6 +17,8 @@ Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "Keras.UnitTest", "test\Kera | |||
EndProject | |||
Project("{6EC3EE1D-3C4E-46DD-8F32-0CC8E7565705}") = "TensorFlowNET.Examples.FSharp", "test\TensorFlowNET.Examples.FSharp\TensorFlowNET.Examples.FSharp.fsproj", "{62BC3801-F0D3-44A9-A0AC-712F40C8F961}" | |||
EndProject | |||
Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "NumSharp.Core", "..\NumSharp\src\NumSharp.Core\NumSharp.Core.csproj", "{92762DCB-64C8-41B4-BEF7-780A969CE68F}" | |||
EndProject | |||
Global | |||
GlobalSection(SolutionConfigurationPlatforms) = preSolution | |||
Debug|Any CPU = Debug|Any CPU | |||
@@ -51,6 +53,10 @@ Global | |||
{62BC3801-F0D3-44A9-A0AC-712F40C8F961}.Debug|Any CPU.Build.0 = Debug|Any CPU | |||
{62BC3801-F0D3-44A9-A0AC-712F40C8F961}.Release|Any CPU.ActiveCfg = Release|Any CPU | |||
{62BC3801-F0D3-44A9-A0AC-712F40C8F961}.Release|Any CPU.Build.0 = Release|Any CPU | |||
{92762DCB-64C8-41B4-BEF7-780A969CE68F}.Debug|Any CPU.ActiveCfg = Debug|Any CPU | |||
{92762DCB-64C8-41B4-BEF7-780A969CE68F}.Debug|Any CPU.Build.0 = Debug|Any CPU | |||
{92762DCB-64C8-41B4-BEF7-780A969CE68F}.Release|Any CPU.ActiveCfg = Release|Any CPU | |||
{92762DCB-64C8-41B4-BEF7-780A969CE68F}.Release|Any CPU.Build.0 = Release|Any CPU | |||
EndGlobalSection | |||
GlobalSection(SolutionProperties) = preSolution | |||
HideSolutionNode = FALSE | |||
@@ -142,6 +142,7 @@ namespace Tensorflow | |||
var layer = new Dense(units, activation, | |||
use_bias: use_bias, | |||
bias_initializer: bias_initializer, | |||
kernel_initializer: kernel_initializer); | |||
return layer.apply(inputs); | |||
@@ -0,0 +1,13 @@ | |||
using System; | |||
using System.Collections.Generic; | |||
using System.Text; | |||
namespace Tensorflow.Framework | |||
{ | |||
/// <summary> | |||
/// Abstract base class for Tensor-like objects that are composed from Tensors. | |||
/// </summary> | |||
public abstract class CompositeTensor | |||
{ | |||
} | |||
} |
@@ -0,0 +1,25 @@ | |||
using System; | |||
using System.Collections.Generic; | |||
using System.Text; | |||
namespace Tensorflow.Framework | |||
{ | |||
/// <summary> | |||
/// A sparse representation of a set of tensor slices at given indices. | |||
/// </summary> | |||
public class IndexedSlices : CompositeTensor | |||
{ | |||
Tensor _values; | |||
public Tensor values => _values; | |||
public IndexedSlices(Tensor values, Tensor indices, Tensor dense_shape = null) | |||
{ | |||
} | |||
public static implicit operator Tensor(IndexedSlices indexedSlices) | |||
{ | |||
return indexedSlices.values; | |||
} | |||
} | |||
} |
@@ -2,6 +2,7 @@ | |||
using System.Collections.Generic; | |||
using System.Linq; | |||
using System.Text; | |||
using Tensorflow.Framework; | |||
using Tensorflow.Operations; | |||
using static Tensorflow.Python; | |||
@@ -42,9 +43,9 @@ namespace Tensorflow.Gradients | |||
return end_value_index <= dim_index ? new Tensor[] { grad, null } : new Tensor[] { null, grad }; | |||
var concat_dim = op.inputs[dim_index]; | |||
if (end_value_index == -1) | |||
end_value_index = op.inputs.Length - 1; | |||
var input_values = op.inputs._inputs.Skip(start_value_index).Take(end_value_index - start_value_index).ToArray(); | |||
var input_values = op.inputs._inputs.Skip(start_value_index) | |||
.Take(end_value_index == -1 ? op.inputs.Length - 1 : end_value_index - start_value_index) | |||
.ToArray(); | |||
var out_grads = new List<Tensor>(); | |||
if (constant_op.is_constant(concat_dim)) | |||
@@ -92,10 +93,16 @@ namespace Tensorflow.Gradients | |||
} | |||
return (end_value_index <= dim_index ? | |||
out_grads.ToArray().Concat(null) : | |||
out_grads.ToArray().Concat(new Tensor[] { null }) : | |||
new Tensor[] { null }.Concat(out_grads)).ToArray(); | |||
} | |||
[RegisterGradient("ExpandDims")] | |||
public static Tensor[] _ExpandDimsGrad(Operation op, Tensor[] grads) | |||
{ | |||
return new Tensor[] { _ReshapeToInput(op, grads[0]), null }; | |||
} | |||
/// <summary> | |||
/// Extract the shapes of a set of input tensors. | |||
/// </summary> | |||
@@ -125,6 +132,45 @@ namespace Tensorflow.Gradients | |||
return gen_ops.shape_n(inputs); | |||
} | |||
/// <summary> | |||
/// Gradient for GatherV2 op. | |||
/// </summary> | |||
/// <param name="op"></param> | |||
/// <param name="grads"></param> | |||
/// <returns></returns> | |||
[RegisterGradient("GatherV2")] | |||
public static Tensor[] _GatherV2Grad(Operation op, Tensor[] grads) | |||
{ | |||
var grad = grads[0]; | |||
var @params = op.inputs[0]; | |||
ops.colocate_with(@params); | |||
var params_shape = array_ops.shape(@params, out_type: tf.int64); | |||
params_shape = math_ops.cast(params_shape, tf.int32); | |||
var indices = op.inputs[1]; | |||
var indices_size = array_ops.expand_dims(array_ops.size(indices), 0); | |||
var axis = op.inputs[2]; | |||
var axis_static = tensor_util.constant_value(axis); | |||
// For axis 0 gathers, build an appropriately shaped IndexedSlices. | |||
if((int)axis_static == 0) | |||
{ | |||
var params_tail_shape = params_shape[1]; | |||
var values_shape = array_ops.concat(new[] { indices_size, params_tail_shape }, 0); | |||
var values = array_ops.reshape(grad, values_shape); | |||
indices = array_ops.reshape(indices, indices_size); | |||
return new Tensor[] | |||
{ | |||
new IndexedSlices(values, indices, params_shape), | |||
null, | |||
null | |||
}; | |||
} | |||
return new Tensor[] { null, null }; | |||
} | |||
[RegisterGradient("Reshape")] | |||
public static Tensor[] _ReshapeGrad(Operation op, Tensor[] grads) | |||
{ | |||
@@ -106,10 +106,10 @@ namespace Tensorflow.Gradients | |||
[RegisterGradient("Conv2D")] | |||
public static Tensor[] _Conv2DGrad(Operation op, Tensor[] grads) | |||
{ | |||
var dilations = op.get_attr("dilations"); | |||
var strides = op.get_attr("strides"); | |||
var dilations = (op.get_attr("dilations") as AttrValue.Types.ListValue).I.Select(x => Convert.ToInt32(x)).ToArray(); | |||
var strides = (op.get_attr("strides") as AttrValue.Types.ListValue).I.Select(x => Convert.ToInt32(x)).ToArray(); | |||
var padding = op.get_attr("padding"); | |||
var explicit_paddings = op.get_attr("explicit_paddings"); | |||
var explicit_paddings = (op.get_attr("explicit_paddings") as AttrValue.Types.ListValue).I.Select(x => Convert.ToInt32(x)).ToArray(); | |||
var use_cudnn_on_gpu = op.get_attr("use_cudnn_on_gpu"); | |||
var data_format = op.get_attr("data_format"); | |||
var shape = gen_array_ops.shape_n(new Tensor[] { op.inputs[0], op.inputs[1] }); | |||
@@ -120,21 +120,23 @@ namespace Tensorflow.Gradients | |||
{ | |||
InputSizes = shape[0], | |||
Filter = op.inputs[1], | |||
Dilations = dilations == null ? null : dilations as int[], | |||
Strides = strides == null ? null : strides as int[], | |||
OutBackProp = grads[0], | |||
Dilations = dilations, | |||
Strides = strides, | |||
Padding = padding.ToString(), | |||
ExplicitPaddings = explicit_paddings == null ? null : explicit_paddings as int[], | |||
ExplicitPaddings = explicit_paddings, | |||
UseCudnnOnGpu = (bool)use_cudnn_on_gpu, | |||
DataFormat = data_format.ToString() | |||
DataFormat = data_format.ToString(), | |||
}), | |||
gen_nn_ops.conv2d_backprop_filter(new Conv2dParams | |||
{ | |||
Input = op.inputs[0], | |||
FilterSizes = shape[1], | |||
Dilations = dilations == null ? null : dilations as int[], | |||
Strides = strides == null ? null : strides as int[], | |||
OutBackProp = grads[0], | |||
Dilations = dilations, | |||
Strides = strides, | |||
Padding = padding.ToString(), | |||
ExplicitPaddings = explicit_paddings == null ? null : explicit_paddings as int[], | |||
ExplicitPaddings = explicit_paddings, | |||
UseCudnnOnGpu = (bool)use_cudnn_on_gpu, | |||
DataFormat = data_format.ToString() | |||
}) | |||
@@ -155,6 +157,23 @@ namespace Tensorflow.Gradients | |||
return vec * mat; | |||
} | |||
[RegisterGradient("MaxPool")] | |||
public static Tensor[] _MaxPoolGrad(Operation op, Tensor[] grads) | |||
{ | |||
var grad = grads[0]; | |||
return new Tensor[] | |||
{ | |||
gen_nn_ops.max_pool_grad( | |||
op.inputs[0], | |||
op.outputs[0], | |||
grad, | |||
(op.get_attr("ksize") as AttrValue.Types.ListValue).I.Select(x => Convert.ToInt32(x)).ToArray(), | |||
(op.get_attr("strides") as AttrValue.Types.ListValue).I.Select(x => Convert.ToInt32(x)).ToArray(), | |||
padding: op.get_attr("padding").ToString(), | |||
data_format: op.get_attr("data_format").ToString()) | |||
}; | |||
} | |||
/// <summary> | |||
/// Return the gradients for TopK. | |||
/// </summary> | |||
@@ -179,6 +179,23 @@ namespace Tensorflow.Operations | |||
return _op.outputs[0]; | |||
} | |||
public static Tensor max_pool_grad(Tensor orig_input, Tensor orig_output, Tensor grad, int[] ksize, int[] strides, string padding, | |||
string data_format= "NHWC", string name= null) | |||
{ | |||
var _op = _op_def_lib._apply_op_helper("MaxPoolGrad", name: name, args: new | |||
{ | |||
orig_input, | |||
orig_output, | |||
grad, | |||
ksize, | |||
strides, | |||
padding, | |||
data_format | |||
}); | |||
return _op.outputs[0]; | |||
} | |||
public static Tensor[] top_kv2(Tensor input, int k, bool sorted = true, string name = null) | |||
{ | |||
var _op = _op_def_lib._apply_op_helper("TopKV2", name: name, args: new | |||
@@ -1,5 +1,7 @@ | |||
using Google.Protobuf.Collections; | |||
//using Newtonsoft.Json; | |||
#if GRAPH_SERIALIZE | |||
using Newtonsoft.Json; | |||
#endif | |||
using System; | |||
using System.Collections.Generic; | |||
using System.Linq; | |||
@@ -33,16 +35,23 @@ namespace Tensorflow | |||
private readonly IntPtr _operDesc; | |||
private Graph _graph; | |||
//[JsonIgnore] | |||
public string type => OpType; | |||
#if GRAPH_SERIALIZE | |||
[JsonIgnore] | |||
public Graph graph => _graph; | |||
[JsonIgnore] | |||
public int _id => _id_value; | |||
[JsonIgnore] | |||
public int _id_value; | |||
[JsonIgnore] | |||
public Operation op => this; | |||
#else | |||
public Graph graph => _graph; | |||
//[JsonIgnore] | |||
public int _id => _id_value; | |||
//[JsonIgnore] | |||
public int _id_value; | |||
public string type => OpType; | |||
//[JsonIgnore] | |||
public Operation op => this; | |||
#endif | |||
public TF_DataType dtype => TF_DataType.DtInvalid; | |||
private Status status = new Status(); | |||
@@ -51,7 +60,7 @@ namespace Tensorflow | |||
public string Device => c_api.StringPiece(c_api.TF_OperationDevice(_handle)); | |||
private NodeDef _node_def; | |||
//[JsonIgnore] | |||
[JsonIgnore] | |||
public NodeDef node_def | |||
{ | |||
get | |||
@@ -277,7 +277,7 @@ namespace Tensorflow | |||
var input_shape = tensor_util.to_shape(input_tensor.shape); | |||
if (optimize && input_tensor.NDims > -1 && input_shape.is_fully_defined()) | |||
{ | |||
var nd = np.array(input_tensor.shape, out_type.as_numpy_datatype()); | |||
var nd = np.array(input_tensor.shape).astype(out_type.as_numpy_datatype()); | |||
return constant_op.constant(nd, name: name); | |||
} | |||
} | |||
@@ -123,7 +123,7 @@ namespace Tensorflow | |||
return with(ops.name_scope(name, "tuple", tensors), scope => | |||
{ | |||
name = scope; | |||
var gating_ops = tensors.Select(x => x.op).ToList(); | |||
var gating_ops = tensors.Where(x => x != null).Select(x => x.op).ToList(); | |||
if(control_inputs != null) | |||
{ | |||
@@ -139,7 +139,10 @@ namespace Tensorflow | |||
var tpl = new List<Tensor>(); | |||
foreach(var t in tensors) | |||
{ | |||
tpl.Add(with_dependencies(new Operation[] { gate }, t)); | |||
if (t != null) | |||
tpl.Add(with_dependencies(new Operation[] { gate }, t)); | |||
else | |||
tpl.Add(null); | |||
} | |||
return tpl.ToArray(); | |||
@@ -29,7 +29,7 @@ Docs: https://tensorflownet.readthedocs.io</Description> | |||
<PropertyGroup Condition="'$(Configuration)|$(Platform)'=='Debug|AnyCPU'"> | |||
<AllowUnsafeBlocks>true</AllowUnsafeBlocks> | |||
<DefineConstants>DEBUG;TRACE</DefineConstants> | |||
<DefineConstants>TRACE;DEBUG;GRAPH_SERIALIZE</DefineConstants> | |||
</PropertyGroup> | |||
<PropertyGroup Condition="'$(Configuration)|$(Platform)'=='Release|AnyCPU'"> | |||
@@ -47,7 +47,8 @@ Docs: https://tensorflownet.readthedocs.io</Description> | |||
</ItemGroup> | |||
<ItemGroup> | |||
<PackageReference Include="Google.Protobuf" Version="3.7.0" /> | |||
<PackageReference Include="Google.Protobuf" Version="3.8.0" /> | |||
<PackageReference Include="Newtonsoft.Json" Version="12.0.2" /> | |||
<PackageReference Include="NumSharp" Version="0.10.2" /> | |||
</ItemGroup> | |||
@@ -62,4 +63,8 @@ Docs: https://tensorflownet.readthedocs.io</Description> | |||
<Folder Include="Keras\Initializers\" /> | |||
</ItemGroup> | |||
<ItemGroup> | |||
<ProjectReference Include="..\..\..\NumSharp\src\NumSharp.Core\NumSharp.Core.csproj" /> | |||
</ItemGroup> | |||
</Project> |
@@ -1,4 +1,6 @@ | |||
//using Newtonsoft.Json; | |||
#if GRAPH_SERIALIZE | |||
using Newtonsoft.Json; | |||
#endif | |||
using NumSharp; | |||
using System; | |||
using System.Collections.Generic; | |||
@@ -19,15 +21,22 @@ namespace Tensorflow | |||
private readonly IntPtr _handle; | |||
private int _id; | |||
//[JsonIgnore] | |||
private Operation _op; | |||
#if GRAPH_SERIALIZE | |||
[JsonIgnore] | |||
public int Id => _id; | |||
[JsonIgnore] | |||
public Graph graph => op?.graph; | |||
[JsonIgnore] | |||
public Operation op => _op; | |||
[JsonIgnore] | |||
public Tensor[] outputs => op.outputs; | |||
#else | |||
public int Id => _id; | |||
//[JsonIgnore] | |||
public Graph graph => op?.graph; | |||
private Operation _op; | |||
//[JsonIgnore] | |||
public Operation op => _op; | |||
//[JsonIgnore] | |||
public Tensor[] outputs => op.outputs; | |||
#endif | |||
/// <summary> | |||
/// The string name of this tensor. | |||
@@ -210,11 +219,11 @@ namespace Tensorflow | |||
} | |||
} | |||
public Tensor this[int slice_spec] | |||
public Tensor this[int start, int? stop, int? step] | |||
{ | |||
get | |||
{ | |||
var slice_spec_s = new int[] { slice_spec }; | |||
var slice_spec = new int[] { start }; | |||
var begin = new List<int>(); | |||
var end = new List<int>(); | |||
var strides = new List<int>(); | |||
@@ -224,22 +233,25 @@ namespace Tensorflow | |||
var (begin_mask, end_mask) = (0, 0); | |||
var ellipsis_mask = 0; | |||
foreach(var s in slice_spec_s) | |||
foreach (var s in slice_spec) | |||
{ | |||
begin.Add(s); | |||
if (stop == null) | |||
{ | |||
begin.Add(s); | |||
end.Add(s + 1); | |||
strides.Add(1); | |||
shrink_axis_mask |= (1 << index); | |||
end.Add(0); | |||
end_mask |= (1 << index); | |||
} | |||
else | |||
end.Add(s + 1); | |||
strides.Add(1); | |||
index += 1; | |||
} | |||
return with(ops.name_scope(null, "strided_slice", new { begin, end, strides }), scope => | |||
{ | |||
string name = scope; | |||
if(begin != null) | |||
if (begin != null) | |||
{ | |||
var (packed_begin, packed_end, packed_strides) = | |||
(array_ops.stack(begin.ToArray()), | |||
@@ -256,15 +268,17 @@ namespace Tensorflow | |||
shrink_axis_mask: shrink_axis_mask, | |||
new_axis_mask: new_axis_mask, | |||
ellipsis_mask: ellipsis_mask, | |||
name: name); | |||
} | |||
throw new NotImplementedException(""); | |||
}); | |||
} | |||
} | |||
public Tensor this[int slice_spec] => this[slice_spec, null, null]; | |||
public override string ToString() | |||
{ | |||
// this can throw IndexOutOfRangeException | |||
@@ -16,6 +16,8 @@ namespace Tensorflow | |||
{ | |||
case TF_DataType.TF_BOOL: | |||
return typeof(bool); | |||
case TF_DataType.TF_INT64: | |||
return typeof(long); | |||
case TF_DataType.TF_INT32: | |||
return typeof(int); | |||
case TF_DataType.TF_INT16: | |||
@@ -57,24 +57,24 @@ namespace Tensorflow | |||
if (initializer is IInitializer init) | |||
{ | |||
return _get_single_variable(name: name, | |||
shape: shape, | |||
dtype: dtype, | |||
initializer: init, | |||
trainable: trainable, | |||
validate_shape: validate_shape, | |||
synchronization: synchronization, | |||
aggregation: aggregation); | |||
shape: shape, | |||
dtype: dtype, | |||
initializer: init, | |||
trainable: trainable, | |||
validate_shape: validate_shape, | |||
synchronization: synchronization, | |||
aggregation: aggregation); | |||
} | |||
else if (initializer is Tensor tensor) | |||
{ | |||
return _get_single_variable(name: name, | |||
shape: shape, | |||
dtype: dtype, | |||
initializer: tensor, | |||
trainable: trainable, | |||
validate_shape: validate_shape, | |||
synchronization: synchronization, | |||
aggregation: aggregation); | |||
shape: shape, | |||
dtype: dtype, | |||
initializer: tensor, | |||
trainable: trainable, | |||
validate_shape: validate_shape, | |||
synchronization: synchronization, | |||
aggregation: aggregation); | |||
} | |||
else | |||
{ | |||
@@ -141,7 +141,7 @@ namespace Tensorflow | |||
v = variable_scope.default_variable_creator(init_val, | |||
name: name, | |||
trainable: trainable, | |||
dtype: TF_DataType.DtInvalid, | |||
dtype: variable_dtype, | |||
validate_shape: validate_shape, | |||
synchronization: synchronization, | |||
aggregation: aggregation); | |||
@@ -5,6 +5,7 @@ using System.Diagnostics; | |||
using System.IO; | |||
using System.Linq; | |||
using System.Text; | |||
using Newtonsoft.Json; | |||
using NumSharp; | |||
using Tensorflow; | |||
using Tensorflow.Keras.Engine; | |||
@@ -197,17 +198,17 @@ namespace TensorFlowNET.Examples | |||
var h_pool = tf.concat(pooled_outputs, 3); | |||
var h_pool_flat = tf.reshape(h_pool, new TensorShape(-1, num_filters * filter_sizes.Rank)); | |||
Tensor h_drop = null; | |||
with(tf.name_scope("dropout"), delegate | |||
{ | |||
var h_drop = tf.nn.dropout(h_pool_flat, keep_prob); | |||
h_drop = tf.nn.dropout(h_pool_flat, keep_prob); | |||
}); | |||
Tensor logits = null; | |||
Tensor predictions = null; | |||
with(tf.name_scope("output"), delegate | |||
{ | |||
logits = tf.layers.dense(h_pool_flat, NUM_CLASS); | |||
logits = tf.layers.dense(h_drop, NUM_CLASS); | |||
predictions = tf.argmax(logits, -1, output_type: tf.int32); | |||
}); | |||
@@ -307,6 +308,8 @@ namespace TensorFlowNET.Examples | |||
{ | |||
var graph = IsImportingGraph ? ImportGraph() : BuildGraph(); | |||
var imported_graph = JsonConvert.SerializeObject(graph, new JsonSerializerSettings { Formatting = Formatting.Indented }); | |||
return with(tf.Session(graph), sess => Train(sess, graph)); | |||
} | |||