@@ -17,6 +17,8 @@ Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "Keras.UnitTest", "test\Kera | |||||
EndProject | EndProject | ||||
Project("{6EC3EE1D-3C4E-46DD-8F32-0CC8E7565705}") = "TensorFlowNET.Examples.FSharp", "test\TensorFlowNET.Examples.FSharp\TensorFlowNET.Examples.FSharp.fsproj", "{62BC3801-F0D3-44A9-A0AC-712F40C8F961}" | Project("{6EC3EE1D-3C4E-46DD-8F32-0CC8E7565705}") = "TensorFlowNET.Examples.FSharp", "test\TensorFlowNET.Examples.FSharp\TensorFlowNET.Examples.FSharp.fsproj", "{62BC3801-F0D3-44A9-A0AC-712F40C8F961}" | ||||
EndProject | EndProject | ||||
Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "NumSharp.Core", "..\NumSharp\src\NumSharp.Core\NumSharp.Core.csproj", "{92762DCB-64C8-41B4-BEF7-780A969CE68F}" | |||||
EndProject | |||||
Global | Global | ||||
GlobalSection(SolutionConfigurationPlatforms) = preSolution | GlobalSection(SolutionConfigurationPlatforms) = preSolution | ||||
Debug|Any CPU = Debug|Any CPU | Debug|Any CPU = Debug|Any CPU | ||||
@@ -51,6 +53,10 @@ Global | |||||
{62BC3801-F0D3-44A9-A0AC-712F40C8F961}.Debug|Any CPU.Build.0 = Debug|Any CPU | {62BC3801-F0D3-44A9-A0AC-712F40C8F961}.Debug|Any CPU.Build.0 = Debug|Any CPU | ||||
{62BC3801-F0D3-44A9-A0AC-712F40C8F961}.Release|Any CPU.ActiveCfg = Release|Any CPU | {62BC3801-F0D3-44A9-A0AC-712F40C8F961}.Release|Any CPU.ActiveCfg = Release|Any CPU | ||||
{62BC3801-F0D3-44A9-A0AC-712F40C8F961}.Release|Any CPU.Build.0 = Release|Any CPU | {62BC3801-F0D3-44A9-A0AC-712F40C8F961}.Release|Any CPU.Build.0 = Release|Any CPU | ||||
{92762DCB-64C8-41B4-BEF7-780A969CE68F}.Debug|Any CPU.ActiveCfg = Debug|Any CPU | |||||
{92762DCB-64C8-41B4-BEF7-780A969CE68F}.Debug|Any CPU.Build.0 = Debug|Any CPU | |||||
{92762DCB-64C8-41B4-BEF7-780A969CE68F}.Release|Any CPU.ActiveCfg = Release|Any CPU | |||||
{92762DCB-64C8-41B4-BEF7-780A969CE68F}.Release|Any CPU.Build.0 = Release|Any CPU | |||||
EndGlobalSection | EndGlobalSection | ||||
GlobalSection(SolutionProperties) = preSolution | GlobalSection(SolutionProperties) = preSolution | ||||
HideSolutionNode = FALSE | HideSolutionNode = FALSE | ||||
@@ -142,6 +142,7 @@ namespace Tensorflow | |||||
var layer = new Dense(units, activation, | var layer = new Dense(units, activation, | ||||
use_bias: use_bias, | use_bias: use_bias, | ||||
bias_initializer: bias_initializer, | |||||
kernel_initializer: kernel_initializer); | kernel_initializer: kernel_initializer); | ||||
return layer.apply(inputs); | return layer.apply(inputs); | ||||
@@ -0,0 +1,13 @@ | |||||
using System; | |||||
using System.Collections.Generic; | |||||
using System.Text; | |||||
namespace Tensorflow.Framework | |||||
{ | |||||
/// <summary> | |||||
/// Abstract base class for Tensor-like objects that are composed from Tensors. | |||||
/// </summary> | |||||
public abstract class CompositeTensor | |||||
{ | |||||
} | |||||
} |
@@ -0,0 +1,25 @@ | |||||
using System; | |||||
using System.Collections.Generic; | |||||
using System.Text; | |||||
namespace Tensorflow.Framework | |||||
{ | |||||
/// <summary> | |||||
/// A sparse representation of a set of tensor slices at given indices. | |||||
/// </summary> | |||||
public class IndexedSlices : CompositeTensor | |||||
{ | |||||
Tensor _values; | |||||
public Tensor values => _values; | |||||
public IndexedSlices(Tensor values, Tensor indices, Tensor dense_shape = null) | |||||
{ | |||||
} | |||||
public static implicit operator Tensor(IndexedSlices indexedSlices) | |||||
{ | |||||
return indexedSlices.values; | |||||
} | |||||
} | |||||
} |
@@ -2,6 +2,7 @@ | |||||
using System.Collections.Generic; | using System.Collections.Generic; | ||||
using System.Linq; | using System.Linq; | ||||
using System.Text; | using System.Text; | ||||
using Tensorflow.Framework; | |||||
using Tensorflow.Operations; | using Tensorflow.Operations; | ||||
using static Tensorflow.Python; | using static Tensorflow.Python; | ||||
@@ -42,9 +43,9 @@ namespace Tensorflow.Gradients | |||||
return end_value_index <= dim_index ? new Tensor[] { grad, null } : new Tensor[] { null, grad }; | return end_value_index <= dim_index ? new Tensor[] { grad, null } : new Tensor[] { null, grad }; | ||||
var concat_dim = op.inputs[dim_index]; | var concat_dim = op.inputs[dim_index]; | ||||
if (end_value_index == -1) | |||||
end_value_index = op.inputs.Length - 1; | |||||
var input_values = op.inputs._inputs.Skip(start_value_index).Take(end_value_index - start_value_index).ToArray(); | |||||
var input_values = op.inputs._inputs.Skip(start_value_index) | |||||
.Take(end_value_index == -1 ? op.inputs.Length - 1 : end_value_index - start_value_index) | |||||
.ToArray(); | |||||
var out_grads = new List<Tensor>(); | var out_grads = new List<Tensor>(); | ||||
if (constant_op.is_constant(concat_dim)) | if (constant_op.is_constant(concat_dim)) | ||||
@@ -92,10 +93,16 @@ namespace Tensorflow.Gradients | |||||
} | } | ||||
return (end_value_index <= dim_index ? | return (end_value_index <= dim_index ? | ||||
out_grads.ToArray().Concat(null) : | |||||
out_grads.ToArray().Concat(new Tensor[] { null }) : | |||||
new Tensor[] { null }.Concat(out_grads)).ToArray(); | new Tensor[] { null }.Concat(out_grads)).ToArray(); | ||||
} | } | ||||
[RegisterGradient("ExpandDims")] | |||||
public static Tensor[] _ExpandDimsGrad(Operation op, Tensor[] grads) | |||||
{ | |||||
return new Tensor[] { _ReshapeToInput(op, grads[0]), null }; | |||||
} | |||||
/// <summary> | /// <summary> | ||||
/// Extract the shapes of a set of input tensors. | /// Extract the shapes of a set of input tensors. | ||||
/// </summary> | /// </summary> | ||||
@@ -125,6 +132,45 @@ namespace Tensorflow.Gradients | |||||
return gen_ops.shape_n(inputs); | return gen_ops.shape_n(inputs); | ||||
} | } | ||||
/// <summary> | |||||
/// Gradient for GatherV2 op. | |||||
/// </summary> | |||||
/// <param name="op"></param> | |||||
/// <param name="grads"></param> | |||||
/// <returns></returns> | |||||
[RegisterGradient("GatherV2")] | |||||
public static Tensor[] _GatherV2Grad(Operation op, Tensor[] grads) | |||||
{ | |||||
var grad = grads[0]; | |||||
var @params = op.inputs[0]; | |||||
ops.colocate_with(@params); | |||||
var params_shape = array_ops.shape(@params, out_type: tf.int64); | |||||
params_shape = math_ops.cast(params_shape, tf.int32); | |||||
var indices = op.inputs[1]; | |||||
var indices_size = array_ops.expand_dims(array_ops.size(indices), 0); | |||||
var axis = op.inputs[2]; | |||||
var axis_static = tensor_util.constant_value(axis); | |||||
// For axis 0 gathers, build an appropriately shaped IndexedSlices. | |||||
if((int)axis_static == 0) | |||||
{ | |||||
var params_tail_shape = params_shape[1]; | |||||
var values_shape = array_ops.concat(new[] { indices_size, params_tail_shape }, 0); | |||||
var values = array_ops.reshape(grad, values_shape); | |||||
indices = array_ops.reshape(indices, indices_size); | |||||
return new Tensor[] | |||||
{ | |||||
new IndexedSlices(values, indices, params_shape), | |||||
null, | |||||
null | |||||
}; | |||||
} | |||||
return new Tensor[] { null, null }; | |||||
} | |||||
[RegisterGradient("Reshape")] | [RegisterGradient("Reshape")] | ||||
public static Tensor[] _ReshapeGrad(Operation op, Tensor[] grads) | public static Tensor[] _ReshapeGrad(Operation op, Tensor[] grads) | ||||
{ | { | ||||
@@ -106,10 +106,10 @@ namespace Tensorflow.Gradients | |||||
[RegisterGradient("Conv2D")] | [RegisterGradient("Conv2D")] | ||||
public static Tensor[] _Conv2DGrad(Operation op, Tensor[] grads) | public static Tensor[] _Conv2DGrad(Operation op, Tensor[] grads) | ||||
{ | { | ||||
var dilations = op.get_attr("dilations"); | |||||
var strides = op.get_attr("strides"); | |||||
var dilations = (op.get_attr("dilations") as AttrValue.Types.ListValue).I.Select(x => Convert.ToInt32(x)).ToArray(); | |||||
var strides = (op.get_attr("strides") as AttrValue.Types.ListValue).I.Select(x => Convert.ToInt32(x)).ToArray(); | |||||
var padding = op.get_attr("padding"); | var padding = op.get_attr("padding"); | ||||
var explicit_paddings = op.get_attr("explicit_paddings"); | |||||
var explicit_paddings = (op.get_attr("explicit_paddings") as AttrValue.Types.ListValue).I.Select(x => Convert.ToInt32(x)).ToArray(); | |||||
var use_cudnn_on_gpu = op.get_attr("use_cudnn_on_gpu"); | var use_cudnn_on_gpu = op.get_attr("use_cudnn_on_gpu"); | ||||
var data_format = op.get_attr("data_format"); | var data_format = op.get_attr("data_format"); | ||||
var shape = gen_array_ops.shape_n(new Tensor[] { op.inputs[0], op.inputs[1] }); | var shape = gen_array_ops.shape_n(new Tensor[] { op.inputs[0], op.inputs[1] }); | ||||
@@ -120,21 +120,23 @@ namespace Tensorflow.Gradients | |||||
{ | { | ||||
InputSizes = shape[0], | InputSizes = shape[0], | ||||
Filter = op.inputs[1], | Filter = op.inputs[1], | ||||
Dilations = dilations == null ? null : dilations as int[], | |||||
Strides = strides == null ? null : strides as int[], | |||||
OutBackProp = grads[0], | |||||
Dilations = dilations, | |||||
Strides = strides, | |||||
Padding = padding.ToString(), | Padding = padding.ToString(), | ||||
ExplicitPaddings = explicit_paddings == null ? null : explicit_paddings as int[], | |||||
ExplicitPaddings = explicit_paddings, | |||||
UseCudnnOnGpu = (bool)use_cudnn_on_gpu, | UseCudnnOnGpu = (bool)use_cudnn_on_gpu, | ||||
DataFormat = data_format.ToString() | |||||
DataFormat = data_format.ToString(), | |||||
}), | }), | ||||
gen_nn_ops.conv2d_backprop_filter(new Conv2dParams | gen_nn_ops.conv2d_backprop_filter(new Conv2dParams | ||||
{ | { | ||||
Input = op.inputs[0], | Input = op.inputs[0], | ||||
FilterSizes = shape[1], | FilterSizes = shape[1], | ||||
Dilations = dilations == null ? null : dilations as int[], | |||||
Strides = strides == null ? null : strides as int[], | |||||
OutBackProp = grads[0], | |||||
Dilations = dilations, | |||||
Strides = strides, | |||||
Padding = padding.ToString(), | Padding = padding.ToString(), | ||||
ExplicitPaddings = explicit_paddings == null ? null : explicit_paddings as int[], | |||||
ExplicitPaddings = explicit_paddings, | |||||
UseCudnnOnGpu = (bool)use_cudnn_on_gpu, | UseCudnnOnGpu = (bool)use_cudnn_on_gpu, | ||||
DataFormat = data_format.ToString() | DataFormat = data_format.ToString() | ||||
}) | }) | ||||
@@ -155,6 +157,23 @@ namespace Tensorflow.Gradients | |||||
return vec * mat; | return vec * mat; | ||||
} | } | ||||
[RegisterGradient("MaxPool")] | |||||
public static Tensor[] _MaxPoolGrad(Operation op, Tensor[] grads) | |||||
{ | |||||
var grad = grads[0]; | |||||
return new Tensor[] | |||||
{ | |||||
gen_nn_ops.max_pool_grad( | |||||
op.inputs[0], | |||||
op.outputs[0], | |||||
grad, | |||||
(op.get_attr("ksize") as AttrValue.Types.ListValue).I.Select(x => Convert.ToInt32(x)).ToArray(), | |||||
(op.get_attr("strides") as AttrValue.Types.ListValue).I.Select(x => Convert.ToInt32(x)).ToArray(), | |||||
padding: op.get_attr("padding").ToString(), | |||||
data_format: op.get_attr("data_format").ToString()) | |||||
}; | |||||
} | |||||
/// <summary> | /// <summary> | ||||
/// Return the gradients for TopK. | /// Return the gradients for TopK. | ||||
/// </summary> | /// </summary> | ||||
@@ -179,6 +179,23 @@ namespace Tensorflow.Operations | |||||
return _op.outputs[0]; | return _op.outputs[0]; | ||||
} | } | ||||
public static Tensor max_pool_grad(Tensor orig_input, Tensor orig_output, Tensor grad, int[] ksize, int[] strides, string padding, | |||||
string data_format= "NHWC", string name= null) | |||||
{ | |||||
var _op = _op_def_lib._apply_op_helper("MaxPoolGrad", name: name, args: new | |||||
{ | |||||
orig_input, | |||||
orig_output, | |||||
grad, | |||||
ksize, | |||||
strides, | |||||
padding, | |||||
data_format | |||||
}); | |||||
return _op.outputs[0]; | |||||
} | |||||
public static Tensor[] top_kv2(Tensor input, int k, bool sorted = true, string name = null) | public static Tensor[] top_kv2(Tensor input, int k, bool sorted = true, string name = null) | ||||
{ | { | ||||
var _op = _op_def_lib._apply_op_helper("TopKV2", name: name, args: new | var _op = _op_def_lib._apply_op_helper("TopKV2", name: name, args: new | ||||
@@ -1,5 +1,7 @@ | |||||
using Google.Protobuf.Collections; | using Google.Protobuf.Collections; | ||||
//using Newtonsoft.Json; | |||||
#if GRAPH_SERIALIZE | |||||
using Newtonsoft.Json; | |||||
#endif | |||||
using System; | using System; | ||||
using System.Collections.Generic; | using System.Collections.Generic; | ||||
using System.Linq; | using System.Linq; | ||||
@@ -33,16 +35,23 @@ namespace Tensorflow | |||||
private readonly IntPtr _operDesc; | private readonly IntPtr _operDesc; | ||||
private Graph _graph; | private Graph _graph; | ||||
//[JsonIgnore] | |||||
public string type => OpType; | |||||
#if GRAPH_SERIALIZE | |||||
[JsonIgnore] | |||||
public Graph graph => _graph; | |||||
[JsonIgnore] | |||||
public int _id => _id_value; | |||||
[JsonIgnore] | |||||
public int _id_value; | |||||
[JsonIgnore] | |||||
public Operation op => this; | |||||
#else | |||||
public Graph graph => _graph; | public Graph graph => _graph; | ||||
//[JsonIgnore] | |||||
public int _id => _id_value; | public int _id => _id_value; | ||||
//[JsonIgnore] | |||||
public int _id_value; | public int _id_value; | ||||
public string type => OpType; | |||||
//[JsonIgnore] | |||||
public Operation op => this; | public Operation op => this; | ||||
#endif | |||||
public TF_DataType dtype => TF_DataType.DtInvalid; | public TF_DataType dtype => TF_DataType.DtInvalid; | ||||
private Status status = new Status(); | private Status status = new Status(); | ||||
@@ -51,7 +60,7 @@ namespace Tensorflow | |||||
public string Device => c_api.StringPiece(c_api.TF_OperationDevice(_handle)); | public string Device => c_api.StringPiece(c_api.TF_OperationDevice(_handle)); | ||||
private NodeDef _node_def; | private NodeDef _node_def; | ||||
//[JsonIgnore] | |||||
[JsonIgnore] | |||||
public NodeDef node_def | public NodeDef node_def | ||||
{ | { | ||||
get | get | ||||
@@ -277,7 +277,7 @@ namespace Tensorflow | |||||
var input_shape = tensor_util.to_shape(input_tensor.shape); | var input_shape = tensor_util.to_shape(input_tensor.shape); | ||||
if (optimize && input_tensor.NDims > -1 && input_shape.is_fully_defined()) | if (optimize && input_tensor.NDims > -1 && input_shape.is_fully_defined()) | ||||
{ | { | ||||
var nd = np.array(input_tensor.shape, out_type.as_numpy_datatype()); | |||||
var nd = np.array(input_tensor.shape).astype(out_type.as_numpy_datatype()); | |||||
return constant_op.constant(nd, name: name); | return constant_op.constant(nd, name: name); | ||||
} | } | ||||
} | } | ||||
@@ -123,7 +123,7 @@ namespace Tensorflow | |||||
return with(ops.name_scope(name, "tuple", tensors), scope => | return with(ops.name_scope(name, "tuple", tensors), scope => | ||||
{ | { | ||||
name = scope; | name = scope; | ||||
var gating_ops = tensors.Select(x => x.op).ToList(); | |||||
var gating_ops = tensors.Where(x => x != null).Select(x => x.op).ToList(); | |||||
if(control_inputs != null) | if(control_inputs != null) | ||||
{ | { | ||||
@@ -139,7 +139,10 @@ namespace Tensorflow | |||||
var tpl = new List<Tensor>(); | var tpl = new List<Tensor>(); | ||||
foreach(var t in tensors) | foreach(var t in tensors) | ||||
{ | { | ||||
tpl.Add(with_dependencies(new Operation[] { gate }, t)); | |||||
if (t != null) | |||||
tpl.Add(with_dependencies(new Operation[] { gate }, t)); | |||||
else | |||||
tpl.Add(null); | |||||
} | } | ||||
return tpl.ToArray(); | return tpl.ToArray(); | ||||
@@ -29,7 +29,7 @@ Docs: https://tensorflownet.readthedocs.io</Description> | |||||
<PropertyGroup Condition="'$(Configuration)|$(Platform)'=='Debug|AnyCPU'"> | <PropertyGroup Condition="'$(Configuration)|$(Platform)'=='Debug|AnyCPU'"> | ||||
<AllowUnsafeBlocks>true</AllowUnsafeBlocks> | <AllowUnsafeBlocks>true</AllowUnsafeBlocks> | ||||
<DefineConstants>DEBUG;TRACE</DefineConstants> | |||||
<DefineConstants>TRACE;DEBUG;GRAPH_SERIALIZE</DefineConstants> | |||||
</PropertyGroup> | </PropertyGroup> | ||||
<PropertyGroup Condition="'$(Configuration)|$(Platform)'=='Release|AnyCPU'"> | <PropertyGroup Condition="'$(Configuration)|$(Platform)'=='Release|AnyCPU'"> | ||||
@@ -47,7 +47,8 @@ Docs: https://tensorflownet.readthedocs.io</Description> | |||||
</ItemGroup> | </ItemGroup> | ||||
<ItemGroup> | <ItemGroup> | ||||
<PackageReference Include="Google.Protobuf" Version="3.7.0" /> | |||||
<PackageReference Include="Google.Protobuf" Version="3.8.0" /> | |||||
<PackageReference Include="Newtonsoft.Json" Version="12.0.2" /> | |||||
<PackageReference Include="NumSharp" Version="0.10.2" /> | <PackageReference Include="NumSharp" Version="0.10.2" /> | ||||
</ItemGroup> | </ItemGroup> | ||||
@@ -62,4 +63,8 @@ Docs: https://tensorflownet.readthedocs.io</Description> | |||||
<Folder Include="Keras\Initializers\" /> | <Folder Include="Keras\Initializers\" /> | ||||
</ItemGroup> | </ItemGroup> | ||||
<ItemGroup> | |||||
<ProjectReference Include="..\..\..\NumSharp\src\NumSharp.Core\NumSharp.Core.csproj" /> | |||||
</ItemGroup> | |||||
</Project> | </Project> |
@@ -1,4 +1,6 @@ | |||||
//using Newtonsoft.Json; | |||||
#if GRAPH_SERIALIZE | |||||
using Newtonsoft.Json; | |||||
#endif | |||||
using NumSharp; | using NumSharp; | ||||
using System; | using System; | ||||
using System.Collections.Generic; | using System.Collections.Generic; | ||||
@@ -19,15 +21,22 @@ namespace Tensorflow | |||||
private readonly IntPtr _handle; | private readonly IntPtr _handle; | ||||
private int _id; | private int _id; | ||||
//[JsonIgnore] | |||||
private Operation _op; | |||||
#if GRAPH_SERIALIZE | |||||
[JsonIgnore] | |||||
public int Id => _id; | |||||
[JsonIgnore] | |||||
public Graph graph => op?.graph; | |||||
[JsonIgnore] | |||||
public Operation op => _op; | |||||
[JsonIgnore] | |||||
public Tensor[] outputs => op.outputs; | |||||
#else | |||||
public int Id => _id; | public int Id => _id; | ||||
//[JsonIgnore] | |||||
public Graph graph => op?.graph; | public Graph graph => op?.graph; | ||||
private Operation _op; | |||||
//[JsonIgnore] | |||||
public Operation op => _op; | public Operation op => _op; | ||||
//[JsonIgnore] | |||||
public Tensor[] outputs => op.outputs; | public Tensor[] outputs => op.outputs; | ||||
#endif | |||||
/// <summary> | /// <summary> | ||||
/// The string name of this tensor. | /// The string name of this tensor. | ||||
@@ -210,11 +219,11 @@ namespace Tensorflow | |||||
} | } | ||||
} | } | ||||
public Tensor this[int slice_spec] | |||||
public Tensor this[int start, int? stop, int? step] | |||||
{ | { | ||||
get | get | ||||
{ | { | ||||
var slice_spec_s = new int[] { slice_spec }; | |||||
var slice_spec = new int[] { start }; | |||||
var begin = new List<int>(); | var begin = new List<int>(); | ||||
var end = new List<int>(); | var end = new List<int>(); | ||||
var strides = new List<int>(); | var strides = new List<int>(); | ||||
@@ -224,22 +233,25 @@ namespace Tensorflow | |||||
var (begin_mask, end_mask) = (0, 0); | var (begin_mask, end_mask) = (0, 0); | ||||
var ellipsis_mask = 0; | var ellipsis_mask = 0; | ||||
foreach(var s in slice_spec_s) | |||||
foreach (var s in slice_spec) | |||||
{ | { | ||||
begin.Add(s); | |||||
if (stop == null) | |||||
{ | { | ||||
begin.Add(s); | |||||
end.Add(s + 1); | |||||
strides.Add(1); | |||||
shrink_axis_mask |= (1 << index); | |||||
end.Add(0); | |||||
end_mask |= (1 << index); | |||||
} | } | ||||
else | |||||
end.Add(s + 1); | |||||
strides.Add(1); | |||||
index += 1; | index += 1; | ||||
} | } | ||||
return with(ops.name_scope(null, "strided_slice", new { begin, end, strides }), scope => | return with(ops.name_scope(null, "strided_slice", new { begin, end, strides }), scope => | ||||
{ | { | ||||
string name = scope; | string name = scope; | ||||
if(begin != null) | |||||
if (begin != null) | |||||
{ | { | ||||
var (packed_begin, packed_end, packed_strides) = | var (packed_begin, packed_end, packed_strides) = | ||||
(array_ops.stack(begin.ToArray()), | (array_ops.stack(begin.ToArray()), | ||||
@@ -256,15 +268,17 @@ namespace Tensorflow | |||||
shrink_axis_mask: shrink_axis_mask, | shrink_axis_mask: shrink_axis_mask, | ||||
new_axis_mask: new_axis_mask, | new_axis_mask: new_axis_mask, | ||||
ellipsis_mask: ellipsis_mask, | ellipsis_mask: ellipsis_mask, | ||||
name: name); | name: name); | ||||
} | } | ||||
throw new NotImplementedException(""); | throw new NotImplementedException(""); | ||||
}); | }); | ||||
} | } | ||||
} | } | ||||
public Tensor this[int slice_spec] => this[slice_spec, null, null]; | |||||
public override string ToString() | public override string ToString() | ||||
{ | { | ||||
// this can throw IndexOutOfRangeException | // this can throw IndexOutOfRangeException | ||||
@@ -16,6 +16,8 @@ namespace Tensorflow | |||||
{ | { | ||||
case TF_DataType.TF_BOOL: | case TF_DataType.TF_BOOL: | ||||
return typeof(bool); | return typeof(bool); | ||||
case TF_DataType.TF_INT64: | |||||
return typeof(long); | |||||
case TF_DataType.TF_INT32: | case TF_DataType.TF_INT32: | ||||
return typeof(int); | return typeof(int); | ||||
case TF_DataType.TF_INT16: | case TF_DataType.TF_INT16: | ||||
@@ -57,24 +57,24 @@ namespace Tensorflow | |||||
if (initializer is IInitializer init) | if (initializer is IInitializer init) | ||||
{ | { | ||||
return _get_single_variable(name: name, | return _get_single_variable(name: name, | ||||
shape: shape, | |||||
dtype: dtype, | |||||
initializer: init, | |||||
trainable: trainable, | |||||
validate_shape: validate_shape, | |||||
synchronization: synchronization, | |||||
aggregation: aggregation); | |||||
shape: shape, | |||||
dtype: dtype, | |||||
initializer: init, | |||||
trainable: trainable, | |||||
validate_shape: validate_shape, | |||||
synchronization: synchronization, | |||||
aggregation: aggregation); | |||||
} | } | ||||
else if (initializer is Tensor tensor) | else if (initializer is Tensor tensor) | ||||
{ | { | ||||
return _get_single_variable(name: name, | return _get_single_variable(name: name, | ||||
shape: shape, | |||||
dtype: dtype, | |||||
initializer: tensor, | |||||
trainable: trainable, | |||||
validate_shape: validate_shape, | |||||
synchronization: synchronization, | |||||
aggregation: aggregation); | |||||
shape: shape, | |||||
dtype: dtype, | |||||
initializer: tensor, | |||||
trainable: trainable, | |||||
validate_shape: validate_shape, | |||||
synchronization: synchronization, | |||||
aggregation: aggregation); | |||||
} | } | ||||
else | else | ||||
{ | { | ||||
@@ -141,7 +141,7 @@ namespace Tensorflow | |||||
v = variable_scope.default_variable_creator(init_val, | v = variable_scope.default_variable_creator(init_val, | ||||
name: name, | name: name, | ||||
trainable: trainable, | trainable: trainable, | ||||
dtype: TF_DataType.DtInvalid, | |||||
dtype: variable_dtype, | |||||
validate_shape: validate_shape, | validate_shape: validate_shape, | ||||
synchronization: synchronization, | synchronization: synchronization, | ||||
aggregation: aggregation); | aggregation: aggregation); | ||||
@@ -5,6 +5,7 @@ using System.Diagnostics; | |||||
using System.IO; | using System.IO; | ||||
using System.Linq; | using System.Linq; | ||||
using System.Text; | using System.Text; | ||||
using Newtonsoft.Json; | |||||
using NumSharp; | using NumSharp; | ||||
using Tensorflow; | using Tensorflow; | ||||
using Tensorflow.Keras.Engine; | using Tensorflow.Keras.Engine; | ||||
@@ -197,17 +198,17 @@ namespace TensorFlowNET.Examples | |||||
var h_pool = tf.concat(pooled_outputs, 3); | var h_pool = tf.concat(pooled_outputs, 3); | ||||
var h_pool_flat = tf.reshape(h_pool, new TensorShape(-1, num_filters * filter_sizes.Rank)); | var h_pool_flat = tf.reshape(h_pool, new TensorShape(-1, num_filters * filter_sizes.Rank)); | ||||
Tensor h_drop = null; | |||||
with(tf.name_scope("dropout"), delegate | with(tf.name_scope("dropout"), delegate | ||||
{ | { | ||||
var h_drop = tf.nn.dropout(h_pool_flat, keep_prob); | |||||
h_drop = tf.nn.dropout(h_pool_flat, keep_prob); | |||||
}); | }); | ||||
Tensor logits = null; | Tensor logits = null; | ||||
Tensor predictions = null; | Tensor predictions = null; | ||||
with(tf.name_scope("output"), delegate | with(tf.name_scope("output"), delegate | ||||
{ | { | ||||
logits = tf.layers.dense(h_pool_flat, NUM_CLASS); | |||||
logits = tf.layers.dense(h_drop, NUM_CLASS); | |||||
predictions = tf.argmax(logits, -1, output_type: tf.int32); | predictions = tf.argmax(logits, -1, output_type: tf.int32); | ||||
}); | }); | ||||
@@ -307,6 +308,8 @@ namespace TensorFlowNET.Examples | |||||
{ | { | ||||
var graph = IsImportingGraph ? ImportGraph() : BuildGraph(); | var graph = IsImportingGraph ? ImportGraph() : BuildGraph(); | ||||
var imported_graph = JsonConvert.SerializeObject(graph, new JsonSerializerSettings { Formatting = Formatting.Indented }); | |||||
return with(tf.Session(graph), sess => Train(sess, graph)); | return with(tf.Session(graph), sess => Train(sess, graph)); | ||||
} | } | ||||