@@ -30,6 +30,10 @@ namespace Tensorflow | |||||
if (src_graph == null) | if (src_graph == null) | ||||
src_graph = ops.get_default_graph(); | src_graph = ops.get_default_graph(); | ||||
// If src_graph is a _FuncGraph (i.e. a function body), gather it and all | |||||
// ancestor graphs. This is necessary for correctly handling captured values. | |||||
var curr_graph = src_graph; | |||||
var ys1 = _AsList(ys); | var ys1 = _AsList(ys); | ||||
var xs1 = _AsList(xs); | var xs1 = _AsList(xs); | ||||
List<Tensor> grad_ys1 = null; | List<Tensor> grad_ys1 = null; | ||||
@@ -47,7 +51,10 @@ namespace Tensorflow | |||||
string grad_scope = ""; | string grad_scope = ""; | ||||
using (var namescope = new ops.name_scope<Tensor>(name, "gradients", values: all)) | using (var namescope = new ops.name_scope<Tensor>(name, "gradients", values: all)) | ||||
{ | |||||
grad_scope = namescope; | grad_scope = namescope; | ||||
} | |||||
} | } | ||||
private static List<Tensor> _AsList(object ys) | private static List<Tensor> _AsList(object ys) | ||||
@@ -173,7 +173,6 @@ namespace Tensorflow | |||||
string new_stack = ""; | string new_stack = ""; | ||||
if (name.EndsWith("/")) | if (name.EndsWith("/")) | ||||
new_stack = ops._name_from_scope_name(name); | new_stack = ops._name_from_scope_name(name); | ||||
else | else | ||||
@@ -15,14 +15,15 @@ namespace Tensorflow | |||||
var g = ops.get_default_graph(); | var g = ops.get_default_graph(); | ||||
var op_def = g.GetOpDef(op_type_name); | var op_def = g.GetOpDef(op_type_name); | ||||
// Default name if not specified. | |||||
if (String.IsNullOrEmpty(name)) | if (String.IsNullOrEmpty(name)) | ||||
{ | |||||
name = op_type_name; | name = op_type_name; | ||||
} | |||||
string scope = ""; | |||||
using (var namescope = new ops.name_scope<object>(name)) | |||||
scope = namescope; | |||||
// Check for deprecation | |||||
if(op_def.Deprecation != null && op_def.Deprecation.Version > 0) | |||||
{ | |||||
} | |||||
var default_type_attr_map = new Dictionary<string, object>(); | var default_type_attr_map = new Dictionary<string, object>(); | ||||
foreach (var attr_def in op_def.Attr) | foreach (var attr_def in op_def.Attr) | ||||
@@ -39,101 +40,107 @@ namespace Tensorflow | |||||
var inputs = new List<Tensor>(); | var inputs = new List<Tensor>(); | ||||
var input_types = new List<TF_DataType>(); | var input_types = new List<TF_DataType>(); | ||||
// Perform input type inference | |||||
foreach (var input_arg in op_def.InputArg) | |||||
string scope = ""; | |||||
using (var namescope = new ops.name_scope<object>(name)) | |||||
{ | { | ||||
var input_name = input_arg.Name; | |||||
if (keywords[input_name] is double int_value) | |||||
{ | |||||
keywords[input_name] = constant_op.Constant(int_value, input_name); | |||||
} | |||||
scope = namescope; | |||||
if (keywords[input_name] is Tensor value) | |||||
// Perform input type inference | |||||
foreach (var input_arg in op_def.InputArg) | |||||
{ | { | ||||
if (keywords.ContainsKey(input_name)) | |||||
var input_name = input_arg.Name; | |||||
if (keywords[input_name] is double int_value) | |||||
{ | { | ||||
inputs.Add(value); | |||||
keywords[input_name] = constant_op.Constant(int_value, input_name); | |||||
} | } | ||||
if (!String.IsNullOrEmpty(input_arg.TypeAttr)) | |||||
if (keywords[input_name] is Tensor value) | |||||
{ | { | ||||
attrs[input_arg.TypeAttr] = value.dtype; | |||||
if (keywords.ContainsKey(input_name)) | |||||
{ | |||||
inputs.Add(value); | |||||
} | |||||
if (!String.IsNullOrEmpty(input_arg.TypeAttr)) | |||||
{ | |||||
attrs[input_arg.TypeAttr] = value.dtype; | |||||
} | |||||
if (input_arg.IsRef) | |||||
{ | |||||
} | |||||
else | |||||
{ | |||||
input_types.Add(value.dtype); | |||||
} | |||||
} | } | ||||
} | |||||
if (input_arg.IsRef) | |||||
{ | |||||
} | |||||
else | |||||
// Process remaining attrs | |||||
foreach (var attr in op_def.Attr) | |||||
{ | |||||
if (keywords.ContainsKey(attr.Name)) | |||||
{ | { | ||||
input_types.Add(value.dtype); | |||||
attrs[attr.Name] = keywords[attr.Name]; | |||||
} | } | ||||
} | } | ||||
} | |||||
// Process remaining attrs | |||||
foreach (var attr in op_def.Attr) | |||||
{ | |||||
if (keywords.ContainsKey(attr.Name)) | |||||
// Convert attr values to AttrValue protos. | |||||
var attr_protos = new Dictionary<string, AttrValue>(); | |||||
foreach (var attr_def in op_def.Attr) | |||||
{ | { | ||||
attrs[attr.Name] = keywords[attr.Name]; | |||||
} | |||||
} | |||||
var key = attr_def.Name; | |||||
var value = attrs[key]; | |||||
var attr_value = new AttrValue(); | |||||
// Convert attr values to AttrValue protos. | |||||
var attr_protos = new Dictionary<string, AttrValue>(); | |||||
foreach (var attr_def in op_def.Attr) | |||||
{ | |||||
var key = attr_def.Name; | |||||
var value = attrs[key]; | |||||
var attr_value = new AttrValue(); | |||||
switch (attr_def.Type) | |||||
{ | |||||
case "string": | |||||
attr_value.S = Google.Protobuf.ByteString.CopyFromUtf8((string)value); | |||||
break; | |||||
case "type": | |||||
attr_value.Type = _MakeType((TF_DataType)value, attr_def); | |||||
break; | |||||
case "bool": | |||||
attr_value.B = (bool)value; | |||||
break; | |||||
case "shape": | |||||
attr_value.Shape = value == null ? | |||||
attr_def.DefaultValue.Shape : | |||||
tensor_util.as_shape((long[])value); | |||||
break; | |||||
default: | |||||
throw new InvalidDataException($"attr_def.Type {attr_def.Type}"); | |||||
} | |||||
switch (attr_def.Type) | |||||
{ | |||||
case "string": | |||||
attr_value.S = Google.Protobuf.ByteString.CopyFromUtf8((string)value); | |||||
break; | |||||
case "type": | |||||
attr_value.Type = _MakeType((TF_DataType)value, attr_def); | |||||
break; | |||||
case "bool": | |||||
attr_value.B = (bool)value; | |||||
break; | |||||
case "shape": | |||||
attr_value.Shape = value == null ? | |||||
attr_def.DefaultValue.Shape : | |||||
tensor_util.as_shape((long[])value); | |||||
break; | |||||
default: | |||||
throw new InvalidDataException($"attr_def.Type {attr_def.Type}"); | |||||
} | |||||
attr_protos[key] = attr_value; | |||||
} | |||||
attr_protos[key] = attr_value; | |||||
} | |||||
// Determine output types (possibly using attrs) | |||||
var output_types = new List<TF_DataType>(); | |||||
// Determine output types (possibly using attrs) | |||||
var output_types = new List<TF_DataType>(); | |||||
foreach (var arg in op_def.OutputArg) | |||||
{ | |||||
if (!String.IsNullOrEmpty(arg.NumberAttr)) | |||||
foreach (var arg in op_def.OutputArg) | |||||
{ | { | ||||
if (!String.IsNullOrEmpty(arg.NumberAttr)) | |||||
{ | |||||
} | |||||
else if (!String.IsNullOrEmpty(arg.TypeAttr)) | |||||
{ | |||||
output_types.Add((TF_DataType)attr_protos[arg.TypeAttr].Type); | |||||
} | |||||
} | } | ||||
else if (!String.IsNullOrEmpty(arg.TypeAttr)) | |||||
{ | |||||
output_types.Add((TF_DataType)attr_protos[arg.TypeAttr].Type); | |||||
} | |||||
} | |||||
// Add Op to graph | |||||
var op = g.create_op(op_type_name, inputs, output_types.ToArray(), | |||||
name: scope, | |||||
input_types: input_types.ToArray(), | |||||
attrs: attr_protos, | |||||
op_def: op_def); | |||||
// Add Op to graph | |||||
var op = g.create_op(op_type_name, inputs, output_types.ToArray(), | |||||
name: scope, | |||||
input_types: input_types.ToArray(), | |||||
attrs: attr_protos, | |||||
op_def: op_def); | |||||
return op; | |||||
return op; | |||||
} | |||||
} | } | ||||
public DataType _MakeType(TF_DataType v, AttrDef attr_def) | public DataType _MakeType(TF_DataType v, AttrDef attr_def) | ||||
@@ -4,9 +4,9 @@ | |||||
<TargetFramework>netstandard2.0</TargetFramework> | <TargetFramework>netstandard2.0</TargetFramework> | ||||
<AssemblyName>TensorFlow.NET</AssemblyName> | <AssemblyName>TensorFlow.NET</AssemblyName> | ||||
<RootNamespace>Tensorflow</RootNamespace> | <RootNamespace>Tensorflow</RootNamespace> | ||||
<Version>0.0.2</Version> | |||||
<Version>0.0.3</Version> | |||||
<Authors>Haiping Chen</Authors> | <Authors>Haiping Chen</Authors> | ||||
<Company>SciSharp.org</Company> | |||||
<Company>SciSharp STACK</Company> | |||||
<GeneratePackageOnBuild>true</GeneratePackageOnBuild> | <GeneratePackageOnBuild>true</GeneratePackageOnBuild> | ||||
<Copyright>Apache 2.0</Copyright> | <Copyright>Apache 2.0</Copyright> | ||||
<RepositoryUrl>https://github.com/SciSharp/TensorFlow.NET</RepositoryUrl> | <RepositoryUrl>https://github.com/SciSharp/TensorFlow.NET</RepositoryUrl> | ||||
@@ -16,7 +16,7 @@ | |||||
<PackageTags>TensorFlow, NumSharp, SciSharp, MachineLearning, TensorFlow.NET</PackageTags> | <PackageTags>TensorFlow, NumSharp, SciSharp, MachineLearning, TensorFlow.NET</PackageTags> | ||||
<Description>Google's TensorFlow binding in .NET Standard. | <Description>Google's TensorFlow binding in .NET Standard. | ||||
Docs: https://tensorflownet.readthedocs.io</Description> | Docs: https://tensorflownet.readthedocs.io</Description> | ||||
<AssemblyVersion>0.0.2.0</AssemblyVersion> | |||||
<AssemblyVersion>0.0.3.0</AssemblyVersion> | |||||
<PackageReleaseNotes>API updated</PackageReleaseNotes> | <PackageReleaseNotes>API updated</PackageReleaseNotes> | ||||
<LangVersion>7.2</LangVersion> | <LangVersion>7.2</LangVersion> | ||||
</PropertyGroup> | </PropertyGroup> | ||||
@@ -4,6 +4,10 @@ using System.Text; | |||||
namespace Tensorflow | namespace Tensorflow | ||||
{ | { | ||||
/// <summary> | |||||
/// TF_DataType holds the type for a scalar value. E.g., one slot in a tensor. | |||||
/// The enum values here are identical to corresponding values in types.proto. | |||||
/// </summary> | |||||
public enum TF_DataType | public enum TF_DataType | ||||
{ | { | ||||
DtInvalid = 0, | DtInvalid = 0, | ||||
@@ -30,6 +34,8 @@ namespace Tensorflow | |||||
TF_RESOURCE = 20, | TF_RESOURCE = 20, | ||||
TF_VARIANT = 21, | TF_VARIANT = 21, | ||||
TF_UINT32 = 22, | TF_UINT32 = 22, | ||||
TF_UINT64 = 23 | |||||
TF_UINT64 = 23, | |||||
DtDoubleRef = 102, // DT_DOUBLE_REF | |||||
} | } | ||||
} | } |
@@ -19,7 +19,10 @@ namespace Tensorflow | |||||
public Graph Graph => op.Graph; | public Graph Graph => op.Graph; | ||||
public Operation op { get; } | public Operation op { get; } | ||||
public string name; | |||||
/// <summary> | |||||
/// The string name of this tensor. | |||||
/// </summary> | |||||
public string name => $"{(op == null ? "Operation was not named" : $"{op.Name}:{value_index}")}"; | |||||
public int value_index { get; } | public int value_index { get; } | ||||
@@ -222,7 +225,7 @@ namespace Tensorflow | |||||
} | } | ||||
} | } | ||||
return $"{name} {dtype} {rank} {string.Join(",", shape)}"; | |||||
return $"{name} {dtype.ToString()} {rank} {string.Join(",", shape)}"; | |||||
} | } | ||||
public void Dispose() | public void Dispose() | ||||
@@ -17,6 +17,10 @@ namespace Tensorflow | |||||
public string Name { get; set; } | public string Name { get; set; } | ||||
public double LearningRate { get; set; } | public double LearningRate { get; set; } | ||||
public Tensor LearningRateTensor { get; set; } | public Tensor LearningRateTensor { get; set; } | ||||
public bool _use_locking; | |||||
public Dictionary<string, object> _slots; | |||||
public Dictionary<string, object> _non_slot_dict; | |||||
public Dictionary<string, object> _deferred_slot_restorations; | |||||
public Optimizer(double learning_rate, bool use_locking, string name = "") | public Optimizer(double learning_rate, bool use_locking, string name = "") | ||||
{ | { | ||||
@@ -24,6 +28,11 @@ namespace Tensorflow | |||||
throw new NotImplementedException("Must specify the optimizer name"); | throw new NotImplementedException("Must specify the optimizer name"); | ||||
Name = name; | Name = name; | ||||
_use_locking = use_locking; | |||||
// Dictionary of slots. | |||||
_slots = new Dictionary<string, object>(); | |||||
_non_slot_dict = new Dictionary<string, object>(); | |||||
_deferred_slot_restorations = new Dictionary<string, object>(); | |||||
} | } | ||||
/// <summary> | /// <summary> | ||||
@@ -68,7 +77,7 @@ namespace Tensorflow | |||||
break; | break; | ||||
} | } | ||||
var processors = var_list.Select(v => optimizer._get_processor(v)); | |||||
var processors = var_list.Select(v => optimizer._get_processor(v)).ToList(); | |||||
var var_refs = processors.Select(x => x.target()).ToList(); | var var_refs = processors.Select(x => x.target()).ToList(); | ||||
gradients_impl.gradients(loss, var_refs, grad_ys: grad_loss, | gradients_impl.gradients(loss, var_refs, grad_ys: grad_loss, | ||||
@@ -79,6 +79,17 @@ namespace Tensorflow | |||||
// have an issue if these other variables aren't initialized first by | // have an issue if these other variables aren't initialized first by | ||||
// using their initialized_value() method. | // using their initialized_value() method. | ||||
var _initializer_op = gen_state_ops.assign(_variable, _initial_value, validate_shape).op; | |||||
if (!String.IsNullOrEmpty(caching_device)) | |||||
{ | |||||
} | |||||
else | |||||
{ | |||||
} | |||||
ops.add_to_collections(collections, this); | ops.add_to_collections(collections, this); | ||||
} | } | ||||
} | } | ||||
@@ -1,4 +1,5 @@ | |||||
using System; | |||||
using NumSharp.Core; | |||||
using System; | |||||
using System.Collections.Generic; | using System.Collections.Generic; | ||||
using System.Text; | using System.Text; | ||||
@@ -33,5 +34,31 @@ namespace Tensorflow | |||||
return new Tensor(_op, 0, dtype); | return new Tensor(_op, 0, dtype); | ||||
} | } | ||||
/// <summary> | |||||
/// Update 'ref' by assigning 'value' to it | |||||
/// </summary> | |||||
/// <param name="REF"></param> | |||||
/// <param name="value"></param> | |||||
/// <param name="validate_shape"></param> | |||||
/// <param name="use_locking"></param> | |||||
/// <param name="name"></param> | |||||
public static Tensor assign(Tensor tensor, Tensor value, | |||||
bool validate_shape = true, | |||||
bool use_locking = true, | |||||
string name = "") | |||||
{ | |||||
var keywords = new Dictionary<string, object>(); | |||||
keywords.Add("ref", tensor); | |||||
keywords.Add("value", value); | |||||
keywords.Add("validate_shape", validate_shape); | |||||
keywords.Add("use_locking", use_locking); | |||||
var _op = _op_def_lib._apply_op_helper("Assign", name: name, keywords: keywords); | |||||
var _result = _op.outputs[0]; | |||||
var _inputs_flat = _op.inputs; | |||||
return _result; | |||||
} | |||||
} | } | ||||
} | } |
@@ -21,8 +21,6 @@ namespace Tensorflow | |||||
_default_name = default_name; | _default_name = default_name; | ||||
_values = values; | _values = values; | ||||
_ctx = new Context(); | _ctx = new Context(); | ||||
_name_scope = __enter__(); | |||||
} | } | ||||
public string __enter__() | public string __enter__() | ||||
@@ -38,8 +36,10 @@ namespace Tensorflow | |||||
if (g == null) | if (g == null) | ||||
g = get_default_graph(); | g = get_default_graph(); | ||||
return g.name_scope(_name); ; | |||||
_name_scope = g.name_scope(_name); | |||||
return _name_scope; | |||||
} | } | ||||
public void Dispose() | public void Dispose() | ||||
@@ -48,9 +48,13 @@ namespace Tensorflow | |||||
g._name_stack = g.old_stack; | g._name_stack = g.old_stack; | ||||
} | } | ||||
/// <summary> | |||||
/// __enter__() | |||||
/// </summary> | |||||
/// <param name="ns"></param> | |||||
public static implicit operator string(name_scope<T> ns) | public static implicit operator string(name_scope<T> ns) | ||||
{ | { | ||||
return ns._name_scope; | |||||
return ns.__enter__(); | |||||
} | } | ||||
} | } | ||||
} | } | ||||
@@ -7,7 +7,7 @@ | |||||
<ItemGroup> | <ItemGroup> | ||||
<PackageReference Include="NumSharp" Version="0.6.5" /> | <PackageReference Include="NumSharp" Version="0.6.5" /> | ||||
<PackageReference Include="TensorFlow.NET" Version="0.0.2" /> | |||||
<PackageReference Include="TensorFlow.NET" Version="0.0.3" /> | |||||
</ItemGroup> | </ItemGroup> | ||||
<ItemGroup> | <ItemGroup> | ||||
@@ -20,7 +20,7 @@ | |||||
<PackageReference Include="MSTest.TestAdapter" Version="1.4.0" /> | <PackageReference Include="MSTest.TestAdapter" Version="1.4.0" /> | ||||
<PackageReference Include="MSTest.TestFramework" Version="1.4.0" /> | <PackageReference Include="MSTest.TestFramework" Version="1.4.0" /> | ||||
<PackageReference Include="NumSharp" Version="0.6.5" /> | <PackageReference Include="NumSharp" Version="0.6.5" /> | ||||
<PackageReference Include="TensorFlow.NET" Version="0.0.2" /> | |||||
<PackageReference Include="TensorFlow.NET" Version="0.0.3" /> | |||||
</ItemGroup> | </ItemGroup> | ||||
<ItemGroup> | <ItemGroup> | ||||