@@ -15,6 +15,7 @@ | |||||
******************************************************************************/ | ******************************************************************************/ | ||||
using System; | using System; | ||||
using static Tensorflow.Python; | |||||
namespace Tensorflow.Keras | namespace Tensorflow.Keras | ||||
{ | { | ||||
@@ -43,9 +44,9 @@ namespace Tensorflow.Keras | |||||
{ | { | ||||
if (value == null) | if (value == null) | ||||
value = _IMAGE_DATA_FORMAT; | value = _IMAGE_DATA_FORMAT; | ||||
if (value.GetType() == typeof(ImageDataFormat)) | |||||
if (isinstance(value, typeof(ImageDataFormat))) | |||||
return (ImageDataFormat)value; | return (ImageDataFormat)value; | ||||
else if (value.GetType() == typeof(string)) | |||||
else if (isinstance(value, typeof(string))) | |||||
{ | { | ||||
ImageDataFormat dataFormat; | ImageDataFormat dataFormat; | ||||
if(Enum.TryParse((string)value, true, out dataFormat)) | if(Enum.TryParse((string)value, true, out dataFormat)) | ||||
@@ -141,7 +141,7 @@ namespace Tensorflow | |||||
dtype = input_arg.Type; | dtype = input_arg.Type; | ||||
else if (attrs.ContainsKey(input_arg.TypeAttr)) | else if (attrs.ContainsKey(input_arg.TypeAttr)) | ||||
dtype = (DataType)attrs[input_arg.TypeAttr]; | dtype = (DataType)attrs[input_arg.TypeAttr]; | ||||
else if (values.GetType() == typeof(string) && dtype == DataType.DtInvalid) | |||||
else if (isinstance(values, typeof(string)) && dtype == DataType.DtInvalid) | |||||
dtype = DataType.DtString; | dtype = DataType.DtString; | ||||
else if (default_type_attr_map.ContainsKey(input_arg.TypeAttr)) | else if (default_type_attr_map.ContainsKey(input_arg.TypeAttr)) | ||||
default_dtype = (DataType)default_type_attr_map[input_arg.TypeAttr]; | default_dtype = (DataType)default_type_attr_map[input_arg.TypeAttr]; | ||||
@@ -309,6 +309,36 @@ namespace Tensorflow | |||||
} | } | ||||
return (__object__)((object[] args) => { return "NaN"; }); | return (__object__)((object[] args) => { return "NaN"; }); | ||||
} | } | ||||
public static IEnumerable TupleToEnumerable(object tuple) | |||||
{ | |||||
Type t = tuple.GetType(); | |||||
if(t.IsGenericType && (t.FullName.StartsWith("System.Tuple") || t.FullName.StartsWith("System.ValueTuple"))) | |||||
{ | |||||
var flds = t.GetFields(); | |||||
for(int i = 0; i < flds.Length;i++) | |||||
{ | |||||
yield return ((object)flds[i].GetValue(tuple)); | |||||
} | |||||
} | |||||
else | |||||
{ | |||||
throw new System.Exception("Expected Tuple."); | |||||
} | |||||
} | |||||
public static bool isinstance(object Item1, Type Item2) | |||||
{ | |||||
return (Item1.GetType() == Item2); | |||||
} | |||||
public static bool isinstance(object Item1, object tuple) | |||||
{ | |||||
var tup = TupleToEnumerable(tuple); | |||||
foreach(var t in tup) | |||||
{ | |||||
if(isinstance(Item1, (Type)t)) | |||||
return true; | |||||
} | |||||
return false; | |||||
} | |||||
} | } | ||||
public interface IPython : IDisposable | public interface IPython : IDisposable | ||||
@@ -25,6 +25,31 @@ namespace TensorFlowNET.UnitTest | |||||
Assert.IsTrue(b); | Assert.IsTrue(b); | ||||
} | } | ||||
[TestMethod] | |||||
public void isinstance_test() | |||||
{ | |||||
var s1 = "hi"; | |||||
var s2 = "hello"; | |||||
var t1 = (s1, s2); | |||||
var t2 = (s1, s2, s1); | |||||
var t3 = (s2, s1); | |||||
var true1 = isinstance(s1, typeof(string)); | |||||
var false1 = isinstance(t1, typeof(string)); | |||||
var true2 = isinstance(t1, t3.GetType()); | |||||
var false2 = isinstance(t1, t2.GetType()); | |||||
var true3 = isinstance(t1, (t2.GetType(), t1.GetType(), typeof(string))); | |||||
var false3 = isinstance(t3, (t2.GetType(), typeof(string))); | |||||
Assert.IsTrue(true1); | |||||
Assert.IsTrue(true2); | |||||
Assert.IsTrue(true3); | |||||
Assert.IsFalse(false1); | |||||
Assert.IsFalse(false2); | |||||
Assert.IsFalse(false3); | |||||
} | |||||
[TestMethod] | [TestMethod] | ||||
public void hasattr_getattr() | public void hasattr_getattr() | ||||
{ | { | ||||