diff --git a/ge/session/inner_session.cc b/ge/session/inner_session.cc index 5a67f7cd..6a56fc05 100755 --- a/ge/session/inner_session.cc +++ b/ge/session/inner_session.cc @@ -77,6 +77,23 @@ Status InnerSession::Initialize() { UpdateThreadContext(std::map{}); + // session device id set here + std::string str_session_device_id; + if (GetContext().GetOption("ge.session_device_id", str_session_device_id) == SUCCESS) { + GELOGI("Option session device id has set, value is %s.", str_session_device_id.c_str()); + + uint32_t session_device_id = 0; + try { + session_device_id = static_cast(std::stoi(str_session_device_id.c_str())); + // session device id has priority + GetContext().SetCtxDeviceId(session_device_id); + } catch (std::invalid_argument &) { + GELOGW("session device id %s transform to int failed.", str_session_device_id.c_str()); + } catch (std::out_of_range &) { + GELOGW("session device id %s transform to int failed.", str_session_device_id.c_str()); + } + } + GE_CHK_RT_RET(rtSetDevice(GetContext().DeviceId())); DumpProperties dump_properties;