You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

Inputs.cs 2.1 kB

6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960
  1. using System;
  2. using System.Collections.Generic;
  3. using System.Text;
  4. using Tensorflow.Data;
  5. using Tensorflow.Models.ObjectDetection.Protos;
  6. namespace Tensorflow.Models.ObjectDetection
  7. {
  8. public class Inputs
  9. {
  10. ModelBuilder modelBuilder;
  11. DatasetBuilder datasetBuilder;
  12. public Inputs()
  13. {
  14. modelBuilder = new ModelBuilder();
  15. datasetBuilder = new DatasetBuilder();
  16. }
  17. public Func<DatasetV1Adapter> create_train_input_fn(TrainConfig train_config, InputReader train_input_config, DetectionModel model_config)
  18. {
  19. Func<DatasetV1Adapter> _train_input_fn = () =>
  20. train_input(train_config, train_input_config, model_config);
  21. return _train_input_fn;
  22. }
  23. /// <summary>
  24. /// Returns `features` and `labels` tensor dictionaries for training.
  25. /// </summary>
  26. /// <param name="train_config"></param>
  27. /// <param name="train_input_config"></param>
  28. /// <param name="model_config"></param>
  29. /// <returns></returns>
  30. public DatasetV1Adapter train_input(TrainConfig train_config, InputReader train_input_config, DetectionModel model_config)
  31. {
  32. var arch = modelBuilder.build(model_config, true, true);
  33. Func<Tensor, (Tensor, Tensor)> model_preprocess_fn = arch.preprocess;
  34. Func<Dictionary<string, Tensor>, (Dictionary<string, Tensor>, Dictionary<string, Tensor>) > transform_and_pad_input_data_fn = (tensor_dict) =>
  35. {
  36. return (_get_features_dict(tensor_dict), _get_labels_dict(tensor_dict));
  37. };
  38. var dataset = datasetBuilder.build(train_input_config);
  39. return dataset;
  40. }
  41. private Dictionary<string, Tensor> _get_features_dict(Dictionary<string, Tensor> input_dict)
  42. {
  43. throw new NotImplementedException("_get_features_dict");
  44. }
  45. private Dictionary<string, Tensor> _get_labels_dict(Dictionary<string, Tensor> input_dict)
  46. {
  47. throw new NotImplementedException("_get_labels_dict");
  48. }
  49. }
  50. }