@@ -24,7 +24,7 @@ namespace Tensorflow | |||||
public partial class Operation | public partial class Operation | ||||
{ | { | ||||
public int NumOutputs => c_api.TF_OperationNumOutputs(_handle); | public int NumOutputs => c_api.TF_OperationNumOutputs(_handle); | ||||
public TF_DataType OutputType(int index) => c_api.TF_OperationOutputType(new TF_Output(_handle, index)); | |||||
public TF_DataType OutputType(int index) => c_api.TF_OperationOutputType(_tf_output(index)); | |||||
public int OutputListLength(string name) | public int OutputListLength(string name) | ||||
{ | { | ||||
@@ -44,7 +44,6 @@ namespace Tensorflow | |||||
public partial class Operation : ITensorOrOperation | public partial class Operation : ITensorOrOperation | ||||
{ | { | ||||
private readonly IntPtr _handle; // _c_op in python | private readonly IntPtr _handle; // _c_op in python | ||||
private readonly IntPtr _operDesc; | |||||
private readonly Graph _graph; | private readonly Graph _graph; | ||||
private NodeDef _node_def; | private NodeDef _node_def; | ||||
@@ -91,7 +90,7 @@ namespace Tensorflow | |||||
{ | { | ||||
_graph = g; | _graph = g; | ||||
_operDesc = c_api.TF_NewOperation(g, opType, oper_name); | |||||
var _operDesc = c_api.TF_NewOperation(g, opType, oper_name); | |||||
c_api.TF_SetAttrType(_operDesc, "dtype", TF_DataType.TF_INT32); | c_api.TF_SetAttrType(_operDesc, "dtype", TF_DataType.TF_INT32); | ||||
lock (Locks.ProcessWide) | lock (Locks.ProcessWide) | ||||
using (var status = new Status()) | using (var status = new Status()) | ||||
@@ -161,7 +160,7 @@ namespace Tensorflow | |||||
op_def = g.GetOpDef(node_def.Op); | op_def = g.GetOpDef(node_def.Op); | ||||
var grouped_inputs = _reconstruct_sequence_inputs(op_def, inputs, node_def.Attr); | var grouped_inputs = _reconstruct_sequence_inputs(op_def, inputs, node_def.Attr); | ||||
(_handle, _operDesc) = ops._create_c_op(g, node_def, grouped_inputs, control_input_ops.ToArray()); | |||||
_handle = ops._create_c_op(g, node_def, grouped_inputs, control_input_ops.ToArray()); | |||||
// Initialize self._outputs. | // Initialize self._outputs. | ||||
output_types = new TF_DataType[NumOutputs]; | output_types = new TF_DataType[NumOutputs]; | ||||
@@ -170,7 +169,7 @@ namespace Tensorflow | |||||
_outputs = new Tensor[NumOutputs]; | _outputs = new Tensor[NumOutputs]; | ||||
for (int i = 0; i < NumOutputs; i++) | for (int i = 0; i < NumOutputs; i++) | ||||
_outputs[i] = new Tensor(this, i, OutputType(i)); | |||||
_outputs[i] = new Tensor(this, i, output_types[i]); | |||||
graph._add_op(this); | graph._add_op(this); | ||||
@@ -275,7 +275,7 @@ namespace Tensorflow | |||||
/// </returns> | /// </returns> | ||||
public static Tensor[] _SwitchRefOrTensor(Tensor data, Tensor pred, string name = "Switch") | public static Tensor[] _SwitchRefOrTensor(Tensor data, Tensor pred, string name = "Switch") | ||||
{ | { | ||||
data = ops.convert_to_tensor_or_indexed_slices(data, name: "data"); | |||||
data = ops.convert_to_tensor_or_composite(data, name: "data"); | |||||
// NOTE(vrv): ops.colocate_with(data, ignore_existing=True) below | // NOTE(vrv): ops.colocate_with(data, ignore_existing=True) below | ||||
// addresses the following scenario. | // addresses the following scenario. | ||||
// | // | ||||
@@ -296,9 +296,8 @@ namespace Tensorflow | |||||
{ | { | ||||
if (data is Tensor) | if (data is Tensor) | ||||
{ | { | ||||
// TODO: ref_switch | |||||
//if (data.dtype._is_ref_dtype) | |||||
// return control_flow_ops.ref_switch(data, pred, name = name); | |||||
if (data.dtype.is_ref_dtype()) | |||||
return gen_control_flow_ops.ref_switch(data, pred, name: name); | |||||
} | } | ||||
return @switch(data, pred, name: name); | return @switch(data, pred, name: name); | ||||
} | } | ||||
@@ -114,6 +114,12 @@ namespace Tensorflow | |||||
return _op; | return _op; | ||||
} | } | ||||
public static Tensor[] ref_switch(Tensor data, Tensor pred, string name = null) | |||||
{ | |||||
var _op = _op_def_lib._apply_op_helper("RefSwitch", name, new { data, pred }); | |||||
return _op.outputs; | |||||
} | |||||
/// <summary> | /// <summary> | ||||
/// Forwards `data` to the output port determined by `pred`. | /// Forwards `data` to the output port determined by `pred`. | ||||
/// | /// | ||||
@@ -5,7 +5,7 @@ | |||||
<AssemblyName>TensorFlow.NET</AssemblyName> | <AssemblyName>TensorFlow.NET</AssemblyName> | ||||
<RootNamespace>Tensorflow</RootNamespace> | <RootNamespace>Tensorflow</RootNamespace> | ||||
<TargetTensorFlow>1.14.0</TargetTensorFlow> | <TargetTensorFlow>1.14.0</TargetTensorFlow> | ||||
<Version>0.11.2</Version> | |||||
<Version>0.11.3</Version> | |||||
<Authors>Haiping Chen, Meinrad Recheis, Eli Belash</Authors> | <Authors>Haiping Chen, Meinrad Recheis, Eli Belash</Authors> | ||||
<Company>SciSharp STACK</Company> | <Company>SciSharp STACK</Company> | ||||
<GeneratePackageOnBuild>true</GeneratePackageOnBuild> | <GeneratePackageOnBuild>true</GeneratePackageOnBuild> | ||||
@@ -17,7 +17,7 @@ | |||||
<PackageTags>TensorFlow, NumSharp, SciSharp, MachineLearning, TensorFlow.NET, C#</PackageTags> | <PackageTags>TensorFlow, NumSharp, SciSharp, MachineLearning, TensorFlow.NET, C#</PackageTags> | ||||
<Description>Google's TensorFlow full binding in .NET Standard. | <Description>Google's TensorFlow full binding in .NET Standard. | ||||
Docs: https://tensorflownet.readthedocs.io</Description> | Docs: https://tensorflownet.readthedocs.io</Description> | ||||
<AssemblyVersion>0.11.2.0</AssemblyVersion> | |||||
<AssemblyVersion>0.11.3.0</AssemblyVersion> | |||||
<PackageReleaseNotes>Changes since v0.10.0: | <PackageReleaseNotes>Changes since v0.10.0: | ||||
1. Upgrade NumSharp to v0.20. | 1. Upgrade NumSharp to v0.20. | ||||
2. Add DisposableObject class to manage object lifetime. | 2. Add DisposableObject class to manage object lifetime. | ||||
@@ -30,7 +30,7 @@ Docs: https://tensorflownet.readthedocs.io</Description> | |||||
9. MultiThread is safe. | 9. MultiThread is safe. | ||||
10. Support n-dim indexing for tensor.</PackageReleaseNotes> | 10. Support n-dim indexing for tensor.</PackageReleaseNotes> | ||||
<LangVersion>7.3</LangVersion> | <LangVersion>7.3</LangVersion> | ||||
<FileVersion>0.11.2.0</FileVersion> | |||||
<FileVersion>0.11.3.0</FileVersion> | |||||
<PackageLicenseFile>LICENSE</PackageLicenseFile> | <PackageLicenseFile>LICENSE</PackageLicenseFile> | ||||
<PackageRequireLicenseAcceptance>true</PackageRequireLicenseAcceptance> | <PackageRequireLicenseAcceptance>true</PackageRequireLicenseAcceptance> | ||||
<SignAssembly>true</SignAssembly> | <SignAssembly>true</SignAssembly> | ||||
@@ -117,7 +117,7 @@ namespace Tensorflow | |||||
/// Key to collect Variable objects that are global (shared across machines). | /// Key to collect Variable objects that are global (shared across machines). | ||||
/// Default collection for all variables, except local ones. | /// Default collection for all variables, except local ones. | ||||
/// </summary> | /// </summary> | ||||
public string GLOBAL_VARIABLES = GLOBAL_VARIABLES_; | |||||
public string GLOBAL_VARIABLES => GLOBAL_VARIABLES_; | |||||
public string TRAIN_OP => TRAIN_OP_; | public string TRAIN_OP => TRAIN_OP_; | ||||
@@ -206,7 +206,7 @@ namespace Tensorflow | |||||
/// </param> | /// </param> | ||||
/// <param name="control_inputs">A list of `Operation`s to set as control dependencies.</param> | /// <param name="control_inputs">A list of `Operation`s to set as control dependencies.</param> | ||||
/// <returns>A wrapped TF_Operation*.</returns> | /// <returns>A wrapped TF_Operation*.</returns> | ||||
public static (IntPtr, IntPtr) _create_c_op<T>(Graph graph, NodeDef node_def, T[] inputs, Operation[] control_inputs) | |||||
public static IntPtr _create_c_op<T>(Graph graph, NodeDef node_def, T[] inputs, Operation[] control_inputs) | |||||
{ | { | ||||
lock (Locks.ProcessWide) | lock (Locks.ProcessWide) | ||||
{ | { | ||||
@@ -249,7 +249,7 @@ namespace Tensorflow | |||||
status.Check(true); | status.Check(true); | ||||
return (c_op, op_desc); | |||||
return c_op; | |||||
} | } | ||||
} | } | ||||
@@ -27,7 +27,7 @@ namespace TensorFlowNET.UnitTest.ops_test | |||||
using (var g = tf.Graph().as_default()) | using (var g = tf.Graph().as_default()) | ||||
{ | { | ||||
var x = constant_op.constant(new[,] {{1, 2, 3}, {4, 5, 6}}); | var x = constant_op.constant(new[,] {{1, 2, 3}, {4, 5, 6}}); | ||||
var (c_op, op_desc) = ops._create_c_op(g, ops._NodeDef("Identity", "myop"), new[] {x}, new Operation[0]); | |||||
var c_op = ops._create_c_op(g, ops._NodeDef("Identity", "myop"), new[] {x}, new Operation[0]); | |||||
var op = g._create_op_from_tf_operation(c_op); | var op = g._create_op_from_tf_operation(c_op); | ||||
Assert.AreEqual("myop", op.name); | Assert.AreEqual("myop", op.name); | ||||
@@ -68,7 +68,7 @@ namespace TensorFlowNET.UnitTest.ops_test | |||||
var true_fn = new Func<Tensor>(() => | var true_fn = new Func<Tensor>(() => | ||||
{ | { | ||||
var (c_op, op_desc) = ops._create_c_op(g, ops._NodeDef("Identity", "cond/myop"), new[] { x }, new Operation[0]); | |||||
var c_op = ops._create_c_op(g, ops._NodeDef("Identity", "cond/myop"), new[] { x }, new Operation[0]); | |||||
var new_ops = g._add_new_tf_operations(); | var new_ops = g._add_new_tf_operations(); | ||||
self.assertEqual(len(new_ops), 1); | self.assertEqual(len(new_ops), 1); | ||||
return x; | return x; | ||||