You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

server.cc 10 kB

5 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277
  1. /**
  2. * Copyright 2020 Huawei Technologies Co., Ltd
  3. *
  4. * Licensed under the Apache License, Version 2.0 (the "License");
  5. * you may not use this file except in compliance with the License.
  6. * You may obtain a copy of the License at
  7. *
  8. * http://www.apache.org/licenses/LICENSE-2.0
  9. *
  10. * Unless required by applicable law or agreed to in writing, software
  11. * distributed under the License is distributed on an "AS IS" BASIS,
  12. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. * See the License for the specific language governing permissions and
  14. * limitations under the License.
  15. */
  16. #include "core/server.h"
  17. #include <grpcpp/grpcpp.h>
  18. #include <grpcpp/health_check_service_interface.h>
  19. #include <grpcpp/ext/proto_server_reflection_plugin.h>
  20. #include <string>
  21. #include <map>
  22. #include <vector>
  23. #include <utility>
  24. #include <memory>
  25. #include "mindspore/ccsrc/utils/log_adapter.h"
  26. #include "serving/ms_service.grpc.pb.h"
  27. #include "core/util/option_parser.h"
  28. #include "core/version_control/version_controller.h"
  29. #include "mindspore/ccsrc/utils/context/ms_context.h"
  30. #include "core/util/file_system_operation.h"
  31. #include "graphengine/third_party/fwkacllib/inc/runtime/context.h"
  32. using ms_serving::MSService;
  33. using ms_serving::PredictReply;
  34. using ms_serving::PredictRequest;
  35. namespace mindspore {
  36. namespace serving {
  37. using MSTensorPtr = std::shared_ptr<inference::MSTensor>;
  38. Status Session::CreatDeviceSession(const std::string &device, uint32_t device_id) {
  39. session_ = inference::MSSession::CreateSession(device + "Inference", device_id);
  40. if (session_ == nullptr) {
  41. MS_LOG(ERROR) << "Creat Session Failed";
  42. return FAILED;
  43. }
  44. device_type_ = device;
  45. return SUCCESS;
  46. }
  47. Session &Session::Instance() {
  48. static Session instance;
  49. return instance;
  50. }
  51. Status Session::Predict(const std::vector<MSTensorPtr> &inputs, inference::MultiTensor *outputs) {
  52. if (last_graph_ == nullptr) {
  53. MS_LOG(ERROR) << "the model has not loaded";
  54. return FAILED;
  55. }
  56. if (session_ == nullptr) {
  57. MS_LOG(ERROR) << "the inference session has not be initialized";
  58. return FAILED;
  59. }
  60. std::lock_guard<std::mutex> lock(mutex_);
  61. MS_LOG(INFO) << "run Predict";
  62. *outputs = session_->RunGraph(graph_id_, inputs);
  63. return SUCCESS;
  64. }
  65. Status Session::Warmup(const MindSporeModelPtr model) {
  66. if (session_ == nullptr) {
  67. MS_LOG(ERROR) << "The CreatDeviceSession should be called, before warmup";
  68. return FAILED;
  69. }
  70. std::lock_guard<std::mutex> lock(mutex_);
  71. size_t size = 0;
  72. std::string file_name = model->GetModelPath() + '/' + model->GetModelName();
  73. char *graphBuf = ReadFile(file_name.c_str(), &size);
  74. if (graphBuf == nullptr) {
  75. MS_LOG(ERROR) << "Load graph model failed, file name is " << file_name.c_str();
  76. return FAILED;
  77. }
  78. last_graph_ = inference::LoadModel(graphBuf, size, device_type_);
  79. graph_id_ = session_->CompileGraph(last_graph_);
  80. MS_LOG(INFO) << "Session Warmup";
  81. return SUCCESS;
  82. }
  83. Status Session::Clear() {
  84. session_ = nullptr;
  85. return SUCCESS;
  86. }
  87. namespace {
  88. const std::map<ms_serving::DataType, TypeId> type2id_map{
  89. {ms_serving::MS_UNKNOWN, TypeId::kNumberTypeBegin}, {ms_serving::MS_BOOL, TypeId::kNumberTypeBool},
  90. {ms_serving::MS_INT8, TypeId::kNumberTypeInt8}, {ms_serving::MS_UINT8, TypeId::kNumberTypeUInt8},
  91. {ms_serving::MS_INT16, TypeId::kNumberTypeInt16}, {ms_serving::MS_UINT16, TypeId::kNumberTypeUInt16},
  92. {ms_serving::MS_INT32, TypeId::kNumberTypeInt32}, {ms_serving::MS_UINT32, TypeId::kNumberTypeUInt32},
  93. {ms_serving::MS_INT64, TypeId::kNumberTypeInt64}, {ms_serving::MS_UINT64, TypeId::kNumberTypeUInt64},
  94. {ms_serving::MS_FLOAT16, TypeId::kNumberTypeFloat16}, {ms_serving::MS_FLOAT32, TypeId::kNumberTypeFloat32},
  95. {ms_serving::MS_FLOAT64, TypeId::kNumberTypeFloat64},
  96. };
  97. const std::map<TypeId, ms_serving::DataType> id2type_map{
  98. {TypeId::kNumberTypeBegin, ms_serving::MS_UNKNOWN}, {TypeId::kNumberTypeBool, ms_serving::MS_BOOL},
  99. {TypeId::kNumberTypeInt8, ms_serving::MS_INT8}, {TypeId::kNumberTypeUInt8, ms_serving::MS_UINT8},
  100. {TypeId::kNumberTypeInt16, ms_serving::MS_INT16}, {TypeId::kNumberTypeUInt16, ms_serving::MS_UINT16},
  101. {TypeId::kNumberTypeInt32, ms_serving::MS_INT32}, {TypeId::kNumberTypeUInt32, ms_serving::MS_UINT32},
  102. {TypeId::kNumberTypeInt64, ms_serving::MS_INT64}, {TypeId::kNumberTypeUInt64, ms_serving::MS_UINT64},
  103. {TypeId::kNumberTypeFloat16, ms_serving::MS_FLOAT16}, {TypeId::kNumberTypeFloat32, ms_serving::MS_FLOAT32},
  104. {TypeId::kNumberTypeFloat64, ms_serving::MS_FLOAT64},
  105. };
  106. const std::map<ms_serving::DataType, size_t> length_map{
  107. {ms_serving::MS_UNKNOWN, 0},
  108. {ms_serving::MS_BOOL, sizeof(bool)},
  109. {ms_serving::MS_INT8, sizeof(int8_t)},
  110. {ms_serving::MS_UINT8, sizeof(uint8_t)},
  111. {ms_serving::MS_INT16, sizeof(int16_t)},
  112. {ms_serving::MS_UINT16, sizeof(uint16_t)},
  113. {ms_serving::MS_INT32, sizeof(int32_t)},
  114. {ms_serving::MS_UINT32, sizeof(uint32_t)},
  115. {ms_serving::MS_INT64, sizeof(int64_t)},
  116. {ms_serving::MS_UINT64, sizeof(uint64_t)},
  117. {ms_serving::MS_FLOAT16, 2},
  118. {ms_serving::MS_FLOAT32, 4},
  119. {ms_serving::MS_FLOAT64, 8},
  120. };
  121. MSTensorPtr ServingTensor2MSTensor(const ms_serving::Tensor &tensor) {
  122. std::vector<int> shape;
  123. for (auto dim : tensor.tensor_shape().dims()) {
  124. shape.push_back(static_cast<int>(dim));
  125. }
  126. auto iter = type2id_map.find(tensor.tensor_type());
  127. if (iter == type2id_map.end()) {
  128. MS_LOG(ERROR) << "input tensor type is wrong, type is " << tensor.tensor_type();
  129. return nullptr;
  130. }
  131. TypeId type = iter->second;
  132. auto ms_tensor = std::shared_ptr<inference::MSTensor>(inference::MSTensor::CreateTensor(type, shape));
  133. memcpy_s(ms_tensor->MutableData(), tensor.data().size(), tensor.data().data(), tensor.data().size());
  134. return ms_tensor;
  135. }
  136. ms_serving::Tensor MSTensor2ServingTensor(MSTensorPtr ms_tensor) {
  137. ms_serving::Tensor tensor;
  138. ms_serving::TensorShape shape;
  139. for (auto dim : ms_tensor->shape()) {
  140. shape.add_dims(dim);
  141. }
  142. *tensor.mutable_tensor_shape() = shape;
  143. auto iter = id2type_map.find(ms_tensor->data_type());
  144. if (iter == id2type_map.end()) {
  145. MS_LOG(ERROR) << "input tensor type is wrong, type is " << tensor.tensor_type();
  146. return tensor;
  147. }
  148. tensor.set_tensor_type(iter->second);
  149. tensor.set_data(ms_tensor->MutableData(), ms_tensor->Size());
  150. return tensor;
  151. }
  152. void ClearEnv() {
  153. Session::Instance().Clear();
  154. inference::ExitInference();
  155. }
  156. void HandleSignal(int sig) {
  157. ClearEnv();
  158. exit(0);
  159. }
  160. #ifdef ENABLE_D
  161. static rtContext_t g_ctx = nullptr;
  162. #endif
  163. } // namespace
  164. // Service Implement
  165. class MSServiceImpl final : public MSService::Service {
  166. grpc::Status Predict(grpc::ServerContext *context, const PredictRequest *request, PredictReply *reply) override {
  167. std::lock_guard<std::mutex> lock(mutex_);
  168. #ifdef ENABLE_D
  169. if (g_ctx == nullptr) {
  170. MS_LOG(ERROR) << "rtCtx is nullptr";
  171. return grpc::Status::CANCELLED;
  172. }
  173. rtError_t rt_ret = rtCtxSetCurrent(g_ctx);
  174. if (rt_ret != RT_ERROR_NONE) {
  175. MS_LOG(ERROR) << "set Ascend rtCtx failed";
  176. }
  177. #endif
  178. std::vector<MSTensorPtr> inputs;
  179. inference::MultiTensor outputs;
  180. for (int i = 0; i < request->data_size(); i++) {
  181. auto input = ServingTensor2MSTensor(request->data(i));
  182. if (input == nullptr) {
  183. MS_LOG(ERROR) << "Tensor convert failed";
  184. return grpc::Status::CANCELLED;
  185. }
  186. inputs.push_back(input);
  187. }
  188. auto res = Session::Instance().Predict(inputs, &outputs);
  189. if (res != SUCCESS) {
  190. return grpc::Status::CANCELLED;
  191. }
  192. for (const auto &tensor : outputs) {
  193. *reply->add_result() = MSTensor2ServingTensor(tensor);
  194. }
  195. MS_LOG(INFO) << "Finish call service Eval";
  196. return grpc::Status::OK;
  197. }
  198. grpc::Status Test(grpc::ServerContext *context, const PredictRequest *request, PredictReply *reply) override {
  199. MS_LOG(INFO) << "TestService call";
  200. return grpc::Status::OK;
  201. }
  202. std::mutex mutex_;
  203. };
  204. Status Server::BuildAndStart() {
  205. // handle exit signal
  206. signal(SIGINT, HandleSignal);
  207. Status res;
  208. auto option_args = Options::Instance().GetArgs();
  209. std::string server_address = "0.0.0.0:" + std::to_string(option_args->grpc_port);
  210. std::string model_path = option_args->model_path;
  211. std::string model_name = option_args->model_name;
  212. std::string device_type = option_args->device_type;
  213. auto device_id = option_args->device_id;
  214. res = Session::Instance().CreatDeviceSession(device_type, device_id);
  215. if (res != SUCCESS) {
  216. MS_LOG(ERROR) << "creat session failed";
  217. ClearEnv();
  218. return res;
  219. }
  220. VersionController version_controller(option_args->poll_model_wait_seconds, model_path, model_name);
  221. res = version_controller.Run();
  222. if (res != SUCCESS) {
  223. MS_LOG(ERROR) << "load model failed";
  224. ClearEnv();
  225. return res;
  226. }
  227. #ifdef ENABLE_D
  228. // set d context
  229. rtContext_t ctx = nullptr;
  230. rtError_t rt_ret = rtCtxGetCurrent(&ctx);
  231. if (rt_ret != RT_ERROR_NONE || ctx == nullptr) {
  232. MS_LOG(ERROR) << "the ascend device context is null";
  233. return FAILED;
  234. }
  235. g_ctx = ctx;
  236. #endif
  237. MSServiceImpl service;
  238. grpc::EnableDefaultHealthCheckService(true);
  239. grpc::reflection::InitProtoReflectionServerBuilderPlugin();
  240. // Set the port is not reuseable
  241. auto option = grpc::MakeChannelArgumentOption(GRPC_ARG_ALLOW_REUSEPORT, 0);
  242. grpc::ServerBuilder builder;
  243. builder.SetOption(std::move(option));
  244. // Listen on the given address without any authentication mechanism.
  245. builder.AddListeningPort(server_address, grpc::InsecureServerCredentials());
  246. // Register "service" as the instance through which we'll communicate with
  247. // clients. In this case it corresponds to an *synchronous* service.
  248. builder.RegisterService(&service);
  249. // Finally assemble the server.
  250. std::unique_ptr<grpc::Server> server(builder.BuildAndStart());
  251. MS_LOG(INFO) << "Server listening on " << server_address << std::endl;
  252. // Wait for the server to shutdown. Note that some other thread must be
  253. // responsible for shutting down the server for this call to ever return.
  254. server->Wait();
  255. return SUCCESS;
  256. }
  257. } // namespace serving
  258. } // namespace mindspore