# Conflicts: # src/TensorFlowNET.Core/Tensorflow.Binding.csproj # src/TensorFlowNET.Keras/Datasets/Imdb.cstags/v0.110.4-Transformer-Model
@@ -16,6 +16,7 @@ | |||
using System; | |||
using System.Runtime.InteropServices; | |||
using static Tensorflow.CppShapeInferenceResult.Types; | |||
namespace Tensorflow | |||
{ | |||
@@ -50,6 +51,35 @@ namespace Tensorflow | |||
return handle == IntPtr.Zero ? String.Empty : Marshal.PtrToStringAnsi(handle); | |||
} | |||
public unsafe static byte[] ByteStringPiece(Buffer? handle) | |||
{ | |||
if (handle is null) | |||
{ | |||
return new byte[0]; | |||
} | |||
var data = handle.ToArray(); | |||
return data; | |||
} | |||
public unsafe static byte[] ByteStringPieceFromNativeString(IntPtr handle) | |||
{ | |||
if (handle == IntPtr.Zero) | |||
{ | |||
return new byte[0]; | |||
} | |||
byte* str_data = (byte*)handle.ToPointer(); | |||
List<byte> bytes = new List<byte>(); | |||
byte current = 255; | |||
while (current != ((byte)'\0')) | |||
{ | |||
current = *(str_data++); | |||
bytes.Add(current); | |||
} | |||
var data = bytes.ToArray(); | |||
return data; | |||
} | |||
[UnmanagedFunctionPointer(CallingConvention.Winapi)] | |||
public delegate void Deallocator(IntPtr data, IntPtr size, ref DeallocatorArgs args); | |||
@@ -10,7 +10,7 @@ namespace Tensorflow | |||
[DllImport(TensorFlowLibName)] | |||
public static extern void TFC_SetAttr(SafeGraphHandle graph, IntPtr op, string attr_name, SafeBufferHandle attr_value_proto, SafeStatusHandle status); | |||
[DllImport(TensorFlowLibName)] | |||
public static extern IntPtr TFC_GetHandleShapeAndType(SafeGraphHandle c_graph, TF_Output output); | |||
public static extern SafeBufferHandle TFC_GetHandleShapeAndType(SafeGraphHandle c_graph, TF_Output output); | |||
[DllImport(TensorFlowLibName)] | |||
public static extern void TFC_SetHandleShapeAndType(SafeGraphHandle c_graph, TF_Output output, byte[] data, long proto_len, SafeStatusHandle status); | |||
} | |||
@@ -91,8 +91,7 @@ namespace Tensorflow | |||
return identity(values.First(), name: scope); | |||
}); | |||
} | |||
return gen_array_ops.concat_v2(values.ToArray(), ops.convert_to_tensor(axis), name: name); | |||
return array_ops.concat(values.ToArray(), axis, name: name); | |||
} | |||
/// <summary> | |||
@@ -163,14 +162,17 @@ namespace Tensorflow | |||
/// Reverses specific dimensions of a tensor. | |||
/// </summary> | |||
/// <param name="tensor"></param> | |||
/// <param name="axis"></param> | |||
/// <param name="axis">The indices of the dimensions to reverse. Must be in the range [-rank(tensor), rank(tensor)).</param> | |||
/// <param name="name"></param> | |||
/// <returns></returns> | |||
public Tensor reverse(Tensor tensor, int[] axis, string name = null) | |||
=> gen_array_ops.reverse(tensor, ops.convert_to_tensor(axis), name: name); | |||
public Tensor reverse(Tensor tensor, Tensor axis, string name = null) | |||
=> gen_array_ops.reverse(tensor, axis, name: name); | |||
public Tensor reverse(Tensor tensor, Axis axis, string name = null) | |||
{ | |||
if (axis.IsScalar) | |||
{ | |||
axis = new Axis(axis.axis); | |||
} | |||
return array_ops.reverse(tensor, axis, name: name); | |||
} | |||
/// <summary> | |||
/// Returns the rank of a tensor. | |||
@@ -46,10 +46,10 @@ namespace Tensorflow | |||
Tensor loop_vars, | |||
int parallel_iterations = 10) | |||
{ | |||
Func<Tensor[], Tensor> cond1 = x | |||
Func<Tensors, Tensor> cond1 = x | |||
=> cond(x[0]); | |||
Func<Tensor[], Tensor[]> body1 = x | |||
Func<Tensors, Tensors> body1 = x | |||
=> new[] { body(x[0]) }; | |||
var results = control_flow_ops.while_loop(cond1, | |||
@@ -58,9 +58,9 @@ namespace Tensorflow | |||
return results[0]; | |||
} | |||
public Tensor[] while_loop(Func<Tensor[], Tensor> cond, | |||
Func<Tensor[], Tensor[]> body, | |||
Tensor[] loop_vars, | |||
public Tensor[] while_loop(Func<Tensors, Tensor> cond, | |||
Func<Tensors, Tensors> body, | |||
Tensors loop_vars, | |||
int parallel_iterations = 10, | |||
string name = null) | |||
=> control_flow_ops.while_loop(cond, body, loop_vars, | |||
@@ -14,6 +14,10 @@ | |||
limitations under the License. | |||
******************************************************************************/ | |||
using OneOf.Types; | |||
using System; | |||
using System.Buffers.Text; | |||
using Tensorflow.Contexts; | |||
using static Tensorflow.Binding; | |||
namespace Tensorflow | |||
@@ -162,17 +166,108 @@ namespace Tensorflow | |||
public Tensor sobel_edges(Tensor image) | |||
=> image_ops_impl.sobel_edges(image); | |||
public Tensor decode_jpeg(Tensor contents, | |||
int channels = 0, | |||
int ratio = 1, | |||
bool fancy_upscaling = true, | |||
bool try_recover_truncated = false, | |||
int acceptable_fraction = 1, | |||
string dct_method = "", | |||
string name = null) | |||
=> gen_image_ops.decode_jpeg(contents, channels: channels, ratio: ratio, | |||
fancy_upscaling: fancy_upscaling, try_recover_truncated: try_recover_truncated, | |||
acceptable_fraction: acceptable_fraction, dct_method: dct_method); | |||
/// <summary> | |||
/// Adjust contrast of RGB or grayscale images. | |||
/// </summary> | |||
/// <param name="images">Images to adjust. At least 3-D.</param> | |||
/// <param name="contrast_factor"></param> | |||
/// <param name="name">A float multiplier for adjusting contrast.</param> | |||
/// <returns>The contrast-adjusted image or images.</returns> | |||
public Tensor adjust_contrast(Tensor images, float contrast_factor, string name = null) | |||
=> gen_image_ops.adjust_contrastv2(images, contrast_factor, name); | |||
/// <summary> | |||
/// Adjust hue of RGB images. | |||
/// </summary> | |||
/// <param name="images">RGB image or images. The size of the last dimension must be 3.</param> | |||
/// <param name="delta">float. How much to add to the hue channel.</param> | |||
/// <param name="name">A name for this operation (optional).</param> | |||
/// <returns>Adjusted image(s), same shape and DType as `image`.</returns> | |||
/// <exception cref="ValueError">if `delta` is not in the interval of `[-1, 1]`.</exception> | |||
public Tensor adjust_hue(Tensor images, float delta, string name = null) | |||
{ | |||
if (tf.Context.executing_eagerly()) | |||
{ | |||
if (delta < -1f || delta > 1f) | |||
throw new ValueError("delta must be in the interval [-1, 1]"); | |||
} | |||
return gen_image_ops.adjust_hue(images, delta, name: name); | |||
} | |||
/// <summary> | |||
/// Adjust saturation of RGB images. | |||
/// </summary> | |||
/// <param name="image">RGB image or images. The size of the last dimension must be 3.</param> | |||
/// <param name="saturation_factor">float. Factor to multiply the saturation by.</param> | |||
/// <param name="name">A name for this operation (optional).</param> | |||
/// <returns>Adjusted image(s), same shape and DType as `image`.</returns> | |||
public Tensor adjust_saturation(Tensor image, float saturation_factor, string name = null) | |||
=> gen_image_ops.adjust_saturation(image, saturation_factor, name); | |||
/// <summary> | |||
/// Greedily selects a subset of bounding boxes in descending order of score. | |||
/// </summary> | |||
/// <param name="boxes"> | |||
/// A 4-D float `Tensor` of shape `[batch_size, num_boxes, q, 4]`. If `q` | |||
/// is 1 then same boxes are used for all classes otherwise, if `q` is equal | |||
/// to number of classes, class-specific boxes are used. | |||
/// </param> | |||
/// <param name="scores"> | |||
/// A 3-D float `Tensor` of shape `[batch_size, num_boxes, num_classes]` | |||
/// representing a single score corresponding to each box(each row of boxes). | |||
/// </param> | |||
/// <param name="max_output_size_per_class"> | |||
/// A scalar integer `Tensor` representing the | |||
/// maximum number of boxes to be selected by non-max suppression per class | |||
/// </param> | |||
/// <param name="max_total_size"> | |||
/// A int32 scalar representing maximum number of boxes retained | |||
/// over all classes.Note that setting this value to a large number may | |||
/// result in OOM error depending on the system workload. | |||
/// </param> | |||
/// <param name="iou_threshold"> | |||
/// A float representing the threshold for deciding whether boxes | |||
/// overlap too much with respect to IOU. | |||
/// </param> | |||
/// <param name="score_threshold"> | |||
/// A float representing the threshold for deciding when to | |||
/// remove boxes based on score. | |||
/// </param> | |||
/// <param name="pad_per_class"> | |||
/// If false, the output nmsed boxes, scores and classes are | |||
/// padded/clipped to `max_total_size`. If true, the output nmsed boxes, scores and classes are padded to be of length `max_size_per_class`*`num_classes`, | |||
/// unless it exceeds `max_total_size` in which case it is clipped to `max_total_size`. Defaults to false. | |||
/// </param> | |||
/// <param name="clip_boxes"> | |||
/// If true, the coordinates of output nmsed boxes will be clipped | |||
/// to[0, 1]. If false, output the box coordinates as it is. Defaults to true. | |||
/// </param> | |||
/// <returns> | |||
/// 'nmsed_boxes': A [batch_size, max_detections, 4] float32 tensor containing the non-max suppressed boxes. | |||
/// 'nmsed_scores': A [batch_size, max_detections] float32 tensor containing the scores for the boxes. | |||
/// 'nmsed_classes': A [batch_size, max_detections] float32 tensor containing the class for boxes. | |||
/// 'valid_detections': A [batch_size] int32 tensor indicating the number of | |||
/// valid detections per batch item. Only the top valid_detections[i] entries | |||
/// in nms_boxes[i], nms_scores[i] and nms_class[i] are valid. The rest of the | |||
/// entries are zero paddings. | |||
/// </returns> | |||
public (Tensor, Tensor, Tensor, Tensor) combined_non_max_suppression( | |||
Tensor boxes, | |||
Tensor scores, | |||
int max_output_size_per_class, | |||
int max_total_size, | |||
float iou_threshold, | |||
float score_threshold, | |||
bool pad_per_class = false, | |||
bool clip_boxes = true) | |||
{ | |||
var iou_threshold_t = ops.convert_to_tensor(iou_threshold, TF_DataType.TF_FLOAT, name: "iou_threshold"); | |||
var score_threshold_t = ops.convert_to_tensor(score_threshold, TF_DataType.TF_FLOAT, name: "score_threshold"); | |||
var max_total_size_t = ops.convert_to_tensor(max_total_size); | |||
var max_output_size_per_class_t = ops.convert_to_tensor(max_output_size_per_class); | |||
return gen_image_ops.combined_non_max_suppression(boxes, scores, max_output_size_per_class_t, max_total_size_t, | |||
iou_threshold_t, score_threshold_t, pad_per_class, clip_boxes); | |||
} | |||
/// <summary> | |||
/// Extracts crops from the input image tensor and resizes them using bilinear sampling or nearest neighbor sampling (possibly with aspect ratio change) to a common output size specified by crop_size. This is more general than the crop_to_bounding_box op which extracts a fixed size slice from the input image and does not allow resizing or aspect ratio change. | |||
@@ -187,7 +282,19 @@ namespace Tensorflow | |||
/// <param name="name">A name for the operation (optional).</param> | |||
/// <returns>A 4-D tensor of shape [num_boxes, crop_height, crop_width, depth].</returns> | |||
public Tensor crop_and_resize(Tensor image, Tensor boxes, Tensor box_ind, Tensor crop_size, string method = "bilinear", float extrapolation_value = 0f, string name = null) => | |||
image_ops_impl.crop_and_resize(image, boxes, box_ind, crop_size, method, extrapolation_value, name); | |||
gen_image_ops.crop_and_resize(image, boxes, box_ind, crop_size, method, extrapolation_value, name); | |||
public Tensor decode_jpeg(Tensor contents, | |||
int channels = 0, | |||
int ratio = 1, | |||
bool fancy_upscaling = true, | |||
bool try_recover_truncated = false, | |||
int acceptable_fraction = 1, | |||
string dct_method = "", | |||
string name = null) | |||
=> gen_image_ops.decode_jpeg(contents, channels: channels, ratio: ratio, | |||
fancy_upscaling: fancy_upscaling, try_recover_truncated: try_recover_truncated, | |||
acceptable_fraction: acceptable_fraction, dct_method: dct_method); | |||
public Tensor extract_glimpse(Tensor input, Tensor size, Tensor offsets, bool centered = true, bool normalized = true, | |||
bool uniform_noise = true, string name = null) | |||
@@ -14,6 +14,7 @@ | |||
limitations under the License. | |||
******************************************************************************/ | |||
using Tensorflow.NumPy; | |||
using Tensorflow.Operations; | |||
namespace Tensorflow | |||
@@ -42,10 +43,20 @@ namespace Tensorflow | |||
public Tensor multiply(Tensor x, Tensor y, string name = null) | |||
=> math_ops.multiply(x, y, name: name); | |||
public Tensor divide_no_nan(Tensor a, Tensor b, string name = null) | |||
=> math_ops.div_no_nan(a, b); | |||
/// <summary> | |||
/// Computes the Euclidean norm of elements across dimensions of a tensor. | |||
/// </summary> | |||
/// <param name="input_tensor">The tensor to reduce. Should have numeric type.</param> | |||
/// <param name="axis">The dimensions to reduce. If `None` (the default), reduces all dimensions.Must be in the range `[-rank(input_tensor), rank(input_tensor))`</param> | |||
/// <param name="keepdims">If true, retains reduced dimensions with length 1.</param> | |||
/// <param name="name">A name for the operation (optional).</param> | |||
/// <returns>The reduced tensor, of the same dtype as the input_tensor.</returns> | |||
public Tensor reduce_euclidean_norm(Tensor input_tensor, Axis? axis = null, bool keepdims = false, string name = null) | |||
=> math_ops.reduce_euclidean_norm(input_tensor, axis: axis, keepdims: keepdims, name); | |||
public Tensor square(Tensor x, string name = null) | |||
=> math_ops.square(x, name: name); | |||
@@ -354,7 +365,7 @@ namespace Tensorflow | |||
=> a / b; | |||
public Tensor sqrt(Tensor a, string name = null) | |||
=> gen_math_ops.sqrt(a, name); | |||
=> math_ops.sqrt(a, name); | |||
public Tensor sign(Tensor a, string name = null) | |||
=> gen_math_ops.sign(a, name); | |||
@@ -452,7 +463,18 @@ namespace Tensorflow | |||
/// <returns></returns> | |||
public Tensor multiply<Tx, Ty>(Tx x, Ty y, string name = null) | |||
=> gen_math_ops.mul(ops.convert_to_tensor(x), ops.convert_to_tensor(y), name: name); | |||
/// <summary> | |||
/// return scalar product | |||
/// </summary> | |||
/// <typeparam name="Tx"></typeparam> | |||
/// <typeparam name="Ty"></typeparam> | |||
/// <param name="x"></param> | |||
/// <param name="y"></param> | |||
/// <param name="axes"></param> | |||
/// <param name="name"></param> | |||
/// <returns></returns> | |||
public Tensor dot_prod<Tx, Ty>(Tx x, Ty y, NDArray axes, string name = null) | |||
=> math_ops.tensordot(convert_to_tensor(x), convert_to_tensor(y), axes, name: name); | |||
public Tensor negative(Tensor x, string name = null) | |||
=> gen_math_ops.neg(x, name); | |||
@@ -600,5 +622,7 @@ namespace Tensorflow | |||
=> gen_math_ops.squared_difference(x: x, y: y, name: name); | |||
public Tensor complex(Tensor real, Tensor imag, Tensorflow.TF_DataType? dtype = null, | |||
string name = null) => gen_ops.complex(real, imag, dtype, name); | |||
public Tensor exp(Tensor x, | |||
string name = null) => gen_math_ops.exp(x, name); | |||
} | |||
} |
@@ -14,6 +14,7 @@ | |||
limitations under the License. | |||
******************************************************************************/ | |||
using System.Xml.Linq; | |||
using Tensorflow.Operations; | |||
using Tensorflow.Operations.Activation; | |||
using static Tensorflow.Binding; | |||
@@ -126,6 +127,26 @@ namespace Tensorflow | |||
name: name, | |||
exponential_avg_factor: exponential_avg_factor); | |||
/// <summary> | |||
/// Normalizes a tensor by `mean` and `variance`, and applies (optionally) a`scale` \\(\gamma\\) to it, as well as an `offset` \\(\beta\\). | |||
/// </summary> | |||
/// <param name="x">A floating point tensor.</param> | |||
/// <param name="mean">A mean `Tensor`.</param> | |||
/// <param name="variance">A variance `Tensor`.</param> | |||
/// <param name="offset"> An offset `Tensor`, often denoted \\(\beta\\) in equations, or NULL. If present, will be added to the normalized tensor.</param> | |||
/// <param name="scale"> A scale `Tensor`, often denoted \\(\gamma\\) in equations, or NULL. If present, the scale is applied to the normalized tensor.</param> | |||
/// <param name="variance_epsilon"> A small float number to avoid dividing by 0.</param> | |||
/// <param name="name">A name for this operation.</param> | |||
/// <returns>the normalized, scaled, offset tensor.</returns> | |||
public Tensor batch_normalization(Tensor x, | |||
Tensor mean, | |||
Tensor variance, | |||
Tensor offset, | |||
Tensor scale, | |||
float variance_epsilon, | |||
string name = null) => nn_impl.batch_normalization(x, mean, variance, offset, scale, variance_epsilon, name); | |||
public Tensor max_pool(Tensor value, int[] ksize, int[] strides, string padding, string data_format = "NHWC", string name = null) | |||
=> nn_ops.max_pool(value, ksize, strides, padding, data_format: data_format, name: name); | |||
@@ -31,6 +31,6 @@ namespace Tensorflow | |||
public Tensor reshape(Tensor tensor, | |||
object[] shape, | |||
string name = null) | |||
=> gen_array_ops.reshape(tensor, ops.convert_to_tensor(shape), name); | |||
=> array_ops.reshape(tensor, shape, name); | |||
} | |||
} |
@@ -68,20 +68,27 @@ namespace Tensorflow | |||
/// <param name="name">A name for the operation (optional)</param> | |||
/// <returns>if num_or_size_splits is a scalar returns num_or_size_splits Tensor objects; | |||
/// if num_or_size_splits is a 1-D Tensor returns num_or_size_splits.get_shape[0] Tensor objects resulting from splitting value.</returns> | |||
public Tensor[] split(Tensor value, int num_split, Tensor axis, string name = null) | |||
public Tensor[] split(Tensor value, int num_split, Axis axis, string name = null) | |||
=> array_ops.split( | |||
value: value, | |||
num_split: num_split, | |||
num_or_size_splits: num_split, | |||
axis: axis, | |||
name: name); | |||
public Tensor[] split(Tensor value, int num_split, int axis, string name = null) | |||
public Tensor[] split(Tensor value, int[] num_split, Axis axis, string name = null) | |||
=> array_ops.split( | |||
value: value, | |||
num_split: num_split, | |||
num_or_size_splits: num_split, | |||
axis: axis, | |||
name: name); | |||
//public Tensor[] split(Tensor value, int num_split, Axis axis, string name = null) | |||
// => array_ops.split( | |||
// value: value, | |||
// num_or_size_splits: num_split, | |||
// axis: axis, | |||
// name: name); | |||
public Tensor ensure_shape(Tensor x, Shape shape, string name = null) | |||
{ | |||
return gen_ops.ensure_shape(x, shape, name); | |||
@@ -23,7 +23,7 @@ namespace Tensorflow | |||
=> gen_array_ops.tile(input, multiples, name); | |||
public Tensor tile(Tensor input, object[] multiples, string name = null) | |||
=> gen_array_ops.tile(input, ops.convert_to_tensor(multiples), name); | |||
=> array_ops.tile(input, constant_op.constant(shape_utils.from_object_array(multiples).dims), name); | |||
public Tensor tile(Tensor input, Shape multiples, string name = null) | |||
{ | |||
@@ -486,7 +486,28 @@ namespace Tensorflow | |||
throw new NotImplementedException(""); | |||
} | |||
} | |||
public static NDArray GetFlattenArray(NDArray x) | |||
{ | |||
switch (x.GetDataType()) | |||
{ | |||
case TF_DataType.TF_FLOAT: | |||
x = x.ToArray<float>(); | |||
break; | |||
case TF_DataType.TF_DOUBLE: | |||
x = x.ToArray<double>(); | |||
break; | |||
case TF_DataType.TF_INT16: | |||
case TF_DataType.TF_INT32: | |||
x = x.ToArray<int>(); | |||
break; | |||
case TF_DataType.TF_INT64: | |||
x = x.ToArray<long>(); | |||
break; | |||
default: | |||
break; | |||
} | |||
return x; | |||
} | |||
public static TF_DataType GetDataType(this object data) | |||
{ | |||
var type = data.GetType(); | |||
@@ -503,7 +524,7 @@ namespace Tensorflow | |||
case Tensors tensors: | |||
return tensors.dtype; | |||
case IEnumerable<Tensor> tensors: | |||
return tensors.First().dtype; | |||
return tensors.Where(x => x is not null).First().dtype; | |||
case RefVariable variable: | |||
return variable.dtype; | |||
case ResourceVariable variable: | |||
@@ -3,16 +3,16 @@ using System; | |||
using System.Collections.Generic; | |||
using System.Text; | |||
namespace Tensorflow.Extensions | |||
namespace Tensorflow.Common.Extensions | |||
{ | |||
public static class JObjectExtensions | |||
{ | |||
public static T? TryGetOrReturnNull<T>(this JObject obj, string key) | |||
{ | |||
var res = obj[key]; | |||
if(res is null) | |||
if (res is null) | |||
{ | |||
return default(T); | |||
return default; | |||
} | |||
else | |||
{ |
@@ -0,0 +1,38 @@ | |||
using System; | |||
using System.Collections.Generic; | |||
using System.Linq; | |||
using System.Text; | |||
namespace Tensorflow.Common.Extensions | |||
{ | |||
public static class LinqExtensions | |||
{ | |||
#if NETSTANDARD2_0 | |||
public static IEnumerable<T> TakeLast<T>(this IEnumerable<T> sequence, int count) | |||
{ | |||
return sequence.Skip(sequence.Count() - count); | |||
} | |||
public static IEnumerable<T> SkipLast<T>(this IEnumerable<T> sequence, int count) | |||
{ | |||
return sequence.Take(sequence.Count() - count); | |||
} | |||
#endif | |||
public static Tensors ToTensors(this Tensor[] tensors) | |||
{ | |||
return new Tensors(tensors); | |||
} | |||
public static Tensors ToTensors(this IList<Tensor> tensors) | |||
{ | |||
return new Tensors(tensors); | |||
} | |||
public static void Deconstruct<T1, T2, T3>(this (T1, T2, T3) values, out T1 first, out T2 second, out T3 third) | |||
{ | |||
first = values.Item1; | |||
second = values.Item2; | |||
third = values.Item3; | |||
} | |||
} | |||
} |
@@ -0,0 +1,33 @@ | |||
using System; | |||
using System.Collections.Generic; | |||
using System.Text; | |||
using Tensorflow.Common.Types; | |||
namespace Tensorflow.Common.Extensions | |||
{ | |||
public static class NestExtensions | |||
{ | |||
public static Tensors ToTensors(this INestable<Tensor> tensors) | |||
{ | |||
return new Tensors(tensors.AsNest()); | |||
} | |||
public static Tensors? ToTensors(this Nest<Tensor> tensors) | |||
{ | |||
return Tensors.FromNest(tensors); | |||
} | |||
/// <summary> | |||
/// If the nested object is already a nested type, this function could reduce it. | |||
/// For example, `Nest[Nest[T]]` can be reduced to `Nest[T]`. | |||
/// </summary> | |||
/// <typeparam name="TIn"></typeparam> | |||
/// <typeparam name="TOut"></typeparam> | |||
/// <param name="input"></param> | |||
/// <returns></returns> | |||
public static Nest<TOut> ReduceTo<TIn, TOut>(this INestStructure<TIn> input) where TIn: INestStructure<TOut> | |||
{ | |||
return Nest<TOut>.ReduceFrom(input); | |||
} | |||
} | |||
} |
@@ -0,0 +1,20 @@ | |||
using System; | |||
using System.Collections.Generic; | |||
using System.Text; | |||
namespace Tensorflow.Common.Types | |||
{ | |||
/// <summary> | |||
/// This is a temp solution, which should be removed after refactoring `Tensors` | |||
/// </summary> | |||
[Obsolete] | |||
public class FakeTensorByTensorArray: Tensor | |||
{ | |||
public TensorArray TensorArray { get; set; } | |||
public FakeTensorByTensorArray(TensorArray array) | |||
{ | |||
TensorArray = array; | |||
} | |||
} | |||
} |
@@ -0,0 +1,69 @@ | |||
using System; | |||
using System.Collections.Generic; | |||
using System.Diagnostics; | |||
using System.Text; | |||
namespace Tensorflow.Common.Types | |||
{ | |||
public class GeneralizedTensorShape: Nest<Shape> | |||
{ | |||
public GeneralizedTensorShape(Shape value, string? name = null) | |||
{ | |||
NodeValue = value; | |||
NestType = NestType.Node; | |||
} | |||
public GeneralizedTensorShape(IEnumerable<Shape> values, string? name = null) | |||
{ | |||
ListValue = values.Select(s => new Nest<Shape>(s) as INestStructure<Shape>).ToList(); | |||
Name = name; | |||
NestType = NestType.List; | |||
} | |||
public GeneralizedTensorShape(Dictionary<string, Shape> value, string? name = null) | |||
{ | |||
DictValue = value.ToDictionary(x => x.Key, x => new Nest<Shape>(x.Value) as INestStructure<Shape>); | |||
Name = name; | |||
NestType = NestType.Dictionary; | |||
} | |||
public GeneralizedTensorShape(Nest<Shape> other) | |||
{ | |||
NestType = other.NestType; | |||
NodeValue = other.NodeValue; | |||
DictValue = other.DictValue; | |||
ListValue = other.ListValue; | |||
Name = other.Name; | |||
} | |||
public Shape ToSingleShape() | |||
{ | |||
var shapes = Flatten().ToList(); | |||
if (shapes.Count != 1) | |||
{ | |||
throw new ValueError("The generalized shape contains more than 1 dim."); | |||
} | |||
return shapes[0]; | |||
} | |||
public long ToNumber() | |||
{ | |||
var shapes = Flatten().ToList(); | |||
if (shapes.Count != 1 || shapes[0].ndim != 1) | |||
{ | |||
throw new ValueError("The generalized shape contains more than 1 dim."); | |||
} | |||
return shapes[0].dims[0]; | |||
} | |||
public INestStructure<TensorShapeConfig> ToTensorShapeConfigs() | |||
{ | |||
return MapStructure(s => new TensorShapeConfig() { Items = s.dims.Select<long, long?>(x => x == -1 ? null : x).ToArray() }); | |||
} | |||
public static implicit operator GeneralizedTensorShape(Shape shape) | |||
{ | |||
return new GeneralizedTensorShape(shape); | |||
} | |||
} | |||
} |
@@ -0,0 +1,40 @@ | |||
using System; | |||
using System.Collections.Generic; | |||
using System.Text; | |||
namespace Tensorflow.Common.Types | |||
{ | |||
/// <summary> | |||
/// This interface indicates that a class may have a nested structure and provide | |||
/// methods to manipulate with the structure. | |||
/// </summary> | |||
public interface INestStructure<T>: INestable<T> | |||
{ | |||
NestType NestType { get; } | |||
/// <summary> | |||
/// The item count of depth 1 of the nested structure. | |||
/// For example, [1, 2, [3, 4, 5]] has ShallowNestedCount = 3. | |||
/// </summary> | |||
int ShallowNestedCount { get; } | |||
/// <summary> | |||
/// The total item count of depth 1 of the nested structure. | |||
/// For example, [1, 2, [3, 4, 5]] has TotalNestedCount = 5. | |||
/// </summary> | |||
int TotalNestedCount { get; } | |||
/// <summary> | |||
/// Flatten the Nestable object. Node that if the object contains only one value, | |||
/// it will be flattened to an enumerable with one element. | |||
/// </summary> | |||
/// <returns></returns> | |||
IEnumerable<T> Flatten(); | |||
/// <summary> | |||
/// Construct a new object with the same nested structure. | |||
/// </summary> | |||
/// <typeparam name="TOut"></typeparam> | |||
/// <param name="func"></param> | |||
/// <returns></returns> | |||
INestStructure<TOut> MapStructure<TOut>(Func<T, TOut> func); | |||
} | |||
} |
@@ -0,0 +1,11 @@ | |||
using System; | |||
using System.Collections.Generic; | |||
using System.Text; | |||
namespace Tensorflow.Common.Types | |||
{ | |||
public interface INestable<T> | |||
{ | |||
Nest<T> AsNest(); | |||
} | |||
} |
@@ -0,0 +1,21 @@ | |||
using System; | |||
using System.Collections.Generic; | |||
using System.Text; | |||
namespace Tensorflow.Common.Types | |||
{ | |||
/// <summary> | |||
/// This interface is used when some corresponding python methods have optional args. | |||
/// For example, `Keras.Layer.Apply` generally takes three args as the inputs, while | |||
/// `Keras.Layer.RNN` takes more. Then when calling RNN, you should add `RnnOptionalArgs` | |||
/// as the parameter of the method. | |||
/// </summary> | |||
public interface IOptionalArgs | |||
{ | |||
/// <summary> | |||
/// The identifier of the class. It is not an argument but only something to | |||
/// separate different OptionalArgs. | |||
/// </summary> | |||
string Identifier { get; } | |||
} | |||
} |
@@ -0,0 +1,62 @@ | |||
using System; | |||
using System.Collections.Generic; | |||
using System.Text; | |||
namespace Tensorflow.Common.Types | |||
{ | |||
public static class Nest | |||
{ | |||
/// <summary> | |||
/// Pack the flat items to a nested sequence by the template. | |||
/// </summary> | |||
/// <typeparam name="T"></typeparam> | |||
/// <param name="template"></param> | |||
/// <param name="flatItems"></param> | |||
/// <returns></returns> | |||
public static Nest<TOut> PackSequenceAs<T, TOut>(INestable<T> template, TOut[] flatItems) | |||
{ | |||
return template.AsNest().PackSequence(flatItems); | |||
} | |||
/// <summary> | |||
/// Pack the flat items to a nested sequence by the template. | |||
/// </summary> | |||
/// <typeparam name="T"></typeparam> | |||
/// <param name="template"></param> | |||
/// <param name="flatItems"></param> | |||
/// <returns></returns> | |||
public static Nest<T> PackSequenceAs<T>(INestable<T> template, List<T> flatItems) | |||
{ | |||
return template.AsNest().PackSequence(flatItems.ToArray()); | |||
} | |||
/// <summary> | |||
/// Flatten the nested object. | |||
/// </summary> | |||
/// <typeparam name="T"></typeparam> | |||
/// <param name="nestedObject"></param> | |||
/// <returns></returns> | |||
public static IEnumerable<T> Flatten<T>(INestable<T> nestedObject) | |||
{ | |||
return nestedObject.AsNest().Flatten(); | |||
} | |||
/// <summary> | |||
/// Map the structure with specified function. | |||
/// </summary> | |||
/// <typeparam name="TIn"></typeparam> | |||
/// <typeparam name="TOut"></typeparam> | |||
/// <param name="func"></param> | |||
/// <param name="nestedObject"></param> | |||
/// <returns></returns> | |||
public static INestStructure<TOut> MapStructure<TIn, TOut>(Func<TIn, TOut> func, INestable<TIn> nestedObject) | |||
{ | |||
return nestedObject.AsNest().MapStructure(func); | |||
} | |||
public static bool IsNested<T>(INestable<T> obj) | |||
{ | |||
return obj.AsNest().IsNested(); | |||
} | |||
} | |||
} |
@@ -0,0 +1,485 @@ | |||
using System; | |||
using System.Collections.Generic; | |||
using System.Text; | |||
using Tensorflow.Common.Extensions; | |||
namespace Tensorflow.Common.Types | |||
{ | |||
public enum NestType | |||
{ | |||
Empty, | |||
Node, | |||
List, | |||
Dictionary | |||
} | |||
/// <summary> | |||
/// A nested structure which may inclulde value, list and dictionary. | |||
/// Note that dictionary does not ensure the data order. When using it as IEnumerable, | |||
/// its order is depth-first. | |||
/// </summary> | |||
/// <typeparam name="T"></typeparam> | |||
public class Nest<T> : INestStructure<T>, IEnumerable<T> | |||
{ | |||
private static readonly Nest<T> _empty = new Nest<T>() | |||
{ | |||
NestType = NestType.Empty, | |||
}; | |||
public static Nest<T> Empty => _empty; | |||
public NestType NestType { get; protected set; } | |||
public string? Name { get; set; } | |||
public T? NodeValue { get; protected set; } | |||
public List<INestStructure<T>>? ListValue { get; protected set; } | |||
public Dictionary<string, INestStructure<T>>? DictValue { get; protected set; } | |||
public int ShallowNestedCount | |||
{ | |||
get | |||
{ | |||
if (NestType == NestType.Empty) | |||
{ | |||
return 0; | |||
} | |||
else if (NestType == NestType.Node) | |||
{ | |||
return 1; | |||
} | |||
else if (NestType == NestType.List) | |||
{ | |||
return ListValue!.Count; | |||
} | |||
else // dict | |||
{ | |||
return DictValue!.Count; | |||
} | |||
} | |||
} | |||
public int TotalNestedCount | |||
{ | |||
get | |||
{ | |||
return Flatten().Count(); | |||
} | |||
} | |||
protected Nest() { } | |||
public Nest(T value, string? name = null) | |||
{ | |||
NodeValue = value; | |||
Name = name; | |||
NestType = NestType.Node; | |||
} | |||
public Nest(IEnumerable<INestStructure<T>> values, string? name = null) | |||
{ | |||
ListValue = values.ToList(); | |||
Name = name; | |||
NestType = NestType.List; | |||
} | |||
public Nest(Dictionary<string, INestStructure<T>> value, string? name = null) | |||
{ | |||
DictValue = value; | |||
Name = name; | |||
NestType = NestType.Dictionary; | |||
} | |||
public Nest(Nest<T> other) | |||
{ | |||
NestType = other.NestType; | |||
NodeValue = other.NodeValue; | |||
DictValue = other.DictValue; | |||
ListValue = other.ListValue; | |||
Name = other.Name; | |||
} | |||
public virtual IEnumerable<T> Flatten() | |||
{ | |||
return FlattenInternal(this); | |||
} | |||
public virtual INestStructure<TOut> MapStructure<TOut>(Func<T, TOut> func) | |||
{ | |||
return MapStructureInternal(func); | |||
} | |||
/// <summary> | |||
/// Pack the flat items to a nested sequence by the template. | |||
/// </summary> | |||
/// <param name="flatItems"></param> | |||
/// <returns></returns> | |||
public virtual Nest<TOut> PackSequence<TOut>(TOut[] flatItems) | |||
{ | |||
if(flatItems.Length == 0) | |||
{ | |||
return Nest<TOut>.Empty; | |||
} | |||
int index = 0; | |||
return PackSequenceInternal(this, flatItems, ref index); | |||
} | |||
private static Nest<TOut> PackSequenceInternal<TOut>(Nest<T> template, TOut[] flatItems, ref int index) | |||
{ | |||
if(template.NestType == NestType.Node) | |||
{ | |||
if(index >= flatItems.Length) | |||
{ | |||
throw new InvalidArgumentError("The template and flat items are not matched."); | |||
} | |||
return new Nest<TOut>(flatItems[index++]); | |||
} | |||
else if(template.NestType == NestType.List) | |||
{ | |||
List<Nest<TOut>> nestedObjects = new List<Nest<TOut>>(); | |||
for (int i = 0; i < template.ListValue!.Count; i++) | |||
{ | |||
nestedObjects.Add(PackSequenceInternal(template.ListValue![i].AsNest(), flatItems, ref index)); | |||
} | |||
return new Nest<TOut>(nestedObjects); | |||
} | |||
else if(template.NestType == NestType.Node) | |||
{ | |||
Dictionary<string, INestStructure<TOut>> dict = new Dictionary<string, INestStructure<TOut>>(); | |||
foreach(var (key, value) in template.DictValue!) | |||
{ | |||
dict[key] = PackSequenceInternal(value.AsNest(), flatItems, ref index); | |||
} | |||
return new Nest<TOut>(dict); | |||
} | |||
// Consider Empty as invalid type. | |||
throw new InvalidArgumentError("When using `PackSequenceAs`, the template cannot contain empty node."); | |||
} | |||
public virtual Nest<T> AsNest() | |||
{ | |||
return this; | |||
} | |||
public virtual Nest<T> MergeWith(Nest<T>? other) | |||
{ | |||
if(other is null || other == Nest<T>.Empty) | |||
{ | |||
return this; | |||
} | |||
if(this == Nest<T>.Empty) | |||
{ | |||
return other; | |||
} | |||
if(NestType == NestType.Node && other.NestType == NestType.Node) | |||
{ | |||
return new Nest<T>(new Nest<T>[] { this, other }); | |||
} | |||
else if(NestType == NestType.List && other.NestType == NestType.List) | |||
{ | |||
return new Nest<T>(this.ListValue!.Concat(other.ListValue!)); | |||
} | |||
else if(NestType == NestType.Dictionary && other.NestType == NestType.Dictionary) | |||
{ | |||
return new Nest<T>(this.DictValue!.Concat(other.DictValue!).ToDictionary(x => x.Key, x => x.Value)); | |||
} | |||
else | |||
{ | |||
return new Nest<T>(new Nest<T>[] { this, other }); | |||
} | |||
} | |||
/// <summary> | |||
/// To see if the nested object is really nested. Despite being called `Nest`, sometimes it's actually not | |||
/// nested. For example, [1, 2, 3] is not nested, while [1, [2, 3]] is nested. | |||
/// </summary> | |||
/// <returns></returns> | |||
public bool IsNested() | |||
{ | |||
if(NestType is NestType.Empty or NestType.Node) | |||
{ | |||
return false; | |||
} | |||
else if(NestType is NestType.List) | |||
{ | |||
return ListValue!.Count > 0; | |||
} | |||
else | |||
{ | |||
return DictValue!.Count > 0; | |||
} | |||
} | |||
[Obsolete("The indexer of Tensors is not encouraged because it leads to unclear meanings.")] | |||
public T this[int index] | |||
{ | |||
get | |||
{ | |||
bool success = FindInternal(this, index, out var result); | |||
if (success) | |||
{ | |||
return result; | |||
} | |||
else | |||
{ | |||
throw new IndexOutOfRangeException(); | |||
} | |||
} | |||
set | |||
{ | |||
bool success = SetInternal(this, index, value); | |||
if (!success) | |||
{ | |||
throw new IndexOutOfRangeException(); | |||
} | |||
} | |||
} | |||
/// <summary> | |||
/// If the existing nested structure if of type `Nest[INestStructure[T]]`, we can reduce it | |||
/// to `Nest[T]`. | |||
/// </summary> | |||
/// <typeparam name="TOut"></typeparam> | |||
/// <param name="input"></param> | |||
/// <returns></returns> | |||
public static Nest<T> ReduceFrom<TOut>(INestStructure<TOut> input) where TOut: INestStructure<T> | |||
{ | |||
var nested = input.AsNest(); | |||
return ReduceInternal(nested).AsNest(); | |||
} | |||
private static INestStructure<T> ReduceInternal<TOut>(Nest<TOut> node) where TOut : INestStructure<T> | |||
{ | |||
if(node.NestType == NestType.Empty) | |||
{ | |||
return Nest<T>.Empty; | |||
} | |||
else if(node.NestType == NestType.Node) | |||
{ | |||
return node.NodeValue!.AsNest(); | |||
} | |||
else if(node.NestType == NestType.List) | |||
{ | |||
return new Nest<T>(node.ListValue!.Select(x => ReduceInternal(x.AsNest()))); | |||
} | |||
else // Dictionary type | |||
{ | |||
return new Nest<T>(node.DictValue!.ToDictionary(x => x.Key, x => ReduceInternal(x.Value.AsNest()))); | |||
} | |||
} | |||
private static bool FindInternal(Nest<T> node, int index, out T? result) | |||
{ | |||
if (node.NestType == NestType.Node) | |||
{ | |||
if(index == 0) | |||
{ | |||
result = node.NodeValue!; | |||
return true; | |||
} | |||
result = default(T); | |||
return false; | |||
} | |||
else if (node.NestType == NestType.List) | |||
{ | |||
foreach (var item in node.ListValue!) | |||
{ | |||
if(index == 0) | |||
{ | |||
return FindInternal(item.AsNest(), index, out result); | |||
} | |||
index--; | |||
} | |||
result = default(T); | |||
return false; | |||
} | |||
else if(node.NestType == NestType.Dictionary) | |||
{ | |||
foreach (var item in node.DictValue!.Values) | |||
{ | |||
if (index == 0) | |||
{ | |||
return FindInternal(item.AsNest(), index, out result); | |||
} | |||
index--; | |||
} | |||
result = default(T); | |||
return false; | |||
} | |||
else | |||
{ | |||
result = default(T); | |||
return false; | |||
} | |||
} | |||
private static bool SetInternal(Nest<T> node, int index, T newValue) | |||
{ | |||
if (node.NestType == NestType.Node) | |||
{ | |||
if (index == 0) | |||
{ | |||
node.NodeValue = newValue; | |||
return true; | |||
} | |||
return false; | |||
} | |||
else if (node.NestType == NestType.List) | |||
{ | |||
foreach (var item in node.ListValue!) | |||
{ | |||
if (index == 0) | |||
{ | |||
return SetInternal(item.AsNest(), index, newValue); | |||
} | |||
index--; | |||
} | |||
return false; | |||
} | |||
else if (node.NestType == NestType.Dictionary) | |||
{ | |||
foreach (var item in node.DictValue!.Values) | |||
{ | |||
if (index == 0) | |||
{ | |||
return SetInternal(item.AsNest(), index, newValue); | |||
} | |||
index--; | |||
} | |||
return false; | |||
} | |||
else | |||
{ | |||
return false; | |||
} | |||
} | |||
private static IEnumerable<T> FlattenInternal(Nest<T> node) | |||
{ | |||
if (node.NestType == NestType.Node) | |||
{ | |||
yield return node.NodeValue!; | |||
} | |||
else if (node.NestType == NestType.List) | |||
{ | |||
foreach (var item in node.ListValue!) | |||
{ | |||
foreach(var val in FlattenInternal(item.AsNest())) | |||
{ | |||
yield return val; | |||
} | |||
} | |||
} | |||
else if (node.NestType == NestType.Dictionary) | |||
{ | |||
foreach (var item in node.DictValue!.Values) | |||
{ | |||
foreach (var val in FlattenInternal(item.AsNest())) | |||
{ | |||
yield return val; | |||
} | |||
} | |||
} | |||
} | |||
private Nest<TOut> MapStructureInternal<TOut>(Func<T, TOut> func) | |||
{ | |||
if (NestType == NestType.Node) | |||
{ | |||
return new Nest<TOut>(func(NodeValue!)); | |||
} | |||
else if (NestType == NestType.List) | |||
{ | |||
List<Nest<TOut>> outs = new List<Nest<TOut>>(); | |||
foreach (var item in ListValue!) | |||
{ | |||
outs.Add(item.AsNest().MapStructureInternal(func)); | |||
} | |||
return new Nest<TOut>(outs); | |||
} | |||
else if (NestType == NestType.Dictionary) | |||
{ | |||
Dictionary<string, INestStructure<TOut>> outs = new Dictionary<string, INestStructure<TOut>>(); | |||
foreach (var (key, value) in DictValue!) | |||
{ | |||
outs.Add(key, value.AsNest().MapStructureInternal(func)); | |||
} | |||
return new Nest<TOut>(outs); | |||
} | |||
else | |||
{ | |||
return Nest<TOut>.Empty; | |||
} | |||
} | |||
public IEnumerator<T> GetEnumerator() | |||
{ | |||
return Flatten().GetEnumerator(); | |||
} | |||
IEnumerator IEnumerable.GetEnumerator() | |||
{ | |||
return GetEnumerator(); | |||
} | |||
public override string ToString() | |||
{ | |||
StringBuilder sb = new StringBuilder(); | |||
sb.Append("("); | |||
WriteString(this, sb); | |||
sb.Append(")"); | |||
return sb.ToString(); | |||
} | |||
private static void WriteString(Nest<T> node, StringBuilder sb) | |||
{ | |||
if (!string.IsNullOrEmpty(node.Name)) | |||
{ | |||
sb.Append($"{node.Name}: "); | |||
} | |||
if (node.NestType == NestType.Node) | |||
{ | |||
sb.Append(node.NodeValue!.ToString()); | |||
} | |||
else if (node.NestType == NestType.List) | |||
{ | |||
sb.Append("["); | |||
for(int i = 0; i < node.ListValue!.Count; i++) | |||
{ | |||
WriteString(node.ListValue![i].AsNest(), sb); | |||
if(i != node.ListValue!.Count - 1) | |||
{ | |||
sb.Append(", "); | |||
} | |||
} | |||
sb.Append("]"); | |||
} | |||
else if (node.NestType == NestType.Dictionary) | |||
{ | |||
sb.Append("{"); | |||
int count = node.DictValue!.Count; | |||
int i = 0; | |||
foreach (var (key, value) in node.DictValue!) | |||
{ | |||
sb.Append($"{key}: "); | |||
WriteString(value.AsNest(), sb); | |||
if (i != count - 1) | |||
{ | |||
sb.Append(", "); | |||
} | |||
i++; | |||
} | |||
sb.Append("}"); | |||
} | |||
else | |||
{ | |||
sb.Append("<empty>"); | |||
} | |||
} | |||
public static implicit operator Nest<T>((INestStructure<T>, INestStructure<T>) inputs) | |||
{ | |||
return new Nest<T>(new INestStructure<T>[] { inputs.Item1, inputs.Item2 }); | |||
} | |||
public static implicit operator Nest<T>((INestStructure<T>, INestStructure<T>, INestStructure<T>) inputs) | |||
{ | |||
return new Nest<T>(new INestStructure<T>[] { inputs.Item1, inputs.Item2, inputs.Item3 }); | |||
} | |||
} | |||
} |
@@ -0,0 +1,103 @@ | |||
using System; | |||
using System.Collections.Generic; | |||
using System.Text; | |||
namespace Tensorflow.Common.Types | |||
{ | |||
public class NestDictionary<TKey, TValue> : INestStructure<TValue>, IDictionary<TKey, TValue> where TKey : notnull | |||
{ | |||
public NestType NestType => NestType.Dictionary; | |||
public IDictionary<TKey, TValue> Value { get; set; } | |||
public int ShallowNestedCount => Values.Count; | |||
public int TotalNestedCount => Values.Count; | |||
public NestDictionary(IDictionary<TKey, TValue> dict) | |||
{ | |||
Value = dict; | |||
} | |||
public IEnumerable<TValue> Flatten() | |||
{ | |||
return Value.Select(x => x.Value); | |||
} | |||
public INestStructure<TOut> MapStructure<TOut>(Func<TValue, TOut> func) | |||
{ | |||
return new NestList<TOut>(Value.Select(x => func(x.Value))); | |||
} | |||
public Nest<TValue> AsNest() | |||
{ | |||
return new Nest<TValue>(Value.Values.Select(x => new Nest<TValue>(x))); | |||
} | |||
// Required IDictionary<TKey, TValue> members | |||
public int Count => Value.Count; | |||
public bool IsReadOnly => Value.IsReadOnly; | |||
public ICollection<TKey> Keys => Value.Keys; | |||
public ICollection<TValue> Values => Value.Values; | |||
public void Add(TKey key, TValue value) | |||
{ | |||
Value.Add(key, value); | |||
} | |||
public void Add(KeyValuePair<TKey, TValue> item) | |||
{ | |||
Value.Add(item); | |||
} | |||
public void Clear() | |||
{ | |||
Value.Clear(); | |||
} | |||
public bool Contains(KeyValuePair<TKey, TValue> item) | |||
{ | |||
return Value.Contains(item); | |||
} | |||
public bool ContainsKey(TKey key) | |||
{ | |||
return Value.ContainsKey(key); | |||
} | |||
public void CopyTo(KeyValuePair<TKey, TValue>[] array, int arrayIndex) | |||
{ | |||
Value.CopyTo(array, arrayIndex); | |||
} | |||
public IEnumerator<KeyValuePair<TKey, TValue>> GetEnumerator() | |||
{ | |||
return Value.GetEnumerator(); | |||
} | |||
IEnumerator IEnumerable.GetEnumerator() | |||
{ | |||
return GetEnumerator(); | |||
} | |||
public bool Remove(TKey key) | |||
{ | |||
return Value.Remove(key); | |||
} | |||
public bool Remove(KeyValuePair<TKey, TValue> item) | |||
{ | |||
return Value.Remove(item); | |||
} | |||
public bool TryGetValue(TKey key, out TValue value) | |||
{ | |||
return Value.TryGetValue(key, out value); | |||
} | |||
// Optional IDictionary<TKey, TValue> members | |||
public TValue this[TKey key] | |||
{ | |||
get => Value[key]; | |||
set => Value[key] = value; | |||
} | |||
} | |||
} |
@@ -0,0 +1,53 @@ | |||
using System; | |||
using System.Collections.Generic; | |||
using System.Text; | |||
namespace Tensorflow.Common.Types | |||
{ | |||
/// <summary> | |||
/// The implementation of a list that support nest structure, in which the depth is 1. | |||
/// </summary> | |||
/// <typeparam name="T"></typeparam> | |||
public sealed class NestList<T> : INestStructure<T>, IEnumerable<T> | |||
{ | |||
public NestType NestType => NestType.List; | |||
public List<T> Values { get; set; } | |||
public int ShallowNestedCount => Values.Count; | |||
public int TotalNestedCount => Values.Count; | |||
public NestList(params T[] values) | |||
{ | |||
Values = new List<T>(values); | |||
} | |||
public NestList(IEnumerable<T> values) | |||
{ | |||
Values = new List<T>(values); | |||
} | |||
public IEnumerable<T> Flatten() | |||
{ | |||
return Values; | |||
} | |||
public INestStructure<TOut> MapStructure<TOut>(Func<T, TOut> func) | |||
{ | |||
return new NestList<TOut>(Values.Select(x => func(x))); | |||
} | |||
public Nest<T> AsNest() | |||
{ | |||
return new Nest<T>(Values.Select(x => new Nest<T>(x))); | |||
} | |||
// Enumerator implementation | |||
public IEnumerator<T> GetEnumerator() | |||
{ | |||
return Values.GetEnumerator(); | |||
} | |||
IEnumerator IEnumerable.GetEnumerator() | |||
{ | |||
return GetEnumerator(); | |||
} | |||
} | |||
} |
@@ -0,0 +1,36 @@ | |||
using System; | |||
using System.Collections.Generic; | |||
using System.Text; | |||
namespace Tensorflow.Common.Types | |||
{ | |||
/// <summary> | |||
/// A nested structure with only one element. | |||
/// </summary> | |||
/// <typeparam name="T"></typeparam> | |||
public class NestNode<T> : INestStructure<T> | |||
{ | |||
public NestType NestType => NestType.Node; | |||
public T Value { get; set; } | |||
public int ShallowNestedCount => 1; | |||
public int TotalNestedCount => 1; | |||
public NestNode(T value) | |||
{ | |||
Value = value; | |||
} | |||
public IEnumerable<T> Flatten() | |||
{ | |||
yield return Value; | |||
} | |||
public INestStructure<TOut> MapStructure<TOut>(Func<T, TOut> func) | |||
{ | |||
return new NestNode<TOut>(func(Value)); | |||
} | |||
public Nest<T> AsNest() | |||
{ | |||
return new Nest<T>(Value); | |||
} | |||
} | |||
} |
@@ -3,7 +3,7 @@ using System; | |||
using System.Collections.Generic; | |||
using System.Linq; | |||
namespace Tensorflow.Keras.Saving | |||
namespace Tensorflow.Common.Types | |||
{ | |||
public class TensorShapeConfig | |||
{ |
@@ -161,8 +161,8 @@ namespace Tensorflow | |||
break; | |||
} | |||
yield return (new Tensors(results.Take(FirstInputTensorCount)), results.Length == FirstInputTensorCount ? | |||
null : new Tensors(results.Skip(FirstInputTensorCount))); | |||
yield return (new Tensors(results.Take(FirstInputTensorCount).ToArray()), results.Length == FirstInputTensorCount ? | |||
null : new Tensors(results.Skip(FirstInputTensorCount).ToArray())); | |||
} | |||
} | |||
@@ -352,13 +352,19 @@ namespace Tensorflow.Eager | |||
c_api.TFE_OpSetAttrFloat(op, key, Convert.ToSingle(value)); | |||
break; | |||
case TF_AttrType.TF_ATTR_SHAPE: | |||
var dims = (value as long[]).ToArray(); | |||
long[] dims; | |||
if (value is Shape shape) dims = shape.dims.ToArray(); | |||
else if (value is long[] longs) dims = longs.ToArray(); | |||
else if (value is int[] ints) dims = ints.Select(x => (long)x).ToArray(); | |||
else dims = ((long[])value).ToArray(); | |||
c_api.TFE_OpSetAttrShape(op, key, dims, dims.Length, status); | |||
status.Check(true); | |||
break; | |||
case TF_AttrType.TF_ATTR_FUNC: | |||
if (value is ConcreteFunction func) | |||
c_api.TFE_OpSetAttrFunctionName(op, key, func.func_graph.FuncName, func.func_graph.FuncName.Length); | |||
else if(value is string str) | |||
c_api.TFE_OpSetAttrFunctionName(op, key, str, str.Length); | |||
else | |||
throw new NotImplementedException("TF_AttrType.TF_ATTR_FUNC"); | |||
break; | |||
@@ -65,7 +65,7 @@ namespace Tensorflow.Eager | |||
{ | |||
outgrad_vec = output_gradients.ToList(); | |||
} | |||
var result = tape.ComputeGradient(target_vec, sources_vec, source_tensors_that_are_targets, outgrad_vec, false); | |||
var result = tape.ComputeGradient(target_vec, sources_vec, source_tensors_that_are_targets, outgrad_vec, true); | |||
bool unconnected_gradients_zero = unconnected_gradients == "zero"; | |||
@@ -137,7 +137,6 @@ namespace Tensorflow.Eager | |||
{ | |||
dims[i] = c_api.TFE_TensorHandleDim(handle, i, status); | |||
} | |||
Shape tensor_shape = new(dims); | |||
if(status.Code != TF_Code.TF_OK) | |||
{ | |||
@@ -145,6 +144,7 @@ namespace Tensorflow.Eager | |||
} | |||
else | |||
{ | |||
Shape tensor_shape = new(dims); | |||
return new TapeTensor(id, dtype, tensor_shape); | |||
} | |||
} | |||
@@ -173,8 +173,12 @@ namespace Tensorflow.Eager | |||
return dtype == dtypes.variant || dtype == dtypes.resource; | |||
} | |||
bool ListContainNone(long[] list) | |||
bool ListContainNone(long[]? list) | |||
{ | |||
if(list is null) | |||
{ | |||
return true; | |||
} | |||
int len = list.Length; | |||
if(len == 0) | |||
{ | |||
@@ -10,6 +10,11 @@ namespace Tensorflow.Eager | |||
var str = NDArrayRender.ToString(nd); | |||
return $"tf.Tensor: shape={shape}, dtype={dtype.as_numpy_name()}, numpy={str}"; | |||
} | |||
public string ToString(int maxLength) | |||
{ | |||
var nd = new NDArray(this); | |||
var str = NDArrayRender.ToString(nd, maxLength); | |||
return $"tf.Tensor: shape={shape}, dtype={dtype.as_numpy_name()}, numpy={str}"; | |||
} | |||
} | |||
} |
@@ -0,0 +1,25 @@ | |||
using Tensorflow; | |||
internal static class GraphOnlyOps | |||
{ | |||
/// <summary> | |||
/// Graph-only version of tf.compat.v1.placeholder(), for internal use only. | |||
/// </summary> | |||
/// <param name="dtyype"></param> | |||
/// <param name="shape"></param> | |||
/// <param name="name"></param> | |||
/// <returns></returns> | |||
internal static Tensor graph_placeholder(TF_DataType dtype, Shape shape, string? name = null) | |||
{ | |||
var dtype_value = new AttrValue() { Type = dtype.as_datatype_enum() }; | |||
var shape_value = new AttrValue() { Shape = shape.as_proto() }; | |||
var g = ops.get_default_graph(); | |||
Dictionary<string, AttrValue> attrs = new(); | |||
attrs["dtype"] = dtype_value; | |||
attrs["shape"] = shape_value; | |||
var op = g.create_op("Placeholder", new Tensor[0], new TF_DataType[] { dtype }, | |||
new TF_DataType[0], attrs: attrs, name: name); | |||
var result = op.outputs[0]; | |||
return result; | |||
} | |||
} |
@@ -0,0 +1,19 @@ | |||
using System; | |||
using System.Collections.Generic; | |||
using System.Text; | |||
namespace Tensorflow.Exceptions | |||
{ | |||
public class NotOkStatusException : TensorflowException | |||
{ | |||
public NotOkStatusException() : base() | |||
{ | |||
} | |||
public NotOkStatusException(string message) : base(message) | |||
{ | |||
} | |||
} | |||
} |
@@ -49,12 +49,25 @@ namespace Tensorflow.Framework | |||
public static implicit operator Tensor(IndexedSlices indexedSlices) | |||
{ | |||
return indexedSlices.values; | |||
return _indexed_slices_to_tensor(indexedSlices); | |||
} | |||
public static implicit operator IndexedSlices(Tensor tensor) | |||
{ | |||
return tensor.Tag as IndexedSlices; | |||
} | |||
/// <summary> | |||
/// Converts an IndexedSlices object `value` to a Tensor. | |||
/// </summary> | |||
/// <param name="indexedSlices"></param> | |||
/// <param name="dtype"></param> | |||
/// <param name="name"></param> | |||
/// <param name="as_ref"></param> | |||
/// <returns></returns> | |||
public static Tensor _indexed_slices_to_tensor(IndexedSlices indexedSlices, TF_DataType dtype = TF_DataType.DtInvalid, String name = "", bool as_ref = false) | |||
{ | |||
return gen_math_ops.unsorted_segment_sum(indexedSlices.values, indexedSlices.indices, indexedSlices.dense_shape.slice(0)); | |||
} | |||
} | |||
} |
@@ -1,4 +1,5 @@ | |||
using System.Linq; | |||
using Tensorflow.Eager; | |||
namespace Tensorflow.Framework.Models | |||
{ | |||
@@ -24,5 +25,17 @@ namespace Tensorflow.Framework.Models | |||
shapes.Insert(0, dim); | |||
return new TensorSpec(shapes.ToArray(), _dtype); | |||
} | |||
public static TensorSpec FromTensor(Tensor tensor, string? name = null) | |||
{ | |||
if(tensor is EagerTensor) | |||
{ | |||
return new TensorSpec(tensor.shape, tensor.dtype, name); | |||
} | |||
else | |||
{ | |||
return new TensorSpec(tensor.shape, tensor.dtype, name ?? tensor.name); | |||
} | |||
} | |||
} | |||
} |
@@ -0,0 +1,89 @@ | |||
using Tensorflow.Graphs; | |||
namespace Tensorflow.Framework | |||
{ | |||
internal static class auto_control_deps_utils | |||
{ | |||
public static readonly string READ_ONLY_RESOURCE_INPUTS_ATTR = "_read_only_resource_inputs"; | |||
public static List<int> get_read_only_resource_input_indices_graph(FuncGraph func_graph) | |||
{ | |||
List<int> result = new List<int>(); | |||
// A cache to store the read only resource inputs of an Op. | |||
// Operation -> ObjectIdentitySet of resource handles. | |||
Dictionary<Operation, HashSet<Tensor>> opReadOnlyResourceInputs = | |||
new Dictionary<Operation, HashSet<Tensor>>(); | |||
for (int inputIndex = 0; inputIndex < func_graph.Inputs.Length; inputIndex++) | |||
{ | |||
Tensor t = func_graph.Inputs[inputIndex]; | |||
if (t.dtype != dtypes.resource) | |||
continue; | |||
bool readOnly = true; | |||
foreach (var op in t.consumers()) | |||
{ | |||
if (opReadOnlyResourceInputs.ContainsKey(op)) | |||
{ | |||
if (!opReadOnlyResourceInputs[op].Contains(t)) | |||
{ | |||
readOnly = false; | |||
break; | |||
} | |||
} | |||
else | |||
{ | |||
List<int> indices = _get_read_only_resource_input_indices_op(op); | |||
opReadOnlyResourceInputs[op] = new HashSet<Tensor>( | |||
indices.Select(i => op.inputs[i])); | |||
if (!opReadOnlyResourceInputs[op].Contains(t)) | |||
{ | |||
readOnly = false; | |||
break; | |||
} | |||
} | |||
} | |||
if (readOnly) | |||
result.Add(inputIndex); | |||
} | |||
return result; | |||
} | |||
private static List<int> _get_read_only_resource_input_indices_op(Operation op) | |||
{ | |||
// ignore the RESOURCE_READ_OPS | |||
int[] read_only_input_indices; | |||
try | |||
{ | |||
read_only_input_indices = op.get_attr<int[]>(READ_ONLY_RESOURCE_INPUTS_ATTR); | |||
} | |||
catch (InvalidArgumentError) | |||
{ | |||
return new List<int>(); | |||
} | |||
int read_only_index = 0; | |||
List<int> result = new(); | |||
for (int i = 0; i < op.inputs.Length; i++) | |||
{ | |||
if (read_only_index >= read_only_input_indices.Length) | |||
{ | |||
break; | |||
} | |||
if (op.inputs[i].dtype != dtypes.resource) | |||
{ | |||
continue; | |||
} | |||
if (read_only_index < read_only_input_indices.Length && i == read_only_input_indices[read_only_index]) | |||
{ | |||
result.Add(i); | |||
read_only_index++; | |||
} | |||
} | |||
return result; | |||
} | |||
} | |||
} |
@@ -42,10 +42,10 @@ namespace Tensorflow.Framework | |||
func_graph.as_default(); | |||
importer.import_graph_def(graph_def, name: "", validate_colocation_constraints: false); | |||
var input_tensor_names = fdef.Signature.InputArg.Select(x => nested_to_flat_tensor_name[x.Name]); | |||
func_graph.Inputs = new Tensors(input_tensor_names.Select(x => func_graph.get_tensor_by_name(x))); | |||
func_graph.Inputs = new Tensors(input_tensor_names.Select(x => func_graph.get_tensor_by_name(x)).ToArray()); | |||
var output_tensor_names = fdef.Signature.OutputArg.Select(x => nested_to_flat_tensor_name[fdef.Ret[x.Name]]); | |||
func_graph.Outputs = new Tensors(output_tensor_names.Select(x => func_graph.get_tensor_by_name(x))); | |||
func_graph.Outputs = new Tensors(output_tensor_names.Select(x => func_graph.get_tensor_by_name(x)).ToArray()); | |||
// TODO(Rinne): func_graph.ControlOutputs | |||
_set_handle_data(func_graph, fdef); | |||
@@ -8,6 +8,7 @@ using Tensorflow.Gradients; | |||
using Tensorflow.Graphs; | |||
using Tensorflow.Train; | |||
using Tensorflow.Util; | |||
using Tensorflow.Common.Extensions; | |||
using static Tensorflow.Binding; | |||
namespace Tensorflow.Functions | |||
@@ -40,6 +41,18 @@ namespace Tensorflow.Functions | |||
public Tensor[] FlatStructuredOutputs => func_graph.FlatStructuredOutputs; | |||
public IEnumerable<IVariableV1> Variables => func_graph.Variables; | |||
public IEnumerable<IVariableV1> TrainableVariables => func_graph.TrainableVariables; | |||
internal NameAttrList AsNameAttrList | |||
{ | |||
get | |||
{ | |||
NameAttrList ret = new() { Name = this.Name }; | |||
foreach (var (name, value) in _attrs) | |||
{ | |||
ret.Attr[name] = value; | |||
} | |||
return ret; | |||
} | |||
} | |||
public ConcreteFunction(string name) | |||
{ | |||
@@ -3,4 +3,7 @@ global using System.Collections.Generic; | |||
global using System.Text; | |||
global using System.Collections; | |||
global using System.Data; | |||
global using System.Linq; | |||
global using System.Linq; | |||
global using Tensorflow.Keras.Engine; | |||
global using Tensorflow.Framework.Models; | |||
global using static Tensorflow.Binding; |
@@ -90,8 +90,7 @@ namespace Tensorflow.Gradients | |||
? input_values[0].rank + dim_int | |||
: dim_int % input_values[0].rank; | |||
var sizes = input_values.Select(x => x.shape[non_neg_concat_dim]).ToArray(); | |||
var sizes_tensor = constant_op.constant(sizes); | |||
out_grads = array_ops.split(grad, sizes_tensor, non_neg_concat_dim).ToList(); | |||
out_grads = array_ops.split(grad, sizes.Select(x => (int)x).ToArray(), ops.convert_to_tensor(non_neg_concat_dim)).ToList(); | |||
} | |||
else if (constant_op.is_constant(concat_dim)) | |||
{ | |||
@@ -127,7 +126,7 @@ namespace Tensorflow.Gradients | |||
new Tensor[] { non_neg_concat_dim, tf.constant(0) }, | |||
new Tensor[] { tf.constant(1), tf.constant(-1) }); | |||
var squeeze_sizes = array_ops.squeeze(slice); | |||
out_grads = array_ops.split(axis: grad, value: squeeze_sizes, num_split: (int)non_neg_concat_dim).ToList(); | |||
out_grads = array_ops.split(axis: grad, value: squeeze_sizes, num_or_size_splits: (int)non_neg_concat_dim).ToList(); | |||
} | |||
else | |||
{ | |||
@@ -374,5 +373,13 @@ namespace Tensorflow.Gradients | |||
var p = op.inputs[1]; | |||
return new Tensor[] { array_ops.transpose(grads[0], array_ops.invert_permutation(p)), null }; | |||
} | |||
[RegisterGradient("ReverseV2")] | |||
public static Tensor[] _ReverseV2Grad(Operation op, Tensor[] grads) | |||
{ | |||
var grad = grads[0]; | |||
var axis = op.inputs[1]; | |||
return new Tensor[] { array_ops.reverse(grad, axis), null }; | |||
} | |||
} | |||
} |
@@ -117,6 +117,137 @@ namespace Tensorflow.Gradients | |||
}; | |||
} | |||
public static string ellipsis = "..."; | |||
[RegisterGradient("Einsum")] | |||
public static Tensor[] _EinsumGrad(Operation op, Tensor[] grads) | |||
{ | |||
// Gradient for Einsum. | |||
string equation = (string)op.get_attr("equation"); | |||
string[] split_equation = equation.Split(new string[] { "->" }, StringSplitOptions.None); | |||
var input_subs = split_equation[0]; | |||
var output_subs = split_equation[1]; | |||
if (op.inputs.Length == 1) | |||
{ | |||
var input_shape = array_ops.shape(op.inputs[0]); | |||
var reduced_label_set = new HashSet<char>(new HashSet<char>(input_subs).Except(new HashSet<char>(output_subs + ellipsis))); | |||
if (reduced_label_set.Count == 0) | |||
return new Tensor[] { math_ops.einsum(string.Format("{0}->{1}", output_subs, input_subs), new Tensors(grads)) }; | |||
return new Tensor[] { _GetGradReduced(new Tensors(grads), output_subs, input_subs, input_shape, reduced_label_set) }; | |||
} | |||
string[] split_input_subs = input_subs.Split(new string[] { "," }, StringSplitOptions.None); | |||
var x_subs = split_input_subs[0]; | |||
var y_subs = split_input_subs[1]; | |||
// Add ellipsis for broadcasted dimensions if any operand does not have it. | |||
// This is because the equation "...ij,jk->ik" may be valid if the 0th input's | |||
// batch shape is empty, but the VJP equation "jk,ik->...ij" is not valid | |||
// because only the output subscripts contain ellipsis. | |||
if (output_subs.Contains(ellipsis)) | |||
{ | |||
if (!x_subs.Contains(ellipsis)) | |||
x_subs += ellipsis; | |||
if (!y_subs.Contains(ellipsis)) | |||
y_subs += ellipsis; | |||
} | |||
// Obtain the gradients wrt the inputs x and y, without taking into account | |||
// the unbroadcasting. | |||
var x = op.inputs[0]; | |||
var y = op.inputs[1]; | |||
if (grads.GetDataType().is_complex()) | |||
{ | |||
x = math_ops.conj(x); | |||
y = math_ops.conj(y); | |||
} | |||
var x_shape = array_ops.shape(x); | |||
var y_shape = array_ops.shape(y); | |||
var grad_x = _GetGradWrt(grads, y, x_shape, x_subs, y_subs, output_subs); | |||
var grad_y = _GetGradWrt(grads, x, y_shape, y_subs, x_subs, output_subs); | |||
if (!output_subs.Contains(ellipsis)) | |||
return new Tensor[] { grad_x, grad_y }; | |||
var bx = _GetBcastSubshape(x_subs); | |||
int bx_start = bx[0], bx_end = bx[1]; | |||
var by = _GetBcastSubshape(y_subs); | |||
int by_start = by[0], by_end = by[1]; | |||
var x_shape_static = x.shape; | |||
var y_shape_static = y.shape; | |||
if(x_shape_static.IsFullyDefined && | |||
y_shape_static.IsFullyDefined && | |||
x_shape_static[string.Format("{0}:{1}",bx_start,bx_end)] == y_shape_static[string.Format("{0}:{1}", by_start, by_end)]) | |||
return new Tensor[] { grad_x, grad_y }; | |||
var r = gen_array_ops.broadcast_gradient_args(x_shape[string.Format("{0}:{1}", bx_start, bx_end)], | |||
y_shape[string.Format("{0}:{1}", by_start, by_end)]); | |||
var rx = r[0]; | |||
var ry = r[1]; | |||
grad_x = array_ops.reshape(math_ops.reduce_sum(grad_x, bx_start + rx), x_shape); | |||
grad_y = array_ops.reshape(math_ops.reduce_sum(grad_y, by_start + ry), y_shape); | |||
return new Tensor[] { grad_x, grad_y }; | |||
} | |||
protected static Tensor _GetGradWrt(Tensor[] output_grads, Tensor other_operand, Tensor input_shape, | |||
string input_subs, string other_subs, string output_subs) | |||
{ | |||
var reduced_label_set = new HashSet<char>(new HashSet<char>(input_subs).Except(new HashSet<char>(output_subs + other_subs + "."))); | |||
var left_subs = string.Join("", input_subs.Where(s => !reduced_label_set.Contains(s))); | |||
var grad_reduced = math_ops.einsum(string.Format("{0},{1}->{2}", output_subs, other_subs, left_subs), new Tensors((Tensors)output_grads, other_operand)); | |||
if (reduced_label_set.Count == 0) | |||
return grad_reduced; | |||
return _GetGradReduced(grad_reduced, left_subs, input_subs, input_shape, reduced_label_set); | |||
} | |||
protected static Tensor _GetGradReduced(Tensor output_grad, string output_subs, string input_subs, Tensor input_shape, HashSet<char> reduced_label_set) | |||
{ | |||
string reduced_subs; | |||
Tensor reduced_dims; | |||
List<int> reduced_axes; | |||
_GetReducedSubscripts(reduced_label_set, input_shape, input_subs, out reduced_subs, out reduced_dims, out reduced_axes); | |||
bool has_repeated_labels = ( | |||
new HashSet<char>(input_subs).Count + new HashSet<char>(output_subs).Count < | |||
input_subs.Length + output_subs.Length); | |||
var input_subs_without_reduced_labels = string.Join("", input_subs.Where(s => !reduced_label_set.Contains(s))); | |||
if (!has_repeated_labels && input_subs_without_reduced_labels == output_subs) | |||
{ | |||
var reduced_shape = math_ops.reduced_shape(input_shape, ops.convert_to_tensor(reduced_axes)); | |||
return gen_array_ops.broadcast_to(array_ops.reshape(output_grad, reduced_shape), input_shape); | |||
} | |||
else | |||
{ | |||
var grad_shape_with_reduced_labels = array_ops.concat(new Tensor[] { reduced_dims, array_ops.shape(new Tensors(output_grad)) }, axis: 0); | |||
var reduced_shape = array_ops.concat(new Tensor[] { array_ops.ones(reduced_label_set.Count, dtype: dtypes.int32), array_ops.shape(new Tensors(output_grad)) }, axis: 0); | |||
var broadcasted_grad = gen_array_ops.broadcast_to(array_ops.reshape(output_grad, reduced_shape), grad_shape_with_reduced_labels); | |||
return math_ops.einsum(string.Format("{0}->{1}", reduced_subs + output_subs, input_subs), new Tensors(broadcasted_grad)); | |||
} | |||
} | |||
protected static void _GetReducedSubscripts(HashSet<char> reduced_label_set, Tensor input_shape, string subscripts, out string reduced_subs, out Tensor reduced_dims, out List<int> reduced_axes) | |||
{ | |||
reduced_subs = string.Join("", reduced_label_set.Select(c => c.ToString())); | |||
reduced_axes = reduced_subs.Select(s => _GetAxisFromLabel(subscripts, s)).ToList(); | |||
reduced_dims = array_ops.stack(reduced_axes.Select(ax => input_shape[ax]).ToList()); | |||
} | |||
protected static int _GetAxisFromLabel(string subscripts, char label) | |||
{ | |||
var splits = subscripts.Split(new string[] { ellipsis }, StringSplitOptions.None); | |||
var index = splits[0].IndexOf(label); | |||
if (index != -1) return index; | |||
if (splits.Length < 2) throw new OutOfRangeError(); | |||
index = splits[1].IndexOf(label); | |||
if (index != -1) return index; | |||
throw new ValueError(); | |||
} | |||
protected static int[] _GetBcastSubshape(string subscripts) | |||
{ | |||
int start = subscripts.IndexOf(ellipsis); | |||
if (start == -1) return new int[] { 0, 0 }; | |||
int remaining = subscripts.Length - (start + ellipsis.Length); | |||
int end; | |||
if (remaining > 0) end = remaining; | |||
else throw new Exception(); | |||
return new int[] { start, end }; | |||
} | |||
/// <summary> | |||
/// Returns grad * exp(x). | |||
/// </summary> | |||
@@ -365,6 +365,23 @@ namespace Tensorflow.Gradients | |||
}; | |||
} | |||
[RegisterGradient("AvgPool")] | |||
public static Tensor[] _AvgPoolGrad(Operation op, Tensor[] grads) | |||
{ | |||
Tensor grad = grads[0]; | |||
return new Tensor[] | |||
{ | |||
gen_nn_ops.avg_pool_grad( | |||
array_ops.shape(op.inputs[0]), | |||
grad, | |||
op.get_attr_list<int>("ksize"), | |||
op.get_attr_list<int>("strides"), | |||
op.get_attr<string>("padding"), | |||
op.get_attr<string>("data_format")) | |||
}; | |||
} | |||
/// <summary> | |||
/// Return the gradients for TopK. | |||
/// </summary> | |||
@@ -81,7 +81,7 @@ public class FuncGraph : Graph, IDisposable | |||
public IEnumerable<IVariableV1> TrainableVariables => Variables.Where(v => v.Trainable); | |||
public Dictionary<string, AttrValue> Attrs { get; set; } | |||
Dictionary<long, (Tensor, Tensor)> _captures | |||
internal Dictionary<long, (Tensor, Tensor)> _captures | |||
= new Dictionary<long, (Tensor, Tensor)>(); | |||
public Tensor[] external_captures | |||
@@ -399,7 +399,7 @@ public class FuncGraph : Graph, IDisposable | |||
var flat_func_args = nest.flatten(func_args as object); | |||
var flat_func_kwargs = nest.flatten(func_kwargs as object); | |||
func_graph.Inputs = new Tensors(flat_func_args.concat(flat_func_kwargs) | |||
.Where(x => x is Tensor).Select(x => (Tensor)x)); | |||
.Where(x => x is Tensor).Select(x => (Tensor)x).ToArray()); | |||
//var func_args_before = nest.pack_sequence_as(func_args, flat_func_args, true); | |||
//var func_kwargs_before = nest.pack_sequence_as(func_kwargs, flat_func_kwargs, true); | |||
@@ -544,12 +544,12 @@ public class FuncGraph : Graph, IDisposable | |||
Tensor placeholder; | |||
try | |||
{ | |||
placeholder = tf.placeholder(tensor.dtype, tensor.shape, name); | |||
placeholder = GraphOnlyOps.graph_placeholder(tensor.dtype, tensor.shape, name); | |||
} | |||
catch (ValueError) | |||
catch (ValueError ex) | |||
{ | |||
// TODO(Rinne): Add warning here. | |||
placeholder = tf.placeholder(tensor.dtype, tensor.shape); | |||
tf.Logger.Warning(ex.ToString()); | |||
placeholder = GraphOnlyOps.graph_placeholder(tensor.dtype, tensor.shape); | |||
} | |||
handle_data_util.copy_handle_data(tensor, placeholder); | |||
if (name is not null) | |||
@@ -575,12 +575,12 @@ public class FuncGraph : Graph, IDisposable | |||
Tensor placeholder; | |||
try | |||
{ | |||
placeholder = tf.placeholder(spec.dtype, spec.shape, requested_name); | |||
placeholder = GraphOnlyOps.graph_placeholder(spec.dtype, spec.shape, requested_name); | |||
} | |||
catch (ValueError) | |||
{ | |||
// TODO(Rinne): Add warning here. | |||
placeholder = tf.placeholder(spec.dtype, spec.shape); | |||
placeholder = GraphOnlyOps.graph_placeholder(spec.dtype, spec.shape); | |||
} | |||
if (name is not null) | |||
{ | |||
@@ -129,7 +129,7 @@ namespace Tensorflow | |||
} | |||
} | |||
protected Graph outer_graph; | |||
internal Graph outer_graph; | |||
public Graph OuterGraph => outer_graph; | |||
public Dictionary<string, EagerDefinedFunction> Functions => _functions; | |||
public SafeGraphHandle c_graph => _handle; | |||
@@ -0,0 +1,10 @@ | |||
using System; | |||
using System.Collections.Generic; | |||
using System.Text; | |||
namespace Tensorflow.Keras.ArgsDefinition | |||
{ | |||
public class ExponentialArgs : LayerArgs | |||
{ | |||
} | |||
} |
@@ -0,0 +1,10 @@ | |||
using System; | |||
using System.Collections.Generic; | |||
using System.Text; | |||
namespace Tensorflow.Keras.ArgsDefinition | |||
{ | |||
public class HardSigmoidArgs : LayerArgs | |||
{ | |||
} | |||
} |
@@ -0,0 +1,11 @@ | |||
using System; | |||
using System.Collections.Generic; | |||
using System.Text; | |||
namespace Tensorflow.Keras.ArgsDefinition | |||
{ | |||
public class SELUArgs : LayerArgs | |||
{ | |||
} | |||
} |
@@ -0,0 +1,10 @@ | |||
using System; | |||
using System.Collections.Generic; | |||
using System.Text; | |||
namespace Tensorflow.Keras.ArgsDefinition | |||
{ | |||
public class SoftplusArgs : LayerArgs | |||
{ | |||
} | |||
} |
@@ -0,0 +1,10 @@ | |||
using System; | |||
using System.Collections.Generic; | |||
using System.Text; | |||
namespace Tensorflow.Keras.ArgsDefinition | |||
{ | |||
public class SoftsignArgs : LayerArgs | |||
{ | |||
} | |||
} |
@@ -0,0 +1,10 @@ | |||
using System; | |||
using System.Collections.Generic; | |||
using System.Text; | |||
namespace Tensorflow.Keras.ArgsDefinition | |||
{ | |||
public class SwishArgs : LayerArgs | |||
{ | |||
} | |||
} |
@@ -0,0 +1,10 @@ | |||
using System; | |||
using System.Collections.Generic; | |||
using System.Text; | |||
namespace Tensorflow.Keras.ArgsDefinition | |||
{ | |||
public class TanhArgs : LayerArgs | |||
{ | |||
} | |||
} |
@@ -0,0 +1,10 @@ | |||
using System; | |||
using System.Collections.Generic; | |||
using System.Text; | |||
namespace Tensorflow.Keras.ArgsDefinition | |||
{ | |||
public class Conv2DTransposeArgs : Conv2DArgs | |||
{ | |||
} | |||
} |
@@ -0,0 +1,10 @@ | |||
using System; | |||
using System.Collections.Generic; | |||
using System.Text; | |||
namespace Tensorflow.Keras.ArgsDefinition | |||
{ | |||
public class AddArgs : MergeArgs | |||
{ | |||
} | |||
} |
@@ -0,0 +1,10 @@ | |||
using System; | |||
using System.Collections.Generic; | |||
using System.Text; | |||
namespace Tensorflow.Keras.ArgsDefinition | |||
{ | |||
public class ConcatenateArgs : MergeArgs | |||
{ | |||
} | |||
} |
@@ -0,0 +1,10 @@ | |||
using System; | |||
using System.Collections.Generic; | |||
using System.Text; | |||
namespace Tensorflow.Keras.ArgsDefinition | |||
{ | |||
public class SubtractArgs : MergeArgs | |||
{ | |||
} | |||
} |
@@ -0,0 +1,10 @@ | |||
using System; | |||
using System.Collections.Generic; | |||
using System.Text; | |||
namespace Tensorflow.Keras.ArgsDefinition | |||
{ | |||
public class GlobalAveragePooling1DArgs : Pooling1DArgs | |||
{ | |||
} | |||
} |
@@ -0,0 +1,10 @@ | |||
using System; | |||
using System.Collections.Generic; | |||
using System.Text; | |||
namespace Tensorflow.Keras.ArgsDefinition | |||
{ | |||
public class GlobalAveragePooling2DArgs : Pooling2DArgs | |||
{ | |||
} | |||
} |
@@ -0,0 +1,10 @@ | |||
using System; | |||
using System.Collections.Generic; | |||
using System.Text; | |||
namespace Tensorflow.Keras.ArgsDefinition | |||
{ | |||
public class GlobalMaxPooling1DArgs : Pooling1DArgs | |||
{ | |||
} | |||
} |
@@ -0,0 +1,10 @@ | |||
using System; | |||
using System.Collections.Generic; | |||
using System.Text; | |||
namespace Tensorflow.Keras.ArgsDefinition | |||
{ | |||
public class GlobalMaxPooling2DArgs : Pooling2DArgs | |||
{ | |||
} | |||
} |
@@ -0,0 +1,10 @@ | |||
using System; | |||
using System.Collections.Generic; | |||
using System.Text; | |||
namespace Tensorflow.Keras.ArgsDefinition | |||
{ | |||
public class MaxPooling1DArgs : Pooling1DArgs | |||
{ | |||
} | |||
} |
@@ -7,7 +7,7 @@ namespace Tensorflow.Keras.ArgsDefinition | |||
[JsonProperty("size")] | |||
public Shape Size { get; set; } | |||
[JsonProperty("data_format")] | |||
public string DataFormat { get; set; } | |||
public string DataFormat { get; set; } = "channels_last"; | |||
/// <summary> | |||
/// 'nearest', 'bilinear' | |||
/// </summary> | |||
@@ -0,0 +1,10 @@ | |||
using Newtonsoft.Json; | |||
namespace Tensorflow.Keras.ArgsDefinition | |||
{ | |||
public class UpSampling1DArgs : AutoSerializeLayerArgs | |||
{ | |||
[JsonProperty("size")] | |||
public int Size { get; set; } | |||
} | |||
} |
@@ -0,0 +1,20 @@ | |||
using Newtonsoft.Json; | |||
using System; | |||
using System.Collections.Generic; | |||
using System.Text; | |||
using Tensorflow.NumPy; | |||
namespace Tensorflow.Keras.ArgsDefinition | |||
{ | |||
public class BidirectionalArgs : AutoSerializeLayerArgs | |||
{ | |||
[JsonProperty("layer")] | |||
public ILayer Layer { get; set; } | |||
[JsonProperty("merge_mode")] | |||
public string? MergeMode { get; set; } | |||
[JsonProperty("backward_layer")] | |||
public ILayer BackwardLayer { get; set; } | |||
public NDArray Weights { get; set; } | |||
} | |||
} |
@@ -0,0 +1,29 @@ | |||
using System; | |||
using System.Collections.Generic; | |||
using System.Text; | |||
namespace Tensorflow.Keras.ArgsDefinition | |||
{ | |||
public class GRUArgs : AutoSerializeLayerArgs | |||
{ | |||
public int Units { get; set; } | |||
public Activation Activation { get; set; } | |||
public Activation RecurrentActivation { get; set; } | |||
public bool UseBias { get; set; } = true; | |||
public float Dropout { get; set; } = .0f; | |||
public float RecurrentDropout { get; set; } = .0f; | |||
public IInitializer KernelInitializer { get; set; } | |||
public IInitializer RecurrentInitializer { get; set; } | |||
public IInitializer BiasInitializer { get; set; } | |||
public bool ReturnSequences { get;set; } | |||
public bool ReturnState { get;set; } | |||
public bool GoBackwards { get;set; } | |||
public bool Stateful { get;set; } | |||
public bool Unroll { get;set; } | |||
public bool TimeMajor { get;set; } | |||
public bool ResetAfter { get;set; } | |||
public int Implementation { get; set; } = 2; | |||
} | |||
} |
@@ -0,0 +1,39 @@ | |||
using Newtonsoft.Json; | |||
using System; | |||
using System.Collections.Generic; | |||
using System.Text; | |||
namespace Tensorflow.Keras.ArgsDefinition | |||
{ | |||
public class GRUCellArgs : AutoSerializeLayerArgs | |||
{ | |||
[JsonProperty("units")] | |||
public int Units { get; set; } | |||
// TODO(Rinne): lack of initialized value of Activation. Merging keras | |||
// into tf.net could resolve it. | |||
[JsonProperty("activation")] | |||
public Activation Activation { get; set; } | |||
[JsonProperty("recurrent_activation")] | |||
public Activation RecurrentActivation { get; set; } | |||
[JsonProperty("use_bias")] | |||
public bool UseBias { get; set; } = true; | |||
[JsonProperty("dropout")] | |||
public float Dropout { get; set; } = .0f; | |||
[JsonProperty("recurrent_dropout")] | |||
public float RecurrentDropout { get; set; } = .0f; | |||
[JsonProperty("kernel_initializer")] | |||
public IInitializer KernelInitializer { get; set; } | |||
[JsonProperty("recurrent_initializer")] | |||
public IInitializer RecurrentInitializer { get; set; } | |||
[JsonProperty("bias_initializer")] | |||
public IInitializer BiasInitializer { get; set; } | |||
[JsonProperty("reset_after")] | |||
public bool ResetAfter { get;set; } | |||
[JsonProperty("implementation")] | |||
public int Implementation { get; set; } = 2; | |||
} | |||
} |
@@ -0,0 +1,13 @@ | |||
using System; | |||
using System.Collections.Generic; | |||
using System.Text; | |||
namespace Tensorflow.Keras.ArgsDefinition | |||
{ | |||
public class GRUOptionalArgs | |||
{ | |||
public string Identifier => "GRU"; | |||
public Tensor Mask { get; set; } = null; | |||
} | |||
} |
@@ -1,11 +1,14 @@ | |||
namespace Tensorflow.Keras.ArgsDefinition.Rnn | |||
namespace Tensorflow.Keras.ArgsDefinition | |||
{ | |||
public class LSTMArgs : RNNArgs | |||
{ | |||
// TODO: maybe change the `RNNArgs` and implement this class. | |||
public bool UnitForgetBias { get; set; } | |||
public float Dropout { get; set; } | |||
public float RecurrentDropout { get; set; } | |||
public int Implementation { get; set; } | |||
public LSTMArgs Clone() | |||
{ | |||
return (LSTMArgs)MemberwiseClone(); | |||
} | |||
} | |||
} |
@@ -1,7 +1,35 @@ | |||
namespace Tensorflow.Keras.ArgsDefinition.Rnn | |||
using Newtonsoft.Json; | |||
using static Tensorflow.Binding; | |||
namespace Tensorflow.Keras.ArgsDefinition | |||
{ | |||
// TODO: complete the implementation | |||
public class LSTMCellArgs : LayerArgs | |||
public class LSTMCellArgs : AutoSerializeLayerArgs | |||
{ | |||
[JsonProperty("units")] | |||
public int Units { get; set; } | |||
// TODO(Rinne): lack of initialized value of Activation. Merging keras | |||
// into tf.net could resolve it. | |||
[JsonProperty("activation")] | |||
public Activation Activation { get; set; } | |||
[JsonProperty("recurrent_activation")] | |||
public Activation RecurrentActivation { get; set; } | |||
[JsonProperty("use_bias")] | |||
public bool UseBias { get; set; } = true; | |||
[JsonProperty("dropout")] | |||
public float Dropout { get; set; } = .0f; | |||
[JsonProperty("recurrent_dropout")] | |||
public float RecurrentDropout { get; set; } = .0f; | |||
[JsonProperty("kernel_initializer")] | |||
public IInitializer KernelInitializer { get; set; } | |||
[JsonProperty("recurrent_initializer")] | |||
public IInitializer RecurrentInitializer { get; set; } | |||
[JsonProperty("bias_initializer")] | |||
public IInitializer BiasInitializer { get; set; } | |||
[JsonProperty("unit_forget_bias")] | |||
public bool UnitForgetBias { get; set; } = true; | |||
[JsonProperty("implementation")] | |||
public int Implementation { get; set; } = 2; | |||
} | |||
} |
@@ -1,17 +1,12 @@ | |||
using Newtonsoft.Json; | |||
using System.Collections.Generic; | |||
using Tensorflow.Keras.Layers; | |||
namespace Tensorflow.Keras.ArgsDefinition.Rnn | |||
namespace Tensorflow.Keras.ArgsDefinition | |||
{ | |||
// TODO(Rinne): add regularizers. | |||
public class RNNArgs : AutoSerializeLayerArgs | |||
{ | |||
public interface IRnnArgCell : ILayer | |||
{ | |||
object state_size { get; } | |||
} | |||
[JsonProperty("cell")] | |||
// TODO: the cell should be serialized with `serialize_keras_object`. | |||
public IRnnArgCell Cell { get; set; } = null; | |||
[JsonProperty("return_sequences")] | |||
public bool ReturnSequences { get; set; } = false; | |||
[JsonProperty("return_state")] | |||
@@ -24,31 +19,31 @@ namespace Tensorflow.Keras.ArgsDefinition.Rnn | |||
public bool Unroll { get; set; } = false; | |||
[JsonProperty("time_major")] | |||
public bool TimeMajor { get; set; } = false; | |||
// TODO: Add `num_constants` and `zero_output_for_mask`. | |||
public Dictionary<string, object> Kwargs { get; set; } = null; | |||
public int? InputDim { get; set; } | |||
public int? InputLength { get; set; } | |||
// TODO: Add `num_constants` and `zero_output_for_mask`. | |||
[JsonProperty("units")] | |||
public int Units { get; set; } | |||
[JsonProperty("activation")] | |||
public Activation Activation { get; set; } | |||
[JsonProperty("recurrent_activation")] | |||
public Activation RecurrentActivation { get; set; } | |||
[JsonProperty("use_bias")] | |||
public bool UseBias { get; set; } = true; | |||
public IInitializer KernelInitializer { get; set; } | |||
public IInitializer RecurrentInitializer { get; set; } | |||
public IInitializer BiasInitializer { get; set; } | |||
[JsonProperty("dropout")] | |||
public float Dropout { get; set; } = .0f; | |||
[JsonProperty("zero_output_for_mask")] | |||
public bool ZeroOutputForMask { get; set; } = false; | |||
[JsonProperty("recurrent_dropout")] | |||
public float RecurrentDropout { get; set; } = .0f; | |||
// kernel_regularizer=None, | |||
// recurrent_regularizer=None, | |||
// bias_regularizer=None, | |||
// activity_regularizer=None, | |||
// kernel_constraint=None, | |||
// recurrent_constraint=None, | |||
// bias_constraint=None, | |||
// dropout=0., | |||
// recurrent_dropout=0., | |||
// return_sequences=False, | |||
// return_state=False, | |||
// go_backwards=False, | |||
// stateful=False, | |||
// unroll=False, | |||
// **kwargs): | |||
public RNNArgs Clone() | |||
{ | |||
return (RNNArgs)MemberwiseClone(); | |||
} | |||
} | |||
} |
@@ -0,0 +1,14 @@ | |||
using System; | |||
using System.Collections.Generic; | |||
using System.Text; | |||
using Tensorflow.Common.Types; | |||
namespace Tensorflow.Keras.ArgsDefinition | |||
{ | |||
public class RnnOptionalArgs: IOptionalArgs | |||
{ | |||
public string Identifier => "Rnn"; | |||
public Tensor Mask { get; set; } = null; | |||
public Tensors Constants { get; set; } = null; | |||
} | |||
} |
@@ -1,4 +1,4 @@ | |||
namespace Tensorflow.Keras.ArgsDefinition.Rnn | |||
namespace Tensorflow.Keras.ArgsDefinition | |||
{ | |||
public class SimpleRNNArgs : RNNArgs | |||
{ | |||
@@ -0,0 +1,27 @@ | |||
using Newtonsoft.Json; | |||
namespace Tensorflow.Keras.ArgsDefinition | |||
{ | |||
public class SimpleRNNCellArgs: AutoSerializeLayerArgs | |||
{ | |||
[JsonProperty("units")] | |||
public int Units { get; set; } | |||
// TODO(Rinne): lack of initialized value of Activation. Merging keras | |||
// into tf.net could resolve it. | |||
[JsonProperty("activation")] | |||
public Activation Activation { get; set; } | |||
[JsonProperty("use_bias")] | |||
public bool UseBias { get; set; } = true; | |||
[JsonProperty("dropout")] | |||
public float Dropout { get; set; } = .0f; | |||
[JsonProperty("recurrent_dropout")] | |||
public float RecurrentDropout { get; set; } = .0f; | |||
[JsonProperty("kernel_initializer")] | |||
public IInitializer KernelInitializer { get; set; } | |||
[JsonProperty("recurrent_initializer")] | |||
public IInitializer RecurrentInitializer { get; set; } | |||
[JsonProperty("bias_initializer")] | |||
public IInitializer BiasInitializer { get; set; } | |||
} | |||
} |
@@ -1,10 +1,10 @@ | |||
using System.Collections.Generic; | |||
using Tensorflow.Keras.Layers; | |||
namespace Tensorflow.Keras.ArgsDefinition.Rnn | |||
namespace Tensorflow.Keras.ArgsDefinition | |||
{ | |||
public class StackedRNNCellsArgs : LayerArgs | |||
{ | |||
public IList<RnnCell> Cells { get; set; } | |||
public Dictionary<string, object> Kwargs { get; set; } = null; | |||
public bool ReverseStateOrder = false; | |||
} | |||
} |
@@ -0,0 +1,24 @@ | |||
using Newtonsoft.Json; | |||
using System; | |||
using System.Collections.Generic; | |||
using System.Runtime.CompilerServices; | |||
using System.Text; | |||
namespace Tensorflow.Keras.ArgsDefinition | |||
{ | |||
public class WrapperArgs : AutoSerializeLayerArgs | |||
{ | |||
[JsonProperty("layer")] | |||
public ILayer Layer { get; set; } | |||
public WrapperArgs(ILayer layer) | |||
{ | |||
Layer = layer; | |||
} | |||
public static implicit operator WrapperArgs(BidirectionalArgs args) | |||
=> new WrapperArgs(args.Layer); | |||
} | |||
} |
@@ -14,6 +14,9 @@ public interface ICallback | |||
void on_predict_batch_end(long end_step, Dictionary<string, Tensors> logs); | |||
void on_predict_end(); | |||
void on_test_begin(); | |||
void on_test_end(Dictionary<string, float> logs); | |||
void on_test_batch_begin(long step); | |||
void on_test_batch_end(long end_step, Dictionary<string, float> logs); | |||
} |
@@ -60,7 +60,7 @@ public interface IModel : ILayer | |||
bool skip_mismatch = false, | |||
object options = null); | |||
Dictionary<string, float> evaluate(Tensor x, Tensor y, | |||
Dictionary<string, float> evaluate(NDArray x, NDArray y, | |||
int batch_size = -1, | |||
int verbose = 1, | |||
int steps = -1, | |||
@@ -0,0 +1,75 @@ | |||
namespace Tensorflow.Keras.Engine; | |||
/// <summary> | |||
/// A representation of a Keras in/output during Functional API construction. | |||
/// </summary> | |||
public class KerasTensor | |||
{ | |||
private Tensors _original_tensors; | |||
public Tensors original_tensors | |||
{ | |||
get => _original_tensors; | |||
set => _original_tensors = value; | |||
} | |||
private Shape _inferred_value; | |||
public Shape inferred_value => _inferred_value; | |||
private string _name; | |||
private TensorSpec _type_spec; | |||
public Shape shape => _type_spec.shape; | |||
public TF_DataType dtype => _type_spec.dtype; | |||
public KerasTensor(TensorSpec type_spec, Shape inferred_value = null, string name = null) | |||
{ | |||
_type_spec = type_spec; | |||
_inferred_value = inferred_value; | |||
_name = name; | |||
} | |||
public static KerasTensor from_tensor(Tensor tensor) | |||
{ | |||
var type_spec = tensor.ToTensorSpec(); | |||
Shape? inferred_value = default; | |||
if (tensor.dtype == TF_DataType.TF_INT32 && tensor.rank < 2) | |||
{ | |||
inferred_value = tf.ones(tensor).shape; | |||
} | |||
var kt = new KerasTensor(type_spec, inferred_value: inferred_value, name: tensor.name); | |||
kt.original_tensors = tensor; | |||
return kt; | |||
} | |||
public KerasTensor this[int idx] | |||
=> _original_tensors.First()[idx]; | |||
public KerasTensor this[params Slice[] slices] | |||
=> _original_tensors.First()[slices]; | |||
public override string ToString() | |||
=> _original_tensors.Length switch | |||
{ | |||
> 1 => "[" + string.Join(", ", _original_tensors.Select(x => $"KerasTensor: shape={x.shape} dtype={x.dtype.as_numpy_name()}{GetInferredValueString()}")) + "]", | |||
1 => $"KerasTensor: shape={_original_tensors.shape} dtype={_original_tensors.dtype.as_numpy_name()}{GetInferredValueString()}", | |||
_ => _original_tensors.ToString(), | |||
}; | |||
private string GetInferredValueString() | |||
=> _inferred_value == null ? "" : $" inferred_value={_inferred_value}"; | |||
public static implicit operator Tensors(KerasTensor kt) | |||
=> kt._original_tensors; | |||
public static implicit operator Tensor(KerasTensor kt) | |||
{ | |||
Tensor tensor = kt._original_tensors; | |||
tensor.IsFromKerasTensor = true; | |||
return tensor; | |||
} | |||
public static implicit operator KerasTensor(Tensor tensor) | |||
=> from_tensor(tensor); | |||
public static implicit operator KerasTensor(Tensors tensors) | |||
=> from_tensor(tensors.First()); | |||
} |
@@ -25,6 +25,27 @@ namespace Tensorflow.Keras | |||
bool amsgrad = false, | |||
string name = "Adam"); | |||
/// <summary> | |||
/// Adam enables L2 weight decay on gradients. | |||
/// </summary> | |||
/// <param name="learning_rate"></param> | |||
/// <param name="weight_decay"></param> | |||
/// <param name="beta_1"></param> | |||
/// <param name="beta_2"></param> | |||
/// <param name="epsilon"></param> | |||
/// <param name="amsgrad"></param> | |||
/// <param name="decay_params"></param> | |||
/// <param name="name"></param> | |||
/// <returns></returns> | |||
IOptimizer AdamW(float learning_rate = 0.001f, | |||
float weight_decay = 0.004f, | |||
float beta_1 = 0.9f, | |||
float beta_2 = 0.999f, | |||
float epsilon = 1e-7f, | |||
bool amsgrad = false, | |||
List<string> no_decay_params = null, | |||
string name = "AdamW"); | |||
/// <summary> | |||
/// Construct a new RMSprop optimizer. | |||
/// </summary> | |||
@@ -42,6 +63,6 @@ namespace Tensorflow.Keras | |||
bool centered = false, | |||
string name = "RMSprop"); | |||
IOptimizer SGD(float learning_rate); | |||
IOptimizer SGD(float learning_rate = 0.01f, float momentum = 0f); | |||
} | |||
} |
@@ -1,4 +1,5 @@ | |||
using Tensorflow.Keras.Engine; | |||
using Tensorflow.Common.Types; | |||
using Tensorflow.Keras.Engine; | |||
using Tensorflow.Keras.Saving; | |||
using Tensorflow.NumPy; | |||
using Tensorflow.Training; | |||
@@ -14,7 +15,7 @@ namespace Tensorflow.Keras | |||
List<ILayer> Layers { get; } | |||
List<INode> InboundNodes { get; } | |||
List<INode> OutboundNodes { get; } | |||
Tensors Apply(Tensors inputs, Tensor state = null, bool training = false); | |||
Tensors Apply(Tensors inputs, Tensors states = null, bool? training = false, IOptionalArgs? optional_args = null); | |||
List<IVariableV1> TrainableVariables { get; } | |||
List<IVariableV1> TrainableWeights { get; } | |||
List<IVariableV1> NonTrainableWeights { get; } | |||
@@ -9,6 +9,10 @@ namespace Tensorflow.Keras.Layers | |||
public ILayer Reshape(Shape target_shape); | |||
public ILayer Reshape(object[] target_shape); | |||
public ILayer UpSampling1D( | |||
int size | |||
); | |||
public ILayer UpSampling2D(Shape size = null, | |||
string data_format = null, | |||
string interpolation = "nearest"); | |||
@@ -1,5 +1,7 @@ | |||
using System; | |||
using Tensorflow.Framework.Models; | |||
using Tensorflow.Keras.Engine; | |||
using Tensorflow.Keras.Layers; | |||
using Tensorflow.NumPy; | |||
using static Google.Protobuf.Reflection.FieldDescriptorProto.Types; | |||
@@ -134,7 +136,7 @@ namespace Tensorflow.Keras.Layers | |||
public ILayer GlobalMaxPooling1D(string data_format = "channels_last"); | |||
public ILayer GlobalMaxPooling2D(string data_format = "channels_last"); | |||
public Tensors Input(Shape shape = null, | |||
public KerasTensor Input(Shape shape = null, | |||
int batch_size = -1, | |||
string name = null, | |||
TF_DataType dtype = TF_DataType.DtInvalid, | |||
@@ -159,6 +161,18 @@ namespace Tensorflow.Keras.Layers | |||
public ILayer Normalization(Shape? input_shape = null, int? axis = -1, float? mean = null, float? variance = null, bool invert = false); | |||
public ILayer LeakyReLU(float alpha = 0.3f); | |||
public IRnnCell LSTMCell(int uints, | |||
string activation = "tanh", | |||
string recurrent_activation = "sigmoid", | |||
bool use_bias = true, | |||
string kernel_initializer = "glorot_uniform", | |||
string recurrent_initializer = "orthogonal", | |||
string bias_initializer = "zeros", | |||
bool unit_forget_bias = true, | |||
float dropout = 0f, | |||
float recurrent_dropout = 0f, | |||
int implementation = 2); | |||
public ILayer LSTM(int units, | |||
Activation activation = null, | |||
Activation recurrent_activation = null, | |||
@@ -192,6 +206,19 @@ namespace Tensorflow.Keras.Layers | |||
float offset = 0, | |||
Shape input_shape = null); | |||
public IRnnCell SimpleRNNCell( | |||
int units, | |||
string activation = "tanh", | |||
bool use_bias = true, | |||
string kernel_initializer = "glorot_uniform", | |||
string recurrent_initializer = "orthogonal", | |||
string bias_initializer = "zeros", | |||
float dropout = 0f, | |||
float recurrent_dropout = 0f); | |||
public IRnnCell StackedRNNCells( | |||
IEnumerable<IRnnCell> cells); | |||
public ILayer SimpleRNN(int units, | |||
string activation = "tanh", | |||
string kernel_initializer = "glorot_uniform", | |||
@@ -200,6 +227,69 @@ namespace Tensorflow.Keras.Layers | |||
bool return_sequences = false, | |||
bool return_state = false); | |||
public ILayer RNN( | |||
IRnnCell cell, | |||
bool return_sequences = false, | |||
bool return_state = false, | |||
bool go_backwards = false, | |||
bool stateful = false, | |||
bool unroll = false, | |||
bool time_major = false | |||
); | |||
public ILayer RNN( | |||
IEnumerable<IRnnCell> cell, | |||
bool return_sequences = false, | |||
bool return_state = false, | |||
bool go_backwards = false, | |||
bool stateful = false, | |||
bool unroll = false, | |||
bool time_major = false | |||
); | |||
public IRnnCell GRUCell( | |||
int units, | |||
string activation = "tanh", | |||
string recurrent_activation = "sigmoid", | |||
bool use_bias = true, | |||
string kernel_initializer = "glorot_uniform", | |||
string recurrent_initializer = "orthogonal", | |||
string bias_initializer = "zeros", | |||
float dropout = 0f, | |||
float recurrent_dropout = 0f, | |||
bool reset_after = true); | |||
public ILayer GRU( | |||
int units, | |||
string activation = "tanh", | |||
string recurrent_activation = "sigmoid", | |||
bool use_bias = true, | |||
string kernel_initializer = "glorot_uniform", | |||
string recurrent_initializer = "orthogonal", | |||
string bias_initializer = "zeros", | |||
float dropout = 0f, | |||
float recurrent_dropout = 0f, | |||
bool return_sequences = false, | |||
bool return_state = false, | |||
bool go_backwards = false, | |||
bool stateful = false, | |||
bool unroll = false, | |||
bool time_major = false, | |||
bool reset_after = true | |||
); | |||
/// <summary> | |||
/// Bidirectional wrapper for RNNs. | |||
/// </summary> | |||
/// <param name="layer">`keras.layers.RNN` instance, such as `keras.layers.LSTM` or `keras.layers.GRU`</param> | |||
/// automatically.</param> | |||
/// <returns></returns> | |||
public ILayer Bidirectional( | |||
ILayer layer, | |||
string merge_mode = "concat", | |||
NDArray weights = null, | |||
ILayer backward_layer = null); | |||
public ILayer Subtract(); | |||
} | |||
} |
@@ -0,0 +1,25 @@ | |||
using System; | |||
using System.Collections.Generic; | |||
using System.Text; | |||
using Tensorflow.Common.Types; | |||
namespace Tensorflow.Keras.Layers | |||
{ | |||
public interface IRnnCell: ILayer | |||
{ | |||
/// <summary> | |||
/// If the derived class tends to not implement it, please return null. | |||
/// </summary> | |||
INestStructure<long>? StateSize { get; } | |||
/// <summary> | |||
/// If the derived class tends to not implement it, please return null. | |||
/// </summary> | |||
INestStructure<long>? OutputSize { get; } | |||
/// <summary> | |||
/// Whether the optional RNN args are supported when appying the layer. | |||
/// In other words, whether `Apply` is overwrited with process of `RnnOptionalArgs`. | |||
/// </summary> | |||
bool SupportOptionalArgs { get; } | |||
Tensors GetInitialState(Tensors inputs, Tensor batch_size, TF_DataType dtype); | |||
} | |||
} |
@@ -0,0 +1,12 @@ | |||
using System; | |||
using System.Collections.Generic; | |||
using System.Text; | |||
namespace Tensorflow.Keras.Layers | |||
{ | |||
public interface IStackedRnnCells : IRnnCell | |||
{ | |||
int Count { get; } | |||
IRnnCell this[int idx] { get; } | |||
} | |||
} |
@@ -3,6 +3,7 @@ using Newtonsoft.Json; | |||
using System; | |||
using System.Collections.Generic; | |||
using System.Text; | |||
using Tensorflow.Common.Types; | |||
namespace Tensorflow.Keras.Saving.Json | |||
{ | |||
@@ -6,6 +6,7 @@ using System.Text; | |||
using System.Diagnostics; | |||
using OneOf.Types; | |||
using Tensorflow.Keras.Saving.Json; | |||
using Tensorflow.Common.Types; | |||
namespace Tensorflow.Keras.Saving | |||
{ | |||
@@ -74,8 +74,3 @@ namespace Tensorflow | |||
=> IsScalar ? $"{axis[0]}" : $"({string.Join(", ", axis)})"; | |||
} | |||
} | |||
namespace System.Runtime.CompilerServices | |||
{ | |||
internal static class IsExternalInit { } | |||
} |
@@ -107,9 +107,15 @@ namespace Tensorflow.NumPy | |||
public static implicit operator NDArray(bool value) | |||
=> new NDArray(value); | |||
public static implicit operator NDArray(byte value) | |||
=> new NDArray(value); | |||
public static implicit operator NDArray(int value) | |||
=> new NDArray(value); | |||
public static implicit operator NDArray(long value) | |||
=> new NDArray(value); | |||
public static implicit operator NDArray(float value) | |||
=> new NDArray(value); | |||
@@ -7,7 +7,7 @@ namespace Tensorflow.NumPy | |||
{ | |||
public class NDArrayRender | |||
{ | |||
public static string ToString(NDArray array) | |||
public static string ToString(NDArray array, int maxLength = 10) | |||
{ | |||
Shape shape = array.shape; | |||
if (shape.IsScalar) | |||
@@ -15,12 +15,12 @@ namespace Tensorflow.NumPy | |||
var s = new StringBuilder(); | |||
s.Append("array("); | |||
Build(s, array); | |||
Build(s, array, maxLength); | |||
s.Append(")"); | |||
return s.ToString(); | |||
} | |||
static void Build(StringBuilder s, NDArray array) | |||
static void Build(StringBuilder s, NDArray array, int maxLength) | |||
{ | |||
var shape = array.shape; | |||
@@ -35,11 +35,11 @@ namespace Tensorflow.NumPy | |||
var len = shape[0]; | |||
s.Append("["); | |||
if (len <= 10) | |||
if (len <= maxLength) | |||
{ | |||
for (int i = 0; i < len; i++) | |||
{ | |||
Build(s, array[i]); | |||
Build(s, array[i], maxLength); | |||
if (i < len - 1) | |||
{ | |||
s.Append(", "); | |||
@@ -49,9 +49,9 @@ namespace Tensorflow.NumPy | |||
} | |||
else | |||
{ | |||
for (int i = 0; i < 5; i++) | |||
for (int i = 0; i < maxLength / 2; i++) | |||
{ | |||
Build(s, array[i]); | |||
Build(s, array[i], maxLength); | |||
if (i < len - 1) | |||
{ | |||
s.Append(", "); | |||
@@ -62,9 +62,9 @@ namespace Tensorflow.NumPy | |||
s.Append(" ... "); | |||
s.AppendLine(); | |||
for (int i = (int)len - 5; i < len; i++) | |||
for (int i = (int)len - maxLength / 2; i < len; i++) | |||
{ | |||
Build(s, array[i]); | |||
Build(s, array[i], maxLength); | |||
if (i < len - 1) | |||
{ | |||
s.Append(", "); | |||
@@ -13,6 +13,10 @@ namespace Tensorflow.NumPy | |||
public static NDArray argmax(NDArray a, Axis? axis = null) | |||
=> new NDArray(math_ops.argmax(a, axis ?? 0)); | |||
[AutoNumPy] | |||
public static NDArray argmin(NDArray a, Axis? axis = null) | |||
=> new NDArray(math_ops.argmin(a, axis ?? 0)); | |||
[AutoNumPy] | |||
public static NDArray argsort(NDArray a, Axis? axis = null) | |||
=> new NDArray(sort_ops.argsort(a, axis: axis ?? -1)); | |||
@@ -10,10 +10,10 @@ namespace Tensorflow.NumPy | |||
public partial class np | |||
{ | |||
[AutoNumPy] | |||
public static NDArray amin(NDArray x, int axis = 0) => new NDArray(tf.arg_min(x, axis)); | |||
public static NDArray amin(NDArray x, int axis = 0) => new NDArray(tf.min(x, axis)); | |||
[AutoNumPy] | |||
public static NDArray amax(NDArray x, int axis = 0) => new NDArray(tf.math.argmax(x, axis)); | |||
public static NDArray amax(NDArray x, int axis = 0) => new NDArray(tf.max(x, axis)); | |||
[AutoNumPy] | |||
public static NDArray average(NDArray a, int axis = -1, NDArray? weights = null, bool returned = false) | |||
@@ -49,9 +49,30 @@ namespace Tensorflow.NumPy | |||
[AutoNumPy] | |||
public static NDArray prod<T>(params T[] array) where T : unmanaged | |||
=> new NDArray(tf.reduce_prod(new NDArray(array))); | |||
[AutoNumPy] | |||
public static NDArray dot(NDArray x1, NDArray x2, NDArray? axes = null, string? name = null) | |||
{ | |||
//if axes mentioned | |||
if (axes != null) | |||
{ | |||
return new NDArray(tf.dot_prod(x1, x2, axes, name)); | |||
} | |||
if (x1.shape.ndim > 1) | |||
{ | |||
x1 = GetFlattenArray(x1); | |||
} | |||
if (x2.shape.ndim > 1) | |||
{ | |||
x2 = GetFlattenArray(x2); | |||
} | |||
//if axes not mentioned, default 0,0 | |||
return new NDArray(tf.dot_prod(x1, x2, axes: new int[] { 0, 0 }, name)); | |||
} | |||
[AutoNumPy] | |||
public static NDArray power(NDArray x, NDArray y) => new NDArray(tf.pow(x, y)); | |||
[AutoNumPy] | |||
public static NDArray square(NDArray x) => new NDArray(tf.square(x)); | |||
[AutoNumPy] | |||
public static NDArray sin(NDArray x) => new NDArray(math_ops.sin(x)); | |||
@@ -19,13 +19,14 @@ using System; | |||
using System.Collections.Generic; | |||
using System.Linq; | |||
using System.Text; | |||
using Tensorflow.Common.Types; | |||
using Tensorflow.Keras.Saving.Common; | |||
using Tensorflow.NumPy; | |||
namespace Tensorflow | |||
{ | |||
[JsonConverter(typeof(CustomizedShapeJsonConverter))] | |||
public class Shape | |||
public class Shape : INestStructure<long> | |||
{ | |||
public int ndim => _dims == null ? -1 : _dims.Length; | |||
long[] _dims; | |||
@@ -41,6 +42,27 @@ namespace Tensorflow | |||
} | |||
} | |||
public NestType NestType => NestType.List; | |||
public int ShallowNestedCount => ndim; | |||
/// <summary> | |||
/// The total item count of depth 1 of the nested structure. | |||
/// For example, [1, 2, [3, 4, 5]] has TotalNestedCount = 5. | |||
/// </summary> | |||
public int TotalNestedCount => ndim; | |||
public IEnumerable<long> Flatten() => dims.Select(x => x); | |||
public INestStructure<TOut> MapStructure<TOut>(Func<long, TOut> func) | |||
{ | |||
return new NestList<TOut>(dims.Select(x => func(x))); | |||
} | |||
public Nest<long> AsNest() | |||
{ | |||
return new NestList<long>(Flatten()).AsNest(); | |||
} | |||
#region https://docs.microsoft.com/en-us/dotnet/csharp/language-reference/proposals/csharp-8.0/ranges | |||
public int Length => ndim; | |||
public long[] Slice(int start, int length) | |||
@@ -0,0 +1,22 @@ | |||
using System; | |||
using System.Collections.Generic; | |||
using System.Text; | |||
using Tensorflow.NumPy; | |||
namespace Tensorflow.Operations.Initializers | |||
{ | |||
/// <summary> | |||
/// An initializer specially used for debugging (to load weights from disk). | |||
/// </summary> | |||
class NpyLoadInitializer : IInitializer | |||
{ | |||
string _path; | |||
public NpyLoadInitializer(string path) { _path = path; } | |||
public string ClassName => ""; | |||
public IDictionary<string, object> Config => new Dictionary<string, object>(); | |||
public Tensor Apply(InitializerArgs args) | |||
{ | |||
return np.load(_path); | |||
} | |||
} | |||
} |
@@ -53,13 +53,12 @@ public class Orthogonal : IInitializer | |||
// Compute the qr factorization | |||
var (q, r) = tf.linalg.qr(a, full_matrices: false); | |||
// Make Q uniform | |||
var d = tf.linalg.tensor_diag_part(r); | |||
var d = tf.linalg.tensor_diag_part(r.Single); | |||
q *= tf.sign(d); | |||
if (num_rows < num_cols) | |||
{ | |||
// q = tf.linalg.matrix_transpose(q); | |||
throw new NotImplementedException(""); | |||
q = array_ops.matrix_transpose(q); | |||
} | |||
return _gain * tf.reshape(q, shape); | |||
@@ -11,6 +11,7 @@ namespace Tensorflow | |||
/// Basic LSTM recurrent network cell. | |||
/// The implementation is based on: http://arxiv.org/abs/1409.2329. | |||
/// </summary> | |||
[Obsolete("This is an incompleted tf v1 api, pleas use keras RNNs instead.")] | |||
public class BasicLstmCell : LayerRnnCell | |||
{ | |||
int _num_units; | |||
@@ -88,7 +89,7 @@ namespace Tensorflow | |||
gate_inputs = nn_ops.bias_add(gate_inputs, _bias); | |||
// i = input_gate, j = new_input, f = forget_gate, o = output_gate | |||
var tensors = array_ops.split(value: gate_inputs, num_split: 4, axis: one); | |||
var tensors = array_ops.split(value: gate_inputs, num_or_size_splits: 4, axis: one); | |||
var (i, j, f, o) = (tensors[0], tensors[1], tensors[2], tensors[3]); | |||
var forget_bias_tensor = constant_op.constant(_forget_bias, dtype: f.dtype); | |||
@@ -20,6 +20,7 @@ using static Tensorflow.Binding; | |||
namespace Tensorflow | |||
{ | |||
[Obsolete("This is an incompleted tf v1 api, pleas use keras RNNs instead.")] | |||
public class BasicRnnCell : LayerRnnCell | |||
{ | |||
int _num_units; | |||
@@ -19,6 +19,7 @@ using static Tensorflow.Binding; | |||
namespace Tensorflow | |||
{ | |||
[Obsolete("This is an incompleted tf v1 api, pleas use keras RNNs instead.")] | |||
public class LayerRnnCell : RnnCell | |||
{ | |||
protected InputSpec inputSpec; | |||
@@ -16,10 +16,11 @@ | |||
using System; | |||
using System.Collections.Generic; | |||
using Tensorflow.Common.Types; | |||
using Tensorflow.Keras; | |||
using Tensorflow.Keras.ArgsDefinition; | |||
using Tensorflow.Keras.ArgsDefinition.Rnn; | |||
using Tensorflow.Keras.Engine; | |||
using Tensorflow.Keras.Layers; | |||
using Tensorflow.Keras.Saving; | |||
using Tensorflow.NumPy; | |||
using Tensorflow.Operations; | |||
@@ -50,7 +51,8 @@ namespace Tensorflow | |||
/// matching structure of Tensors having shape `[batch_size].concatenate(s)` | |||
/// for each `s` in `self.batch_size`. | |||
/// </summary> | |||
public abstract class RnnCell : ILayer, RNNArgs.IRnnArgCell | |||
[Obsolete("This is an incompleted tf v1 api, pleas use keras RNNs instead.")] | |||
public abstract class RnnCell : ILayer, IRnnCell | |||
{ | |||
/// <summary> | |||
/// Attribute that indicates whether the cell is a TF RNN cell, due the slight | |||
@@ -142,7 +144,7 @@ namespace Tensorflow | |||
throw new NotImplementedException("_zero_state_tensors"); | |||
} | |||
public Tensors Apply(Tensors inputs, Tensor state = null, bool is_training = false) | |||
public Tensors Apply(Tensors inputs, Tensors state = null, bool? is_training = false, IOptionalArgs? optional_args = null) | |||
{ | |||
throw new NotImplementedException(); | |||
} | |||
@@ -173,5 +175,18 @@ namespace Tensorflow | |||
{ | |||
throw new NotImplementedException(); | |||
} | |||
public (Tensor, Tensors) Call(Tensors inputs, Tensors states, bool? training = null) | |||
{ | |||
throw new NotImplementedException(); | |||
} | |||
public Tensors GetInitialState(Tensors inputs = null, Tensor batch_size = null, TF_DataType dtype = TF_DataType.DtInvalid) | |||
{ | |||
throw new NotImplementedException(); | |||
} | |||
public INestStructure<long> StateSize => throw new NotImplementedException(); | |||
public INestStructure<long> OutputSize => throw new NotImplementedException(); | |||
public bool IsTFRnnCell => throw new NotImplementedException(); | |||
public bool SupportOptionalArgs => throw new NotImplementedException(); | |||
} | |||
} |
@@ -15,9 +15,11 @@ | |||
******************************************************************************/ | |||
using Google.Protobuf; | |||
using Google.Protobuf.Collections; | |||
using System; | |||
using System.Collections.Generic; | |||
using System.Linq; | |||
using Tensorflow.Functions; | |||
using static Tensorflow.Binding; | |||
using static Tensorflow.OpDef.Types; | |||
@@ -387,9 +389,13 @@ namespace Tensorflow | |||
case "list(type)": | |||
attr_value.List.Type.AddRange((value as IList<TF_DataType>).Select(x => _MakeType(x, attr_def))); | |||
break; | |||
case "list(float)": | |||
if (value != null) | |||
attr_value.List.F.AddRange((value as IEnumerable<float>).ToArray()); | |||
break; | |||
case "list(int)": | |||
if (value != null) | |||
attr_value.List.I.AddRange((value as int[]).Select(x => Convert.ToInt64(x))); | |||
attr_value.List.I.AddRange((value as IEnumerable<int>).Select(x => Convert.ToInt64(x))); | |||
break; | |||
case "bool": | |||
attr_value.B = (bool)value; | |||
@@ -420,6 +426,15 @@ namespace Tensorflow | |||
case "list(shape)": | |||
attr_value.List.Shape.AddRange((value as Shape[]).Select(x => _MakeShape(x, attr_def))); | |||
break; | |||
case "func": | |||
attr_value.Func = _MakeFunc(value, attr_def.Name); | |||
break; | |||
case "list(func)": | |||
attr_value.List.Func.AddRange(_MakeFuncList(value, attr_def.Name)); | |||
break; | |||
case "list(string)": | |||
attr_value.List.S.AddRange((value as IEnumerable<string>).Select(x => ByteString.CopyFromUtf8(x))); | |||
break; | |||
default: | |||
throw new TypeError($"SetAttrValue: can't not convert attr_def.Type '{attr_def.Type}' to protos."); | |||
} | |||
@@ -427,6 +442,47 @@ namespace Tensorflow | |||
return attr_value; | |||
} | |||
private NameAttrList _MakeFunc(object func, string arg_name) | |||
{ | |||
if(func is NameAttrList attrList) | |||
{ | |||
return attrList; | |||
} | |||
NameAttrList fn_attr; | |||
if(func is string funcStr) | |||
{ | |||
fn_attr = new NameAttrList() { Name = funcStr }; | |||
} | |||
else if(func is ConcreteFunction concrete) | |||
{ | |||
concrete.AddTograph(ops.get_default_graph()); | |||
fn_attr = concrete.AsNameAttrList; | |||
} | |||
else if(func is EagerDefinedFunction eager) | |||
{ | |||
eager.AddToGraph(ops.get_default_graph()); | |||
fn_attr = new NameAttrList() { Name = eager.Name }; | |||
} | |||
else | |||
{ | |||
throw new TypeError($"Don't know how to convert {func} to a func for argument {arg_name}"); | |||
} | |||
return fn_attr; | |||
} | |||
private List<NameAttrList> _MakeFuncList(object funcList, string arg_name) | |||
{ | |||
List<NameAttrList> res = new List<NameAttrList>(); | |||
if(funcList is IEnumerable enumerable) | |||
{ | |||
foreach(var func in enumerable) | |||
{ | |||
res.Add(_MakeFunc(func, arg_name)); | |||
} | |||
} | |||
return res; | |||
} | |||
private bool _IsListParameter(ArgDef arg) | |||
{ | |||
if (!String.IsNullOrEmpty(arg.NumberAttr)) | |||