Browse Source

ConcatV2 should be recorded #701

tags/keras_v0.3.0
Oceania2018 4 years ago
parent
commit
7293c328f7
2 changed files with 5 additions and 3 deletions
  1. +4
    -2
      src/TensorFlowNET.Core/Eager/EagerRunner.TFE_FastPathExecute.cs
  2. +1
    -1
      src/TensorFlowNET.Keras/Layers/Reshaping/Reshape.cs

+ 4
- 2
src/TensorFlowNET.Core/Eager/EagerRunner.TFE_FastPathExecute.cs View File

@@ -380,8 +380,10 @@ namespace Tensorflow.Eager
c_api.TFE_OpSetAttrBool(op, key, Convert.ToBoolean(value)); c_api.TFE_OpSetAttrBool(op, key, Convert.ToBoolean(value));
break; break;
case TF_AttrType.TF_ATTR_INT: case TF_AttrType.TF_ATTR_INT:
attr_list_sizes[key] = Convert.ToInt64(value);
c_api.TFE_OpSetAttrInt(op, key, attr_list_sizes[key]);
var size = Convert.ToInt64(value);
c_api.TFE_OpSetAttrInt(op, key, size);
if (attr_list_sizes != null)
attr_list_sizes[key] = size;
break; break;
case TF_AttrType.TF_ATTR_FLOAT: case TF_AttrType.TF_ATTR_FLOAT:
c_api.TFE_OpSetAttrFloat(op, key, Convert.ToSingle(value)); c_api.TFE_OpSetAttrFloat(op, key, Convert.ToSingle(value));


+ 1
- 1
src/TensorFlowNET.Keras/Layers/Reshaping/Reshape.cs View File

@@ -22,7 +22,7 @@ namespace Tensorflow.Keras.Layers
protected override Tensors Call(Tensors inputs, Tensor state = null, bool is_training = false) protected override Tensors Call(Tensors inputs, Tensor state = null, bool is_training = false)
{ {
var shape_tensor = array_ops.shape(inputs); var shape_tensor = array_ops.shape(inputs);
var shape = new List<int> { shape_tensor.shape[0] };
var shape = new List<int> { inputs.shape[0] };
shape.AddRange(args.TargetShape.dims); shape.AddRange(args.TargetShape.dims);


var result = array_ops.reshape(inputs, shape.ToArray()); var result = array_ops.reshape(inputs, shape.ToArray());


Loading…
Cancel
Save