Browse Source

Merge branch 'master' into ndarrayload

# Conflicts:
#	src/TensorFlowNET.Core/Tensorflow.Binding.csproj
#	src/TensorFlowNET.Keras/Datasets/Imdb.cs
tags/v0.110.4-Transformer-Model
lingbai-kong 2 years ago
parent
commit
10f6819f08
100 changed files with 2396 additions and 129 deletions
  1. +30
    -0
      src/TensorFlowNET.Core/APIs/c_api.cs
  2. +1
    -1
      src/TensorFlowNET.Core/APIs/c_api.customize.cs
  3. +10
    -8
      src/TensorFlowNET.Core/APIs/tf.array.cs
  4. +5
    -5
      src/TensorFlowNET.Core/APIs/tf.control_flow.cs
  5. +119
    -12
      src/TensorFlowNET.Core/APIs/tf.image.cs
  6. +27
    -3
      src/TensorFlowNET.Core/APIs/tf.math.cs
  7. +21
    -0
      src/TensorFlowNET.Core/APIs/tf.nn.cs
  8. +1
    -1
      src/TensorFlowNET.Core/APIs/tf.reshape.cs
  9. +11
    -4
      src/TensorFlowNET.Core/APIs/tf.tensor.cs
  10. +1
    -1
      src/TensorFlowNET.Core/APIs/tf.tile.cs
  11. +23
    -2
      src/TensorFlowNET.Core/Binding.Util.cs
  12. +0
    -0
      src/TensorFlowNET.Core/Common/Extensions/DictionaryExtension.cs
  13. +3
    -3
      src/TensorFlowNET.Core/Common/Extensions/JObjectExtensions.cs
  14. +38
    -0
      src/TensorFlowNET.Core/Common/Extensions/LinqExtensions.cs
  15. +33
    -0
      src/TensorFlowNET.Core/Common/Extensions/NestExtensions.cs
  16. +0
    -0
      src/TensorFlowNET.Core/Common/Extensions/OneofExtension.cs
  17. +20
    -0
      src/TensorFlowNET.Core/Common/Types/FakeTensorByTensorArray.cs
  18. +69
    -0
      src/TensorFlowNET.Core/Common/Types/GeneralizedTensorShape.cs
  19. +40
    -0
      src/TensorFlowNET.Core/Common/Types/INestStructure.cs
  20. +11
    -0
      src/TensorFlowNET.Core/Common/Types/INestable.cs
  21. +21
    -0
      src/TensorFlowNET.Core/Common/Types/IOptionalArgs.cs
  22. +0
    -0
      src/TensorFlowNET.Core/Common/Types/NamedTuple.cs
  23. +62
    -0
      src/TensorFlowNET.Core/Common/Types/Nest.Static.cs
  24. +485
    -0
      src/TensorFlowNET.Core/Common/Types/Nest.cs
  25. +103
    -0
      src/TensorFlowNET.Core/Common/Types/NestDictionary.cs
  26. +53
    -0
      src/TensorFlowNET.Core/Common/Types/NestList.cs
  27. +36
    -0
      src/TensorFlowNET.Core/Common/Types/NestNode.cs
  28. +1
    -1
      src/TensorFlowNET.Core/Common/Types/TensorShapeConfig.cs
  29. +2
    -2
      src/TensorFlowNET.Core/Data/DatasetV2.cs
  30. +7
    -1
      src/TensorFlowNET.Core/Eager/EagerRunner.TFE_FastPathExecute.cs
  31. +7
    -3
      src/TensorFlowNET.Core/Eager/EagerRunner.TFE_TapeGradient.cs
  32. +6
    -1
      src/TensorFlowNET.Core/Eager/EagerTensor.ToString.cs
  33. +25
    -0
      src/TensorFlowNET.Core/Eager/GraphOnlyOps.cs
  34. +19
    -0
      src/TensorFlowNET.Core/Exceptions/NotOkStatusException.cs
  35. +14
    -1
      src/TensorFlowNET.Core/Framework/IndexedSlices.cs
  36. +13
    -0
      src/TensorFlowNET.Core/Framework/Models/TensorSpec.cs
  37. +89
    -0
      src/TensorFlowNET.Core/Framework/auto_control_deps_utils.cs
  38. +2
    -2
      src/TensorFlowNET.Core/Framework/function_def_lib.cs
  39. +13
    -0
      src/TensorFlowNET.Core/Functions/ConcreteFunction.cs
  40. +4
    -1
      src/TensorFlowNET.Core/GlobalUsing.cs
  41. +10
    -3
      src/TensorFlowNET.Core/Gradients/array_grad.cs
  42. +131
    -0
      src/TensorFlowNET.Core/Gradients/math_grad.cs
  43. +17
    -0
      src/TensorFlowNET.Core/Gradients/nn_grad.cs
  44. +8
    -8
      src/TensorFlowNET.Core/Graphs/FuncGraph.cs
  45. +1
    -1
      src/TensorFlowNET.Core/Graphs/Graph.cs
  46. +10
    -0
      src/TensorFlowNET.Core/Keras/ArgsDefinition/Activation/ExponentialArgs.cs
  47. +10
    -0
      src/TensorFlowNET.Core/Keras/ArgsDefinition/Activation/HardSigmoidArgs.cs
  48. +11
    -0
      src/TensorFlowNET.Core/Keras/ArgsDefinition/Activation/SELUArgs.cs
  49. +10
    -0
      src/TensorFlowNET.Core/Keras/ArgsDefinition/Activation/SoftplusArgs.cs
  50. +10
    -0
      src/TensorFlowNET.Core/Keras/ArgsDefinition/Activation/SoftsignArgs.cs
  51. +10
    -0
      src/TensorFlowNET.Core/Keras/ArgsDefinition/Activation/SwishArgs.cs
  52. +10
    -0
      src/TensorFlowNET.Core/Keras/ArgsDefinition/Activation/TanhArgs.cs
  53. +10
    -0
      src/TensorFlowNET.Core/Keras/ArgsDefinition/Convolution/Conv2DTransposeArgs.cs
  54. +10
    -0
      src/TensorFlowNET.Core/Keras/ArgsDefinition/Merging/AddArgs.cs
  55. +10
    -0
      src/TensorFlowNET.Core/Keras/ArgsDefinition/Merging/ConcatenateArgs.cs
  56. +10
    -0
      src/TensorFlowNET.Core/Keras/ArgsDefinition/Merging/SubtractArgs.cs
  57. +10
    -0
      src/TensorFlowNET.Core/Keras/ArgsDefinition/Pooling/GlobalAveragePooling1DArgs.cs
  58. +10
    -0
      src/TensorFlowNET.Core/Keras/ArgsDefinition/Pooling/GlobalAveragePooling2DArgs.cs
  59. +10
    -0
      src/TensorFlowNET.Core/Keras/ArgsDefinition/Pooling/GlobalMaxPooling1DArgs.cs
  60. +10
    -0
      src/TensorFlowNET.Core/Keras/ArgsDefinition/Pooling/GlobalMaxPooling2DArgs.cs
  61. +10
    -0
      src/TensorFlowNET.Core/Keras/ArgsDefinition/Pooling/MaxPooling1DArgs.cs
  62. +1
    -1
      src/TensorFlowNET.Core/Keras/ArgsDefinition/Reshaping/UpSampling2DArgs.cs
  63. +10
    -0
      src/TensorFlowNET.Core/Keras/ArgsDefinition/Reshaping/Upsampling1DArgs.cs
  64. +20
    -0
      src/TensorFlowNET.Core/Keras/ArgsDefinition/Rnn/BidirectionalArgs.cs
  65. +29
    -0
      src/TensorFlowNET.Core/Keras/ArgsDefinition/Rnn/GRUArgs.cs
  66. +39
    -0
      src/TensorFlowNET.Core/Keras/ArgsDefinition/Rnn/GRUCellArgs.cs
  67. +13
    -0
      src/TensorFlowNET.Core/Keras/ArgsDefinition/Rnn/GRUOptionalArgs.cs
  68. +6
    -3
      src/TensorFlowNET.Core/Keras/ArgsDefinition/Rnn/LSTMArgs.cs
  69. +30
    -2
      src/TensorFlowNET.Core/Keras/ArgsDefinition/Rnn/LSTMCellArgs.cs
  70. +20
    -25
      src/TensorFlowNET.Core/Keras/ArgsDefinition/Rnn/RNNArgs.cs
  71. +14
    -0
      src/TensorFlowNET.Core/Keras/ArgsDefinition/Rnn/RnnOptionalArgs.cs
  72. +1
    -1
      src/TensorFlowNET.Core/Keras/ArgsDefinition/Rnn/SimpleRNNArgs.cs
  73. +27
    -0
      src/TensorFlowNET.Core/Keras/ArgsDefinition/Rnn/SimpleRNNCellArgs.cs
  74. +3
    -3
      src/TensorFlowNET.Core/Keras/ArgsDefinition/Rnn/StackedRNNCellsArgs.cs
  75. +24
    -0
      src/TensorFlowNET.Core/Keras/ArgsDefinition/Rnn/WrapperArgs.cs
  76. +3
    -0
      src/TensorFlowNET.Core/Keras/Engine/ICallback.cs
  77. +1
    -1
      src/TensorFlowNET.Core/Keras/Engine/IModel.cs
  78. +75
    -0
      src/TensorFlowNET.Core/Keras/Engine/KerasTensor.cs
  79. +22
    -1
      src/TensorFlowNET.Core/Keras/IOptimizerApi.cs
  80. +3
    -2
      src/TensorFlowNET.Core/Keras/Layers/ILayer.cs
  81. +4
    -0
      src/TensorFlowNET.Core/Keras/Layers/ILayersApi.Reshaping.cs
  82. +91
    -1
      src/TensorFlowNET.Core/Keras/Layers/ILayersApi.cs
  83. +25
    -0
      src/TensorFlowNET.Core/Keras/Layers/Rnn/IRnnCell.cs
  84. +12
    -0
      src/TensorFlowNET.Core/Keras/Layers/Rnn/IStackedRnnCells.cs
  85. +1
    -0
      src/TensorFlowNET.Core/Keras/Saving/Json/CustomizedKerasShapesWrapperJsonConverter.cs
  86. +1
    -0
      src/TensorFlowNET.Core/Keras/Saving/KerasShapesWrapper.cs
  87. +0
    -5
      src/TensorFlowNET.Core/NumPy/Axis.cs
  88. +6
    -0
      src/TensorFlowNET.Core/NumPy/NDArray.Implicit.cs
  89. +9
    -9
      src/TensorFlowNET.Core/NumPy/NDArrayRender.cs
  90. +4
    -0
      src/TensorFlowNET.Core/NumPy/NumPy.Sorting.Searching.Counting.cs
  91. +2
    -2
      src/TensorFlowNET.Core/NumPy/NumPy.Statistics.cs
  92. +21
    -0
      src/TensorFlowNET.Core/NumPy/Numpy.Math.cs
  93. +23
    -1
      src/TensorFlowNET.Core/Numpy/Shape.cs
  94. +22
    -0
      src/TensorFlowNET.Core/Operations/Initializers/NpyLoadInitializer.cs
  95. +2
    -3
      src/TensorFlowNET.Core/Operations/Initializers/Orthogonal.cs
  96. +2
    -1
      src/TensorFlowNET.Core/Operations/NnOps/BasicLSTMCell.cs
  97. +1
    -0
      src/TensorFlowNET.Core/Operations/NnOps/BasicRNNCell.cs
  98. +1
    -0
      src/TensorFlowNET.Core/Operations/NnOps/LayerRNNCell.cs
  99. +18
    -3
      src/TensorFlowNET.Core/Operations/NnOps/RNNCell.cs
  100. +57
    -1
      src/TensorFlowNET.Core/Operations/OpDefLibrary.cs

+ 30
- 0
src/TensorFlowNET.Core/APIs/c_api.cs View File

@@ -16,6 +16,7 @@

using System;
using System.Runtime.InteropServices;
using static Tensorflow.CppShapeInferenceResult.Types;

namespace Tensorflow
{
@@ -50,6 +51,35 @@ namespace Tensorflow
return handle == IntPtr.Zero ? String.Empty : Marshal.PtrToStringAnsi(handle);
}

public unsafe static byte[] ByteStringPiece(Buffer? handle)
{
if (handle is null)
{
return new byte[0];
}
var data = handle.ToArray();
return data;
}
public unsafe static byte[] ByteStringPieceFromNativeString(IntPtr handle)
{
if (handle == IntPtr.Zero)
{
return new byte[0];
}

byte* str_data = (byte*)handle.ToPointer();
List<byte> bytes = new List<byte>();
byte current = 255;
while (current != ((byte)'\0'))
{
current = *(str_data++);
bytes.Add(current);
}
var data = bytes.ToArray();
return data;
}

[UnmanagedFunctionPointer(CallingConvention.Winapi)]
public delegate void Deallocator(IntPtr data, IntPtr size, ref DeallocatorArgs args);



+ 1
- 1
src/TensorFlowNET.Core/APIs/c_api.customize.cs View File

@@ -10,7 +10,7 @@ namespace Tensorflow
[DllImport(TensorFlowLibName)]
public static extern void TFC_SetAttr(SafeGraphHandle graph, IntPtr op, string attr_name, SafeBufferHandle attr_value_proto, SafeStatusHandle status);
[DllImport(TensorFlowLibName)]
public static extern IntPtr TFC_GetHandleShapeAndType(SafeGraphHandle c_graph, TF_Output output);
public static extern SafeBufferHandle TFC_GetHandleShapeAndType(SafeGraphHandle c_graph, TF_Output output);
[DllImport(TensorFlowLibName)]
public static extern void TFC_SetHandleShapeAndType(SafeGraphHandle c_graph, TF_Output output, byte[] data, long proto_len, SafeStatusHandle status);
}


+ 10
- 8
src/TensorFlowNET.Core/APIs/tf.array.cs View File

@@ -91,8 +91,7 @@ namespace Tensorflow
return identity(values.First(), name: scope);
});
}

return gen_array_ops.concat_v2(values.ToArray(), ops.convert_to_tensor(axis), name: name);
return array_ops.concat(values.ToArray(), axis, name: name);
}

/// <summary>
@@ -163,14 +162,17 @@ namespace Tensorflow
/// Reverses specific dimensions of a tensor.
/// </summary>
/// <param name="tensor"></param>
/// <param name="axis"></param>
/// <param name="axis">The indices of the dimensions to reverse. Must be in the range [-rank(tensor), rank(tensor)).</param>
/// <param name="name"></param>
/// <returns></returns>
public Tensor reverse(Tensor tensor, int[] axis, string name = null)
=> gen_array_ops.reverse(tensor, ops.convert_to_tensor(axis), name: name);

public Tensor reverse(Tensor tensor, Tensor axis, string name = null)
=> gen_array_ops.reverse(tensor, axis, name: name);
public Tensor reverse(Tensor tensor, Axis axis, string name = null)
{
if (axis.IsScalar)
{
axis = new Axis(axis.axis);
}
return array_ops.reverse(tensor, axis, name: name);
}

/// <summary>
/// Returns the rank of a tensor.


+ 5
- 5
src/TensorFlowNET.Core/APIs/tf.control_flow.cs View File

@@ -46,10 +46,10 @@ namespace Tensorflow
Tensor loop_vars,
int parallel_iterations = 10)
{
Func<Tensor[], Tensor> cond1 = x
Func<Tensors, Tensor> cond1 = x
=> cond(x[0]);

Func<Tensor[], Tensor[]> body1 = x
Func<Tensors, Tensors> body1 = x
=> new[] { body(x[0]) };

var results = control_flow_ops.while_loop(cond1,
@@ -58,9 +58,9 @@ namespace Tensorflow
return results[0];
}

public Tensor[] while_loop(Func<Tensor[], Tensor> cond,
Func<Tensor[], Tensor[]> body,
Tensor[] loop_vars,
public Tensor[] while_loop(Func<Tensors, Tensor> cond,
Func<Tensors, Tensors> body,
Tensors loop_vars,
int parallel_iterations = 10,
string name = null)
=> control_flow_ops.while_loop(cond, body, loop_vars,


+ 119
- 12
src/TensorFlowNET.Core/APIs/tf.image.cs View File

@@ -14,6 +14,10 @@
limitations under the License.
******************************************************************************/

using OneOf.Types;
using System;
using System.Buffers.Text;
using Tensorflow.Contexts;
using static Tensorflow.Binding;

namespace Tensorflow
@@ -162,17 +166,108 @@ namespace Tensorflow
public Tensor sobel_edges(Tensor image)
=> image_ops_impl.sobel_edges(image);

public Tensor decode_jpeg(Tensor contents,
int channels = 0,
int ratio = 1,
bool fancy_upscaling = true,
bool try_recover_truncated = false,
int acceptable_fraction = 1,
string dct_method = "",
string name = null)
=> gen_image_ops.decode_jpeg(contents, channels: channels, ratio: ratio,
fancy_upscaling: fancy_upscaling, try_recover_truncated: try_recover_truncated,
acceptable_fraction: acceptable_fraction, dct_method: dct_method);
/// <summary>
/// Adjust contrast of RGB or grayscale images.
/// </summary>
/// <param name="images">Images to adjust. At least 3-D.</param>
/// <param name="contrast_factor"></param>
/// <param name="name">A float multiplier for adjusting contrast.</param>
/// <returns>The contrast-adjusted image or images.</returns>
public Tensor adjust_contrast(Tensor images, float contrast_factor, string name = null)
=> gen_image_ops.adjust_contrastv2(images, contrast_factor, name);

/// <summary>
/// Adjust hue of RGB images.
/// </summary>
/// <param name="images">RGB image or images. The size of the last dimension must be 3.</param>
/// <param name="delta">float. How much to add to the hue channel.</param>
/// <param name="name">A name for this operation (optional).</param>
/// <returns>Adjusted image(s), same shape and DType as `image`.</returns>
/// <exception cref="ValueError">if `delta` is not in the interval of `[-1, 1]`.</exception>
public Tensor adjust_hue(Tensor images, float delta, string name = null)
{
if (tf.Context.executing_eagerly())
{
if (delta < -1f || delta > 1f)
throw new ValueError("delta must be in the interval [-1, 1]");
}
return gen_image_ops.adjust_hue(images, delta, name: name);
}

/// <summary>
/// Adjust saturation of RGB images.
/// </summary>
/// <param name="image">RGB image or images. The size of the last dimension must be 3.</param>
/// <param name="saturation_factor">float. Factor to multiply the saturation by.</param>
/// <param name="name">A name for this operation (optional).</param>
/// <returns>Adjusted image(s), same shape and DType as `image`.</returns>
public Tensor adjust_saturation(Tensor image, float saturation_factor, string name = null)
=> gen_image_ops.adjust_saturation(image, saturation_factor, name);

/// <summary>
/// Greedily selects a subset of bounding boxes in descending order of score.
/// </summary>
/// <param name="boxes">
/// A 4-D float `Tensor` of shape `[batch_size, num_boxes, q, 4]`. If `q`
/// is 1 then same boxes are used for all classes otherwise, if `q` is equal
/// to number of classes, class-specific boxes are used.
/// </param>
/// <param name="scores">
/// A 3-D float `Tensor` of shape `[batch_size, num_boxes, num_classes]`
/// representing a single score corresponding to each box(each row of boxes).
/// </param>
/// <param name="max_output_size_per_class">
/// A scalar integer `Tensor` representing the
/// maximum number of boxes to be selected by non-max suppression per class
/// </param>
/// <param name="max_total_size">
/// A int32 scalar representing maximum number of boxes retained
/// over all classes.Note that setting this value to a large number may
/// result in OOM error depending on the system workload.
/// </param>
/// <param name="iou_threshold">
/// A float representing the threshold for deciding whether boxes
/// overlap too much with respect to IOU.
/// </param>
/// <param name="score_threshold">
/// A float representing the threshold for deciding when to
/// remove boxes based on score.
/// </param>
/// <param name="pad_per_class">
/// If false, the output nmsed boxes, scores and classes are
/// padded/clipped to `max_total_size`. If true, the output nmsed boxes, scores and classes are padded to be of length `max_size_per_class`*`num_classes`,
/// unless it exceeds `max_total_size` in which case it is clipped to `max_total_size`. Defaults to false.
/// </param>
/// <param name="clip_boxes">
/// If true, the coordinates of output nmsed boxes will be clipped
/// to[0, 1]. If false, output the box coordinates as it is. Defaults to true.
/// </param>
/// <returns>
/// 'nmsed_boxes': A [batch_size, max_detections, 4] float32 tensor containing the non-max suppressed boxes.
/// 'nmsed_scores': A [batch_size, max_detections] float32 tensor containing the scores for the boxes.
/// 'nmsed_classes': A [batch_size, max_detections] float32 tensor containing the class for boxes.
/// 'valid_detections': A [batch_size] int32 tensor indicating the number of
/// valid detections per batch item. Only the top valid_detections[i] entries
/// in nms_boxes[i], nms_scores[i] and nms_class[i] are valid. The rest of the
/// entries are zero paddings.
/// </returns>
public (Tensor, Tensor, Tensor, Tensor) combined_non_max_suppression(
Tensor boxes,
Tensor scores,
int max_output_size_per_class,
int max_total_size,
float iou_threshold,
float score_threshold,
bool pad_per_class = false,
bool clip_boxes = true)
{
var iou_threshold_t = ops.convert_to_tensor(iou_threshold, TF_DataType.TF_FLOAT, name: "iou_threshold");
var score_threshold_t = ops.convert_to_tensor(score_threshold, TF_DataType.TF_FLOAT, name: "score_threshold");
var max_total_size_t = ops.convert_to_tensor(max_total_size);
var max_output_size_per_class_t = ops.convert_to_tensor(max_output_size_per_class);
return gen_image_ops.combined_non_max_suppression(boxes, scores, max_output_size_per_class_t, max_total_size_t,
iou_threshold_t, score_threshold_t, pad_per_class, clip_boxes);
}

/// <summary>
/// Extracts crops from the input image tensor and resizes them using bilinear sampling or nearest neighbor sampling (possibly with aspect ratio change) to a common output size specified by crop_size. This is more general than the crop_to_bounding_box op which extracts a fixed size slice from the input image and does not allow resizing or aspect ratio change.
@@ -187,7 +282,19 @@ namespace Tensorflow
/// <param name="name">A name for the operation (optional).</param>
/// <returns>A 4-D tensor of shape [num_boxes, crop_height, crop_width, depth].</returns>
public Tensor crop_and_resize(Tensor image, Tensor boxes, Tensor box_ind, Tensor crop_size, string method = "bilinear", float extrapolation_value = 0f, string name = null) =>
image_ops_impl.crop_and_resize(image, boxes, box_ind, crop_size, method, extrapolation_value, name);
gen_image_ops.crop_and_resize(image, boxes, box_ind, crop_size, method, extrapolation_value, name);

public Tensor decode_jpeg(Tensor contents,
int channels = 0,
int ratio = 1,
bool fancy_upscaling = true,
bool try_recover_truncated = false,
int acceptable_fraction = 1,
string dct_method = "",
string name = null)
=> gen_image_ops.decode_jpeg(contents, channels: channels, ratio: ratio,
fancy_upscaling: fancy_upscaling, try_recover_truncated: try_recover_truncated,
acceptable_fraction: acceptable_fraction, dct_method: dct_method);

public Tensor extract_glimpse(Tensor input, Tensor size, Tensor offsets, bool centered = true, bool normalized = true,
bool uniform_noise = true, string name = null)


+ 27
- 3
src/TensorFlowNET.Core/APIs/tf.math.cs View File

@@ -14,6 +14,7 @@
limitations under the License.
******************************************************************************/

using Tensorflow.NumPy;
using Tensorflow.Operations;

namespace Tensorflow
@@ -42,10 +43,20 @@ namespace Tensorflow

public Tensor multiply(Tensor x, Tensor y, string name = null)
=> math_ops.multiply(x, y, name: name);

public Tensor divide_no_nan(Tensor a, Tensor b, string name = null)
=> math_ops.div_no_nan(a, b);

/// <summary>
/// Computes the Euclidean norm of elements across dimensions of a tensor.
/// </summary>
/// <param name="input_tensor">The tensor to reduce. Should have numeric type.</param>
/// <param name="axis">The dimensions to reduce. If `None` (the default), reduces all dimensions.Must be in the range `[-rank(input_tensor), rank(input_tensor))`</param>
/// <param name="keepdims">If true, retains reduced dimensions with length 1.</param>
/// <param name="name">A name for the operation (optional).</param>
/// <returns>The reduced tensor, of the same dtype as the input_tensor.</returns>
public Tensor reduce_euclidean_norm(Tensor input_tensor, Axis? axis = null, bool keepdims = false, string name = null)
=> math_ops.reduce_euclidean_norm(input_tensor, axis: axis, keepdims: keepdims, name);

public Tensor square(Tensor x, string name = null)
=> math_ops.square(x, name: name);

@@ -354,7 +365,7 @@ namespace Tensorflow
=> a / b;

public Tensor sqrt(Tensor a, string name = null)
=> gen_math_ops.sqrt(a, name);
=> math_ops.sqrt(a, name);

public Tensor sign(Tensor a, string name = null)
=> gen_math_ops.sign(a, name);
@@ -452,7 +463,18 @@ namespace Tensorflow
/// <returns></returns>
public Tensor multiply<Tx, Ty>(Tx x, Ty y, string name = null)
=> gen_math_ops.mul(ops.convert_to_tensor(x), ops.convert_to_tensor(y), name: name);

/// <summary>
/// return scalar product
/// </summary>
/// <typeparam name="Tx"></typeparam>
/// <typeparam name="Ty"></typeparam>
/// <param name="x"></param>
/// <param name="y"></param>
/// <param name="axes"></param>
/// <param name="name"></param>
/// <returns></returns>
public Tensor dot_prod<Tx, Ty>(Tx x, Ty y, NDArray axes, string name = null)
=> math_ops.tensordot(convert_to_tensor(x), convert_to_tensor(y), axes, name: name);
public Tensor negative(Tensor x, string name = null)
=> gen_math_ops.neg(x, name);

@@ -600,5 +622,7 @@ namespace Tensorflow
=> gen_math_ops.squared_difference(x: x, y: y, name: name);
public Tensor complex(Tensor real, Tensor imag, Tensorflow.TF_DataType? dtype = null,
string name = null) => gen_ops.complex(real, imag, dtype, name);
public Tensor exp(Tensor x,
string name = null) => gen_math_ops.exp(x, name);
}
}

+ 21
- 0
src/TensorFlowNET.Core/APIs/tf.nn.cs View File

@@ -14,6 +14,7 @@
limitations under the License.
******************************************************************************/

using System.Xml.Linq;
using Tensorflow.Operations;
using Tensorflow.Operations.Activation;
using static Tensorflow.Binding;
@@ -126,6 +127,26 @@ namespace Tensorflow
name: name,
exponential_avg_factor: exponential_avg_factor);

/// <summary>
/// Normalizes a tensor by `mean` and `variance`, and applies (optionally) a`scale` \\(\gamma\\) to it, as well as an `offset` \\(\beta\\).
/// </summary>
/// <param name="x">A floating point tensor.</param>
/// <param name="mean">A mean `Tensor`.</param>
/// <param name="variance">A variance `Tensor`.</param>
/// <param name="offset"> An offset `Tensor`, often denoted \\(\beta\\) in equations, or NULL. If present, will be added to the normalized tensor.</param>
/// <param name="scale"> A scale `Tensor`, often denoted \\(\gamma\\) in equations, or NULL. If present, the scale is applied to the normalized tensor.</param>
/// <param name="variance_epsilon"> A small float number to avoid dividing by 0.</param>
/// <param name="name">A name for this operation.</param>
/// <returns>the normalized, scaled, offset tensor.</returns>
public Tensor batch_normalization(Tensor x,
Tensor mean,
Tensor variance,
Tensor offset,
Tensor scale,
float variance_epsilon,
string name = null) => nn_impl.batch_normalization(x, mean, variance, offset, scale, variance_epsilon, name);


public Tensor max_pool(Tensor value, int[] ksize, int[] strides, string padding, string data_format = "NHWC", string name = null)
=> nn_ops.max_pool(value, ksize, strides, padding, data_format: data_format, name: name);



+ 1
- 1
src/TensorFlowNET.Core/APIs/tf.reshape.cs View File

@@ -31,6 +31,6 @@ namespace Tensorflow
public Tensor reshape(Tensor tensor,
object[] shape,
string name = null)
=> gen_array_ops.reshape(tensor, ops.convert_to_tensor(shape), name);
=> array_ops.reshape(tensor, shape, name);
}
}

+ 11
- 4
src/TensorFlowNET.Core/APIs/tf.tensor.cs View File

@@ -68,20 +68,27 @@ namespace Tensorflow
/// <param name="name">A name for the operation (optional)</param>
/// <returns>if num_or_size_splits is a scalar returns num_or_size_splits Tensor objects;
/// if num_or_size_splits is a 1-D Tensor returns num_or_size_splits.get_shape[0] Tensor objects resulting from splitting value.</returns>
public Tensor[] split(Tensor value, int num_split, Tensor axis, string name = null)
public Tensor[] split(Tensor value, int num_split, Axis axis, string name = null)
=> array_ops.split(
value: value,
num_split: num_split,
num_or_size_splits: num_split,
axis: axis,
name: name);

public Tensor[] split(Tensor value, int num_split, int axis, string name = null)
public Tensor[] split(Tensor value, int[] num_split, Axis axis, string name = null)
=> array_ops.split(
value: value,
num_split: num_split,
num_or_size_splits: num_split,
axis: axis,
name: name);

//public Tensor[] split(Tensor value, int num_split, Axis axis, string name = null)
// => array_ops.split(
// value: value,
// num_or_size_splits: num_split,
// axis: axis,
// name: name);

public Tensor ensure_shape(Tensor x, Shape shape, string name = null)
{
return gen_ops.ensure_shape(x, shape, name);


+ 1
- 1
src/TensorFlowNET.Core/APIs/tf.tile.cs View File

@@ -23,7 +23,7 @@ namespace Tensorflow
=> gen_array_ops.tile(input, multiples, name);

public Tensor tile(Tensor input, object[] multiples, string name = null)
=> gen_array_ops.tile(input, ops.convert_to_tensor(multiples), name);
=> array_ops.tile(input, constant_op.constant(shape_utils.from_object_array(multiples).dims), name);

public Tensor tile(Tensor input, Shape multiples, string name = null)
{


+ 23
- 2
src/TensorFlowNET.Core/Binding.Util.cs View File

@@ -486,7 +486,28 @@ namespace Tensorflow
throw new NotImplementedException("");
}
}

public static NDArray GetFlattenArray(NDArray x)
{
switch (x.GetDataType())
{
case TF_DataType.TF_FLOAT:
x = x.ToArray<float>();
break;
case TF_DataType.TF_DOUBLE:
x = x.ToArray<double>();
break;
case TF_DataType.TF_INT16:
case TF_DataType.TF_INT32:
x = x.ToArray<int>();
break;
case TF_DataType.TF_INT64:
x = x.ToArray<long>();
break;
default:
break;
}
return x;
}
public static TF_DataType GetDataType(this object data)
{
var type = data.GetType();
@@ -503,7 +524,7 @@ namespace Tensorflow
case Tensors tensors:
return tensors.dtype;
case IEnumerable<Tensor> tensors:
return tensors.First().dtype;
return tensors.Where(x => x is not null).First().dtype;
case RefVariable variable:
return variable.dtype;
case ResourceVariable variable:


src/TensorFlowNET.Core/Extensions/DictionaryExtension.cs → src/TensorFlowNET.Core/Common/Extensions/DictionaryExtension.cs View File


src/TensorFlowNET.Core/Extensions/JObjectExtensions.cs → src/TensorFlowNET.Core/Common/Extensions/JObjectExtensions.cs View File

@@ -3,16 +3,16 @@ using System;
using System.Collections.Generic;
using System.Text;

namespace Tensorflow.Extensions
namespace Tensorflow.Common.Extensions
{
public static class JObjectExtensions
{
public static T? TryGetOrReturnNull<T>(this JObject obj, string key)
{
var res = obj[key];
if(res is null)
if (res is null)
{
return default(T);
return default;
}
else
{

+ 38
- 0
src/TensorFlowNET.Core/Common/Extensions/LinqExtensions.cs View File

@@ -0,0 +1,38 @@
using System;
using System.Collections.Generic;
using System.Linq;
using System.Text;

namespace Tensorflow.Common.Extensions
{
public static class LinqExtensions
{
#if NETSTANDARD2_0
public static IEnumerable<T> TakeLast<T>(this IEnumerable<T> sequence, int count)
{
return sequence.Skip(sequence.Count() - count);
}

public static IEnumerable<T> SkipLast<T>(this IEnumerable<T> sequence, int count)
{
return sequence.Take(sequence.Count() - count);
}
#endif
public static Tensors ToTensors(this Tensor[] tensors)
{
return new Tensors(tensors);
}

public static Tensors ToTensors(this IList<Tensor> tensors)
{
return new Tensors(tensors);
}

public static void Deconstruct<T1, T2, T3>(this (T1, T2, T3) values, out T1 first, out T2 second, out T3 third)
{
first = values.Item1;
second = values.Item2;
third = values.Item3;
}
}
}

+ 33
- 0
src/TensorFlowNET.Core/Common/Extensions/NestExtensions.cs View File

@@ -0,0 +1,33 @@
using System;
using System.Collections.Generic;
using System.Text;
using Tensorflow.Common.Types;

namespace Tensorflow.Common.Extensions
{
public static class NestExtensions
{
public static Tensors ToTensors(this INestable<Tensor> tensors)
{
return new Tensors(tensors.AsNest());
}

public static Tensors? ToTensors(this Nest<Tensor> tensors)
{
return Tensors.FromNest(tensors);
}

/// <summary>
/// If the nested object is already a nested type, this function could reduce it.
/// For example, `Nest[Nest[T]]` can be reduced to `Nest[T]`.
/// </summary>
/// <typeparam name="TIn"></typeparam>
/// <typeparam name="TOut"></typeparam>
/// <param name="input"></param>
/// <returns></returns>
public static Nest<TOut> ReduceTo<TIn, TOut>(this INestStructure<TIn> input) where TIn: INestStructure<TOut>
{
return Nest<TOut>.ReduceFrom(input);
}
}
}

src/TensorFlowNET.Core/Extensions/OneofExtension.cs → src/TensorFlowNET.Core/Common/Extensions/OneofExtension.cs View File


+ 20
- 0
src/TensorFlowNET.Core/Common/Types/FakeTensorByTensorArray.cs View File

@@ -0,0 +1,20 @@
using System;
using System.Collections.Generic;
using System.Text;

namespace Tensorflow.Common.Types
{
/// <summary>
/// This is a temp solution, which should be removed after refactoring `Tensors`
/// </summary>
[Obsolete]
public class FakeTensorByTensorArray: Tensor
{
public TensorArray TensorArray { get; set; }

public FakeTensorByTensorArray(TensorArray array)
{
TensorArray = array;
}
}
}

+ 69
- 0
src/TensorFlowNET.Core/Common/Types/GeneralizedTensorShape.cs View File

@@ -0,0 +1,69 @@
using System;
using System.Collections.Generic;
using System.Diagnostics;
using System.Text;

namespace Tensorflow.Common.Types
{
public class GeneralizedTensorShape: Nest<Shape>
{
public GeneralizedTensorShape(Shape value, string? name = null)
{
NodeValue = value;
NestType = NestType.Node;
}

public GeneralizedTensorShape(IEnumerable<Shape> values, string? name = null)
{
ListValue = values.Select(s => new Nest<Shape>(s) as INestStructure<Shape>).ToList();
Name = name;
NestType = NestType.List;
}

public GeneralizedTensorShape(Dictionary<string, Shape> value, string? name = null)
{
DictValue = value.ToDictionary(x => x.Key, x => new Nest<Shape>(x.Value) as INestStructure<Shape>);
Name = name;
NestType = NestType.Dictionary;
}

public GeneralizedTensorShape(Nest<Shape> other)
{
NestType = other.NestType;
NodeValue = other.NodeValue;
DictValue = other.DictValue;
ListValue = other.ListValue;
Name = other.Name;
}

public Shape ToSingleShape()
{
var shapes = Flatten().ToList();
if (shapes.Count != 1)
{
throw new ValueError("The generalized shape contains more than 1 dim.");
}
return shapes[0];
}

public long ToNumber()
{
var shapes = Flatten().ToList();
if (shapes.Count != 1 || shapes[0].ndim != 1)
{
throw new ValueError("The generalized shape contains more than 1 dim.");
}
return shapes[0].dims[0];
}

public INestStructure<TensorShapeConfig> ToTensorShapeConfigs()
{
return MapStructure(s => new TensorShapeConfig() { Items = s.dims.Select<long, long?>(x => x == -1 ? null : x).ToArray() });
}

public static implicit operator GeneralizedTensorShape(Shape shape)
{
return new GeneralizedTensorShape(shape);
}
}
}

+ 40
- 0
src/TensorFlowNET.Core/Common/Types/INestStructure.cs View File

@@ -0,0 +1,40 @@
using System;
using System.Collections.Generic;
using System.Text;

namespace Tensorflow.Common.Types
{
/// <summary>
/// This interface indicates that a class may have a nested structure and provide
/// methods to manipulate with the structure.
/// </summary>
public interface INestStructure<T>: INestable<T>
{
NestType NestType { get; }

/// <summary>
/// The item count of depth 1 of the nested structure.
/// For example, [1, 2, [3, 4, 5]] has ShallowNestedCount = 3.
/// </summary>
int ShallowNestedCount { get; }
/// <summary>
/// The total item count of depth 1 of the nested structure.
/// For example, [1, 2, [3, 4, 5]] has TotalNestedCount = 5.
/// </summary>
int TotalNestedCount { get; }

/// <summary>
/// Flatten the Nestable object. Node that if the object contains only one value,
/// it will be flattened to an enumerable with one element.
/// </summary>
/// <returns></returns>
IEnumerable<T> Flatten();
/// <summary>
/// Construct a new object with the same nested structure.
/// </summary>
/// <typeparam name="TOut"></typeparam>
/// <param name="func"></param>
/// <returns></returns>
INestStructure<TOut> MapStructure<TOut>(Func<T, TOut> func);
}
}

+ 11
- 0
src/TensorFlowNET.Core/Common/Types/INestable.cs View File

@@ -0,0 +1,11 @@
using System;
using System.Collections.Generic;
using System.Text;

namespace Tensorflow.Common.Types
{
public interface INestable<T>
{
Nest<T> AsNest();
}
}

+ 21
- 0
src/TensorFlowNET.Core/Common/Types/IOptionalArgs.cs View File

@@ -0,0 +1,21 @@
using System;
using System.Collections.Generic;
using System.Text;

namespace Tensorflow.Common.Types
{
/// <summary>
/// This interface is used when some corresponding python methods have optional args.
/// For example, `Keras.Layer.Apply` generally takes three args as the inputs, while
/// `Keras.Layer.RNN` takes more. Then when calling RNN, you should add `RnnOptionalArgs`
/// as the parameter of the method.
/// </summary>
public interface IOptionalArgs
{
/// <summary>
/// The identifier of the class. It is not an argument but only something to
/// separate different OptionalArgs.
/// </summary>
string Identifier { get; }
}
}

src/TensorFlowNET.Core/Extensions/NamedTuple.cs → src/TensorFlowNET.Core/Common/Types/NamedTuple.cs View File


+ 62
- 0
src/TensorFlowNET.Core/Common/Types/Nest.Static.cs View File

@@ -0,0 +1,62 @@
using System;
using System.Collections.Generic;
using System.Text;

namespace Tensorflow.Common.Types
{
public static class Nest
{
/// <summary>
/// Pack the flat items to a nested sequence by the template.
/// </summary>
/// <typeparam name="T"></typeparam>
/// <param name="template"></param>
/// <param name="flatItems"></param>
/// <returns></returns>
public static Nest<TOut> PackSequenceAs<T, TOut>(INestable<T> template, TOut[] flatItems)
{
return template.AsNest().PackSequence(flatItems);
}

/// <summary>
/// Pack the flat items to a nested sequence by the template.
/// </summary>
/// <typeparam name="T"></typeparam>
/// <param name="template"></param>
/// <param name="flatItems"></param>
/// <returns></returns>
public static Nest<T> PackSequenceAs<T>(INestable<T> template, List<T> flatItems)
{
return template.AsNest().PackSequence(flatItems.ToArray());
}

/// <summary>
/// Flatten the nested object.
/// </summary>
/// <typeparam name="T"></typeparam>
/// <param name="nestedObject"></param>
/// <returns></returns>
public static IEnumerable<T> Flatten<T>(INestable<T> nestedObject)
{
return nestedObject.AsNest().Flatten();
}

/// <summary>
/// Map the structure with specified function.
/// </summary>
/// <typeparam name="TIn"></typeparam>
/// <typeparam name="TOut"></typeparam>
/// <param name="func"></param>
/// <param name="nestedObject"></param>
/// <returns></returns>
public static INestStructure<TOut> MapStructure<TIn, TOut>(Func<TIn, TOut> func, INestable<TIn> nestedObject)
{
return nestedObject.AsNest().MapStructure(func);
}

public static bool IsNested<T>(INestable<T> obj)
{
return obj.AsNest().IsNested();
}
}
}

+ 485
- 0
src/TensorFlowNET.Core/Common/Types/Nest.cs View File

@@ -0,0 +1,485 @@
using System;
using System.Collections.Generic;
using System.Text;
using Tensorflow.Common.Extensions;

namespace Tensorflow.Common.Types
{
public enum NestType
{
Empty,
Node,
List,
Dictionary
}

/// <summary>
/// A nested structure which may inclulde value, list and dictionary.
/// Note that dictionary does not ensure the data order. When using it as IEnumerable,
/// its order is depth-first.
/// </summary>
/// <typeparam name="T"></typeparam>
public class Nest<T> : INestStructure<T>, IEnumerable<T>
{
private static readonly Nest<T> _empty = new Nest<T>()
{
NestType = NestType.Empty,
};
public static Nest<T> Empty => _empty;
public NestType NestType { get; protected set; }
public string? Name { get; set; }
public T? NodeValue { get; protected set; }
public List<INestStructure<T>>? ListValue { get; protected set; }
public Dictionary<string, INestStructure<T>>? DictValue { get; protected set; }

public int ShallowNestedCount
{
get
{
if (NestType == NestType.Empty)
{
return 0;
}
else if (NestType == NestType.Node)
{
return 1;
}
else if (NestType == NestType.List)
{
return ListValue!.Count;
}
else // dict
{
return DictValue!.Count;
}
}
}

public int TotalNestedCount
{
get
{
return Flatten().Count();
}
}

protected Nest() { }

public Nest(T value, string? name = null)
{
NodeValue = value;
Name = name;
NestType = NestType.Node;
}

public Nest(IEnumerable<INestStructure<T>> values, string? name = null)
{
ListValue = values.ToList();
Name = name;
NestType = NestType.List;
}

public Nest(Dictionary<string, INestStructure<T>> value, string? name = null)
{
DictValue = value;
Name = name;
NestType = NestType.Dictionary;
}

public Nest(Nest<T> other)
{
NestType = other.NestType;
NodeValue = other.NodeValue;
DictValue = other.DictValue;
ListValue = other.ListValue;
Name = other.Name;
}

public virtual IEnumerable<T> Flatten()
{
return FlattenInternal(this);
}
public virtual INestStructure<TOut> MapStructure<TOut>(Func<T, TOut> func)
{
return MapStructureInternal(func);
}

/// <summary>
/// Pack the flat items to a nested sequence by the template.
/// </summary>
/// <param name="flatItems"></param>
/// <returns></returns>
public virtual Nest<TOut> PackSequence<TOut>(TOut[] flatItems)
{
if(flatItems.Length == 0)
{
return Nest<TOut>.Empty;
}
int index = 0;
return PackSequenceInternal(this, flatItems, ref index);
}

private static Nest<TOut> PackSequenceInternal<TOut>(Nest<T> template, TOut[] flatItems, ref int index)
{
if(template.NestType == NestType.Node)
{
if(index >= flatItems.Length)
{
throw new InvalidArgumentError("The template and flat items are not matched.");
}
return new Nest<TOut>(flatItems[index++]);
}
else if(template.NestType == NestType.List)
{
List<Nest<TOut>> nestedObjects = new List<Nest<TOut>>();
for (int i = 0; i < template.ListValue!.Count; i++)
{
nestedObjects.Add(PackSequenceInternal(template.ListValue![i].AsNest(), flatItems, ref index));
}
return new Nest<TOut>(nestedObjects);
}
else if(template.NestType == NestType.Node)
{
Dictionary<string, INestStructure<TOut>> dict = new Dictionary<string, INestStructure<TOut>>();
foreach(var (key, value) in template.DictValue!)
{
dict[key] = PackSequenceInternal(value.AsNest(), flatItems, ref index);
}
return new Nest<TOut>(dict);
}
// Consider Empty as invalid type.
throw new InvalidArgumentError("When using `PackSequenceAs`, the template cannot contain empty node.");
}

public virtual Nest<T> AsNest()
{
return this;
}

public virtual Nest<T> MergeWith(Nest<T>? other)
{
if(other is null || other == Nest<T>.Empty)
{
return this;
}
if(this == Nest<T>.Empty)
{
return other;
}
if(NestType == NestType.Node && other.NestType == NestType.Node)
{
return new Nest<T>(new Nest<T>[] { this, other });
}
else if(NestType == NestType.List && other.NestType == NestType.List)
{
return new Nest<T>(this.ListValue!.Concat(other.ListValue!));
}
else if(NestType == NestType.Dictionary && other.NestType == NestType.Dictionary)
{
return new Nest<T>(this.DictValue!.Concat(other.DictValue!).ToDictionary(x => x.Key, x => x.Value));
}
else
{
return new Nest<T>(new Nest<T>[] { this, other });
}
}

/// <summary>
/// To see if the nested object is really nested. Despite being called `Nest`, sometimes it's actually not
/// nested. For example, [1, 2, 3] is not nested, while [1, [2, 3]] is nested.
/// </summary>
/// <returns></returns>
public bool IsNested()
{
if(NestType is NestType.Empty or NestType.Node)
{
return false;
}
else if(NestType is NestType.List)
{
return ListValue!.Count > 0;
}
else
{
return DictValue!.Count > 0;
}
}

[Obsolete("The indexer of Tensors is not encouraged because it leads to unclear meanings.")]
public T this[int index]
{
get
{
bool success = FindInternal(this, index, out var result);
if (success)
{
return result;
}
else
{
throw new IndexOutOfRangeException();
}
}
set
{
bool success = SetInternal(this, index, value);
if (!success)
{
throw new IndexOutOfRangeException();
}
}
}

/// <summary>
/// If the existing nested structure if of type `Nest[INestStructure[T]]`, we can reduce it
/// to `Nest[T]`.
/// </summary>
/// <typeparam name="TOut"></typeparam>
/// <param name="input"></param>
/// <returns></returns>
public static Nest<T> ReduceFrom<TOut>(INestStructure<TOut> input) where TOut: INestStructure<T>
{
var nested = input.AsNest();
return ReduceInternal(nested).AsNest();
}

private static INestStructure<T> ReduceInternal<TOut>(Nest<TOut> node) where TOut : INestStructure<T>
{
if(node.NestType == NestType.Empty)
{
return Nest<T>.Empty;
}
else if(node.NestType == NestType.Node)
{
return node.NodeValue!.AsNest();
}
else if(node.NestType == NestType.List)
{
return new Nest<T>(node.ListValue!.Select(x => ReduceInternal(x.AsNest())));
}
else // Dictionary type
{
return new Nest<T>(node.DictValue!.ToDictionary(x => x.Key, x => ReduceInternal(x.Value.AsNest())));
}
}

private static bool FindInternal(Nest<T> node, int index, out T? result)
{
if (node.NestType == NestType.Node)
{
if(index == 0)
{
result = node.NodeValue!;
return true;
}
result = default(T);
return false;
}
else if (node.NestType == NestType.List)
{
foreach (var item in node.ListValue!)
{
if(index == 0)
{
return FindInternal(item.AsNest(), index, out result);
}
index--;
}
result = default(T);
return false;
}
else if(node.NestType == NestType.Dictionary)
{
foreach (var item in node.DictValue!.Values)
{
if (index == 0)
{
return FindInternal(item.AsNest(), index, out result);
}
index--;
}
result = default(T);
return false;
}
else
{
result = default(T);
return false;
}
}

private static bool SetInternal(Nest<T> node, int index, T newValue)
{
if (node.NestType == NestType.Node)
{
if (index == 0)
{
node.NodeValue = newValue;
return true;
}
return false;
}
else if (node.NestType == NestType.List)
{
foreach (var item in node.ListValue!)
{
if (index == 0)
{
return SetInternal(item.AsNest(), index, newValue);
}
index--;
}
return false;
}
else if (node.NestType == NestType.Dictionary)
{
foreach (var item in node.DictValue!.Values)
{
if (index == 0)
{
return SetInternal(item.AsNest(), index, newValue);
}
index--;
}
return false;
}
else
{
return false;
}
}

private static IEnumerable<T> FlattenInternal(Nest<T> node)
{
if (node.NestType == NestType.Node)
{
yield return node.NodeValue!;
}
else if (node.NestType == NestType.List)
{
foreach (var item in node.ListValue!)
{
foreach(var val in FlattenInternal(item.AsNest()))
{
yield return val;
}
}
}
else if (node.NestType == NestType.Dictionary)
{
foreach (var item in node.DictValue!.Values)
{
foreach (var val in FlattenInternal(item.AsNest()))
{
yield return val;
}
}
}
}

private Nest<TOut> MapStructureInternal<TOut>(Func<T, TOut> func)
{
if (NestType == NestType.Node)
{
return new Nest<TOut>(func(NodeValue!));
}
else if (NestType == NestType.List)
{
List<Nest<TOut>> outs = new List<Nest<TOut>>();
foreach (var item in ListValue!)
{
outs.Add(item.AsNest().MapStructureInternal(func));
}
return new Nest<TOut>(outs);
}
else if (NestType == NestType.Dictionary)
{
Dictionary<string, INestStructure<TOut>> outs = new Dictionary<string, INestStructure<TOut>>();
foreach (var (key, value) in DictValue!)
{
outs.Add(key, value.AsNest().MapStructureInternal(func));
}
return new Nest<TOut>(outs);
}
else
{
return Nest<TOut>.Empty;
}
}

public IEnumerator<T> GetEnumerator()
{
return Flatten().GetEnumerator();
}

IEnumerator IEnumerable.GetEnumerator()
{
return GetEnumerator();
}

public override string ToString()
{
StringBuilder sb = new StringBuilder();
sb.Append("(");
WriteString(this, sb);
sb.Append(")");
return sb.ToString();
}

private static void WriteString(Nest<T> node, StringBuilder sb)
{
if (!string.IsNullOrEmpty(node.Name))
{
sb.Append($"{node.Name}: ");
}
if (node.NestType == NestType.Node)
{
sb.Append(node.NodeValue!.ToString());
}
else if (node.NestType == NestType.List)
{
sb.Append("[");
for(int i = 0; i < node.ListValue!.Count; i++)
{
WriteString(node.ListValue![i].AsNest(), sb);
if(i != node.ListValue!.Count - 1)
{
sb.Append(", ");
}
}
sb.Append("]");
}
else if (node.NestType == NestType.Dictionary)
{
sb.Append("{");
int count = node.DictValue!.Count;
int i = 0;
foreach (var (key, value) in node.DictValue!)
{
sb.Append($"{key}: ");
WriteString(value.AsNest(), sb);
if (i != count - 1)
{
sb.Append(", ");
}
i++;
}
sb.Append("}");
}
else
{
sb.Append("<empty>");
}
}

public static implicit operator Nest<T>((INestStructure<T>, INestStructure<T>) inputs)
{
return new Nest<T>(new INestStructure<T>[] { inputs.Item1, inputs.Item2 });
}

public static implicit operator Nest<T>((INestStructure<T>, INestStructure<T>, INestStructure<T>) inputs)
{
return new Nest<T>(new INestStructure<T>[] { inputs.Item1, inputs.Item2, inputs.Item3 });
}
}
}

+ 103
- 0
src/TensorFlowNET.Core/Common/Types/NestDictionary.cs View File

@@ -0,0 +1,103 @@
using System;
using System.Collections.Generic;
using System.Text;

namespace Tensorflow.Common.Types
{
public class NestDictionary<TKey, TValue> : INestStructure<TValue>, IDictionary<TKey, TValue> where TKey : notnull
{
public NestType NestType => NestType.Dictionary;
public IDictionary<TKey, TValue> Value { get; set; }
public int ShallowNestedCount => Values.Count;

public int TotalNestedCount => Values.Count;
public NestDictionary(IDictionary<TKey, TValue> dict)
{
Value = dict;
}
public IEnumerable<TValue> Flatten()
{
return Value.Select(x => x.Value);
}
public INestStructure<TOut> MapStructure<TOut>(Func<TValue, TOut> func)
{
return new NestList<TOut>(Value.Select(x => func(x.Value)));
}

public Nest<TValue> AsNest()
{
return new Nest<TValue>(Value.Values.Select(x => new Nest<TValue>(x)));
}

// Required IDictionary<TKey, TValue> members
public int Count => Value.Count;

public bool IsReadOnly => Value.IsReadOnly;

public ICollection<TKey> Keys => Value.Keys;

public ICollection<TValue> Values => Value.Values;

public void Add(TKey key, TValue value)
{
Value.Add(key, value);
}

public void Add(KeyValuePair<TKey, TValue> item)
{
Value.Add(item);
}

public void Clear()
{
Value.Clear();
}

public bool Contains(KeyValuePair<TKey, TValue> item)
{
return Value.Contains(item);
}

public bool ContainsKey(TKey key)
{
return Value.ContainsKey(key);
}

public void CopyTo(KeyValuePair<TKey, TValue>[] array, int arrayIndex)
{
Value.CopyTo(array, arrayIndex);
}

public IEnumerator<KeyValuePair<TKey, TValue>> GetEnumerator()
{
return Value.GetEnumerator();
}

IEnumerator IEnumerable.GetEnumerator()
{
return GetEnumerator();
}

public bool Remove(TKey key)
{
return Value.Remove(key);
}

public bool Remove(KeyValuePair<TKey, TValue> item)
{
return Value.Remove(item);
}

public bool TryGetValue(TKey key, out TValue value)
{
return Value.TryGetValue(key, out value);
}

// Optional IDictionary<TKey, TValue> members
public TValue this[TKey key]
{
get => Value[key];
set => Value[key] = value;
}
}
}

+ 53
- 0
src/TensorFlowNET.Core/Common/Types/NestList.cs View File

@@ -0,0 +1,53 @@
using System;
using System.Collections.Generic;
using System.Text;

namespace Tensorflow.Common.Types
{
/// <summary>
/// The implementation of a list that support nest structure, in which the depth is 1.
/// </summary>
/// <typeparam name="T"></typeparam>
public sealed class NestList<T> : INestStructure<T>, IEnumerable<T>
{
public NestType NestType => NestType.List;
public List<T> Values { get; set; }
public int ShallowNestedCount => Values.Count;

public int TotalNestedCount => Values.Count;

public NestList(params T[] values)
{
Values = new List<T>(values);
}

public NestList(IEnumerable<T> values)
{
Values = new List<T>(values);
}
public IEnumerable<T> Flatten()
{
return Values;
}
public INestStructure<TOut> MapStructure<TOut>(Func<T, TOut> func)
{
return new NestList<TOut>(Values.Select(x => func(x)));
}

public Nest<T> AsNest()
{
return new Nest<T>(Values.Select(x => new Nest<T>(x)));
}

// Enumerator implementation
public IEnumerator<T> GetEnumerator()
{
return Values.GetEnumerator();
}

IEnumerator IEnumerable.GetEnumerator()
{
return GetEnumerator();
}
}
}

+ 36
- 0
src/TensorFlowNET.Core/Common/Types/NestNode.cs View File

@@ -0,0 +1,36 @@
using System;
using System.Collections.Generic;
using System.Text;

namespace Tensorflow.Common.Types
{
/// <summary>
/// A nested structure with only one element.
/// </summary>
/// <typeparam name="T"></typeparam>
public class NestNode<T> : INestStructure<T>
{
public NestType NestType => NestType.Node;
public T Value { get; set; }
public int ShallowNestedCount => 1;

public int TotalNestedCount => 1;
public NestNode(T value)
{
Value = value;
}
public IEnumerable<T> Flatten()
{
yield return Value;
}
public INestStructure<TOut> MapStructure<TOut>(Func<T, TOut> func)
{
return new NestNode<TOut>(func(Value));
}

public Nest<T> AsNest()
{
return new Nest<T>(Value);
}
}
}

src/TensorFlowNET.Core/Keras/Saving/TensorShapeConfig.cs → src/TensorFlowNET.Core/Common/Types/TensorShapeConfig.cs View File

@@ -3,7 +3,7 @@ using System;
using System.Collections.Generic;
using System.Linq;

namespace Tensorflow.Keras.Saving
namespace Tensorflow.Common.Types
{
public class TensorShapeConfig
{

+ 2
- 2
src/TensorFlowNET.Core/Data/DatasetV2.cs View File

@@ -161,8 +161,8 @@ namespace Tensorflow
break;
}

yield return (new Tensors(results.Take(FirstInputTensorCount)), results.Length == FirstInputTensorCount ?
null : new Tensors(results.Skip(FirstInputTensorCount)));
yield return (new Tensors(results.Take(FirstInputTensorCount).ToArray()), results.Length == FirstInputTensorCount ?
null : new Tensors(results.Skip(FirstInputTensorCount).ToArray()));
}
}



+ 7
- 1
src/TensorFlowNET.Core/Eager/EagerRunner.TFE_FastPathExecute.cs View File

@@ -352,13 +352,19 @@ namespace Tensorflow.Eager
c_api.TFE_OpSetAttrFloat(op, key, Convert.ToSingle(value));
break;
case TF_AttrType.TF_ATTR_SHAPE:
var dims = (value as long[]).ToArray();
long[] dims;
if (value is Shape shape) dims = shape.dims.ToArray();
else if (value is long[] longs) dims = longs.ToArray();
else if (value is int[] ints) dims = ints.Select(x => (long)x).ToArray();
else dims = ((long[])value).ToArray();
c_api.TFE_OpSetAttrShape(op, key, dims, dims.Length, status);
status.Check(true);
break;
case TF_AttrType.TF_ATTR_FUNC:
if (value is ConcreteFunction func)
c_api.TFE_OpSetAttrFunctionName(op, key, func.func_graph.FuncName, func.func_graph.FuncName.Length);
else if(value is string str)
c_api.TFE_OpSetAttrFunctionName(op, key, str, str.Length);
else
throw new NotImplementedException("TF_AttrType.TF_ATTR_FUNC");
break;


+ 7
- 3
src/TensorFlowNET.Core/Eager/EagerRunner.TFE_TapeGradient.cs View File

@@ -65,7 +65,7 @@ namespace Tensorflow.Eager
{
outgrad_vec = output_gradients.ToList();
}
var result = tape.ComputeGradient(target_vec, sources_vec, source_tensors_that_are_targets, outgrad_vec, false);
var result = tape.ComputeGradient(target_vec, sources_vec, source_tensors_that_are_targets, outgrad_vec, true);


bool unconnected_gradients_zero = unconnected_gradients == "zero";
@@ -137,7 +137,6 @@ namespace Tensorflow.Eager
{
dims[i] = c_api.TFE_TensorHandleDim(handle, i, status);
}
Shape tensor_shape = new(dims);

if(status.Code != TF_Code.TF_OK)
{
@@ -145,6 +144,7 @@ namespace Tensorflow.Eager
}
else
{
Shape tensor_shape = new(dims);
return new TapeTensor(id, dtype, tensor_shape);
}
}
@@ -173,8 +173,12 @@ namespace Tensorflow.Eager
return dtype == dtypes.variant || dtype == dtypes.resource;
}

bool ListContainNone(long[] list)
bool ListContainNone(long[]? list)
{
if(list is null)
{
return true;
}
int len = list.Length;
if(len == 0)
{


+ 6
- 1
src/TensorFlowNET.Core/Eager/EagerTensor.ToString.cs View File

@@ -10,6 +10,11 @@ namespace Tensorflow.Eager
var str = NDArrayRender.ToString(nd);
return $"tf.Tensor: shape={shape}, dtype={dtype.as_numpy_name()}, numpy={str}";
}
public string ToString(int maxLength)
{
var nd = new NDArray(this);
var str = NDArrayRender.ToString(nd, maxLength);
return $"tf.Tensor: shape={shape}, dtype={dtype.as_numpy_name()}, numpy={str}";
}
}
}

+ 25
- 0
src/TensorFlowNET.Core/Eager/GraphOnlyOps.cs View File

@@ -0,0 +1,25 @@
using Tensorflow;

internal static class GraphOnlyOps
{
/// <summary>
/// Graph-only version of tf.compat.v1.placeholder(), for internal use only.
/// </summary>
/// <param name="dtyype"></param>
/// <param name="shape"></param>
/// <param name="name"></param>
/// <returns></returns>
internal static Tensor graph_placeholder(TF_DataType dtype, Shape shape, string? name = null)
{
var dtype_value = new AttrValue() { Type = dtype.as_datatype_enum() };
var shape_value = new AttrValue() { Shape = shape.as_proto() };
var g = ops.get_default_graph();
Dictionary<string, AttrValue> attrs = new();
attrs["dtype"] = dtype_value;
attrs["shape"] = shape_value;
var op = g.create_op("Placeholder", new Tensor[0], new TF_DataType[] { dtype },
new TF_DataType[0], attrs: attrs, name: name);
var result = op.outputs[0];
return result;
}
}

+ 19
- 0
src/TensorFlowNET.Core/Exceptions/NotOkStatusException.cs View File

@@ -0,0 +1,19 @@
using System;
using System.Collections.Generic;
using System.Text;

namespace Tensorflow.Exceptions
{
public class NotOkStatusException : TensorflowException
{
public NotOkStatusException() : base()
{

}

public NotOkStatusException(string message) : base(message)
{

}
}
}

+ 14
- 1
src/TensorFlowNET.Core/Framework/IndexedSlices.cs View File

@@ -49,12 +49,25 @@ namespace Tensorflow.Framework

public static implicit operator Tensor(IndexedSlices indexedSlices)
{
return indexedSlices.values;
return _indexed_slices_to_tensor(indexedSlices);
}

public static implicit operator IndexedSlices(Tensor tensor)
{
return tensor.Tag as IndexedSlices;
}

/// <summary>
/// Converts an IndexedSlices object `value` to a Tensor.
/// </summary>
/// <param name="indexedSlices"></param>
/// <param name="dtype"></param>
/// <param name="name"></param>
/// <param name="as_ref"></param>
/// <returns></returns>
public static Tensor _indexed_slices_to_tensor(IndexedSlices indexedSlices, TF_DataType dtype = TF_DataType.DtInvalid, String name = "", bool as_ref = false)
{
return gen_math_ops.unsorted_segment_sum(indexedSlices.values, indexedSlices.indices, indexedSlices.dense_shape.slice(0));
}
}
}

+ 13
- 0
src/TensorFlowNET.Core/Framework/Models/TensorSpec.cs View File

@@ -1,4 +1,5 @@
using System.Linq;
using Tensorflow.Eager;

namespace Tensorflow.Framework.Models
{
@@ -24,5 +25,17 @@ namespace Tensorflow.Framework.Models
shapes.Insert(0, dim);
return new TensorSpec(shapes.ToArray(), _dtype);
}

public static TensorSpec FromTensor(Tensor tensor, string? name = null)
{
if(tensor is EagerTensor)
{
return new TensorSpec(tensor.shape, tensor.dtype, name);
}
else
{
return new TensorSpec(tensor.shape, tensor.dtype, name ?? tensor.name);
}
}
}
}

+ 89
- 0
src/TensorFlowNET.Core/Framework/auto_control_deps_utils.cs View File

@@ -0,0 +1,89 @@
using Tensorflow.Graphs;

namespace Tensorflow.Framework
{
internal static class auto_control_deps_utils
{
public static readonly string READ_ONLY_RESOURCE_INPUTS_ATTR = "_read_only_resource_inputs";
public static List<int> get_read_only_resource_input_indices_graph(FuncGraph func_graph)
{
List<int> result = new List<int>();
// A cache to store the read only resource inputs of an Op.
// Operation -> ObjectIdentitySet of resource handles.
Dictionary<Operation, HashSet<Tensor>> opReadOnlyResourceInputs =
new Dictionary<Operation, HashSet<Tensor>>();

for (int inputIndex = 0; inputIndex < func_graph.Inputs.Length; inputIndex++)
{
Tensor t = func_graph.Inputs[inputIndex];
if (t.dtype != dtypes.resource)
continue;

bool readOnly = true;
foreach (var op in t.consumers())
{
if (opReadOnlyResourceInputs.ContainsKey(op))
{
if (!opReadOnlyResourceInputs[op].Contains(t))
{
readOnly = false;
break;
}
}
else
{
List<int> indices = _get_read_only_resource_input_indices_op(op);
opReadOnlyResourceInputs[op] = new HashSet<Tensor>(
indices.Select(i => op.inputs[i]));
if (!opReadOnlyResourceInputs[op].Contains(t))
{
readOnly = false;
break;
}
}
}

if (readOnly)
result.Add(inputIndex);
}

return result;
}

private static List<int> _get_read_only_resource_input_indices_op(Operation op)
{
// ignore the RESOURCE_READ_OPS

int[] read_only_input_indices;

try
{
read_only_input_indices = op.get_attr<int[]>(READ_ONLY_RESOURCE_INPUTS_ATTR);
}
catch (InvalidArgumentError)
{
return new List<int>();
}

int read_only_index = 0;
List<int> result = new();
for (int i = 0; i < op.inputs.Length; i++)
{
if (read_only_index >= read_only_input_indices.Length)
{
break;
}
if (op.inputs[i].dtype != dtypes.resource)
{
continue;
}
if (read_only_index < read_only_input_indices.Length && i == read_only_input_indices[read_only_index])
{
result.Add(i);
read_only_index++;
}
}
return result;
}
}
}

+ 2
- 2
src/TensorFlowNET.Core/Framework/function_def_lib.cs View File

@@ -42,10 +42,10 @@ namespace Tensorflow.Framework
func_graph.as_default();
importer.import_graph_def(graph_def, name: "", validate_colocation_constraints: false);
var input_tensor_names = fdef.Signature.InputArg.Select(x => nested_to_flat_tensor_name[x.Name]);
func_graph.Inputs = new Tensors(input_tensor_names.Select(x => func_graph.get_tensor_by_name(x)));
func_graph.Inputs = new Tensors(input_tensor_names.Select(x => func_graph.get_tensor_by_name(x)).ToArray());

var output_tensor_names = fdef.Signature.OutputArg.Select(x => nested_to_flat_tensor_name[fdef.Ret[x.Name]]);
func_graph.Outputs = new Tensors(output_tensor_names.Select(x => func_graph.get_tensor_by_name(x)));
func_graph.Outputs = new Tensors(output_tensor_names.Select(x => func_graph.get_tensor_by_name(x)).ToArray());
// TODO(Rinne): func_graph.ControlOutputs
_set_handle_data(func_graph, fdef);



+ 13
- 0
src/TensorFlowNET.Core/Functions/ConcreteFunction.cs View File

@@ -8,6 +8,7 @@ using Tensorflow.Gradients;
using Tensorflow.Graphs;
using Tensorflow.Train;
using Tensorflow.Util;
using Tensorflow.Common.Extensions;
using static Tensorflow.Binding;

namespace Tensorflow.Functions
@@ -40,6 +41,18 @@ namespace Tensorflow.Functions
public Tensor[] FlatStructuredOutputs => func_graph.FlatStructuredOutputs;
public IEnumerable<IVariableV1> Variables => func_graph.Variables;
public IEnumerable<IVariableV1> TrainableVariables => func_graph.TrainableVariables;
internal NameAttrList AsNameAttrList
{
get
{
NameAttrList ret = new() { Name = this.Name };
foreach (var (name, value) in _attrs)
{
ret.Attr[name] = value;
}
return ret;
}
}

public ConcreteFunction(string name)
{


+ 4
- 1
src/TensorFlowNET.Core/GlobalUsing.cs View File

@@ -3,4 +3,7 @@ global using System.Collections.Generic;
global using System.Text;
global using System.Collections;
global using System.Data;
global using System.Linq;
global using System.Linq;
global using Tensorflow.Keras.Engine;
global using Tensorflow.Framework.Models;
global using static Tensorflow.Binding;

+ 10
- 3
src/TensorFlowNET.Core/Gradients/array_grad.cs View File

@@ -90,8 +90,7 @@ namespace Tensorflow.Gradients
? input_values[0].rank + dim_int
: dim_int % input_values[0].rank;
var sizes = input_values.Select(x => x.shape[non_neg_concat_dim]).ToArray();
var sizes_tensor = constant_op.constant(sizes);
out_grads = array_ops.split(grad, sizes_tensor, non_neg_concat_dim).ToList();
out_grads = array_ops.split(grad, sizes.Select(x => (int)x).ToArray(), ops.convert_to_tensor(non_neg_concat_dim)).ToList();
}
else if (constant_op.is_constant(concat_dim))
{
@@ -127,7 +126,7 @@ namespace Tensorflow.Gradients
new Tensor[] { non_neg_concat_dim, tf.constant(0) },
new Tensor[] { tf.constant(1), tf.constant(-1) });
var squeeze_sizes = array_ops.squeeze(slice);
out_grads = array_ops.split(axis: grad, value: squeeze_sizes, num_split: (int)non_neg_concat_dim).ToList();
out_grads = array_ops.split(axis: grad, value: squeeze_sizes, num_or_size_splits: (int)non_neg_concat_dim).ToList();
}
else
{
@@ -374,5 +373,13 @@ namespace Tensorflow.Gradients
var p = op.inputs[1];
return new Tensor[] { array_ops.transpose(grads[0], array_ops.invert_permutation(p)), null };
}

[RegisterGradient("ReverseV2")]
public static Tensor[] _ReverseV2Grad(Operation op, Tensor[] grads)
{
var grad = grads[0];
var axis = op.inputs[1];
return new Tensor[] { array_ops.reverse(grad, axis), null };
}
}
}

+ 131
- 0
src/TensorFlowNET.Core/Gradients/math_grad.cs View File

@@ -117,6 +117,137 @@ namespace Tensorflow.Gradients
};
}

public static string ellipsis = "...";
[RegisterGradient("Einsum")]
public static Tensor[] _EinsumGrad(Operation op, Tensor[] grads)
{
// Gradient for Einsum.
string equation = (string)op.get_attr("equation");
string[] split_equation = equation.Split(new string[] { "->" }, StringSplitOptions.None);
var input_subs = split_equation[0];
var output_subs = split_equation[1];

if (op.inputs.Length == 1)
{
var input_shape = array_ops.shape(op.inputs[0]);
var reduced_label_set = new HashSet<char>(new HashSet<char>(input_subs).Except(new HashSet<char>(output_subs + ellipsis)));
if (reduced_label_set.Count == 0)
return new Tensor[] { math_ops.einsum(string.Format("{0}->{1}", output_subs, input_subs), new Tensors(grads)) };
return new Tensor[] { _GetGradReduced(new Tensors(grads), output_subs, input_subs, input_shape, reduced_label_set) };
}

string[] split_input_subs = input_subs.Split(new string[] { "," }, StringSplitOptions.None);
var x_subs = split_input_subs[0];
var y_subs = split_input_subs[1];
// Add ellipsis for broadcasted dimensions if any operand does not have it.
// This is because the equation "...ij,jk->ik" may be valid if the 0th input's
// batch shape is empty, but the VJP equation "jk,ik->...ij" is not valid
// because only the output subscripts contain ellipsis.
if (output_subs.Contains(ellipsis))
{
if (!x_subs.Contains(ellipsis))
x_subs += ellipsis;
if (!y_subs.Contains(ellipsis))
y_subs += ellipsis;
}
// Obtain the gradients wrt the inputs x and y, without taking into account
// the unbroadcasting.
var x = op.inputs[0];
var y = op.inputs[1];
if (grads.GetDataType().is_complex())
{
x = math_ops.conj(x);
y = math_ops.conj(y);
}

var x_shape = array_ops.shape(x);
var y_shape = array_ops.shape(y);
var grad_x = _GetGradWrt(grads, y, x_shape, x_subs, y_subs, output_subs);
var grad_y = _GetGradWrt(grads, x, y_shape, y_subs, x_subs, output_subs);

if (!output_subs.Contains(ellipsis))
return new Tensor[] { grad_x, grad_y };
var bx = _GetBcastSubshape(x_subs);
int bx_start = bx[0], bx_end = bx[1];
var by = _GetBcastSubshape(y_subs);
int by_start = by[0], by_end = by[1];

var x_shape_static = x.shape;
var y_shape_static = y.shape;
if(x_shape_static.IsFullyDefined &&
y_shape_static.IsFullyDefined &&
x_shape_static[string.Format("{0}:{1}",bx_start,bx_end)] == y_shape_static[string.Format("{0}:{1}", by_start, by_end)])
return new Tensor[] { grad_x, grad_y };

var r = gen_array_ops.broadcast_gradient_args(x_shape[string.Format("{0}:{1}", bx_start, bx_end)],
y_shape[string.Format("{0}:{1}", by_start, by_end)]);
var rx = r[0];
var ry = r[1];
grad_x = array_ops.reshape(math_ops.reduce_sum(grad_x, bx_start + rx), x_shape);
grad_y = array_ops.reshape(math_ops.reduce_sum(grad_y, by_start + ry), y_shape);
return new Tensor[] { grad_x, grad_y };
}
protected static Tensor _GetGradWrt(Tensor[] output_grads, Tensor other_operand, Tensor input_shape,
string input_subs, string other_subs, string output_subs)
{
var reduced_label_set = new HashSet<char>(new HashSet<char>(input_subs).Except(new HashSet<char>(output_subs + other_subs + ".")));
var left_subs = string.Join("", input_subs.Where(s => !reduced_label_set.Contains(s)));
var grad_reduced = math_ops.einsum(string.Format("{0},{1}->{2}", output_subs, other_subs, left_subs), new Tensors((Tensors)output_grads, other_operand));
if (reduced_label_set.Count == 0)
return grad_reduced;
return _GetGradReduced(grad_reduced, left_subs, input_subs, input_shape, reduced_label_set);
}
protected static Tensor _GetGradReduced(Tensor output_grad, string output_subs, string input_subs, Tensor input_shape, HashSet<char> reduced_label_set)
{
string reduced_subs;
Tensor reduced_dims;
List<int> reduced_axes;
_GetReducedSubscripts(reduced_label_set, input_shape, input_subs, out reduced_subs, out reduced_dims, out reduced_axes);
bool has_repeated_labels = (
new HashSet<char>(input_subs).Count + new HashSet<char>(output_subs).Count <
input_subs.Length + output_subs.Length);
var input_subs_without_reduced_labels = string.Join("", input_subs.Where(s => !reduced_label_set.Contains(s)));

if (!has_repeated_labels && input_subs_without_reduced_labels == output_subs)
{
var reduced_shape = math_ops.reduced_shape(input_shape, ops.convert_to_tensor(reduced_axes));
return gen_array_ops.broadcast_to(array_ops.reshape(output_grad, reduced_shape), input_shape);
}
else
{
var grad_shape_with_reduced_labels = array_ops.concat(new Tensor[] { reduced_dims, array_ops.shape(new Tensors(output_grad)) }, axis: 0);
var reduced_shape = array_ops.concat(new Tensor[] { array_ops.ones(reduced_label_set.Count, dtype: dtypes.int32), array_ops.shape(new Tensors(output_grad)) }, axis: 0);
var broadcasted_grad = gen_array_ops.broadcast_to(array_ops.reshape(output_grad, reduced_shape), grad_shape_with_reduced_labels);
return math_ops.einsum(string.Format("{0}->{1}", reduced_subs + output_subs, input_subs), new Tensors(broadcasted_grad));
}
}
protected static void _GetReducedSubscripts(HashSet<char> reduced_label_set, Tensor input_shape, string subscripts, out string reduced_subs, out Tensor reduced_dims, out List<int> reduced_axes)
{
reduced_subs = string.Join("", reduced_label_set.Select(c => c.ToString()));
reduced_axes = reduced_subs.Select(s => _GetAxisFromLabel(subscripts, s)).ToList();
reduced_dims = array_ops.stack(reduced_axes.Select(ax => input_shape[ax]).ToList());
}
protected static int _GetAxisFromLabel(string subscripts, char label)
{
var splits = subscripts.Split(new string[] { ellipsis }, StringSplitOptions.None);
var index = splits[0].IndexOf(label);
if (index != -1) return index;
if (splits.Length < 2) throw new OutOfRangeError();
index = splits[1].IndexOf(label);
if (index != -1) return index;
throw new ValueError();
}
protected static int[] _GetBcastSubshape(string subscripts)
{
int start = subscripts.IndexOf(ellipsis);
if (start == -1) return new int[] { 0, 0 };
int remaining = subscripts.Length - (start + ellipsis.Length);
int end;
if (remaining > 0) end = remaining;
else throw new Exception();
return new int[] { start, end };
}

/// <summary>
/// Returns grad * exp(x).
/// </summary>


+ 17
- 0
src/TensorFlowNET.Core/Gradients/nn_grad.cs View File

@@ -365,6 +365,23 @@ namespace Tensorflow.Gradients
};
}

[RegisterGradient("AvgPool")]
public static Tensor[] _AvgPoolGrad(Operation op, Tensor[] grads)
{
Tensor grad = grads[0];
return new Tensor[]
{
gen_nn_ops.avg_pool_grad(
array_ops.shape(op.inputs[0]),
grad,
op.get_attr_list<int>("ksize"),
op.get_attr_list<int>("strides"),
op.get_attr<string>("padding"),
op.get_attr<string>("data_format"))
};
}

/// <summary>
/// Return the gradients for TopK.
/// </summary>


+ 8
- 8
src/TensorFlowNET.Core/Graphs/FuncGraph.cs View File

@@ -81,7 +81,7 @@ public class FuncGraph : Graph, IDisposable
public IEnumerable<IVariableV1> TrainableVariables => Variables.Where(v => v.Trainable);
public Dictionary<string, AttrValue> Attrs { get; set; }

Dictionary<long, (Tensor, Tensor)> _captures
internal Dictionary<long, (Tensor, Tensor)> _captures
= new Dictionary<long, (Tensor, Tensor)>();

public Tensor[] external_captures
@@ -399,7 +399,7 @@ public class FuncGraph : Graph, IDisposable
var flat_func_args = nest.flatten(func_args as object);
var flat_func_kwargs = nest.flatten(func_kwargs as object);
func_graph.Inputs = new Tensors(flat_func_args.concat(flat_func_kwargs)
.Where(x => x is Tensor).Select(x => (Tensor)x));
.Where(x => x is Tensor).Select(x => (Tensor)x).ToArray());

//var func_args_before = nest.pack_sequence_as(func_args, flat_func_args, true);
//var func_kwargs_before = nest.pack_sequence_as(func_kwargs, flat_func_kwargs, true);
@@ -544,12 +544,12 @@ public class FuncGraph : Graph, IDisposable
Tensor placeholder;
try
{
placeholder = tf.placeholder(tensor.dtype, tensor.shape, name);
placeholder = GraphOnlyOps.graph_placeholder(tensor.dtype, tensor.shape, name);
}
catch (ValueError)
catch (ValueError ex)
{
// TODO(Rinne): Add warning here.
placeholder = tf.placeholder(tensor.dtype, tensor.shape);
tf.Logger.Warning(ex.ToString());
placeholder = GraphOnlyOps.graph_placeholder(tensor.dtype, tensor.shape);
}
handle_data_util.copy_handle_data(tensor, placeholder);
if (name is not null)
@@ -575,12 +575,12 @@ public class FuncGraph : Graph, IDisposable
Tensor placeholder;
try
{
placeholder = tf.placeholder(spec.dtype, spec.shape, requested_name);
placeholder = GraphOnlyOps.graph_placeholder(spec.dtype, spec.shape, requested_name);
}
catch (ValueError)
{
// TODO(Rinne): Add warning here.
placeholder = tf.placeholder(spec.dtype, spec.shape);
placeholder = GraphOnlyOps.graph_placeholder(spec.dtype, spec.shape);
}
if (name is not null)
{


+ 1
- 1
src/TensorFlowNET.Core/Graphs/Graph.cs View File

@@ -129,7 +129,7 @@ namespace Tensorflow
}
}

protected Graph outer_graph;
internal Graph outer_graph;
public Graph OuterGraph => outer_graph;
public Dictionary<string, EagerDefinedFunction> Functions => _functions;
public SafeGraphHandle c_graph => _handle;


+ 10
- 0
src/TensorFlowNET.Core/Keras/ArgsDefinition/Activation/ExponentialArgs.cs View File

@@ -0,0 +1,10 @@
using System;
using System.Collections.Generic;
using System.Text;

namespace Tensorflow.Keras.ArgsDefinition
{
public class ExponentialArgs : LayerArgs
{
}
}

+ 10
- 0
src/TensorFlowNET.Core/Keras/ArgsDefinition/Activation/HardSigmoidArgs.cs View File

@@ -0,0 +1,10 @@
using System;
using System.Collections.Generic;
using System.Text;

namespace Tensorflow.Keras.ArgsDefinition
{
public class HardSigmoidArgs : LayerArgs
{
}
}

+ 11
- 0
src/TensorFlowNET.Core/Keras/ArgsDefinition/Activation/SELUArgs.cs View File

@@ -0,0 +1,11 @@
using System;
using System.Collections.Generic;
using System.Text;

namespace Tensorflow.Keras.ArgsDefinition
{
public class SELUArgs : LayerArgs
{

}
}

+ 10
- 0
src/TensorFlowNET.Core/Keras/ArgsDefinition/Activation/SoftplusArgs.cs View File

@@ -0,0 +1,10 @@
using System;
using System.Collections.Generic;
using System.Text;

namespace Tensorflow.Keras.ArgsDefinition
{
public class SoftplusArgs : LayerArgs
{
}
}

+ 10
- 0
src/TensorFlowNET.Core/Keras/ArgsDefinition/Activation/SoftsignArgs.cs View File

@@ -0,0 +1,10 @@
using System;
using System.Collections.Generic;
using System.Text;

namespace Tensorflow.Keras.ArgsDefinition
{
public class SoftsignArgs : LayerArgs
{
}
}

+ 10
- 0
src/TensorFlowNET.Core/Keras/ArgsDefinition/Activation/SwishArgs.cs View File

@@ -0,0 +1,10 @@
using System;
using System.Collections.Generic;
using System.Text;

namespace Tensorflow.Keras.ArgsDefinition
{
public class SwishArgs : LayerArgs
{
}
}

+ 10
- 0
src/TensorFlowNET.Core/Keras/ArgsDefinition/Activation/TanhArgs.cs View File

@@ -0,0 +1,10 @@
using System;
using System.Collections.Generic;
using System.Text;

namespace Tensorflow.Keras.ArgsDefinition
{
public class TanhArgs : LayerArgs
{
}
}

+ 10
- 0
src/TensorFlowNET.Core/Keras/ArgsDefinition/Convolution/Conv2DTransposeArgs.cs View File

@@ -0,0 +1,10 @@
using System;
using System.Collections.Generic;
using System.Text;

namespace Tensorflow.Keras.ArgsDefinition
{
public class Conv2DTransposeArgs : Conv2DArgs
{
}
}

+ 10
- 0
src/TensorFlowNET.Core/Keras/ArgsDefinition/Merging/AddArgs.cs View File

@@ -0,0 +1,10 @@
using System;
using System.Collections.Generic;
using System.Text;

namespace Tensorflow.Keras.ArgsDefinition
{
public class AddArgs : MergeArgs
{
}
}

+ 10
- 0
src/TensorFlowNET.Core/Keras/ArgsDefinition/Merging/ConcatenateArgs.cs View File

@@ -0,0 +1,10 @@
using System;
using System.Collections.Generic;
using System.Text;

namespace Tensorflow.Keras.ArgsDefinition
{
public class ConcatenateArgs : MergeArgs
{
}
}

+ 10
- 0
src/TensorFlowNET.Core/Keras/ArgsDefinition/Merging/SubtractArgs.cs View File

@@ -0,0 +1,10 @@
using System;
using System.Collections.Generic;
using System.Text;

namespace Tensorflow.Keras.ArgsDefinition
{
public class SubtractArgs : MergeArgs
{
}
}

+ 10
- 0
src/TensorFlowNET.Core/Keras/ArgsDefinition/Pooling/GlobalAveragePooling1DArgs.cs View File

@@ -0,0 +1,10 @@
using System;
using System.Collections.Generic;
using System.Text;

namespace Tensorflow.Keras.ArgsDefinition
{
public class GlobalAveragePooling1DArgs : Pooling1DArgs
{
}
}

+ 10
- 0
src/TensorFlowNET.Core/Keras/ArgsDefinition/Pooling/GlobalAveragePooling2DArgs.cs View File

@@ -0,0 +1,10 @@
using System;
using System.Collections.Generic;
using System.Text;

namespace Tensorflow.Keras.ArgsDefinition
{
public class GlobalAveragePooling2DArgs : Pooling2DArgs
{
}
}

+ 10
- 0
src/TensorFlowNET.Core/Keras/ArgsDefinition/Pooling/GlobalMaxPooling1DArgs.cs View File

@@ -0,0 +1,10 @@
using System;
using System.Collections.Generic;
using System.Text;

namespace Tensorflow.Keras.ArgsDefinition
{
public class GlobalMaxPooling1DArgs : Pooling1DArgs
{
}
}

+ 10
- 0
src/TensorFlowNET.Core/Keras/ArgsDefinition/Pooling/GlobalMaxPooling2DArgs.cs View File

@@ -0,0 +1,10 @@
using System;
using System.Collections.Generic;
using System.Text;

namespace Tensorflow.Keras.ArgsDefinition
{
public class GlobalMaxPooling2DArgs : Pooling2DArgs
{
}
}

+ 10
- 0
src/TensorFlowNET.Core/Keras/ArgsDefinition/Pooling/MaxPooling1DArgs.cs View File

@@ -0,0 +1,10 @@
using System;
using System.Collections.Generic;
using System.Text;

namespace Tensorflow.Keras.ArgsDefinition
{
public class MaxPooling1DArgs : Pooling1DArgs
{
}
}

+ 1
- 1
src/TensorFlowNET.Core/Keras/ArgsDefinition/Reshaping/UpSampling2DArgs.cs View File

@@ -7,7 +7,7 @@ namespace Tensorflow.Keras.ArgsDefinition
[JsonProperty("size")]
public Shape Size { get; set; }
[JsonProperty("data_format")]
public string DataFormat { get; set; }
public string DataFormat { get; set; } = "channels_last";
/// <summary>
/// 'nearest', 'bilinear'
/// </summary>


+ 10
- 0
src/TensorFlowNET.Core/Keras/ArgsDefinition/Reshaping/Upsampling1DArgs.cs View File

@@ -0,0 +1,10 @@
using Newtonsoft.Json;

namespace Tensorflow.Keras.ArgsDefinition
{
public class UpSampling1DArgs : AutoSerializeLayerArgs
{
[JsonProperty("size")]
public int Size { get; set; }
}
}

+ 20
- 0
src/TensorFlowNET.Core/Keras/ArgsDefinition/Rnn/BidirectionalArgs.cs View File

@@ -0,0 +1,20 @@
using Newtonsoft.Json;
using System;
using System.Collections.Generic;
using System.Text;
using Tensorflow.NumPy;

namespace Tensorflow.Keras.ArgsDefinition
{
public class BidirectionalArgs : AutoSerializeLayerArgs
{
[JsonProperty("layer")]
public ILayer Layer { get; set; }
[JsonProperty("merge_mode")]
public string? MergeMode { get; set; }
[JsonProperty("backward_layer")]
public ILayer BackwardLayer { get; set; }
public NDArray Weights { get; set; }
}

}

+ 29
- 0
src/TensorFlowNET.Core/Keras/ArgsDefinition/Rnn/GRUArgs.cs View File

@@ -0,0 +1,29 @@
using System;
using System.Collections.Generic;
using System.Text;

namespace Tensorflow.Keras.ArgsDefinition
{
public class GRUArgs : AutoSerializeLayerArgs
{
public int Units { get; set; }
public Activation Activation { get; set; }
public Activation RecurrentActivation { get; set; }
public bool UseBias { get; set; } = true;
public float Dropout { get; set; } = .0f;
public float RecurrentDropout { get; set; } = .0f;
public IInitializer KernelInitializer { get; set; }
public IInitializer RecurrentInitializer { get; set; }
public IInitializer BiasInitializer { get; set; }
public bool ReturnSequences { get;set; }
public bool ReturnState { get;set; }
public bool GoBackwards { get;set; }
public bool Stateful { get;set; }
public bool Unroll { get;set; }
public bool TimeMajor { get;set; }
public bool ResetAfter { get;set; }
public int Implementation { get; set; } = 2;

}

}

+ 39
- 0
src/TensorFlowNET.Core/Keras/ArgsDefinition/Rnn/GRUCellArgs.cs View File

@@ -0,0 +1,39 @@
using Newtonsoft.Json;
using System;
using System.Collections.Generic;
using System.Text;

namespace Tensorflow.Keras.ArgsDefinition
{
public class GRUCellArgs : AutoSerializeLayerArgs
{
[JsonProperty("units")]
public int Units { get; set; }
// TODO(Rinne): lack of initialized value of Activation. Merging keras
// into tf.net could resolve it.
[JsonProperty("activation")]
public Activation Activation { get; set; }
[JsonProperty("recurrent_activation")]
public Activation RecurrentActivation { get; set; }
[JsonProperty("use_bias")]
public bool UseBias { get; set; } = true;
[JsonProperty("dropout")]
public float Dropout { get; set; } = .0f;
[JsonProperty("recurrent_dropout")]
public float RecurrentDropout { get; set; } = .0f;
[JsonProperty("kernel_initializer")]
public IInitializer KernelInitializer { get; set; }
[JsonProperty("recurrent_initializer")]
public IInitializer RecurrentInitializer { get; set; }
[JsonProperty("bias_initializer")]
public IInitializer BiasInitializer { get; set; }
[JsonProperty("reset_after")]
public bool ResetAfter { get;set; }
[JsonProperty("implementation")]
public int Implementation { get; set; } = 2;



}

}

+ 13
- 0
src/TensorFlowNET.Core/Keras/ArgsDefinition/Rnn/GRUOptionalArgs.cs View File

@@ -0,0 +1,13 @@
using System;
using System.Collections.Generic;
using System.Text;

namespace Tensorflow.Keras.ArgsDefinition
{
public class GRUOptionalArgs
{
public string Identifier => "GRU";

public Tensor Mask { get; set; } = null;
}
}

+ 6
- 3
src/TensorFlowNET.Core/Keras/ArgsDefinition/Rnn/LSTMArgs.cs View File

@@ -1,11 +1,14 @@
namespace Tensorflow.Keras.ArgsDefinition.Rnn
namespace Tensorflow.Keras.ArgsDefinition
{
public class LSTMArgs : RNNArgs
{
// TODO: maybe change the `RNNArgs` and implement this class.
public bool UnitForgetBias { get; set; }
public float Dropout { get; set; }
public float RecurrentDropout { get; set; }
public int Implementation { get; set; }

public LSTMArgs Clone()
{
return (LSTMArgs)MemberwiseClone();
}
}
}

+ 30
- 2
src/TensorFlowNET.Core/Keras/ArgsDefinition/Rnn/LSTMCellArgs.cs View File

@@ -1,7 +1,35 @@
namespace Tensorflow.Keras.ArgsDefinition.Rnn
using Newtonsoft.Json;
using static Tensorflow.Binding;

namespace Tensorflow.Keras.ArgsDefinition
{
// TODO: complete the implementation
public class LSTMCellArgs : LayerArgs
public class LSTMCellArgs : AutoSerializeLayerArgs
{
[JsonProperty("units")]
public int Units { get; set; }
// TODO(Rinne): lack of initialized value of Activation. Merging keras
// into tf.net could resolve it.
[JsonProperty("activation")]
public Activation Activation { get; set; }
[JsonProperty("recurrent_activation")]
public Activation RecurrentActivation { get; set; }
[JsonProperty("use_bias")]
public bool UseBias { get; set; } = true;
[JsonProperty("dropout")]
public float Dropout { get; set; } = .0f;
[JsonProperty("recurrent_dropout")]
public float RecurrentDropout { get; set; } = .0f;
[JsonProperty("kernel_initializer")]
public IInitializer KernelInitializer { get; set; }
[JsonProperty("recurrent_initializer")]
public IInitializer RecurrentInitializer { get; set; }
[JsonProperty("bias_initializer")]
public IInitializer BiasInitializer { get; set; }
[JsonProperty("unit_forget_bias")]
public bool UnitForgetBias { get; set; } = true;
[JsonProperty("implementation")]
public int Implementation { get; set; } = 2;

}
}

+ 20
- 25
src/TensorFlowNET.Core/Keras/ArgsDefinition/Rnn/RNNArgs.cs View File

@@ -1,17 +1,12 @@
using Newtonsoft.Json;
using System.Collections.Generic;
using Tensorflow.Keras.Layers;

namespace Tensorflow.Keras.ArgsDefinition.Rnn
namespace Tensorflow.Keras.ArgsDefinition
{
// TODO(Rinne): add regularizers.
public class RNNArgs : AutoSerializeLayerArgs
{
public interface IRnnArgCell : ILayer
{
object state_size { get; }
}
[JsonProperty("cell")]
// TODO: the cell should be serialized with `serialize_keras_object`.
public IRnnArgCell Cell { get; set; } = null;
[JsonProperty("return_sequences")]
public bool ReturnSequences { get; set; } = false;
[JsonProperty("return_state")]
@@ -24,31 +19,31 @@ namespace Tensorflow.Keras.ArgsDefinition.Rnn
public bool Unroll { get; set; } = false;
[JsonProperty("time_major")]
public bool TimeMajor { get; set; } = false;
// TODO: Add `num_constants` and `zero_output_for_mask`.
public Dictionary<string, object> Kwargs { get; set; } = null;

public int? InputDim { get; set; }
public int? InputLength { get; set; }
// TODO: Add `num_constants` and `zero_output_for_mask`.
[JsonProperty("units")]
public int Units { get; set; }
[JsonProperty("activation")]
public Activation Activation { get; set; }
[JsonProperty("recurrent_activation")]
public Activation RecurrentActivation { get; set; }
[JsonProperty("use_bias")]
public bool UseBias { get; set; } = true;
public IInitializer KernelInitializer { get; set; }
public IInitializer RecurrentInitializer { get; set; }
public IInitializer BiasInitializer { get; set; }
[JsonProperty("dropout")]
public float Dropout { get; set; } = .0f;
[JsonProperty("zero_output_for_mask")]
public bool ZeroOutputForMask { get; set; } = false;
[JsonProperty("recurrent_dropout")]
public float RecurrentDropout { get; set; } = .0f;

// kernel_regularizer=None,
// recurrent_regularizer=None,
// bias_regularizer=None,
// activity_regularizer=None,
// kernel_constraint=None,
// recurrent_constraint=None,
// bias_constraint=None,
// dropout=0.,
// recurrent_dropout=0.,
// return_sequences=False,
// return_state=False,
// go_backwards=False,
// stateful=False,
// unroll=False,
// **kwargs):
public RNNArgs Clone()
{
return (RNNArgs)MemberwiseClone();
}
}
}

+ 14
- 0
src/TensorFlowNET.Core/Keras/ArgsDefinition/Rnn/RnnOptionalArgs.cs View File

@@ -0,0 +1,14 @@
using System;
using System.Collections.Generic;
using System.Text;
using Tensorflow.Common.Types;

namespace Tensorflow.Keras.ArgsDefinition
{
public class RnnOptionalArgs: IOptionalArgs
{
public string Identifier => "Rnn";
public Tensor Mask { get; set; } = null;
public Tensors Constants { get; set; } = null;
}
}

+ 1
- 1
src/TensorFlowNET.Core/Keras/ArgsDefinition/Rnn/SimpleRNNArgs.cs View File

@@ -1,4 +1,4 @@
namespace Tensorflow.Keras.ArgsDefinition.Rnn
namespace Tensorflow.Keras.ArgsDefinition
{
public class SimpleRNNArgs : RNNArgs
{


+ 27
- 0
src/TensorFlowNET.Core/Keras/ArgsDefinition/Rnn/SimpleRNNCellArgs.cs View File

@@ -0,0 +1,27 @@
using Newtonsoft.Json;

namespace Tensorflow.Keras.ArgsDefinition
{
public class SimpleRNNCellArgs: AutoSerializeLayerArgs
{
[JsonProperty("units")]
public int Units { get; set; }
// TODO(Rinne): lack of initialized value of Activation. Merging keras
// into tf.net could resolve it.
[JsonProperty("activation")]
public Activation Activation { get; set; }
[JsonProperty("use_bias")]
public bool UseBias { get; set; } = true;
[JsonProperty("dropout")]
public float Dropout { get; set; } = .0f;
[JsonProperty("recurrent_dropout")]
public float RecurrentDropout { get; set; } = .0f;
[JsonProperty("kernel_initializer")]
public IInitializer KernelInitializer { get; set; }
[JsonProperty("recurrent_initializer")]
public IInitializer RecurrentInitializer { get; set; }
[JsonProperty("bias_initializer")]
public IInitializer BiasInitializer { get; set; }

}
}

+ 3
- 3
src/TensorFlowNET.Core/Keras/ArgsDefinition/Rnn/StackedRNNCellsArgs.cs View File

@@ -1,10 +1,10 @@
using System.Collections.Generic;
using Tensorflow.Keras.Layers;

namespace Tensorflow.Keras.ArgsDefinition.Rnn
namespace Tensorflow.Keras.ArgsDefinition
{
public class StackedRNNCellsArgs : LayerArgs
{
public IList<RnnCell> Cells { get; set; }
public Dictionary<string, object> Kwargs { get; set; } = null;
public bool ReverseStateOrder = false;
}
}

+ 24
- 0
src/TensorFlowNET.Core/Keras/ArgsDefinition/Rnn/WrapperArgs.cs View File

@@ -0,0 +1,24 @@
using Newtonsoft.Json;
using System;
using System.Collections.Generic;
using System.Runtime.CompilerServices;
using System.Text;


namespace Tensorflow.Keras.ArgsDefinition
{
public class WrapperArgs : AutoSerializeLayerArgs
{
[JsonProperty("layer")]
public ILayer Layer { get; set; }

public WrapperArgs(ILayer layer)
{
Layer = layer;
}

public static implicit operator WrapperArgs(BidirectionalArgs args)
=> new WrapperArgs(args.Layer);
}

}

+ 3
- 0
src/TensorFlowNET.Core/Keras/Engine/ICallback.cs View File

@@ -14,6 +14,9 @@ public interface ICallback
void on_predict_batch_end(long end_step, Dictionary<string, Tensors> logs);
void on_predict_end();
void on_test_begin();
void on_test_end(Dictionary<string, float> logs);
void on_test_batch_begin(long step);
void on_test_batch_end(long end_step, Dictionary<string, float> logs);


}

+ 1
- 1
src/TensorFlowNET.Core/Keras/Engine/IModel.cs View File

@@ -60,7 +60,7 @@ public interface IModel : ILayer
bool skip_mismatch = false,
object options = null);

Dictionary<string, float> evaluate(Tensor x, Tensor y,
Dictionary<string, float> evaluate(NDArray x, NDArray y,
int batch_size = -1,
int verbose = 1,
int steps = -1,


+ 75
- 0
src/TensorFlowNET.Core/Keras/Engine/KerasTensor.cs View File

@@ -0,0 +1,75 @@
namespace Tensorflow.Keras.Engine;

/// <summary>
/// A representation of a Keras in/output during Functional API construction.
/// </summary>
public class KerasTensor
{
private Tensors _original_tensors;
public Tensors original_tensors
{
get => _original_tensors;
set => _original_tensors = value;
}

private Shape _inferred_value;
public Shape inferred_value => _inferred_value;

private string _name;
private TensorSpec _type_spec;
public Shape shape => _type_spec.shape;
public TF_DataType dtype => _type_spec.dtype;

public KerasTensor(TensorSpec type_spec, Shape inferred_value = null, string name = null)
{
_type_spec = type_spec;
_inferred_value = inferred_value;
_name = name;
}

public static KerasTensor from_tensor(Tensor tensor)
{
var type_spec = tensor.ToTensorSpec();
Shape? inferred_value = default;
if (tensor.dtype == TF_DataType.TF_INT32 && tensor.rank < 2)
{
inferred_value = tf.ones(tensor).shape;
}
var kt = new KerasTensor(type_spec, inferred_value: inferred_value, name: tensor.name);
kt.original_tensors = tensor;
return kt;
}

public KerasTensor this[int idx]
=> _original_tensors.First()[idx];

public KerasTensor this[params Slice[] slices]
=> _original_tensors.First()[slices];

public override string ToString()
=> _original_tensors.Length switch
{
> 1 => "[" + string.Join(", ", _original_tensors.Select(x => $"KerasTensor: shape={x.shape} dtype={x.dtype.as_numpy_name()}{GetInferredValueString()}")) + "]",
1 => $"KerasTensor: shape={_original_tensors.shape} dtype={_original_tensors.dtype.as_numpy_name()}{GetInferredValueString()}",
_ => _original_tensors.ToString(),
};

private string GetInferredValueString()
=> _inferred_value == null ? "" : $" inferred_value={_inferred_value}";

public static implicit operator Tensors(KerasTensor kt)
=> kt._original_tensors;

public static implicit operator Tensor(KerasTensor kt)
{
Tensor tensor = kt._original_tensors;
tensor.IsFromKerasTensor = true;
return tensor;
}

public static implicit operator KerasTensor(Tensor tensor)
=> from_tensor(tensor);

public static implicit operator KerasTensor(Tensors tensors)
=> from_tensor(tensors.First());
}

+ 22
- 1
src/TensorFlowNET.Core/Keras/IOptimizerApi.cs View File

@@ -25,6 +25,27 @@ namespace Tensorflow.Keras
bool amsgrad = false,
string name = "Adam");

/// <summary>
/// Adam enables L2 weight decay on gradients.
/// </summary>
/// <param name="learning_rate"></param>
/// <param name="weight_decay"></param>
/// <param name="beta_1"></param>
/// <param name="beta_2"></param>
/// <param name="epsilon"></param>
/// <param name="amsgrad"></param>
/// <param name="decay_params"></param>
/// <param name="name"></param>
/// <returns></returns>
IOptimizer AdamW(float learning_rate = 0.001f,
float weight_decay = 0.004f,
float beta_1 = 0.9f,
float beta_2 = 0.999f,
float epsilon = 1e-7f,
bool amsgrad = false,
List<string> no_decay_params = null,
string name = "AdamW");

/// <summary>
/// Construct a new RMSprop optimizer.
/// </summary>
@@ -42,6 +63,6 @@ namespace Tensorflow.Keras
bool centered = false,
string name = "RMSprop");

IOptimizer SGD(float learning_rate);
IOptimizer SGD(float learning_rate = 0.01f, float momentum = 0f);
}
}

+ 3
- 2
src/TensorFlowNET.Core/Keras/Layers/ILayer.cs View File

@@ -1,4 +1,5 @@
using Tensorflow.Keras.Engine;
using Tensorflow.Common.Types;
using Tensorflow.Keras.Engine;
using Tensorflow.Keras.Saving;
using Tensorflow.NumPy;
using Tensorflow.Training;
@@ -14,7 +15,7 @@ namespace Tensorflow.Keras
List<ILayer> Layers { get; }
List<INode> InboundNodes { get; }
List<INode> OutboundNodes { get; }
Tensors Apply(Tensors inputs, Tensor state = null, bool training = false);
Tensors Apply(Tensors inputs, Tensors states = null, bool? training = false, IOptionalArgs? optional_args = null);
List<IVariableV1> TrainableVariables { get; }
List<IVariableV1> TrainableWeights { get; }
List<IVariableV1> NonTrainableWeights { get; }


+ 4
- 0
src/TensorFlowNET.Core/Keras/Layers/ILayersApi.Reshaping.cs View File

@@ -9,6 +9,10 @@ namespace Tensorflow.Keras.Layers
public ILayer Reshape(Shape target_shape);
public ILayer Reshape(object[] target_shape);

public ILayer UpSampling1D(
int size
);

public ILayer UpSampling2D(Shape size = null,
string data_format = null,
string interpolation = "nearest");


+ 91
- 1
src/TensorFlowNET.Core/Keras/Layers/ILayersApi.cs View File

@@ -1,5 +1,7 @@
using System;
using Tensorflow.Framework.Models;
using Tensorflow.Keras.Engine;
using Tensorflow.Keras.Layers;
using Tensorflow.NumPy;
using static Google.Protobuf.Reflection.FieldDescriptorProto.Types;

@@ -134,7 +136,7 @@ namespace Tensorflow.Keras.Layers
public ILayer GlobalMaxPooling1D(string data_format = "channels_last");
public ILayer GlobalMaxPooling2D(string data_format = "channels_last");

public Tensors Input(Shape shape = null,
public KerasTensor Input(Shape shape = null,
int batch_size = -1,
string name = null,
TF_DataType dtype = TF_DataType.DtInvalid,
@@ -159,6 +161,18 @@ namespace Tensorflow.Keras.Layers
public ILayer Normalization(Shape? input_shape = null, int? axis = -1, float? mean = null, float? variance = null, bool invert = false);
public ILayer LeakyReLU(float alpha = 0.3f);

public IRnnCell LSTMCell(int uints,
string activation = "tanh",
string recurrent_activation = "sigmoid",
bool use_bias = true,
string kernel_initializer = "glorot_uniform",
string recurrent_initializer = "orthogonal",
string bias_initializer = "zeros",
bool unit_forget_bias = true,
float dropout = 0f,
float recurrent_dropout = 0f,
int implementation = 2);

public ILayer LSTM(int units,
Activation activation = null,
Activation recurrent_activation = null,
@@ -192,6 +206,19 @@ namespace Tensorflow.Keras.Layers
float offset = 0,
Shape input_shape = null);

public IRnnCell SimpleRNNCell(
int units,
string activation = "tanh",
bool use_bias = true,
string kernel_initializer = "glorot_uniform",
string recurrent_initializer = "orthogonal",
string bias_initializer = "zeros",
float dropout = 0f,
float recurrent_dropout = 0f);

public IRnnCell StackedRNNCells(
IEnumerable<IRnnCell> cells);

public ILayer SimpleRNN(int units,
string activation = "tanh",
string kernel_initializer = "glorot_uniform",
@@ -200,6 +227,69 @@ namespace Tensorflow.Keras.Layers
bool return_sequences = false,
bool return_state = false);

public ILayer RNN(
IRnnCell cell,
bool return_sequences = false,
bool return_state = false,
bool go_backwards = false,
bool stateful = false,
bool unroll = false,
bool time_major = false
);

public ILayer RNN(
IEnumerable<IRnnCell> cell,
bool return_sequences = false,
bool return_state = false,
bool go_backwards = false,
bool stateful = false,
bool unroll = false,
bool time_major = false
);

public IRnnCell GRUCell(
int units,
string activation = "tanh",
string recurrent_activation = "sigmoid",
bool use_bias = true,
string kernel_initializer = "glorot_uniform",
string recurrent_initializer = "orthogonal",
string bias_initializer = "zeros",
float dropout = 0f,
float recurrent_dropout = 0f,
bool reset_after = true);

public ILayer GRU(
int units,
string activation = "tanh",
string recurrent_activation = "sigmoid",
bool use_bias = true,
string kernel_initializer = "glorot_uniform",
string recurrent_initializer = "orthogonal",
string bias_initializer = "zeros",
float dropout = 0f,
float recurrent_dropout = 0f,
bool return_sequences = false,
bool return_state = false,
bool go_backwards = false,
bool stateful = false,
bool unroll = false,
bool time_major = false,
bool reset_after = true
);

/// <summary>
/// Bidirectional wrapper for RNNs.
/// </summary>
/// <param name="layer">`keras.layers.RNN` instance, such as `keras.layers.LSTM` or `keras.layers.GRU`</param>
/// automatically.</param>
/// <returns></returns>
public ILayer Bidirectional(
ILayer layer,
string merge_mode = "concat",
NDArray weights = null,
ILayer backward_layer = null);

public ILayer Subtract();
}
}

+ 25
- 0
src/TensorFlowNET.Core/Keras/Layers/Rnn/IRnnCell.cs View File

@@ -0,0 +1,25 @@
using System;
using System.Collections.Generic;
using System.Text;
using Tensorflow.Common.Types;

namespace Tensorflow.Keras.Layers
{
public interface IRnnCell: ILayer
{
/// <summary>
/// If the derived class tends to not implement it, please return null.
/// </summary>
INestStructure<long>? StateSize { get; }
/// <summary>
/// If the derived class tends to not implement it, please return null.
/// </summary>
INestStructure<long>? OutputSize { get; }
/// <summary>
/// Whether the optional RNN args are supported when appying the layer.
/// In other words, whether `Apply` is overwrited with process of `RnnOptionalArgs`.
/// </summary>
bool SupportOptionalArgs { get; }
Tensors GetInitialState(Tensors inputs, Tensor batch_size, TF_DataType dtype);
}
}

+ 12
- 0
src/TensorFlowNET.Core/Keras/Layers/Rnn/IStackedRnnCells.cs View File

@@ -0,0 +1,12 @@
using System;
using System.Collections.Generic;
using System.Text;

namespace Tensorflow.Keras.Layers
{
public interface IStackedRnnCells : IRnnCell
{
int Count { get; }
IRnnCell this[int idx] { get; }
}
}

+ 1
- 0
src/TensorFlowNET.Core/Keras/Saving/Json/CustomizedKerasShapesWrapperJsonConverter.cs View File

@@ -3,6 +3,7 @@ using Newtonsoft.Json;
using System;
using System.Collections.Generic;
using System.Text;
using Tensorflow.Common.Types;

namespace Tensorflow.Keras.Saving.Json
{


+ 1
- 0
src/TensorFlowNET.Core/Keras/Saving/KerasShapesWrapper.cs View File

@@ -6,6 +6,7 @@ using System.Text;
using System.Diagnostics;
using OneOf.Types;
using Tensorflow.Keras.Saving.Json;
using Tensorflow.Common.Types;

namespace Tensorflow.Keras.Saving
{


+ 0
- 5
src/TensorFlowNET.Core/NumPy/Axis.cs View File

@@ -74,8 +74,3 @@ namespace Tensorflow
=> IsScalar ? $"{axis[0]}" : $"({string.Join(", ", axis)})";
}
}

namespace System.Runtime.CompilerServices
{
internal static class IsExternalInit { }
}

+ 6
- 0
src/TensorFlowNET.Core/NumPy/NDArray.Implicit.cs View File

@@ -107,9 +107,15 @@ namespace Tensorflow.NumPy
public static implicit operator NDArray(bool value)
=> new NDArray(value);

public static implicit operator NDArray(byte value)
=> new NDArray(value);

public static implicit operator NDArray(int value)
=> new NDArray(value);

public static implicit operator NDArray(long value)
=> new NDArray(value);

public static implicit operator NDArray(float value)
=> new NDArray(value);



+ 9
- 9
src/TensorFlowNET.Core/NumPy/NDArrayRender.cs View File

@@ -7,7 +7,7 @@ namespace Tensorflow.NumPy
{
public class NDArrayRender
{
public static string ToString(NDArray array)
public static string ToString(NDArray array, int maxLength = 10)
{
Shape shape = array.shape;
if (shape.IsScalar)
@@ -15,12 +15,12 @@ namespace Tensorflow.NumPy

var s = new StringBuilder();
s.Append("array(");
Build(s, array);
Build(s, array, maxLength);
s.Append(")");
return s.ToString();
}

static void Build(StringBuilder s, NDArray array)
static void Build(StringBuilder s, NDArray array, int maxLength)
{
var shape = array.shape;

@@ -35,11 +35,11 @@ namespace Tensorflow.NumPy
var len = shape[0];
s.Append("[");

if (len <= 10)
if (len <= maxLength)
{
for (int i = 0; i < len; i++)
{
Build(s, array[i]);
Build(s, array[i], maxLength);
if (i < len - 1)
{
s.Append(", ");
@@ -49,9 +49,9 @@ namespace Tensorflow.NumPy
}
else
{
for (int i = 0; i < 5; i++)
for (int i = 0; i < maxLength / 2; i++)
{
Build(s, array[i]);
Build(s, array[i], maxLength);
if (i < len - 1)
{
s.Append(", ");
@@ -62,9 +62,9 @@ namespace Tensorflow.NumPy
s.Append(" ... ");
s.AppendLine();

for (int i = (int)len - 5; i < len; i++)
for (int i = (int)len - maxLength / 2; i < len; i++)
{
Build(s, array[i]);
Build(s, array[i], maxLength);
if (i < len - 1)
{
s.Append(", ");


+ 4
- 0
src/TensorFlowNET.Core/NumPy/NumPy.Sorting.Searching.Counting.cs View File

@@ -13,6 +13,10 @@ namespace Tensorflow.NumPy
public static NDArray argmax(NDArray a, Axis? axis = null)
=> new NDArray(math_ops.argmax(a, axis ?? 0));

[AutoNumPy]
public static NDArray argmin(NDArray a, Axis? axis = null)
=> new NDArray(math_ops.argmin(a, axis ?? 0));

[AutoNumPy]
public static NDArray argsort(NDArray a, Axis? axis = null)
=> new NDArray(sort_ops.argsort(a, axis: axis ?? -1));


+ 2
- 2
src/TensorFlowNET.Core/NumPy/NumPy.Statistics.cs View File

@@ -10,10 +10,10 @@ namespace Tensorflow.NumPy
public partial class np
{
[AutoNumPy]
public static NDArray amin(NDArray x, int axis = 0) => new NDArray(tf.arg_min(x, axis));
public static NDArray amin(NDArray x, int axis = 0) => new NDArray(tf.min(x, axis));

[AutoNumPy]
public static NDArray amax(NDArray x, int axis = 0) => new NDArray(tf.math.argmax(x, axis));
public static NDArray amax(NDArray x, int axis = 0) => new NDArray(tf.max(x, axis));

[AutoNumPy]
public static NDArray average(NDArray a, int axis = -1, NDArray? weights = null, bool returned = false)


+ 21
- 0
src/TensorFlowNET.Core/NumPy/Numpy.Math.cs View File

@@ -49,9 +49,30 @@ namespace Tensorflow.NumPy
[AutoNumPy]
public static NDArray prod<T>(params T[] array) where T : unmanaged
=> new NDArray(tf.reduce_prod(new NDArray(array)));
[AutoNumPy]
public static NDArray dot(NDArray x1, NDArray x2, NDArray? axes = null, string? name = null)
{
//if axes mentioned
if (axes != null)
{
return new NDArray(tf.dot_prod(x1, x2, axes, name));
}
if (x1.shape.ndim > 1)
{
x1 = GetFlattenArray(x1);
}
if (x2.shape.ndim > 1)
{
x2 = GetFlattenArray(x2);
}
//if axes not mentioned, default 0,0
return new NDArray(tf.dot_prod(x1, x2, axes: new int[] { 0, 0 }, name));

}
[AutoNumPy]
public static NDArray power(NDArray x, NDArray y) => new NDArray(tf.pow(x, y));
[AutoNumPy]
public static NDArray square(NDArray x) => new NDArray(tf.square(x));

[AutoNumPy]
public static NDArray sin(NDArray x) => new NDArray(math_ops.sin(x));


+ 23
- 1
src/TensorFlowNET.Core/Numpy/Shape.cs View File

@@ -19,13 +19,14 @@ using System;
using System.Collections.Generic;
using System.Linq;
using System.Text;
using Tensorflow.Common.Types;
using Tensorflow.Keras.Saving.Common;
using Tensorflow.NumPy;

namespace Tensorflow
{
[JsonConverter(typeof(CustomizedShapeJsonConverter))]
public class Shape
public class Shape : INestStructure<long>
{
public int ndim => _dims == null ? -1 : _dims.Length;
long[] _dims;
@@ -41,6 +42,27 @@ namespace Tensorflow
}
}

public NestType NestType => NestType.List;

public int ShallowNestedCount => ndim;
/// <summary>
/// The total item count of depth 1 of the nested structure.
/// For example, [1, 2, [3, 4, 5]] has TotalNestedCount = 5.
/// </summary>
public int TotalNestedCount => ndim;

public IEnumerable<long> Flatten() => dims.Select(x => x);

public INestStructure<TOut> MapStructure<TOut>(Func<long, TOut> func)
{
return new NestList<TOut>(dims.Select(x => func(x)));
}

public Nest<long> AsNest()
{
return new NestList<long>(Flatten()).AsNest();
}

#region https://docs.microsoft.com/en-us/dotnet/csharp/language-reference/proposals/csharp-8.0/ranges
public int Length => ndim;
public long[] Slice(int start, int length)


+ 22
- 0
src/TensorFlowNET.Core/Operations/Initializers/NpyLoadInitializer.cs View File

@@ -0,0 +1,22 @@
using System;
using System.Collections.Generic;
using System.Text;
using Tensorflow.NumPy;

namespace Tensorflow.Operations.Initializers
{
/// <summary>
/// An initializer specially used for debugging (to load weights from disk).
/// </summary>
class NpyLoadInitializer : IInitializer
{
string _path;
public NpyLoadInitializer(string path) { _path = path; }
public string ClassName => "";
public IDictionary<string, object> Config => new Dictionary<string, object>();
public Tensor Apply(InitializerArgs args)
{
return np.load(_path);
}
}
}

+ 2
- 3
src/TensorFlowNET.Core/Operations/Initializers/Orthogonal.cs View File

@@ -53,13 +53,12 @@ public class Orthogonal : IInitializer
// Compute the qr factorization
var (q, r) = tf.linalg.qr(a, full_matrices: false);
// Make Q uniform
var d = tf.linalg.tensor_diag_part(r);
var d = tf.linalg.tensor_diag_part(r.Single);
q *= tf.sign(d);

if (num_rows < num_cols)
{
// q = tf.linalg.matrix_transpose(q);
throw new NotImplementedException("");
q = array_ops.matrix_transpose(q);
}

return _gain * tf.reshape(q, shape);


+ 2
- 1
src/TensorFlowNET.Core/Operations/NnOps/BasicLSTMCell.cs View File

@@ -11,6 +11,7 @@ namespace Tensorflow
/// Basic LSTM recurrent network cell.
/// The implementation is based on: http://arxiv.org/abs/1409.2329.
/// </summary>
[Obsolete("This is an incompleted tf v1 api, pleas use keras RNNs instead.")]
public class BasicLstmCell : LayerRnnCell
{
int _num_units;
@@ -88,7 +89,7 @@ namespace Tensorflow
gate_inputs = nn_ops.bias_add(gate_inputs, _bias);

// i = input_gate, j = new_input, f = forget_gate, o = output_gate
var tensors = array_ops.split(value: gate_inputs, num_split: 4, axis: one);
var tensors = array_ops.split(value: gate_inputs, num_or_size_splits: 4, axis: one);
var (i, j, f, o) = (tensors[0], tensors[1], tensors[2], tensors[3]);

var forget_bias_tensor = constant_op.constant(_forget_bias, dtype: f.dtype);


+ 1
- 0
src/TensorFlowNET.Core/Operations/NnOps/BasicRNNCell.cs View File

@@ -20,6 +20,7 @@ using static Tensorflow.Binding;

namespace Tensorflow
{
[Obsolete("This is an incompleted tf v1 api, pleas use keras RNNs instead.")]
public class BasicRnnCell : LayerRnnCell
{
int _num_units;


+ 1
- 0
src/TensorFlowNET.Core/Operations/NnOps/LayerRNNCell.cs View File

@@ -19,6 +19,7 @@ using static Tensorflow.Binding;

namespace Tensorflow
{
[Obsolete("This is an incompleted tf v1 api, pleas use keras RNNs instead.")]
public class LayerRnnCell : RnnCell
{
protected InputSpec inputSpec;


+ 18
- 3
src/TensorFlowNET.Core/Operations/NnOps/RNNCell.cs View File

@@ -16,10 +16,11 @@

using System;
using System.Collections.Generic;
using Tensorflow.Common.Types;
using Tensorflow.Keras;
using Tensorflow.Keras.ArgsDefinition;
using Tensorflow.Keras.ArgsDefinition.Rnn;
using Tensorflow.Keras.Engine;
using Tensorflow.Keras.Layers;
using Tensorflow.Keras.Saving;
using Tensorflow.NumPy;
using Tensorflow.Operations;
@@ -50,7 +51,8 @@ namespace Tensorflow
/// matching structure of Tensors having shape `[batch_size].concatenate(s)`
/// for each `s` in `self.batch_size`.
/// </summary>
public abstract class RnnCell : ILayer, RNNArgs.IRnnArgCell
[Obsolete("This is an incompleted tf v1 api, pleas use keras RNNs instead.")]
public abstract class RnnCell : ILayer, IRnnCell
{
/// <summary>
/// Attribute that indicates whether the cell is a TF RNN cell, due the slight
@@ -142,7 +144,7 @@ namespace Tensorflow
throw new NotImplementedException("_zero_state_tensors");
}

public Tensors Apply(Tensors inputs, Tensor state = null, bool is_training = false)
public Tensors Apply(Tensors inputs, Tensors state = null, bool? is_training = false, IOptionalArgs? optional_args = null)
{
throw new NotImplementedException();
}
@@ -173,5 +175,18 @@ namespace Tensorflow
{
throw new NotImplementedException();
}

public (Tensor, Tensors) Call(Tensors inputs, Tensors states, bool? training = null)
{
throw new NotImplementedException();
}
public Tensors GetInitialState(Tensors inputs = null, Tensor batch_size = null, TF_DataType dtype = TF_DataType.DtInvalid)
{
throw new NotImplementedException();
}
public INestStructure<long> StateSize => throw new NotImplementedException();
public INestStructure<long> OutputSize => throw new NotImplementedException();
public bool IsTFRnnCell => throw new NotImplementedException();
public bool SupportOptionalArgs => throw new NotImplementedException();
}
}

+ 57
- 1
src/TensorFlowNET.Core/Operations/OpDefLibrary.cs View File

@@ -15,9 +15,11 @@
******************************************************************************/

using Google.Protobuf;
using Google.Protobuf.Collections;
using System;
using System.Collections.Generic;
using System.Linq;
using Tensorflow.Functions;
using static Tensorflow.Binding;
using static Tensorflow.OpDef.Types;

@@ -387,9 +389,13 @@ namespace Tensorflow
case "list(type)":
attr_value.List.Type.AddRange((value as IList<TF_DataType>).Select(x => _MakeType(x, attr_def)));
break;
case "list(float)":
if (value != null)
attr_value.List.F.AddRange((value as IEnumerable<float>).ToArray());
break;
case "list(int)":
if (value != null)
attr_value.List.I.AddRange((value as int[]).Select(x => Convert.ToInt64(x)));
attr_value.List.I.AddRange((value as IEnumerable<int>).Select(x => Convert.ToInt64(x)));
break;
case "bool":
attr_value.B = (bool)value;
@@ -420,6 +426,15 @@ namespace Tensorflow
case "list(shape)":
attr_value.List.Shape.AddRange((value as Shape[]).Select(x => _MakeShape(x, attr_def)));
break;
case "func":
attr_value.Func = _MakeFunc(value, attr_def.Name);
break;
case "list(func)":
attr_value.List.Func.AddRange(_MakeFuncList(value, attr_def.Name));
break;
case "list(string)":
attr_value.List.S.AddRange((value as IEnumerable<string>).Select(x => ByteString.CopyFromUtf8(x)));
break;
default:
throw new TypeError($"SetAttrValue: can't not convert attr_def.Type '{attr_def.Type}' to protos.");
}
@@ -427,6 +442,47 @@ namespace Tensorflow
return attr_value;
}

private NameAttrList _MakeFunc(object func, string arg_name)
{
if(func is NameAttrList attrList)
{
return attrList;
}
NameAttrList fn_attr;
if(func is string funcStr)
{
fn_attr = new NameAttrList() { Name = funcStr };
}
else if(func is ConcreteFunction concrete)
{
concrete.AddTograph(ops.get_default_graph());
fn_attr = concrete.AsNameAttrList;
}
else if(func is EagerDefinedFunction eager)
{
eager.AddToGraph(ops.get_default_graph());
fn_attr = new NameAttrList() { Name = eager.Name };
}
else
{
throw new TypeError($"Don't know how to convert {func} to a func for argument {arg_name}");
}
return fn_attr;
}

private List<NameAttrList> _MakeFuncList(object funcList, string arg_name)
{
List<NameAttrList> res = new List<NameAttrList>();
if(funcList is IEnumerable enumerable)
{
foreach(var func in enumerable)
{
res.Add(_MakeFunc(func, arg_name));
}
}
return res;
}

private bool _IsListParameter(ArgDef arg)
{
if (!String.IsNullOrEmpty(arg.NumberAttr))


Some files were not shown because too many files changed in this diff

Loading…
Cancel
Save