Browse Source

Multithreading with Keras models. #890

tags/TimeSeries
Oceania2018 3 years ago
parent
commit
74fda601fb
2 changed files with 4 additions and 3 deletions
  1. +2
    -1
      src/TensorFlowNET.Keras/KerasInterface.cs
  2. +2
    -2
      test/TensorFlowNET.Keras.UnitTest/MultiThreadsTest.cs

+ 2
- 1
src/TensorFlowNET.Keras/KerasInterface.cs View File

@@ -25,7 +25,8 @@ namespace Tensorflow.Keras
public LossesApi losses { get; } = new LossesApi();
public Activations activations { get; } = new Activations();
public Preprocessing preprocessing { get; } = new Preprocessing();
public BackendImpl backend { get; } = new BackendImpl();
ThreadLocal<BackendImpl> _backend = new ThreadLocal<BackendImpl>(() => new BackendImpl());
public BackendImpl backend => _backend.Value;
public OptimizerApi optimizers { get; } = new OptimizerApi();
public MetricsApi metrics { get; } = new MetricsApi();
public ModelsApi models { get; } = new ModelsApi();


+ 2
- 2
test/TensorFlowNET.Keras.UnitTest/MultiThreadsTest.cs View File

@@ -60,9 +60,9 @@ namespace TensorFlowNET.Keras.UnitTest

//act
ParallelOptions parallelOptions = new ParallelOptions();
parallelOptions.MaxDegreeOfParallelism = 1;
parallelOptions.MaxDegreeOfParallelism = 8;
var input = np.array(new float[,] { { 0, 0 } });
Parallel.For(0, 1, parallelOptions, i =>
Parallel.For(0, 8, parallelOptions, i =>
{
var clone = BuildModel();
clone.load_weights(savefile);


Loading…
Cancel
Save