Browse Source

Base layer skeleton added

tags/v0.20
Deepak Battini 5 years ago
parent
commit
2dd7ac9195
2 changed files with 263 additions and 6 deletions
  1. +65
    -2
      src/TensorFlowNET.Keras/Engine/BaseLayer.cs
  2. +198
    -4
      src/TensorFlowNET.Keras/Layers/Layer.cs

+ 65
- 2
src/TensorFlowNET.Keras/Engine/BaseLayer.cs View File

@@ -1,10 +1,73 @@
using System;
using Keras.Layers;
using NumSharp;
using System;
using System.Collections.Generic;
using System.Text;

namespace Tensorflow.Keras.Engine
{
public class Layer
public class TensorFlowOpLayer : Layer
{
public TensorFlowOpLayer(string node_def, string name, NDArray[] constants = null, bool trainable = true, string dtype = null)
{

}

public override void call(Tensor[] inputs)
{
throw new NotImplementedException();
}

public override Dictionary<string, object> get_config()
{
throw new NotImplementedException();
}

private NodeDef _make_node_def(Graph graph) => throw new NotImplementedException();

private Tensor[] _make_op(Tensor[] inputs) => throw new NotImplementedException();

private Tensor[] _defun_call(Tensor[] inputs) => throw new NotImplementedException();
}

public class AddLoss : Layer
{
public AddLoss(bool unconditional)
{
throw new NotImplementedException();
}

public override void call(Tensor[] inputs)
{
throw new NotImplementedException();
}

public override Dictionary<string, object> get_config()
{
throw new NotImplementedException();
}
}

public class AddMetric : Layer
{
public AddMetric(string aggregation = null, string metric_name = null)
{
throw new NotImplementedException();
}

public override void call(Tensor[] inputs)
{
throw new NotImplementedException();
}

public override Dictionary<string, object> get_config()
{
throw new NotImplementedException();
}
}

public class KerasHistory
{

}
}

+ 198
- 4
src/TensorFlowNET.Keras/Layers/Layer.cs View File

@@ -5,6 +5,7 @@ using Tensorflow;
using Tensorflow.Keras.Constraints;
using Tensorflow.Keras.Initializers;
using Tensorflow.Keras.Losses;
using Tensorflow.Keras.Metrics;
using Tensorflow.Keras.Regularizers;

namespace Keras.Layers
@@ -87,7 +88,7 @@ namespace Keras.Layers
}
}

public Tensor[] weights
private Tensor[] _weights
{
get
{
@@ -167,6 +168,38 @@ namespace Keras.Layers
}
}

public Tensor[] variables
{
get
{
return _weights;
}
}

public Tensor[] trainable_variables
{
get
{
return trainable_weights;
}
}

public Tensor[] non_trainable_variables
{
get
{
return non_trainable_weights;
}
}

private string _compute_dtype
{
get
{
throw new NotImplementedException();
}
}

public Layer(bool trainable = true, string name = null, string dtype = null, bool @dynamic = false, Dictionary<string, object> kwargs = null)
{

@@ -174,7 +207,7 @@ namespace Keras.Layers

public void build(TensorShape shape) => throw new NotImplementedException();

public void call(Tensor[] inputs) => throw new NotImplementedException();
public virtual void call(Tensor[] inputs) => throw new NotImplementedException();

public void _add_trackable(dynamic trackable_object, bool trainable) => throw new NotImplementedException();

@@ -183,7 +216,7 @@ namespace Keras.Layers
dynamic partitioner= null, bool? use_resource= null, VariableSynchronization synchronization= VariableSynchronization.Auto,
VariableAggregation aggregation= VariableAggregation.None, Dictionary<string, object> kwargs = null) => throw new NotImplementedException();

public Dictionary<string, object> get_config() => throw new NotImplementedException();
public virtual Dictionary<string, object> get_config() => throw new NotImplementedException();

public Layer from_config(Dictionary<string, object> config) => throw new NotImplementedException();

@@ -224,5 +257,166 @@ namespace Keras.Layers
public Tensor[] get_output_at(int node_index) => throw new NotImplementedException();

public int count_params() => throw new NotImplementedException();
}

private void _set_dtype_policy(string dtype) => throw new NotImplementedException();

private Tensor _maybe_cast_inputs(Tensor inputs) => throw new NotImplementedException();

private void _warn_about_input_casting(string input_dtype) => throw new NotImplementedException();

private string _name_scope()
{
return name;
}

private string _obj_reference_counts
{
get
{
throw new NotImplementedException();
}
}

private dynamic _attribute_sentinel
{
get
{
throw new NotImplementedException();
}
}

private dynamic _call_full_argspec
{
get
{
throw new NotImplementedException();
}
}

private string[] _call_fn_args
{
get
{
throw new NotImplementedException();
}
}

private string[] _call_accepts_kwargs
{
get
{
throw new NotImplementedException();
}
}

private bool _should_compute_mask
{
get
{
throw new NotImplementedException();
}
}

private Tensor[] _eager_losses
{
get
{
throw new NotImplementedException();
}
set
{
throw new NotImplementedException();
}
}

private dynamic _trackable_saved_model_saver
{
get
{
throw new NotImplementedException();
}
}

private string _object_identifier
{
get
{
throw new NotImplementedException();
}
}

private string _tracking_metadata
{
get
{
throw new NotImplementedException();
}
}

public Dictionary<string, bool> state
{
get
{
throw new NotImplementedException();
}
set
{
throw new NotImplementedException();
}
}

private void _init_set_name(string name, bool zero_based= true) => throw new NotImplementedException();

private Metric _get_existing_metric(string name = null) => throw new NotImplementedException();

private void _eager_add_metric(Metric value, string aggregation= null, string name= null) => throw new NotImplementedException();

private void _symbolic_add_metric(Metric value, string aggregation = null, string name = null) => throw new NotImplementedException();

private void _handle_weight_regularization(string name, VariableV1 variable, Regularizer regularizer) => throw new NotImplementedException();

private void _handle_activity_regularization(Tensor[] inputs, Tensor[] outputs) => throw new NotImplementedException();

private void _set_mask_metadata(Tensor[] inputs, Tensor[] outputs, Tensor previous_mask) => throw new NotImplementedException();

private Tensor[] _collect_input_masks(Tensor[] inputs, Dictionary<string, object> args, Dictionary<string, object> kwargs) => throw new NotImplementedException();

private bool _call_arg_was_passed(string arg_name, Dictionary<string, object> args, Dictionary<string, object> kwargs, bool inputs_in_args= false) => throw new NotImplementedException();

private T _get_call_arg_value<T>(string arg_name, Dictionary<string, object> args, Dictionary<string, object> kwargs, bool inputs_in_args = false) => throw new NotImplementedException();

private (Tensor[], Tensor[]) _set_connectivity_metadata_(Tensor[] inputs, Tensor[] outputs, Dictionary<string, object> args, Dictionary<string, object> kwargs) => throw new NotImplementedException();

private void _add_inbound_node(Tensor[] input_tensors, Tensor[] output_tensors, Dictionary<string, object> args = null) => throw new NotImplementedException();

private AttrValue _get_node_attribute_at_index(int node_index, string attr, string attr_name) => throw new NotImplementedException();

private void _maybe_build(Tensor[] inputs) => throw new NotImplementedException();

private void _symbolic_call(Tensor[] inputs) => throw new NotImplementedException();

private Dictionary<Layer, bool> _get_trainable_state() => throw new NotImplementedException();

private void _set_trainable_state(bool trainable_state) => throw new NotImplementedException();

private void _maybe_create_attribute(string name, object default_value) => throw new NotImplementedException();

private void __delattr__(string name) => throw new NotImplementedException();

private void __setattr__(string name, object value) => throw new NotImplementedException();

private List<AttrValue> _gather_children_attribute(string attribute) => throw new NotImplementedException();

private List<Layer> _gather_unique_layers() => throw new NotImplementedException();

private List<Layer> _gather_layers() => throw new NotImplementedException();

private bool _is_layer() => throw new NotImplementedException();

private void _init_call_fn_args() => throw new NotImplementedException();

public dynamic _list_extra_dependencies_for_serialization(dynamic serialization_cache) => throw new NotImplementedException();

public dynamic _list_functions_for_serialization(dynamic serialization_cache) => throw new NotImplementedException();
}
}

Loading…
Cancel
Save