# Conflicts: # src/TensorFlowNET.Core/Tensorflow.Binding.csproj # src/TensorFlowNET.Keras/Datasets/Imdb.cstags/v0.110.4-Transformer-Model
@@ -16,6 +16,7 @@ | |||||
using System; | using System; | ||||
using System.Runtime.InteropServices; | using System.Runtime.InteropServices; | ||||
using static Tensorflow.CppShapeInferenceResult.Types; | |||||
namespace Tensorflow | namespace Tensorflow | ||||
{ | { | ||||
@@ -50,6 +51,35 @@ namespace Tensorflow | |||||
return handle == IntPtr.Zero ? String.Empty : Marshal.PtrToStringAnsi(handle); | return handle == IntPtr.Zero ? String.Empty : Marshal.PtrToStringAnsi(handle); | ||||
} | } | ||||
public unsafe static byte[] ByteStringPiece(Buffer? handle) | |||||
{ | |||||
if (handle is null) | |||||
{ | |||||
return new byte[0]; | |||||
} | |||||
var data = handle.ToArray(); | |||||
return data; | |||||
} | |||||
public unsafe static byte[] ByteStringPieceFromNativeString(IntPtr handle) | |||||
{ | |||||
if (handle == IntPtr.Zero) | |||||
{ | |||||
return new byte[0]; | |||||
} | |||||
byte* str_data = (byte*)handle.ToPointer(); | |||||
List<byte> bytes = new List<byte>(); | |||||
byte current = 255; | |||||
while (current != ((byte)'\0')) | |||||
{ | |||||
current = *(str_data++); | |||||
bytes.Add(current); | |||||
} | |||||
var data = bytes.ToArray(); | |||||
return data; | |||||
} | |||||
[UnmanagedFunctionPointer(CallingConvention.Winapi)] | [UnmanagedFunctionPointer(CallingConvention.Winapi)] | ||||
public delegate void Deallocator(IntPtr data, IntPtr size, ref DeallocatorArgs args); | public delegate void Deallocator(IntPtr data, IntPtr size, ref DeallocatorArgs args); | ||||
@@ -10,7 +10,7 @@ namespace Tensorflow | |||||
[DllImport(TensorFlowLibName)] | [DllImport(TensorFlowLibName)] | ||||
public static extern void TFC_SetAttr(SafeGraphHandle graph, IntPtr op, string attr_name, SafeBufferHandle attr_value_proto, SafeStatusHandle status); | public static extern void TFC_SetAttr(SafeGraphHandle graph, IntPtr op, string attr_name, SafeBufferHandle attr_value_proto, SafeStatusHandle status); | ||||
[DllImport(TensorFlowLibName)] | [DllImport(TensorFlowLibName)] | ||||
public static extern IntPtr TFC_GetHandleShapeAndType(SafeGraphHandle c_graph, TF_Output output); | |||||
public static extern SafeBufferHandle TFC_GetHandleShapeAndType(SafeGraphHandle c_graph, TF_Output output); | |||||
[DllImport(TensorFlowLibName)] | [DllImport(TensorFlowLibName)] | ||||
public static extern void TFC_SetHandleShapeAndType(SafeGraphHandle c_graph, TF_Output output, byte[] data, long proto_len, SafeStatusHandle status); | public static extern void TFC_SetHandleShapeAndType(SafeGraphHandle c_graph, TF_Output output, byte[] data, long proto_len, SafeStatusHandle status); | ||||
} | } | ||||
@@ -91,8 +91,7 @@ namespace Tensorflow | |||||
return identity(values.First(), name: scope); | return identity(values.First(), name: scope); | ||||
}); | }); | ||||
} | } | ||||
return gen_array_ops.concat_v2(values.ToArray(), ops.convert_to_tensor(axis), name: name); | |||||
return array_ops.concat(values.ToArray(), axis, name: name); | |||||
} | } | ||||
/// <summary> | /// <summary> | ||||
@@ -163,14 +162,17 @@ namespace Tensorflow | |||||
/// Reverses specific dimensions of a tensor. | /// Reverses specific dimensions of a tensor. | ||||
/// </summary> | /// </summary> | ||||
/// <param name="tensor"></param> | /// <param name="tensor"></param> | ||||
/// <param name="axis"></param> | |||||
/// <param name="axis">The indices of the dimensions to reverse. Must be in the range [-rank(tensor), rank(tensor)).</param> | |||||
/// <param name="name"></param> | /// <param name="name"></param> | ||||
/// <returns></returns> | /// <returns></returns> | ||||
public Tensor reverse(Tensor tensor, int[] axis, string name = null) | |||||
=> gen_array_ops.reverse(tensor, ops.convert_to_tensor(axis), name: name); | |||||
public Tensor reverse(Tensor tensor, Tensor axis, string name = null) | |||||
=> gen_array_ops.reverse(tensor, axis, name: name); | |||||
public Tensor reverse(Tensor tensor, Axis axis, string name = null) | |||||
{ | |||||
if (axis.IsScalar) | |||||
{ | |||||
axis = new Axis(axis.axis); | |||||
} | |||||
return array_ops.reverse(tensor, axis, name: name); | |||||
} | |||||
/// <summary> | /// <summary> | ||||
/// Returns the rank of a tensor. | /// Returns the rank of a tensor. | ||||
@@ -46,10 +46,10 @@ namespace Tensorflow | |||||
Tensor loop_vars, | Tensor loop_vars, | ||||
int parallel_iterations = 10) | int parallel_iterations = 10) | ||||
{ | { | ||||
Func<Tensor[], Tensor> cond1 = x | |||||
Func<Tensors, Tensor> cond1 = x | |||||
=> cond(x[0]); | => cond(x[0]); | ||||
Func<Tensor[], Tensor[]> body1 = x | |||||
Func<Tensors, Tensors> body1 = x | |||||
=> new[] { body(x[0]) }; | => new[] { body(x[0]) }; | ||||
var results = control_flow_ops.while_loop(cond1, | var results = control_flow_ops.while_loop(cond1, | ||||
@@ -58,9 +58,9 @@ namespace Tensorflow | |||||
return results[0]; | return results[0]; | ||||
} | } | ||||
public Tensor[] while_loop(Func<Tensor[], Tensor> cond, | |||||
Func<Tensor[], Tensor[]> body, | |||||
Tensor[] loop_vars, | |||||
public Tensor[] while_loop(Func<Tensors, Tensor> cond, | |||||
Func<Tensors, Tensors> body, | |||||
Tensors loop_vars, | |||||
int parallel_iterations = 10, | int parallel_iterations = 10, | ||||
string name = null) | string name = null) | ||||
=> control_flow_ops.while_loop(cond, body, loop_vars, | => control_flow_ops.while_loop(cond, body, loop_vars, | ||||
@@ -14,6 +14,10 @@ | |||||
limitations under the License. | limitations under the License. | ||||
******************************************************************************/ | ******************************************************************************/ | ||||
using OneOf.Types; | |||||
using System; | |||||
using System.Buffers.Text; | |||||
using Tensorflow.Contexts; | |||||
using static Tensorflow.Binding; | using static Tensorflow.Binding; | ||||
namespace Tensorflow | namespace Tensorflow | ||||
@@ -162,17 +166,108 @@ namespace Tensorflow | |||||
public Tensor sobel_edges(Tensor image) | public Tensor sobel_edges(Tensor image) | ||||
=> image_ops_impl.sobel_edges(image); | => image_ops_impl.sobel_edges(image); | ||||
public Tensor decode_jpeg(Tensor contents, | |||||
int channels = 0, | |||||
int ratio = 1, | |||||
bool fancy_upscaling = true, | |||||
bool try_recover_truncated = false, | |||||
int acceptable_fraction = 1, | |||||
string dct_method = "", | |||||
string name = null) | |||||
=> gen_image_ops.decode_jpeg(contents, channels: channels, ratio: ratio, | |||||
fancy_upscaling: fancy_upscaling, try_recover_truncated: try_recover_truncated, | |||||
acceptable_fraction: acceptable_fraction, dct_method: dct_method); | |||||
/// <summary> | |||||
/// Adjust contrast of RGB or grayscale images. | |||||
/// </summary> | |||||
/// <param name="images">Images to adjust. At least 3-D.</param> | |||||
/// <param name="contrast_factor"></param> | |||||
/// <param name="name">A float multiplier for adjusting contrast.</param> | |||||
/// <returns>The contrast-adjusted image or images.</returns> | |||||
public Tensor adjust_contrast(Tensor images, float contrast_factor, string name = null) | |||||
=> gen_image_ops.adjust_contrastv2(images, contrast_factor, name); | |||||
/// <summary> | |||||
/// Adjust hue of RGB images. | |||||
/// </summary> | |||||
/// <param name="images">RGB image or images. The size of the last dimension must be 3.</param> | |||||
/// <param name="delta">float. How much to add to the hue channel.</param> | |||||
/// <param name="name">A name for this operation (optional).</param> | |||||
/// <returns>Adjusted image(s), same shape and DType as `image`.</returns> | |||||
/// <exception cref="ValueError">if `delta` is not in the interval of `[-1, 1]`.</exception> | |||||
public Tensor adjust_hue(Tensor images, float delta, string name = null) | |||||
{ | |||||
if (tf.Context.executing_eagerly()) | |||||
{ | |||||
if (delta < -1f || delta > 1f) | |||||
throw new ValueError("delta must be in the interval [-1, 1]"); | |||||
} | |||||
return gen_image_ops.adjust_hue(images, delta, name: name); | |||||
} | |||||
/// <summary> | |||||
/// Adjust saturation of RGB images. | |||||
/// </summary> | |||||
/// <param name="image">RGB image or images. The size of the last dimension must be 3.</param> | |||||
/// <param name="saturation_factor">float. Factor to multiply the saturation by.</param> | |||||
/// <param name="name">A name for this operation (optional).</param> | |||||
/// <returns>Adjusted image(s), same shape and DType as `image`.</returns> | |||||
public Tensor adjust_saturation(Tensor image, float saturation_factor, string name = null) | |||||
=> gen_image_ops.adjust_saturation(image, saturation_factor, name); | |||||
/// <summary> | |||||
/// Greedily selects a subset of bounding boxes in descending order of score. | |||||
/// </summary> | |||||
/// <param name="boxes"> | |||||
/// A 4-D float `Tensor` of shape `[batch_size, num_boxes, q, 4]`. If `q` | |||||
/// is 1 then same boxes are used for all classes otherwise, if `q` is equal | |||||
/// to number of classes, class-specific boxes are used. | |||||
/// </param> | |||||
/// <param name="scores"> | |||||
/// A 3-D float `Tensor` of shape `[batch_size, num_boxes, num_classes]` | |||||
/// representing a single score corresponding to each box(each row of boxes). | |||||
/// </param> | |||||
/// <param name="max_output_size_per_class"> | |||||
/// A scalar integer `Tensor` representing the | |||||
/// maximum number of boxes to be selected by non-max suppression per class | |||||
/// </param> | |||||
/// <param name="max_total_size"> | |||||
/// A int32 scalar representing maximum number of boxes retained | |||||
/// over all classes.Note that setting this value to a large number may | |||||
/// result in OOM error depending on the system workload. | |||||
/// </param> | |||||
/// <param name="iou_threshold"> | |||||
/// A float representing the threshold for deciding whether boxes | |||||
/// overlap too much with respect to IOU. | |||||
/// </param> | |||||
/// <param name="score_threshold"> | |||||
/// A float representing the threshold for deciding when to | |||||
/// remove boxes based on score. | |||||
/// </param> | |||||
/// <param name="pad_per_class"> | |||||
/// If false, the output nmsed boxes, scores and classes are | |||||
/// padded/clipped to `max_total_size`. If true, the output nmsed boxes, scores and classes are padded to be of length `max_size_per_class`*`num_classes`, | |||||
/// unless it exceeds `max_total_size` in which case it is clipped to `max_total_size`. Defaults to false. | |||||
/// </param> | |||||
/// <param name="clip_boxes"> | |||||
/// If true, the coordinates of output nmsed boxes will be clipped | |||||
/// to[0, 1]. If false, output the box coordinates as it is. Defaults to true. | |||||
/// </param> | |||||
/// <returns> | |||||
/// 'nmsed_boxes': A [batch_size, max_detections, 4] float32 tensor containing the non-max suppressed boxes. | |||||
/// 'nmsed_scores': A [batch_size, max_detections] float32 tensor containing the scores for the boxes. | |||||
/// 'nmsed_classes': A [batch_size, max_detections] float32 tensor containing the class for boxes. | |||||
/// 'valid_detections': A [batch_size] int32 tensor indicating the number of | |||||
/// valid detections per batch item. Only the top valid_detections[i] entries | |||||
/// in nms_boxes[i], nms_scores[i] and nms_class[i] are valid. The rest of the | |||||
/// entries are zero paddings. | |||||
/// </returns> | |||||
public (Tensor, Tensor, Tensor, Tensor) combined_non_max_suppression( | |||||
Tensor boxes, | |||||
Tensor scores, | |||||
int max_output_size_per_class, | |||||
int max_total_size, | |||||
float iou_threshold, | |||||
float score_threshold, | |||||
bool pad_per_class = false, | |||||
bool clip_boxes = true) | |||||
{ | |||||
var iou_threshold_t = ops.convert_to_tensor(iou_threshold, TF_DataType.TF_FLOAT, name: "iou_threshold"); | |||||
var score_threshold_t = ops.convert_to_tensor(score_threshold, TF_DataType.TF_FLOAT, name: "score_threshold"); | |||||
var max_total_size_t = ops.convert_to_tensor(max_total_size); | |||||
var max_output_size_per_class_t = ops.convert_to_tensor(max_output_size_per_class); | |||||
return gen_image_ops.combined_non_max_suppression(boxes, scores, max_output_size_per_class_t, max_total_size_t, | |||||
iou_threshold_t, score_threshold_t, pad_per_class, clip_boxes); | |||||
} | |||||
/// <summary> | /// <summary> | ||||
/// Extracts crops from the input image tensor and resizes them using bilinear sampling or nearest neighbor sampling (possibly with aspect ratio change) to a common output size specified by crop_size. This is more general than the crop_to_bounding_box op which extracts a fixed size slice from the input image and does not allow resizing or aspect ratio change. | /// Extracts crops from the input image tensor and resizes them using bilinear sampling or nearest neighbor sampling (possibly with aspect ratio change) to a common output size specified by crop_size. This is more general than the crop_to_bounding_box op which extracts a fixed size slice from the input image and does not allow resizing or aspect ratio change. | ||||
@@ -187,7 +282,19 @@ namespace Tensorflow | |||||
/// <param name="name">A name for the operation (optional).</param> | /// <param name="name">A name for the operation (optional).</param> | ||||
/// <returns>A 4-D tensor of shape [num_boxes, crop_height, crop_width, depth].</returns> | /// <returns>A 4-D tensor of shape [num_boxes, crop_height, crop_width, depth].</returns> | ||||
public Tensor crop_and_resize(Tensor image, Tensor boxes, Tensor box_ind, Tensor crop_size, string method = "bilinear", float extrapolation_value = 0f, string name = null) => | public Tensor crop_and_resize(Tensor image, Tensor boxes, Tensor box_ind, Tensor crop_size, string method = "bilinear", float extrapolation_value = 0f, string name = null) => | ||||
image_ops_impl.crop_and_resize(image, boxes, box_ind, crop_size, method, extrapolation_value, name); | |||||
gen_image_ops.crop_and_resize(image, boxes, box_ind, crop_size, method, extrapolation_value, name); | |||||
public Tensor decode_jpeg(Tensor contents, | |||||
int channels = 0, | |||||
int ratio = 1, | |||||
bool fancy_upscaling = true, | |||||
bool try_recover_truncated = false, | |||||
int acceptable_fraction = 1, | |||||
string dct_method = "", | |||||
string name = null) | |||||
=> gen_image_ops.decode_jpeg(contents, channels: channels, ratio: ratio, | |||||
fancy_upscaling: fancy_upscaling, try_recover_truncated: try_recover_truncated, | |||||
acceptable_fraction: acceptable_fraction, dct_method: dct_method); | |||||
public Tensor extract_glimpse(Tensor input, Tensor size, Tensor offsets, bool centered = true, bool normalized = true, | public Tensor extract_glimpse(Tensor input, Tensor size, Tensor offsets, bool centered = true, bool normalized = true, | ||||
bool uniform_noise = true, string name = null) | bool uniform_noise = true, string name = null) | ||||
@@ -14,6 +14,7 @@ | |||||
limitations under the License. | limitations under the License. | ||||
******************************************************************************/ | ******************************************************************************/ | ||||
using Tensorflow.NumPy; | |||||
using Tensorflow.Operations; | using Tensorflow.Operations; | ||||
namespace Tensorflow | namespace Tensorflow | ||||
@@ -42,10 +43,20 @@ namespace Tensorflow | |||||
public Tensor multiply(Tensor x, Tensor y, string name = null) | public Tensor multiply(Tensor x, Tensor y, string name = null) | ||||
=> math_ops.multiply(x, y, name: name); | => math_ops.multiply(x, y, name: name); | ||||
public Tensor divide_no_nan(Tensor a, Tensor b, string name = null) | public Tensor divide_no_nan(Tensor a, Tensor b, string name = null) | ||||
=> math_ops.div_no_nan(a, b); | => math_ops.div_no_nan(a, b); | ||||
/// <summary> | |||||
/// Computes the Euclidean norm of elements across dimensions of a tensor. | |||||
/// </summary> | |||||
/// <param name="input_tensor">The tensor to reduce. Should have numeric type.</param> | |||||
/// <param name="axis">The dimensions to reduce. If `None` (the default), reduces all dimensions.Must be in the range `[-rank(input_tensor), rank(input_tensor))`</param> | |||||
/// <param name="keepdims">If true, retains reduced dimensions with length 1.</param> | |||||
/// <param name="name">A name for the operation (optional).</param> | |||||
/// <returns>The reduced tensor, of the same dtype as the input_tensor.</returns> | |||||
public Tensor reduce_euclidean_norm(Tensor input_tensor, Axis? axis = null, bool keepdims = false, string name = null) | |||||
=> math_ops.reduce_euclidean_norm(input_tensor, axis: axis, keepdims: keepdims, name); | |||||
public Tensor square(Tensor x, string name = null) | public Tensor square(Tensor x, string name = null) | ||||
=> math_ops.square(x, name: name); | => math_ops.square(x, name: name); | ||||
@@ -354,7 +365,7 @@ namespace Tensorflow | |||||
=> a / b; | => a / b; | ||||
public Tensor sqrt(Tensor a, string name = null) | public Tensor sqrt(Tensor a, string name = null) | ||||
=> gen_math_ops.sqrt(a, name); | |||||
=> math_ops.sqrt(a, name); | |||||
public Tensor sign(Tensor a, string name = null) | public Tensor sign(Tensor a, string name = null) | ||||
=> gen_math_ops.sign(a, name); | => gen_math_ops.sign(a, name); | ||||
@@ -452,7 +463,18 @@ namespace Tensorflow | |||||
/// <returns></returns> | /// <returns></returns> | ||||
public Tensor multiply<Tx, Ty>(Tx x, Ty y, string name = null) | public Tensor multiply<Tx, Ty>(Tx x, Ty y, string name = null) | ||||
=> gen_math_ops.mul(ops.convert_to_tensor(x), ops.convert_to_tensor(y), name: name); | => gen_math_ops.mul(ops.convert_to_tensor(x), ops.convert_to_tensor(y), name: name); | ||||
/// <summary> | |||||
/// return scalar product | |||||
/// </summary> | |||||
/// <typeparam name="Tx"></typeparam> | |||||
/// <typeparam name="Ty"></typeparam> | |||||
/// <param name="x"></param> | |||||
/// <param name="y"></param> | |||||
/// <param name="axes"></param> | |||||
/// <param name="name"></param> | |||||
/// <returns></returns> | |||||
public Tensor dot_prod<Tx, Ty>(Tx x, Ty y, NDArray axes, string name = null) | |||||
=> math_ops.tensordot(convert_to_tensor(x), convert_to_tensor(y), axes, name: name); | |||||
public Tensor negative(Tensor x, string name = null) | public Tensor negative(Tensor x, string name = null) | ||||
=> gen_math_ops.neg(x, name); | => gen_math_ops.neg(x, name); | ||||
@@ -600,5 +622,7 @@ namespace Tensorflow | |||||
=> gen_math_ops.squared_difference(x: x, y: y, name: name); | => gen_math_ops.squared_difference(x: x, y: y, name: name); | ||||
public Tensor complex(Tensor real, Tensor imag, Tensorflow.TF_DataType? dtype = null, | public Tensor complex(Tensor real, Tensor imag, Tensorflow.TF_DataType? dtype = null, | ||||
string name = null) => gen_ops.complex(real, imag, dtype, name); | string name = null) => gen_ops.complex(real, imag, dtype, name); | ||||
public Tensor exp(Tensor x, | |||||
string name = null) => gen_math_ops.exp(x, name); | |||||
} | } | ||||
} | } |
@@ -14,6 +14,7 @@ | |||||
limitations under the License. | limitations under the License. | ||||
******************************************************************************/ | ******************************************************************************/ | ||||
using System.Xml.Linq; | |||||
using Tensorflow.Operations; | using Tensorflow.Operations; | ||||
using Tensorflow.Operations.Activation; | using Tensorflow.Operations.Activation; | ||||
using static Tensorflow.Binding; | using static Tensorflow.Binding; | ||||
@@ -126,6 +127,26 @@ namespace Tensorflow | |||||
name: name, | name: name, | ||||
exponential_avg_factor: exponential_avg_factor); | exponential_avg_factor: exponential_avg_factor); | ||||
/// <summary> | |||||
/// Normalizes a tensor by `mean` and `variance`, and applies (optionally) a`scale` \\(\gamma\\) to it, as well as an `offset` \\(\beta\\). | |||||
/// </summary> | |||||
/// <param name="x">A floating point tensor.</param> | |||||
/// <param name="mean">A mean `Tensor`.</param> | |||||
/// <param name="variance">A variance `Tensor`.</param> | |||||
/// <param name="offset"> An offset `Tensor`, often denoted \\(\beta\\) in equations, or NULL. If present, will be added to the normalized tensor.</param> | |||||
/// <param name="scale"> A scale `Tensor`, often denoted \\(\gamma\\) in equations, or NULL. If present, the scale is applied to the normalized tensor.</param> | |||||
/// <param name="variance_epsilon"> A small float number to avoid dividing by 0.</param> | |||||
/// <param name="name">A name for this operation.</param> | |||||
/// <returns>the normalized, scaled, offset tensor.</returns> | |||||
public Tensor batch_normalization(Tensor x, | |||||
Tensor mean, | |||||
Tensor variance, | |||||
Tensor offset, | |||||
Tensor scale, | |||||
float variance_epsilon, | |||||
string name = null) => nn_impl.batch_normalization(x, mean, variance, offset, scale, variance_epsilon, name); | |||||
public Tensor max_pool(Tensor value, int[] ksize, int[] strides, string padding, string data_format = "NHWC", string name = null) | public Tensor max_pool(Tensor value, int[] ksize, int[] strides, string padding, string data_format = "NHWC", string name = null) | ||||
=> nn_ops.max_pool(value, ksize, strides, padding, data_format: data_format, name: name); | => nn_ops.max_pool(value, ksize, strides, padding, data_format: data_format, name: name); | ||||
@@ -31,6 +31,6 @@ namespace Tensorflow | |||||
public Tensor reshape(Tensor tensor, | public Tensor reshape(Tensor tensor, | ||||
object[] shape, | object[] shape, | ||||
string name = null) | string name = null) | ||||
=> gen_array_ops.reshape(tensor, ops.convert_to_tensor(shape), name); | |||||
=> array_ops.reshape(tensor, shape, name); | |||||
} | } | ||||
} | } |
@@ -68,20 +68,27 @@ namespace Tensorflow | |||||
/// <param name="name">A name for the operation (optional)</param> | /// <param name="name">A name for the operation (optional)</param> | ||||
/// <returns>if num_or_size_splits is a scalar returns num_or_size_splits Tensor objects; | /// <returns>if num_or_size_splits is a scalar returns num_or_size_splits Tensor objects; | ||||
/// if num_or_size_splits is a 1-D Tensor returns num_or_size_splits.get_shape[0] Tensor objects resulting from splitting value.</returns> | /// if num_or_size_splits is a 1-D Tensor returns num_or_size_splits.get_shape[0] Tensor objects resulting from splitting value.</returns> | ||||
public Tensor[] split(Tensor value, int num_split, Tensor axis, string name = null) | |||||
public Tensor[] split(Tensor value, int num_split, Axis axis, string name = null) | |||||
=> array_ops.split( | => array_ops.split( | ||||
value: value, | value: value, | ||||
num_split: num_split, | |||||
num_or_size_splits: num_split, | |||||
axis: axis, | axis: axis, | ||||
name: name); | name: name); | ||||
public Tensor[] split(Tensor value, int num_split, int axis, string name = null) | |||||
public Tensor[] split(Tensor value, int[] num_split, Axis axis, string name = null) | |||||
=> array_ops.split( | => array_ops.split( | ||||
value: value, | value: value, | ||||
num_split: num_split, | |||||
num_or_size_splits: num_split, | |||||
axis: axis, | axis: axis, | ||||
name: name); | name: name); | ||||
//public Tensor[] split(Tensor value, int num_split, Axis axis, string name = null) | |||||
// => array_ops.split( | |||||
// value: value, | |||||
// num_or_size_splits: num_split, | |||||
// axis: axis, | |||||
// name: name); | |||||
public Tensor ensure_shape(Tensor x, Shape shape, string name = null) | public Tensor ensure_shape(Tensor x, Shape shape, string name = null) | ||||
{ | { | ||||
return gen_ops.ensure_shape(x, shape, name); | return gen_ops.ensure_shape(x, shape, name); | ||||
@@ -23,7 +23,7 @@ namespace Tensorflow | |||||
=> gen_array_ops.tile(input, multiples, name); | => gen_array_ops.tile(input, multiples, name); | ||||
public Tensor tile(Tensor input, object[] multiples, string name = null) | public Tensor tile(Tensor input, object[] multiples, string name = null) | ||||
=> gen_array_ops.tile(input, ops.convert_to_tensor(multiples), name); | |||||
=> array_ops.tile(input, constant_op.constant(shape_utils.from_object_array(multiples).dims), name); | |||||
public Tensor tile(Tensor input, Shape multiples, string name = null) | public Tensor tile(Tensor input, Shape multiples, string name = null) | ||||
{ | { | ||||
@@ -486,7 +486,28 @@ namespace Tensorflow | |||||
throw new NotImplementedException(""); | throw new NotImplementedException(""); | ||||
} | } | ||||
} | } | ||||
public static NDArray GetFlattenArray(NDArray x) | |||||
{ | |||||
switch (x.GetDataType()) | |||||
{ | |||||
case TF_DataType.TF_FLOAT: | |||||
x = x.ToArray<float>(); | |||||
break; | |||||
case TF_DataType.TF_DOUBLE: | |||||
x = x.ToArray<double>(); | |||||
break; | |||||
case TF_DataType.TF_INT16: | |||||
case TF_DataType.TF_INT32: | |||||
x = x.ToArray<int>(); | |||||
break; | |||||
case TF_DataType.TF_INT64: | |||||
x = x.ToArray<long>(); | |||||
break; | |||||
default: | |||||
break; | |||||
} | |||||
return x; | |||||
} | |||||
public static TF_DataType GetDataType(this object data) | public static TF_DataType GetDataType(this object data) | ||||
{ | { | ||||
var type = data.GetType(); | var type = data.GetType(); | ||||
@@ -503,7 +524,7 @@ namespace Tensorflow | |||||
case Tensors tensors: | case Tensors tensors: | ||||
return tensors.dtype; | return tensors.dtype; | ||||
case IEnumerable<Tensor> tensors: | case IEnumerable<Tensor> tensors: | ||||
return tensors.First().dtype; | |||||
return tensors.Where(x => x is not null).First().dtype; | |||||
case RefVariable variable: | case RefVariable variable: | ||||
return variable.dtype; | return variable.dtype; | ||||
case ResourceVariable variable: | case ResourceVariable variable: | ||||
@@ -3,16 +3,16 @@ using System; | |||||
using System.Collections.Generic; | using System.Collections.Generic; | ||||
using System.Text; | using System.Text; | ||||
namespace Tensorflow.Extensions | |||||
namespace Tensorflow.Common.Extensions | |||||
{ | { | ||||
public static class JObjectExtensions | public static class JObjectExtensions | ||||
{ | { | ||||
public static T? TryGetOrReturnNull<T>(this JObject obj, string key) | public static T? TryGetOrReturnNull<T>(this JObject obj, string key) | ||||
{ | { | ||||
var res = obj[key]; | var res = obj[key]; | ||||
if(res is null) | |||||
if (res is null) | |||||
{ | { | ||||
return default(T); | |||||
return default; | |||||
} | } | ||||
else | else | ||||
{ | { |
@@ -0,0 +1,38 @@ | |||||
using System; | |||||
using System.Collections.Generic; | |||||
using System.Linq; | |||||
using System.Text; | |||||
namespace Tensorflow.Common.Extensions | |||||
{ | |||||
public static class LinqExtensions | |||||
{ | |||||
#if NETSTANDARD2_0 | |||||
public static IEnumerable<T> TakeLast<T>(this IEnumerable<T> sequence, int count) | |||||
{ | |||||
return sequence.Skip(sequence.Count() - count); | |||||
} | |||||
public static IEnumerable<T> SkipLast<T>(this IEnumerable<T> sequence, int count) | |||||
{ | |||||
return sequence.Take(sequence.Count() - count); | |||||
} | |||||
#endif | |||||
public static Tensors ToTensors(this Tensor[] tensors) | |||||
{ | |||||
return new Tensors(tensors); | |||||
} | |||||
public static Tensors ToTensors(this IList<Tensor> tensors) | |||||
{ | |||||
return new Tensors(tensors); | |||||
} | |||||
public static void Deconstruct<T1, T2, T3>(this (T1, T2, T3) values, out T1 first, out T2 second, out T3 third) | |||||
{ | |||||
first = values.Item1; | |||||
second = values.Item2; | |||||
third = values.Item3; | |||||
} | |||||
} | |||||
} |
@@ -0,0 +1,33 @@ | |||||
using System; | |||||
using System.Collections.Generic; | |||||
using System.Text; | |||||
using Tensorflow.Common.Types; | |||||
namespace Tensorflow.Common.Extensions | |||||
{ | |||||
public static class NestExtensions | |||||
{ | |||||
public static Tensors ToTensors(this INestable<Tensor> tensors) | |||||
{ | |||||
return new Tensors(tensors.AsNest()); | |||||
} | |||||
public static Tensors? ToTensors(this Nest<Tensor> tensors) | |||||
{ | |||||
return Tensors.FromNest(tensors); | |||||
} | |||||
/// <summary> | |||||
/// If the nested object is already a nested type, this function could reduce it. | |||||
/// For example, `Nest[Nest[T]]` can be reduced to `Nest[T]`. | |||||
/// </summary> | |||||
/// <typeparam name="TIn"></typeparam> | |||||
/// <typeparam name="TOut"></typeparam> | |||||
/// <param name="input"></param> | |||||
/// <returns></returns> | |||||
public static Nest<TOut> ReduceTo<TIn, TOut>(this INestStructure<TIn> input) where TIn: INestStructure<TOut> | |||||
{ | |||||
return Nest<TOut>.ReduceFrom(input); | |||||
} | |||||
} | |||||
} |
@@ -0,0 +1,20 @@ | |||||
using System; | |||||
using System.Collections.Generic; | |||||
using System.Text; | |||||
namespace Tensorflow.Common.Types | |||||
{ | |||||
/// <summary> | |||||
/// This is a temp solution, which should be removed after refactoring `Tensors` | |||||
/// </summary> | |||||
[Obsolete] | |||||
public class FakeTensorByTensorArray: Tensor | |||||
{ | |||||
public TensorArray TensorArray { get; set; } | |||||
public FakeTensorByTensorArray(TensorArray array) | |||||
{ | |||||
TensorArray = array; | |||||
} | |||||
} | |||||
} |
@@ -0,0 +1,69 @@ | |||||
using System; | |||||
using System.Collections.Generic; | |||||
using System.Diagnostics; | |||||
using System.Text; | |||||
namespace Tensorflow.Common.Types | |||||
{ | |||||
public class GeneralizedTensorShape: Nest<Shape> | |||||
{ | |||||
public GeneralizedTensorShape(Shape value, string? name = null) | |||||
{ | |||||
NodeValue = value; | |||||
NestType = NestType.Node; | |||||
} | |||||
public GeneralizedTensorShape(IEnumerable<Shape> values, string? name = null) | |||||
{ | |||||
ListValue = values.Select(s => new Nest<Shape>(s) as INestStructure<Shape>).ToList(); | |||||
Name = name; | |||||
NestType = NestType.List; | |||||
} | |||||
public GeneralizedTensorShape(Dictionary<string, Shape> value, string? name = null) | |||||
{ | |||||
DictValue = value.ToDictionary(x => x.Key, x => new Nest<Shape>(x.Value) as INestStructure<Shape>); | |||||
Name = name; | |||||
NestType = NestType.Dictionary; | |||||
} | |||||
public GeneralizedTensorShape(Nest<Shape> other) | |||||
{ | |||||
NestType = other.NestType; | |||||
NodeValue = other.NodeValue; | |||||
DictValue = other.DictValue; | |||||
ListValue = other.ListValue; | |||||
Name = other.Name; | |||||
} | |||||
public Shape ToSingleShape() | |||||
{ | |||||
var shapes = Flatten().ToList(); | |||||
if (shapes.Count != 1) | |||||
{ | |||||
throw new ValueError("The generalized shape contains more than 1 dim."); | |||||
} | |||||
return shapes[0]; | |||||
} | |||||
public long ToNumber() | |||||
{ | |||||
var shapes = Flatten().ToList(); | |||||
if (shapes.Count != 1 || shapes[0].ndim != 1) | |||||
{ | |||||
throw new ValueError("The generalized shape contains more than 1 dim."); | |||||
} | |||||
return shapes[0].dims[0]; | |||||
} | |||||
public INestStructure<TensorShapeConfig> ToTensorShapeConfigs() | |||||
{ | |||||
return MapStructure(s => new TensorShapeConfig() { Items = s.dims.Select<long, long?>(x => x == -1 ? null : x).ToArray() }); | |||||
} | |||||
public static implicit operator GeneralizedTensorShape(Shape shape) | |||||
{ | |||||
return new GeneralizedTensorShape(shape); | |||||
} | |||||
} | |||||
} |
@@ -0,0 +1,40 @@ | |||||
using System; | |||||
using System.Collections.Generic; | |||||
using System.Text; | |||||
namespace Tensorflow.Common.Types | |||||
{ | |||||
/// <summary> | |||||
/// This interface indicates that a class may have a nested structure and provide | |||||
/// methods to manipulate with the structure. | |||||
/// </summary> | |||||
public interface INestStructure<T>: INestable<T> | |||||
{ | |||||
NestType NestType { get; } | |||||
/// <summary> | |||||
/// The item count of depth 1 of the nested structure. | |||||
/// For example, [1, 2, [3, 4, 5]] has ShallowNestedCount = 3. | |||||
/// </summary> | |||||
int ShallowNestedCount { get; } | |||||
/// <summary> | |||||
/// The total item count of depth 1 of the nested structure. | |||||
/// For example, [1, 2, [3, 4, 5]] has TotalNestedCount = 5. | |||||
/// </summary> | |||||
int TotalNestedCount { get; } | |||||
/// <summary> | |||||
/// Flatten the Nestable object. Node that if the object contains only one value, | |||||
/// it will be flattened to an enumerable with one element. | |||||
/// </summary> | |||||
/// <returns></returns> | |||||
IEnumerable<T> Flatten(); | |||||
/// <summary> | |||||
/// Construct a new object with the same nested structure. | |||||
/// </summary> | |||||
/// <typeparam name="TOut"></typeparam> | |||||
/// <param name="func"></param> | |||||
/// <returns></returns> | |||||
INestStructure<TOut> MapStructure<TOut>(Func<T, TOut> func); | |||||
} | |||||
} |
@@ -0,0 +1,11 @@ | |||||
using System; | |||||
using System.Collections.Generic; | |||||
using System.Text; | |||||
namespace Tensorflow.Common.Types | |||||
{ | |||||
public interface INestable<T> | |||||
{ | |||||
Nest<T> AsNest(); | |||||
} | |||||
} |
@@ -0,0 +1,21 @@ | |||||
using System; | |||||
using System.Collections.Generic; | |||||
using System.Text; | |||||
namespace Tensorflow.Common.Types | |||||
{ | |||||
/// <summary> | |||||
/// This interface is used when some corresponding python methods have optional args. | |||||
/// For example, `Keras.Layer.Apply` generally takes three args as the inputs, while | |||||
/// `Keras.Layer.RNN` takes more. Then when calling RNN, you should add `RnnOptionalArgs` | |||||
/// as the parameter of the method. | |||||
/// </summary> | |||||
public interface IOptionalArgs | |||||
{ | |||||
/// <summary> | |||||
/// The identifier of the class. It is not an argument but only something to | |||||
/// separate different OptionalArgs. | |||||
/// </summary> | |||||
string Identifier { get; } | |||||
} | |||||
} |
@@ -0,0 +1,62 @@ | |||||
using System; | |||||
using System.Collections.Generic; | |||||
using System.Text; | |||||
namespace Tensorflow.Common.Types | |||||
{ | |||||
public static class Nest | |||||
{ | |||||
/// <summary> | |||||
/// Pack the flat items to a nested sequence by the template. | |||||
/// </summary> | |||||
/// <typeparam name="T"></typeparam> | |||||
/// <param name="template"></param> | |||||
/// <param name="flatItems"></param> | |||||
/// <returns></returns> | |||||
public static Nest<TOut> PackSequenceAs<T, TOut>(INestable<T> template, TOut[] flatItems) | |||||
{ | |||||
return template.AsNest().PackSequence(flatItems); | |||||
} | |||||
/// <summary> | |||||
/// Pack the flat items to a nested sequence by the template. | |||||
/// </summary> | |||||
/// <typeparam name="T"></typeparam> | |||||
/// <param name="template"></param> | |||||
/// <param name="flatItems"></param> | |||||
/// <returns></returns> | |||||
public static Nest<T> PackSequenceAs<T>(INestable<T> template, List<T> flatItems) | |||||
{ | |||||
return template.AsNest().PackSequence(flatItems.ToArray()); | |||||
} | |||||
/// <summary> | |||||
/// Flatten the nested object. | |||||
/// </summary> | |||||
/// <typeparam name="T"></typeparam> | |||||
/// <param name="nestedObject"></param> | |||||
/// <returns></returns> | |||||
public static IEnumerable<T> Flatten<T>(INestable<T> nestedObject) | |||||
{ | |||||
return nestedObject.AsNest().Flatten(); | |||||
} | |||||
/// <summary> | |||||
/// Map the structure with specified function. | |||||
/// </summary> | |||||
/// <typeparam name="TIn"></typeparam> | |||||
/// <typeparam name="TOut"></typeparam> | |||||
/// <param name="func"></param> | |||||
/// <param name="nestedObject"></param> | |||||
/// <returns></returns> | |||||
public static INestStructure<TOut> MapStructure<TIn, TOut>(Func<TIn, TOut> func, INestable<TIn> nestedObject) | |||||
{ | |||||
return nestedObject.AsNest().MapStructure(func); | |||||
} | |||||
public static bool IsNested<T>(INestable<T> obj) | |||||
{ | |||||
return obj.AsNest().IsNested(); | |||||
} | |||||
} | |||||
} |
@@ -0,0 +1,485 @@ | |||||
using System; | |||||
using System.Collections.Generic; | |||||
using System.Text; | |||||
using Tensorflow.Common.Extensions; | |||||
namespace Tensorflow.Common.Types | |||||
{ | |||||
public enum NestType | |||||
{ | |||||
Empty, | |||||
Node, | |||||
List, | |||||
Dictionary | |||||
} | |||||
/// <summary> | |||||
/// A nested structure which may inclulde value, list and dictionary. | |||||
/// Note that dictionary does not ensure the data order. When using it as IEnumerable, | |||||
/// its order is depth-first. | |||||
/// </summary> | |||||
/// <typeparam name="T"></typeparam> | |||||
public class Nest<T> : INestStructure<T>, IEnumerable<T> | |||||
{ | |||||
private static readonly Nest<T> _empty = new Nest<T>() | |||||
{ | |||||
NestType = NestType.Empty, | |||||
}; | |||||
public static Nest<T> Empty => _empty; | |||||
public NestType NestType { get; protected set; } | |||||
public string? Name { get; set; } | |||||
public T? NodeValue { get; protected set; } | |||||
public List<INestStructure<T>>? ListValue { get; protected set; } | |||||
public Dictionary<string, INestStructure<T>>? DictValue { get; protected set; } | |||||
public int ShallowNestedCount | |||||
{ | |||||
get | |||||
{ | |||||
if (NestType == NestType.Empty) | |||||
{ | |||||
return 0; | |||||
} | |||||
else if (NestType == NestType.Node) | |||||
{ | |||||
return 1; | |||||
} | |||||
else if (NestType == NestType.List) | |||||
{ | |||||
return ListValue!.Count; | |||||
} | |||||
else // dict | |||||
{ | |||||
return DictValue!.Count; | |||||
} | |||||
} | |||||
} | |||||
public int TotalNestedCount | |||||
{ | |||||
get | |||||
{ | |||||
return Flatten().Count(); | |||||
} | |||||
} | |||||
protected Nest() { } | |||||
public Nest(T value, string? name = null) | |||||
{ | |||||
NodeValue = value; | |||||
Name = name; | |||||
NestType = NestType.Node; | |||||
} | |||||
public Nest(IEnumerable<INestStructure<T>> values, string? name = null) | |||||
{ | |||||
ListValue = values.ToList(); | |||||
Name = name; | |||||
NestType = NestType.List; | |||||
} | |||||
public Nest(Dictionary<string, INestStructure<T>> value, string? name = null) | |||||
{ | |||||
DictValue = value; | |||||
Name = name; | |||||
NestType = NestType.Dictionary; | |||||
} | |||||
public Nest(Nest<T> other) | |||||
{ | |||||
NestType = other.NestType; | |||||
NodeValue = other.NodeValue; | |||||
DictValue = other.DictValue; | |||||
ListValue = other.ListValue; | |||||
Name = other.Name; | |||||
} | |||||
public virtual IEnumerable<T> Flatten() | |||||
{ | |||||
return FlattenInternal(this); | |||||
} | |||||
public virtual INestStructure<TOut> MapStructure<TOut>(Func<T, TOut> func) | |||||
{ | |||||
return MapStructureInternal(func); | |||||
} | |||||
/// <summary> | |||||
/// Pack the flat items to a nested sequence by the template. | |||||
/// </summary> | |||||
/// <param name="flatItems"></param> | |||||
/// <returns></returns> | |||||
public virtual Nest<TOut> PackSequence<TOut>(TOut[] flatItems) | |||||
{ | |||||
if(flatItems.Length == 0) | |||||
{ | |||||
return Nest<TOut>.Empty; | |||||
} | |||||
int index = 0; | |||||
return PackSequenceInternal(this, flatItems, ref index); | |||||
} | |||||
private static Nest<TOut> PackSequenceInternal<TOut>(Nest<T> template, TOut[] flatItems, ref int index) | |||||
{ | |||||
if(template.NestType == NestType.Node) | |||||
{ | |||||
if(index >= flatItems.Length) | |||||
{ | |||||
throw new InvalidArgumentError("The template and flat items are not matched."); | |||||
} | |||||
return new Nest<TOut>(flatItems[index++]); | |||||
} | |||||
else if(template.NestType == NestType.List) | |||||
{ | |||||
List<Nest<TOut>> nestedObjects = new List<Nest<TOut>>(); | |||||
for (int i = 0; i < template.ListValue!.Count; i++) | |||||
{ | |||||
nestedObjects.Add(PackSequenceInternal(template.ListValue![i].AsNest(), flatItems, ref index)); | |||||
} | |||||
return new Nest<TOut>(nestedObjects); | |||||
} | |||||
else if(template.NestType == NestType.Node) | |||||
{ | |||||
Dictionary<string, INestStructure<TOut>> dict = new Dictionary<string, INestStructure<TOut>>(); | |||||
foreach(var (key, value) in template.DictValue!) | |||||
{ | |||||
dict[key] = PackSequenceInternal(value.AsNest(), flatItems, ref index); | |||||
} | |||||
return new Nest<TOut>(dict); | |||||
} | |||||
// Consider Empty as invalid type. | |||||
throw new InvalidArgumentError("When using `PackSequenceAs`, the template cannot contain empty node."); | |||||
} | |||||
public virtual Nest<T> AsNest() | |||||
{ | |||||
return this; | |||||
} | |||||
public virtual Nest<T> MergeWith(Nest<T>? other) | |||||
{ | |||||
if(other is null || other == Nest<T>.Empty) | |||||
{ | |||||
return this; | |||||
} | |||||
if(this == Nest<T>.Empty) | |||||
{ | |||||
return other; | |||||
} | |||||
if(NestType == NestType.Node && other.NestType == NestType.Node) | |||||
{ | |||||
return new Nest<T>(new Nest<T>[] { this, other }); | |||||
} | |||||
else if(NestType == NestType.List && other.NestType == NestType.List) | |||||
{ | |||||
return new Nest<T>(this.ListValue!.Concat(other.ListValue!)); | |||||
} | |||||
else if(NestType == NestType.Dictionary && other.NestType == NestType.Dictionary) | |||||
{ | |||||
return new Nest<T>(this.DictValue!.Concat(other.DictValue!).ToDictionary(x => x.Key, x => x.Value)); | |||||
} | |||||
else | |||||
{ | |||||
return new Nest<T>(new Nest<T>[] { this, other }); | |||||
} | |||||
} | |||||
/// <summary> | |||||
/// To see if the nested object is really nested. Despite being called `Nest`, sometimes it's actually not | |||||
/// nested. For example, [1, 2, 3] is not nested, while [1, [2, 3]] is nested. | |||||
/// </summary> | |||||
/// <returns></returns> | |||||
public bool IsNested() | |||||
{ | |||||
if(NestType is NestType.Empty or NestType.Node) | |||||
{ | |||||
return false; | |||||
} | |||||
else if(NestType is NestType.List) | |||||
{ | |||||
return ListValue!.Count > 0; | |||||
} | |||||
else | |||||
{ | |||||
return DictValue!.Count > 0; | |||||
} | |||||
} | |||||
[Obsolete("The indexer of Tensors is not encouraged because it leads to unclear meanings.")] | |||||
public T this[int index] | |||||
{ | |||||
get | |||||
{ | |||||
bool success = FindInternal(this, index, out var result); | |||||
if (success) | |||||
{ | |||||
return result; | |||||
} | |||||
else | |||||
{ | |||||
throw new IndexOutOfRangeException(); | |||||
} | |||||
} | |||||
set | |||||
{ | |||||
bool success = SetInternal(this, index, value); | |||||
if (!success) | |||||
{ | |||||
throw new IndexOutOfRangeException(); | |||||
} | |||||
} | |||||
} | |||||
/// <summary> | |||||
/// If the existing nested structure if of type `Nest[INestStructure[T]]`, we can reduce it | |||||
/// to `Nest[T]`. | |||||
/// </summary> | |||||
/// <typeparam name="TOut"></typeparam> | |||||
/// <param name="input"></param> | |||||
/// <returns></returns> | |||||
public static Nest<T> ReduceFrom<TOut>(INestStructure<TOut> input) where TOut: INestStructure<T> | |||||
{ | |||||
var nested = input.AsNest(); | |||||
return ReduceInternal(nested).AsNest(); | |||||
} | |||||
private static INestStructure<T> ReduceInternal<TOut>(Nest<TOut> node) where TOut : INestStructure<T> | |||||
{ | |||||
if(node.NestType == NestType.Empty) | |||||
{ | |||||
return Nest<T>.Empty; | |||||
} | |||||
else if(node.NestType == NestType.Node) | |||||
{ | |||||
return node.NodeValue!.AsNest(); | |||||
} | |||||
else if(node.NestType == NestType.List) | |||||
{ | |||||
return new Nest<T>(node.ListValue!.Select(x => ReduceInternal(x.AsNest()))); | |||||
} | |||||
else // Dictionary type | |||||
{ | |||||
return new Nest<T>(node.DictValue!.ToDictionary(x => x.Key, x => ReduceInternal(x.Value.AsNest()))); | |||||
} | |||||
} | |||||
private static bool FindInternal(Nest<T> node, int index, out T? result) | |||||
{ | |||||
if (node.NestType == NestType.Node) | |||||
{ | |||||
if(index == 0) | |||||
{ | |||||
result = node.NodeValue!; | |||||
return true; | |||||
} | |||||
result = default(T); | |||||
return false; | |||||
} | |||||
else if (node.NestType == NestType.List) | |||||
{ | |||||
foreach (var item in node.ListValue!) | |||||
{ | |||||
if(index == 0) | |||||
{ | |||||
return FindInternal(item.AsNest(), index, out result); | |||||
} | |||||
index--; | |||||
} | |||||
result = default(T); | |||||
return false; | |||||
} | |||||
else if(node.NestType == NestType.Dictionary) | |||||
{ | |||||
foreach (var item in node.DictValue!.Values) | |||||
{ | |||||
if (index == 0) | |||||
{ | |||||
return FindInternal(item.AsNest(), index, out result); | |||||
} | |||||
index--; | |||||
} | |||||
result = default(T); | |||||
return false; | |||||
} | |||||
else | |||||
{ | |||||
result = default(T); | |||||
return false; | |||||
} | |||||
} | |||||
private static bool SetInternal(Nest<T> node, int index, T newValue) | |||||
{ | |||||
if (node.NestType == NestType.Node) | |||||
{ | |||||
if (index == 0) | |||||
{ | |||||
node.NodeValue = newValue; | |||||
return true; | |||||
} | |||||
return false; | |||||
} | |||||
else if (node.NestType == NestType.List) | |||||
{ | |||||
foreach (var item in node.ListValue!) | |||||
{ | |||||
if (index == 0) | |||||
{ | |||||
return SetInternal(item.AsNest(), index, newValue); | |||||
} | |||||
index--; | |||||
} | |||||
return false; | |||||
} | |||||
else if (node.NestType == NestType.Dictionary) | |||||
{ | |||||
foreach (var item in node.DictValue!.Values) | |||||
{ | |||||
if (index == 0) | |||||
{ | |||||
return SetInternal(item.AsNest(), index, newValue); | |||||
} | |||||
index--; | |||||
} | |||||
return false; | |||||
} | |||||
else | |||||
{ | |||||
return false; | |||||
} | |||||
} | |||||
private static IEnumerable<T> FlattenInternal(Nest<T> node) | |||||
{ | |||||
if (node.NestType == NestType.Node) | |||||
{ | |||||
yield return node.NodeValue!; | |||||
} | |||||
else if (node.NestType == NestType.List) | |||||
{ | |||||
foreach (var item in node.ListValue!) | |||||
{ | |||||
foreach(var val in FlattenInternal(item.AsNest())) | |||||
{ | |||||
yield return val; | |||||
} | |||||
} | |||||
} | |||||
else if (node.NestType == NestType.Dictionary) | |||||
{ | |||||
foreach (var item in node.DictValue!.Values) | |||||
{ | |||||
foreach (var val in FlattenInternal(item.AsNest())) | |||||
{ | |||||
yield return val; | |||||
} | |||||
} | |||||
} | |||||
} | |||||
private Nest<TOut> MapStructureInternal<TOut>(Func<T, TOut> func) | |||||
{ | |||||
if (NestType == NestType.Node) | |||||
{ | |||||
return new Nest<TOut>(func(NodeValue!)); | |||||
} | |||||
else if (NestType == NestType.List) | |||||
{ | |||||
List<Nest<TOut>> outs = new List<Nest<TOut>>(); | |||||
foreach (var item in ListValue!) | |||||
{ | |||||
outs.Add(item.AsNest().MapStructureInternal(func)); | |||||
} | |||||
return new Nest<TOut>(outs); | |||||
} | |||||
else if (NestType == NestType.Dictionary) | |||||
{ | |||||
Dictionary<string, INestStructure<TOut>> outs = new Dictionary<string, INestStructure<TOut>>(); | |||||
foreach (var (key, value) in DictValue!) | |||||
{ | |||||
outs.Add(key, value.AsNest().MapStructureInternal(func)); | |||||
} | |||||
return new Nest<TOut>(outs); | |||||
} | |||||
else | |||||
{ | |||||
return Nest<TOut>.Empty; | |||||
} | |||||
} | |||||
public IEnumerator<T> GetEnumerator() | |||||
{ | |||||
return Flatten().GetEnumerator(); | |||||
} | |||||
IEnumerator IEnumerable.GetEnumerator() | |||||
{ | |||||
return GetEnumerator(); | |||||
} | |||||
public override string ToString() | |||||
{ | |||||
StringBuilder sb = new StringBuilder(); | |||||
sb.Append("("); | |||||
WriteString(this, sb); | |||||
sb.Append(")"); | |||||
return sb.ToString(); | |||||
} | |||||
private static void WriteString(Nest<T> node, StringBuilder sb) | |||||
{ | |||||
if (!string.IsNullOrEmpty(node.Name)) | |||||
{ | |||||
sb.Append($"{node.Name}: "); | |||||
} | |||||
if (node.NestType == NestType.Node) | |||||
{ | |||||
sb.Append(node.NodeValue!.ToString()); | |||||
} | |||||
else if (node.NestType == NestType.List) | |||||
{ | |||||
sb.Append("["); | |||||
for(int i = 0; i < node.ListValue!.Count; i++) | |||||
{ | |||||
WriteString(node.ListValue![i].AsNest(), sb); | |||||
if(i != node.ListValue!.Count - 1) | |||||
{ | |||||
sb.Append(", "); | |||||
} | |||||
} | |||||
sb.Append("]"); | |||||
} | |||||
else if (node.NestType == NestType.Dictionary) | |||||
{ | |||||
sb.Append("{"); | |||||
int count = node.DictValue!.Count; | |||||
int i = 0; | |||||
foreach (var (key, value) in node.DictValue!) | |||||
{ | |||||
sb.Append($"{key}: "); | |||||
WriteString(value.AsNest(), sb); | |||||
if (i != count - 1) | |||||
{ | |||||
sb.Append(", "); | |||||
} | |||||
i++; | |||||
} | |||||
sb.Append("}"); | |||||
} | |||||
else | |||||
{ | |||||
sb.Append("<empty>"); | |||||
} | |||||
} | |||||
public static implicit operator Nest<T>((INestStructure<T>, INestStructure<T>) inputs) | |||||
{ | |||||
return new Nest<T>(new INestStructure<T>[] { inputs.Item1, inputs.Item2 }); | |||||
} | |||||
public static implicit operator Nest<T>((INestStructure<T>, INestStructure<T>, INestStructure<T>) inputs) | |||||
{ | |||||
return new Nest<T>(new INestStructure<T>[] { inputs.Item1, inputs.Item2, inputs.Item3 }); | |||||
} | |||||
} | |||||
} |
@@ -0,0 +1,103 @@ | |||||
using System; | |||||
using System.Collections.Generic; | |||||
using System.Text; | |||||
namespace Tensorflow.Common.Types | |||||
{ | |||||
public class NestDictionary<TKey, TValue> : INestStructure<TValue>, IDictionary<TKey, TValue> where TKey : notnull | |||||
{ | |||||
public NestType NestType => NestType.Dictionary; | |||||
public IDictionary<TKey, TValue> Value { get; set; } | |||||
public int ShallowNestedCount => Values.Count; | |||||
public int TotalNestedCount => Values.Count; | |||||
public NestDictionary(IDictionary<TKey, TValue> dict) | |||||
{ | |||||
Value = dict; | |||||
} | |||||
public IEnumerable<TValue> Flatten() | |||||
{ | |||||
return Value.Select(x => x.Value); | |||||
} | |||||
public INestStructure<TOut> MapStructure<TOut>(Func<TValue, TOut> func) | |||||
{ | |||||
return new NestList<TOut>(Value.Select(x => func(x.Value))); | |||||
} | |||||
public Nest<TValue> AsNest() | |||||
{ | |||||
return new Nest<TValue>(Value.Values.Select(x => new Nest<TValue>(x))); | |||||
} | |||||
// Required IDictionary<TKey, TValue> members | |||||
public int Count => Value.Count; | |||||
public bool IsReadOnly => Value.IsReadOnly; | |||||
public ICollection<TKey> Keys => Value.Keys; | |||||
public ICollection<TValue> Values => Value.Values; | |||||
public void Add(TKey key, TValue value) | |||||
{ | |||||
Value.Add(key, value); | |||||
} | |||||
public void Add(KeyValuePair<TKey, TValue> item) | |||||
{ | |||||
Value.Add(item); | |||||
} | |||||
public void Clear() | |||||
{ | |||||
Value.Clear(); | |||||
} | |||||
public bool Contains(KeyValuePair<TKey, TValue> item) | |||||
{ | |||||
return Value.Contains(item); | |||||
} | |||||
public bool ContainsKey(TKey key) | |||||
{ | |||||
return Value.ContainsKey(key); | |||||
} | |||||
public void CopyTo(KeyValuePair<TKey, TValue>[] array, int arrayIndex) | |||||
{ | |||||
Value.CopyTo(array, arrayIndex); | |||||
} | |||||
public IEnumerator<KeyValuePair<TKey, TValue>> GetEnumerator() | |||||
{ | |||||
return Value.GetEnumerator(); | |||||
} | |||||
IEnumerator IEnumerable.GetEnumerator() | |||||
{ | |||||
return GetEnumerator(); | |||||
} | |||||
public bool Remove(TKey key) | |||||
{ | |||||
return Value.Remove(key); | |||||
} | |||||
public bool Remove(KeyValuePair<TKey, TValue> item) | |||||
{ | |||||
return Value.Remove(item); | |||||
} | |||||
public bool TryGetValue(TKey key, out TValue value) | |||||
{ | |||||
return Value.TryGetValue(key, out value); | |||||
} | |||||
// Optional IDictionary<TKey, TValue> members | |||||
public TValue this[TKey key] | |||||
{ | |||||
get => Value[key]; | |||||
set => Value[key] = value; | |||||
} | |||||
} | |||||
} |
@@ -0,0 +1,53 @@ | |||||
using System; | |||||
using System.Collections.Generic; | |||||
using System.Text; | |||||
namespace Tensorflow.Common.Types | |||||
{ | |||||
/// <summary> | |||||
/// The implementation of a list that support nest structure, in which the depth is 1. | |||||
/// </summary> | |||||
/// <typeparam name="T"></typeparam> | |||||
public sealed class NestList<T> : INestStructure<T>, IEnumerable<T> | |||||
{ | |||||
public NestType NestType => NestType.List; | |||||
public List<T> Values { get; set; } | |||||
public int ShallowNestedCount => Values.Count; | |||||
public int TotalNestedCount => Values.Count; | |||||
public NestList(params T[] values) | |||||
{ | |||||
Values = new List<T>(values); | |||||
} | |||||
public NestList(IEnumerable<T> values) | |||||
{ | |||||
Values = new List<T>(values); | |||||
} | |||||
public IEnumerable<T> Flatten() | |||||
{ | |||||
return Values; | |||||
} | |||||
public INestStructure<TOut> MapStructure<TOut>(Func<T, TOut> func) | |||||
{ | |||||
return new NestList<TOut>(Values.Select(x => func(x))); | |||||
} | |||||
public Nest<T> AsNest() | |||||
{ | |||||
return new Nest<T>(Values.Select(x => new Nest<T>(x))); | |||||
} | |||||
// Enumerator implementation | |||||
public IEnumerator<T> GetEnumerator() | |||||
{ | |||||
return Values.GetEnumerator(); | |||||
} | |||||
IEnumerator IEnumerable.GetEnumerator() | |||||
{ | |||||
return GetEnumerator(); | |||||
} | |||||
} | |||||
} |
@@ -0,0 +1,36 @@ | |||||
using System; | |||||
using System.Collections.Generic; | |||||
using System.Text; | |||||
namespace Tensorflow.Common.Types | |||||
{ | |||||
/// <summary> | |||||
/// A nested structure with only one element. | |||||
/// </summary> | |||||
/// <typeparam name="T"></typeparam> | |||||
public class NestNode<T> : INestStructure<T> | |||||
{ | |||||
public NestType NestType => NestType.Node; | |||||
public T Value { get; set; } | |||||
public int ShallowNestedCount => 1; | |||||
public int TotalNestedCount => 1; | |||||
public NestNode(T value) | |||||
{ | |||||
Value = value; | |||||
} | |||||
public IEnumerable<T> Flatten() | |||||
{ | |||||
yield return Value; | |||||
} | |||||
public INestStructure<TOut> MapStructure<TOut>(Func<T, TOut> func) | |||||
{ | |||||
return new NestNode<TOut>(func(Value)); | |||||
} | |||||
public Nest<T> AsNest() | |||||
{ | |||||
return new Nest<T>(Value); | |||||
} | |||||
} | |||||
} |
@@ -3,7 +3,7 @@ using System; | |||||
using System.Collections.Generic; | using System.Collections.Generic; | ||||
using System.Linq; | using System.Linq; | ||||
namespace Tensorflow.Keras.Saving | |||||
namespace Tensorflow.Common.Types | |||||
{ | { | ||||
public class TensorShapeConfig | public class TensorShapeConfig | ||||
{ | { |
@@ -161,8 +161,8 @@ namespace Tensorflow | |||||
break; | break; | ||||
} | } | ||||
yield return (new Tensors(results.Take(FirstInputTensorCount)), results.Length == FirstInputTensorCount ? | |||||
null : new Tensors(results.Skip(FirstInputTensorCount))); | |||||
yield return (new Tensors(results.Take(FirstInputTensorCount).ToArray()), results.Length == FirstInputTensorCount ? | |||||
null : new Tensors(results.Skip(FirstInputTensorCount).ToArray())); | |||||
} | } | ||||
} | } | ||||
@@ -352,13 +352,19 @@ namespace Tensorflow.Eager | |||||
c_api.TFE_OpSetAttrFloat(op, key, Convert.ToSingle(value)); | c_api.TFE_OpSetAttrFloat(op, key, Convert.ToSingle(value)); | ||||
break; | break; | ||||
case TF_AttrType.TF_ATTR_SHAPE: | case TF_AttrType.TF_ATTR_SHAPE: | ||||
var dims = (value as long[]).ToArray(); | |||||
long[] dims; | |||||
if (value is Shape shape) dims = shape.dims.ToArray(); | |||||
else if (value is long[] longs) dims = longs.ToArray(); | |||||
else if (value is int[] ints) dims = ints.Select(x => (long)x).ToArray(); | |||||
else dims = ((long[])value).ToArray(); | |||||
c_api.TFE_OpSetAttrShape(op, key, dims, dims.Length, status); | c_api.TFE_OpSetAttrShape(op, key, dims, dims.Length, status); | ||||
status.Check(true); | status.Check(true); | ||||
break; | break; | ||||
case TF_AttrType.TF_ATTR_FUNC: | case TF_AttrType.TF_ATTR_FUNC: | ||||
if (value is ConcreteFunction func) | if (value is ConcreteFunction func) | ||||
c_api.TFE_OpSetAttrFunctionName(op, key, func.func_graph.FuncName, func.func_graph.FuncName.Length); | c_api.TFE_OpSetAttrFunctionName(op, key, func.func_graph.FuncName, func.func_graph.FuncName.Length); | ||||
else if(value is string str) | |||||
c_api.TFE_OpSetAttrFunctionName(op, key, str, str.Length); | |||||
else | else | ||||
throw new NotImplementedException("TF_AttrType.TF_ATTR_FUNC"); | throw new NotImplementedException("TF_AttrType.TF_ATTR_FUNC"); | ||||
break; | break; | ||||
@@ -65,7 +65,7 @@ namespace Tensorflow.Eager | |||||
{ | { | ||||
outgrad_vec = output_gradients.ToList(); | outgrad_vec = output_gradients.ToList(); | ||||
} | } | ||||
var result = tape.ComputeGradient(target_vec, sources_vec, source_tensors_that_are_targets, outgrad_vec, false); | |||||
var result = tape.ComputeGradient(target_vec, sources_vec, source_tensors_that_are_targets, outgrad_vec, true); | |||||
bool unconnected_gradients_zero = unconnected_gradients == "zero"; | bool unconnected_gradients_zero = unconnected_gradients == "zero"; | ||||
@@ -137,7 +137,6 @@ namespace Tensorflow.Eager | |||||
{ | { | ||||
dims[i] = c_api.TFE_TensorHandleDim(handle, i, status); | dims[i] = c_api.TFE_TensorHandleDim(handle, i, status); | ||||
} | } | ||||
Shape tensor_shape = new(dims); | |||||
if(status.Code != TF_Code.TF_OK) | if(status.Code != TF_Code.TF_OK) | ||||
{ | { | ||||
@@ -145,6 +144,7 @@ namespace Tensorflow.Eager | |||||
} | } | ||||
else | else | ||||
{ | { | ||||
Shape tensor_shape = new(dims); | |||||
return new TapeTensor(id, dtype, tensor_shape); | return new TapeTensor(id, dtype, tensor_shape); | ||||
} | } | ||||
} | } | ||||
@@ -173,8 +173,12 @@ namespace Tensorflow.Eager | |||||
return dtype == dtypes.variant || dtype == dtypes.resource; | return dtype == dtypes.variant || dtype == dtypes.resource; | ||||
} | } | ||||
bool ListContainNone(long[] list) | |||||
bool ListContainNone(long[]? list) | |||||
{ | { | ||||
if(list is null) | |||||
{ | |||||
return true; | |||||
} | |||||
int len = list.Length; | int len = list.Length; | ||||
if(len == 0) | if(len == 0) | ||||
{ | { | ||||
@@ -10,6 +10,11 @@ namespace Tensorflow.Eager | |||||
var str = NDArrayRender.ToString(nd); | var str = NDArrayRender.ToString(nd); | ||||
return $"tf.Tensor: shape={shape}, dtype={dtype.as_numpy_name()}, numpy={str}"; | return $"tf.Tensor: shape={shape}, dtype={dtype.as_numpy_name()}, numpy={str}"; | ||||
} | } | ||||
public string ToString(int maxLength) | |||||
{ | |||||
var nd = new NDArray(this); | |||||
var str = NDArrayRender.ToString(nd, maxLength); | |||||
return $"tf.Tensor: shape={shape}, dtype={dtype.as_numpy_name()}, numpy={str}"; | |||||
} | |||||
} | } | ||||
} | } |
@@ -0,0 +1,25 @@ | |||||
using Tensorflow; | |||||
internal static class GraphOnlyOps | |||||
{ | |||||
/// <summary> | |||||
/// Graph-only version of tf.compat.v1.placeholder(), for internal use only. | |||||
/// </summary> | |||||
/// <param name="dtyype"></param> | |||||
/// <param name="shape"></param> | |||||
/// <param name="name"></param> | |||||
/// <returns></returns> | |||||
internal static Tensor graph_placeholder(TF_DataType dtype, Shape shape, string? name = null) | |||||
{ | |||||
var dtype_value = new AttrValue() { Type = dtype.as_datatype_enum() }; | |||||
var shape_value = new AttrValue() { Shape = shape.as_proto() }; | |||||
var g = ops.get_default_graph(); | |||||
Dictionary<string, AttrValue> attrs = new(); | |||||
attrs["dtype"] = dtype_value; | |||||
attrs["shape"] = shape_value; | |||||
var op = g.create_op("Placeholder", new Tensor[0], new TF_DataType[] { dtype }, | |||||
new TF_DataType[0], attrs: attrs, name: name); | |||||
var result = op.outputs[0]; | |||||
return result; | |||||
} | |||||
} |
@@ -0,0 +1,19 @@ | |||||
using System; | |||||
using System.Collections.Generic; | |||||
using System.Text; | |||||
namespace Tensorflow.Exceptions | |||||
{ | |||||
public class NotOkStatusException : TensorflowException | |||||
{ | |||||
public NotOkStatusException() : base() | |||||
{ | |||||
} | |||||
public NotOkStatusException(string message) : base(message) | |||||
{ | |||||
} | |||||
} | |||||
} |
@@ -49,12 +49,25 @@ namespace Tensorflow.Framework | |||||
public static implicit operator Tensor(IndexedSlices indexedSlices) | public static implicit operator Tensor(IndexedSlices indexedSlices) | ||||
{ | { | ||||
return indexedSlices.values; | |||||
return _indexed_slices_to_tensor(indexedSlices); | |||||
} | } | ||||
public static implicit operator IndexedSlices(Tensor tensor) | public static implicit operator IndexedSlices(Tensor tensor) | ||||
{ | { | ||||
return tensor.Tag as IndexedSlices; | return tensor.Tag as IndexedSlices; | ||||
} | } | ||||
/// <summary> | |||||
/// Converts an IndexedSlices object `value` to a Tensor. | |||||
/// </summary> | |||||
/// <param name="indexedSlices"></param> | |||||
/// <param name="dtype"></param> | |||||
/// <param name="name"></param> | |||||
/// <param name="as_ref"></param> | |||||
/// <returns></returns> | |||||
public static Tensor _indexed_slices_to_tensor(IndexedSlices indexedSlices, TF_DataType dtype = TF_DataType.DtInvalid, String name = "", bool as_ref = false) | |||||
{ | |||||
return gen_math_ops.unsorted_segment_sum(indexedSlices.values, indexedSlices.indices, indexedSlices.dense_shape.slice(0)); | |||||
} | |||||
} | } | ||||
} | } |
@@ -1,4 +1,5 @@ | |||||
using System.Linq; | using System.Linq; | ||||
using Tensorflow.Eager; | |||||
namespace Tensorflow.Framework.Models | namespace Tensorflow.Framework.Models | ||||
{ | { | ||||
@@ -24,5 +25,17 @@ namespace Tensorflow.Framework.Models | |||||
shapes.Insert(0, dim); | shapes.Insert(0, dim); | ||||
return new TensorSpec(shapes.ToArray(), _dtype); | return new TensorSpec(shapes.ToArray(), _dtype); | ||||
} | } | ||||
public static TensorSpec FromTensor(Tensor tensor, string? name = null) | |||||
{ | |||||
if(tensor is EagerTensor) | |||||
{ | |||||
return new TensorSpec(tensor.shape, tensor.dtype, name); | |||||
} | |||||
else | |||||
{ | |||||
return new TensorSpec(tensor.shape, tensor.dtype, name ?? tensor.name); | |||||
} | |||||
} | |||||
} | } | ||||
} | } |
@@ -0,0 +1,89 @@ | |||||
using Tensorflow.Graphs; | |||||
namespace Tensorflow.Framework | |||||
{ | |||||
internal static class auto_control_deps_utils | |||||
{ | |||||
public static readonly string READ_ONLY_RESOURCE_INPUTS_ATTR = "_read_only_resource_inputs"; | |||||
public static List<int> get_read_only_resource_input_indices_graph(FuncGraph func_graph) | |||||
{ | |||||
List<int> result = new List<int>(); | |||||
// A cache to store the read only resource inputs of an Op. | |||||
// Operation -> ObjectIdentitySet of resource handles. | |||||
Dictionary<Operation, HashSet<Tensor>> opReadOnlyResourceInputs = | |||||
new Dictionary<Operation, HashSet<Tensor>>(); | |||||
for (int inputIndex = 0; inputIndex < func_graph.Inputs.Length; inputIndex++) | |||||
{ | |||||
Tensor t = func_graph.Inputs[inputIndex]; | |||||
if (t.dtype != dtypes.resource) | |||||
continue; | |||||
bool readOnly = true; | |||||
foreach (var op in t.consumers()) | |||||
{ | |||||
if (opReadOnlyResourceInputs.ContainsKey(op)) | |||||
{ | |||||
if (!opReadOnlyResourceInputs[op].Contains(t)) | |||||
{ | |||||
readOnly = false; | |||||
break; | |||||
} | |||||
} | |||||
else | |||||
{ | |||||
List<int> indices = _get_read_only_resource_input_indices_op(op); | |||||
opReadOnlyResourceInputs[op] = new HashSet<Tensor>( | |||||
indices.Select(i => op.inputs[i])); | |||||
if (!opReadOnlyResourceInputs[op].Contains(t)) | |||||
{ | |||||
readOnly = false; | |||||
break; | |||||
} | |||||
} | |||||
} | |||||
if (readOnly) | |||||
result.Add(inputIndex); | |||||
} | |||||
return result; | |||||
} | |||||
private static List<int> _get_read_only_resource_input_indices_op(Operation op) | |||||
{ | |||||
// ignore the RESOURCE_READ_OPS | |||||
int[] read_only_input_indices; | |||||
try | |||||
{ | |||||
read_only_input_indices = op.get_attr<int[]>(READ_ONLY_RESOURCE_INPUTS_ATTR); | |||||
} | |||||
catch (InvalidArgumentError) | |||||
{ | |||||
return new List<int>(); | |||||
} | |||||
int read_only_index = 0; | |||||
List<int> result = new(); | |||||
for (int i = 0; i < op.inputs.Length; i++) | |||||
{ | |||||
if (read_only_index >= read_only_input_indices.Length) | |||||
{ | |||||
break; | |||||
} | |||||
if (op.inputs[i].dtype != dtypes.resource) | |||||
{ | |||||
continue; | |||||
} | |||||
if (read_only_index < read_only_input_indices.Length && i == read_only_input_indices[read_only_index]) | |||||
{ | |||||
result.Add(i); | |||||
read_only_index++; | |||||
} | |||||
} | |||||
return result; | |||||
} | |||||
} | |||||
} |
@@ -42,10 +42,10 @@ namespace Tensorflow.Framework | |||||
func_graph.as_default(); | func_graph.as_default(); | ||||
importer.import_graph_def(graph_def, name: "", validate_colocation_constraints: false); | importer.import_graph_def(graph_def, name: "", validate_colocation_constraints: false); | ||||
var input_tensor_names = fdef.Signature.InputArg.Select(x => nested_to_flat_tensor_name[x.Name]); | var input_tensor_names = fdef.Signature.InputArg.Select(x => nested_to_flat_tensor_name[x.Name]); | ||||
func_graph.Inputs = new Tensors(input_tensor_names.Select(x => func_graph.get_tensor_by_name(x))); | |||||
func_graph.Inputs = new Tensors(input_tensor_names.Select(x => func_graph.get_tensor_by_name(x)).ToArray()); | |||||
var output_tensor_names = fdef.Signature.OutputArg.Select(x => nested_to_flat_tensor_name[fdef.Ret[x.Name]]); | var output_tensor_names = fdef.Signature.OutputArg.Select(x => nested_to_flat_tensor_name[fdef.Ret[x.Name]]); | ||||
func_graph.Outputs = new Tensors(output_tensor_names.Select(x => func_graph.get_tensor_by_name(x))); | |||||
func_graph.Outputs = new Tensors(output_tensor_names.Select(x => func_graph.get_tensor_by_name(x)).ToArray()); | |||||
// TODO(Rinne): func_graph.ControlOutputs | // TODO(Rinne): func_graph.ControlOutputs | ||||
_set_handle_data(func_graph, fdef); | _set_handle_data(func_graph, fdef); | ||||
@@ -8,6 +8,7 @@ using Tensorflow.Gradients; | |||||
using Tensorflow.Graphs; | using Tensorflow.Graphs; | ||||
using Tensorflow.Train; | using Tensorflow.Train; | ||||
using Tensorflow.Util; | using Tensorflow.Util; | ||||
using Tensorflow.Common.Extensions; | |||||
using static Tensorflow.Binding; | using static Tensorflow.Binding; | ||||
namespace Tensorflow.Functions | namespace Tensorflow.Functions | ||||
@@ -40,6 +41,18 @@ namespace Tensorflow.Functions | |||||
public Tensor[] FlatStructuredOutputs => func_graph.FlatStructuredOutputs; | public Tensor[] FlatStructuredOutputs => func_graph.FlatStructuredOutputs; | ||||
public IEnumerable<IVariableV1> Variables => func_graph.Variables; | public IEnumerable<IVariableV1> Variables => func_graph.Variables; | ||||
public IEnumerable<IVariableV1> TrainableVariables => func_graph.TrainableVariables; | public IEnumerable<IVariableV1> TrainableVariables => func_graph.TrainableVariables; | ||||
internal NameAttrList AsNameAttrList | |||||
{ | |||||
get | |||||
{ | |||||
NameAttrList ret = new() { Name = this.Name }; | |||||
foreach (var (name, value) in _attrs) | |||||
{ | |||||
ret.Attr[name] = value; | |||||
} | |||||
return ret; | |||||
} | |||||
} | |||||
public ConcreteFunction(string name) | public ConcreteFunction(string name) | ||||
{ | { | ||||
@@ -3,4 +3,7 @@ global using System.Collections.Generic; | |||||
global using System.Text; | global using System.Text; | ||||
global using System.Collections; | global using System.Collections; | ||||
global using System.Data; | global using System.Data; | ||||
global using System.Linq; | |||||
global using System.Linq; | |||||
global using Tensorflow.Keras.Engine; | |||||
global using Tensorflow.Framework.Models; | |||||
global using static Tensorflow.Binding; |
@@ -90,8 +90,7 @@ namespace Tensorflow.Gradients | |||||
? input_values[0].rank + dim_int | ? input_values[0].rank + dim_int | ||||
: dim_int % input_values[0].rank; | : dim_int % input_values[0].rank; | ||||
var sizes = input_values.Select(x => x.shape[non_neg_concat_dim]).ToArray(); | var sizes = input_values.Select(x => x.shape[non_neg_concat_dim]).ToArray(); | ||||
var sizes_tensor = constant_op.constant(sizes); | |||||
out_grads = array_ops.split(grad, sizes_tensor, non_neg_concat_dim).ToList(); | |||||
out_grads = array_ops.split(grad, sizes.Select(x => (int)x).ToArray(), ops.convert_to_tensor(non_neg_concat_dim)).ToList(); | |||||
} | } | ||||
else if (constant_op.is_constant(concat_dim)) | else if (constant_op.is_constant(concat_dim)) | ||||
{ | { | ||||
@@ -127,7 +126,7 @@ namespace Tensorflow.Gradients | |||||
new Tensor[] { non_neg_concat_dim, tf.constant(0) }, | new Tensor[] { non_neg_concat_dim, tf.constant(0) }, | ||||
new Tensor[] { tf.constant(1), tf.constant(-1) }); | new Tensor[] { tf.constant(1), tf.constant(-1) }); | ||||
var squeeze_sizes = array_ops.squeeze(slice); | var squeeze_sizes = array_ops.squeeze(slice); | ||||
out_grads = array_ops.split(axis: grad, value: squeeze_sizes, num_split: (int)non_neg_concat_dim).ToList(); | |||||
out_grads = array_ops.split(axis: grad, value: squeeze_sizes, num_or_size_splits: (int)non_neg_concat_dim).ToList(); | |||||
} | } | ||||
else | else | ||||
{ | { | ||||
@@ -374,5 +373,13 @@ namespace Tensorflow.Gradients | |||||
var p = op.inputs[1]; | var p = op.inputs[1]; | ||||
return new Tensor[] { array_ops.transpose(grads[0], array_ops.invert_permutation(p)), null }; | return new Tensor[] { array_ops.transpose(grads[0], array_ops.invert_permutation(p)), null }; | ||||
} | } | ||||
[RegisterGradient("ReverseV2")] | |||||
public static Tensor[] _ReverseV2Grad(Operation op, Tensor[] grads) | |||||
{ | |||||
var grad = grads[0]; | |||||
var axis = op.inputs[1]; | |||||
return new Tensor[] { array_ops.reverse(grad, axis), null }; | |||||
} | |||||
} | } | ||||
} | } |
@@ -117,6 +117,137 @@ namespace Tensorflow.Gradients | |||||
}; | }; | ||||
} | } | ||||
public static string ellipsis = "..."; | |||||
[RegisterGradient("Einsum")] | |||||
public static Tensor[] _EinsumGrad(Operation op, Tensor[] grads) | |||||
{ | |||||
// Gradient for Einsum. | |||||
string equation = (string)op.get_attr("equation"); | |||||
string[] split_equation = equation.Split(new string[] { "->" }, StringSplitOptions.None); | |||||
var input_subs = split_equation[0]; | |||||
var output_subs = split_equation[1]; | |||||
if (op.inputs.Length == 1) | |||||
{ | |||||
var input_shape = array_ops.shape(op.inputs[0]); | |||||
var reduced_label_set = new HashSet<char>(new HashSet<char>(input_subs).Except(new HashSet<char>(output_subs + ellipsis))); | |||||
if (reduced_label_set.Count == 0) | |||||
return new Tensor[] { math_ops.einsum(string.Format("{0}->{1}", output_subs, input_subs), new Tensors(grads)) }; | |||||
return new Tensor[] { _GetGradReduced(new Tensors(grads), output_subs, input_subs, input_shape, reduced_label_set) }; | |||||
} | |||||
string[] split_input_subs = input_subs.Split(new string[] { "," }, StringSplitOptions.None); | |||||
var x_subs = split_input_subs[0]; | |||||
var y_subs = split_input_subs[1]; | |||||
// Add ellipsis for broadcasted dimensions if any operand does not have it. | |||||
// This is because the equation "...ij,jk->ik" may be valid if the 0th input's | |||||
// batch shape is empty, but the VJP equation "jk,ik->...ij" is not valid | |||||
// because only the output subscripts contain ellipsis. | |||||
if (output_subs.Contains(ellipsis)) | |||||
{ | |||||
if (!x_subs.Contains(ellipsis)) | |||||
x_subs += ellipsis; | |||||
if (!y_subs.Contains(ellipsis)) | |||||
y_subs += ellipsis; | |||||
} | |||||
// Obtain the gradients wrt the inputs x and y, without taking into account | |||||
// the unbroadcasting. | |||||
var x = op.inputs[0]; | |||||
var y = op.inputs[1]; | |||||
if (grads.GetDataType().is_complex()) | |||||
{ | |||||
x = math_ops.conj(x); | |||||
y = math_ops.conj(y); | |||||
} | |||||
var x_shape = array_ops.shape(x); | |||||
var y_shape = array_ops.shape(y); | |||||
var grad_x = _GetGradWrt(grads, y, x_shape, x_subs, y_subs, output_subs); | |||||
var grad_y = _GetGradWrt(grads, x, y_shape, y_subs, x_subs, output_subs); | |||||
if (!output_subs.Contains(ellipsis)) | |||||
return new Tensor[] { grad_x, grad_y }; | |||||
var bx = _GetBcastSubshape(x_subs); | |||||
int bx_start = bx[0], bx_end = bx[1]; | |||||
var by = _GetBcastSubshape(y_subs); | |||||
int by_start = by[0], by_end = by[1]; | |||||
var x_shape_static = x.shape; | |||||
var y_shape_static = y.shape; | |||||
if(x_shape_static.IsFullyDefined && | |||||
y_shape_static.IsFullyDefined && | |||||
x_shape_static[string.Format("{0}:{1}",bx_start,bx_end)] == y_shape_static[string.Format("{0}:{1}", by_start, by_end)]) | |||||
return new Tensor[] { grad_x, grad_y }; | |||||
var r = gen_array_ops.broadcast_gradient_args(x_shape[string.Format("{0}:{1}", bx_start, bx_end)], | |||||
y_shape[string.Format("{0}:{1}", by_start, by_end)]); | |||||
var rx = r[0]; | |||||
var ry = r[1]; | |||||
grad_x = array_ops.reshape(math_ops.reduce_sum(grad_x, bx_start + rx), x_shape); | |||||
grad_y = array_ops.reshape(math_ops.reduce_sum(grad_y, by_start + ry), y_shape); | |||||
return new Tensor[] { grad_x, grad_y }; | |||||
} | |||||
protected static Tensor _GetGradWrt(Tensor[] output_grads, Tensor other_operand, Tensor input_shape, | |||||
string input_subs, string other_subs, string output_subs) | |||||
{ | |||||
var reduced_label_set = new HashSet<char>(new HashSet<char>(input_subs).Except(new HashSet<char>(output_subs + other_subs + "."))); | |||||
var left_subs = string.Join("", input_subs.Where(s => !reduced_label_set.Contains(s))); | |||||
var grad_reduced = math_ops.einsum(string.Format("{0},{1}->{2}", output_subs, other_subs, left_subs), new Tensors((Tensors)output_grads, other_operand)); | |||||
if (reduced_label_set.Count == 0) | |||||
return grad_reduced; | |||||
return _GetGradReduced(grad_reduced, left_subs, input_subs, input_shape, reduced_label_set); | |||||
} | |||||
protected static Tensor _GetGradReduced(Tensor output_grad, string output_subs, string input_subs, Tensor input_shape, HashSet<char> reduced_label_set) | |||||
{ | |||||
string reduced_subs; | |||||
Tensor reduced_dims; | |||||
List<int> reduced_axes; | |||||
_GetReducedSubscripts(reduced_label_set, input_shape, input_subs, out reduced_subs, out reduced_dims, out reduced_axes); | |||||
bool has_repeated_labels = ( | |||||
new HashSet<char>(input_subs).Count + new HashSet<char>(output_subs).Count < | |||||
input_subs.Length + output_subs.Length); | |||||
var input_subs_without_reduced_labels = string.Join("", input_subs.Where(s => !reduced_label_set.Contains(s))); | |||||
if (!has_repeated_labels && input_subs_without_reduced_labels == output_subs) | |||||
{ | |||||
var reduced_shape = math_ops.reduced_shape(input_shape, ops.convert_to_tensor(reduced_axes)); | |||||
return gen_array_ops.broadcast_to(array_ops.reshape(output_grad, reduced_shape), input_shape); | |||||
} | |||||
else | |||||
{ | |||||
var grad_shape_with_reduced_labels = array_ops.concat(new Tensor[] { reduced_dims, array_ops.shape(new Tensors(output_grad)) }, axis: 0); | |||||
var reduced_shape = array_ops.concat(new Tensor[] { array_ops.ones(reduced_label_set.Count, dtype: dtypes.int32), array_ops.shape(new Tensors(output_grad)) }, axis: 0); | |||||
var broadcasted_grad = gen_array_ops.broadcast_to(array_ops.reshape(output_grad, reduced_shape), grad_shape_with_reduced_labels); | |||||
return math_ops.einsum(string.Format("{0}->{1}", reduced_subs + output_subs, input_subs), new Tensors(broadcasted_grad)); | |||||
} | |||||
} | |||||
protected static void _GetReducedSubscripts(HashSet<char> reduced_label_set, Tensor input_shape, string subscripts, out string reduced_subs, out Tensor reduced_dims, out List<int> reduced_axes) | |||||
{ | |||||
reduced_subs = string.Join("", reduced_label_set.Select(c => c.ToString())); | |||||
reduced_axes = reduced_subs.Select(s => _GetAxisFromLabel(subscripts, s)).ToList(); | |||||
reduced_dims = array_ops.stack(reduced_axes.Select(ax => input_shape[ax]).ToList()); | |||||
} | |||||
protected static int _GetAxisFromLabel(string subscripts, char label) | |||||
{ | |||||
var splits = subscripts.Split(new string[] { ellipsis }, StringSplitOptions.None); | |||||
var index = splits[0].IndexOf(label); | |||||
if (index != -1) return index; | |||||
if (splits.Length < 2) throw new OutOfRangeError(); | |||||
index = splits[1].IndexOf(label); | |||||
if (index != -1) return index; | |||||
throw new ValueError(); | |||||
} | |||||
protected static int[] _GetBcastSubshape(string subscripts) | |||||
{ | |||||
int start = subscripts.IndexOf(ellipsis); | |||||
if (start == -1) return new int[] { 0, 0 }; | |||||
int remaining = subscripts.Length - (start + ellipsis.Length); | |||||
int end; | |||||
if (remaining > 0) end = remaining; | |||||
else throw new Exception(); | |||||
return new int[] { start, end }; | |||||
} | |||||
/// <summary> | /// <summary> | ||||
/// Returns grad * exp(x). | /// Returns grad * exp(x). | ||||
/// </summary> | /// </summary> | ||||
@@ -365,6 +365,23 @@ namespace Tensorflow.Gradients | |||||
}; | }; | ||||
} | } | ||||
[RegisterGradient("AvgPool")] | |||||
public static Tensor[] _AvgPoolGrad(Operation op, Tensor[] grads) | |||||
{ | |||||
Tensor grad = grads[0]; | |||||
return new Tensor[] | |||||
{ | |||||
gen_nn_ops.avg_pool_grad( | |||||
array_ops.shape(op.inputs[0]), | |||||
grad, | |||||
op.get_attr_list<int>("ksize"), | |||||
op.get_attr_list<int>("strides"), | |||||
op.get_attr<string>("padding"), | |||||
op.get_attr<string>("data_format")) | |||||
}; | |||||
} | |||||
/// <summary> | /// <summary> | ||||
/// Return the gradients for TopK. | /// Return the gradients for TopK. | ||||
/// </summary> | /// </summary> | ||||
@@ -81,7 +81,7 @@ public class FuncGraph : Graph, IDisposable | |||||
public IEnumerable<IVariableV1> TrainableVariables => Variables.Where(v => v.Trainable); | public IEnumerable<IVariableV1> TrainableVariables => Variables.Where(v => v.Trainable); | ||||
public Dictionary<string, AttrValue> Attrs { get; set; } | public Dictionary<string, AttrValue> Attrs { get; set; } | ||||
Dictionary<long, (Tensor, Tensor)> _captures | |||||
internal Dictionary<long, (Tensor, Tensor)> _captures | |||||
= new Dictionary<long, (Tensor, Tensor)>(); | = new Dictionary<long, (Tensor, Tensor)>(); | ||||
public Tensor[] external_captures | public Tensor[] external_captures | ||||
@@ -399,7 +399,7 @@ public class FuncGraph : Graph, IDisposable | |||||
var flat_func_args = nest.flatten(func_args as object); | var flat_func_args = nest.flatten(func_args as object); | ||||
var flat_func_kwargs = nest.flatten(func_kwargs as object); | var flat_func_kwargs = nest.flatten(func_kwargs as object); | ||||
func_graph.Inputs = new Tensors(flat_func_args.concat(flat_func_kwargs) | func_graph.Inputs = new Tensors(flat_func_args.concat(flat_func_kwargs) | ||||
.Where(x => x is Tensor).Select(x => (Tensor)x)); | |||||
.Where(x => x is Tensor).Select(x => (Tensor)x).ToArray()); | |||||
//var func_args_before = nest.pack_sequence_as(func_args, flat_func_args, true); | //var func_args_before = nest.pack_sequence_as(func_args, flat_func_args, true); | ||||
//var func_kwargs_before = nest.pack_sequence_as(func_kwargs, flat_func_kwargs, true); | //var func_kwargs_before = nest.pack_sequence_as(func_kwargs, flat_func_kwargs, true); | ||||
@@ -544,12 +544,12 @@ public class FuncGraph : Graph, IDisposable | |||||
Tensor placeholder; | Tensor placeholder; | ||||
try | try | ||||
{ | { | ||||
placeholder = tf.placeholder(tensor.dtype, tensor.shape, name); | |||||
placeholder = GraphOnlyOps.graph_placeholder(tensor.dtype, tensor.shape, name); | |||||
} | } | ||||
catch (ValueError) | |||||
catch (ValueError ex) | |||||
{ | { | ||||
// TODO(Rinne): Add warning here. | |||||
placeholder = tf.placeholder(tensor.dtype, tensor.shape); | |||||
tf.Logger.Warning(ex.ToString()); | |||||
placeholder = GraphOnlyOps.graph_placeholder(tensor.dtype, tensor.shape); | |||||
} | } | ||||
handle_data_util.copy_handle_data(tensor, placeholder); | handle_data_util.copy_handle_data(tensor, placeholder); | ||||
if (name is not null) | if (name is not null) | ||||
@@ -575,12 +575,12 @@ public class FuncGraph : Graph, IDisposable | |||||
Tensor placeholder; | Tensor placeholder; | ||||
try | try | ||||
{ | { | ||||
placeholder = tf.placeholder(spec.dtype, spec.shape, requested_name); | |||||
placeholder = GraphOnlyOps.graph_placeholder(spec.dtype, spec.shape, requested_name); | |||||
} | } | ||||
catch (ValueError) | catch (ValueError) | ||||
{ | { | ||||
// TODO(Rinne): Add warning here. | // TODO(Rinne): Add warning here. | ||||
placeholder = tf.placeholder(spec.dtype, spec.shape); | |||||
placeholder = GraphOnlyOps.graph_placeholder(spec.dtype, spec.shape); | |||||
} | } | ||||
if (name is not null) | if (name is not null) | ||||
{ | { | ||||
@@ -129,7 +129,7 @@ namespace Tensorflow | |||||
} | } | ||||
} | } | ||||
protected Graph outer_graph; | |||||
internal Graph outer_graph; | |||||
public Graph OuterGraph => outer_graph; | public Graph OuterGraph => outer_graph; | ||||
public Dictionary<string, EagerDefinedFunction> Functions => _functions; | public Dictionary<string, EagerDefinedFunction> Functions => _functions; | ||||
public SafeGraphHandle c_graph => _handle; | public SafeGraphHandle c_graph => _handle; | ||||
@@ -0,0 +1,10 @@ | |||||
using System; | |||||
using System.Collections.Generic; | |||||
using System.Text; | |||||
namespace Tensorflow.Keras.ArgsDefinition | |||||
{ | |||||
public class ExponentialArgs : LayerArgs | |||||
{ | |||||
} | |||||
} |
@@ -0,0 +1,10 @@ | |||||
using System; | |||||
using System.Collections.Generic; | |||||
using System.Text; | |||||
namespace Tensorflow.Keras.ArgsDefinition | |||||
{ | |||||
public class HardSigmoidArgs : LayerArgs | |||||
{ | |||||
} | |||||
} |
@@ -0,0 +1,11 @@ | |||||
using System; | |||||
using System.Collections.Generic; | |||||
using System.Text; | |||||
namespace Tensorflow.Keras.ArgsDefinition | |||||
{ | |||||
public class SELUArgs : LayerArgs | |||||
{ | |||||
} | |||||
} |
@@ -0,0 +1,10 @@ | |||||
using System; | |||||
using System.Collections.Generic; | |||||
using System.Text; | |||||
namespace Tensorflow.Keras.ArgsDefinition | |||||
{ | |||||
public class SoftplusArgs : LayerArgs | |||||
{ | |||||
} | |||||
} |
@@ -0,0 +1,10 @@ | |||||
using System; | |||||
using System.Collections.Generic; | |||||
using System.Text; | |||||
namespace Tensorflow.Keras.ArgsDefinition | |||||
{ | |||||
public class SoftsignArgs : LayerArgs | |||||
{ | |||||
} | |||||
} |
@@ -0,0 +1,10 @@ | |||||
using System; | |||||
using System.Collections.Generic; | |||||
using System.Text; | |||||
namespace Tensorflow.Keras.ArgsDefinition | |||||
{ | |||||
public class SwishArgs : LayerArgs | |||||
{ | |||||
} | |||||
} |
@@ -0,0 +1,10 @@ | |||||
using System; | |||||
using System.Collections.Generic; | |||||
using System.Text; | |||||
namespace Tensorflow.Keras.ArgsDefinition | |||||
{ | |||||
public class TanhArgs : LayerArgs | |||||
{ | |||||
} | |||||
} |
@@ -0,0 +1,10 @@ | |||||
using System; | |||||
using System.Collections.Generic; | |||||
using System.Text; | |||||
namespace Tensorflow.Keras.ArgsDefinition | |||||
{ | |||||
public class Conv2DTransposeArgs : Conv2DArgs | |||||
{ | |||||
} | |||||
} |
@@ -0,0 +1,10 @@ | |||||
using System; | |||||
using System.Collections.Generic; | |||||
using System.Text; | |||||
namespace Tensorflow.Keras.ArgsDefinition | |||||
{ | |||||
public class AddArgs : MergeArgs | |||||
{ | |||||
} | |||||
} |
@@ -0,0 +1,10 @@ | |||||
using System; | |||||
using System.Collections.Generic; | |||||
using System.Text; | |||||
namespace Tensorflow.Keras.ArgsDefinition | |||||
{ | |||||
public class ConcatenateArgs : MergeArgs | |||||
{ | |||||
} | |||||
} |
@@ -0,0 +1,10 @@ | |||||
using System; | |||||
using System.Collections.Generic; | |||||
using System.Text; | |||||
namespace Tensorflow.Keras.ArgsDefinition | |||||
{ | |||||
public class SubtractArgs : MergeArgs | |||||
{ | |||||
} | |||||
} |
@@ -0,0 +1,10 @@ | |||||
using System; | |||||
using System.Collections.Generic; | |||||
using System.Text; | |||||
namespace Tensorflow.Keras.ArgsDefinition | |||||
{ | |||||
public class GlobalAveragePooling1DArgs : Pooling1DArgs | |||||
{ | |||||
} | |||||
} |
@@ -0,0 +1,10 @@ | |||||
using System; | |||||
using System.Collections.Generic; | |||||
using System.Text; | |||||
namespace Tensorflow.Keras.ArgsDefinition | |||||
{ | |||||
public class GlobalAveragePooling2DArgs : Pooling2DArgs | |||||
{ | |||||
} | |||||
} |
@@ -0,0 +1,10 @@ | |||||
using System; | |||||
using System.Collections.Generic; | |||||
using System.Text; | |||||
namespace Tensorflow.Keras.ArgsDefinition | |||||
{ | |||||
public class GlobalMaxPooling1DArgs : Pooling1DArgs | |||||
{ | |||||
} | |||||
} |
@@ -0,0 +1,10 @@ | |||||
using System; | |||||
using System.Collections.Generic; | |||||
using System.Text; | |||||
namespace Tensorflow.Keras.ArgsDefinition | |||||
{ | |||||
public class GlobalMaxPooling2DArgs : Pooling2DArgs | |||||
{ | |||||
} | |||||
} |
@@ -0,0 +1,10 @@ | |||||
using System; | |||||
using System.Collections.Generic; | |||||
using System.Text; | |||||
namespace Tensorflow.Keras.ArgsDefinition | |||||
{ | |||||
public class MaxPooling1DArgs : Pooling1DArgs | |||||
{ | |||||
} | |||||
} |
@@ -7,7 +7,7 @@ namespace Tensorflow.Keras.ArgsDefinition | |||||
[JsonProperty("size")] | [JsonProperty("size")] | ||||
public Shape Size { get; set; } | public Shape Size { get; set; } | ||||
[JsonProperty("data_format")] | [JsonProperty("data_format")] | ||||
public string DataFormat { get; set; } | |||||
public string DataFormat { get; set; } = "channels_last"; | |||||
/// <summary> | /// <summary> | ||||
/// 'nearest', 'bilinear' | /// 'nearest', 'bilinear' | ||||
/// </summary> | /// </summary> | ||||
@@ -0,0 +1,10 @@ | |||||
using Newtonsoft.Json; | |||||
namespace Tensorflow.Keras.ArgsDefinition | |||||
{ | |||||
public class UpSampling1DArgs : AutoSerializeLayerArgs | |||||
{ | |||||
[JsonProperty("size")] | |||||
public int Size { get; set; } | |||||
} | |||||
} |
@@ -0,0 +1,20 @@ | |||||
using Newtonsoft.Json; | |||||
using System; | |||||
using System.Collections.Generic; | |||||
using System.Text; | |||||
using Tensorflow.NumPy; | |||||
namespace Tensorflow.Keras.ArgsDefinition | |||||
{ | |||||
public class BidirectionalArgs : AutoSerializeLayerArgs | |||||
{ | |||||
[JsonProperty("layer")] | |||||
public ILayer Layer { get; set; } | |||||
[JsonProperty("merge_mode")] | |||||
public string? MergeMode { get; set; } | |||||
[JsonProperty("backward_layer")] | |||||
public ILayer BackwardLayer { get; set; } | |||||
public NDArray Weights { get; set; } | |||||
} | |||||
} |
@@ -0,0 +1,29 @@ | |||||
using System; | |||||
using System.Collections.Generic; | |||||
using System.Text; | |||||
namespace Tensorflow.Keras.ArgsDefinition | |||||
{ | |||||
public class GRUArgs : AutoSerializeLayerArgs | |||||
{ | |||||
public int Units { get; set; } | |||||
public Activation Activation { get; set; } | |||||
public Activation RecurrentActivation { get; set; } | |||||
public bool UseBias { get; set; } = true; | |||||
public float Dropout { get; set; } = .0f; | |||||
public float RecurrentDropout { get; set; } = .0f; | |||||
public IInitializer KernelInitializer { get; set; } | |||||
public IInitializer RecurrentInitializer { get; set; } | |||||
public IInitializer BiasInitializer { get; set; } | |||||
public bool ReturnSequences { get;set; } | |||||
public bool ReturnState { get;set; } | |||||
public bool GoBackwards { get;set; } | |||||
public bool Stateful { get;set; } | |||||
public bool Unroll { get;set; } | |||||
public bool TimeMajor { get;set; } | |||||
public bool ResetAfter { get;set; } | |||||
public int Implementation { get; set; } = 2; | |||||
} | |||||
} |
@@ -0,0 +1,39 @@ | |||||
using Newtonsoft.Json; | |||||
using System; | |||||
using System.Collections.Generic; | |||||
using System.Text; | |||||
namespace Tensorflow.Keras.ArgsDefinition | |||||
{ | |||||
public class GRUCellArgs : AutoSerializeLayerArgs | |||||
{ | |||||
[JsonProperty("units")] | |||||
public int Units { get; set; } | |||||
// TODO(Rinne): lack of initialized value of Activation. Merging keras | |||||
// into tf.net could resolve it. | |||||
[JsonProperty("activation")] | |||||
public Activation Activation { get; set; } | |||||
[JsonProperty("recurrent_activation")] | |||||
public Activation RecurrentActivation { get; set; } | |||||
[JsonProperty("use_bias")] | |||||
public bool UseBias { get; set; } = true; | |||||
[JsonProperty("dropout")] | |||||
public float Dropout { get; set; } = .0f; | |||||
[JsonProperty("recurrent_dropout")] | |||||
public float RecurrentDropout { get; set; } = .0f; | |||||
[JsonProperty("kernel_initializer")] | |||||
public IInitializer KernelInitializer { get; set; } | |||||
[JsonProperty("recurrent_initializer")] | |||||
public IInitializer RecurrentInitializer { get; set; } | |||||
[JsonProperty("bias_initializer")] | |||||
public IInitializer BiasInitializer { get; set; } | |||||
[JsonProperty("reset_after")] | |||||
public bool ResetAfter { get;set; } | |||||
[JsonProperty("implementation")] | |||||
public int Implementation { get; set; } = 2; | |||||
} | |||||
} |
@@ -0,0 +1,13 @@ | |||||
using System; | |||||
using System.Collections.Generic; | |||||
using System.Text; | |||||
namespace Tensorflow.Keras.ArgsDefinition | |||||
{ | |||||
public class GRUOptionalArgs | |||||
{ | |||||
public string Identifier => "GRU"; | |||||
public Tensor Mask { get; set; } = null; | |||||
} | |||||
} |
@@ -1,11 +1,14 @@ | |||||
namespace Tensorflow.Keras.ArgsDefinition.Rnn | |||||
namespace Tensorflow.Keras.ArgsDefinition | |||||
{ | { | ||||
public class LSTMArgs : RNNArgs | public class LSTMArgs : RNNArgs | ||||
{ | { | ||||
// TODO: maybe change the `RNNArgs` and implement this class. | // TODO: maybe change the `RNNArgs` and implement this class. | ||||
public bool UnitForgetBias { get; set; } | public bool UnitForgetBias { get; set; } | ||||
public float Dropout { get; set; } | |||||
public float RecurrentDropout { get; set; } | |||||
public int Implementation { get; set; } | public int Implementation { get; set; } | ||||
public LSTMArgs Clone() | |||||
{ | |||||
return (LSTMArgs)MemberwiseClone(); | |||||
} | |||||
} | } | ||||
} | } |
@@ -1,7 +1,35 @@ | |||||
namespace Tensorflow.Keras.ArgsDefinition.Rnn | |||||
using Newtonsoft.Json; | |||||
using static Tensorflow.Binding; | |||||
namespace Tensorflow.Keras.ArgsDefinition | |||||
{ | { | ||||
// TODO: complete the implementation | // TODO: complete the implementation | ||||
public class LSTMCellArgs : LayerArgs | |||||
public class LSTMCellArgs : AutoSerializeLayerArgs | |||||
{ | { | ||||
[JsonProperty("units")] | |||||
public int Units { get; set; } | |||||
// TODO(Rinne): lack of initialized value of Activation. Merging keras | |||||
// into tf.net could resolve it. | |||||
[JsonProperty("activation")] | |||||
public Activation Activation { get; set; } | |||||
[JsonProperty("recurrent_activation")] | |||||
public Activation RecurrentActivation { get; set; } | |||||
[JsonProperty("use_bias")] | |||||
public bool UseBias { get; set; } = true; | |||||
[JsonProperty("dropout")] | |||||
public float Dropout { get; set; } = .0f; | |||||
[JsonProperty("recurrent_dropout")] | |||||
public float RecurrentDropout { get; set; } = .0f; | |||||
[JsonProperty("kernel_initializer")] | |||||
public IInitializer KernelInitializer { get; set; } | |||||
[JsonProperty("recurrent_initializer")] | |||||
public IInitializer RecurrentInitializer { get; set; } | |||||
[JsonProperty("bias_initializer")] | |||||
public IInitializer BiasInitializer { get; set; } | |||||
[JsonProperty("unit_forget_bias")] | |||||
public bool UnitForgetBias { get; set; } = true; | |||||
[JsonProperty("implementation")] | |||||
public int Implementation { get; set; } = 2; | |||||
} | } | ||||
} | } |
@@ -1,17 +1,12 @@ | |||||
using Newtonsoft.Json; | using Newtonsoft.Json; | ||||
using System.Collections.Generic; | using System.Collections.Generic; | ||||
using Tensorflow.Keras.Layers; | |||||
namespace Tensorflow.Keras.ArgsDefinition.Rnn | |||||
namespace Tensorflow.Keras.ArgsDefinition | |||||
{ | { | ||||
// TODO(Rinne): add regularizers. | |||||
public class RNNArgs : AutoSerializeLayerArgs | public class RNNArgs : AutoSerializeLayerArgs | ||||
{ | { | ||||
public interface IRnnArgCell : ILayer | |||||
{ | |||||
object state_size { get; } | |||||
} | |||||
[JsonProperty("cell")] | |||||
// TODO: the cell should be serialized with `serialize_keras_object`. | |||||
public IRnnArgCell Cell { get; set; } = null; | |||||
[JsonProperty("return_sequences")] | [JsonProperty("return_sequences")] | ||||
public bool ReturnSequences { get; set; } = false; | public bool ReturnSequences { get; set; } = false; | ||||
[JsonProperty("return_state")] | [JsonProperty("return_state")] | ||||
@@ -24,31 +19,31 @@ namespace Tensorflow.Keras.ArgsDefinition.Rnn | |||||
public bool Unroll { get; set; } = false; | public bool Unroll { get; set; } = false; | ||||
[JsonProperty("time_major")] | [JsonProperty("time_major")] | ||||
public bool TimeMajor { get; set; } = false; | public bool TimeMajor { get; set; } = false; | ||||
// TODO: Add `num_constants` and `zero_output_for_mask`. | |||||
public Dictionary<string, object> Kwargs { get; set; } = null; | |||||
public int? InputDim { get; set; } | |||||
public int? InputLength { get; set; } | |||||
// TODO: Add `num_constants` and `zero_output_for_mask`. | |||||
[JsonProperty("units")] | |||||
public int Units { get; set; } | public int Units { get; set; } | ||||
[JsonProperty("activation")] | |||||
public Activation Activation { get; set; } | public Activation Activation { get; set; } | ||||
[JsonProperty("recurrent_activation")] | |||||
public Activation RecurrentActivation { get; set; } | public Activation RecurrentActivation { get; set; } | ||||
[JsonProperty("use_bias")] | |||||
public bool UseBias { get; set; } = true; | public bool UseBias { get; set; } = true; | ||||
public IInitializer KernelInitializer { get; set; } | public IInitializer KernelInitializer { get; set; } | ||||
public IInitializer RecurrentInitializer { get; set; } | public IInitializer RecurrentInitializer { get; set; } | ||||
public IInitializer BiasInitializer { get; set; } | public IInitializer BiasInitializer { get; set; } | ||||
[JsonProperty("dropout")] | |||||
public float Dropout { get; set; } = .0f; | |||||
[JsonProperty("zero_output_for_mask")] | |||||
public bool ZeroOutputForMask { get; set; } = false; | |||||
[JsonProperty("recurrent_dropout")] | |||||
public float RecurrentDropout { get; set; } = .0f; | |||||
// kernel_regularizer=None, | |||||
// recurrent_regularizer=None, | |||||
// bias_regularizer=None, | |||||
// activity_regularizer=None, | |||||
// kernel_constraint=None, | |||||
// recurrent_constraint=None, | |||||
// bias_constraint=None, | |||||
// dropout=0., | |||||
// recurrent_dropout=0., | |||||
// return_sequences=False, | |||||
// return_state=False, | |||||
// go_backwards=False, | |||||
// stateful=False, | |||||
// unroll=False, | |||||
// **kwargs): | |||||
public RNNArgs Clone() | |||||
{ | |||||
return (RNNArgs)MemberwiseClone(); | |||||
} | |||||
} | } | ||||
} | } |
@@ -0,0 +1,14 @@ | |||||
using System; | |||||
using System.Collections.Generic; | |||||
using System.Text; | |||||
using Tensorflow.Common.Types; | |||||
namespace Tensorflow.Keras.ArgsDefinition | |||||
{ | |||||
public class RnnOptionalArgs: IOptionalArgs | |||||
{ | |||||
public string Identifier => "Rnn"; | |||||
public Tensor Mask { get; set; } = null; | |||||
public Tensors Constants { get; set; } = null; | |||||
} | |||||
} |
@@ -1,4 +1,4 @@ | |||||
namespace Tensorflow.Keras.ArgsDefinition.Rnn | |||||
namespace Tensorflow.Keras.ArgsDefinition | |||||
{ | { | ||||
public class SimpleRNNArgs : RNNArgs | public class SimpleRNNArgs : RNNArgs | ||||
{ | { | ||||
@@ -0,0 +1,27 @@ | |||||
using Newtonsoft.Json; | |||||
namespace Tensorflow.Keras.ArgsDefinition | |||||
{ | |||||
public class SimpleRNNCellArgs: AutoSerializeLayerArgs | |||||
{ | |||||
[JsonProperty("units")] | |||||
public int Units { get; set; } | |||||
// TODO(Rinne): lack of initialized value of Activation. Merging keras | |||||
// into tf.net could resolve it. | |||||
[JsonProperty("activation")] | |||||
public Activation Activation { get; set; } | |||||
[JsonProperty("use_bias")] | |||||
public bool UseBias { get; set; } = true; | |||||
[JsonProperty("dropout")] | |||||
public float Dropout { get; set; } = .0f; | |||||
[JsonProperty("recurrent_dropout")] | |||||
public float RecurrentDropout { get; set; } = .0f; | |||||
[JsonProperty("kernel_initializer")] | |||||
public IInitializer KernelInitializer { get; set; } | |||||
[JsonProperty("recurrent_initializer")] | |||||
public IInitializer RecurrentInitializer { get; set; } | |||||
[JsonProperty("bias_initializer")] | |||||
public IInitializer BiasInitializer { get; set; } | |||||
} | |||||
} |
@@ -1,10 +1,10 @@ | |||||
using System.Collections.Generic; | using System.Collections.Generic; | ||||
using Tensorflow.Keras.Layers; | |||||
namespace Tensorflow.Keras.ArgsDefinition.Rnn | |||||
namespace Tensorflow.Keras.ArgsDefinition | |||||
{ | { | ||||
public class StackedRNNCellsArgs : LayerArgs | public class StackedRNNCellsArgs : LayerArgs | ||||
{ | { | ||||
public IList<RnnCell> Cells { get; set; } | |||||
public Dictionary<string, object> Kwargs { get; set; } = null; | |||||
public bool ReverseStateOrder = false; | |||||
} | } | ||||
} | } |
@@ -0,0 +1,24 @@ | |||||
using Newtonsoft.Json; | |||||
using System; | |||||
using System.Collections.Generic; | |||||
using System.Runtime.CompilerServices; | |||||
using System.Text; | |||||
namespace Tensorflow.Keras.ArgsDefinition | |||||
{ | |||||
public class WrapperArgs : AutoSerializeLayerArgs | |||||
{ | |||||
[JsonProperty("layer")] | |||||
public ILayer Layer { get; set; } | |||||
public WrapperArgs(ILayer layer) | |||||
{ | |||||
Layer = layer; | |||||
} | |||||
public static implicit operator WrapperArgs(BidirectionalArgs args) | |||||
=> new WrapperArgs(args.Layer); | |||||
} | |||||
} |
@@ -14,6 +14,9 @@ public interface ICallback | |||||
void on_predict_batch_end(long end_step, Dictionary<string, Tensors> logs); | void on_predict_batch_end(long end_step, Dictionary<string, Tensors> logs); | ||||
void on_predict_end(); | void on_predict_end(); | ||||
void on_test_begin(); | void on_test_begin(); | ||||
void on_test_end(Dictionary<string, float> logs); | |||||
void on_test_batch_begin(long step); | void on_test_batch_begin(long step); | ||||
void on_test_batch_end(long end_step, Dictionary<string, float> logs); | void on_test_batch_end(long end_step, Dictionary<string, float> logs); | ||||
} | } |
@@ -60,7 +60,7 @@ public interface IModel : ILayer | |||||
bool skip_mismatch = false, | bool skip_mismatch = false, | ||||
object options = null); | object options = null); | ||||
Dictionary<string, float> evaluate(Tensor x, Tensor y, | |||||
Dictionary<string, float> evaluate(NDArray x, NDArray y, | |||||
int batch_size = -1, | int batch_size = -1, | ||||
int verbose = 1, | int verbose = 1, | ||||
int steps = -1, | int steps = -1, | ||||
@@ -0,0 +1,75 @@ | |||||
namespace Tensorflow.Keras.Engine; | |||||
/// <summary> | |||||
/// A representation of a Keras in/output during Functional API construction. | |||||
/// </summary> | |||||
public class KerasTensor | |||||
{ | |||||
private Tensors _original_tensors; | |||||
public Tensors original_tensors | |||||
{ | |||||
get => _original_tensors; | |||||
set => _original_tensors = value; | |||||
} | |||||
private Shape _inferred_value; | |||||
public Shape inferred_value => _inferred_value; | |||||
private string _name; | |||||
private TensorSpec _type_spec; | |||||
public Shape shape => _type_spec.shape; | |||||
public TF_DataType dtype => _type_spec.dtype; | |||||
public KerasTensor(TensorSpec type_spec, Shape inferred_value = null, string name = null) | |||||
{ | |||||
_type_spec = type_spec; | |||||
_inferred_value = inferred_value; | |||||
_name = name; | |||||
} | |||||
public static KerasTensor from_tensor(Tensor tensor) | |||||
{ | |||||
var type_spec = tensor.ToTensorSpec(); | |||||
Shape? inferred_value = default; | |||||
if (tensor.dtype == TF_DataType.TF_INT32 && tensor.rank < 2) | |||||
{ | |||||
inferred_value = tf.ones(tensor).shape; | |||||
} | |||||
var kt = new KerasTensor(type_spec, inferred_value: inferred_value, name: tensor.name); | |||||
kt.original_tensors = tensor; | |||||
return kt; | |||||
} | |||||
public KerasTensor this[int idx] | |||||
=> _original_tensors.First()[idx]; | |||||
public KerasTensor this[params Slice[] slices] | |||||
=> _original_tensors.First()[slices]; | |||||
public override string ToString() | |||||
=> _original_tensors.Length switch | |||||
{ | |||||
> 1 => "[" + string.Join(", ", _original_tensors.Select(x => $"KerasTensor: shape={x.shape} dtype={x.dtype.as_numpy_name()}{GetInferredValueString()}")) + "]", | |||||
1 => $"KerasTensor: shape={_original_tensors.shape} dtype={_original_tensors.dtype.as_numpy_name()}{GetInferredValueString()}", | |||||
_ => _original_tensors.ToString(), | |||||
}; | |||||
private string GetInferredValueString() | |||||
=> _inferred_value == null ? "" : $" inferred_value={_inferred_value}"; | |||||
public static implicit operator Tensors(KerasTensor kt) | |||||
=> kt._original_tensors; | |||||
public static implicit operator Tensor(KerasTensor kt) | |||||
{ | |||||
Tensor tensor = kt._original_tensors; | |||||
tensor.IsFromKerasTensor = true; | |||||
return tensor; | |||||
} | |||||
public static implicit operator KerasTensor(Tensor tensor) | |||||
=> from_tensor(tensor); | |||||
public static implicit operator KerasTensor(Tensors tensors) | |||||
=> from_tensor(tensors.First()); | |||||
} |
@@ -25,6 +25,27 @@ namespace Tensorflow.Keras | |||||
bool amsgrad = false, | bool amsgrad = false, | ||||
string name = "Adam"); | string name = "Adam"); | ||||
/// <summary> | |||||
/// Adam enables L2 weight decay on gradients. | |||||
/// </summary> | |||||
/// <param name="learning_rate"></param> | |||||
/// <param name="weight_decay"></param> | |||||
/// <param name="beta_1"></param> | |||||
/// <param name="beta_2"></param> | |||||
/// <param name="epsilon"></param> | |||||
/// <param name="amsgrad"></param> | |||||
/// <param name="decay_params"></param> | |||||
/// <param name="name"></param> | |||||
/// <returns></returns> | |||||
IOptimizer AdamW(float learning_rate = 0.001f, | |||||
float weight_decay = 0.004f, | |||||
float beta_1 = 0.9f, | |||||
float beta_2 = 0.999f, | |||||
float epsilon = 1e-7f, | |||||
bool amsgrad = false, | |||||
List<string> no_decay_params = null, | |||||
string name = "AdamW"); | |||||
/// <summary> | /// <summary> | ||||
/// Construct a new RMSprop optimizer. | /// Construct a new RMSprop optimizer. | ||||
/// </summary> | /// </summary> | ||||
@@ -42,6 +63,6 @@ namespace Tensorflow.Keras | |||||
bool centered = false, | bool centered = false, | ||||
string name = "RMSprop"); | string name = "RMSprop"); | ||||
IOptimizer SGD(float learning_rate); | |||||
IOptimizer SGD(float learning_rate = 0.01f, float momentum = 0f); | |||||
} | } | ||||
} | } |
@@ -1,4 +1,5 @@ | |||||
using Tensorflow.Keras.Engine; | |||||
using Tensorflow.Common.Types; | |||||
using Tensorflow.Keras.Engine; | |||||
using Tensorflow.Keras.Saving; | using Tensorflow.Keras.Saving; | ||||
using Tensorflow.NumPy; | using Tensorflow.NumPy; | ||||
using Tensorflow.Training; | using Tensorflow.Training; | ||||
@@ -14,7 +15,7 @@ namespace Tensorflow.Keras | |||||
List<ILayer> Layers { get; } | List<ILayer> Layers { get; } | ||||
List<INode> InboundNodes { get; } | List<INode> InboundNodes { get; } | ||||
List<INode> OutboundNodes { get; } | List<INode> OutboundNodes { get; } | ||||
Tensors Apply(Tensors inputs, Tensor state = null, bool training = false); | |||||
Tensors Apply(Tensors inputs, Tensors states = null, bool? training = false, IOptionalArgs? optional_args = null); | |||||
List<IVariableV1> TrainableVariables { get; } | List<IVariableV1> TrainableVariables { get; } | ||||
List<IVariableV1> TrainableWeights { get; } | List<IVariableV1> TrainableWeights { get; } | ||||
List<IVariableV1> NonTrainableWeights { get; } | List<IVariableV1> NonTrainableWeights { get; } | ||||
@@ -9,6 +9,10 @@ namespace Tensorflow.Keras.Layers | |||||
public ILayer Reshape(Shape target_shape); | public ILayer Reshape(Shape target_shape); | ||||
public ILayer Reshape(object[] target_shape); | public ILayer Reshape(object[] target_shape); | ||||
public ILayer UpSampling1D( | |||||
int size | |||||
); | |||||
public ILayer UpSampling2D(Shape size = null, | public ILayer UpSampling2D(Shape size = null, | ||||
string data_format = null, | string data_format = null, | ||||
string interpolation = "nearest"); | string interpolation = "nearest"); | ||||
@@ -1,5 +1,7 @@ | |||||
using System; | using System; | ||||
using Tensorflow.Framework.Models; | using Tensorflow.Framework.Models; | ||||
using Tensorflow.Keras.Engine; | |||||
using Tensorflow.Keras.Layers; | |||||
using Tensorflow.NumPy; | using Tensorflow.NumPy; | ||||
using static Google.Protobuf.Reflection.FieldDescriptorProto.Types; | using static Google.Protobuf.Reflection.FieldDescriptorProto.Types; | ||||
@@ -134,7 +136,7 @@ namespace Tensorflow.Keras.Layers | |||||
public ILayer GlobalMaxPooling1D(string data_format = "channels_last"); | public ILayer GlobalMaxPooling1D(string data_format = "channels_last"); | ||||
public ILayer GlobalMaxPooling2D(string data_format = "channels_last"); | public ILayer GlobalMaxPooling2D(string data_format = "channels_last"); | ||||
public Tensors Input(Shape shape = null, | |||||
public KerasTensor Input(Shape shape = null, | |||||
int batch_size = -1, | int batch_size = -1, | ||||
string name = null, | string name = null, | ||||
TF_DataType dtype = TF_DataType.DtInvalid, | TF_DataType dtype = TF_DataType.DtInvalid, | ||||
@@ -159,6 +161,18 @@ namespace Tensorflow.Keras.Layers | |||||
public ILayer Normalization(Shape? input_shape = null, int? axis = -1, float? mean = null, float? variance = null, bool invert = false); | public ILayer Normalization(Shape? input_shape = null, int? axis = -1, float? mean = null, float? variance = null, bool invert = false); | ||||
public ILayer LeakyReLU(float alpha = 0.3f); | public ILayer LeakyReLU(float alpha = 0.3f); | ||||
public IRnnCell LSTMCell(int uints, | |||||
string activation = "tanh", | |||||
string recurrent_activation = "sigmoid", | |||||
bool use_bias = true, | |||||
string kernel_initializer = "glorot_uniform", | |||||
string recurrent_initializer = "orthogonal", | |||||
string bias_initializer = "zeros", | |||||
bool unit_forget_bias = true, | |||||
float dropout = 0f, | |||||
float recurrent_dropout = 0f, | |||||
int implementation = 2); | |||||
public ILayer LSTM(int units, | public ILayer LSTM(int units, | ||||
Activation activation = null, | Activation activation = null, | ||||
Activation recurrent_activation = null, | Activation recurrent_activation = null, | ||||
@@ -192,6 +206,19 @@ namespace Tensorflow.Keras.Layers | |||||
float offset = 0, | float offset = 0, | ||||
Shape input_shape = null); | Shape input_shape = null); | ||||
public IRnnCell SimpleRNNCell( | |||||
int units, | |||||
string activation = "tanh", | |||||
bool use_bias = true, | |||||
string kernel_initializer = "glorot_uniform", | |||||
string recurrent_initializer = "orthogonal", | |||||
string bias_initializer = "zeros", | |||||
float dropout = 0f, | |||||
float recurrent_dropout = 0f); | |||||
public IRnnCell StackedRNNCells( | |||||
IEnumerable<IRnnCell> cells); | |||||
public ILayer SimpleRNN(int units, | public ILayer SimpleRNN(int units, | ||||
string activation = "tanh", | string activation = "tanh", | ||||
string kernel_initializer = "glorot_uniform", | string kernel_initializer = "glorot_uniform", | ||||
@@ -200,6 +227,69 @@ namespace Tensorflow.Keras.Layers | |||||
bool return_sequences = false, | bool return_sequences = false, | ||||
bool return_state = false); | bool return_state = false); | ||||
public ILayer RNN( | |||||
IRnnCell cell, | |||||
bool return_sequences = false, | |||||
bool return_state = false, | |||||
bool go_backwards = false, | |||||
bool stateful = false, | |||||
bool unroll = false, | |||||
bool time_major = false | |||||
); | |||||
public ILayer RNN( | |||||
IEnumerable<IRnnCell> cell, | |||||
bool return_sequences = false, | |||||
bool return_state = false, | |||||
bool go_backwards = false, | |||||
bool stateful = false, | |||||
bool unroll = false, | |||||
bool time_major = false | |||||
); | |||||
public IRnnCell GRUCell( | |||||
int units, | |||||
string activation = "tanh", | |||||
string recurrent_activation = "sigmoid", | |||||
bool use_bias = true, | |||||
string kernel_initializer = "glorot_uniform", | |||||
string recurrent_initializer = "orthogonal", | |||||
string bias_initializer = "zeros", | |||||
float dropout = 0f, | |||||
float recurrent_dropout = 0f, | |||||
bool reset_after = true); | |||||
public ILayer GRU( | |||||
int units, | |||||
string activation = "tanh", | |||||
string recurrent_activation = "sigmoid", | |||||
bool use_bias = true, | |||||
string kernel_initializer = "glorot_uniform", | |||||
string recurrent_initializer = "orthogonal", | |||||
string bias_initializer = "zeros", | |||||
float dropout = 0f, | |||||
float recurrent_dropout = 0f, | |||||
bool return_sequences = false, | |||||
bool return_state = false, | |||||
bool go_backwards = false, | |||||
bool stateful = false, | |||||
bool unroll = false, | |||||
bool time_major = false, | |||||
bool reset_after = true | |||||
); | |||||
/// <summary> | |||||
/// Bidirectional wrapper for RNNs. | |||||
/// </summary> | |||||
/// <param name="layer">`keras.layers.RNN` instance, such as `keras.layers.LSTM` or `keras.layers.GRU`</param> | |||||
/// automatically.</param> | |||||
/// <returns></returns> | |||||
public ILayer Bidirectional( | |||||
ILayer layer, | |||||
string merge_mode = "concat", | |||||
NDArray weights = null, | |||||
ILayer backward_layer = null); | |||||
public ILayer Subtract(); | public ILayer Subtract(); | ||||
} | } | ||||
} | } |
@@ -0,0 +1,25 @@ | |||||
using System; | |||||
using System.Collections.Generic; | |||||
using System.Text; | |||||
using Tensorflow.Common.Types; | |||||
namespace Tensorflow.Keras.Layers | |||||
{ | |||||
public interface IRnnCell: ILayer | |||||
{ | |||||
/// <summary> | |||||
/// If the derived class tends to not implement it, please return null. | |||||
/// </summary> | |||||
INestStructure<long>? StateSize { get; } | |||||
/// <summary> | |||||
/// If the derived class tends to not implement it, please return null. | |||||
/// </summary> | |||||
INestStructure<long>? OutputSize { get; } | |||||
/// <summary> | |||||
/// Whether the optional RNN args are supported when appying the layer. | |||||
/// In other words, whether `Apply` is overwrited with process of `RnnOptionalArgs`. | |||||
/// </summary> | |||||
bool SupportOptionalArgs { get; } | |||||
Tensors GetInitialState(Tensors inputs, Tensor batch_size, TF_DataType dtype); | |||||
} | |||||
} |
@@ -0,0 +1,12 @@ | |||||
using System; | |||||
using System.Collections.Generic; | |||||
using System.Text; | |||||
namespace Tensorflow.Keras.Layers | |||||
{ | |||||
public interface IStackedRnnCells : IRnnCell | |||||
{ | |||||
int Count { get; } | |||||
IRnnCell this[int idx] { get; } | |||||
} | |||||
} |
@@ -3,6 +3,7 @@ using Newtonsoft.Json; | |||||
using System; | using System; | ||||
using System.Collections.Generic; | using System.Collections.Generic; | ||||
using System.Text; | using System.Text; | ||||
using Tensorflow.Common.Types; | |||||
namespace Tensorflow.Keras.Saving.Json | namespace Tensorflow.Keras.Saving.Json | ||||
{ | { | ||||
@@ -6,6 +6,7 @@ using System.Text; | |||||
using System.Diagnostics; | using System.Diagnostics; | ||||
using OneOf.Types; | using OneOf.Types; | ||||
using Tensorflow.Keras.Saving.Json; | using Tensorflow.Keras.Saving.Json; | ||||
using Tensorflow.Common.Types; | |||||
namespace Tensorflow.Keras.Saving | namespace Tensorflow.Keras.Saving | ||||
{ | { | ||||
@@ -74,8 +74,3 @@ namespace Tensorflow | |||||
=> IsScalar ? $"{axis[0]}" : $"({string.Join(", ", axis)})"; | => IsScalar ? $"{axis[0]}" : $"({string.Join(", ", axis)})"; | ||||
} | } | ||||
} | } | ||||
namespace System.Runtime.CompilerServices | |||||
{ | |||||
internal static class IsExternalInit { } | |||||
} |
@@ -107,9 +107,15 @@ namespace Tensorflow.NumPy | |||||
public static implicit operator NDArray(bool value) | public static implicit operator NDArray(bool value) | ||||
=> new NDArray(value); | => new NDArray(value); | ||||
public static implicit operator NDArray(byte value) | |||||
=> new NDArray(value); | |||||
public static implicit operator NDArray(int value) | public static implicit operator NDArray(int value) | ||||
=> new NDArray(value); | => new NDArray(value); | ||||
public static implicit operator NDArray(long value) | |||||
=> new NDArray(value); | |||||
public static implicit operator NDArray(float value) | public static implicit operator NDArray(float value) | ||||
=> new NDArray(value); | => new NDArray(value); | ||||
@@ -7,7 +7,7 @@ namespace Tensorflow.NumPy | |||||
{ | { | ||||
public class NDArrayRender | public class NDArrayRender | ||||
{ | { | ||||
public static string ToString(NDArray array) | |||||
public static string ToString(NDArray array, int maxLength = 10) | |||||
{ | { | ||||
Shape shape = array.shape; | Shape shape = array.shape; | ||||
if (shape.IsScalar) | if (shape.IsScalar) | ||||
@@ -15,12 +15,12 @@ namespace Tensorflow.NumPy | |||||
var s = new StringBuilder(); | var s = new StringBuilder(); | ||||
s.Append("array("); | s.Append("array("); | ||||
Build(s, array); | |||||
Build(s, array, maxLength); | |||||
s.Append(")"); | s.Append(")"); | ||||
return s.ToString(); | return s.ToString(); | ||||
} | } | ||||
static void Build(StringBuilder s, NDArray array) | |||||
static void Build(StringBuilder s, NDArray array, int maxLength) | |||||
{ | { | ||||
var shape = array.shape; | var shape = array.shape; | ||||
@@ -35,11 +35,11 @@ namespace Tensorflow.NumPy | |||||
var len = shape[0]; | var len = shape[0]; | ||||
s.Append("["); | s.Append("["); | ||||
if (len <= 10) | |||||
if (len <= maxLength) | |||||
{ | { | ||||
for (int i = 0; i < len; i++) | for (int i = 0; i < len; i++) | ||||
{ | { | ||||
Build(s, array[i]); | |||||
Build(s, array[i], maxLength); | |||||
if (i < len - 1) | if (i < len - 1) | ||||
{ | { | ||||
s.Append(", "); | s.Append(", "); | ||||
@@ -49,9 +49,9 @@ namespace Tensorflow.NumPy | |||||
} | } | ||||
else | else | ||||
{ | { | ||||
for (int i = 0; i < 5; i++) | |||||
for (int i = 0; i < maxLength / 2; i++) | |||||
{ | { | ||||
Build(s, array[i]); | |||||
Build(s, array[i], maxLength); | |||||
if (i < len - 1) | if (i < len - 1) | ||||
{ | { | ||||
s.Append(", "); | s.Append(", "); | ||||
@@ -62,9 +62,9 @@ namespace Tensorflow.NumPy | |||||
s.Append(" ... "); | s.Append(" ... "); | ||||
s.AppendLine(); | s.AppendLine(); | ||||
for (int i = (int)len - 5; i < len; i++) | |||||
for (int i = (int)len - maxLength / 2; i < len; i++) | |||||
{ | { | ||||
Build(s, array[i]); | |||||
Build(s, array[i], maxLength); | |||||
if (i < len - 1) | if (i < len - 1) | ||||
{ | { | ||||
s.Append(", "); | s.Append(", "); | ||||
@@ -13,6 +13,10 @@ namespace Tensorflow.NumPy | |||||
public static NDArray argmax(NDArray a, Axis? axis = null) | public static NDArray argmax(NDArray a, Axis? axis = null) | ||||
=> new NDArray(math_ops.argmax(a, axis ?? 0)); | => new NDArray(math_ops.argmax(a, axis ?? 0)); | ||||
[AutoNumPy] | |||||
public static NDArray argmin(NDArray a, Axis? axis = null) | |||||
=> new NDArray(math_ops.argmin(a, axis ?? 0)); | |||||
[AutoNumPy] | [AutoNumPy] | ||||
public static NDArray argsort(NDArray a, Axis? axis = null) | public static NDArray argsort(NDArray a, Axis? axis = null) | ||||
=> new NDArray(sort_ops.argsort(a, axis: axis ?? -1)); | => new NDArray(sort_ops.argsort(a, axis: axis ?? -1)); | ||||
@@ -10,10 +10,10 @@ namespace Tensorflow.NumPy | |||||
public partial class np | public partial class np | ||||
{ | { | ||||
[AutoNumPy] | [AutoNumPy] | ||||
public static NDArray amin(NDArray x, int axis = 0) => new NDArray(tf.arg_min(x, axis)); | |||||
public static NDArray amin(NDArray x, int axis = 0) => new NDArray(tf.min(x, axis)); | |||||
[AutoNumPy] | [AutoNumPy] | ||||
public static NDArray amax(NDArray x, int axis = 0) => new NDArray(tf.math.argmax(x, axis)); | |||||
public static NDArray amax(NDArray x, int axis = 0) => new NDArray(tf.max(x, axis)); | |||||
[AutoNumPy] | [AutoNumPy] | ||||
public static NDArray average(NDArray a, int axis = -1, NDArray? weights = null, bool returned = false) | public static NDArray average(NDArray a, int axis = -1, NDArray? weights = null, bool returned = false) | ||||
@@ -49,9 +49,30 @@ namespace Tensorflow.NumPy | |||||
[AutoNumPy] | [AutoNumPy] | ||||
public static NDArray prod<T>(params T[] array) where T : unmanaged | public static NDArray prod<T>(params T[] array) where T : unmanaged | ||||
=> new NDArray(tf.reduce_prod(new NDArray(array))); | => new NDArray(tf.reduce_prod(new NDArray(array))); | ||||
[AutoNumPy] | |||||
public static NDArray dot(NDArray x1, NDArray x2, NDArray? axes = null, string? name = null) | |||||
{ | |||||
//if axes mentioned | |||||
if (axes != null) | |||||
{ | |||||
return new NDArray(tf.dot_prod(x1, x2, axes, name)); | |||||
} | |||||
if (x1.shape.ndim > 1) | |||||
{ | |||||
x1 = GetFlattenArray(x1); | |||||
} | |||||
if (x2.shape.ndim > 1) | |||||
{ | |||||
x2 = GetFlattenArray(x2); | |||||
} | |||||
//if axes not mentioned, default 0,0 | |||||
return new NDArray(tf.dot_prod(x1, x2, axes: new int[] { 0, 0 }, name)); | |||||
} | |||||
[AutoNumPy] | [AutoNumPy] | ||||
public static NDArray power(NDArray x, NDArray y) => new NDArray(tf.pow(x, y)); | public static NDArray power(NDArray x, NDArray y) => new NDArray(tf.pow(x, y)); | ||||
[AutoNumPy] | |||||
public static NDArray square(NDArray x) => new NDArray(tf.square(x)); | |||||
[AutoNumPy] | [AutoNumPy] | ||||
public static NDArray sin(NDArray x) => new NDArray(math_ops.sin(x)); | public static NDArray sin(NDArray x) => new NDArray(math_ops.sin(x)); | ||||
@@ -19,13 +19,14 @@ using System; | |||||
using System.Collections.Generic; | using System.Collections.Generic; | ||||
using System.Linq; | using System.Linq; | ||||
using System.Text; | using System.Text; | ||||
using Tensorflow.Common.Types; | |||||
using Tensorflow.Keras.Saving.Common; | using Tensorflow.Keras.Saving.Common; | ||||
using Tensorflow.NumPy; | using Tensorflow.NumPy; | ||||
namespace Tensorflow | namespace Tensorflow | ||||
{ | { | ||||
[JsonConverter(typeof(CustomizedShapeJsonConverter))] | [JsonConverter(typeof(CustomizedShapeJsonConverter))] | ||||
public class Shape | |||||
public class Shape : INestStructure<long> | |||||
{ | { | ||||
public int ndim => _dims == null ? -1 : _dims.Length; | public int ndim => _dims == null ? -1 : _dims.Length; | ||||
long[] _dims; | long[] _dims; | ||||
@@ -41,6 +42,27 @@ namespace Tensorflow | |||||
} | } | ||||
} | } | ||||
public NestType NestType => NestType.List; | |||||
public int ShallowNestedCount => ndim; | |||||
/// <summary> | |||||
/// The total item count of depth 1 of the nested structure. | |||||
/// For example, [1, 2, [3, 4, 5]] has TotalNestedCount = 5. | |||||
/// </summary> | |||||
public int TotalNestedCount => ndim; | |||||
public IEnumerable<long> Flatten() => dims.Select(x => x); | |||||
public INestStructure<TOut> MapStructure<TOut>(Func<long, TOut> func) | |||||
{ | |||||
return new NestList<TOut>(dims.Select(x => func(x))); | |||||
} | |||||
public Nest<long> AsNest() | |||||
{ | |||||
return new NestList<long>(Flatten()).AsNest(); | |||||
} | |||||
#region https://docs.microsoft.com/en-us/dotnet/csharp/language-reference/proposals/csharp-8.0/ranges | #region https://docs.microsoft.com/en-us/dotnet/csharp/language-reference/proposals/csharp-8.0/ranges | ||||
public int Length => ndim; | public int Length => ndim; | ||||
public long[] Slice(int start, int length) | public long[] Slice(int start, int length) | ||||
@@ -0,0 +1,22 @@ | |||||
using System; | |||||
using System.Collections.Generic; | |||||
using System.Text; | |||||
using Tensorflow.NumPy; | |||||
namespace Tensorflow.Operations.Initializers | |||||
{ | |||||
/// <summary> | |||||
/// An initializer specially used for debugging (to load weights from disk). | |||||
/// </summary> | |||||
class NpyLoadInitializer : IInitializer | |||||
{ | |||||
string _path; | |||||
public NpyLoadInitializer(string path) { _path = path; } | |||||
public string ClassName => ""; | |||||
public IDictionary<string, object> Config => new Dictionary<string, object>(); | |||||
public Tensor Apply(InitializerArgs args) | |||||
{ | |||||
return np.load(_path); | |||||
} | |||||
} | |||||
} |
@@ -53,13 +53,12 @@ public class Orthogonal : IInitializer | |||||
// Compute the qr factorization | // Compute the qr factorization | ||||
var (q, r) = tf.linalg.qr(a, full_matrices: false); | var (q, r) = tf.linalg.qr(a, full_matrices: false); | ||||
// Make Q uniform | // Make Q uniform | ||||
var d = tf.linalg.tensor_diag_part(r); | |||||
var d = tf.linalg.tensor_diag_part(r.Single); | |||||
q *= tf.sign(d); | q *= tf.sign(d); | ||||
if (num_rows < num_cols) | if (num_rows < num_cols) | ||||
{ | { | ||||
// q = tf.linalg.matrix_transpose(q); | |||||
throw new NotImplementedException(""); | |||||
q = array_ops.matrix_transpose(q); | |||||
} | } | ||||
return _gain * tf.reshape(q, shape); | return _gain * tf.reshape(q, shape); | ||||
@@ -11,6 +11,7 @@ namespace Tensorflow | |||||
/// Basic LSTM recurrent network cell. | /// Basic LSTM recurrent network cell. | ||||
/// The implementation is based on: http://arxiv.org/abs/1409.2329. | /// The implementation is based on: http://arxiv.org/abs/1409.2329. | ||||
/// </summary> | /// </summary> | ||||
[Obsolete("This is an incompleted tf v1 api, pleas use keras RNNs instead.")] | |||||
public class BasicLstmCell : LayerRnnCell | public class BasicLstmCell : LayerRnnCell | ||||
{ | { | ||||
int _num_units; | int _num_units; | ||||
@@ -88,7 +89,7 @@ namespace Tensorflow | |||||
gate_inputs = nn_ops.bias_add(gate_inputs, _bias); | gate_inputs = nn_ops.bias_add(gate_inputs, _bias); | ||||
// i = input_gate, j = new_input, f = forget_gate, o = output_gate | // i = input_gate, j = new_input, f = forget_gate, o = output_gate | ||||
var tensors = array_ops.split(value: gate_inputs, num_split: 4, axis: one); | |||||
var tensors = array_ops.split(value: gate_inputs, num_or_size_splits: 4, axis: one); | |||||
var (i, j, f, o) = (tensors[0], tensors[1], tensors[2], tensors[3]); | var (i, j, f, o) = (tensors[0], tensors[1], tensors[2], tensors[3]); | ||||
var forget_bias_tensor = constant_op.constant(_forget_bias, dtype: f.dtype); | var forget_bias_tensor = constant_op.constant(_forget_bias, dtype: f.dtype); | ||||
@@ -20,6 +20,7 @@ using static Tensorflow.Binding; | |||||
namespace Tensorflow | namespace Tensorflow | ||||
{ | { | ||||
[Obsolete("This is an incompleted tf v1 api, pleas use keras RNNs instead.")] | |||||
public class BasicRnnCell : LayerRnnCell | public class BasicRnnCell : LayerRnnCell | ||||
{ | { | ||||
int _num_units; | int _num_units; | ||||
@@ -19,6 +19,7 @@ using static Tensorflow.Binding; | |||||
namespace Tensorflow | namespace Tensorflow | ||||
{ | { | ||||
[Obsolete("This is an incompleted tf v1 api, pleas use keras RNNs instead.")] | |||||
public class LayerRnnCell : RnnCell | public class LayerRnnCell : RnnCell | ||||
{ | { | ||||
protected InputSpec inputSpec; | protected InputSpec inputSpec; | ||||
@@ -16,10 +16,11 @@ | |||||
using System; | using System; | ||||
using System.Collections.Generic; | using System.Collections.Generic; | ||||
using Tensorflow.Common.Types; | |||||
using Tensorflow.Keras; | using Tensorflow.Keras; | ||||
using Tensorflow.Keras.ArgsDefinition; | using Tensorflow.Keras.ArgsDefinition; | ||||
using Tensorflow.Keras.ArgsDefinition.Rnn; | |||||
using Tensorflow.Keras.Engine; | using Tensorflow.Keras.Engine; | ||||
using Tensorflow.Keras.Layers; | |||||
using Tensorflow.Keras.Saving; | using Tensorflow.Keras.Saving; | ||||
using Tensorflow.NumPy; | using Tensorflow.NumPy; | ||||
using Tensorflow.Operations; | using Tensorflow.Operations; | ||||
@@ -50,7 +51,8 @@ namespace Tensorflow | |||||
/// matching structure of Tensors having shape `[batch_size].concatenate(s)` | /// matching structure of Tensors having shape `[batch_size].concatenate(s)` | ||||
/// for each `s` in `self.batch_size`. | /// for each `s` in `self.batch_size`. | ||||
/// </summary> | /// </summary> | ||||
public abstract class RnnCell : ILayer, RNNArgs.IRnnArgCell | |||||
[Obsolete("This is an incompleted tf v1 api, pleas use keras RNNs instead.")] | |||||
public abstract class RnnCell : ILayer, IRnnCell | |||||
{ | { | ||||
/// <summary> | /// <summary> | ||||
/// Attribute that indicates whether the cell is a TF RNN cell, due the slight | /// Attribute that indicates whether the cell is a TF RNN cell, due the slight | ||||
@@ -142,7 +144,7 @@ namespace Tensorflow | |||||
throw new NotImplementedException("_zero_state_tensors"); | throw new NotImplementedException("_zero_state_tensors"); | ||||
} | } | ||||
public Tensors Apply(Tensors inputs, Tensor state = null, bool is_training = false) | |||||
public Tensors Apply(Tensors inputs, Tensors state = null, bool? is_training = false, IOptionalArgs? optional_args = null) | |||||
{ | { | ||||
throw new NotImplementedException(); | throw new NotImplementedException(); | ||||
} | } | ||||
@@ -173,5 +175,18 @@ namespace Tensorflow | |||||
{ | { | ||||
throw new NotImplementedException(); | throw new NotImplementedException(); | ||||
} | } | ||||
public (Tensor, Tensors) Call(Tensors inputs, Tensors states, bool? training = null) | |||||
{ | |||||
throw new NotImplementedException(); | |||||
} | |||||
public Tensors GetInitialState(Tensors inputs = null, Tensor batch_size = null, TF_DataType dtype = TF_DataType.DtInvalid) | |||||
{ | |||||
throw new NotImplementedException(); | |||||
} | |||||
public INestStructure<long> StateSize => throw new NotImplementedException(); | |||||
public INestStructure<long> OutputSize => throw new NotImplementedException(); | |||||
public bool IsTFRnnCell => throw new NotImplementedException(); | |||||
public bool SupportOptionalArgs => throw new NotImplementedException(); | |||||
} | } | ||||
} | } |
@@ -15,9 +15,11 @@ | |||||
******************************************************************************/ | ******************************************************************************/ | ||||
using Google.Protobuf; | using Google.Protobuf; | ||||
using Google.Protobuf.Collections; | |||||
using System; | using System; | ||||
using System.Collections.Generic; | using System.Collections.Generic; | ||||
using System.Linq; | using System.Linq; | ||||
using Tensorflow.Functions; | |||||
using static Tensorflow.Binding; | using static Tensorflow.Binding; | ||||
using static Tensorflow.OpDef.Types; | using static Tensorflow.OpDef.Types; | ||||
@@ -387,9 +389,13 @@ namespace Tensorflow | |||||
case "list(type)": | case "list(type)": | ||||
attr_value.List.Type.AddRange((value as IList<TF_DataType>).Select(x => _MakeType(x, attr_def))); | attr_value.List.Type.AddRange((value as IList<TF_DataType>).Select(x => _MakeType(x, attr_def))); | ||||
break; | break; | ||||
case "list(float)": | |||||
if (value != null) | |||||
attr_value.List.F.AddRange((value as IEnumerable<float>).ToArray()); | |||||
break; | |||||
case "list(int)": | case "list(int)": | ||||
if (value != null) | if (value != null) | ||||
attr_value.List.I.AddRange((value as int[]).Select(x => Convert.ToInt64(x))); | |||||
attr_value.List.I.AddRange((value as IEnumerable<int>).Select(x => Convert.ToInt64(x))); | |||||
break; | break; | ||||
case "bool": | case "bool": | ||||
attr_value.B = (bool)value; | attr_value.B = (bool)value; | ||||
@@ -420,6 +426,15 @@ namespace Tensorflow | |||||
case "list(shape)": | case "list(shape)": | ||||
attr_value.List.Shape.AddRange((value as Shape[]).Select(x => _MakeShape(x, attr_def))); | attr_value.List.Shape.AddRange((value as Shape[]).Select(x => _MakeShape(x, attr_def))); | ||||
break; | break; | ||||
case "func": | |||||
attr_value.Func = _MakeFunc(value, attr_def.Name); | |||||
break; | |||||
case "list(func)": | |||||
attr_value.List.Func.AddRange(_MakeFuncList(value, attr_def.Name)); | |||||
break; | |||||
case "list(string)": | |||||
attr_value.List.S.AddRange((value as IEnumerable<string>).Select(x => ByteString.CopyFromUtf8(x))); | |||||
break; | |||||
default: | default: | ||||
throw new TypeError($"SetAttrValue: can't not convert attr_def.Type '{attr_def.Type}' to protos."); | throw new TypeError($"SetAttrValue: can't not convert attr_def.Type '{attr_def.Type}' to protos."); | ||||
} | } | ||||
@@ -427,6 +442,47 @@ namespace Tensorflow | |||||
return attr_value; | return attr_value; | ||||
} | } | ||||
private NameAttrList _MakeFunc(object func, string arg_name) | |||||
{ | |||||
if(func is NameAttrList attrList) | |||||
{ | |||||
return attrList; | |||||
} | |||||
NameAttrList fn_attr; | |||||
if(func is string funcStr) | |||||
{ | |||||
fn_attr = new NameAttrList() { Name = funcStr }; | |||||
} | |||||
else if(func is ConcreteFunction concrete) | |||||
{ | |||||
concrete.AddTograph(ops.get_default_graph()); | |||||
fn_attr = concrete.AsNameAttrList; | |||||
} | |||||
else if(func is EagerDefinedFunction eager) | |||||
{ | |||||
eager.AddToGraph(ops.get_default_graph()); | |||||
fn_attr = new NameAttrList() { Name = eager.Name }; | |||||
} | |||||
else | |||||
{ | |||||
throw new TypeError($"Don't know how to convert {func} to a func for argument {arg_name}"); | |||||
} | |||||
return fn_attr; | |||||
} | |||||
private List<NameAttrList> _MakeFuncList(object funcList, string arg_name) | |||||
{ | |||||
List<NameAttrList> res = new List<NameAttrList>(); | |||||
if(funcList is IEnumerable enumerable) | |||||
{ | |||||
foreach(var func in enumerable) | |||||
{ | |||||
res.Add(_MakeFunc(func, arg_name)); | |||||
} | |||||
} | |||||
return res; | |||||
} | |||||
private bool _IsListParameter(ArgDef arg) | private bool _IsListParameter(ArgDef arg) | ||||
{ | { | ||||
if (!String.IsNullOrEmpty(arg.NumberAttr)) | if (!String.IsNullOrEmpty(arg.NumberAttr)) | ||||