@@ -30,6 +30,10 @@ namespace Tensorflow | |||
if (src_graph == null) | |||
src_graph = ops.get_default_graph(); | |||
// If src_graph is a _FuncGraph (i.e. a function body), gather it and all | |||
// ancestor graphs. This is necessary for correctly handling captured values. | |||
var curr_graph = src_graph; | |||
var ys1 = _AsList(ys); | |||
var xs1 = _AsList(xs); | |||
List<Tensor> grad_ys1 = null; | |||
@@ -47,7 +51,10 @@ namespace Tensorflow | |||
string grad_scope = ""; | |||
using (var namescope = new ops.name_scope<Tensor>(name, "gradients", values: all)) | |||
{ | |||
grad_scope = namescope; | |||
} | |||
} | |||
private static List<Tensor> _AsList(object ys) | |||
@@ -173,7 +173,6 @@ namespace Tensorflow | |||
string new_stack = ""; | |||
if (name.EndsWith("/")) | |||
new_stack = ops._name_from_scope_name(name); | |||
else | |||
@@ -15,14 +15,15 @@ namespace Tensorflow | |||
var g = ops.get_default_graph(); | |||
var op_def = g.GetOpDef(op_type_name); | |||
// Default name if not specified. | |||
if (String.IsNullOrEmpty(name)) | |||
{ | |||
name = op_type_name; | |||
} | |||
string scope = ""; | |||
using (var namescope = new ops.name_scope<object>(name)) | |||
scope = namescope; | |||
// Check for deprecation | |||
if(op_def.Deprecation != null && op_def.Deprecation.Version > 0) | |||
{ | |||
} | |||
var default_type_attr_map = new Dictionary<string, object>(); | |||
foreach (var attr_def in op_def.Attr) | |||
@@ -39,101 +40,107 @@ namespace Tensorflow | |||
var inputs = new List<Tensor>(); | |||
var input_types = new List<TF_DataType>(); | |||
// Perform input type inference | |||
foreach (var input_arg in op_def.InputArg) | |||
string scope = ""; | |||
using (var namescope = new ops.name_scope<object>(name)) | |||
{ | |||
var input_name = input_arg.Name; | |||
if (keywords[input_name] is double int_value) | |||
{ | |||
keywords[input_name] = constant_op.Constant(int_value, input_name); | |||
} | |||
scope = namescope; | |||
if (keywords[input_name] is Tensor value) | |||
// Perform input type inference | |||
foreach (var input_arg in op_def.InputArg) | |||
{ | |||
if (keywords.ContainsKey(input_name)) | |||
var input_name = input_arg.Name; | |||
if (keywords[input_name] is double int_value) | |||
{ | |||
inputs.Add(value); | |||
keywords[input_name] = constant_op.Constant(int_value, input_name); | |||
} | |||
if (!String.IsNullOrEmpty(input_arg.TypeAttr)) | |||
if (keywords[input_name] is Tensor value) | |||
{ | |||
attrs[input_arg.TypeAttr] = value.dtype; | |||
if (keywords.ContainsKey(input_name)) | |||
{ | |||
inputs.Add(value); | |||
} | |||
if (!String.IsNullOrEmpty(input_arg.TypeAttr)) | |||
{ | |||
attrs[input_arg.TypeAttr] = value.dtype; | |||
} | |||
if (input_arg.IsRef) | |||
{ | |||
} | |||
else | |||
{ | |||
input_types.Add(value.dtype); | |||
} | |||
} | |||
} | |||
if (input_arg.IsRef) | |||
{ | |||
} | |||
else | |||
// Process remaining attrs | |||
foreach (var attr in op_def.Attr) | |||
{ | |||
if (keywords.ContainsKey(attr.Name)) | |||
{ | |||
input_types.Add(value.dtype); | |||
attrs[attr.Name] = keywords[attr.Name]; | |||
} | |||
} | |||
} | |||
// Process remaining attrs | |||
foreach (var attr in op_def.Attr) | |||
{ | |||
if (keywords.ContainsKey(attr.Name)) | |||
// Convert attr values to AttrValue protos. | |||
var attr_protos = new Dictionary<string, AttrValue>(); | |||
foreach (var attr_def in op_def.Attr) | |||
{ | |||
attrs[attr.Name] = keywords[attr.Name]; | |||
} | |||
} | |||
var key = attr_def.Name; | |||
var value = attrs[key]; | |||
var attr_value = new AttrValue(); | |||
// Convert attr values to AttrValue protos. | |||
var attr_protos = new Dictionary<string, AttrValue>(); | |||
foreach (var attr_def in op_def.Attr) | |||
{ | |||
var key = attr_def.Name; | |||
var value = attrs[key]; | |||
var attr_value = new AttrValue(); | |||
switch (attr_def.Type) | |||
{ | |||
case "string": | |||
attr_value.S = Google.Protobuf.ByteString.CopyFromUtf8((string)value); | |||
break; | |||
case "type": | |||
attr_value.Type = _MakeType((TF_DataType)value, attr_def); | |||
break; | |||
case "bool": | |||
attr_value.B = (bool)value; | |||
break; | |||
case "shape": | |||
attr_value.Shape = value == null ? | |||
attr_def.DefaultValue.Shape : | |||
tensor_util.as_shape((long[])value); | |||
break; | |||
default: | |||
throw new InvalidDataException($"attr_def.Type {attr_def.Type}"); | |||
} | |||
switch (attr_def.Type) | |||
{ | |||
case "string": | |||
attr_value.S = Google.Protobuf.ByteString.CopyFromUtf8((string)value); | |||
break; | |||
case "type": | |||
attr_value.Type = _MakeType((TF_DataType)value, attr_def); | |||
break; | |||
case "bool": | |||
attr_value.B = (bool)value; | |||
break; | |||
case "shape": | |||
attr_value.Shape = value == null ? | |||
attr_def.DefaultValue.Shape : | |||
tensor_util.as_shape((long[])value); | |||
break; | |||
default: | |||
throw new InvalidDataException($"attr_def.Type {attr_def.Type}"); | |||
} | |||
attr_protos[key] = attr_value; | |||
} | |||
attr_protos[key] = attr_value; | |||
} | |||
// Determine output types (possibly using attrs) | |||
var output_types = new List<TF_DataType>(); | |||
// Determine output types (possibly using attrs) | |||
var output_types = new List<TF_DataType>(); | |||
foreach (var arg in op_def.OutputArg) | |||
{ | |||
if (!String.IsNullOrEmpty(arg.NumberAttr)) | |||
foreach (var arg in op_def.OutputArg) | |||
{ | |||
if (!String.IsNullOrEmpty(arg.NumberAttr)) | |||
{ | |||
} | |||
else if (!String.IsNullOrEmpty(arg.TypeAttr)) | |||
{ | |||
output_types.Add((TF_DataType)attr_protos[arg.TypeAttr].Type); | |||
} | |||
} | |||
else if (!String.IsNullOrEmpty(arg.TypeAttr)) | |||
{ | |||
output_types.Add((TF_DataType)attr_protos[arg.TypeAttr].Type); | |||
} | |||
} | |||
// Add Op to graph | |||
var op = g.create_op(op_type_name, inputs, output_types.ToArray(), | |||
name: scope, | |||
input_types: input_types.ToArray(), | |||
attrs: attr_protos, | |||
op_def: op_def); | |||
// Add Op to graph | |||
var op = g.create_op(op_type_name, inputs, output_types.ToArray(), | |||
name: scope, | |||
input_types: input_types.ToArray(), | |||
attrs: attr_protos, | |||
op_def: op_def); | |||
return op; | |||
return op; | |||
} | |||
} | |||
public DataType _MakeType(TF_DataType v, AttrDef attr_def) | |||
@@ -4,9 +4,9 @@ | |||
<TargetFramework>netstandard2.0</TargetFramework> | |||
<AssemblyName>TensorFlow.NET</AssemblyName> | |||
<RootNamespace>Tensorflow</RootNamespace> | |||
<Version>0.0.2</Version> | |||
<Version>0.0.3</Version> | |||
<Authors>Haiping Chen</Authors> | |||
<Company>SciSharp.org</Company> | |||
<Company>SciSharp STACK</Company> | |||
<GeneratePackageOnBuild>true</GeneratePackageOnBuild> | |||
<Copyright>Apache 2.0</Copyright> | |||
<RepositoryUrl>https://github.com/SciSharp/TensorFlow.NET</RepositoryUrl> | |||
@@ -16,7 +16,7 @@ | |||
<PackageTags>TensorFlow, NumSharp, SciSharp, MachineLearning, TensorFlow.NET</PackageTags> | |||
<Description>Google's TensorFlow binding in .NET Standard. | |||
Docs: https://tensorflownet.readthedocs.io</Description> | |||
<AssemblyVersion>0.0.2.0</AssemblyVersion> | |||
<AssemblyVersion>0.0.3.0</AssemblyVersion> | |||
<PackageReleaseNotes>API updated</PackageReleaseNotes> | |||
<LangVersion>7.2</LangVersion> | |||
</PropertyGroup> | |||
@@ -4,6 +4,10 @@ using System.Text; | |||
namespace Tensorflow | |||
{ | |||
/// <summary> | |||
/// TF_DataType holds the type for a scalar value. E.g., one slot in a tensor. | |||
/// The enum values here are identical to corresponding values in types.proto. | |||
/// </summary> | |||
public enum TF_DataType | |||
{ | |||
DtInvalid = 0, | |||
@@ -30,6 +34,8 @@ namespace Tensorflow | |||
TF_RESOURCE = 20, | |||
TF_VARIANT = 21, | |||
TF_UINT32 = 22, | |||
TF_UINT64 = 23 | |||
TF_UINT64 = 23, | |||
DtDoubleRef = 102, // DT_DOUBLE_REF | |||
} | |||
} |
@@ -19,7 +19,10 @@ namespace Tensorflow | |||
public Graph Graph => op.Graph; | |||
public Operation op { get; } | |||
public string name; | |||
/// <summary> | |||
/// The string name of this tensor. | |||
/// </summary> | |||
public string name => $"{(op == null ? "Operation was not named" : $"{op.Name}:{value_index}")}"; | |||
public int value_index { get; } | |||
@@ -222,7 +225,7 @@ namespace Tensorflow | |||
} | |||
} | |||
return $"{name} {dtype} {rank} {string.Join(",", shape)}"; | |||
return $"{name} {dtype.ToString()} {rank} {string.Join(",", shape)}"; | |||
} | |||
public void Dispose() | |||
@@ -17,6 +17,10 @@ namespace Tensorflow | |||
public string Name { get; set; } | |||
public double LearningRate { get; set; } | |||
public Tensor LearningRateTensor { get; set; } | |||
public bool _use_locking; | |||
public Dictionary<string, object> _slots; | |||
public Dictionary<string, object> _non_slot_dict; | |||
public Dictionary<string, object> _deferred_slot_restorations; | |||
public Optimizer(double learning_rate, bool use_locking, string name = "") | |||
{ | |||
@@ -24,6 +28,11 @@ namespace Tensorflow | |||
throw new NotImplementedException("Must specify the optimizer name"); | |||
Name = name; | |||
_use_locking = use_locking; | |||
// Dictionary of slots. | |||
_slots = new Dictionary<string, object>(); | |||
_non_slot_dict = new Dictionary<string, object>(); | |||
_deferred_slot_restorations = new Dictionary<string, object>(); | |||
} | |||
/// <summary> | |||
@@ -68,7 +77,7 @@ namespace Tensorflow | |||
break; | |||
} | |||
var processors = var_list.Select(v => optimizer._get_processor(v)); | |||
var processors = var_list.Select(v => optimizer._get_processor(v)).ToList(); | |||
var var_refs = processors.Select(x => x.target()).ToList(); | |||
gradients_impl.gradients(loss, var_refs, grad_ys: grad_loss, | |||
@@ -79,6 +79,17 @@ namespace Tensorflow | |||
// have an issue if these other variables aren't initialized first by | |||
// using their initialized_value() method. | |||
var _initializer_op = gen_state_ops.assign(_variable, _initial_value, validate_shape).op; | |||
if (!String.IsNullOrEmpty(caching_device)) | |||
{ | |||
} | |||
else | |||
{ | |||
} | |||
ops.add_to_collections(collections, this); | |||
} | |||
} | |||
@@ -1,4 +1,5 @@ | |||
using System; | |||
using NumSharp.Core; | |||
using System; | |||
using System.Collections.Generic; | |||
using System.Text; | |||
@@ -33,5 +34,31 @@ namespace Tensorflow | |||
return new Tensor(_op, 0, dtype); | |||
} | |||
/// <summary> | |||
/// Update 'ref' by assigning 'value' to it | |||
/// </summary> | |||
/// <param name="REF"></param> | |||
/// <param name="value"></param> | |||
/// <param name="validate_shape"></param> | |||
/// <param name="use_locking"></param> | |||
/// <param name="name"></param> | |||
public static Tensor assign(Tensor tensor, Tensor value, | |||
bool validate_shape = true, | |||
bool use_locking = true, | |||
string name = "") | |||
{ | |||
var keywords = new Dictionary<string, object>(); | |||
keywords.Add("ref", tensor); | |||
keywords.Add("value", value); | |||
keywords.Add("validate_shape", validate_shape); | |||
keywords.Add("use_locking", use_locking); | |||
var _op = _op_def_lib._apply_op_helper("Assign", name: name, keywords: keywords); | |||
var _result = _op.outputs[0]; | |||
var _inputs_flat = _op.inputs; | |||
return _result; | |||
} | |||
} | |||
} |
@@ -21,8 +21,6 @@ namespace Tensorflow | |||
_default_name = default_name; | |||
_values = values; | |||
_ctx = new Context(); | |||
_name_scope = __enter__(); | |||
} | |||
public string __enter__() | |||
@@ -38,8 +36,10 @@ namespace Tensorflow | |||
if (g == null) | |||
g = get_default_graph(); | |||
return g.name_scope(_name); ; | |||
_name_scope = g.name_scope(_name); | |||
return _name_scope; | |||
} | |||
public void Dispose() | |||
@@ -48,9 +48,13 @@ namespace Tensorflow | |||
g._name_stack = g.old_stack; | |||
} | |||
/// <summary> | |||
/// __enter__() | |||
/// </summary> | |||
/// <param name="ns"></param> | |||
public static implicit operator string(name_scope<T> ns) | |||
{ | |||
return ns._name_scope; | |||
return ns.__enter__(); | |||
} | |||
} | |||
} | |||
@@ -7,7 +7,7 @@ | |||
<ItemGroup> | |||
<PackageReference Include="NumSharp" Version="0.6.5" /> | |||
<PackageReference Include="TensorFlow.NET" Version="0.0.2" /> | |||
<PackageReference Include="TensorFlow.NET" Version="0.0.3" /> | |||
</ItemGroup> | |||
<ItemGroup> | |||
@@ -20,7 +20,7 @@ | |||
<PackageReference Include="MSTest.TestAdapter" Version="1.4.0" /> | |||
<PackageReference Include="MSTest.TestFramework" Version="1.4.0" /> | |||
<PackageReference Include="NumSharp" Version="0.6.5" /> | |||
<PackageReference Include="TensorFlow.NET" Version="0.0.2" /> | |||
<PackageReference Include="TensorFlow.NET" Version="0.0.3" /> | |||
</ItemGroup> | |||
<ItemGroup> | |||