Browse Source

Base Layer methods in progress

tags/v0.20
Deepak Kumar 5 years ago
parent
commit
22c05f0155
3 changed files with 194 additions and 8 deletions
  1. +0
    -6
      src/TensorFlowNET.Keras/IInitializer.cs
  2. +193
    -1
      src/TensorFlowNET.Keras/Layers/Layer.cs
  3. +1
    -1
      src/TensorFlowNET.Keras/Losses/Loss.cs

+ 0
- 6
src/TensorFlowNET.Keras/IInitializer.cs View File

@@ -1,6 +0,0 @@
namespace Keras
{
interface IInitializer
{
}
}

+ 193
- 1
src/TensorFlowNET.Keras/Layers/Layer.cs View File

@@ -1,14 +1,172 @@
using System;
using NumSharp;
using System;
using System.Collections.Generic;
using Tensorflow;
using Tensorflow.Keras.Constraints;
using Tensorflow.Keras.Initializers;
using Tensorflow.Keras.Losses;
using Tensorflow.Keras.Regularizers;

namespace Keras.Layers
{
public abstract class Layer
{
public TF_DataType dtype
{
get
{
throw new NotImplementedException();
}
}

public string name
{
get
{
throw new NotImplementedException();
}
}

public bool stateful
{
get
{
throw new NotImplementedException();
}
set
{
throw new NotImplementedException();
}
}

public bool trainable
{
get
{
throw new NotImplementedException();
}
}

public Regularizer activity_regularizer
{
get
{
throw new NotImplementedException();
}
set
{
throw new NotImplementedException();
}
}

public dynamic input_spec
{
get
{
throw new NotImplementedException();
}
set
{
throw new NotImplementedException();
}
}

public Tensor[] trainable_weights
{
get
{
throw new NotImplementedException();
}
}

public Tensor[] non_trainable_weights
{
get
{
throw new NotImplementedException();
}
}

public Tensor[] weights
{
get
{
throw new NotImplementedException();
}
}

public Func<bool>[] updates
{
get
{
throw new NotImplementedException();
}
}

public Tensor[] losses
{
get
{
throw new NotImplementedException();
}
}

public Tensor[] metrics
{
get
{
throw new NotImplementedException();
}
}

public Tensor[] input_mask
{
get
{
throw new NotImplementedException();
}
}

public Tensor[] output_mask
{
get
{
throw new NotImplementedException();
}
}

public Tensor[] input
{
get
{
throw new NotImplementedException();
}
}

public Tensor[] output
{
get
{
throw new NotImplementedException();
}
}

public TensorShape[] input_shape
{
get
{
throw new NotImplementedException();
}
}

public TensorShape[] output_shape
{
get
{
throw new NotImplementedException();
}
}

public Layer(bool trainable = true, string name = null, string dtype = null, bool @dynamic = false, Dictionary<string, object> kwargs = null)
{

@@ -32,5 +190,39 @@ namespace Keras.Layers
public TensorShape compute_output_shape(TensorShape input_shape) => throw new NotImplementedException();

public dynamic compute_output_signature(dynamic input_signature) => throw new NotImplementedException();

public Tensor[] compute_mask(Tensor[] inputs, Tensor[] mask = null) => throw new NotImplementedException();

public void __call__(Tensor[] inputs) => throw new NotImplementedException();

public void add_loss(Loss[] losses, Tensor[] inputs = null) => throw new NotImplementedException();

public void _clear_losses() => throw new NotImplementedException();

public void add_metric(Tensor value, string aggregation= null, string name= null) => throw new NotImplementedException();

public void add_update(Func<bool>[] updates) => throw new NotImplementedException();

public void set_weights(NDArray[] weights) => throw new NotImplementedException();

public NDArray[] get_weights() => throw new NotImplementedException();

public Func<bool>[] get_updates_for(Tensor[] inputs) => throw new NotImplementedException();

public Tensor[] get_losses_for(Tensor[] inputs) => throw new NotImplementedException();

public Tensor[] get_input_mask_at(int node_index) => throw new NotImplementedException();

public Tensor[] get_output_mask_at(int node_index) => throw new NotImplementedException();

public TensorShape[] get_input_shape_at(int node_index) => throw new NotImplementedException();

public TensorShape[] get_output_shape_at(int node_index) => throw new NotImplementedException();

public Tensor[] get_input_at(int node_index) => throw new NotImplementedException();

public Tensor[] get_output_at(int node_index) => throw new NotImplementedException();

public int count_params() => throw new NotImplementedException();
}
}

+ 1
- 1
src/TensorFlowNET.Keras/Losses/Loss.cs View File

@@ -4,7 +4,7 @@ using System.Text;

namespace Tensorflow.Keras.Losses
{
class Loss
public abstract class Loss
{
}
}

Loading…
Cancel
Save