|
|
@@ -3,6 +3,7 @@ using System.Collections; |
|
|
|
using System.Collections.Generic; |
|
|
|
using System.Linq; |
|
|
|
using Tensorflow.Framework.Models; |
|
|
|
using static Tensorflow.Binding; |
|
|
|
|
|
|
|
namespace Tensorflow |
|
|
|
{ |
|
|
@@ -98,6 +99,20 @@ namespace Tensorflow |
|
|
|
return dataset; |
|
|
|
} |
|
|
|
|
|
|
|
public Tensor dataset_cardinality(string name = null) |
|
|
|
{ |
|
|
|
if (tf.Context.executing_eagerly()) |
|
|
|
{ |
|
|
|
var results = tf.Runner.TFE_FastPathExecute(tf.Context, tf.Context.DeviceName, |
|
|
|
"DatasetCardinality", name, |
|
|
|
null, |
|
|
|
variant_tensor); |
|
|
|
return results[0]; |
|
|
|
} |
|
|
|
|
|
|
|
throw new NotImplementedException(""); |
|
|
|
} |
|
|
|
|
|
|
|
public override string ToString() |
|
|
|
=> $"{GetType().Name} shapes: {string.Join(", ", structure.Select(x => x.shape))}, types: {string.Join(", ", structure.Select(x => "tf." + x.dtype.as_numpy_name()))}"; |
|
|
|
|
|
|
@@ -117,7 +132,9 @@ namespace Tensorflow |
|
|
|
break; |
|
|
|
} |
|
|
|
|
|
|
|
yield return (results[0], results.Length == 1 ? null : results[1]); |
|
|
|
yield return results.Length == 2 |
|
|
|
? (results[0], results[1]) |
|
|
|
: (null, results[0]); |
|
|
|
} |
|
|
|
} |
|
|
|
|
|
|
|