Browse Source

Tensor indexing. #368

tags/v0.12
Oceania2018 6 years ago
parent
commit
ca3a775e7a
2 changed files with 140 additions and 105 deletions
  1. +140
    -0
      src/TensorFlowNET.Core/Tensors/Tensor.Index.cs
  2. +0
    -105
      src/TensorFlowNET.Core/Tensors/Tensor.cs

+ 140
- 0
src/TensorFlowNET.Core/Tensors/Tensor.Index.cs View File

@@ -0,0 +1,140 @@
/*****************************************************************************
Copyright 2018 The TensorFlow.NET Authors. All Rights Reserved.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
******************************************************************************/

using NumSharp;
using System;
using System.Collections.Generic;
using System.Text;
using static Tensorflow.Binding;

namespace Tensorflow
{
public partial class Tensor
{
public Tensor this[int idx]
{
get
{
return slice(idx);
}
}

public Tensor slice(Slice slice)
{
var slice_spec = new int[] { slice.Start.Value };
var begin = new List<int>();
var end = new List<int>();
var strides = new List<int>();

var index = 0;
var (new_axis_mask, shrink_axis_mask) = (0, 0);
var (begin_mask, end_mask) = (0, 0);
var ellipsis_mask = 0;

foreach (var s in slice_spec)
{
begin.Add(s);
if (slice.Stop.HasValue)
{
end.Add(slice.Stop.Value);
}
else
{
end.Add(0);
end_mask |= (1 << index);
}

strides.Add(slice.Step);

index += 1;
}

return tf_with(ops.name_scope(null, "strided_slice", new { begin, end, strides }), scope =>
{
string name = scope;
if (begin != null)
{
var (packed_begin, packed_end, packed_strides) =
(array_ops.stack(begin.ToArray()),
array_ops.stack(end.ToArray()),
array_ops.stack(strides.ToArray()));

return gen_array_ops.strided_slice(
this,
packed_begin,
packed_end,
packed_strides,
begin_mask: begin_mask,
end_mask: end_mask,
shrink_axis_mask: shrink_axis_mask,
new_axis_mask: new_axis_mask,
ellipsis_mask: ellipsis_mask,
name: name);
}

throw new NotImplementedException("");
});
}

public Tensor slice(int start)
{
var slice_spec = new int[] { start };
var begin = new List<int>();
var end = new List<int>();
var strides = new List<int>();

var index = 0;
var (new_axis_mask, shrink_axis_mask) = (0, 0);
var (begin_mask, end_mask) = (0, 0);
var ellipsis_mask = 0;

foreach (var s in slice_spec)
{
begin.Add(s);
end.Add(s + 1);
strides.Add(1);
shrink_axis_mask |= (1 << index);
index += 1;
}

return tf_with(ops.name_scope(null, "strided_slice", new { begin, end, strides }), scope =>
{
string name = scope;
if (begin != null)
{
var (packed_begin, packed_end, packed_strides) =
(array_ops.stack(begin.ToArray()),
array_ops.stack(end.ToArray()),
array_ops.stack(strides.ToArray()));

return gen_array_ops.strided_slice(
this,
packed_begin,
packed_end,
packed_strides,
begin_mask: begin_mask,
end_mask: end_mask,
shrink_axis_mask: shrink_axis_mask,
new_axis_mask: new_axis_mask,
ellipsis_mask: ellipsis_mask,
name: name);
}

throw new NotImplementedException("");
});
}
}
}

+ 0
- 105
src/TensorFlowNET.Core/Tensors/Tensor.cs View File

@@ -28,7 +28,6 @@ using NumSharp.Backends;
using NumSharp.Backends.Unmanaged;
using NumSharp.Utilities;
using Tensorflow.Framework;
using static Tensorflow.Binding;

namespace Tensorflow
{
@@ -436,110 +435,6 @@ namespace Tensorflow
return ops._eval_using_default_session(this, feed_dict, graph, session);
}

public Tensor slice(Slice slice)
{
var slice_spec = new int[] {slice.Start.Value};
var begin = new List<int>();
var end = new List<int>();
var strides = new List<int>();

var index = 0;
var (new_axis_mask, shrink_axis_mask) = (0, 0);
var (begin_mask, end_mask) = (0, 0);
var ellipsis_mask = 0;

foreach (var s in slice_spec)
{
begin.Add(s);
if (slice.Stop.HasValue)
{
end.Add(slice.Stop.Value);
} else
{
end.Add(0);
end_mask |= (1 << index);
}

strides.Add(slice.Step);

index += 1;
}

return tf_with(ops.name_scope(null, "strided_slice", new {begin, end, strides}), scope =>
{
string name = scope;
if (begin != null)
{
var (packed_begin, packed_end, packed_strides) =
(array_ops.stack(begin.ToArray()),
array_ops.stack(end.ToArray()),
array_ops.stack(strides.ToArray()));

return gen_array_ops.strided_slice(
this,
packed_begin,
packed_end,
packed_strides,
begin_mask: begin_mask,
end_mask: end_mask,
shrink_axis_mask: shrink_axis_mask,
new_axis_mask: new_axis_mask,
ellipsis_mask: ellipsis_mask,
name: name);
}

throw new NotImplementedException("");
});
}

public Tensor slice(int start)
{
var slice_spec = new int[] {start};
var begin = new List<int>();
var end = new List<int>();
var strides = new List<int>();

var index = 0;
var (new_axis_mask, shrink_axis_mask) = (0, 0);
var (begin_mask, end_mask) = (0, 0);
var ellipsis_mask = 0;

foreach (var s in slice_spec)
{
begin.Add(s);
end.Add(s + 1);
strides.Add(1);
shrink_axis_mask |= (1 << index);
index += 1;
}

return tf_with(ops.name_scope(null, "strided_slice", new {begin, end, strides}), scope =>
{
string name = scope;
if (begin != null)
{
var (packed_begin, packed_end, packed_strides) =
(array_ops.stack(begin.ToArray()),
array_ops.stack(end.ToArray()),
array_ops.stack(strides.ToArray()));

return gen_array_ops.strided_slice(
this,
packed_begin,
packed_end,
packed_strides,
begin_mask: begin_mask,
end_mask: end_mask,
shrink_axis_mask: shrink_axis_mask,
new_axis_mask: new_axis_mask,
ellipsis_mask: ellipsis_mask,
name: name);
}

throw new NotImplementedException("");
});
}

public override string ToString()
{
// this can throw IndexOutOfRangeException


Loading…
Cancel
Save