| @@ -35,7 +35,7 @@ function(ms_build_flatbuffers source_schema_files | |||
| set(total_schema_dirs -I ${schema_dir} ${total_schema_dirs}) | |||
| endforeach() | |||
| foreach(schema ${source_schema_files}) | |||
| foreach(schema IN LISTS ${source_schema_files}) | |||
| get_filename_component(filename ${schema} NAME_WE) | |||
| if(NOT ${generated_output_dir} STREQUAL "") | |||
| set(generated_file ${generated_output_dir}/${filename}_generated.h) | |||
| @@ -212,7 +212,7 @@ if(ENABLE_GPU) | |||
| ) | |||
| endif() | |||
| if(ENABLE_CPU AND (ENABLE_D OR ENABLE_GPU)) | |||
| if(ENABLE_CPU AND NOT WIN32) | |||
| install( | |||
| TARGETS ps_cache | |||
| DESTINATION ${INSTALL_LIB_DIR} | |||
| @@ -373,7 +373,7 @@ elseif(CMAKE_SYSTEM_NAME MATCHES "Darwin") | |||
| target_link_libraries(mindspore mindspore_gvar) | |||
| target_link_libraries(_c_expression PRIVATE -Wl,-force_load mindspore mindspore_core -Wl,-noall_load) | |||
| else() | |||
| if(ENABLE_CPU AND (ENABLE_D OR ENABLE_GPU)) | |||
| if(ENABLE_CPU AND NOT WIN32) | |||
| target_link_libraries(mindspore proto_input mindspore::protobuf | |||
| mindspore::event mindspore::event_pthreads mindspore::event_openssl mindspore::json) | |||
| target_link_libraries(mindspore -Wl,--no-as-needed mindspore::event_core ps_cache) | |||
| @@ -75,7 +75,7 @@ if(ENABLE_CPU) | |||
| endif() | |||
| endif() | |||
| if(NOT (ENABLE_CPU AND (ENABLE_D OR ENABLE_GPU))) | |||
| if(NOT ENABLE_CPU OR WIN32) | |||
| list(REMOVE_ITEM CPU_SRC_LIST "cpu/ps/apply_momentum_ps_kernel.cc") | |||
| list(REMOVE_ITEM CPU_SRC_LIST "cpu/ps/embedding_look_up_proxy_kernel.cc") | |||
| list(REMOVE_ITEM CPU_SRC_LIST "cpu/ps/embedding_look_up_ps_kernel.cc") | |||
| @@ -421,7 +421,7 @@ void AscendSession::LoadInputData(const std::shared_ptr<KernelGraph> &kernel_gra | |||
| size = abstract::ShapeSize(shape_tmp) * abstract::TypeIdSize(tensor->data_type()); | |||
| } | |||
| if (input_node->isa<Parameter>() && AnfAlgo::OutputAddrExist(input_node, 0) && TensorNeedSync(input_node, tensor)) { | |||
| #if (ENABLE_CPU && (ENABLE_D || ENABLE_GPU)) | |||
| #if (ENABLE_CPU && !_WIN32) | |||
| const std::string ¶m_name = input_node->fullname_with_scope(); | |||
| if (ps::ps_cache_instance.IsHashTable(param_name)) { | |||
| continue; | |||
| @@ -33,7 +33,7 @@ | |||
| #include "debug/anf_ir_dump.h" | |||
| #include "debug/dump_proto.h" | |||
| #include "debug/data_dump/dump_json_parser.h" | |||
| #if (ENABLE_CPU && (ENABLE_D || ENABLE_GPU)) | |||
| #if (ENABLE_CPU && !_WIN32) | |||
| #include "ps/util.h" | |||
| #include "ps/ps_context.h" | |||
| #endif | |||
| @@ -74,7 +74,7 @@ void CPUSession::Reorder(std::vector<CNodePtr> *node_list) { AnfAlgo::ReorderPos | |||
| void CPUSession::Optimize(const std::shared_ptr<KernelGraph> &kernel_graph) { | |||
| auto optimizer = std::make_shared<opt::GraphOptimizer>(); | |||
| auto pm = std::make_shared<opt::PassManager>(); | |||
| #if (ENABLE_CPU && (ENABLE_D || ENABLE_GPU)) | |||
| #if (ENABLE_CPU && !_WIN32) | |||
| auto ms_context = MsContext::GetInstance(); | |||
| MS_EXCEPTION_IF_NULL(ms_context); | |||
| if (ms_context->get_param<int>(MS_CTX_EXECUTION_MODE) != kPynativeMode && ps::PSContext::instance()->is_ps_mode()) { | |||
| @@ -174,7 +174,7 @@ void CPUSession::PreExecuteGraph(const std::shared_ptr<KernelGraph> &kernel_grap | |||
| MS_LOG(INFO) << "Bind input output address"; | |||
| runtime_.BindInputOutput(kernel_graph.get(), inputs, outputs); | |||
| #if (ENABLE_CPU && (ENABLE_D || ENABLE_GPU)) | |||
| #if (ENABLE_CPU && !_WIN32) | |||
| InitPSParamAndOptim(kernel_graph, inputs); | |||
| #endif | |||
| } | |||
| @@ -21,7 +21,7 @@ | |||
| #include "utils/comm_manager.h" | |||
| #include "utils/scoped_long_running.h" | |||
| #include "pybind_api/ir/tensor_py.h" | |||
| #if (ENABLE_CPU && (ENABLE_D || ENABLE_GPU)) | |||
| #if (ENABLE_CPU && !_WIN32) | |||
| #include "ps/ps_cache/ps_cache_manager.h" | |||
| #endif | |||
| @@ -43,7 +43,7 @@ | |||
| #include "debug/common.h" | |||
| #include "utils/trace_base.h" | |||
| #include "frontend/parallel/context.h" | |||
| #if (ENABLE_CPU && (ENABLE_D || ENABLE_GPU)) | |||
| #if (ENABLE_CPU && !_WIN32) | |||
| #include "ps/ps_cache/ps_cache_manager.h" | |||
| #include "ps/constants.h" | |||
| #include "ps/util.h" | |||
| @@ -2357,7 +2357,7 @@ void SessionBasic::DumpGraph(const std::shared_ptr<KernelGraph> &kernel_graph) { | |||
| #endif | |||
| } | |||
| #if (ENABLE_CPU && (ENABLE_D || ENABLE_GPU)) | |||
| #if (ENABLE_CPU && !_WIN32) | |||
| void SessionBasic::InitPsWorker(const KernelGraphPtr &kernel_graph) { | |||
| if (!ps::PSContext::instance()->is_worker()) { | |||
| return; | |||
| @@ -244,7 +244,7 @@ class SessionBasic : public std::enable_shared_from_this<SessionBasic> { | |||
| std::vector<uint32_t> GetAllReduceSplitIndex(); | |||
| virtual std::string GetCommWorldGroup() { return std::string(); } | |||
| void DumpGraph(const std::shared_ptr<KernelGraph> &kernel_graph); | |||
| #if (ENABLE_CPU && (ENABLE_D || ENABLE_GPU)) | |||
| #if (ENABLE_CPU && !_WIN32) | |||
| void CheckPSModeConsistence(const KernelGraphPtr &kernel_graph) const; | |||
| void GetBatchElements(const AnfNodePtr &kernel_node) const; | |||
| void InitPsWorker(const KernelGraphPtr &kernel_graph); | |||
| @@ -263,7 +263,7 @@ class SessionBasic : public std::enable_shared_from_this<SessionBasic> { | |||
| #if !defined(_WIN32) && !defined(_WIN64) | |||
| std::shared_ptr<Debugger> debugger_; | |||
| #endif | |||
| #if (ENABLE_CPU && (ENABLE_D || ENABLE_GPU)) | |||
| #if (ENABLE_CPU && !_WIN32) | |||
| bool initialized_ps_cache_{false}; | |||
| #endif | |||
| }; | |||
| @@ -25,7 +25,7 @@ | |||
| #include "frontend/parallel/device_matrix.h" | |||
| #include "frontend/parallel/graph_util/generate_graph.h" | |||
| #include "frontend/parallel/context.h" | |||
| #if (ENABLE_CPU && (ENABLE_D || ENABLE_GPU)) | |||
| #if (ENABLE_CPU && !_WIN32) | |||
| #include "ps/ps_cache/ps_cache_manager.h" | |||
| #include "utils/ms_context.h" | |||
| #endif | |||
| @@ -160,7 +160,7 @@ Status GatherPInfo::GetAttrs() { | |||
| if (std::find(inputs_shape_[1].begin(), inputs_shape_[1].end(), -1) != inputs_shape_[1].end()) { | |||
| dynamic_shape_indices_ = true; | |||
| } | |||
| #if (ENABLE_CPU && (ENABLE_D || ENABLE_GPU)) | |||
| #if (ENABLE_CPU && !_WIN32) | |||
| MS_EXCEPTION_IF_NULL(ParallelContext::GetInstance()); | |||
| std::string parallel_mode = ParallelContext::GetInstance()->parallel_mode(); | |||
| MS_EXCEPTION_IF_NULL(MsContext::GetInstance()); | |||
| @@ -617,7 +617,7 @@ Status GatherPInfo::InferBias() { | |||
| rank = rank % (params_strategy[0] * params_strategy[1]); | |||
| } | |||
| } | |||
| #if (ENABLE_CPU && (ENABLE_D || ENABLE_GPU)) | |||
| #if (ENABLE_CPU && !_WIN32) | |||
| if (ps::PsDataPrefetch::GetInstance().cache_enable()) { | |||
| bias_ = static_cast<int64_t>(ps::PsCacheManager::GetInstance().cache_indices_lower_bound()); | |||
| return SUCCESS; | |||
| @@ -28,7 +28,7 @@ | |||
| #include "frontend/parallel/strategy.h" | |||
| #include "frontend/parallel/context.h" | |||
| #include "frontend/parallel/tensor_layout/tensor_redistribution.h" | |||
| #if (ENABLE_CPU && (ENABLE_D || ENABLE_GPU)) | |||
| #if (ENABLE_CPU && !_WIN32) | |||
| #include "ps/ps_cache/ps_cache_manager.h" | |||
| #endif | |||
| @@ -192,7 +192,7 @@ Status UniqueInfo::GenerateStrategies(int64_t stage_id) { | |||
| return SUCCESS; | |||
| } | |||
| #if (ENABLE_CPU && (ENABLE_D || ENABLE_GPU)) | |||
| #if (ENABLE_CPU && !_WIN32) | |||
| Status UniqueInfo::ComputeReplaceGraph(const CNodePtr &cnode) { | |||
| GenerateGraph gen_g = GenerateGraph(); | |||
| if (gen_g.Init(cnode) != SUCCESS) { | |||
| @@ -230,7 +230,7 @@ Status UniqueInfo::ComputeReplaceGraph(const CNodePtr &cnode) { | |||
| #endif | |||
| ReplaceGraphPtr UniqueInfo::replace_graph(const CNodePtr &cnode) { | |||
| #if (ENABLE_CPU && (ENABLE_D || ENABLE_GPU)) | |||
| #if (ENABLE_CPU && !_WIN32) | |||
| if (ps::PsDataPrefetch::GetInstance().cache_enable()) { | |||
| auto inputs = cnode->inputs(); | |||
| if (inputs.empty()) { | |||
| @@ -51,7 +51,7 @@ class UniqueInfo : public OperatorInfo { | |||
| Status InferMirrorOps() override; | |||
| Status InferForwardCommunication() override { return SUCCESS; } | |||
| Status InferAsLossDivisor() override { return SUCCESS; } | |||
| #if (ENABLE_CPU && (ENABLE_D || ENABLE_GPU)) | |||
| #if (ENABLE_CPU && !_WIN32) | |||
| Status ComputeReplaceGraph(const CNodePtr &cnode); | |||
| #endif | |||
| @@ -47,14 +47,14 @@ | |||
| #include "ir/anf.h" | |||
| #include "ir/param_info.h" | |||
| #include "ir/tensor.h" | |||
| #if (ENABLE_CPU && (ENABLE_D || ENABLE_GPU)) | |||
| #if (ENABLE_CPU && !_WIN32) | |||
| #include "ps/util.h" | |||
| #endif | |||
| namespace mindspore { | |||
| namespace parallel { | |||
| bool StepAutoParallel(const FuncGraphPtr &root, const opt::OptimizerPtr &) { | |||
| #if (ENABLE_CPU && (ENABLE_D || ENABLE_GPU)) | |||
| #if (ENABLE_CPU && !_WIN32) | |||
| if (ps::Util::IsRoleOfPServer() || ps::Util::IsRoleOfScheduler()) { | |||
| return false; | |||
| } | |||
| @@ -46,7 +46,7 @@ | |||
| #include "utils/ms_context.h" | |||
| #include "utils/symbolic.h" | |||
| #include "mindspore/core/utils/parallel_node_check.h" | |||
| #if (ENABLE_CPU && (ENABLE_D || ENABLE_GPU)) | |||
| #if (ENABLE_CPU && !_WIN32) | |||
| #include "ps/util.h" | |||
| #include "ps/ps_context.h" | |||
| #endif | |||
| @@ -3553,7 +3553,7 @@ static void HandleFullySplitParameters(const FuncGraphPtr &root) { | |||
| } | |||
| bool StepParallel(const FuncGraphPtr &root, const opt::OptimizerPtr &optimizer) { | |||
| #if (ENABLE_CPU && (ENABLE_D || ENABLE_GPU)) | |||
| #if (ENABLE_CPU && !_WIN32) | |||
| if (ps::PSContext::instance()->is_server() || ps::PSContext::instance()->is_scheduler()) { | |||
| return false; | |||
| } | |||
| @@ -295,7 +295,7 @@ if(${CMAKE_SYSTEM_NAME} MATCHES "Windows") | |||
| target_link_libraries(_c_dataengine PRIVATE _c_mindrecord ${MINDRECORD_LINK_OBJECT} mindspore::sqlite) | |||
| else() | |||
| target_link_libraries(_c_dataengine PRIVATE _c_mindrecord) | |||
| if(ENABLE_CPU AND (ENABLE_D OR ENABLE_GPU)) | |||
| if(ENABLE_CPU AND NOT WIN32) | |||
| if(${ENABLE_IBVERBS} STREQUAL "ON") | |||
| target_link_libraries(_c_dataengine PRIVATE ibverbs rdmacm) | |||
| endif() | |||
| @@ -1,7 +1,8 @@ | |||
| add_subdirectory(perf EXCLUDE_FROM_ALL) | |||
| include_directories("${CMAKE_BINARY_DIR}/minddata/dataset/engine/cache") | |||
| set(MD_FLATBUFFER_OU "${CMAKE_BINARY_DIR}/minddata/dataset/engine/cache") | |||
| ms_build_flatbuffers("de_tensor.fbs" ${CMAKE_CURRENT_SOURCE_DIR} generated_engine_files ${MD_FLATBUFFER_OU}) | |||
| set(FBS_FILES de_tensor.fbs) | |||
| ms_build_flatbuffers(FBS_FILES ${CMAKE_CURRENT_SOURCE_DIR} generated_engine_files ${MD_FLATBUFFER_OU}) | |||
| file(GLOB_RECURSE _CURRENT_SRC_FILES RELATIVE ${CMAKE_CURRENT_SOURCE_DIR} "*.cc") | |||
| set_property(SOURCE ${_CURRENT_SRC_FILES} PROPERTY COMPILE_DEFINITIONS SUBMODULE_ID=mindspore::SubModuleId::SM_MD) | |||
| @@ -43,7 +43,7 @@ | |||
| #include "vm/transform.h" | |||
| #include "parse/python_adapter.h" | |||
| #include "frontend/optimizer/py_pass_manager.h" | |||
| #if (ENABLE_CPU && (ENABLE_D || ENABLE_GPU)) | |||
| #if (ENABLE_CPU && !_WIN32) | |||
| #include "ps/parameter_server.h" | |||
| #include "ps/scheduler.h" | |||
| #include "ps/worker.h" | |||
| @@ -606,7 +606,7 @@ bool ExecuteAction(const ResourcePtr &res) { | |||
| return true; | |||
| } | |||
| #if (ENABLE_CPU && (ENABLE_D || ENABLE_GPU)) | |||
| #if (ENABLE_CPU && !_WIN32) | |||
| bool StartPSWorkerAction(const ResourcePtr &res) { | |||
| ps::Worker::GetInstance().Run(); | |||
| return true; | |||
| @@ -782,7 +782,7 @@ std::vector<ActionItem> VmPipeline() { | |||
| actions.emplace_back(std::make_pair("auto_monad_reorder", OrderEnforceAction)); | |||
| actions.emplace_back(std::make_pair("validate", ValidateAction)); | |||
| #if (ENABLE_CPU && (ENABLE_D || ENABLE_GPU)) | |||
| #if (ENABLE_CPU && !_WIN32) | |||
| if (ps::PSContext::instance()->is_worker()) { | |||
| actions.emplace_back(std::make_pair("worker", StartPSWorkerAction)); | |||
| } | |||
| @@ -796,7 +796,7 @@ std::vector<ActionItem> VmPipeline() { | |||
| return actions; | |||
| } | |||
| #if (ENABLE_CPU && (ENABLE_D || ENABLE_GPU)) | |||
| #if (ENABLE_CPU && !_WIN32) | |||
| std::vector<ActionItem> PServerPipeline() { | |||
| auto actions = CommonPipeline(); | |||
| actions.emplace_back(std::make_pair("optimize", VmOptimizeAction)); | |||
| @@ -34,7 +34,7 @@ | |||
| #else | |||
| #include "runtime/device/gpu/distribution/collective_fake_init.h" | |||
| #endif | |||
| #if (ENABLE_CPU && (ENABLE_D || ENABLE_GPU)) | |||
| #if (ENABLE_CPU && !_WIN32) | |||
| #include "ps/util.h" | |||
| #endif | |||
| #include "ps/ps_context.h" | |||
| @@ -42,7 +42,7 @@ | |||
| #include "pipeline/jit/pipeline_split.h" | |||
| #include "pipeline/jit/static_analysis/auto_monad.h" | |||
| #include "frontend/optimizer/irpass/gradient_eliminate.h" | |||
| #if (ENABLE_CPU && (ENABLE_D || ENABLE_GPU)) | |||
| #if (ENABLE_CPU && !_WIN32) | |||
| #include "ps/util.h" | |||
| #include "ps/ps_context.h" | |||
| #endif | |||
| @@ -407,7 +407,7 @@ bool AddRecomputationPass(const ResourcePtr &res) { | |||
| } | |||
| bool AddCacheEmbeddingPass(const ResourcePtr &res) { | |||
| #if (ENABLE_CPU && (ENABLE_D || ENABLE_GPU)) | |||
| #if (ENABLE_CPU && !_WIN32) | |||
| if (ps::PSContext::instance()->is_ps_mode()) { | |||
| return true; | |||
| } | |||
| @@ -49,7 +49,7 @@ | |||
| #include "utils/shape_utils.h" | |||
| #include "utils/info.h" | |||
| #include "load_mindir/load_model.h" | |||
| #if (ENABLE_CPU && (ENABLE_D || ENABLE_GPU)) | |||
| #if (ENABLE_CPU && !_WIN32) | |||
| #include "ps/constants.h" | |||
| #include "ps/util.h" | |||
| #include "ps/worker.h" | |||
| @@ -528,7 +528,7 @@ std::vector<ActionItem> GetPipeline(const ResourcePtr &resource, const std::stri | |||
| std::string backend = MsContext::GetInstance()->backend_policy(); | |||
| #if (ENABLE_CPU && (ENABLE_D || ENABLE_GPU)) | |||
| #if (ENABLE_CPU && !_WIN32) | |||
| if (ps::PSContext::instance()->is_server()) { | |||
| resource->results()[kBackend] = compile::CreateBackend(); | |||
| return PServerPipeline(); | |||
| @@ -961,7 +961,7 @@ bool InitExecDataset(const std::string &queue_name, int64_t iter_num, int64_t ba | |||
| bool InitExecDatasetVm(const std::string &queue_name, int64_t size, int64_t batch_size, | |||
| const std::vector<TypePtr> &types, const std::vector<std::vector<int64_t>> &shapes, | |||
| const std::vector<int64_t> &input_indexes, bool need_run) { | |||
| #if (ENABLE_CPU && (ENABLE_D || ENABLE_GPU)) | |||
| #if (ENABLE_CPU && !_WIN32) | |||
| if ((ps::PSContext::instance()->is_ps_mode()) && (!ps::PSContext::instance()->is_worker())) { | |||
| return true; | |||
| } | |||
| @@ -1027,7 +1027,7 @@ bool InitExecDatasetVm(const std::string &queue_name, int64_t size, int64_t batc | |||
| } | |||
| ConfigManager::GetInstance().set_iter_num(size); | |||
| // PS cache does not support loop sink. | |||
| #if (ENABLE_CPU && (ENABLE_D || ENABLE_GPU)) | |||
| #if (ENABLE_CPU && !_WIN32) | |||
| if (ps::PSContext::instance()->is_worker() && ps::PsDataPrefetch::GetInstance().cache_enable()) { | |||
| ps::PsDataPrefetch::GetInstance().CreateDataChannel(queue_name, LongToSize(size)); | |||
| ConfigManager::GetInstance().set_iter_num(1); | |||
| @@ -1150,7 +1150,7 @@ void FinalizeBackend() { | |||
| void ClearResAtexit() { | |||
| MS_LOG(DEBUG) << "Pipeline clear all resource"; | |||
| pynative::ClearPyNativeSession(); | |||
| #if (ENABLE_CPU && (ENABLE_D || ENABLE_GPU)) | |||
| #if (ENABLE_CPU && !_WIN32) | |||
| if (ps::PSContext::instance()->is_ps_mode() && ps::PSContext::instance()->is_worker()) { | |||
| if (ps::PsDataPrefetch::GetInstance().cache_enable()) { | |||
| ps::ps_cache_instance.Finalize(); | |||
| @@ -1,6 +1,13 @@ | |||
| file(GLOB_RECURSE _PS_SRC_FILES RELATIVE ${CMAKE_CURRENT_SOURCE_DIR} "*.cc") | |||
| if(NOT (ENABLE_CPU AND (ENABLE_D OR ENABLE_GPU))) | |||
| set(SERVER_FLATBUFFER_OUTPUT "${CMAKE_BINARY_DIR}/schema") | |||
| set(FBS_FILES | |||
| ${CMAKE_CURRENT_SOURCE_DIR}/../../schema/cipher.fbs | |||
| ${CMAKE_CURRENT_SOURCE_DIR}/../../schema/fl_job.fbs | |||
| ) | |||
| ms_build_flatbuffers(FBS_FILES ${CMAKE_CURRENT_SOURCE_DIR}../../schema generated_fbs_files ${SERVER_FLATBUFFER_OUTPUT}) | |||
| if(NOT ENABLE_CPU OR WIN32) | |||
| list(REMOVE_ITEM _PS_SRC_FILES "optimizer_info_builder.cc") | |||
| list(REMOVE_ITEM _PS_SRC_FILES "optimizer_info.cc") | |||
| list(REMOVE_ITEM _PS_SRC_FILES "scheduler.cc") | |||
| @@ -12,11 +19,6 @@ if(NOT (ENABLE_CPU AND (ENABLE_D OR ENABLE_GPU))) | |||
| list(REMOVE_ITEM _PS_SRC_FILES "core/communicator/tcp_client.cc") | |||
| list(REMOVE_ITEM _PS_SRC_FILES "core/communicator/tcp_message_handler.cc") | |||
| list(REMOVE_ITEM _PS_SRC_FILES "core/communicator/tcp_server.cc") | |||
| list(REMOVE_ITEM _PS_SRC_FILES "core/communicator/communicator_base.cc") | |||
| list(REMOVE_ITEM _PS_SRC_FILES "core/communicator/http_communicator.cc") | |||
| list(REMOVE_ITEM _PS_SRC_FILES "core/communicator/tcp_communicator.cc") | |||
| list(REMOVE_ITEM _PS_SRC_FILES "core/communicator/http_msg_handler.cc") | |||
| list(REMOVE_ITEM _PS_SRC_FILES "core/communicator/tcp_msg_handler.cc") | |||
| list(REMOVE_ITEM _PS_SRC_FILES "core/node.cc") | |||
| list(REMOVE_ITEM _PS_SRC_FILES "core/node_manager.cc") | |||
| list(REMOVE_ITEM _PS_SRC_FILES "ps_cache/ps_cache_manager.cc") | |||
| @@ -39,18 +41,32 @@ if(NOT ENABLE_GPU) | |||
| list(REMOVE_ITEM _PS_SRC_FILES "ps_cache/gpu/gpu_ps_cache.cc") | |||
| endif() | |||
| if(WIN32 OR NOT ENABLE_CPU) | |||
| if(NOT ENABLE_CPU OR WIN32) | |||
| list(REMOVE_ITEM _PS_SRC_FILES "core/communicator/communicator_base.cc") | |||
| list(REMOVE_ITEM _PS_SRC_FILES "core/communicator/http_communicator.cc") | |||
| list(REMOVE_ITEM _PS_SRC_FILES "core/communicator/tcp_communicator.cc") | |||
| list(REMOVE_ITEM _PS_SRC_FILES "core/communicator/http_msg_handler.cc") | |||
| list(REMOVE_ITEM _PS_SRC_FILES "core/communicator/tcp_msg_handler.cc") | |||
| list(REMOVE_ITEM _PS_SRC_FILES "server/kernel/apply_momentum_kernel.cc") | |||
| list(REMOVE_ITEM _PS_SRC_FILES "server/kernel/aggregation_kernel_factory.cc") | |||
| list(REMOVE_ITEM _PS_SRC_FILES "server/kernel/dense_grad_accum_kernel.cc") | |||
| list(REMOVE_ITEM _PS_SRC_FILES "server/kernel/optimizer_kernel_factory.cc") | |||
| list(REMOVE_ITEM _PS_SRC_FILES "server/kernel/round/round_kernel_factory.cc") | |||
| list(REMOVE_ITEM _PS_SRC_FILES "server/kernel/round/round_kernel.cc") | |||
| list(REMOVE_ITEM _PS_SRC_FILES "server/kernel/round/start_fl_job_kernel.cc") | |||
| list(REMOVE_ITEM _PS_SRC_FILES "server/kernel/params_info.cc") | |||
| list(REMOVE_ITEM _PS_SRC_FILES "server/consistent_hash_ring.cc") | |||
| list(REMOVE_ITEM _PS_SRC_FILES "server/iteration_timer.cc") | |||
| list(REMOVE_ITEM _PS_SRC_FILES "server/local_meta_storage.cc") | |||
| list(REMOVE_ITEM _PS_SRC_FILES "server/local_meta_store.cc") | |||
| list(REMOVE_ITEM _PS_SRC_FILES "server/memory_register.cc") | |||
| list(REMOVE_ITEM _PS_SRC_FILES "server/parameter_aggregator.cc") | |||
| list(REMOVE_ITEM _PS_SRC_FILES "server/executor.cc") | |||
| list(REMOVE_ITEM _PS_SRC_FILES "server/collective_ops_impl.cc") | |||
| list(REMOVE_ITEM _PS_SRC_FILES "server/distributed_count_service.cc") | |||
| list(REMOVE_ITEM _PS_SRC_FILES "server/distributed_metadata_store.cc") | |||
| list(REMOVE_ITEM _PS_SRC_FILES "server/iteration.cc") | |||
| list(REMOVE_ITEM _PS_SRC_FILES "server/model_store.cc") | |||
| list(REMOVE_ITEM _PS_SRC_FILES "server/round.cc") | |||
| endif() | |||
| list(REMOVE_ITEM _PS_SRC_FILES "ps_cache/ps_data/ps_data_prefetch.cc") | |||
| @@ -59,3 +75,5 @@ add_subdirectory(ps_cache) | |||
| set_property(SOURCE ${_PS_SRC_FILES} PROPERTY COMPILE_DEFINITIONS SUBMODULE_ID=mindspore::SubModuleId::SM_PS) | |||
| add_library(_mindspore_ps_obj OBJECT ${_PS_SRC_FILES}) | |||
| add_dependencies(_mindspore_ps_obj generated_fbs_files) | |||
| target_link_libraries(_mindspore_ps_obj mindspore::flatbuffers) | |||
| @@ -34,13 +34,26 @@ | |||
| namespace mindspore { | |||
| namespace ps { | |||
| namespace core { | |||
| enum class TcpUserCommand { kPush, kPull, kCount, kReachThreshold, kResetCount, kGetValue, kPutValue, kCounterEvent }; | |||
| enum class TcpUserCommand { | |||
| kPush, | |||
| kPull, | |||
| kCount, | |||
| kReachThreshold, | |||
| kResetCount, | |||
| kGetMetadata, | |||
| kUpdateMetadata, | |||
| kCounterEvent | |||
| }; | |||
| const std::unordered_map<TcpUserCommand, std::string> kUserCommandToMsgType = { | |||
| {TcpUserCommand::kPush, "push"}, {TcpUserCommand::kPull, "pull"}, | |||
| {TcpUserCommand::kCount, "count"}, {TcpUserCommand::kReachThreshold, "reachThreshold"}, | |||
| {TcpUserCommand::kResetCount, "resetCnt"}, {TcpUserCommand::kGetValue, "getValue"}, | |||
| {TcpUserCommand::kPutValue, "putValue"}, {TcpUserCommand::kCounterEvent, "counterEvent"}, | |||
| {TcpUserCommand::kPush, "push"}, | |||
| {TcpUserCommand::kPull, "pull"}, | |||
| {TcpUserCommand::kCount, "count"}, | |||
| {TcpUserCommand::kReachThreshold, "countReachThreshold"}, | |||
| {TcpUserCommand::kResetCount, "resetCnt"}, | |||
| {TcpUserCommand::kGetMetadata, "getMetadata"}, | |||
| {TcpUserCommand::kUpdateMetadata, "updateMetadata"}, | |||
| {TcpUserCommand::kCounterEvent, "counterEvent"}, | |||
| }; | |||
| class TcpCommunicator : public CommunicatorBase { | |||
| @@ -0,0 +1,155 @@ | |||
| /** | |||
| * Copyright 2021 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| syntax = "proto3"; | |||
| package mindspore.ps; | |||
| message CollectiveData { | |||
| bytes data = 1; | |||
| } | |||
| message CountRequest { | |||
| string name = 1; | |||
| string id = 2; | |||
| } | |||
| message CountResponse { | |||
| bool result = 1; | |||
| string reason = 2; | |||
| } | |||
| message CountReachThresholdRequest { | |||
| string name = 1; | |||
| } | |||
| message CountReachThresholdResponse { | |||
| bool is_enough = 1; | |||
| } | |||
| message ResetCounterRequest { | |||
| string name = 1; | |||
| } | |||
| message UpdateMetadataRequest { | |||
| string name = 1; | |||
| bytes value = 2; | |||
| } | |||
| message GetMetadataRequest { | |||
| string name = 1; | |||
| } | |||
| message GetMetadataResponse { | |||
| bytes value = 1; | |||
| } | |||
| enum CounterEventType { | |||
| FIRST_CNT = 0; | |||
| LAST_CNT = 1; | |||
| } | |||
| message CounterEvent { | |||
| CounterEventType type = 1; | |||
| string name = 2; | |||
| bytes data = 3; | |||
| } | |||
| message FLId { | |||
| string fl_id = 1; | |||
| } | |||
| message UpdateModelClientList { | |||
| repeated string fl_id = 1; | |||
| } | |||
| message DeviceMeta { | |||
| string fl_name = 1; | |||
| string fl_id = 2; | |||
| uint64 data_size = 3; | |||
| } | |||
| message FLIdToDeviceMeta { | |||
| map<string, DeviceMeta> fl_id_to_meta = 1; | |||
| } | |||
| message UpdateModelThreshold { | |||
| uint64 threshold = 1; | |||
| } | |||
| message ClientShares { | |||
| map<string, SharesPb> client_secret_shares = 1; | |||
| } | |||
| message PairClientShares { | |||
| string fl_id = 1; | |||
| SharesPb client_shares = 2; | |||
| } | |||
| message ClientKeys { | |||
| map<string, KeysPb> client_keys = 1; | |||
| } | |||
| message ClientNoises { | |||
| OneClientNoises one_client_noises = 1; | |||
| } | |||
| message PairClientKeys { | |||
| string fl_id = 1; | |||
| KeysPb client_keys = 2; | |||
| } | |||
| message OneClientNoises { | |||
| repeated float noise = 1; | |||
| } | |||
| message ClientShareStr { | |||
| string fl_id = 1; | |||
| bytes share = 2; // todo: verify the correctness | |||
| int32 index = 3; | |||
| } | |||
| message SharesPb { | |||
| repeated ClientShareStr clientsharestrs = 1; | |||
| } | |||
| message KeysPb { | |||
| repeated bytes key = 1; | |||
| } | |||
| message PBMetadata { | |||
| oneof value { | |||
| DeviceMeta device_meta = 1; | |||
| FLIdToDeviceMeta device_metas = 2; | |||
| FLId fl_id = 3; | |||
| UpdateModelClientList client_list = 4; | |||
| UpdateModelThreshold update_model_threshold = 5; | |||
| PairClientShares pair_client_shares = 6; | |||
| ClientShares client_shares = 7; | |||
| PairClientKeys pair_client_keys = 8; | |||
| ClientKeys client_keys = 9; | |||
| OneClientNoises one_client_noises = 10; | |||
| ClientNoises client_noises = 11; | |||
| } | |||
| } | |||
| message PBMetadataWithName { | |||
| string name = 1; | |||
| PBMetadata metadata = 2; | |||
| } | |||
| @@ -60,4 +60,4 @@ message EmbeddingTableLookup { | |||
| uint64 key = 2; | |||
| repeated int32 keys = 3; | |||
| repeated float values = 4; | |||
| } | |||
| } | |||
| @@ -1,4 +1,4 @@ | |||
| if(ENABLE_CPU AND (ENABLE_D OR ENABLE_GPU)) | |||
| if(ENABLE_CPU AND NOT WIN32) | |||
| file(GLOB_RECURSE _PS_CACHE_SRC_FILES RELATIVE ${CMAKE_CURRENT_SOURCE_DIR} "ps_data/*.cc") | |||
| set_property(SOURCE ${_PS_CACHE_SRC_FILES} PROPERTY COMPILE_DEFINITIONS SUBMODULE_ID=mindspore::SubModuleId::SM_PS) | |||
| add_library(ps_cache SHARED ${_PS_CACHE_SRC_FILES}) | |||
| @@ -18,7 +18,7 @@ | |||
| #include "utils/log_adapter.h" | |||
| #include "utils/ms_utils.h" | |||
| #include "backend/kernel_compiler/kernel.h" | |||
| #if (ENABLE_CPU && (ENABLE_D || ENABLE_GPU)) | |||
| #if (ENABLE_CPU && !_WIN32) | |||
| #include "ps/ps_cache/ps_cache_manager.h" | |||
| #include "ps/ps_cache/ps_data/ps_data_prefetch.h" | |||
| #endif | |||
| @@ -68,7 +68,7 @@ void PSContext::Reset() { | |||
| is_worker_ = false; | |||
| is_pserver_ = false; | |||
| is_sched_ = false; | |||
| #if (ENABLE_CPU && (ENABLE_D || ENABLE_GPU)) | |||
| #if (ENABLE_CPU && !_WIN32) | |||
| if (ps::PsDataPrefetch::GetInstance().cache_enable()) { | |||
| ps_cache_instance.Finalize(); | |||
| set_cache_enable(false); | |||
| @@ -108,46 +108,62 @@ int PSContext::ps_rank_id() const { return rank_id_; } | |||
| void PSContext::InsertHashTableSize(const std::string ¶m_name, size_t cache_vocab_size, size_t embedding_size, | |||
| size_t vocab_size) const { | |||
| #if (ENABLE_CPU && (ENABLE_D || ENABLE_GPU)) | |||
| #if (ENABLE_CPU && !_WIN32) | |||
| ps_cache_instance.InsertHashTableSize(param_name, cache_vocab_size, embedding_size, vocab_size); | |||
| #endif | |||
| } | |||
| void PSContext::ReInsertHashTableSize(const std::string &new_param_name, const std::string &cur_param_name, | |||
| size_t cache_vocab_size, size_t embedding_size) const { | |||
| #if (ENABLE_CPU && (ENABLE_D || ENABLE_GPU)) | |||
| #if (ENABLE_CPU && !_WIN32) | |||
| ps_cache_instance.ReInsertHashTableSize(new_param_name, cur_param_name, cache_vocab_size, embedding_size); | |||
| #endif | |||
| } | |||
| void PSContext::InsertWeightInitInfo(const std::string ¶m_name, size_t global_seed, size_t op_seed) const { | |||
| #if (ENABLE_CPU && (ENABLE_D || ENABLE_GPU)) | |||
| #if (ENABLE_CPU && !_WIN32) | |||
| ps_cache_instance.InsertWeightInitInfo(param_name, global_seed, op_seed); | |||
| #endif | |||
| } | |||
| void PSContext::InsertAccumuInitInfo(const std::string ¶m_name, float init_val) const { | |||
| #if (ENABLE_CPU && (ENABLE_D || ENABLE_GPU)) | |||
| #if (ENABLE_CPU && !_WIN32) | |||
| ps_cache_instance.InsertAccumuInitInfo(param_name, init_val); | |||
| #endif | |||
| } | |||
| void PSContext::CloneHashTable(const std::string &dest_param_name, const std::string &src_param_name) const { | |||
| #if (ENABLE_CPU && (ENABLE_D || ENABLE_GPU)) | |||
| #if (ENABLE_CPU && !_WIN32) | |||
| ps_cache_instance.CloneHashTable(dest_param_name, src_param_name); | |||
| #endif | |||
| } | |||
| void PSContext::set_cache_enable(bool cache_enable) const { | |||
| #if (ENABLE_CPU && (ENABLE_D || ENABLE_GPU)) | |||
| #if (ENABLE_CPU && !_WIN32) | |||
| PsDataPrefetch::GetInstance().set_cache_enable(cache_enable); | |||
| #endif | |||
| } | |||
| void PSContext::set_rank_id(int rank_id) const { | |||
| #if (ENABLE_CPU && (ENABLE_D || ENABLE_GPU)) | |||
| #if (ENABLE_CPU && !_WIN32) | |||
| ps_cache_instance.set_rank_id(rank_id); | |||
| #endif | |||
| } | |||
| void PSContext::set_fl_name(const std::string &fl_name) { fl_name_ = fl_name; } | |||
| const std::string &PSContext::fl_name() const { return fl_name_; } | |||
| void PSContext::set_fl_iteration_num(uint64_t fl_iteration_num) { fl_iteration_num_ = fl_iteration_num; } | |||
| uint64_t PSContext::fl_iteration_num() const { return fl_iteration_num_; } | |||
| void PSContext::set_client_epoch_num(uint64_t client_epoch_num) { client_epoch_num_ = client_epoch_num; } | |||
| uint64_t PSContext::client_epoch_num() const { return client_epoch_num_; } | |||
| void PSContext::set_client_batch_size(uint64_t client_batch_size) { client_batch_size_ = client_batch_size; } | |||
| uint64_t PSContext::client_batch_size() const { return client_batch_size_; } | |||
| } // namespace ps | |||
| } // namespace mindspore | |||
| @@ -60,6 +60,19 @@ class PSContext { | |||
| void set_cache_enable(bool cache_enable) const; | |||
| void set_rank_id(int rank_id) const; | |||
| // Setter and getter for federated learning. | |||
| void set_fl_name(const std::string &fl_name); | |||
| const std::string &fl_name() const; | |||
| void set_fl_iteration_num(uint64_t fl_iteration_num); | |||
| uint64_t fl_iteration_num() const; | |||
| void set_client_epoch_num(uint64_t client_epoch_num); | |||
| uint64_t client_epoch_num() const; | |||
| void set_client_batch_size(uint64_t client_batch_size); | |||
| uint64_t client_batch_size() const; | |||
| private: | |||
| PSContext() | |||
| : ps_enabled_(false), | |||
| @@ -80,6 +93,12 @@ class PSContext { | |||
| uint32_t server_num_; | |||
| std::string scheduler_host_; | |||
| uint16_t scheduler_port_; | |||
| // Members for federated learning. | |||
| std::string fl_name_; | |||
| uint64_t fl_iteration_num_; | |||
| uint64_t client_epoch_num_; | |||
| uint64_t client_batch_size_; | |||
| }; | |||
| } // namespace ps | |||
| } // namespace mindspore | |||
| @@ -0,0 +1,223 @@ | |||
| /** | |||
| * Copyright 2021 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #include "ps/server/collective_ops_impl.h" | |||
| namespace mindspore { | |||
| namespace ps { | |||
| namespace server { | |||
| void CollectiveOpsImpl::Initialize(const std::shared_ptr<core::ServerNode> &server_node) { | |||
| MS_EXCEPTION_IF_NULL(server_node); | |||
| server_node_ = server_node; | |||
| local_rank_ = server_node_->rank_id(); | |||
| server_num_ = PSContext::instance()->initial_server_num(); | |||
| return; | |||
| } | |||
| template <typename T> | |||
| bool CollectiveOpsImpl::RingAllReduce(const void *sendbuff, void *recvbuff, size_t count) { | |||
| int ret = memcpy_s(recvbuff, count * sizeof(T), sendbuff, count * sizeof(T)); | |||
| if (ret != 0) { | |||
| MS_LOG(ERROR) << "memcpy_s error, errorno(" << ret << ")"; | |||
| return false; | |||
| } | |||
| uint32_t rank_size = server_num_; | |||
| uint32_t local_rank_ = server_node_->rank_id(); | |||
| size_t chunk_size = count / rank_size; | |||
| size_t remainder_size = count % rank_size; | |||
| std::vector<size_t> chunk_sizes(rank_size, chunk_size); | |||
| // The rest of the data should be assigned to each chunk. | |||
| for (size_t i = 0; i < remainder_size; i++) { | |||
| chunk_sizes[i]++; | |||
| } | |||
| // Store offsets to get every data chunk's address. | |||
| std::vector<size_t> chunk_offset; | |||
| for (size_t i = 0; i < rank_size; i++) { | |||
| size_t ofs = | |||
| std::accumulate(chunk_sizes.begin(), chunk_sizes.begin() + i, static_cast<size_t>(0), std::plus<size_t>()); | |||
| chunk_offset.push_back(ofs); | |||
| } | |||
| T *output_buff = reinterpret_cast<T *>(recvbuff); | |||
| uint32_t send_to_rank = (local_rank_ + 1) % rank_size; | |||
| uint32_t recv_from_rank = (local_rank_ - 1 + rank_size) % rank_size; | |||
| MS_LOG(DEBUG) << "AllReduce count:" << count << ", rank_size:" << rank_size << ", local_rank_:" << local_rank_ | |||
| << ", chunk_size:" << chunk_size << ", remainder_size:" << remainder_size | |||
| << ", chunk_sizes:" << chunk_sizes << ", send_to_rank:" << send_to_rank | |||
| << ", recv_from_rank:" << recv_from_rank; | |||
| // Ring ReduceScatter. | |||
| MS_LOG(DEBUG) << "Start Ring ReduceScatter."; | |||
| std::unique_ptr<T[]> tmp_recv_chunk = std::make_unique<T[]>(chunk_sizes[0]); | |||
| for (size_t i = 0; i < rank_size - 1; i++) { | |||
| // Step 1: Async send data to next rank. | |||
| size_t send_chunk_index = (local_rank_ - i + rank_size) % rank_size; | |||
| T *send_chunk = output_buff + chunk_offset[send_chunk_index]; | |||
| auto send_req_id = server_node_->CollectiveSendAsync(core::NodeRole::SERVER, send_to_rank, send_chunk, | |||
| chunk_sizes[send_chunk_index] * sizeof(T)); | |||
| // Step 2: Async receive data to next rank and wait until it's done. | |||
| size_t recv_chunk_index = (local_rank_ - i - 1 + rank_size) % rank_size; | |||
| T *recv_chunk = output_buff + chunk_offset[recv_chunk_index]; | |||
| MS_LOG(DEBUG) << "Ring ReduceScatter send_to_rank:" << send_to_rank << ", recv_from_rank:" << recv_from_rank | |||
| << ", send count:" << chunk_sizes[send_chunk_index] | |||
| << ", recv count:" << chunk_sizes[recv_chunk_index] << ", iteration:" << i; | |||
| std::shared_ptr<std::vector<unsigned char>> recv_str; | |||
| auto recv_req_id = server_node_->CollectiveReceiveAsync(core::NodeRole::SERVER, recv_from_rank, &recv_str); | |||
| if (!server_node_->CollectiveWait(recv_req_id, 1)) { | |||
| MS_LOG(ERROR) << "CollectiveWait " << recv_req_id << " failed."; | |||
| return false; | |||
| } | |||
| memcpy_s(tmp_recv_chunk.get(), chunk_sizes[recv_chunk_index] * sizeof(T), recv_str->data(), recv_str->size()); | |||
| // Step 3: Reduce the data so we can overlap the time cost of send. | |||
| for (size_t j = 0; j < chunk_sizes[recv_chunk_index]; j++) { | |||
| recv_chunk[j] += tmp_recv_chunk[j]; | |||
| } | |||
| // Step 4: Wait until send is done. | |||
| if (!server_node_->Wait(send_req_id, 1)) { | |||
| MS_LOG(ERROR) << "CollectiveWait " << send_req_id << " failed."; | |||
| return false; | |||
| } | |||
| } | |||
| MS_LOG(DEBUG) << "End Ring ReduceScatter."; | |||
| // Ring AllGather. | |||
| MS_LOG(DEBUG) << "Start Ring AllGather."; | |||
| for (size_t i = 0; i < rank_size - 1; i++) { | |||
| size_t send_chunk_index = (local_rank_ - i + 1 + rank_size) % rank_size; | |||
| T *send_chunk = output_buff + chunk_offset[send_chunk_index]; | |||
| auto send_req_id = server_node_->CollectiveSendAsync(core::NodeRole::SERVER, send_to_rank, send_chunk, | |||
| chunk_sizes[send_chunk_index] * sizeof(T)); | |||
| size_t recv_chunk_index = (local_rank_ - i + rank_size) % rank_size; | |||
| T *recv_chunk = output_buff + chunk_offset[recv_chunk_index]; | |||
| MS_LOG(DEBUG) << "Ring AllGather send_to_rank:" << send_to_rank << ", recv_from_rank:" << recv_from_rank | |||
| << ", send count:" << chunk_sizes[send_chunk_index] | |||
| << ", recv count:" << chunk_sizes[recv_chunk_index] << ", iteration:" << i; | |||
| std::shared_ptr<std::vector<unsigned char>> recv_str; | |||
| auto recv_req_id = server_node_->CollectiveReceiveAsync(core::NodeRole::SERVER, recv_from_rank, &recv_str); | |||
| if (!server_node_->CollectiveWait(recv_req_id, 1)) { | |||
| MS_LOG(ERROR) << "CollectiveWait " << recv_req_id << " failed."; | |||
| return false; | |||
| } | |||
| memcpy_s(recv_chunk, chunk_sizes[recv_chunk_index] * sizeof(T), recv_str->data(), recv_str->size()); | |||
| if (!server_node_->Wait(send_req_id, 1)) { | |||
| MS_LOG(ERROR) << "CollectiveWait " << send_req_id << " failed."; | |||
| return false; | |||
| } | |||
| } | |||
| MS_LOG(DEBUG) << "End Ring AllGather."; | |||
| return true; | |||
| } | |||
| template <typename T> | |||
| bool CollectiveOpsImpl::ReduceBroadcastAllReduce(const void *sendbuff, void *recvbuff, size_t count) { | |||
| uint32_t rank_size = server_num_; | |||
| uint32_t local_rank_ = server_node_->rank_id(); | |||
| MS_LOG(DEBUG) << "Reduce Broadcast AllReduce rank_size:" << rank_size << ", local_rank_:" << local_rank_ | |||
| << ", count:" << count; | |||
| int ret = memcpy_s(recvbuff, count * sizeof(T), sendbuff, count * sizeof(T)); | |||
| if (ret != 0) { | |||
| MS_LOG(ERROR) << "memcpy_s error, errorno(" << ret << ")"; | |||
| return false; | |||
| } | |||
| T *output_buff = reinterpret_cast<T *>(recvbuff); | |||
| // Reduce data to rank 0 process. | |||
| MS_LOG(DEBUG) << "Start Reduce to rank 0 process."; | |||
| if (local_rank_ == 0) { | |||
| std::unique_ptr<T[]> tmp_recv_buff = std::make_unique<T[]>(count); | |||
| for (uint32_t i = 1; i < rank_size; i++) { | |||
| std::shared_ptr<std::vector<unsigned char>> recv_str; | |||
| MS_LOG(DEBUG) << "Reduce rank 0 receive from rank " << i; | |||
| auto recv_req_id = server_node_->CollectiveReceiveAsync(core::NodeRole::SERVER, i, &recv_str); | |||
| if (!server_node_->CollectiveWait(recv_req_id, 1)) { | |||
| MS_LOG(ERROR) << "CollectiveWait " << recv_req_id << " failed."; | |||
| return false; | |||
| } | |||
| memcpy_s(tmp_recv_buff.get(), count * sizeof(T), recv_str->data(), recv_str->size()); | |||
| for (size_t j = 0; j < count; j++) { | |||
| output_buff[j] += tmp_recv_buff[j]; | |||
| } | |||
| } | |||
| } else { | |||
| MS_LOG(DEBUG) << "Reduce send data to rank 0 process."; | |||
| auto send_req_id = server_node_->CollectiveSendAsync(core::NodeRole::SERVER, 0, sendbuff, count * sizeof(T)); | |||
| if (!server_node_->Wait(send_req_id, 1)) { | |||
| MS_LOG(ERROR) << "CollectiveWait " << send_req_id << " failed."; | |||
| return false; | |||
| } | |||
| } | |||
| MS_LOG(DEBUG) << "End Reduce."; | |||
| // Broadcast data to not 0 rank process. | |||
| MS_LOG(DEBUG) << "Start broadcast from rank 0 to other processes."; | |||
| if (local_rank_ == 0) { | |||
| for (uint32_t i = 1; i < rank_size; i++) { | |||
| MS_LOG(DEBUG) << "Broadcast data to process " << i; | |||
| auto send_req_id = server_node_->CollectiveSendAsync(core::NodeRole::SERVER, i, output_buff, count * sizeof(T)); | |||
| if (!server_node_->Wait(send_req_id, 1)) { | |||
| MS_LOG(ERROR) << "CollectiveWait " << send_req_id << " failed."; | |||
| return false; | |||
| } | |||
| } | |||
| } else { | |||
| MS_LOG(DEBUG) << "Broadcast receive from rank 0."; | |||
| std::shared_ptr<std::vector<unsigned char>> recv_str; | |||
| auto recv_req_id = server_node_->CollectiveReceiveAsync(core::NodeRole::SERVER, 0, &recv_str); | |||
| if (!server_node_->CollectiveWait(recv_req_id, 1)) { | |||
| MS_LOG(ERROR) << "CollectiveWait " << recv_req_id << " failed."; | |||
| return false; | |||
| } | |||
| memcpy_s(output_buff, count * sizeof(T), recv_str->data(), recv_str->size()); | |||
| } | |||
| MS_LOG(DEBUG) << "End broadcast."; | |||
| return true; | |||
| } | |||
| template <typename T> | |||
| bool CollectiveOpsImpl::AllReduce(const void *sendbuff, void *recvbuff, size_t count) { | |||
| // The collective communication API does not support calling Send and Recv concurrently with multiple threads; | |||
| std::unique_lock<std::mutex> lock(mtx_); | |||
| if (sendbuff == nullptr || recvbuff == nullptr) { | |||
| MS_LOG(ERROR) << "AllReduce sendbuff or recvbuff is nullptr."; | |||
| return false; | |||
| } | |||
| uint32_t rank_size = server_num_; | |||
| if (count >= rank_size) { | |||
| return RingAllReduce<T>(sendbuff, recvbuff, count); | |||
| } else { | |||
| return ReduceBroadcastAllReduce<T>(sendbuff, recvbuff, count); | |||
| } | |||
| } | |||
| template bool CollectiveOpsImpl::RingAllReduce<float>(const void *sendbuff, void *recvbuff, size_t count); | |||
| template bool CollectiveOpsImpl::RingAllReduce<size_t>(const void *sendbuff, void *recvbuff, size_t count); | |||
| template bool CollectiveOpsImpl::RingAllReduce<int>(const void *sendbuff, void *recvbuff, size_t count); | |||
| template bool CollectiveOpsImpl::ReduceBroadcastAllReduce<float>(const void *sendbuff, void *recvbuff, size_t count); | |||
| template bool CollectiveOpsImpl::ReduceBroadcastAllReduce<size_t>(const void *sendbuff, void *recvbuff, size_t count); | |||
| template bool CollectiveOpsImpl::ReduceBroadcastAllReduce<int>(const void *sendbuff, void *recvbuff, size_t count); | |||
| template bool CollectiveOpsImpl::AllReduce<float>(const void *sendbuff, void *recvbuff, size_t count); | |||
| template bool CollectiveOpsImpl::AllReduce<size_t>(const void *sendbuff, void *recvbuff, size_t count); | |||
| template bool CollectiveOpsImpl::AllReduce<int>(const void *sendbuff, void *recvbuff, size_t count); | |||
| } // namespace server | |||
| } // namespace ps | |||
| } // namespace mindspore | |||
| @@ -0,0 +1,71 @@ | |||
| /** | |||
| * Copyright 2021 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #ifndef MINDSPORE_CCSRC_PS_SERVER_COLLECTIVE_OPS_IMPL_H_ | |||
| #define MINDSPORE_CCSRC_PS_SERVER_COLLECTIVE_OPS_IMPL_H_ | |||
| #include <memory> | |||
| #include <string> | |||
| #include <vector> | |||
| #include <functional> | |||
| #include "proto/ps.pb.h" | |||
| #include "ps/ps_context.h" | |||
| #include "ps/core/server_node.h" | |||
| #include "ps/server/common.h" | |||
| namespace mindspore { | |||
| namespace ps { | |||
| namespace server { | |||
| // CollectiveOpsImpl is the collective communication API of the server. | |||
| // For now, it implements two AllReduce algorithms: RingAllReduce and BroadcastAllReduce. Elastic AllReduce is also | |||
| // supported for the elastic scaling feature of the server. | |||
| class CollectiveOpsImpl { | |||
| public: | |||
| static CollectiveOpsImpl &GetInstance() { | |||
| static CollectiveOpsImpl instance; | |||
| return instance; | |||
| } | |||
| void Initialize(const std::shared_ptr<core::ServerNode> &server_node); | |||
| template <typename T> | |||
| bool AllReduce(const void *sendbuff, void *recvbuff, size_t count); | |||
| private: | |||
| CollectiveOpsImpl() = default; | |||
| ~CollectiveOpsImpl() = default; | |||
| CollectiveOpsImpl(const CollectiveOpsImpl &) = delete; | |||
| CollectiveOpsImpl &operator=(const CollectiveOpsImpl &) = delete; | |||
| // Implementation of RingAllReduce. | |||
| template <typename T> | |||
| bool RingAllReduce(const void *sendbuff, void *recvbuff, size_t count); | |||
| // Implementation of BroadcastAllReduce. | |||
| template <typename T> | |||
| bool ReduceBroadcastAllReduce(const void *sendbuff, void *recvbuff, size_t count); | |||
| std::shared_ptr<core::ServerNode> server_node_; | |||
| uint32_t local_rank_; | |||
| uint32_t server_num_; | |||
| // The mutex to ensure that collective communication is threadsafe. | |||
| std::mutex mtx_; | |||
| }; | |||
| } // namespace server | |||
| } // namespace ps | |||
| } // namespace mindspore | |||
| #endif // MINDSPORE_CCSRC_PS_SERVER_COLLECTIVE_OPS_IMPL_H_ | |||
| @@ -24,13 +24,17 @@ | |||
| #include <memory> | |||
| #include <functional> | |||
| #include "proto/ps.pb.h" | |||
| #include "proto/fl.pb.h" | |||
| #include "ir/anf.h" | |||
| #include "utils/utils.h" | |||
| #include "ir/dtype/type_id.h" | |||
| #include "backend/kernel_compiler/cpu/cpu_kernel.h" | |||
| #include "schema/fl_job_generated.h" | |||
| #include "schema/cipher_generated.h" | |||
| #include "ps/ps_context.h" | |||
| #include "ps/core/communicator/http_message_handler.h" | |||
| #include "ps/core/communicator/tcp_server.h" | |||
| #include "ps/core/communicator/message_handler.h" | |||
| namespace mindspore { | |||
| namespace ps { | |||
| @@ -40,13 +44,15 @@ enum ServerMode { PARAMETER_SERVER = 0, FL_SERVER }; | |||
| enum CommType { HTTP = 0, TCP }; | |||
| enum AggregationType { FedAvg = 0, FedAdam, FedAdagarg, FedMeta, qffl, DenseGradAccum, SparseGradAccum }; | |||
| using kernel::Address; | |||
| using kernel::AddressPtr; | |||
| using kernel::CPUKernel; | |||
| using mindspore::kernel::Address; | |||
| using mindspore::kernel::AddressPtr; | |||
| using mindspore::kernel::CPUKernel; | |||
| using FBBuilder = flatbuffers::FlatBufferBuilder; | |||
| using TimeOutCb = std::function<void(void)>; | |||
| using StopTimerCb = std::function<void(void)>; | |||
| using FinishIterCb = std::function<void(void)>; | |||
| using FinalizeCb = std::function<void(void)>; | |||
| using MessageCallback = std::function<void(const std::shared_ptr<core::MessageHandler> &)>; | |||
| // Information about whether server kernel will reuse kernel node memory from the front end. | |||
| // Key refers to the server kernel's parameter name, like "weights", "grad", "learning_rate". | |||
| @@ -0,0 +1,298 @@ | |||
| /** | |||
| * Copyright 2021 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #include "ps/server/distributed_count_service.h" | |||
| #include <string> | |||
| #include <memory> | |||
| #include <vector> | |||
| namespace mindspore { | |||
| namespace ps { | |||
| namespace server { | |||
| void DistributedCountService::Initialize(const std::shared_ptr<core::ServerNode> &server_node, | |||
| uint32_t counting_server_rank) { | |||
| server_node_ = server_node; | |||
| MS_EXCEPTION_IF_NULL(server_node_); | |||
| communicator_ = | |||
| std::dynamic_pointer_cast<core::TcpCommunicator>(server_node_->GetOrCreateTcpComm("", 0, 0, 0, nullptr)); | |||
| MS_EXCEPTION_IF_NULL(communicator_); | |||
| local_rank_ = server_node_->rank_id(); | |||
| server_num_ = PSContext::instance()->initial_server_num(); | |||
| counting_server_rank_ = counting_server_rank; | |||
| RegisterCallback(); | |||
| return; | |||
| } | |||
| void DistributedCountService::RegisterCounter(const std::string &name, size_t global_threshold_count, | |||
| const CounterHandlers &counter_handlers) { | |||
| if (!counter_handlers.first_count_handler || !counter_handlers.last_count_handler) { | |||
| MS_LOG(EXCEPTION) << "First count handler or last count handler is not set."; | |||
| return; | |||
| } | |||
| if (global_threshold_count_.count(name) != 0) { | |||
| MS_LOG(ERROR) << "Counter for " << name << " is already set."; | |||
| return; | |||
| } | |||
| MS_LOG(INFO) << "Rank " << local_rank_ << " register counter for " << name << " count:" << global_threshold_count; | |||
| // If the server is the leader server, it needs to set the counter handlers and do the real counting. | |||
| if (local_rank_ == counting_server_rank_) { | |||
| global_current_count_[name] = {}; | |||
| global_threshold_count_[name] = global_threshold_count; | |||
| mutex_[name]; | |||
| } | |||
| counter_handlers_[name] = counter_handlers; | |||
| return; | |||
| } | |||
| bool DistributedCountService::Count(const std::string &name, const std::string &id) { | |||
| MS_LOG(INFO) << "Rank " << local_rank_ << " reports count for " << name << " of " << id; | |||
| if (local_rank_ == counting_server_rank_) { | |||
| if (global_threshold_count_.count(name) == 0) { | |||
| MS_LOG(ERROR) << "Counter for " << name << " is not registered."; | |||
| return false; | |||
| } | |||
| std::unique_lock<std::mutex> lock(mutex_[name]); | |||
| if (global_current_count_[name].size() >= global_threshold_count_[name]) { | |||
| MS_LOG(ERROR) << "Count for " << name << " is already enough. Threshold count is " | |||
| << global_threshold_count_[name]; | |||
| return false; | |||
| } | |||
| MS_LOG(INFO) << "Leader server increase count for " << name << " of " << id; | |||
| global_current_count_[name].insert(id); | |||
| TriggerCounterEvent(name); | |||
| } else { | |||
| // If this server is a follower server, it needs to send CountRequest to the leader server. | |||
| CountRequest report_count_req; | |||
| report_count_req.set_name(name); | |||
| report_count_req.set_id(id); | |||
| std::shared_ptr<std::vector<unsigned char>> report_cnt_rsp_msg = nullptr; | |||
| if (!communicator_->SendPbRequest(report_count_req, counting_server_rank_, core::TcpUserCommand::kCount, | |||
| &report_cnt_rsp_msg)) { | |||
| MS_LOG(ERROR) << "Sending reporting count message to leader server failed for " << name; | |||
| return false; | |||
| } | |||
| CountResponse count_rsp; | |||
| count_rsp.ParseFromArray(report_cnt_rsp_msg->data(), report_cnt_rsp_msg->size()); | |||
| if (!count_rsp.result()) { | |||
| MS_LOG(ERROR) << "Reporting count failed:" << count_rsp.reason(); | |||
| return false; | |||
| } | |||
| } | |||
| return true; | |||
| } | |||
| bool DistributedCountService::CountReachThreshold(const std::string &name) { | |||
| MS_LOG(INFO) << "Rank " << local_rank_ << " query whether count reaches threshold for " << name; | |||
| if (local_rank_ == counting_server_rank_) { | |||
| if (global_threshold_count_.count(name) == 0) { | |||
| MS_LOG(ERROR) << "Counter for " << name << " is not set."; | |||
| return false; | |||
| } | |||
| std::unique_lock<std::mutex> lock(mutex_[name]); | |||
| return global_current_count_[name].size() == global_threshold_count_[name]; | |||
| } else { | |||
| CountReachThresholdRequest count_reach_threashold_req; | |||
| count_reach_threashold_req.set_name(name); | |||
| std::shared_ptr<std::vector<unsigned char>> query_cnt_enough_rsp_msg = nullptr; | |||
| if (!communicator_->SendPbRequest(count_reach_threashold_req, counting_server_rank_, | |||
| core::TcpUserCommand::kReachThreshold, &query_cnt_enough_rsp_msg)) { | |||
| MS_LOG(ERROR) << "Sending querying whether count reaches threshold message to leader server failed for " << name; | |||
| return false; | |||
| } | |||
| CountReachThresholdResponse count_reach_threashold_rsp; | |||
| count_reach_threashold_rsp.ParseFromArray(query_cnt_enough_rsp_msg->data(), query_cnt_enough_rsp_msg->size()); | |||
| return count_reach_threashold_rsp.is_enough(); | |||
| } | |||
| } | |||
| void DistributedCountService::ResetCounter(const std::string &name) { | |||
| if (local_rank_ == counting_server_rank_) { | |||
| MS_LOG(INFO) << "Leader server reset count for " << name; | |||
| global_current_count_[name].clear(); | |||
| } | |||
| return; | |||
| } | |||
| void DistributedCountService::RegisterCallback() { | |||
| if (local_rank_ == counting_server_rank_) { | |||
| communicator_->RegisterMsgCallBack( | |||
| "count", std::bind(&DistributedCountService::HandleCountRequest, this, std::placeholders::_1)); | |||
| communicator_->RegisterMsgCallBack( | |||
| "countReachThreshold", | |||
| std::bind(&DistributedCountService::HandleCountReachThresholdRequest, this, std::placeholders::_1)); | |||
| } | |||
| // The callback of first/last event must be set in both leader server and follower servers. | |||
| communicator_->RegisterMsgCallBack( | |||
| "counterEvent", std::bind(&DistributedCountService::HandleCounterEvent, this, std::placeholders::_1)); | |||
| } | |||
| void DistributedCountService::HandleCountRequest(const std::shared_ptr<core::MessageHandler> &message) { | |||
| if (message == nullptr) { | |||
| MS_LOG(ERROR) << "Message is nullptr."; | |||
| return; | |||
| } | |||
| CountRequest report_count_req; | |||
| report_count_req.ParseFromArray(message->data(), message->len()); | |||
| const std::string &name = report_count_req.name(); | |||
| const std::string &id = report_count_req.id(); | |||
| CountResponse count_rsp; | |||
| std::unique_lock<std::mutex> lock(mutex_[name]); | |||
| // If leader server has no counter for the name registered, return an error. | |||
| if (global_threshold_count_.count(name) == 0) { | |||
| std::string reason = "Counter for " + name + " is not registered."; | |||
| count_rsp.set_result(false); | |||
| count_rsp.set_reason(reason); | |||
| MS_LOG(ERROR) << reason; | |||
| communicator_->SendResponse(count_rsp.SerializeAsString().data(), count_rsp.SerializeAsString().size(), message); | |||
| return; | |||
| } | |||
| // If leader server already has enough count for the name, return an error. | |||
| if (global_current_count_[name].size() >= global_threshold_count_[name]) { | |||
| std::string reason = | |||
| "Count for " + name + " is already enough. Threshold count is " + std::to_string(global_threshold_count_[name]); | |||
| count_rsp.set_result(false); | |||
| count_rsp.set_reason(reason); | |||
| MS_LOG(ERROR) << reason; | |||
| communicator_->SendResponse(count_rsp.SerializeAsString().data(), count_rsp.SerializeAsString().size(), message); | |||
| return; | |||
| } | |||
| // Insert the id for the counter, which means the count for the name is increased. | |||
| MS_LOG(INFO) << "Leader server increase count for " << name << " of " << id; | |||
| global_current_count_[name].insert(id); | |||
| TriggerCounterEvent(name); | |||
| count_rsp.set_result(true); | |||
| count_rsp.set_reason("success"); | |||
| communicator_->SendResponse(count_rsp.SerializeAsString().data(), count_rsp.SerializeAsString().size(), message); | |||
| return; | |||
| } | |||
| void DistributedCountService::HandleCountReachThresholdRequest(const std::shared_ptr<core::MessageHandler> &message) { | |||
| if (message == nullptr) { | |||
| MS_LOG(ERROR) << "Message is nullptr."; | |||
| return; | |||
| } | |||
| CountReachThresholdRequest count_reach_threashold_req; | |||
| count_reach_threashold_req.ParseFromArray(message->data(), message->len()); | |||
| const std::string &name = count_reach_threashold_req.name(); | |||
| std::unique_lock<std::mutex> lock(mutex_[name]); | |||
| if (global_threshold_count_.count(name) == 0) { | |||
| MS_LOG(ERROR) << "Counter for " << name << " is not registered."; | |||
| return; | |||
| } | |||
| CountReachThresholdResponse count_reach_threashold_rsp; | |||
| count_reach_threashold_rsp.set_is_enough(global_current_count_[name].size() == global_threshold_count_[name]); | |||
| communicator_->SendResponse(count_reach_threashold_rsp.SerializeAsString().data(), | |||
| count_reach_threashold_rsp.SerializeAsString().size(), message); | |||
| return; | |||
| } | |||
| void DistributedCountService::HandleCounterEvent(const std::shared_ptr<core::MessageHandler> &message) { | |||
| if (message == nullptr) { | |||
| MS_LOG(ERROR) << "Message is nullptr."; | |||
| return; | |||
| } | |||
| // Respond as soon as possible so the leader server won't wait for each follower servers to finish calling the | |||
| // callbacks. | |||
| std::string couter_event_rsp_msg = "success"; | |||
| communicator_->SendResponse(couter_event_rsp_msg.data(), couter_event_rsp_msg.size(), message); | |||
| CounterEvent counter_event; | |||
| counter_event.ParseFromArray(message->data(), message->len()); | |||
| const auto &type = counter_event.type(); | |||
| const auto &name = counter_event.name(); | |||
| MS_LOG(INFO) << "Rank " << local_rank_ << " do counter event " << type << " for " << name; | |||
| if (type == CounterEventType::FIRST_CNT) { | |||
| counter_handlers_[name].first_count_handler(message); | |||
| } else if (type == CounterEventType::LAST_CNT) { | |||
| counter_handlers_[name].last_count_handler(message); | |||
| } else { | |||
| MS_LOG(ERROR) << "DistributedCountService event type " << type << " is invalid."; | |||
| return; | |||
| } | |||
| return; | |||
| } | |||
| void DistributedCountService::TriggerCounterEvent(const std::string &name) { | |||
| MS_LOG(INFO) << "Current count for " << name << " is " << global_current_count_[name].size() | |||
| << ", threshold count is " << global_threshold_count_[name]; | |||
| // The threshold count may be 1 so the first and last count event should be both activated. | |||
| if (global_current_count_[name].size() == 1) { | |||
| TriggerFirstCountEvent(name); | |||
| } | |||
| if (global_current_count_[name].size() == global_threshold_count_[name]) { | |||
| TriggerLastCountEvent(name); | |||
| } | |||
| return; | |||
| } | |||
| void DistributedCountService::TriggerFirstCountEvent(const std::string &name) { | |||
| MS_LOG(INFO) << "Activating first count event for " << name; | |||
| CounterEvent first_count_event; | |||
| first_count_event.set_type(CounterEventType::FIRST_CNT); | |||
| first_count_event.set_name(name); | |||
| // Broadcast to all follower servers. | |||
| for (uint32_t i = 1; i < server_num_; i++) { | |||
| if (!communicator_->SendPbRequest(first_count_event, i, core::TcpUserCommand::kCounterEvent)) { | |||
| MS_LOG(ERROR) << "Activating first count event to server " << i << " failed."; | |||
| return; | |||
| } | |||
| } | |||
| // Leader server directly calls the callback. | |||
| counter_handlers_[name].first_count_handler(nullptr); | |||
| return; | |||
| } | |||
| void DistributedCountService::TriggerLastCountEvent(const std::string &name) { | |||
| MS_LOG(INFO) << "Activating last count event for " << name; | |||
| CounterEvent last_count_event; | |||
| last_count_event.set_type(CounterEventType::LAST_CNT); | |||
| last_count_event.set_name(name); | |||
| // Broadcast to all follower servers. | |||
| for (uint32_t i = 1; i < server_num_; i++) { | |||
| if (!communicator_->SendPbRequest(last_count_event, i, core::TcpUserCommand::kCounterEvent)) { | |||
| MS_LOG(ERROR) << "Activating last count event to server " << i << " failed."; | |||
| return; | |||
| } | |||
| } | |||
| // Leader server directly calls the callback. | |||
| counter_handlers_[name].last_count_handler(nullptr); | |||
| return; | |||
| } | |||
| } // namespace server | |||
| } // namespace ps | |||
| } // namespace mindspore | |||
| @@ -0,0 +1,126 @@ | |||
| /** | |||
| * Copyright 2021 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #ifndef MINDSPORE_CCSRC_PS_SERVER_DISTRIBUTED_COUNT_SERVICE_H_ | |||
| #define MINDSPORE_CCSRC_PS_SERVER_DISTRIBUTED_COUNT_SERVICE_H_ | |||
| #include <set> | |||
| #include <string> | |||
| #include <memory> | |||
| #include <unordered_map> | |||
| #include "proto/ps.pb.h" | |||
| #include "ps/server/common.h" | |||
| #include "ps/core/server_node.h" | |||
| #include "ps/core/communicator/tcp_communicator.h" | |||
| namespace mindspore { | |||
| namespace ps { | |||
| namespace server { | |||
| // The callbacks for the first count and last count event. | |||
| typedef struct { | |||
| MessageCallback first_count_handler; | |||
| MessageCallback last_count_handler; | |||
| } CounterHandlers; | |||
| // DistributedCountService is used for counting in the server cluster dimension. It's used for counting of rounds, | |||
| // aggregation counting, etc. | |||
| // The counting could be called by any server, but only one server has the information | |||
| // of the cluster count and we mark this server as the counting server. Other servers must communicate with this | |||
| // counting server to increase/query count number. | |||
| // On the first count or last count event, DistributedCountService on the counting server triggers the event on other | |||
| // servers by sending counter event commands. This is for the purpose of keeping server cluster's consistency. | |||
| class DistributedCountService { | |||
| public: | |||
| static DistributedCountService &GetInstance() { | |||
| static DistributedCountService instance; | |||
| return instance; | |||
| } | |||
| // Initialize counter service with the server node because communication is needed. | |||
| void Initialize(const std::shared_ptr<core::ServerNode> &server_node, uint32_t counting_server_rank); | |||
| // Register counter to the counting server for the name with its threshold count in server cluster dimension and | |||
| // first/last count event callbacks. | |||
| void RegisterCounter(const std::string &name, size_t global_threshold_count, const CounterHandlers &counter_handlers); | |||
| // Report a count to the counting server. Parameter 'id' is in case of repeated counting. | |||
| bool Count(const std::string &name, const std::string &id); | |||
| // Query whether the count reaches the threshold count for the name. If the count is the same as the threshold count, | |||
| // this method returns true. | |||
| bool CountReachThreshold(const std::string &name); | |||
| // Reset the count of the name to 0. | |||
| void ResetCounter(const std::string &name); | |||
| // Returns the server rank because in some cases the callers use this rank as the 'id' for method | |||
| // Count. | |||
| uint32_t local_rank() { return local_rank_; } | |||
| private: | |||
| DistributedCountService() = default; | |||
| ~DistributedCountService() = default; | |||
| DistributedCountService(const DistributedCountService &) = delete; | |||
| DistributedCountService &operator=(const DistributedCountService &) = delete; | |||
| // Register callbacks of the counting server to handle messages sent by the other servers. | |||
| void RegisterCallback(); | |||
| // Callback for the reporting count message from other servers. Only counting server will call this method. | |||
| void HandleCountRequest(const std::shared_ptr<core::MessageHandler> &message); | |||
| // Callback for the querying whether threshold count is reached message from other servers. Only counting | |||
| // server will call this method. | |||
| void HandleCountReachThresholdRequest(const std::shared_ptr<core::MessageHandler> &message); | |||
| // Callback for the first/last event message from the counting server. Only other servers will call this | |||
| // method. | |||
| void HandleCounterEvent(const std::shared_ptr<core::MessageHandler> &message); | |||
| // Call the callbacks when the first/last count event is triggered. | |||
| void TriggerCounterEvent(const std::string &name); | |||
| void TriggerFirstCountEvent(const std::string &name); | |||
| void TriggerLastCountEvent(const std::string &name); | |||
| // Members for the communication between counting server and other servers. | |||
| std::shared_ptr<core::ServerNode> server_node_; | |||
| std::shared_ptr<core::TcpCommunicator> communicator_; | |||
| uint32_t local_rank_; | |||
| uint32_t server_num_; | |||
| // Only one server will be set to do the real counting. | |||
| uint32_t counting_server_rank_; | |||
| // Key: name, e.g, startFLJob, updateModel, push. | |||
| // Value: a set of id without repeatation because each work may report multiple times. | |||
| std::unordered_map<std::string, std::set<std::string>> global_current_count_; | |||
| // Key: name, e.g, StartFLJobCount. | |||
| // Value: global threshold count in the server cluster dimension for this name. | |||
| std::unordered_map<std::string, size_t> global_threshold_count_; | |||
| // First/last count event callbacks of the name. | |||
| std::unordered_map<std::string, CounterHandlers> counter_handlers_; | |||
| // Because the count is increased/queried conccurently, we must ensure the operations are threadsafe. | |||
| std::unordered_map<std::string, std::mutex> mutex_; | |||
| }; | |||
| } // namespace server | |||
| } // namespace ps | |||
| } // namespace mindspore | |||
| #endif // MINDSPORE_CCSRC_PS_SERVER_DISTRIBUTED_COUNT_SERVICE_H_ | |||
| @@ -0,0 +1,201 @@ | |||
| /** | |||
| * Copyright 2021 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #include "ps/server/distributed_metadata_store.h" | |||
| #include <memory> | |||
| #include <string> | |||
| #include <vector> | |||
| namespace mindspore { | |||
| namespace ps { | |||
| namespace server { | |||
| void DistributedMetadataStore::Initialize(const std::shared_ptr<core::ServerNode> &server_node) { | |||
| server_node_ = server_node; | |||
| MS_EXCEPTION_IF_NULL(server_node); | |||
| communicator_ = | |||
| std::dynamic_pointer_cast<core::TcpCommunicator>(server_node_->GetOrCreateTcpComm("", 0, 0, 0, nullptr)); | |||
| MS_EXCEPTION_IF_NULL(communicator_); | |||
| local_rank_ = server_node_->rank_id(); | |||
| server_num_ = PSContext::instance()->initial_server_num(); | |||
| InitHashRing(); | |||
| RegisterCallback(); | |||
| return; | |||
| } | |||
| void DistributedMetadataStore::RegisterMetadata(const std::string &name, const PBMetadata &meta) { | |||
| if (router_ == nullptr) { | |||
| MS_LOG(ERROR) << "The consistent hash ring is not initialized yet."; | |||
| return; | |||
| } | |||
| uint32_t stored_rank = router_->Find(name); | |||
| if (local_rank_ == stored_rank) { | |||
| if (metadata_.count(name) != 0) { | |||
| MS_LOG(ERROR) << "The metadata for " << name << " is already registered."; | |||
| return; | |||
| } | |||
| MS_LOG(INFO) << "Rank " << local_rank_ << " register storage for metadata " << name; | |||
| metadata_[name] = meta; | |||
| mutex_[name]; | |||
| } | |||
| return; | |||
| } | |||
| void DistributedMetadataStore::ResetMetadata(const std::string &name) { | |||
| if (router_ == nullptr) { | |||
| MS_LOG(ERROR) << "The consistent hash ring is not initialized yet."; | |||
| return; | |||
| } | |||
| uint32_t stored_rank = router_->Find(name); | |||
| if (local_rank_ == stored_rank) { | |||
| if (metadata_.count(name) == 0) { | |||
| MS_LOG(ERROR) << "The metadata for " << name << " is not registered."; | |||
| return; | |||
| } | |||
| MS_LOG(INFO) << "Rank " << local_rank_ << " reset metadata for " << name; | |||
| std::unique_lock<std::mutex> lock(mutex_[name]); | |||
| PBMetadata empty_meta; | |||
| metadata_[name] = empty_meta; | |||
| } | |||
| return; | |||
| } | |||
| void DistributedMetadataStore::UpdateMetadata(const std::string &name, const PBMetadata &meta) { | |||
| if (router_ == nullptr) { | |||
| MS_LOG(ERROR) << "The consistent hash ring is not initialized yet."; | |||
| return; | |||
| } | |||
| uint32_t stored_rank = router_->Find(name); | |||
| MS_LOG(INFO) << "Rank " << local_rank_ << " update value for " << name << " which is stored in rank " << stored_rank; | |||
| if (local_rank_ == stored_rank) { | |||
| if (!DoUpdateMetadata(name, meta)) { | |||
| MS_LOG(ERROR) << "Updating meta data failed."; | |||
| return; | |||
| } | |||
| } else { | |||
| PBMetadataWithName metadata_with_name; | |||
| metadata_with_name.set_name(name); | |||
| *metadata_with_name.mutable_metadata() = meta; | |||
| if (!communicator_->SendPbRequest(metadata_with_name, stored_rank, core::TcpUserCommand::kUpdateMetadata)) { | |||
| MS_LOG(ERROR) << "Sending updating metadata message to server " << stored_rank << " failed."; | |||
| return; | |||
| } | |||
| } | |||
| return; | |||
| } | |||
| PBMetadata DistributedMetadataStore::GetMetadata(const std::string &name) { | |||
| if (router_ == nullptr) { | |||
| MS_LOG(ERROR) << "The consistent hash ring is not initialized yet."; | |||
| return {}; | |||
| } | |||
| uint32_t stored_rank = router_->Find(name); | |||
| MS_LOG(INFO) << "Rank " << local_rank_ << " get metadata for " << name << " which is stored in rank " << stored_rank; | |||
| if (local_rank_ == stored_rank) { | |||
| std::unique_lock<std::mutex> lock(mutex_[name]); | |||
| return metadata_[name]; | |||
| } else { | |||
| GetMetadataRequest get_metadata_req; | |||
| get_metadata_req.set_name(name); | |||
| PBMetadata get_metadata_rsp; | |||
| std::shared_ptr<std::vector<unsigned char>> get_meta_rsp_msg = nullptr; | |||
| if (!communicator_->SendPbRequest(get_metadata_req, stored_rank, core::TcpUserCommand::kGetMetadata, | |||
| &get_meta_rsp_msg)) { | |||
| MS_LOG(ERROR) << "Sending getting metadata message to server " << stored_rank << " failed."; | |||
| return get_metadata_rsp; | |||
| } | |||
| get_metadata_rsp.ParseFromArray(get_meta_rsp_msg->data(), get_meta_rsp_msg->size()); | |||
| return get_metadata_rsp; | |||
| } | |||
| } | |||
| void DistributedMetadataStore::InitHashRing() { | |||
| router_ = std::make_shared<ConsistentHashRing>(32); | |||
| MS_EXCEPTION_IF_NULL(router_); | |||
| for (uint32_t i = 0; i < server_num_; i++) { | |||
| bool ret = router_->Insert(i); | |||
| if (!ret) { | |||
| MS_LOG(EXCEPTION) << "Add node " << i << " to router of meta storage failed."; | |||
| return; | |||
| } | |||
| } | |||
| return; | |||
| } | |||
| void DistributedMetadataStore::RegisterCallback() { | |||
| communicator_->RegisterMsgCallBack( | |||
| "updateMetadata", std::bind(&DistributedMetadataStore::HandleUpdateMetadataRequest, this, std::placeholders::_1)); | |||
| communicator_->RegisterMsgCallBack( | |||
| "getMetadata", std::bind(&DistributedMetadataStore::HandleGetMetadataRequest, this, std::placeholders::_1)); | |||
| return; | |||
| } | |||
| void DistributedMetadataStore::HandleUpdateMetadataRequest(const std::shared_ptr<core::MessageHandler> &message) { | |||
| if (message == nullptr) { | |||
| MS_LOG(ERROR) << "Message is nullptr."; | |||
| return; | |||
| } | |||
| PBMetadataWithName meta_with_name; | |||
| meta_with_name.ParseFromArray(message->data(), message->len()); | |||
| const std::string &name = meta_with_name.name(); | |||
| MS_LOG(INFO) << "Update metadata for " << name; | |||
| std::string update_meta_rsp_msg; | |||
| if (!DoUpdateMetadata(name, meta_with_name.metadata())) { | |||
| update_meta_rsp_msg = "Updating meta data failed."; | |||
| } else { | |||
| update_meta_rsp_msg = "Success"; | |||
| } | |||
| communicator_->SendResponse(update_meta_rsp_msg.data(), update_meta_rsp_msg.size(), message); | |||
| return; | |||
| } | |||
| void DistributedMetadataStore::HandleGetMetadataRequest(const std::shared_ptr<core::MessageHandler> &message) { | |||
| if (message == nullptr) { | |||
| MS_LOG(ERROR) << "Message is nullptr."; | |||
| return; | |||
| } | |||
| GetMetadataRequest get_metadata_req; | |||
| get_metadata_req.ParseFromArray(message->data(), message->len()); | |||
| const std::string &name = get_metadata_req.name(); | |||
| MS_LOG(INFO) << "Getting metadata for " << name; | |||
| std::unique_lock<std::mutex> lock(mutex_[name]); | |||
| PBMetadata stored_meta = metadata_[name]; | |||
| std::string getting_meta_rsp_msg = stored_meta.SerializeAsString(); | |||
| communicator_->SendResponse(getting_meta_rsp_msg.data(), getting_meta_rsp_msg.size(), message); | |||
| return; | |||
| } | |||
| bool DistributedMetadataStore::DoUpdateMetadata(const std::string &name, const PBMetadata &meta) { | |||
| std::unique_lock<std::mutex> lock(mutex_[name]); | |||
| metadata_[name] = meta; | |||
| return true; | |||
| } | |||
| } // namespace server | |||
| } // namespace ps | |||
| } // namespace mindspore | |||
| @@ -0,0 +1,101 @@ | |||
| /** | |||
| * Copyright 2021 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #ifndef MINDSPORE_CCSRC_PS_SERVER_DISTRIBUTED_META_STORE_H_ | |||
| #define MINDSPORE_CCSRC_PS_SERVER_DISTRIBUTED_META_STORE_H_ | |||
| #include <string> | |||
| #include <memory> | |||
| #include <unordered_map> | |||
| #include "proto/ps.pb.h" | |||
| #include "ps/server/common.h" | |||
| #include "ps/core/server_node.h" | |||
| #include "ps/core/communicator/tcp_communicator.h" | |||
| #include "ps/server/consistent_hash_ring.h" | |||
| namespace mindspore { | |||
| namespace ps { | |||
| namespace server { | |||
| // This class is used for distributed metadata storage using consistent hash. All metadata is distributedly | |||
| // stored in all servers. Caller doesn't need to know which server stores the metadata. It only needs to know what kind | |||
| // of operations should be done to the metadata. | |||
| // The metadata stored in the server is in protobuffer format because it's easy for serializing and communicating. The | |||
| // type of the protobuffer struct is decided by the caller using protobuffer's API. | |||
| class DistributedMetadataStore { | |||
| public: | |||
| static DistributedMetadataStore &GetInstance() { | |||
| static DistributedMetadataStore instance; | |||
| return instance; | |||
| } | |||
| // Initialize metadata storage with the server node because communication is needed. | |||
| void Initialize(const std::shared_ptr<core::ServerNode> &server_node); | |||
| // Register metadata for the name with the initial value. This method should be only called once for each name. | |||
| void RegisterMetadata(const std::string &name, const PBMetadata &meta); | |||
| // Reset the metadata value for the name. | |||
| void ResetMetadata(const std::string &name); | |||
| // Update the metadata for the name. | |||
| void UpdateMetadata(const std::string &name, const PBMetadata &meta); | |||
| // Get the metadata for the name. | |||
| PBMetadata GetMetadata(const std::string &name); | |||
| private: | |||
| DistributedMetadataStore() = default; | |||
| ~DistributedMetadataStore() = default; | |||
| DistributedMetadataStore(const DistributedMetadataStore &) = delete; | |||
| DistributedMetadataStore &operator=(const DistributedMetadataStore &) = delete; | |||
| // Initialize the consistent hash ring for distributed storage. | |||
| void InitHashRing(); | |||
| // Register callbacks for the server to handle update/get metadata messages from other servers. | |||
| void RegisterCallback(); | |||
| // Callback for updating metadata request sent to the server. | |||
| void HandleUpdateMetadataRequest(const std::shared_ptr<core::MessageHandler> &message); | |||
| // Callback for getting metadata request sent to the server. | |||
| void HandleGetMetadataRequest(const std::shared_ptr<core::MessageHandler> &message); | |||
| // Do updating metadata in the server where the metadata for the name is stored. | |||
| bool DoUpdateMetadata(const std::string &name, const PBMetadata &meta); | |||
| // Members for the communication between servers. | |||
| std::shared_ptr<core::ServerNode> server_node_; | |||
| std::shared_ptr<core::TcpCommunicator> communicator_; | |||
| uint32_t local_rank_; | |||
| uint32_t server_num_; | |||
| // Consistent hash ring. This is used for DistributedMetadataStore to find which server node the meta data is stored. | |||
| std::shared_ptr<ConsistentHashRing> router_; | |||
| // We store metadata which is serialized by ProtoBuffer so that data storage and data transmission API is easy to use. | |||
| // Key: data name. | |||
| // Value: ProtoBuffer Struct. | |||
| std::unordered_map<std::string, PBMetadata> metadata_; | |||
| // Because the metadata is read/written conccurently, we must ensure the operations are threadsafe. | |||
| std::unordered_map<std::string, std::mutex> mutex_; | |||
| }; | |||
| } // namespace server | |||
| } // namespace ps | |||
| } // namespace mindspore | |||
| #endif // MINDSPORE_CCSRC_PS_SERVER_DISTRIBUTED_META_STORE_H_ | |||
| @@ -169,7 +169,7 @@ bool Executor::HandleOverwriteWeightsByKey(const std::map<std::string, Address> | |||
| } | |||
| AddressPtr Executor::HandlePull(const std::string ¶m_name) { | |||
| MS_LOG(INFO) << "Handle blocking pull msg for parameter " << param_name; | |||
| MS_LOG(INFO) << "Handle blocking pull message for parameter " << param_name; | |||
| if (param_aggrs_.count(param_name) == 0) { | |||
| MS_LOG(WARNING) << "Parameter " << param_name << " is not registered in server."; | |||
| return nullptr; | |||
| @@ -193,11 +193,6 @@ AddressPtr Executor::HandlePull(const std::string ¶m_name) { | |||
| return addr; | |||
| } | |||
| std::map<std::string, AddressPtr> Executor::HandleAsyncGetModel() { | |||
| std::unique_lock<std::mutex> lock(model_mutex_); | |||
| return GetModel(); | |||
| } | |||
| std::map<std::string, AddressPtr> Executor::HandleGetWeightsByKey(const std::vector<std::string> ¶m_names) { | |||
| std::map<std::string, AddressPtr> weights; | |||
| for (const auto ¶m_name : param_names) { | |||
| @@ -63,10 +63,6 @@ class Executor { | |||
| // asynchronously. | |||
| bool HandleModelUpdateAsync(const std::map<std::string, UploadData> &feature_map); | |||
| // Called in asynchronous federated learning training mode. Returns whole model in key-value where key refers to the | |||
| // parameter name. | |||
| std::map<std::string, AddressPtr> HandleAsyncGetModel(); | |||
| // Forcibly overwrite specific weights in overwriteWeights message. | |||
| bool HandleOverwriteWeightsByKey(const std::map<std::string, Address> &feature_map); | |||
| @@ -0,0 +1,76 @@ | |||
| /** | |||
| * Copyright 2021 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #include "ps/server/iteration.h" | |||
| #include <memory> | |||
| #include <vector> | |||
| #include <numeric> | |||
| #include "ps/server/model_store.h" | |||
| namespace mindspore { | |||
| namespace ps { | |||
| namespace server { | |||
| Iteration::Iteration() : iteration_num_(1) { LocalMetaStore::GetInstance().set_curr_iter_num(iteration_num_); } | |||
| void Iteration::AddRound(const std::shared_ptr<Round> &round) { | |||
| MS_EXCEPTION_IF_NULL(round); | |||
| rounds_.push_back(round); | |||
| } | |||
| void Iteration::InitRounds(const std::vector<std::shared_ptr<core::CommunicatorBase>> &communicators, | |||
| const TimeOutCb &timeout_cb, const FinishIterCb &finish_iteration_cb) { | |||
| if (communicators.empty()) { | |||
| MS_LOG(EXCEPTION) << "Communicators for rounds is empty."; | |||
| return; | |||
| } | |||
| std::for_each(communicators.begin(), communicators.end(), | |||
| [&](const std::shared_ptr<core::CommunicatorBase> &communicator) { | |||
| for (auto &round : rounds_) { | |||
| if (round == nullptr) { | |||
| continue; | |||
| } | |||
| round->Initialize(communicator, timeout_cb, finish_iteration_cb); | |||
| } | |||
| }); | |||
| // The time window for one iteration, which will be used in some round kernels. | |||
| size_t iteration_time_window = | |||
| std::accumulate(rounds_.begin(), rounds_.end(), 0, | |||
| [](size_t total, const std::shared_ptr<Round> &round) { return total + round->time_window(); }); | |||
| LocalMetaStore::GetInstance().put_value(kCtxTotalTimeoutDuration, iteration_time_window); | |||
| return; | |||
| } | |||
| void Iteration::ProceedToNextIter() { | |||
| iteration_num_ = LocalMetaStore::GetInstance().curr_iter_num(); | |||
| // Store the model for each iteration. | |||
| const auto &model = Executor::GetInstance().GetModel(); | |||
| ModelStore::GetInstance().StoreModelByIterNum(iteration_num_, model); | |||
| for (auto &round : rounds_) { | |||
| round->Reset(); | |||
| } | |||
| iteration_num_++; | |||
| LocalMetaStore::GetInstance().set_curr_iter_num(iteration_num_); | |||
| MS_LOG(INFO) << "Proceed to next iteration:" << iteration_num_ << "\n"; | |||
| } | |||
| const std::vector<std::shared_ptr<Round>> &Iteration::rounds() { return rounds_; } | |||
| } // namespace server | |||
| } // namespace ps | |||
| } // namespace mindspore | |||
| @@ -0,0 +1,58 @@ | |||
| /** | |||
| * Copyright 2021 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #ifndef MINDSPORE_CCSRC_PS_SERVER_ITERATION_H_ | |||
| #define MINDSPORE_CCSRC_PS_SERVER_ITERATION_H_ | |||
| #include <memory> | |||
| #include <vector> | |||
| #include "ps/core/communicator/communicator_base.h" | |||
| #include "ps/server/common.h" | |||
| #include "ps/server/round.h" | |||
| #include "ps/server/local_meta_store.h" | |||
| namespace mindspore { | |||
| namespace ps { | |||
| namespace server { | |||
| // In server's logic, Iteration is the minimum execution unit. For each execution, it consists of multiple kinds of | |||
| // Rounds, only after all the rounds are finished, this iteration is considered as completed. | |||
| class Iteration { | |||
| public: | |||
| Iteration(); | |||
| ~Iteration() = default; | |||
| // Add a round for the iteration. This method will be called multiple times for each round. | |||
| void AddRound(const std::shared_ptr<Round> &round); | |||
| // Initialize all the rounds in the iteration. | |||
| void InitRounds(const std::vector<std::shared_ptr<core::CommunicatorBase>> &communicators, | |||
| const TimeOutCb &timeout_cb, const FinishIterCb &finish_iteration_cb); | |||
| // The server proceeds to the next iteration only after the last iteration finishes. | |||
| void ProceedToNextIter(); | |||
| const std::vector<std::shared_ptr<Round>> &rounds(); | |||
| private: | |||
| std::vector<std::shared_ptr<Round>> rounds_; | |||
| // Server's current iteration number. | |||
| size_t iteration_num_; | |||
| }; | |||
| } // namespace server | |||
| } // namespace ps | |||
| } // namespace mindspore | |||
| #endif // MINDSPORE_CCSRC_PS_SERVER_ITERATION_H_ | |||
| @@ -0,0 +1,127 @@ | |||
| /** | |||
| * Copyright 2021 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #include "ps/server/kernel/round/round_kernel.h" | |||
| #include <mutex> | |||
| #include <queue> | |||
| #include <chrono> | |||
| #include <thread> | |||
| #include <utility> | |||
| #include <string> | |||
| #include <vector> | |||
| namespace mindspore { | |||
| namespace ps { | |||
| namespace server { | |||
| namespace kernel { | |||
| RoundKernel::RoundKernel() : name_(""), current_count_(0), required_count_(0), error_reason_("") { | |||
| release_thread_ = std::thread([&]() { | |||
| while (true) { | |||
| std::unique_lock<std::mutex> release_lock(release_mtx_); | |||
| // Detect whether there's any data needs to be released every 100 milliseconds. | |||
| if (heap_data_to_release_.empty()) { | |||
| release_lock.unlock(); | |||
| std::this_thread::sleep_for(std::chrono::milliseconds(100)); | |||
| continue; | |||
| } | |||
| AddressPtr addr_ptr = heap_data_to_release_.front(); | |||
| heap_data_to_release_.pop(); | |||
| release_lock.unlock(); | |||
| std::unique_lock<std::mutex> heap_data_lock(heap_data_mtx_); | |||
| if (heap_data_.count(addr_ptr) == 0) { | |||
| MS_LOG(ERROR) << "The data is not stored."; | |||
| continue; | |||
| } | |||
| // Manually release unique_ptr data. | |||
| heap_data_[addr_ptr].reset(nullptr); | |||
| heap_data_.erase(heap_data_.find(addr_ptr)); | |||
| } | |||
| }); | |||
| release_thread_.detach(); | |||
| } | |||
| void RoundKernel::OnFirstCountEvent(const std::shared_ptr<core::MessageHandler> &message) { return; } | |||
| void RoundKernel::OnLastCountEvent(const std::shared_ptr<core::MessageHandler> &message) { return; } | |||
| void RoundKernel::StopTimer() { | |||
| if (stop_timer_cb_) { | |||
| stop_timer_cb_(); | |||
| } | |||
| return; | |||
| } | |||
| void RoundKernel::FinishIteration() { | |||
| if (finish_iteration_cb_) { | |||
| finish_iteration_cb_(); | |||
| } | |||
| return; | |||
| } | |||
| void RoundKernel::Release(AddressPtr addr_ptr) { | |||
| if (addr_ptr == nullptr) { | |||
| MS_LOG(ERROR) << "Data to be released is empty."; | |||
| return; | |||
| } | |||
| std::unique_lock<std::mutex> lock(release_mtx_); | |||
| heap_data_to_release_.push(addr_ptr); | |||
| return; | |||
| } | |||
| void RoundKernel::set_name(const std::string &name) { name_ = name; } | |||
| void RoundKernel::set_stop_timer_cb(StopTimerCb timer_stopper) { stop_timer_cb_ = timer_stopper; } | |||
| void RoundKernel::set_finish_iteration_cb(FinishIterCb finish_iteration_cb) { | |||
| finish_iteration_cb_ = finish_iteration_cb; | |||
| } | |||
| void RoundKernel::GenerateOutput(const std::vector<AddressPtr> &outputs, void *data, size_t len) { | |||
| if (data == nullptr) { | |||
| MS_LOG(ERROR) << "The data is nullptr."; | |||
| return; | |||
| } | |||
| if (outputs.empty()) { | |||
| MS_LOG(ERROR) << "Generating output failed. Outputs size is empty."; | |||
| return; | |||
| } | |||
| std::unique_ptr<unsigned char[]> output_data = std::make_unique<unsigned char[]>(len); | |||
| if (output_data == nullptr) { | |||
| MS_LOG(ERROR) << "Output data is nullptr."; | |||
| return; | |||
| } | |||
| size_t dst_size = len; | |||
| int ret = memcpy_s(output_data.get(), dst_size, data, len); | |||
| if (ret != 0) { | |||
| MS_LOG(ERROR) << "memcpy_s error, errorno(" << ret << ")"; | |||
| return; | |||
| } | |||
| outputs[0]->addr = output_data.get(); | |||
| outputs[0]->size = len; | |||
| std::unique_lock<std::mutex> lock(heap_data_mtx_); | |||
| heap_data_.insert(std::make_pair(outputs[0], std::move(output_data))); | |||
| return; | |||
| } | |||
| } // namespace kernel | |||
| } // namespace server | |||
| } // namespace ps | |||
| } // namespace mindspore | |||
| @@ -0,0 +1,130 @@ | |||
| /** | |||
| * Copyright 2021 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #ifndef MINDSPORE_CCSRC_PS_SERVER_KERNEL_ROUND_ROUND_KERNEL_H_ | |||
| #define MINDSPORE_CCSRC_PS_SERVER_KERNEL_ROUND_ROUND_KERNEL_H_ | |||
| #include <map> | |||
| #include <memory> | |||
| #include <string> | |||
| #include <vector> | |||
| #include <mutex> | |||
| #include <queue> | |||
| #include <utility> | |||
| #include <chrono> | |||
| #include <thread> | |||
| #include <unordered_map> | |||
| #include "backend/kernel_compiler/common_utils.h" | |||
| #include "backend/kernel_compiler/cpu/cpu_kernel.h" | |||
| #include "ps/server/common.h" | |||
| #include "ps/server/local_meta_store.h" | |||
| #include "ps/server/distributed_count_service.h" | |||
| #include "ps/server/distributed_metadata_store.h" | |||
| namespace mindspore { | |||
| namespace ps { | |||
| namespace server { | |||
| namespace kernel { | |||
| // RoundKernel contains the main logic of server handling messages from workers. One iteration has multiple round | |||
| // kernels to represent the process. They receive and parse messages from the server communication module. After | |||
| // handling these messages, round kernels allocate response data and send it back. | |||
| // For example, the main process of federated learning is: | |||
| // startFLJob round->updateModel round->getModel round. | |||
| class RoundKernel : virtual public CPUKernel { | |||
| public: | |||
| RoundKernel(); | |||
| virtual ~RoundKernel() = default; | |||
| // RoundKernel doesn't use InitKernel method of base class CPUKernel to initialize. So implementation of this | |||
| // inherited method is empty. | |||
| void InitKernel(const CNodePtr &kernel_node) override {} | |||
| // Initialize RoundKernel with threshold_count which means that for every iteration, this round needs threshold_count | |||
| // messages. | |||
| virtual void InitKernel(size_t threshold_count) = 0; | |||
| // Launch the round kernel logic to handle the message passed by the communication module. | |||
| virtual bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace, | |||
| const std::vector<AddressPtr> &outputs) = 0; | |||
| // The callbacks when first message and last message for this round kernel is received. | |||
| // These methods is called by class DistributedCountService and triggered by leader server(Rank 0). | |||
| // virtual void OnFirstCountEvent(std::shared_ptr<core::MessageHandler> message); | |||
| // virtual void OnLastCnt(std::shared_ptr<core::MessageHandler> message); | |||
| // Some rounds could be stateful in a iteration. Reset method resets the status of this round. | |||
| virtual bool Reset() = 0; | |||
| // The counter event handlers for DistributedCountService. | |||
| virtual void OnFirstCountEvent(const std::shared_ptr<core::MessageHandler> &message); | |||
| virtual void OnLastCountEvent(const std::shared_ptr<core::MessageHandler> &message); | |||
| // Called when this round is finished. This round timer's Stop method will be called. | |||
| void StopTimer(); | |||
| // Called after this iteration(including all rounds) is finished. All rounds' Reset method will | |||
| // be called. | |||
| void FinishIteration(); | |||
| // Release the response data allocated inside the round kernel. | |||
| // Server framework must call this after the response data is sent back. | |||
| void Release(AddressPtr addr_ptr); | |||
| // Set round kernel name, which could be used in round kernel's methods. | |||
| void set_name(const std::string &name); | |||
| // Set callbacks to be called under certain triggered conditions. | |||
| void set_stop_timer_cb(StopTimerCb timer_stopper); | |||
| void set_finish_iteration_cb(FinishIterCb finish_iteration_cb); | |||
| protected: | |||
| // Generating response data of this round. The data is allocated on the heap to ensure it's not released before sent | |||
| // back to worker. | |||
| void GenerateOutput(const std::vector<AddressPtr> &outputs, void *data, size_t len); | |||
| // Round kernel's name. | |||
| std::string name_; | |||
| // The current received message count for this round in this iteration. | |||
| size_t current_count_; | |||
| // The required received message count for this round in one iteration. | |||
| size_t required_count_; | |||
| // The reason causes the error in this round kernel. | |||
| std::string error_reason_; | |||
| StopTimerCb stop_timer_cb_; | |||
| FinishIterCb finish_iteration_cb_; | |||
| // Members below are used for allocating and releasing response data on the heap. | |||
| // To ensure the performance, we use another thread to release data on the heap. So the operation on the data should | |||
| // be threadsafe. | |||
| std::thread release_thread_; | |||
| // Data needs to be released and its mutex; | |||
| std::mutex release_mtx_; | |||
| std::queue<AddressPtr> heap_data_to_release_; | |||
| std::mutex heap_data_mtx_; | |||
| std::unordered_map<AddressPtr, std::unique_ptr<unsigned char[]>> heap_data_; | |||
| }; | |||
| } // namespace kernel | |||
| } // namespace server | |||
| } // namespace ps | |||
| } // namespace mindspore | |||
| #endif // MINDSPORE_CCSRC_PS_SERVER_KERNEL_ROUND_ROUND_KERNEL_H_ | |||
| @@ -0,0 +1,44 @@ | |||
| /** | |||
| * Copyright 2021 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #include "ps/server/kernel/round/round_kernel_factory.h" | |||
| namespace mindspore { | |||
| namespace ps { | |||
| namespace server { | |||
| namespace kernel { | |||
| RoundKernelFactory &RoundKernelFactory::GetInstance() { | |||
| static RoundKernelFactory instance; | |||
| return instance; | |||
| } | |||
| void RoundKernelFactory::Register(const std::string &name, RoundKernelCreator &&creator) { | |||
| name_to_creator_map_[name] = creator; | |||
| } | |||
| std::shared_ptr<RoundKernel> RoundKernelFactory::Create(const std::string &name) { | |||
| if (name_to_creator_map_.count(name) == 0) { | |||
| MS_LOG(ERROR) << "Round kernel " << name << " is not registered."; | |||
| return nullptr; | |||
| } | |||
| auto kernel = name_to_creator_map_[name](); | |||
| kernel->set_name(name); | |||
| return kernel; | |||
| } | |||
| } // namespace kernel | |||
| } // namespace server | |||
| } // namespace ps | |||
| } // namespace mindspore | |||
| @@ -0,0 +1,62 @@ | |||
| /** | |||
| * Copyright 2021 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #ifndef MINDSPORE_CCSRC_PS_SERVER_KERNEL_ROUND_ROUND_KERNEL_FACTORY_H_ | |||
| #define MINDSPORE_CCSRC_PS_SERVER_KERNEL_ROUND_ROUND_KERNEL_FACTORY_H_ | |||
| #include <memory> | |||
| #include <string> | |||
| #include <utility> | |||
| #include <unordered_map> | |||
| #include "ps/server/common.h" | |||
| #include "ps/server/kernel/round/round_kernel.h" | |||
| namespace mindspore { | |||
| namespace ps { | |||
| namespace server { | |||
| namespace kernel { | |||
| using RoundKernelCreator = std::function<std::shared_ptr<RoundKernel>()>; | |||
| // Kernel factory of round kernels. | |||
| class RoundKernelFactory { | |||
| public: | |||
| static RoundKernelFactory &GetInstance(); | |||
| void Register(const std::string &name, RoundKernelCreator &&creator); | |||
| std::shared_ptr<RoundKernel> Create(const std::string &name); | |||
| private: | |||
| RoundKernelFactory() = default; | |||
| ~RoundKernelFactory() = default; | |||
| RoundKernelFactory(const RoundKernelFactory &) = delete; | |||
| RoundKernelFactory &operator=(const RoundKernelFactory &) = delete; | |||
| std::unordered_map<std::string, RoundKernelCreator> name_to_creator_map_; | |||
| }; | |||
| class RoundKernelRegister { | |||
| public: | |||
| RoundKernelRegister(const std::string &name, RoundKernelCreator &&creator) { | |||
| RoundKernelFactory::GetInstance().Register(name, std::move(creator)); | |||
| } | |||
| }; | |||
| #define REG_ROUND_KERNEL(NAME, CLASS) \ | |||
| static_assert(std::is_base_of<RoundKernel, CLASS>::value, " must be base of RoundKernel"); \ | |||
| static const RoundKernelRegister g_##NAME##_round_kernel_reg(#NAME, []() { return std::make_shared<CLASS>(); }); | |||
| } // namespace kernel | |||
| } // namespace server | |||
| } // namespace ps | |||
| } // namespace mindspore | |||
| #endif // MINDSPORE_CCSRC_PS_SERVER_KERNEL_ROUND_ROUND_KERNEL_FACTORY_H_ | |||
| @@ -0,0 +1,192 @@ | |||
| /** | |||
| * Copyright 2021 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #include "ps/server/kernel/round/start_fl_job_kernel.h" | |||
| #include <map> | |||
| #include <memory> | |||
| #include <string> | |||
| #include <vector> | |||
| namespace mindspore { | |||
| namespace ps { | |||
| namespace server { | |||
| namespace kernel { | |||
| void StartFLJobKernel::InitKernel(size_t) { | |||
| if (LocalMetaStore::GetInstance().has_value(kCtxTotalTimeoutDuration)) { | |||
| iteration_time_window_ = LocalMetaStore::GetInstance().value<size_t>(kCtxTotalTimeoutDuration); | |||
| } | |||
| executor_ = &Executor::GetInstance(); | |||
| MS_EXCEPTION_IF_NULL(executor_); | |||
| if (!executor_->initialized()) { | |||
| MS_LOG(EXCEPTION) << "Executor must be initialized in server pipeline."; | |||
| return; | |||
| } | |||
| PBMetadata devices_metas; | |||
| DistributedMetadataStore::GetInstance().RegisterMetadata(kCtxDeviceMetas, devices_metas); | |||
| return; | |||
| } | |||
| bool StartFLJobKernel::Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace, | |||
| const std::vector<AddressPtr> &outputs) { | |||
| MS_LOG(INFO) << "Launching StartFLJobKernel kernel."; | |||
| if (inputs.size() != 1 || outputs.size() != 1) { | |||
| MS_LOG(ERROR) << "inputs or outputs size is invalid."; | |||
| return false; | |||
| } | |||
| void *req_data = inputs[0]->addr; | |||
| const std::shared_ptr<FBBuilder> &fbb = std::make_shared<FBBuilder>(); | |||
| if (fbb == nullptr || req_data == nullptr) { | |||
| MS_LOG(ERROR) << "FBBuilder builder or req_data is nullptr."; | |||
| return false; | |||
| } | |||
| if (ReachThresholdForStartFLJob(fbb)) { | |||
| GenerateOutput(outputs, fbb->GetBufferPointer(), fbb->GetSize()); | |||
| return false; | |||
| } | |||
| const schema::RequestFLJob *start_fl_job_req = flatbuffers::GetRoot<schema::RequestFLJob>(req_data); | |||
| DeviceMeta device_meta = CreateDeviceMetadata(start_fl_job_req); | |||
| if (!ReadyForStartFLJob(fbb, device_meta)) { | |||
| GenerateOutput(outputs, fbb->GetBufferPointer(), fbb->GetSize()); | |||
| return false; | |||
| } | |||
| // If calling ReportCount before ReadyForStartFLJob, the result will be inconsistent if the device is not selected. | |||
| if (!CountForStartFLJob(fbb, start_fl_job_req)) { | |||
| GenerateOutput(outputs, fbb->GetBufferPointer(), fbb->GetSize()); | |||
| return false; | |||
| } | |||
| StartFLJob(fbb, device_meta); | |||
| GenerateOutput(outputs, fbb->GetBufferPointer(), fbb->GetSize()); | |||
| return true; | |||
| } | |||
| bool StartFLJobKernel::Reset() { | |||
| MS_LOG(INFO) << "Starting fl job kernel reset!"; | |||
| StopTimer(); | |||
| DistributedCountService::GetInstance().ResetCounter(name_); | |||
| DistributedMetadataStore::GetInstance().ResetMetadata(kCtxDeviceMetas); | |||
| return true; | |||
| } | |||
| bool StartFLJobKernel::ReachThresholdForStartFLJob(const std::shared_ptr<FBBuilder> &fbb) { | |||
| if (DistributedCountService::GetInstance().CountReachThreshold(name_)) { | |||
| std::string reason = "Current amount for startFLJob has reached the threshold. Please startFLJob later."; | |||
| BuildStartFLJobRsp(fbb, schema::ResponseCode_OutOfTime, reason, false, | |||
| std::to_string(CURRENT_TIME_MILLI.count() + iteration_time_window_)); | |||
| MS_LOG(ERROR) << reason; | |||
| return true; | |||
| } | |||
| return false; | |||
| } | |||
| DeviceMeta StartFLJobKernel::CreateDeviceMetadata(const schema::RequestFLJob *start_fl_job_req) { | |||
| std::string fl_name = start_fl_job_req->fl_name()->str(); | |||
| std::string fl_id = start_fl_job_req->fl_id()->str(); | |||
| int data_size = start_fl_job_req->data_size(); | |||
| MS_LOG(INFO) << "DeviceMeta fl_name:" << fl_name << ", fl_id:" << fl_id << ", data_size:" << data_size; | |||
| DeviceMeta device_meta; | |||
| device_meta.set_fl_name(fl_name); | |||
| device_meta.set_fl_id(fl_id); | |||
| device_meta.set_data_size(data_size); | |||
| return device_meta; | |||
| } | |||
| bool StartFLJobKernel::ReadyForStartFLJob(const std::shared_ptr<FBBuilder> &fbb, const DeviceMeta &device_meta) { | |||
| bool ret = true; | |||
| std::string reason = ""; | |||
| if (device_meta.data_size() < 1) { | |||
| reason = "FL job data size is not enough."; | |||
| ret = false; | |||
| } | |||
| if (!ret) { | |||
| BuildStartFLJobRsp(fbb, schema::ResponseCode_NotSelected, reason, false, | |||
| std::to_string(CURRENT_TIME_MILLI.count() + iteration_time_window_)); | |||
| MS_LOG(ERROR) << reason; | |||
| } | |||
| return ret; | |||
| } | |||
| bool StartFLJobKernel::CountForStartFLJob(const std::shared_ptr<FBBuilder> &fbb, | |||
| const schema::RequestFLJob *start_fl_job_req) { | |||
| if (!DistributedCountService::GetInstance().Count(name_, start_fl_job_req->fl_id()->str())) { | |||
| std::string reason = "startFLJob counting failed."; | |||
| BuildStartFLJobRsp(fbb, schema::ResponseCode_OutOfTime, reason, false, | |||
| std::to_string(CURRENT_TIME_MILLI.count() + iteration_time_window_)); | |||
| MS_LOG(ERROR) << reason; | |||
| return false; | |||
| } | |||
| return true; | |||
| } | |||
| void StartFLJobKernel::StartFLJob(const std::shared_ptr<FBBuilder> &fbb, const DeviceMeta &device_meta) { | |||
| PBMetadata metadata; | |||
| *metadata.mutable_device_meta() = device_meta; | |||
| DistributedMetadataStore::GetInstance().UpdateMetadata(kCtxDeviceMetas, metadata); | |||
| std::map<std::string, AddressPtr> feature_maps = executor_->GetModel(); | |||
| BuildStartFLJobRsp(fbb, schema::ResponseCode_SUCCEED, "success", true, | |||
| std::to_string(CURRENT_TIME_MILLI.count() + iteration_time_window_), feature_maps); | |||
| return; | |||
| } | |||
| void StartFLJobKernel::BuildStartFLJobRsp(const std::shared_ptr<FBBuilder> &fbb, const schema::ResponseCode retcode, | |||
| const std::string &reason, const bool is_selected, | |||
| const std::string &next_req_time, | |||
| std::map<std::string, AddressPtr> feature_maps) { | |||
| auto fbs_reason = fbb->CreateString(reason); | |||
| auto fbs_next_req_time = fbb->CreateString(next_req_time); | |||
| auto fbs_fl_name = fbb->CreateString(PSContext::instance()->fl_name()); | |||
| schema::FLPlanBuilder fl_plan_builder(*(fbb.get())); | |||
| fl_plan_builder.add_fl_name(fbs_fl_name); | |||
| fl_plan_builder.add_iterations(PSContext::instance()->fl_iteration_num()); | |||
| fl_plan_builder.add_epochs(PSContext::instance()->client_epoch_num()); | |||
| fl_plan_builder.add_mini_batch(PSContext::instance()->client_batch_size()); | |||
| auto fbs_fl_plan = fl_plan_builder.Finish(); | |||
| std::vector<flatbuffers::Offset<schema::FeatureMap>> fbs_feature_maps; | |||
| for (auto feature_map : feature_maps) { | |||
| auto fbs_weight_fullname = fbb->CreateString(feature_map.first); | |||
| auto fbs_weight_data = | |||
| fbb->CreateVector(reinterpret_cast<float *>(feature_map.second->addr), feature_map.second->size / sizeof(float)); | |||
| auto fbs_feature_map = schema::CreateFeatureMap(*(fbb.get()), fbs_weight_fullname, fbs_weight_data); | |||
| fbs_feature_maps.push_back(fbs_feature_map); | |||
| } | |||
| auto fbs_feature_maps_vector = fbb->CreateVector(fbs_feature_maps); | |||
| schema::ResponseFLJobBuilder rsp_fl_job_builder(*(fbb.get())); | |||
| rsp_fl_job_builder.add_retcode(retcode); | |||
| rsp_fl_job_builder.add_reason(fbs_reason); | |||
| rsp_fl_job_builder.add_iteration(LocalMetaStore::GetInstance().curr_iter_num()); | |||
| rsp_fl_job_builder.add_is_selected(is_selected); | |||
| rsp_fl_job_builder.add_next_req_time(fbs_next_req_time); | |||
| rsp_fl_job_builder.add_fl_plan_config(fbs_fl_plan); | |||
| rsp_fl_job_builder.add_feature_map(fbs_feature_maps_vector); | |||
| auto rsp_fl_job = rsp_fl_job_builder.Finish(); | |||
| fbb->Finish(rsp_fl_job); | |||
| return; | |||
| } | |||
| REG_ROUND_KERNEL(startFLJob, StartFLJobKernel) | |||
| } // namespace kernel | |||
| } // namespace server | |||
| } // namespace ps | |||
| } // namespace mindspore | |||
| @@ -0,0 +1,74 @@ | |||
| /** | |||
| * Copyright 2021 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #ifndef MINDSPORE_CCSRC_PS_SERVER_KERNEL_START_FL_JOB_KERNEL_H_ | |||
| #define MINDSPORE_CCSRC_PS_SERVER_KERNEL_START_FL_JOB_KERNEL_H_ | |||
| #include <map> | |||
| #include <memory> | |||
| #include <string> | |||
| #include <vector> | |||
| #include "ps/server/common.h" | |||
| #include "ps/server/executor.h" | |||
| #include "ps/server/kernel/round/round_kernel.h" | |||
| #include "ps/server/kernel/round/round_kernel_factory.h" | |||
| namespace mindspore { | |||
| namespace ps { | |||
| namespace server { | |||
| namespace kernel { | |||
| class StartFLJobKernel : public RoundKernel { | |||
| public: | |||
| StartFLJobKernel() = default; | |||
| ~StartFLJobKernel() override = default; | |||
| void InitKernel(size_t threshold_count) override; | |||
| bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace, | |||
| const std::vector<AddressPtr> &outputs) override; | |||
| bool Reset() override; | |||
| private: | |||
| // Returns whether the startFLJob count of this iteration has reached the threshold. | |||
| bool ReachThresholdForStartFLJob(const std::shared_ptr<FBBuilder> &fbb); | |||
| // The metadata of device will be stored and queried in updateModel round. | |||
| DeviceMeta CreateDeviceMetadata(const schema::RequestFLJob *start_fl_job_req); | |||
| // Returns whether the request is valid for startFLJob.For now, the condition is simple. We will add more conditions | |||
| // to device in later versions. | |||
| bool ReadyForStartFLJob(const std::shared_ptr<FBBuilder> &fbb, const DeviceMeta &device_meta); | |||
| // Distributed count service counts for startFLJob. | |||
| bool CountForStartFLJob(const std::shared_ptr<FBBuilder> &fbb, const schema::RequestFLJob *start_fl_job_req); | |||
| void StartFLJob(const std::shared_ptr<FBBuilder> &fbb, const DeviceMeta &device_meta); | |||
| // Build response for startFLJob round no matter success or failure. | |||
| void BuildStartFLJobRsp(const std::shared_ptr<FBBuilder> &fbb, const schema::ResponseCode retcode, | |||
| const std::string &reason, const bool is_selected, const std::string &next_req_time, | |||
| std::map<std::string, AddressPtr> feature_maps = {}); | |||
| // The executor is for getting the initial model for startFLJob request. | |||
| Executor *executor_; | |||
| // The time window of one iteration. | |||
| size_t iteration_time_window_; | |||
| }; | |||
| } // namespace kernel | |||
| } // namespace server | |||
| } // namespace ps | |||
| } // namespace mindspore | |||
| #endif // MINDSPORE_CCSRC_PS_SERVER_KERNEL_START_FL_JOB_KERNEL_H_ | |||
| @@ -14,30 +14,29 @@ | |||
| * limitations under the License. | |||
| */ | |||
| #include "ps/server/local_meta_storage.h" | |||
| #include <string> | |||
| #include "ps/server/local_meta_store.h" | |||
| namespace mindspore { | |||
| namespace ps { | |||
| namespace server { | |||
| void LocalMetaStorage::remove_value(const std::string &name) { | |||
| void LocalMetaStore::remove_value(const std::string &name) { | |||
| std::unique_lock<std::mutex> lock(mtx_); | |||
| if (key_to_meta_.count(name) != 0) { | |||
| key_to_meta_.erase(key_to_meta_.find(name)); | |||
| } | |||
| } | |||
| bool LocalMetaStorage::has_value(const std::string &name) { | |||
| bool LocalMetaStore::has_value(const std::string &name) { | |||
| std::unique_lock<std::mutex> lock(mtx_); | |||
| return key_to_meta_.count(name) != 0; | |||
| } | |||
| void LocalMetaStorage::set_curr_iter_num(size_t num) { | |||
| void LocalMetaStore::set_curr_iter_num(size_t num) { | |||
| std::unique_lock<std::mutex> lock(mtx_); | |||
| curr_iter_num_ = num; | |||
| } | |||
| const size_t LocalMetaStorage::curr_iter_num() { | |||
| const size_t LocalMetaStore::curr_iter_num() { | |||
| std::unique_lock<std::mutex> lock(mtx_); | |||
| return curr_iter_num_; | |||
| } | |||
| @@ -14,8 +14,8 @@ | |||
| * limitations under the License. | |||
| */ | |||
| #ifndef MINDSPORE_CCSRC_PS_SERVER_LOCAL_META_STORAGE_H_ | |||
| #define MINDSPORE_CCSRC_PS_SERVER_LOCAL_META_STORAGE_H_ | |||
| #ifndef MINDSPORE_CCSRC_PS_SERVER_LOCAL_META_STORE_H_ | |||
| #define MINDSPORE_CCSRC_PS_SERVER_LOCAL_META_STORE_H_ | |||
| #include <any> | |||
| #include <mutex> | |||
| @@ -26,13 +26,13 @@ | |||
| namespace mindspore { | |||
| namespace ps { | |||
| namespace server { | |||
| // LocalMetaStorage class is used for metadata storage of this server process. | |||
| // LocalMetaStore class is used for metadata storage of this server process. | |||
| // For example, the current iteration number, time windows for round kernels, etc. | |||
| // LocalMetaStorage is threadsafe. | |||
| class LocalMetaStorage { | |||
| // LocalMetaStore is threadsafe. | |||
| class LocalMetaStore { | |||
| public: | |||
| static LocalMetaStorage &GetInstance() { | |||
| static LocalMetaStorage instance; | |||
| static LocalMetaStore &GetInstance() { | |||
| static LocalMetaStore instance; | |||
| return instance; | |||
| } | |||
| @@ -43,7 +43,7 @@ class LocalMetaStorage { | |||
| } | |||
| template <typename T> | |||
| const T &value(const std::string &name) { | |||
| T value(const std::string &name) { | |||
| std::unique_lock<std::mutex> lock(mtx_); | |||
| try { | |||
| T value = std::any_cast<T>(key_to_meta_[name]); | |||
| @@ -71,10 +71,10 @@ class LocalMetaStorage { | |||
| const size_t curr_iter_num(); | |||
| private: | |||
| LocalMetaStorage() = default; | |||
| ~LocalMetaStorage() = default; | |||
| LocalMetaStorage(const LocalMetaStorage &) = delete; | |||
| LocalMetaStorage &operator=(const LocalMetaStorage &) = delete; | |||
| LocalMetaStore() = default; | |||
| ~LocalMetaStore() = default; | |||
| LocalMetaStore(const LocalMetaStore &) = delete; | |||
| LocalMetaStore &operator=(const LocalMetaStore &) = delete; | |||
| // key_to_meta_ stores metadata with key-value format. | |||
| std::unordered_map<std::string, std::any> key_to_meta_; | |||
| @@ -85,4 +85,4 @@ class LocalMetaStorage { | |||
| } // namespace server | |||
| } // namespace ps | |||
| } // namespace mindspore | |||
| #endif // MINDSPORE_CCSRC_PS_SERVER_LOCAL_META_STORAGE_H_ | |||
| #endif // MINDSPORE_CCSRC_PS_SERVER_LOCAL_META_STORE_H_ | |||
| @@ -0,0 +1,144 @@ | |||
| /** | |||
| * Copyright 2021 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #include "ps/server/model_store.h" | |||
| #include <map> | |||
| #include <string> | |||
| #include <memory> | |||
| #include "ps/server/executor.h" | |||
| namespace mindspore { | |||
| namespace ps { | |||
| namespace server { | |||
| void ModelStore::Init(uint32_t max_count) { | |||
| if (!Executor::GetInstance().initialized()) { | |||
| MS_LOG(EXCEPTION) << "Server's executor must be initialized before model storage."; | |||
| return; | |||
| } | |||
| max_model_count_ = max_count; | |||
| iteration_to_model_[kInitIterationNum] = AssignNewModelMemory(); | |||
| model_size_ = ComputeModelSize(); | |||
| } | |||
| bool ModelStore::StoreModelByIterNum(size_t iteration, const std::map<std::string, AddressPtr> &new_model) { | |||
| if (iteration_to_model_.count(iteration) != 0) { | |||
| MS_LOG(WARNING) << "Model for iteration " << iteration << " is already stored"; | |||
| return false; | |||
| } | |||
| if (new_model.empty()) { | |||
| MS_LOG(ERROR) << "Model feature map is empty."; | |||
| return false; | |||
| } | |||
| std::shared_ptr<MemoryRegister> memory_register; | |||
| if (iteration_to_model_.size() < max_model_count_) { | |||
| // If iteration_to_model_.size() is not max_model_count_, need to assign new memory for the model. | |||
| memory_register = AssignNewModelMemory(); | |||
| if (memory_register == nullptr) { | |||
| MS_LOG(ERROR) << "Memory for the new model is nullptr."; | |||
| return false; | |||
| } | |||
| iteration_to_model_[iteration] = memory_register; | |||
| } else { | |||
| // If iteration_to_model_ size is already max_model_count_, we need to replace earliest model with the newest model. | |||
| memory_register = iteration_to_model_.begin()->second; | |||
| if (memory_register == nullptr) { | |||
| MS_LOG(ERROR) << "Earliest model is nullptr."; | |||
| return false; | |||
| } | |||
| iteration_to_model_.erase(iteration_to_model_.begin()); | |||
| } | |||
| // Copy new model data to the the stored model. | |||
| auto &stored_model = memory_register->addresses(); | |||
| for (const auto &weight : new_model) { | |||
| const std::string &weight_name = weight.first; | |||
| if (stored_model.count(weight_name) != 0) { | |||
| MS_LOG(ERROR) << "The stored model has no weight " << weight_name; | |||
| continue; | |||
| } | |||
| void *dst_addr = stored_model[weight_name]->addr; | |||
| size_t dst_size = stored_model[weight_name]->size; | |||
| void *src_addr = weight.second->addr; | |||
| size_t src_size = weight.second->size; | |||
| int ret = memcpy_s(dst_addr, dst_size, src_addr, src_size); | |||
| if (ret != 0) { | |||
| MS_LOG(ERROR) << "memcpy_s error, errorno(" << ret << ")"; | |||
| return false; | |||
| } | |||
| } | |||
| iteration_to_model_[iteration] = memory_register; | |||
| return true; | |||
| } | |||
| std::map<std::string, AddressPtr> ModelStore::GetModelByIterNum(size_t iteration) { | |||
| std::map<std::string, AddressPtr> model = {}; | |||
| if (iteration_to_model_.count(iteration) == 0) { | |||
| MS_LOG(ERROR) << "Model for iteration " << iteration << " is not stored."; | |||
| return model; | |||
| } | |||
| model = iteration_to_model_[iteration]->addresses(); | |||
| return model; | |||
| } | |||
| const std::map<size_t, std::shared_ptr<MemoryRegister>> &ModelStore::iteration_to_model() const { | |||
| return iteration_to_model_; | |||
| } | |||
| size_t ModelStore::model_size() const { return model_size_; } | |||
| std::shared_ptr<MemoryRegister> ModelStore::AssignNewModelMemory() { | |||
| std::map<std::string, AddressPtr> model = Executor::GetInstance().GetModel(); | |||
| if (model.empty()) { | |||
| MS_LOG(EXCEPTION) << "Model feature map is empty."; | |||
| return nullptr; | |||
| } | |||
| // Assign new memory for the model. | |||
| std::shared_ptr<MemoryRegister> memory_register = std::make_shared<MemoryRegister>(); | |||
| for (const auto &weight : model) { | |||
| const std::string weight_name = weight.first; | |||
| size_t weight_size = weight.second->size; | |||
| auto weight_data = std::make_unique<char[]>(weight_size); | |||
| if (weight_data == nullptr) { | |||
| MS_LOG(EXCEPTION) << "Assign memory for weight failed."; | |||
| return nullptr; | |||
| } | |||
| memory_register->RegisterArray(weight_name, &weight_data, weight_size); | |||
| } | |||
| return memory_register; | |||
| } | |||
| size_t ModelStore::ComputeModelSize() { | |||
| if (iteration_to_model_.empty()) { | |||
| MS_LOG(EXCEPTION) << "Calculating model size failed: model for iteration 0 is not stored yet. "; | |||
| return 0; | |||
| } | |||
| const auto &model = iteration_to_model_[kInitIterationNum]; | |||
| MS_EXCEPTION_IF_NULL(model); | |||
| size_t model_size = std::accumulate(model->addresses().begin(), model->addresses().end(), static_cast<size_t>(0), | |||
| [](size_t s, const auto &weight) { return s + weight.second->size; }); | |||
| MS_LOG(INFO) << "Model size in byte is " << model_size; | |||
| return model_size; | |||
| } | |||
| } // namespace server | |||
| } // namespace ps | |||
| } // namespace mindspore | |||
| @@ -0,0 +1,78 @@ | |||
| /** | |||
| * Copyright 2021 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #ifndef MINDSPORE_CCSRC_PS_SERVER_MODEL_STORE_H_ | |||
| #define MINDSPORE_CCSRC_PS_SERVER_MODEL_STORE_H_ | |||
| #include <map> | |||
| #include <memory> | |||
| #include <string> | |||
| #include "ps/server/common.h" | |||
| #include "ps/server/memory_register.h" | |||
| #include "ps/server/executor.h" | |||
| namespace mindspore { | |||
| namespace ps { | |||
| namespace server { | |||
| // The initial iteration number is 0 in server. | |||
| constexpr size_t kInitIterationNum = 0; | |||
| // Server framework use ModelStore to store and query models. | |||
| // ModelStore stores multiple models because worker could get models of the previous iterations. | |||
| class ModelStore { | |||
| public: | |||
| static ModelStore &GetInstance() { | |||
| static ModelStore instance; | |||
| return instance; | |||
| } | |||
| // Initialize ModelStore with max count of models need to be stored. | |||
| void Init(uint32_t max_count = 3); | |||
| // Store the model of the given iteration. The model is acquired from Executor. If the current model count is already | |||
| // max_model_count_, the earliest model will be replaced. | |||
| bool StoreModelByIterNum(size_t iteration, const std::map<std::string, AddressPtr> &model); | |||
| // Get model of the given iteration. | |||
| std::map<std::string, AddressPtr> GetModelByIterNum(size_t iteration); | |||
| // Returns all models stored in ModelStore. | |||
| const std::map<size_t, std::shared_ptr<MemoryRegister>> &iteration_to_model() const; | |||
| // Returns the model size, which could be calculated at the initializing phase. | |||
| size_t model_size() const; | |||
| private: | |||
| ModelStore() : max_model_count_(0), model_size_(0), iteration_to_model_({}) {} | |||
| ~ModelStore() = default; | |||
| ModelStore(const ModelStore &) = delete; | |||
| ModelStore &operator=(const ModelStore &) = delete; | |||
| // To store multiple models, new memory must assigned. The max memory size assigned for models is max_model_count_ * | |||
| // model_size_. | |||
| std::shared_ptr<MemoryRegister> AssignNewModelMemory(); | |||
| // Calculate the model size. This method should be called after iteration_to_model_ is initialized. | |||
| size_t ComputeModelSize(); | |||
| size_t max_model_count_; | |||
| size_t model_size_; | |||
| std::map<size_t, std::shared_ptr<MemoryRegister>> iteration_to_model_; | |||
| }; | |||
| } // namespace server | |||
| } // namespace ps | |||
| } // namespace mindspore | |||
| #endif // MINDSPORE_CCSRC_PS_SERVER_MODEL_STORE_H_ | |||
| @@ -25,15 +25,15 @@ | |||
| namespace mindspore { | |||
| namespace ps { | |||
| namespace server { | |||
| bool ParameterAggregator::Init(const CNodePtr &cnode, size_t required_count) { | |||
| bool ParameterAggregator::Init(const CNodePtr &cnode, size_t threshold_count) { | |||
| MS_EXCEPTION_IF_NULL(cnode); | |||
| memory_register_ = std::make_shared<MemoryRegister>(); | |||
| MS_EXCEPTION_IF_NULL(memory_register_); | |||
| required_push_count_ = required_count; | |||
| required_push_count_ = threshold_count; | |||
| // The required_pull_count_ is the count for Pull, which should be the same as required_push_count_. | |||
| // required_pull_count_ normally used in parameter server training mode. | |||
| required_pull_count_ = required_count; | |||
| required_pull_count_ = threshold_count; | |||
| MS_LOG(DEBUG) << "Start initializing kernels for " << AnfAlgo::GetCNodeName(cnode); | |||
| InitAggregationKernels(cnode); | |||
| @@ -61,8 +61,8 @@ class ParameterAggregator { | |||
| ~ParameterAggregator() = default; | |||
| // Initialize ParameterAggregator with a cnode. This cnode is normally a optimizer kernel for now. | |||
| // The parameter required_count helps ParameterAggregator to judge the current status if it's stateful. | |||
| bool Init(const CNodePtr &cnode, size_t required_count = 0); | |||
| // The parameter threshold_count helps ParameterAggregator to judge the current status if it's stateful. | |||
| bool Init(const CNodePtr &cnode, size_t threshold_count = 0); | |||
| // Update old data stored in ParameterAggregator with new data. | |||
| // The data could have many meanings: weights, gradients, learning_rate, momentum, etc. | |||
| @@ -0,0 +1,139 @@ | |||
| /** | |||
| * Copyright 2021 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #include "ps/server/round.h" | |||
| #include <memory> | |||
| #include <string> | |||
| namespace mindspore { | |||
| namespace ps { | |||
| namespace server { | |||
| Round::Round(const std::string &name, bool check_timeout, size_t time_window, bool check_count, size_t threshold_count) | |||
| : name_(name), | |||
| check_timeout_(check_timeout), | |||
| time_window_(time_window), | |||
| check_count_(check_count), | |||
| threshold_count_(threshold_count) {} | |||
| void Round::Initialize(const std::shared_ptr<core::CommunicatorBase> &communicator, TimeOutCb timeout_cb, | |||
| FinishIterCb finish_iteration_cb) { | |||
| MS_EXCEPTION_IF_NULL(communicator); | |||
| communicator_ = communicator; | |||
| // Register callback for round kernel. | |||
| communicator_->RegisterMsgCallBack( | |||
| name_, [&](std::shared_ptr<core::MessageHandler> message) { LaunchRoundKernel(message); }); | |||
| // Callback when the iteration is finished. | |||
| finish_iteration_cb_ = [this, finish_iteration_cb](void) -> void { | |||
| MS_LOG(INFO) << "Round " << name_ << " finished! Proceed to next iteration."; | |||
| finish_iteration_cb(); | |||
| }; | |||
| // Callback for finalizing the server. This can only be called once. | |||
| finalize_cb_ = [&](void) -> void { communicator_->Stop(); }; | |||
| if (check_timeout_) { | |||
| iter_timer_ = std::make_shared<IterationTimer>(); | |||
| // 1.Set the timeout callback for the timer. | |||
| iter_timer_->SetTimeOutCallBack([this, timeout_cb](void) -> void { | |||
| MS_LOG(INFO) << "Round " << name_ << " timeout! Proceed to next iteration."; | |||
| timeout_cb(); | |||
| }); | |||
| // 2.Stopping timer callback which will be set to the round kernel. | |||
| stop_timer_cb_ = [&](void) -> void { | |||
| MS_LOG(INFO) << "Round " << name_ << " kernel stops its timer."; | |||
| iter_timer_->Stop(); | |||
| }; | |||
| } | |||
| // Set counter event callbacks for this round if the round kernel is stateful. | |||
| if (check_count_) { | |||
| auto first_count_handler = std::bind(&Round::OnFirstCountEvent, this, std::placeholders::_1); | |||
| auto last_count_handler = std::bind(&Round::OnLastCountEvent, this, std::placeholders::_1); | |||
| DistributedCountService::GetInstance().RegisterCounter(name_, threshold_count_, | |||
| {first_count_handler, last_count_handler}); | |||
| } | |||
| } | |||
| void Round::BindRoundKernel(const std::shared_ptr<kernel::RoundKernel> &kernel) { | |||
| MS_EXCEPTION_IF_NULL(kernel); | |||
| kernel_ = kernel; | |||
| kernel_->set_stop_timer_cb(stop_timer_cb_); | |||
| kernel_->set_finish_iteration_cb(finish_iteration_cb_); | |||
| return; | |||
| } | |||
| void Round::LaunchRoundKernel(const std::shared_ptr<core::MessageHandler> &message) { | |||
| if (message == nullptr) { | |||
| MS_LOG(ERROR) << "Message is nullptr."; | |||
| return; | |||
| } | |||
| AddressPtr input = std::make_shared<Address>(); | |||
| AddressPtr output = std::make_shared<Address>(); | |||
| input->addr = message->data(); | |||
| input->size = message->len(); | |||
| bool ret = kernel_->Launch({input}, {}, {output}); | |||
| if (output->size == 0) { | |||
| std::string reason = "The output of the round " + name_ + " is empty."; | |||
| MS_LOG(WARNING) << reason; | |||
| communicator_->SendResponse(reason.c_str(), reason.size(), message); | |||
| return; | |||
| } | |||
| // Must send response back no matter what value Launch method returns. | |||
| if (!ret) { | |||
| MS_LOG(WARNING) << "Launching round kernel of round " << name_ << " failed."; | |||
| } | |||
| communicator_->SendResponse(output->addr, output->size, message); | |||
| kernel_->Release(output); | |||
| return; | |||
| } | |||
| void Round::Reset() { kernel_->Reset(); } | |||
| const std::string &Round::name() const { return name_; } | |||
| size_t Round::threshold_count() const { return threshold_count_; } | |||
| size_t Round::time_window() const { return time_window_; } | |||
| void Round::OnFirstCountEvent(const std::shared_ptr<core::MessageHandler> &) { | |||
| MS_LOG(INFO) << "Round " << name_ << " first count event is triggered."; | |||
| // The timer starts only after the first count event is triggered by DistributedCountService. | |||
| if (check_timeout_) { | |||
| iter_timer_->Start(std::chrono::milliseconds(time_window_)); | |||
| } | |||
| return; | |||
| } | |||
| void Round::OnLastCountEvent(const std::shared_ptr<core::MessageHandler> &message) { | |||
| MS_LOG(INFO) << "Round " << name_ << " last count event is triggered."; | |||
| // Same as the first count event, the timer must be stopped by DistributedCountService. | |||
| if (check_timeout_) { | |||
| iter_timer_->Stop(); | |||
| } | |||
| // Some kernels override the OnLastCountEvent method. | |||
| kernel_->OnLastCountEvent(message); | |||
| return; | |||
| } | |||
| } // namespace server | |||
| } // namespace ps | |||
| } // namespace mindspore | |||
| @@ -0,0 +1,95 @@ | |||
| /** | |||
| * Copyright 2021 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #ifndef MINDSPORE_CCSRC_PS_SERVER_ROUND_H_ | |||
| #define MINDSPORE_CCSRC_PS_SERVER_ROUND_H_ | |||
| #include <memory> | |||
| #include <string> | |||
| #include "ps/core/communicator/communicator_base.h" | |||
| #include "ps/server/common.h" | |||
| #include "ps/server/iteration_timer.h" | |||
| #include "ps/server/distributed_count_service.h" | |||
| #include "ps/server/kernel/round/round_kernel.h" | |||
| namespace mindspore { | |||
| namespace ps { | |||
| namespace server { | |||
| // Round helps server to handle network round messages and launch round kernels. One iteration in server consists of | |||
| // multiple rounds like startFLJob, updateModel, Push, Pull, etc. Some round kernels may be stateful because of counting | |||
| // and timing. So Round helps register counter and timer so that the round kernels only need to focus on the logic. | |||
| class Round { | |||
| public: | |||
| explicit Round(const std::string &name, bool check_timeout = true, size_t time_window = 3000, | |||
| bool check_count = false, size_t threshold_count = 8); | |||
| ~Round() = default; | |||
| void Initialize(const std::shared_ptr<core::CommunicatorBase> &communicator, TimeOutCb timeout_cb, | |||
| FinishIterCb finish_iteration_cb); | |||
| // Bind a round kernel to this Round. This method should be called after Initialize. | |||
| void BindRoundKernel(const std::shared_ptr<kernel::RoundKernel> &kernel); | |||
| // This method is the callback which will be set to the communicator and called after the corresponding round message | |||
| // is sent to the server. | |||
| void LaunchRoundKernel(const std::shared_ptr<core::MessageHandler> &message); | |||
| // Round needs to be reset after each iteration is finished or its timer expires. | |||
| void Reset(); | |||
| const std::string &name() const; | |||
| size_t threshold_count() const; | |||
| size_t time_window() const; | |||
| private: | |||
| // The callbacks which will be set to DistributedCounterService. | |||
| void OnFirstCountEvent(const std::shared_ptr<core::MessageHandler> &message); | |||
| void OnLastCountEvent(const std::shared_ptr<core::MessageHandler> &message); | |||
| std::string name_; | |||
| // Whether this round needs to use timer. Most rounds in federated learning with mobile devices scenario need to set | |||
| // check_timeout_ to true. | |||
| bool check_timeout_; | |||
| // The time window duration for this round when check_timeout_ is set to true. | |||
| size_t time_window_; | |||
| // If check_count_ is true, it means the round has to do counting for every round message and the first/last count | |||
| // event will be triggered. | |||
| bool check_count_; | |||
| // The threshold count for this round when check_count_ is set to true. The logic of this round has to check whether | |||
| // the round message count has reached threshold_count_. | |||
| size_t threshold_count_; | |||
| std::shared_ptr<core::CommunicatorBase> communicator_; | |||
| // The round kernel for this Round. | |||
| std::shared_ptr<kernel::RoundKernel> kernel_; | |||
| // Some rounds may need timer to eliminate the long tail effect. | |||
| std::shared_ptr<IterationTimer> iter_timer_; | |||
| // The callbacks which will be set to the round kernel. | |||
| StopTimerCb stop_timer_cb_; | |||
| FinishIterCb finish_iteration_cb_; | |||
| FinalizeCb finalize_cb_; | |||
| }; | |||
| } // namespace server | |||
| } // namespace ps | |||
| } // namespace mindspore | |||
| #endif // MINDSPORE_CCSRC_PS_SERVER_ROUND_H_ | |||
| @@ -31,7 +31,7 @@ | |||
| #include "utils/utils.h" | |||
| #include "frontend/parallel/context.h" | |||
| #include "debug/env_config_parser.h" | |||
| #if (ENABLE_CPU && (ENABLE_D || ENABLE_GPU)) | |||
| #if (ENABLE_CPU && !_WIN32) | |||
| #include "ps/ps_cache/ps_cache_manager.h" | |||
| #endif | |||
| @@ -307,7 +307,7 @@ void KernelRuntime::AssignStaticMemoryInput(const session::KernelGraph *graph) { | |||
| } | |||
| need_alloc_nodes.push_back(item); | |||
| } | |||
| #if (ENABLE_CPU && (ENABLE_D || ENABLE_GPU)) | |||
| #if (ENABLE_CPU && !_WIN32) | |||
| bool ps_cache_check = false; | |||
| #endif | |||
| for (auto &item : need_alloc_nodes) { | |||
| @@ -320,7 +320,7 @@ void KernelRuntime::AssignStaticMemoryInput(const session::KernelGraph *graph) { | |||
| continue; | |||
| } | |||
| DeviceAddressPtr device_address = nullptr; | |||
| #if (ENABLE_CPU && (ENABLE_D || ENABLE_GPU)) | |||
| #if (ENABLE_CPU && !_WIN32) | |||
| const std::string ¶m_name = item->fullname_with_scope(); | |||
| if (ps::ps_cache_instance.IsHashTable(param_name)) { | |||
| MS_LOG(INFO) << "Parameter(" << param_name << ")" | |||
| @@ -1038,7 +1038,7 @@ DeviceAddressPtr KernelRuntime::AssignSingleOpLaunchMemory(size_t size, const st | |||
| return device_address; | |||
| } | |||
| #if (ENABLE_CPU && (ENABLE_D || ENABLE_GPU)) | |||
| #if (ENABLE_CPU && !_WIN32) | |||
| void KernelRuntime::GetFirstPSEmbeddingCache(const session::KernelGraph *graph, | |||
| AnfNodePtr *const first_cache_input_index, | |||
| size_t *const first_cache_size) { | |||
| @@ -142,7 +142,7 @@ class KernelRuntime { | |||
| void RunOpAssignOutputNodeMemory(const ValuePtr &pre_output_value, session::KernelGraph *graph); | |||
| void AssignValueNodeTensor(const ValueNodePtr &value_node, const ValuePtr &node_value, size_t output_idx); | |||
| DeviceAddressPtr PreAssignCNodeMemory(const AnfNodePtr &anf_node, size_t index); | |||
| #if (ENABLE_CPU && (ENABLE_D || ENABLE_GPU)) | |||
| #if (ENABLE_CPU && !_WIN32) | |||
| void GetFirstPSEmbeddingCache(const session::KernelGraph *graph, AnfNodePtr *const first_cache_input_index, | |||
| size_t *const first_cache_size); | |||
| void CheckIfSupportPSEmbeddingCache(const session::KernelGraph *graph); | |||
| @@ -16,14 +16,14 @@ | |||
| #include "runtime/device/kernel_runtime_manager.h" | |||
| #include "utils/log_adapter.h" | |||
| #if (ENABLE_CPU && (ENABLE_D || ENABLE_GPU)) | |||
| #if (ENABLE_CPU && !_WIN32) | |||
| #include "ps/ps_cache/ps_cache_manager.h" | |||
| #endif | |||
| namespace mindspore { | |||
| namespace device { | |||
| void KernelRuntimeManager::ClearRuntimeResource() { | |||
| #if (ENABLE_CPU && (ENABLE_D || ENABLE_GPU)) | |||
| #if (ENABLE_CPU && !_WIN32) | |||
| if (ps::PSContext::instance()->is_worker() && ps::PsDataPrefetch::GetInstance().cache_enable()) { | |||
| ps::ps_cache_instance.SyncEmbeddingTable(); | |||
| } | |||
| @@ -125,7 +125,7 @@ void KernelRuntimeManager::ReleaseKernelRuntime(const std::string &device_name, | |||
| if (runtime == nullptr) { | |||
| return; | |||
| } | |||
| #if (ENABLE_CPU && (ENABLE_D || ENABLE_GPU)) | |||
| #if (ENABLE_CPU && !_WIN32) | |||
| if (ps::PSContext::instance()->is_worker() && ps::PsDataPrefetch::GetInstance().cache_enable()) { | |||
| ps::ps_cache_instance.SyncEmbeddingTable(); | |||
| } | |||
| @@ -0,0 +1,123 @@ | |||
| /** | |||
| * Copyright 2021 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| namespace mindspore.schema; | |||
| table CipherPublicParams { | |||
| t:int; | |||
| p:[ubyte]; | |||
| g:int; | |||
| prime:[ubyte]; | |||
| dp_eps:float; | |||
| dp_delta:float; | |||
| dp_norm_clip:float; | |||
| encrypt_type:int; | |||
| } | |||
| table ClientPublicKeys { | |||
| fl_id:string; | |||
| c_pk:[ubyte]; | |||
| s_pk: [ubyte]; | |||
| } | |||
| table ClientShare { | |||
| fl_id:string; | |||
| share:[ubyte]; | |||
| index:int; | |||
| } | |||
| table RequestExchangeKeys{ | |||
| fl_id:string; | |||
| c_pk:[ubyte]; | |||
| s_pk:[ubyte]; | |||
| iteration:int; | |||
| timestamp:string; | |||
| } | |||
| table ResponseExchangeKeys{ | |||
| retcode:int; | |||
| reason:string; | |||
| next_req_time:string; | |||
| iteration:int; | |||
| } | |||
| table GetExchangeKeys{ | |||
| fl_id:string; | |||
| iteration:int; | |||
| timestamp:string; | |||
| } | |||
| table ReturnExchangeKeys{ | |||
| retcode:int; | |||
| iteration:int; | |||
| remote_publickeys:[ClientPublicKeys]; | |||
| next_req_time:string; | |||
| } | |||
| table RequestShareSecrets{ | |||
| fl_id:string; | |||
| encrypted_shares:[ClientShare]; | |||
| iteration:int; | |||
| timestamp:string; | |||
| } | |||
| table ResponseShareSecrets{ | |||
| retcode:int; | |||
| reason:string; | |||
| next_req_time:string; | |||
| iteration:int; | |||
| } | |||
| table GetShareSecrets{ | |||
| fl_id:string; | |||
| iteration:int; | |||
| timestamp:string; | |||
| } | |||
| table ReturnShareSecrets{ | |||
| retcode:int; | |||
| iteration:int; | |||
| encrypted_shares: [ClientShare]; | |||
| next_req_time:string; | |||
| } | |||
| table GetClientList{ | |||
| fl_id:string; | |||
| iteration:int; | |||
| timestamp:string; | |||
| } | |||
| table ReturnClientList{ | |||
| retcode:int; | |||
| reason:string; | |||
| clients:[string]; | |||
| iteration:int; | |||
| next_req_time:string; | |||
| } | |||
| table SendReconstructSecret{ | |||
| fl_id:string; | |||
| reconstruct_secret_shares:[ClientShare]; | |||
| iteration:int; | |||
| timestamp:string; | |||
| } | |||
| table ReconstructSecret{ | |||
| retcode:int; | |||
| reason:string; | |||
| iteration:int; | |||
| next_req_time:string; | |||
| } | |||
| @@ -0,0 +1,159 @@ | |||
| /** | |||
| * Copyright 2021 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| include "cipher.fbs"; | |||
| namespace mindspore.schema; | |||
| file_identifier "FLJ0"; | |||
| file_extension "fl"; | |||
| enum ResponseCode: int { | |||
| SUCCEED=200, | |||
| SucNotReady=201, | |||
| RepeatRequest=202, | |||
| SucNotMatch=204, | |||
| OutOfTime=300, | |||
| NotSelected=301, | |||
| RequestError=400, | |||
| SystemError=500 | |||
| } | |||
| enum AggregationType:byte {FedAvg=0, FedAdam = 1, FedAdagrag=2, FedMeta=3, qffl=4} | |||
| enum Metrics:byte {accuracy = 0, precision = 1, recall = 2, AUC = 3,f1=4, fbeta=5} | |||
| enum EarlyStopType:byte {loss_diff = 0, loss_abs = 1, weight_diff = 2} | |||
| table Aggregation { | |||
| type:AggregationType; | |||
| weights:[float]; | |||
| } | |||
| table EarlyStop { | |||
| early_stop_type:EarlyStopType; | |||
| weight:float; | |||
| rounds:int; | |||
| } | |||
| table FeatureMap{ | |||
| weight_fullname:string; | |||
| data:[float]; | |||
| } | |||
| table RequestFLJob{ | |||
| fl_name:string; | |||
| fl_id:string; | |||
| iteration:int; | |||
| data_size:int; | |||
| timestamp:string; | |||
| } | |||
| table ResponseFLJob { | |||
| retcode:int; | |||
| reason:string; | |||
| iteration:int; | |||
| is_selected:bool = false; | |||
| next_req_time:string; | |||
| fl_plan_config:FLPlan; | |||
| feature_map:[FeatureMap]; | |||
| timestamp:string; | |||
| } | |||
| table FLPlan { | |||
| fl_name:string; | |||
| iterations:int; | |||
| epochs:int; | |||
| early_stop:EarlyStop; | |||
| mini_batch:int; | |||
| shuffle:bool = false; | |||
| lr:float; | |||
| aggregation:Aggregation; | |||
| metrics:[Metrics]; | |||
| cipher:CipherPublicParams; | |||
| } | |||
| table RequestUpdateModel{ | |||
| fl_name:string; | |||
| fl_id:string; | |||
| iteration:int; | |||
| feature_map:[FeatureMap]; | |||
| timestamp:string; | |||
| } | |||
| table ResponseUpdateModel{ | |||
| retcode:int; | |||
| reason:string; | |||
| feature_map:[FeatureMap]; | |||
| next_req_time:string; | |||
| timestamp:string; | |||
| } | |||
| table RequestAsyncUpdateModel{ | |||
| fl_name:string; | |||
| fl_id:string; | |||
| iteration:int; | |||
| data_size:int; | |||
| feature_map:[FeatureMap]; | |||
| } | |||
| table ResponseAsyncUpdateModel{ | |||
| retcode:int; | |||
| reason:string; | |||
| iteration:int; | |||
| } | |||
| table RequestOverwriteWeightsByKey{ | |||
| iteration:int; | |||
| feature_map:[FeatureMap]; | |||
| } | |||
| table ResponseOverwriteWeightsByKey{ | |||
| retcode:int; | |||
| reason:string; | |||
| } | |||
| table RequestGetModel{ | |||
| fl_name:string; | |||
| iteration:int; | |||
| timestamp:string; | |||
| } | |||
| table ResponseGetModel{ | |||
| retcode:int; | |||
| reason:string; | |||
| iteration:int; | |||
| feature_map:[FeatureMap]; | |||
| timestamp:string; | |||
| } | |||
| table RequestAsyncGetModel{ | |||
| fl_name:string; | |||
| iteration:int; | |||
| } | |||
| table ResponseAsyncGetModel{ | |||
| retcode:int; | |||
| reason:string; | |||
| iteration:int; | |||
| feature_map:[FeatureMap]; | |||
| } | |||
| table RequestGetWeightsByKey{ | |||
| iteration:int; | |||
| weight_names:[string]; | |||
| } | |||
| table ResponseGetWeightsByKey{ | |||
| retcode:int; | |||
| reason:string; | |||
| feature_map:[FeatureMap]; | |||
| } | |||
| // FeatureMapList refers to the whole trained model. | |||
| table FeatureMapList { | |||
| feature_map:[FeatureMap]; | |||
| } | |||