Browse Source

Sequential, Metric base and Metric Util class skeleton methods added

tags/v0.20
Deepak Kumar 5 years ago
parent
commit
4deb320885
7 changed files with 145 additions and 18 deletions
  1. +29
    -0
      src/TensorFlowNET.Keras/Args.cs
  2. +1
    -1
      src/TensorFlowNET.Keras/Engine/Node.cs
  3. +2
    -1
      src/TensorFlowNET.Keras/Engine/Sequential.cs
  4. +30
    -1
      src/TensorFlowNET.Keras/Metrics/Metric.cs
  5. +32
    -4
      src/TensorFlowNET.Keras/Models.cs
  6. +0
    -10
      src/TensorFlowNET.Keras/Ops.cs
  7. +51
    -1
      src/TensorFlowNET.Keras/Utils/MetricsUtils.cs

+ 29
- 0
src/TensorFlowNET.Keras/Args.cs View File

@@ -0,0 +1,29 @@
using System;
using System.Collections.Generic;
using System.Text;

namespace Tensorflow.Keras
{
public class Args
{
private List<object> args = new List<object>();

public object this[int index]
{
get
{
return args.Count < index ? args[index] : null;
}
}

public T Get<T>(int index)
{
return args.Count < index ? (T)args[index] : default(T);
}

public void Add<T>(T arg)
{
args.Add(arg);
}
}
}

+ 1
- 1
src/TensorFlowNET.Keras/Engine/Node.cs View File

@@ -4,7 +4,7 @@ using System.Text;

namespace Tensorflow.Keras.Engine
{
class Node
public class Node
{
}
}

+ 2
- 1
src/TensorFlowNET.Keras/Engine/Sequential.cs View File

@@ -4,7 +4,8 @@ using System.Text;

namespace Tensorflow.Keras.Engine
{
class Sequential
public class Sequential
{
}
}

+ 30
- 1
src/TensorFlowNET.Keras/Metrics/Metric.cs View File

@@ -1,10 +1,39 @@
using System;
using System.Collections;
using System.Collections.Generic;
using System.Text;

namespace Tensorflow.Keras.Metrics
{
public abstract class Metric
public abstract class Metric : Layers.Layer
{
public string dtype
{
get
{
throw new NotImplementedException();
}
}

public Metric(string name, string dtype)
{
throw new NotImplementedException();
}

public void __new__ (Metric cls, Args args, KwArgs kwargs) => throw new NotImplementedException();

public Tensor __call__(Metric cls, Args args, KwArgs kwargs) => throw new NotImplementedException();

public virtual Hashtable get_config() => throw new NotImplementedException();

public virtual void reset_states() => throw new NotImplementedException();

public abstract void update_state(Args args, KwArgs kwargs);

public abstract Tensor result();

public void add_weight(string name, TensorShape shape= null, VariableAggregation aggregation= VariableAggregation.Sum,
VariableSynchronization synchronization = VariableSynchronization.OnRead, Initializers.Initializer initializer= null,
string dtype= null) => throw new NotImplementedException();
}
}

+ 32
- 4
src/TensorFlowNET.Keras/Models.cs View File

@@ -1,14 +1,42 @@
using System;
using Keras.Layers;
using System;
using System.Collections;
using System.Collections.Generic;
using System.Text;
using Tensorflow.Keras.Engine;

namespace Tensorflow.Keras
{
class Models
{
public class Model : Keras.Engine.Training.Model
{
public class Model : Keras.Engine.Training.Model{}

}
public static Layer share_weights(Layer layer) => throw new NotImplementedException();

private static Layer _clone_layer(Layer layer) => throw new NotImplementedException();

private static Layer _insert_ancillary_layers(Model model, Layer ancillary_layers, string[] metrics_names, Node[] new_nodes) => throw new NotImplementedException();

private static Node[] _make_new_nodes(Node[] nodes_by_depth, Func<Layer, Layer> layer_fn, Hashtable layer_map, Hashtable tensor_map) => throw new NotImplementedException();

private static Model _clone_functional_model(Model model, Tensor[] input_tensors = null, Func<Layer, Layer> layer_fn = null) => throw new NotImplementedException();

private static (Hashtable, Layer[]) _clone_layers_and_model_config(Model model, Layer[] input_layers, Func<Layer, Layer> layer_fn) => throw new NotImplementedException();

private static (Layer[], Layer[]) _remove_ancillary_layers(Model model, Hashtable layer_map, Layer[] layers) => throw new NotImplementedException();

private static Sequential _clone_sequential_model(Model model, Tensor[] input_tensors = null, Func<Layer, Layer> layer_fn = null) => throw new NotImplementedException();

public static Model clone_model(Model model, Tensor[] input_tensors = null, Func<Layer, Layer> layer_fn = null) => throw new NotImplementedException();

private static void _in_place_subclassed_model_reset(Model model) => throw new NotImplementedException();

private static void _reset_build_compile_trackers(Model model) => throw new NotImplementedException();

public static void in_place_subclassed_model_state_restoration(Model model) => throw new NotImplementedException();

public static void clone_and_build_model(Model model, Tensor[] input_tensors= null, Tensor[] target_tensors= null, object custom_objects= null,
bool compile_clone= true, bool in_place_reset= false, VariableV1 optimizer_iterations= null, Hashtable optimizer_config= null)
=> throw new NotImplementedException();
}
}

+ 0
- 10
src/TensorFlowNET.Keras/Ops.cs View File

@@ -1,10 +0,0 @@
using System;
using System.Collections.Generic;
using System.Text;

namespace Tensorflow.Keras
{
class Ops
{
}
}

+ 51
- 1
src/TensorFlowNET.Keras/Utils/MetricsUtils.cs View File

@@ -1,10 +1,60 @@
using System;
using System.Collections.Generic;
using System.Reflection;
using System.Text;

namespace Tensorflow.Keras.Utils
{
class MetricsUtils
public class MetricsUtils
{
public static class Reduction
{
public const string SUM = "sum";
public const string SUM_OVER_BATCH_SIZE = "sum_over_batch_size";
public const string WEIGHTED_MEAN = "weighted_mean";
}

public static class ConfusionMatrix
{
public const string TRUE_POSITIVES = "tp";
public const string FALSE_POSITIVES = "fp";
public const string TRUE_NEGATIVES = "tn";
public const string FALSE_NEGATIVES = "fn";
}

public static class AUCCurve
{
public const string ROC = "ROC";
public const string PR = "PR";

public static string from_str(string key) => throw new NotImplementedException();
}

public static class AUCSummationMethod
{
public const string INTERPOLATION = "interpolation";
public const string MAJORING = "majoring";
public const string MINORING = "minoring";

public static string from_str(string key) => throw new NotImplementedException();
}

public static dynamic update_state_wrapper(Func<Args, KwArgs, Func<bool>> update_state_fn) => throw new NotImplementedException();

public static dynamic result_wrapper(Func<Args, Tensor> result_fn) => throw new NotImplementedException();

public static WeakReference weakmethod(MethodInfo method) => throw new NotImplementedException();

public static void assert_thresholds_range(float[] thresholds) => throw new NotImplementedException();

public static void parse_init_thresholds(float[] thresholds, float default_threshold = 0.5f) => throw new NotImplementedException();

public static Operation update_confusion_matrix_variables(variables variables_to_update, Tensor y_true, Tensor y_pred, float[] thresholds,
int? top_k= null,int? class_id= null, Tensor sample_weight= null, bool multi_label= false,
Tensor label_weights= null) => throw new NotImplementedException();

private static Tensor _filter_top_k(Tensor x, int k) => throw new NotImplementedException();

private static (Tensor[], Tensor) ragged_assert_compatible_and_get_flat_values(Tensor[] values, Tensor mask = null) => throw new NotImplementedException();
}
}

Loading…
Cancel
Save