You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

server.cc 11 kB

5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294
  1. /**
  2. * Copyright 2020 Huawei Technologies Co., Ltd
  3. *
  4. * Licensed under the Apache License, Version 2.0 (the "License");
  5. * you may not use this file except in compliance with the License.
  6. * You may obtain a copy of the License at
  7. *
  8. * http://www.apache.org/licenses/LICENSE-2.0
  9. *
  10. * Unless required by applicable law or agreed to in writing, software
  11. * distributed under the License is distributed on an "AS IS" BASIS,
  12. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. * See the License for the specific language governing permissions and
  14. * limitations under the License.
  15. */
  16. #include "core/server.h"
  17. #include <grpcpp/grpcpp.h>
  18. #include <grpcpp/health_check_service_interface.h>
  19. #include <grpcpp/ext/proto_server_reflection_plugin.h>
  20. #include <string>
  21. #include <map>
  22. #include <vector>
  23. #include <utility>
  24. #include <memory>
  25. #include <future>
  26. #include "mindspore/ccsrc/utils/log_adapter.h"
  27. #include "serving/ms_service.grpc.pb.h"
  28. #include "core/util/option_parser.h"
  29. #include "core/version_control/version_controller.h"
  30. #include "mindspore/ccsrc/utils/context/ms_context.h"
  31. #include "core/util/file_system_operation.h"
  32. #include "graphengine/third_party/fwkacllib/inc/runtime/context.h"
  33. using ms_serving::MSService;
  34. using ms_serving::PredictReply;
  35. using ms_serving::PredictRequest;
  36. namespace mindspore {
  37. namespace serving {
  38. using MSTensorPtr = std::shared_ptr<inference::MSTensor>;
  39. Status Session::CreatDeviceSession(const std::string &device, uint32_t device_id) {
  40. session_ = inference::MSSession::CreateSession(device, device_id);
  41. if (session_ == nullptr) {
  42. MS_LOG(ERROR) << "Creat Session Failed";
  43. return FAILED;
  44. }
  45. device_type_ = device;
  46. return SUCCESS;
  47. }
  48. Session &Session::Instance() {
  49. static Session instance;
  50. return instance;
  51. }
  52. Status Session::Predict(const std::vector<MSTensorPtr> &inputs, inference::MultiTensor *outputs) {
  53. if (last_graph_ == nullptr) {
  54. MS_LOG(ERROR) << "the model has not loaded";
  55. return FAILED;
  56. }
  57. if (session_ == nullptr) {
  58. MS_LOG(ERROR) << "the inference session has not be initialized";
  59. return FAILED;
  60. }
  61. std::lock_guard<std::mutex> lock(mutex_);
  62. MS_LOG(INFO) << "run Predict";
  63. if (!session_->CheckModelInputs(graph_id_, inputs)) {
  64. MS_LOG(ERROR) << "Input error.";
  65. return FAILED;
  66. }
  67. *outputs = session_->RunGraph(graph_id_, inputs);
  68. MS_LOG(INFO) << "run Predict finished";
  69. return SUCCESS;
  70. }
  71. Status Session::Warmup(const MindSporeModelPtr model) {
  72. if (session_ == nullptr) {
  73. MS_LOG(ERROR) << "The CreatDeviceSession should be called, before warmup";
  74. return FAILED;
  75. }
  76. std::lock_guard<std::mutex> lock(mutex_);
  77. size_t size = 0;
  78. std::string file_name = model->GetModelPath() + '/' + model->GetModelName();
  79. char *graphBuf = ReadFile(file_name.c_str(), &size);
  80. if (graphBuf == nullptr) {
  81. MS_LOG(ERROR) << "Read model file failed, file name is " << file_name.c_str();
  82. return FAILED;
  83. }
  84. last_graph_ = inference::LoadModel(graphBuf, size, device_type_);
  85. if (last_graph_ == nullptr) {
  86. MS_LOG(ERROR) << "Load graph model failed, file name is " << file_name.c_str();
  87. return FAILED;
  88. }
  89. graph_id_ = session_->CompileGraph(last_graph_);
  90. MS_LOG(INFO) << "Session Warmup finished";
  91. return SUCCESS;
  92. }
  93. Status Session::Clear() {
  94. session_ = nullptr;
  95. return SUCCESS;
  96. }
  97. namespace {
  98. static const uint32_t uint32max = 0x7FFFFFFF;
  99. std::promise<void> exit_requested;
  100. const std::map<ms_serving::DataType, TypeId> type2id_map{
  101. {ms_serving::MS_UNKNOWN, TypeId::kNumberTypeBegin}, {ms_serving::MS_BOOL, TypeId::kNumberTypeBool},
  102. {ms_serving::MS_INT8, TypeId::kNumberTypeInt8}, {ms_serving::MS_UINT8, TypeId::kNumberTypeUInt8},
  103. {ms_serving::MS_INT16, TypeId::kNumberTypeInt16}, {ms_serving::MS_UINT16, TypeId::kNumberTypeUInt16},
  104. {ms_serving::MS_INT32, TypeId::kNumberTypeInt32}, {ms_serving::MS_UINT32, TypeId::kNumberTypeUInt32},
  105. {ms_serving::MS_INT64, TypeId::kNumberTypeInt64}, {ms_serving::MS_UINT64, TypeId::kNumberTypeUInt64},
  106. {ms_serving::MS_FLOAT16, TypeId::kNumberTypeFloat16}, {ms_serving::MS_FLOAT32, TypeId::kNumberTypeFloat32},
  107. {ms_serving::MS_FLOAT64, TypeId::kNumberTypeFloat64},
  108. };
  109. const std::map<TypeId, ms_serving::DataType> id2type_map{
  110. {TypeId::kNumberTypeBegin, ms_serving::MS_UNKNOWN}, {TypeId::kNumberTypeBool, ms_serving::MS_BOOL},
  111. {TypeId::kNumberTypeInt8, ms_serving::MS_INT8}, {TypeId::kNumberTypeUInt8, ms_serving::MS_UINT8},
  112. {TypeId::kNumberTypeInt16, ms_serving::MS_INT16}, {TypeId::kNumberTypeUInt16, ms_serving::MS_UINT16},
  113. {TypeId::kNumberTypeInt32, ms_serving::MS_INT32}, {TypeId::kNumberTypeUInt32, ms_serving::MS_UINT32},
  114. {TypeId::kNumberTypeInt64, ms_serving::MS_INT64}, {TypeId::kNumberTypeUInt64, ms_serving::MS_UINT64},
  115. {TypeId::kNumberTypeFloat16, ms_serving::MS_FLOAT16}, {TypeId::kNumberTypeFloat32, ms_serving::MS_FLOAT32},
  116. {TypeId::kNumberTypeFloat64, ms_serving::MS_FLOAT64},
  117. };
  118. const std::map<ms_serving::DataType, size_t> length_map{
  119. {ms_serving::MS_UNKNOWN, 0},
  120. {ms_serving::MS_BOOL, sizeof(bool)},
  121. {ms_serving::MS_INT8, sizeof(int8_t)},
  122. {ms_serving::MS_UINT8, sizeof(uint8_t)},
  123. {ms_serving::MS_INT16, sizeof(int16_t)},
  124. {ms_serving::MS_UINT16, sizeof(uint16_t)},
  125. {ms_serving::MS_INT32, sizeof(int32_t)},
  126. {ms_serving::MS_UINT32, sizeof(uint32_t)},
  127. {ms_serving::MS_INT64, sizeof(int64_t)},
  128. {ms_serving::MS_UINT64, sizeof(uint64_t)},
  129. {ms_serving::MS_FLOAT16, 2},
  130. {ms_serving::MS_FLOAT32, 4},
  131. {ms_serving::MS_FLOAT64, 8},
  132. };
  133. MSTensorPtr ServingTensor2MSTensor(const ms_serving::Tensor &tensor) {
  134. std::vector<int> shape;
  135. for (auto dim : tensor.tensor_shape().dims()) {
  136. shape.push_back(static_cast<int>(dim));
  137. }
  138. auto iter = type2id_map.find(tensor.tensor_type());
  139. if (iter == type2id_map.end()) {
  140. MS_LOG(ERROR) << "input tensor type is wrong, type is " << tensor.tensor_type();
  141. return nullptr;
  142. }
  143. TypeId type = iter->second;
  144. auto ms_tensor = std::shared_ptr<inference::MSTensor>(inference::MSTensor::CreateTensor(type, shape));
  145. memcpy_s(ms_tensor->MutableData(), ms_tensor->Size(), tensor.data().data(), tensor.data().size());
  146. return ms_tensor;
  147. }
  148. ms_serving::Tensor MSTensor2ServingTensor(MSTensorPtr ms_tensor) {
  149. ms_serving::Tensor tensor;
  150. ms_serving::TensorShape shape;
  151. for (auto dim : ms_tensor->shape()) {
  152. shape.add_dims(dim);
  153. }
  154. *tensor.mutable_tensor_shape() = shape;
  155. auto iter = id2type_map.find(ms_tensor->data_type());
  156. if (iter == id2type_map.end()) {
  157. MS_LOG(ERROR) << "input tensor type is wrong, type is " << tensor.tensor_type();
  158. return tensor;
  159. }
  160. tensor.set_tensor_type(iter->second);
  161. tensor.set_data(ms_tensor->MutableData(), ms_tensor->Size());
  162. return tensor;
  163. }
  164. void ClearEnv() {
  165. Session::Instance().Clear();
  166. inference::ExitInference();
  167. }
  168. void HandleSignal(int sig) { exit_requested.set_value(); }
  169. #ifdef ENABLE_D
  170. static rtContext_t g_ctx = nullptr;
  171. #endif
  172. } // namespace
  173. // Service Implement
  174. class MSServiceImpl final : public MSService::Service {
  175. grpc::Status Predict(grpc::ServerContext *context, const PredictRequest *request, PredictReply *reply) override {
  176. std::lock_guard<std::mutex> lock(mutex_);
  177. #ifdef ENABLE_D
  178. if (g_ctx == nullptr) {
  179. MS_LOG(ERROR) << "rtCtx is nullptr";
  180. return grpc::Status::CANCELLED;
  181. }
  182. rtError_t rt_ret = rtCtxSetCurrent(g_ctx);
  183. if (rt_ret != RT_ERROR_NONE) {
  184. MS_LOG(ERROR) << "set Ascend rtCtx failed";
  185. }
  186. #endif
  187. std::vector<MSTensorPtr> inputs;
  188. inference::MultiTensor outputs;
  189. for (int i = 0; i < request->data_size(); i++) {
  190. auto input = ServingTensor2MSTensor(request->data(i));
  191. if (input == nullptr) {
  192. MS_LOG(ERROR) << "Tensor convert failed";
  193. return grpc::Status::CANCELLED;
  194. }
  195. inputs.push_back(input);
  196. }
  197. auto res = Session::Instance().Predict(inputs, &outputs);
  198. if (res != SUCCESS) {
  199. return grpc::Status::CANCELLED;
  200. }
  201. for (const auto &tensor : outputs) {
  202. *reply->add_result() = MSTensor2ServingTensor(tensor);
  203. }
  204. MS_LOG(INFO) << "Finish call service Eval";
  205. return grpc::Status::OK;
  206. }
  207. grpc::Status Test(grpc::ServerContext *context, const PredictRequest *request, PredictReply *reply) override {
  208. MS_LOG(INFO) << "TestService call";
  209. return grpc::Status::OK;
  210. }
  211. std::mutex mutex_;
  212. };
  213. Status Server::BuildAndStart() {
  214. // handle exit signal
  215. signal(SIGINT, HandleSignal);
  216. signal(SIGTERM, HandleSignal);
  217. Status res;
  218. auto option_args = Options::Instance().GetArgs();
  219. std::string server_address = "0.0.0.0:" + std::to_string(option_args->grpc_port);
  220. std::string model_path = option_args->model_path;
  221. std::string model_name = option_args->model_name;
  222. std::string device_type = option_args->device_type;
  223. auto device_id = option_args->device_id;
  224. res = Session::Instance().CreatDeviceSession(device_type, device_id);
  225. if (res != SUCCESS) {
  226. MS_LOG(ERROR) << "creat session failed";
  227. ClearEnv();
  228. return res;
  229. }
  230. VersionController version_controller(option_args->poll_model_wait_seconds, model_path, model_name);
  231. res = version_controller.Run();
  232. if (res != SUCCESS) {
  233. MS_LOG(ERROR) << "load model failed";
  234. ClearEnv();
  235. return res;
  236. }
  237. #ifdef ENABLE_D
  238. // set d context
  239. rtContext_t ctx = nullptr;
  240. rtError_t rt_ret = rtCtxGetCurrent(&ctx);
  241. if (rt_ret != RT_ERROR_NONE || ctx == nullptr) {
  242. MS_LOG(ERROR) << "the ascend device context is null";
  243. ClearEnv();
  244. return FAILED;
  245. }
  246. g_ctx = ctx;
  247. #endif
  248. MSServiceImpl ms_service;
  249. grpc::EnableDefaultHealthCheckService(true);
  250. grpc::reflection::InitProtoReflectionServerBuilderPlugin();
  251. // Set the port is not reuseable
  252. auto option = grpc::MakeChannelArgumentOption(GRPC_ARG_ALLOW_REUSEPORT, 0);
  253. grpc::ServerBuilder serverBuilder;
  254. serverBuilder.SetOption(std::move(option));
  255. serverBuilder.SetMaxMessageSize(uint32max);
  256. serverBuilder.AddListeningPort(server_address, grpc::InsecureServerCredentials());
  257. serverBuilder.RegisterService(&ms_service);
  258. std::unique_ptr<grpc::Server> server(serverBuilder.BuildAndStart());
  259. if (server == nullptr) {
  260. MS_LOG(ERROR) << "The serving server create failed";
  261. ClearEnv();
  262. return FAILED;
  263. }
  264. auto grpc_server_run = [&server]() { server->Wait(); };
  265. std::thread serving_thread(grpc_server_run);
  266. MS_LOG(INFO) << "MS Serving listening on " << server_address;
  267. auto exit_future = exit_requested.get_future();
  268. exit_future.wait();
  269. ClearEnv();
  270. server->Shutdown();
  271. serving_thread.join();
  272. return SUCCESS;
  273. }
  274. } // namespace serving
  275. } // namespace mindspore