@@ -28,6 +28,7 @@
#include <thread>
#include <cmath>
#include <random>
#include <list>
#include "ir/func_graph.h"
#include "backend/session/session_basic.h"
#include "backend/session/anf_runtime_algorithm.h"
@@ -70,6 +71,7 @@ class ParameterServer {
handler_(nullptr),
func_graph_(nullptr),
sess_(nullptr),
running_(true),
thread_(nullptr) {}
~ParameterServer() = default;
ParameterServer(const ParameterServer &) = delete;
@@ -89,6 +91,8 @@ class ParameterServer {
::ps::KVPairs<T> *res);
void HandleInitInputsShape(const ::ps::KVMeta &req_meta, const ::ps::KVPairs<T> &req_data, ::ps::KVPairs<T> *res);
void HandleInitEmbeddings(const ::ps::KVMeta &req_meta, const ::ps::KVPairs<T> &req_data, ::ps::KVPairs<T> *res);
void HandleCheckReadyForPush(const ::ps::KVMeta &req_meta, const ::ps::KVPairs<T> &req_data, ::ps::KVPairs<T> *res);
void HandleCheckReadyForPull(const ::ps::KVMeta &req_meta, const ::ps::KVPairs<T> &req_data, ::ps::KVPairs<T> *res);
void HandleEmbeddingLookup(const ::ps::KVMeta &req_meta, const ::ps::KVPairs<T> &req_data, ::ps::KVPairs<T> *res);
void HandleFinalize(const ::ps::KVMeta &req_meta, const ::ps::KVPairs<T> &req_data, ::ps::KVPairs<T> *res);
@@ -96,6 +100,9 @@ class ParameterServer {
typedef void (ServerHandler::*RequestHandler)(const ::ps::KVMeta &req_meta, const ::ps::KVPairs<T> &req_data,
::ps::KVPairs<T> *res);
std::unordered_map<int, RequestHandler> handlers_;
std::unordered_map<Key, bool> init_weights_;
std::unordered_map<Key, bool> init_weight_to_optim_;
std::unordered_map<Key, bool> init_optim_info_;
};
bool Init(const FuncGraphPtr &func_graph);
@@ -106,14 +113,18 @@ class ParameterServer {
void InitGrad(const Key &key, const GradPtr &grad);
void InitEmbeddingTable(const Key &key,
const std::shared_ptr<std::vector<std::shared_ptr<std::vector<size_t>>>> &shapes);
void Finalize();
void UpdateWeights();
void AccumGrad(const Keys &key, const Values &values, const Lengths &lengths);
WeightPtr weight(const Key &key);
void DoEmbeddingLookup(Key key, const LookupIds &lookup_ids, ::ps::KVPairs<T> *res);
int SumOfShapes(const std::vector<int> &shapes) const;
bool ReadyForUpdateWeights();
bool ReadyForAccumGrads();
bool ReadyForPush(const Key &key);
bool ReadyForPull(const Key &key);
void ResetGradAccumCount();
std::mutex &mutex();
const CNodePtr GetCNode(const std::string &name) const;
size_t pserver_num_;
size_t worker_num_;
@@ -123,20 +134,23 @@ class ParameterServer {
std::unique_ptr<ServerHandler> handler_;
FuncGraphPtr func_graph_;
std::shared_ptr<session::SessionBasic> sess_;
bool running_;
std::unordered_map<Key, std::shared_ptr<PServerKernel>> optimizers_;
std::unordered_map<Key, InputsShapePtr> optim_inputs_shape_;
std::unordered_map<Key, std::shared_ptr<OptimizerInfo>> optim_infos_;
std::unordered_map<std::string, std::shared_ptr<OptimizerInfoBuilder>> optim_info_builders_;
std::unordered_map<Key, std::string> weight_key_to_optims_;
std::unordered_map<Key, std::string> weight_key_to_optim_op_;
std::unordered_map<Key, WeightPtr> weights_;
std::unordered_map<Key, bool> is_embedding_;
std::unordered_map<Key, WeightPtr> grads_;
std::unordered_map<Key, size_t> grads_accum_counter_;
std::unordered_map<Key, std::shared_ptr<PServerKernel>> embedding_lookup_ops_;
std::unordered_map<Key, uint64_t> tokens_;
std::mutex mutex_;
std::condition_variable apply_grads_cv_;
std::condition_variable accum_grads_cv_;
std::unique_ptr<std::thread> thread_;
@@ -165,6 +179,8 @@ void ParameterServer<T>::ServerHandler::Init() {
handlers_[kInitWeightToOptimIdCmd] = &ServerHandler::HandleInitWeightToOptimId;
handlers_[kInitOptimInputsShapeCmd] = &ServerHandler::HandleInitInputsShape;
handlers_[kInitEmbeddingsCmd] = &ServerHandler::HandleInitEmbeddings;
handlers_[kCheckReadyForPushCmd] = &ServerHandler::HandleCheckReadyForPush;
handlers_[kCheckReadyForPullCmd] = &ServerHandler::HandleCheckReadyForPull;
handlers_[kEmbeddingLookupCmd] = &ServerHandler::HandleEmbeddingLookup;
handlers_[kFinalizeCmd] = &ServerHandler::HandleFinalize;
}
@@ -186,6 +202,7 @@ void ParameterServer<T>::ServerHandler::HandlePullReq(const ::ps::KVMeta &req_me
template <typename T>
void ParameterServer<T>::ServerHandler::HandleInitWeights(const ::ps::KVMeta &req_meta,
const ::ps::KVPairs<T> &req_data, ::ps::KVPairs<T> *res) {
std::unique_lock<std::mutex> lock(ps_->mutex());
size_t key_num = req_data.keys.size();
T *data_ptr = req_data.vals.data();
size_t pos = 0;
@@ -207,10 +224,16 @@ template <typename T>
void ParameterServer<T>::ServerHandler::HandleInitWeightToOptimId(const ::ps::KVMeta &req_meta,
const ::ps::KVPairs<T> &req_data,
::ps::KVPairs<T> *res) {
std::unique_lock<std::mutex> lock(ps_->mutex());
size_t key_num = req_data.keys.size();
for (size_t i = 0; i < key_num; i++) {
Key key = req_data.keys[i];
T val = req_data.vals[i];
if (init_weight_to_optim_[key]) {
continue;
} else {
init_weight_to_optim_[key] = true;
}
ps_->InitWeightKeyToOptims(key, val);
}
}
@@ -218,12 +241,21 @@ void ParameterServer<T>::ServerHandler::HandleInitWeightToOptimId(const ::ps::KV
template <typename T>
void ParameterServer<T>::ServerHandler::HandleInitInputsShape(const ::ps::KVMeta &req_meta,
const ::ps::KVPairs<T> &req_data, ::ps::KVPairs<T> *res) {
std::unique_lock<std::mutex> lock(ps_->mutex());
const Key &key = req_data.keys[0];
if (init_optim_info_[key]) {
return;
} else {
init_optim_info_[key] = true;
}
ps_->InitOptimInputsShape(req_data.keys, req_data.vals, req_data.lens);
}
template <typename T>
void ParameterServer<T>::ServerHandler::HandleInitEmbeddings(const ::ps::KVMeta &req_meta,
const ::ps::KVPairs<T> &req_data, ::ps::KVPairs<T> *res) {
std::unique_lock<std::mutex> lock(ps_->mutex());
const Key &key = req_data.keys[0];
std::shared_ptr<std::vector<std::shared_ptr<std::vector<size_t>>>> shapes =
std::make_shared<std::vector<std::shared_ptr<std::vector<size_t>>>>();
std::shared_ptr<std::vector<size_t>> input_shape = std::make_shared<std::vector<size_t>>();
@@ -233,7 +265,6 @@ void ParameterServer<T>::ServerHandler::HandleInitEmbeddings(const ::ps::KVMeta
shapes->push_back(indices_shape);
shapes->push_back(output_shape);
const Key &key = req_data.keys[0];
const Lengths &lens = req_data.lens;
size_t index = 0;
for (int i = 0; i < lens[0]; i++) {
@@ -248,6 +279,26 @@ void ParameterServer<T>::ServerHandler::HandleInitEmbeddings(const ::ps::KVMeta
ps_->InitEmbeddingTable(key, shapes);
}
template <typename T>
void ParameterServer<T>::ServerHandler::HandleCheckReadyForPush(const ::ps::KVMeta &req_meta,
const ::ps::KVPairs<T> &req_data,
::ps::KVPairs<T> *res) {
const Key &key = req_data.keys[0];
bool ready = ps_->ReadyForPush(key);
res->keys.push_back(key);
res->vals.push_back(ready);
}
template <typename T>
void ParameterServer<T>::ServerHandler::HandleCheckReadyForPull(const ::ps::KVMeta &req_meta,
const ::ps::KVPairs<T> &req_data,
::ps::KVPairs<T> *res) {
const Key &key = req_data.keys[0];
bool ready = ps_->ReadyForPull(key);
res->keys.push_back(key);
res->vals.push_back(ready);
}
template <typename T>
void ParameterServer<T>::ServerHandler::HandleEmbeddingLookup(const ::ps::KVMeta &req_meta,
const ::ps::KVPairs<T> &req_data, ::ps::KVPairs<T> *res) {
@@ -261,7 +312,7 @@ void ParameterServer<T>::ServerHandler::HandleEmbeddingLookup(const ::ps::KVMeta
template <typename T>
void ParameterServer<T>::ServerHandler::HandleFinalize(const ::ps::KVMeta &req_meta, const ::ps::KVPairs<T> &req_data,
::ps::KVPairs<T> *res) {
::ps::Finalize(0, false );
ps_->Finalize( );
}
template <typename T>
@@ -274,7 +325,6 @@ bool ParameterServer<T>::Init(const FuncGraphPtr &func_graph) {
handler_->Init();
InitOptimInfoBuilders();
ps_->set_request_handle(*handler_);
thread_.reset(new std::thread(&ParameterServer::UpdateWeights, this));
return true;
@@ -296,6 +346,7 @@ void ParameterServer<T>::InitWeightKeyToOptims(const Key &key, const int &optim_
return;
}
weight_key_to_optims_[key] = Util::optimizer_name(optim_id);
weight_key_to_optim_op_[key] = Util::optimizer_node_name(optim_id);
}
template <typename T>
@@ -318,31 +369,49 @@ void ParameterServer<T>::InitOptimInputsShape(const Keys &keys, const Values &va
}
if (weight_key_to_optims_.count(key) > 0) {
const std::string &optim_name = weight_key_to_optims_[key];
const std::string &optim_op_name = weight_key_to_optim_op_[key];
if (optimizers_.count(key) == 0 && optim_inputs_shape_.count(key) > 0) {
const CNodePtr cnode = GetCNode(optim_op_name);
MS_EXCEPTION_IF_NULL(cnode);
if (optim_name == kSparseAdam) {
std::shared_ptr<PServerKernel> optimizer =
std::make_shared<kernel::ps::SparseApplyLazyAdamPSKernel>(rank_id_, pserver_num_);
optimizer->InitKernel(optim_inputs_shape_[key]);
optimizer->InitKernel(cnode, optim_inputs_shape_[key]);
optimizers_[key] = optimizer;
} else if (optim_name == kApplyMomentum) {
std::shared_ptr<PServerKernel> optimizer =
std::make_shared<kernel::ps::ApplyMomentumPSKernel>(rank_id_, pserver_num_);
optimizer->InitKernel(optim_inputs_shape_[key]);
optimizer->InitKernel(cnode, optim_inputs_shape_[key]);
optimizers_[key] = optimizer;
} else if (optim_name == kSparseFtrl) {
std::shared_ptr<PServerKernel> optimizer =
std::make_shared<kernel::ps::SparseApplyFtrlPSKernel>(rank_id_, pserver_num_);
optimizer->InitKernel(optim_inputs_shape_[key]);
optimizer->InitKernel(cnode, optim_inputs_shape_[key]);
optimizers_[key] = optimizer;
}
}
}
}
template <typename T>
const CNodePtr ParameterServer<T>::GetCNode(const std::string &name) const {
std::list<CNodePtr> cnodes = func_graph_->GetOrderedCnodes();
for (CNodePtr cnode : cnodes) {
std::string fullname = cnode->fullname_with_scope();
if (fullname.find(name) != std::string::npos && fullname.find("Push") != std::string::npos) {
return cnode;
}
}
return nullptr;
}
template <typename T>
void ParameterServer<T>::InitWeight(const Key &key, const WeightPtr &weight) {
if (weights_.count(key) == 0) {
MS_LOG(INFO) << "Initializing weight for key " << key;
if ((weights_.count(key) == 0) || (is_embedding_[key] && weights_.count(key) != 0)) {
weights_[key] = weight;
tokens_[key] = 0;
is_embedding_[key] = false;
}
}
@@ -357,7 +426,7 @@ void ParameterServer<T>::InitGrad(const Key &key, const GradPtr &grad) {
template <typename T>
void ParameterServer<T>::InitEmbeddingTable(
const Key &key, const std::shared_ptr<std::vector<std::shared_ptr<std::vector<size_t>>>> &shapes) {
// Init embedding lookup kernel
MS_LOG(INFO) << "Initializing embedding table for key " << key;
std::shared_ptr<PServerKernel> lookup = std::make_shared<kernel::ps::EmbeddingLookUpPSKernel>(rank_id_, pserver_num_);
lookup->InitKernel(shapes);
embedding_lookup_ops_[key] = lookup;
@@ -377,15 +446,26 @@ void ParameterServer<T>::InitEmbeddingTable(
embedding_data[i] = random(engine);
}
weights_[key] = embedding;
tokens_[key] = 0;
is_embedding_[key] = true;
grads_accum_counter_[key] = 0;
}
template <typename T>
void ParameterServer<T>::Finalize() {
running_ = false;
apply_grads_cv_.notify_one();
}
template <typename T>
void ParameterServer<T>::UpdateWeights() {
while (true) {
std::unique_lock<std::mutex> lock(mutex_);
apply_grads_cv_.wait(lock, [this] { return this->ReadyForUpdateWeights(); });
apply_grads_cv_.wait(lock, [this] { return this->ReadyForUpdateWeights() || !running_; });
if (!running_) {
break;
}
for (auto iter = weights_.begin(); iter != weights_.end(); iter++) {
Key key = iter->first;
@@ -408,17 +488,17 @@ void ParameterServer<T>::UpdateWeights() {
optim_info->ComputeMean(worker_num_);
optimizer->Execute(inputs, workspaces, outputs);
optim_info->Reset();
if (!is_embedding_[key]) {
tokens_[key] = worker_num_;
}
}
ResetGradAccumCount();
accum_grads_cv_.notify_all();
}
}
template <typename T>
void ParameterServer<T>::AccumGrad(const Keys &keys, const Values &values, const Lengths &lengths) {
std::unique_lock<std::mutex> lock(mutex_);
accum_grads_cv_.wait(lock, [this] { return this->ReadyForAccumGrads(); });
const Key &key = keys[0];
std::shared_ptr<OptimizerInfo> optim_info = optim_infos_[key];
@@ -451,14 +531,13 @@ void ParameterServer<T>::AccumGrad(const Keys &keys, const Values &values, const
template <typename T>
WeightPtr ParameterServer<T>::weight(const Key &key) {
std::unique_lock<std::mutex> lock(mutex_);
if (weights_.count(key) == 0) {
MS_LOG(ERROR) << "Invalid weight key " << key;
return nullptr;
MS_LOG(EXCEPTION) << "Invalid weight key " << key;
}
WeightPtr weight_ptr = weights_[key];
WeightPtr copy_weight_ptr = std::make_shared<::ps::SArray<T>>(weight_ptr->size(), 0);
copy_weight_ptr->CopyFrom(weight_ptr->data(), weight_ptr->size());
tokens_[key] -= 1;
return copy_weight_ptr;
}
@@ -529,8 +608,22 @@ inline bool ParameterServer<T>::ReadyForUpdateWeights() {
}
template <typename T>
inline bool ParameterServer<T>::ReadyForAccumGrads() {
return grad_accum_count_ < weights_.size();
inline bool ParameterServer<T>::ReadyForPush(const Key &key) {
std::unique_lock<std::mutex> lock(mutex_);
if (weights_.empty()) {
MS_LOG(EXCEPTION) << "The weights in server is empty. Many reasons could cause this: 1.The Worker didn't send "
"kInitWeightsCmd command. 2.The Server failed to initialize weights.";
}
return grad_accum_count_ < weights_.size() && tokens_[key] <= 0;
}
template <typename T>
inline bool ParameterServer<T>::ReadyForPull(const Key &key) {
std::unique_lock<std::mutex> lock(mutex_);
if (tokens_.count(key) == 0 || weights_[key] == 0) {
MS_LOG(EXCEPTION) << "Invalid weight key " << key;
}
return tokens_[key] > 0;
}
template <typename T>
@@ -541,6 +634,11 @@ inline void ParameterServer<T>::ResetGradAccumCount() {
}
}
template <typename T>
inline std::mutex &ParameterServer<T>::mutex() {
return mutex_;
}
template <typename T>
void ParameterServer<T>::Run(const FuncGraphPtr &func_graph) {
::ps::Start(0);
@@ -550,6 +648,8 @@ void ParameterServer<T>::Run(const FuncGraphPtr &func_graph) {
}
Init(func_graph);
thread_->join();
::ps::Finalize(0, true);
exit(1);
}
} // namespace ps
} // namespace parallel