Browse Source

Allow list parameter in OpDefLibrary

tags/v0.8.0
haiping008 6 years ago
parent
commit
5ccef1bd65
8 changed files with 149 additions and 18 deletions
  1. +19
    -0
      src/TensorFlowNET.Core/Exceptions/TypeError.cs
  2. +71
    -15
      src/TensorFlowNET.Core/Operations/OpDefLibrary.cs
  3. +2
    -2
      src/TensorFlowNET.Core/Operations/gen_array_ops.cs
  4. +19
    -0
      src/TensorFlowNET.Core/Operations/gen_data_flow_ops.py.cs
  5. +3
    -1
      src/TensorFlowNET.Core/Operations/math_ops.py.cs
  6. +6
    -0
      src/TensorFlowNET.Core/Python.cs
  7. +1
    -0
      src/TensorFlowNET.Core/ops.name_scope.cs
  8. +28
    -0
      src/TensorFlowNET.Core/ops.py.cs

+ 19
- 0
src/TensorFlowNET.Core/Exceptions/TypeError.cs View File

@@ -0,0 +1,19 @@
using System;
using System.Collections.Generic;
using System.Text;

namespace Tensorflow
{
public class TypeError : Exception
{
public TypeError() : base()
{

}

public TypeError(string message) : base(message)
{

}
}
}

+ 71
- 15
src/TensorFlowNET.Core/Operations/OpDefLibrary.cs View File

@@ -3,6 +3,7 @@ using System.Collections.Generic;
using System.ComponentModel;
using System.Dynamic;
using System.IO;
using System.Linq;
using System.Runtime.InteropServices;
using System.Text;
using static Tensorflow.OpDef.Types;
@@ -41,6 +42,7 @@ namespace Tensorflow
var attrs = new Dictionary<string, object>();
var inputs = new List<Tensor>();
var input_types = new List<TF_DataType>();
var base_types = new List<TF_DataType>();

Operation op = null;
Python.with<ops.name_scope>(new ops.name_scope(name), scope =>
@@ -49,34 +51,67 @@ namespace Tensorflow
foreach (var input_arg in op_def.InputArg)
{
var input_name = input_arg.Name;
if (keywords[input_name] is double int_value)
var values = keywords[input_name];
// Goals:
// * Convert values to Tensors if it contains constants.
// * Verify that values is a list if that matches the input_arg's
// type.
// * If the input_arg's type is determined by attrs, either set
// those attrs and validate those attr values are legal (if
// they have not yet been set) or validate the input matches
// the type indicated by the attrs (if they have already been
// inferred via an earlier input).
// * If the input_arg has an explicit type, make sure the input
// conforms.

if (_IsListParameter(input_arg))
{
keywords[input_name] = constant_op.constant(int_value, input_name);
}
DataType dtype = DataType.DtInvalid;
DataType default_dtype = DataType.DtInvalid;

if (keywords[input_name] is Tensor value)
{
if (keywords.ContainsKey(input_name))
if (!_IsListValue(values))
throw new TypeError($"Expected list for '{input_name}' argument to '{op_type_name}' Op, not {values}.");
if(input_arg.Type != DataType.DtInvalid)
{
inputs.Add(value);
dtype = input_arg.Type;
}

if (!String.IsNullOrEmpty(input_arg.TypeAttr))
else if (!String.IsNullOrEmpty(input_arg.NumberAttr))
{
attrs[input_arg.TypeAttr] = value.dtype;
}

if (input_arg.IsRef)
{
if(input_arg.IsRef && dtype != DataType.DtInvalid)
dtype = dtype.as_base_dtype();

values = ops.internal_convert_n_to_tensor(values, name: input_arg.Name, dtype: dtype, preferred_dtype: default_dtype, as_ref: input_arg.IsRef);
inputs.AddRange(values as Tensor[]);
}
else
{
if (!(values is Tensor))
{
keywords[input_name] = constant_op.constant(values, input_name);
}
else

if (keywords[input_name] is Tensor value)
{
var base_type = value.dtype.as_base_dtype();
if (keywords.ContainsKey(input_name))
{
inputs.Add(value);
}

if (!String.IsNullOrEmpty(input_arg.TypeAttr))
{
attrs[input_arg.TypeAttr] = value.dtype;
}

input_types.Add(base_type);
values = new Tensor[] { value };
}
}

base_types.AddRange((values as Tensor[]).Select(x => x.dtype.as_base_dtype()));
input_types.AddRange(base_types);
}

// Process remaining attrs
@@ -152,6 +187,27 @@ namespace Tensorflow
return v.as_base_dtype().as_datatype_enum();
}

private bool _IsListParameter(ArgDef arg)
{
if (!String.IsNullOrEmpty(arg.NumberAttr))
return true;
else if (!String.IsNullOrEmpty(arg.TypeListAttr))
return true;
else
return false;
}

private bool _IsListValue(object v)
{
switch (v)
{
case Tensor[] val:
return true;
default:
return false;
}
}

private Dictionary<string, object> ConvertToDict(dynamic dyn)
{
var dictionary = new Dictionary<string, object>();


+ 2
- 2
src/TensorFlowNET.Core/Operations/gen_array_ops.cs View File

@@ -50,10 +50,10 @@ namespace Tensorflow
/// Creates a tensor filled with a scalar value.
/// </summary>
/// <param name="dims">A `Tensor`.</param>
/// <param name="value">A `Tensor`.</param>
/// <param name="value">A `Tensor`. 0-D (scalar). Value to fill the returned tensor.</param>
/// <param name="name">A name for the operation (optional).</param>
/// <returns>A `Tensor`. Has the same type as `value`.</returns>
public static Tensor fill(Tensor dims, Tensor value, string name = "")
public static Tensor fill<T>(Tensor dims, T value, string name = "")
{
var _op = _op_def_lib._apply_op_helper("Fill", name, new { dims, value });



+ 19
- 0
src/TensorFlowNET.Core/Operations/gen_data_flow_ops.py.cs View File

@@ -0,0 +1,19 @@
using System;
using System.Collections.Generic;
using System.Text;

namespace Tensorflow
{
public class gen_data_flow_ops
{
public static OpDefLibrary _op_def_lib = new OpDefLibrary();

public static Tensor dynamic_stitch(Tensor[] indices, Tensor[] data, string name = "")
{
var _attr_N = indices.Length;
var _op = _op_def_lib._apply_op_helper("DynamicStitch", name, new { indices, data });

return _op.outputs[0];
}
}
}

+ 3
- 1
src/TensorFlowNET.Core/Operations/math_ops.py.cs View File

@@ -20,8 +20,10 @@ namespace Tensorflow
var input_rank = array_ops.size(input_shape);
axes = (axes + input_rank) % input_rank;
var axes_shape = array_ops.shape(axes);
var a1 = new Tensor[] { input_rank, axes };
var a2 = new Tensor[] { input_shape, gen_array_ops.fill(axes_shape, 1) };

return null;
return gen_data_flow_ops.dynamic_stitch(a1, a2);
}

/// <summary>


+ 6
- 0
src/TensorFlowNET.Core/Python.cs View File

@@ -83,6 +83,12 @@ namespace Tensorflow
for (int i = 0; i < t1.Count; i++)
yield return (t1[i], t2[i]);
}

public static IEnumerable<(int, T)> enumerate<T>(IList<T> values)
{
for (int i = 0; i < values.Count; i++)
yield return (i, values[i]);
}
}

public interface IPython : IDisposable


+ 1
- 0
src/TensorFlowNET.Core/ops.name_scope.cs View File

@@ -46,6 +46,7 @@ namespace Tensorflow
public void Dispose()
{
var g = get_default_graph();
Console.WriteLine($"name_scope: {g._name_stack} -> {old_stack}");
g._name_stack = old_stack;
}



+ 28
- 0
src/TensorFlowNET.Core/ops.py.cs View File

@@ -309,5 +309,33 @@ namespace Tensorflow
return (p1.GetValue(result, null) as Tensor, p2.GetValue(result, null) as Tensor);*/
};
}

public static T[] internal_convert_n_to_tensor<T>(T[] values, DataType dtype = DataType.DtInvalid,
string name = "", DataType preferred_dtype = DataType.DtInvalid,
bool as_ref = false)
{
var ret = new List<T>();

foreach((int i, T value) in Python.enumerate(values))
{
string n = string.IsNullOrEmpty(name) ? "" : $"{name}_{i}";
ret.Add(internal_convert_to_tensor(value, dtype: dtype, name: n, as_ref: as_ref, preferred_dtype: preferred_dtype));
}

return ret.ToArray();
}

public static T internal_convert_to_tensor<T>(T value, DataType dtype = DataType.DtInvalid,
string name = "", DataType preferred_dtype = DataType.DtInvalid,
bool as_ref = false)
{
switch (typeof(T).Name)
{
case "Tensor":
return value;
default:
throw new NotImplementedException("internal_convert_to_tensor");
}
}
}
}

Loading…
Cancel
Save