You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

Model.Fit.cs 14 kB

2 years ago
4 years ago
4 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359
  1. using Tensorflow.NumPy;
  2. using System;
  3. using System.Collections.Generic;
  4. using System.Linq;
  5. using Tensorflow.Keras.ArgsDefinition;
  6. using Tensorflow.Keras.Engine.DataAdapters;
  7. using System.Diagnostics;
  8. using Tensorflow.Keras.Callbacks;
  9. using System.Data;
  10. namespace Tensorflow.Keras.Engine
  11. {
  12. public partial class Model
  13. {
  14. /// <summary>
  15. /// Trains the model for a fixed number of epochs (iterations on a dataset).
  16. /// </summary>
  17. /// <param name="x"></param>
  18. /// <param name="y"></param>
  19. /// <param name="batch_size"></param>
  20. /// <param name="epochs"></param>
  21. /// <param name="callbacks"></param>
  22. /// <param name="verbose"></param>
  23. /// <param name="validation_split"></param>
  24. /// <param name="validation_data"></param>
  25. /// <param name="shuffle"></param>
  26. public ICallback fit(NDArray x, NDArray y,
  27. int batch_size = -1,
  28. int epochs = 1,
  29. int verbose = 1,
  30. List<ICallback> callbacks = null,
  31. float validation_split = 0f,
  32. (NDArray val_x, NDArray val_y)? validation_data = null,
  33. bool shuffle = true,
  34. int initial_epoch = 0,
  35. int max_queue_size = 10,
  36. int workers = 1,
  37. bool use_multiprocessing = false)
  38. {
  39. if (x.dims[0] != y.dims[0])
  40. {
  41. throw new InvalidArgumentError(
  42. $"The array x and y should have same value at dim 0, but got {x.dims[0]} and {y.dims[0]}");
  43. }
  44. var train_x = x;
  45. var train_y = y;
  46. if (validation_split != 0f && validation_data == null)
  47. {
  48. int train_count = Convert.ToInt32(x.dims[0] * (1 - validation_split));
  49. train_x = x[new Slice(0, train_count)];
  50. train_y = y[new Slice(0, train_count)];
  51. validation_data = (val_x: x[new Slice(train_count)], val_y: y[new Slice(train_count)]);
  52. }
  53. var data_handler = new DataHandler(new DataHandlerArgs
  54. {
  55. X = train_x,
  56. Y = train_y,
  57. BatchSize = batch_size,
  58. InitialEpoch = initial_epoch,
  59. Epochs = epochs,
  60. Shuffle = shuffle,
  61. MaxQueueSize = max_queue_size,
  62. Workers = workers,
  63. UseMultiprocessing = use_multiprocessing,
  64. Model = this,
  65. StepsPerExecution = _steps_per_execution
  66. });
  67. return FitInternal(data_handler, epochs, verbose, callbackList: callbacks, validation_data: validation_data,
  68. train_step_func: train_step_function);
  69. }
  70. public ICallback fit(IEnumerable<NDArray> x, NDArray y,
  71. int batch_size = -1,
  72. int epochs = 1,
  73. int verbose = 1,
  74. List<ICallback> callbacks = null,
  75. float validation_split = 0f,
  76. (IEnumerable<NDArray> val_x, NDArray val_y)? validation_data = null,
  77. bool shuffle = true,
  78. int initial_epoch = 0,
  79. int max_queue_size = 10,
  80. int workers = 1,
  81. bool use_multiprocessing = false)
  82. {
  83. foreach(var tx in x)
  84. {
  85. if (tx.dims[0] != y.dims[0])
  86. {
  87. throw new InvalidArgumentError(
  88. $"The array x and y should have same value at dim 0, but got {tx.dims[0]} and {y.dims[0]}");
  89. }
  90. }
  91. var train_x = x;
  92. var train_y = y;
  93. if (validation_split != 0f && validation_data == null)
  94. {
  95. int train_count = Convert.ToInt32(y.dims[0] * (1 - validation_split));
  96. train_x = x.Select(x => x[new Slice(0, train_count)] as NDArray);
  97. train_y = y[new Slice(0, train_count)];
  98. var val_x = x.Select(x => x[new Slice(train_count)] as NDArray);
  99. var val_y = y[new Slice(train_count)];
  100. validation_data = (val_x, val_y);
  101. }
  102. var data_handler = new DataHandler(new DataHandlerArgs
  103. {
  104. X = new Tensors(train_x.ToArray()),
  105. Y = train_y,
  106. BatchSize = batch_size,
  107. InitialEpoch = initial_epoch,
  108. Epochs = epochs,
  109. Shuffle = shuffle,
  110. MaxQueueSize = max_queue_size,
  111. Workers = workers,
  112. UseMultiprocessing = use_multiprocessing,
  113. Model = this,
  114. StepsPerExecution = _steps_per_execution
  115. });
  116. if (data_handler.DataAdapter.GetDataset().structure.Length > 2 ||
  117. data_handler.DataAdapter.GetDataset().FirstInputTensorCount > 1)
  118. {
  119. return FitInternal(data_handler, epochs, verbose, callbackList: callbacks, validation_data: validation_data,
  120. train_step_func: train_step_multi_inputs_function);
  121. }
  122. else
  123. {
  124. return FitInternal(data_handler, epochs, verbose, callbackList: callbacks, validation_data: validation_data,
  125. train_step_func: train_step_function);
  126. }
  127. }
  128. public History fit(IDatasetV2 dataset,
  129. int batch_size = -1,
  130. int epochs = 1,
  131. int verbose = 1,
  132. List<ICallback> callbacks = null,
  133. IDatasetV2 validation_data = null,
  134. int validation_step = 10, // 间隔多少次会进行一次验证
  135. bool shuffle = true,
  136. int initial_epoch = 0,
  137. int max_queue_size = 10,
  138. int workers = 1,
  139. bool use_multiprocessing = false)
  140. {
  141. var data_handler = new DataHandler(new DataHandlerArgs
  142. {
  143. Dataset = dataset,
  144. BatchSize = batch_size,
  145. InitialEpoch = initial_epoch,
  146. Epochs = epochs,
  147. Shuffle = shuffle,
  148. MaxQueueSize = max_queue_size,
  149. Workers = workers,
  150. UseMultiprocessing = use_multiprocessing,
  151. Model = this,
  152. StepsPerExecution = _steps_per_execution
  153. });
  154. return FitInternal(data_handler, epochs, validation_step, verbose, callbacks, validation_data: validation_data,
  155. train_step_func: train_step_function);
  156. }
  157. History FitInternal(DataHandler data_handler, int epochs, int validation_step, int verbose, List<ICallback> callbackList, IDatasetV2 validation_data,
  158. Func<DataHandler, OwnedIterator, Dictionary<string, float>> train_step_func)
  159. {
  160. stop_training = false;
  161. _train_counter.assign(0);
  162. var callbacks = new CallbackList(new CallbackParams
  163. {
  164. Model = this,
  165. Verbose = verbose,
  166. Epochs = epochs,
  167. Steps = data_handler.Inferredsteps
  168. });
  169. if (callbackList != null)
  170. {
  171. foreach(var callback in callbackList)
  172. callbacks.callbacks.add(callback);
  173. }
  174. callbacks.on_train_begin();
  175. foreach (var (epoch, iterator) in data_handler.enumerate_epochs())
  176. {
  177. reset_metrics();
  178. callbacks.on_epoch_begin(epoch);
  179. // data_handler.catch_stop_iteration();
  180. var logs = new Dictionary<string, float>();
  181. long End_step = 0;
  182. foreach (var step in data_handler.steps())
  183. {
  184. callbacks.on_train_batch_begin(step);
  185. logs = train_step_func(data_handler, iterator);
  186. var end_step = step + data_handler.StepIncrement;
  187. End_step = end_step;
  188. callbacks.on_train_batch_end(end_step, logs);
  189. }
  190. if (validation_data != null)
  191. {
  192. if (validation_step > 0 && epoch ==0 || (epoch) % validation_step != 0)
  193. continue;
  194. var val_logs = evaluate(validation_data);
  195. foreach(var log in val_logs)
  196. {
  197. logs["val_" + log.Key] = log.Value;
  198. }
  199. callbacks.on_train_batch_end(End_step, logs);
  200. }
  201. callbacks.on_epoch_end(epoch, logs);
  202. GC.Collect();
  203. GC.WaitForPendingFinalizers();
  204. if (stop_training)
  205. {
  206. break;
  207. }
  208. }
  209. return callbacks.History;
  210. }
  211. History FitInternal(DataHandler data_handler, int epochs, int verbose, List<ICallback> callbackList, (NDArray, NDArray)? validation_data,
  212. Func<DataHandler, OwnedIterator, Dictionary<string, float>> train_step_func)
  213. {
  214. stop_training = false;
  215. _train_counter.assign(0);
  216. var callbacks = new CallbackList(new CallbackParams
  217. {
  218. Model = this,
  219. Verbose = verbose,
  220. Epochs = epochs,
  221. Steps = data_handler.Inferredsteps
  222. });
  223. if (callbackList != null)
  224. {
  225. foreach (var callback in callbackList)
  226. callbacks.callbacks.add(callback);
  227. }
  228. callbacks.on_train_begin();
  229. foreach (var (epoch, iterator) in data_handler.enumerate_epochs())
  230. {
  231. reset_metrics();
  232. callbacks.on_epoch_begin(epoch);
  233. // data_handler.catch_stop_iteration();
  234. var logs = new Dictionary<string, float>();
  235. long End_step = 0;
  236. foreach (var step in data_handler.steps())
  237. {
  238. callbacks.on_train_batch_begin(step);
  239. logs = train_step_func(data_handler, iterator);
  240. var end_step = step + data_handler.StepIncrement;
  241. End_step = end_step;
  242. callbacks.on_train_batch_end(end_step, logs);
  243. }
  244. if (validation_data != null)
  245. {
  246. // Because evaluate calls call_test_batch_end, this interferes with our output on the screen
  247. // so we need to pass a is_val parameter to stop on_test_batch_end
  248. var val_logs = evaluate(validation_data.Value.Item1, validation_data.Value.Item2, is_val:true);
  249. foreach (var log in val_logs)
  250. {
  251. logs["val_" + log.Key] = log.Value;
  252. }
  253. // because after evaluate, logs add some new log which we need to print
  254. callbacks.on_train_batch_end(End_step, logs);
  255. }
  256. callbacks.on_epoch_end(epoch, logs);
  257. GC.Collect();
  258. GC.WaitForPendingFinalizers();
  259. if (stop_training)
  260. {
  261. break;
  262. }
  263. }
  264. return callbacks.History;
  265. }
  266. History FitInternal(DataHandler data_handler, int epochs, int verbose, List<ICallback> callbackList, (IEnumerable<Tensor>, NDArray)? validation_data,
  267. Func<DataHandler, OwnedIterator, Dictionary<string, float>> train_step_func)
  268. {
  269. stop_training = false;
  270. _train_counter.assign(0);
  271. var callbacks = new CallbackList(new CallbackParams
  272. {
  273. Model = this,
  274. Verbose = verbose,
  275. Epochs = epochs,
  276. Steps = data_handler.Inferredsteps
  277. });
  278. if (callbackList != null)
  279. {
  280. foreach (var callback in callbackList)
  281. callbacks.callbacks.add(callback);
  282. }
  283. callbacks.on_train_begin();
  284. foreach (var (epoch, iterator) in data_handler.enumerate_epochs())
  285. {
  286. reset_metrics();
  287. callbacks.on_epoch_begin(epoch);
  288. // data_handler.catch_stop_iteration();
  289. var logs = new Dictionary<string, float>();
  290. long End_step = 0;
  291. foreach (var step in data_handler.steps())
  292. {
  293. callbacks.on_train_batch_begin(step);
  294. logs = train_step_func(data_handler, iterator);
  295. var end_step = step + data_handler.StepIncrement;
  296. End_step = end_step;
  297. callbacks.on_train_batch_end(end_step, logs);
  298. }
  299. if (validation_data != null)
  300. {
  301. var val_logs = evaluate(validation_data.Value.Item1, validation_data.Value.Item2);
  302. foreach (var log in val_logs)
  303. {
  304. logs["val_" + log.Key] = log.Value;
  305. callbacks.on_train_batch_end(End_step, logs);
  306. }
  307. }
  308. callbacks.on_epoch_end(epoch, logs);
  309. GC.Collect();
  310. GC.WaitForPendingFinalizers();
  311. if (stop_training)
  312. {
  313. break;
  314. }
  315. }
  316. return callbacks.History;
  317. }
  318. }
  319. }