Browse Source

gpu inference mixed precision

pull/15612/head
wilfChen 4 years ago
parent
commit
ba9bbfadf8
6 changed files with 36 additions and 3 deletions
  1. +12
    -0
      include/api/context.h
  2. +9
    -0
      mindspore/ccsrc/backend/optimizer/trt_pass/trt_converter_context.cc
  3. +11
    -0
      mindspore/ccsrc/cxx_api/context.cc
  4. +2
    -2
      mindspore/ccsrc/cxx_api/graph/gpu/gpu_graph_impl.cc
  5. +1
    -1
      mindspore/ccsrc/runtime/device/gpu/trt_loader.h
  6. +1
    -0
      mindspore/core/utils/ms_context.h

+ 12
- 0
include/api/context.h View File

@@ -116,8 +116,20 @@ class MS_API NvidiaGPUDeviceInfo : public DeviceInfoContext {

void SetGpuTrtInferMode(bool gpu_trt_infer_mode);
bool GetGpuTrtInferMode() const;

inline void SetPrecisionMode(const std::string &precison_mode);
inline std::string GetPrecisionMode() const;

private:
void SetPrecisionMode(const std::vector<char> &precision_mode);
std::vector<char> GetPrecisionModeChar() const;
};

void NvidiaGPUDeviceInfo::SetPrecisionMode(const std::string &precision_mode) {
SetPrecisionMode(StringToChar(precision_mode));
}
std::string NvidiaGPUDeviceInfo::GetPrecisionMode() const { return CharToString(GetPrecisionModeChar()); }

class MS_API Ascend910DeviceInfo : public DeviceInfoContext {
public:
enum DeviceType GetDeviceType() const override { return DeviceType::kAscend910; };


+ 9
- 0
mindspore/ccsrc/backend/optimizer/trt_pass/trt_converter_context.cc View File

@@ -198,6 +198,15 @@ bool TrtConverterContext::Serialize(std::string *model) {
MS_EXCEPTION_IF_NULL(model);
builder_->setMaxBatchSize(batch_size_);
config_->setMaxWorkspaceSize(workspace_size_);

// Set precision mode
const auto &context = MsContext::GetInstance();
const auto &precision_mode = context->get_param<std::string>(MS_CTX_INFER_PRECISION_MODE);
if (precision_mode == "fp16") {
MS_LOG(WARNING) << "Inference with mixed precision mode. It will take few minutes for operators selection.";
config_->setFlag(nvinfer1::BuilderFlag::kFP16);
}

engine_ = TrtPtr(builder_->buildEngineWithConfig(*network_, *config_));
MS_EXCEPTION_IF_NULL(engine_);



+ 11
- 0
mindspore/ccsrc/cxx_api/context.cc View File

@@ -27,6 +27,7 @@ constexpr auto kModelOptionKirinNpuFrequency = "mindspore.option.kirin_npu.frequ
constexpr auto kModelOptionDeviceID = "mindspore.option.device_id";
constexpr auto kModelOptionNvidiaGpuDeviceID = kModelOptionDeviceID;
constexpr auto kModelOptionNvidiaGpuTrtInferMode = "mindspore.option.nvidia_gpu.trt_infer_mode";
constexpr auto kModelOptionNvidiaGpuPrecisionMode = "mindspore.option.nvidia_gpu.precision_mode";
constexpr auto kModelOptionAscend910DeviceID = kModelOptionDeviceID;
constexpr auto kModelOptionAscend310DeviceID = kModelOptionDeviceID;
constexpr auto kModelOptionAscend310DumpCfgPath = "mindspore.option.ascend310.dump_config_file_path";
@@ -153,6 +154,16 @@ bool NvidiaGPUDeviceInfo::GetGpuTrtInferMode() const {
return GetValue<bool>(data_, kModelOptionNvidiaGpuTrtInferMode);
}

void NvidiaGPUDeviceInfo::SetPrecisionMode(const std::vector<char> &precision_mode) {
MS_EXCEPTION_IF_NULL(data_);
data_->params[kModelOptionNvidiaGpuPrecisionMode] = CharToString(precision_mode);
}
std::vector<char> NvidiaGPUDeviceInfo::GetPrecisionModeChar() const {
MS_EXCEPTION_IF_NULL(data_);
const std::string &ref = GetValue<std::string>(data_, kModelOptionNvidiaGpuPrecisionMode);
return StringToChar(ref);
}

void Ascend910DeviceInfo::SetDeviceID(uint32_t device_id) {
MS_EXCEPTION_IF_NULL(data_);
data_->params[kModelOptionAscend910DeviceID] = device_id;


+ 2
- 2
mindspore/ccsrc/cxx_api/graph/gpu/gpu_graph_impl.cc View File

@@ -63,8 +63,8 @@ Status GPUGraphImpl::InitEnv() {
if (gpu_info == nullptr) {
return kMCDeviceError;
}
auto enable_trt = gpu_info->GetGpuTrtInferMode();
ms_context->set_param<bool>(MS_CTX_ENABLE_INFER_OPT, enable_trt);
ms_context->set_param<bool>(MS_CTX_ENABLE_INFER_OPT, gpu_info->GetGpuTrtInferMode());
ms_context->set_param<std::string>(MS_CTX_INFER_PRECISION_MODE, gpu_info->GetPrecisionMode());

session_impl_ = session::SessionFactory::Get().Create(kGpuInferenceDevice);
if (session_impl_ == nullptr) {


+ 1
- 1
mindspore/ccsrc/runtime/device/gpu/trt_loader.h View File

@@ -31,7 +31,7 @@ class TrtLoader {

std::shared_ptr<nvinfer1::IBuilder> CreateInferBuilder(nvinfer1::ILogger *logger);
std::shared_ptr<nvinfer1::IRuntime> CreateInferRuntime(nvinfer1::ILogger *logger);
bool nvinfer_loaded() { return nvinfer_loaded_; }
bool nvinfer_loaded() const { return nvinfer_loaded_; }

private:
bool nvinfer_loaded_;


+ 1
- 0
mindspore/core/utils/ms_context.h View File

@@ -118,6 +118,7 @@ enum MsCtxParam : unsigned {
MS_CTX_ENV_CONFIG_PATH,
MS_CTX_TUNE_MODE,
MS_CTX_GRAPH_KERNEL_FLAGS,
MS_CTX_INFER_PRECISION_MODE, // GPU inference precision mode configured by Serving or Unify API.
MS_CTX_TYPE_STRING_END,

// parameter numbers of each type


Loading…
Cancel
Save