You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

interface.py 6.3 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169
  1. # Copyright 2023 The KubeEdge Authors.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. import os
  15. import cv2
  16. import numpy as np
  17. import torch
  18. from PIL import Image
  19. from torchvision import transforms
  20. from torch.utils.data import DataLoader
  21. from torchvision import transforms
  22. from sedna.common.config import Context
  23. from sedna.common.file_ops import FileOps
  24. from sedna.common.log import LOGGER
  25. from sedna.common.config import BaseConfig
  26. from dataloaders import custom_transforms as tr
  27. from utils.args import TrainingArguments, EvaluationArguments
  28. from estimators.train import Trainer
  29. from estimators.eval import Validator, load_my_state_dict
  30. from accuracy import accuracy
  31. def preprocess_url(image_urls):
  32. transformed_images = []
  33. for paths in image_urls:
  34. if len(paths) == 2:
  35. img_path, depth_path = paths
  36. _img = Image.open(img_path).convert(
  37. 'RGB').resize((2048, 1024), Image.BILINEAR)
  38. _depth = Image.open(depth_path).resize(
  39. (2048, 1024), Image.BILINEAR)
  40. else:
  41. img_path = paths[0]
  42. _img = Image.open(img_path).convert(
  43. 'RGB').resize((2048, 1024), Image.BILINEAR)
  44. _depth = _img
  45. sample = {'image': _img, 'depth': _depth, 'label': _img}
  46. composed_transforms = transforms.Compose([
  47. tr.Normalize(
  48. mean=(
  49. 0.485, 0.456, 0.406), std=(
  50. 0.229, 0.224, 0.225)),
  51. tr.ToTensor()])
  52. transformed_images.append((composed_transforms(sample), img_path))
  53. return transformed_images
  54. def preprocess_frames(frames):
  55. composed_transforms = transforms.Compose([
  56. tr.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)),
  57. tr.ToTensor()])
  58. trainsformed_frames = []
  59. for frame in frames:
  60. img = frame.get('image')
  61. img = cv2.resize(np.array(img), (2048, 1024),
  62. interpolation=cv2.INTER_CUBIC)
  63. img = Image.fromarray(np.array(img))
  64. sample = {'image': img, "depth": img, "label": img}
  65. trainsformed_frames.append((composed_transforms(sample), ""))
  66. return trainsformed_frames
  67. class Estimator:
  68. def __init__(self, **kwargs):
  69. self.train_args = TrainingArguments(**kwargs)
  70. self.val_args = EvaluationArguments(**kwargs)
  71. self.train_args.resume = Context.get_parameters(
  72. "PRETRAINED_MODEL_URL", None)
  73. self.trainer = None
  74. self.train_model_url = None
  75. label_save_dir = Context.get_parameters(
  76. "INFERENCE_RESULT_DIR",
  77. os.path.join(BaseConfig.data_path_prefix,
  78. "inference_results"))
  79. self.val_args.color_label_save_path = os.path.join(
  80. label_save_dir, "color")
  81. self.val_args.merge_label_save_path = os.path.join(
  82. label_save_dir, "merge")
  83. self.val_args.label_save_path = os.path.join(label_save_dir, "label")
  84. self.val_args.weight_path = kwargs.get("weight_path")
  85. self.validator = Validator(self.val_args)
  86. def train(self, train_data, valid_data=None, **kwargs):
  87. self.trainer = Trainer(
  88. self.train_args, train_data=train_data, valid_data=valid_data)
  89. LOGGER.info("Total epoches: {}".format(self.trainer.args.epochs))
  90. for epoch in range(
  91. self.trainer.args.start_epoch,
  92. self.trainer.args.epochs):
  93. if epoch == 0 and self.trainer.val_loader:
  94. self.trainer.validation(epoch)
  95. self.trainer.training(epoch)
  96. if self.trainer.args.no_val and \
  97. (epoch % self.trainer.args.eval_interval ==
  98. (self.trainer.args.eval_interval - 1) or
  99. epoch == self.trainer.args.epochs - 1):
  100. # save checkpoint when it meets eval_interval
  101. # or the training finishes
  102. is_best = False
  103. train_model_url = self.trainer.saver.save_checkpoint({
  104. 'epoch': epoch + 1,
  105. 'state_dict': self.trainer.model.state_dict(),
  106. 'optimizer': self.trainer.optimizer.state_dict(),
  107. 'best_pred': self.trainer.best_pred,
  108. }, is_best)
  109. self.trainer.writer.close()
  110. self.train_model_url = train_model_url
  111. return {"mIoU": 0 if not valid_data
  112. else self.trainer.validation(epoch)}
  113. def predict(self, data, **kwargs):
  114. if isinstance(data[0], dict):
  115. data = preprocess_frames(data)
  116. if isinstance(data[0], np.ndarray):
  117. data = preprocess_url(data)
  118. self.validator.test_loader = DataLoader(
  119. data,
  120. batch_size=self.val_args.test_batch_size,
  121. shuffle=False,
  122. pin_memory=False)
  123. return self.validator.validate()
  124. def evaluate(self, data, **kwargs):
  125. predictions = self.predict(data.x)
  126. return accuracy(data.y, predictions)
  127. def load(self, model_url, **kwargs):
  128. if model_url:
  129. self.validator.new_state_dict = torch.load(model_url)
  130. self.validator.model = load_my_state_dict(
  131. self.validator.model,
  132. self.validator.new_state_dict['state_dict'])
  133. self.train_args.resume = model_url
  134. else:
  135. raise Exception("model url does not exist.")
  136. def save(self, model_path=None):
  137. if not model_path:
  138. LOGGER.warning(f"Not specify model path.")
  139. return self.train_model_url
  140. return FileOps.upload(self.train_model_url, model_path)