Browse Source

GeTensor aligned addr & zero copy support

tags/v1.2.0
chenyemeng 4 years ago
parent
commit
1537a0d187
6 changed files with 27 additions and 36 deletions
  1. +0
    -2
      ge/ge_local_engine/engine/host_cpu_engine.cc
  2. +2
    -3
      ge/graph/manager/host_mem_allocator.cc
  3. +7
    -9
      ge/graph/manager/host_mem_manager.cc
  4. +9
    -8
      ge/graph/passes/assign_pass.cc
  5. +8
    -12
      ge/graph/passes/inplace_support_check_pass.cc
  6. +1
    -2
      ge/hybrid/model/hybrid_model_builder.cc

+ 0
- 2
ge/ge_local_engine/engine/host_cpu_engine.cc View File

@@ -47,8 +47,6 @@ namespace {
auto tensor_name = op_desc->GetOutputNameByIndex(i); \ auto tensor_name = op_desc->GetOutputNameByIndex(i); \
GE_RETURN_WITH_LOG_IF_TRUE(tensor_name.empty(), "Failed to get output name. node = %s, index = %zu", \ GE_RETURN_WITH_LOG_IF_TRUE(tensor_name.empty(), "Failed to get output name. node = %s, index = %zu", \
op_desc->GetName().c_str(), i); \ op_desc->GetName().c_str(), i); \
GELOGD("Successfully inserted output tensor. node = %s, index = %zu, output name = %s, addr = %p, size = %zu", \
op_desc->GetName().c_str(), i, tensor_name.c_str(), tensor.GetData(), tensor.GetSize()); \
named_outputs.emplace(tensor_name, tensor); \ named_outputs.emplace(tensor_name, tensor); \
break; \ break; \
} }


+ 2
- 3
ge/graph/manager/host_mem_allocator.cc View File

@@ -24,7 +24,7 @@ const void *HostMemAllocator::Malloc(const std::shared_ptr<AlignedPtr> &aligned_
GELOGW("Insert a null aligned_ptr"); GELOGW("Insert a null aligned_ptr");
return nullptr; return nullptr;
} }
GELOGD("allocate existed host memory succ, addr=%p, size=%zu", aligned_ptr->Get(), size);
GELOGD("allocate existed host memory succ, size=%zu", size);
allocated_blocks_[aligned_ptr->Get()] = { size, aligned_ptr }; allocated_blocks_[aligned_ptr->Get()] = { size, aligned_ptr };
return aligned_ptr->Get(); return aligned_ptr->Get();
} }
@@ -38,12 +38,11 @@ uint8_t *HostMemAllocator::Malloc(size_t size) {
return nullptr; return nullptr;
} }
allocated_blocks_[aligned_ptr->Get()] = { size, aligned_ptr }; allocated_blocks_[aligned_ptr->Get()] = { size, aligned_ptr };
GELOGD("allocate host memory succ, addr=%p, size=%zu", aligned_ptr->Get(), size);
GELOGD("allocate host memory succ, size=%zu", size);
return aligned_ptr->MutableGet(); return aligned_ptr->MutableGet();
} }


Status HostMemAllocator::Free(const void *memory_addr) { Status HostMemAllocator::Free(const void *memory_addr) {
GELOGD("Free host memory, addr=%p", memory_addr);
if (memory_addr == nullptr) { if (memory_addr == nullptr) {
GELOGE(GE_GRAPH_FREE_FAILED, "Invalid memory pointer"); GELOGE(GE_GRAPH_FREE_FAILED, "Invalid memory pointer");
return GE_GRAPH_FREE_FAILED; return GE_GRAPH_FREE_FAILED;


+ 7
- 9
ge/graph/manager/host_mem_manager.cc View File

@@ -44,15 +44,13 @@ Status SharedMemAllocator::Allocate(SharedMemInfo &mem_info) {
} }
mem_info.fd = output_para.fd; mem_info.fd = output_para.fd;
#ifndef ONLY_COMPILE_OPEN_SRC #ifndef ONLY_COMPILE_OPEN_SRC
mem_info.host_aligned_ptr = AlignedPtr::BuildAlignedPtr(mem_info.mem_size,
[&output_para](std::unique_ptr<uint8_t[], deleter> &ptr) {
GELOGD("set aligned_ptr, addr=%p", output_para.ptr);
ptr.reset(reinterpret_cast<uint8_t *>(output_para.ptr));
},
[](uint8_t *ptr) {
GELOGD("reset aligned_ptr in SharedMemAllocator, addr=%p", ptr);
ptr = nullptr;
}, 0);
mem_info.host_aligned_ptr = AlignedPtr::BuildFromAllocFunc(mem_info.mem_size,
[&output_para](std::unique_ptr<uint8_t[], deleter> &ptr) {
ptr.reset(reinterpret_cast<uint8_t *>(output_para.ptr));
},
[](uint8_t *ptr) {
ptr = nullptr;
}, 0);
#else #else
mem_info.host_address = reinterpret_cast<uint8_t *>(output_para.ptr); mem_info.host_address = reinterpret_cast<uint8_t *>(output_para.ptr);
#endif #endif


+ 9
- 8
ge/graph/passes/assign_pass.cc View File

@@ -9,7 +9,7 @@
* *
* Unless required by applicable law or agreed to in writing, software * Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, * distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.l
* See the License for the specific language governing permissions and * See the License for the specific language governing permissions and
* limitations under the License. * limitations under the License.
*/ */
@@ -20,9 +20,12 @@
#include "graph/debug/ge_attr_define.h" #include "graph/debug/ge_attr_define.h"


namespace { namespace {
const uint32_t kValidInputNodeOutputNum = 1;
const int32_t kAssignRefInputIndex = 0;
const int32_t kAssignValueInputIndex = 1;
constexpr uint32_t kValidInputNodeOutputNum = 1;
constexpr int32_t kAssignRefInputIndex = 0;
constexpr int32_t kAssignValueInputIndex = 1;
static const std::set<std::string> kNoTaskNodeTypes = { ge::DATA, ge::ANN_DATA, ge::AIPPDATA,
ge::CONSTANT, ge::CONSTANTOP,
ge::VARIABLE, ge::VARIABLEV2 };
} }


namespace ge { namespace ge {
@@ -80,7 +83,6 @@ Status AssignPass::OptimizedAssignNode(NodePtr &assign_node) {
GELOGE(FAILED, "Isolate and delete assign_node %s failed.", assign_node->GetName().c_str()); GELOGE(FAILED, "Isolate and delete assign_node %s failed.", assign_node->GetName().c_str());
return FAILED; return FAILED;
} }
AddNodeDeleted(assign_node);


const auto &ref_input = ref_peer_anchor->GetOwnerNode()->GetOpDesc(); const auto &ref_input = ref_peer_anchor->GetOwnerNode()->GetOpDesc();
const auto &value_input = value_peer_anchor->GetOwnerNode()->GetOpDesc(); const auto &value_input = value_peer_anchor->GetOwnerNode()->GetOpDesc();
@@ -221,9 +223,8 @@ bool AssignPass::IsCondMatch(const NodePtr &node, const OutDataAnchorPtr &ref_pe
node->GetName().c_str(), ref_peer_anchor->GetOwnerNode()->GetName().c_str(), node->GetName().c_str(), ref_peer_anchor->GetOwnerNode()->GetName().c_str(),
value_peer_anchor->GetOwnerNode()->GetName().c_str()); value_peer_anchor->GetOwnerNode()->GetName().c_str());


const std::string &value_type = value_peer_anchor->GetOwnerNode()->GetType();
if ((value_type == CONSTANTOP) || (value_type == CONSTANT)) {
GELOGD("value input is const");
if (kNoTaskNodeTypes.count(value_peer_anchor->GetOwnerNode()->GetType()) > 0) {
GELOGD("value input is not calculate node");
return false; return false;
} }




+ 8
- 12
ge/graph/passes/inplace_support_check_pass.cc View File

@@ -20,20 +20,16 @@
#include "graph/debug/ge_attr_define.h" #include "graph/debug/ge_attr_define.h"


namespace { namespace {
const uint32_t kInplaceSupportOutputIndex = 0;
const uint32_t kInplaceSupportOutputNum = 1;
static const std::set<std::string> src_node_types = { ge::DATA, ge::ANN_DATA, ge::AIPPDATA,
ge::CONSTANT, ge::CONSTANTOP,
ge::VARIABLE, ge::VARIABLEV2 };
constexpr uint32_t kInplaceSupportOutputIndex = 0;
constexpr uint32_t kInplaceSupportOutputNum = 1;
static const std::set<std::string> kSrcNodeTypes = { ge::DATA, ge::ANN_DATA, ge::AIPPDATA,
ge::CONSTANT, ge::CONSTANTOP,
ge::VARIABLE, ge::VARIABLEV2 };
} }


namespace ge { namespace ge {
Status InplaceSupportCheckPass::Run(NodePtr &node) { Status InplaceSupportCheckPass::Run(NodePtr &node) {
GELOGD("InplaceSupportCheckPass running"); GELOGD("InplaceSupportCheckPass running");
if (src_node_types.count(node->GetType()) > 0) {
GELOGD("meet src_node %s, skip InplaceSupportCheckPass", node->GetName().c_str());
return SUCCESS;
}
if (node->GetAllOutDataAnchorsSize() != kInplaceSupportOutputNum) { if (node->GetAllOutDataAnchorsSize() != kInplaceSupportOutputNum) {
GELOGD("output num of node %s is not %u, skip InplaceSupportCheckPass", GELOGD("output num of node %s is not %u, skip InplaceSupportCheckPass",
node->GetName().c_str(), kInplaceSupportOutputNum); node->GetName().c_str(), kInplaceSupportOutputNum);
@@ -49,7 +45,7 @@ Status InplaceSupportCheckPass::Run(NodePtr &node) {
continue; continue;
} }
auto in_node = peer_data_anchor->GetOwnerNode(); auto in_node = peer_data_anchor->GetOwnerNode();
if (src_node_types.count(in_node->GetType()) > 0) {
if (kSrcNodeTypes.count(in_node->GetType()) > 0) {
GELOGD("meet src_node %s", in_node->GetName().c_str()); GELOGD("meet src_node %s", in_node->GetName().c_str());
continue; continue;
} }
@@ -62,11 +58,11 @@ Status InplaceSupportCheckPass::Run(NodePtr &node) {
const DataType &input_type = node->GetOpDesc()->GetInputDesc(inplace_input_idx).GetDataType(); const DataType &input_type = node->GetOpDesc()->GetInputDesc(inplace_input_idx).GetDataType();
const GeShape &input_shape = node->GetOpDesc()->GetInputDesc(inplace_input_idx).GetShape(); const GeShape &input_shape = node->GetOpDesc()->GetInputDesc(inplace_input_idx).GetShape();
if (input_type != output_type) { if (input_type != output_type) {
GELOGD("DataType mismatch, in_idx=%d, input_type=%u, output_type=%u", inplace_input_idx, input_type, output_type);
GELOGW("DataType mismatch, in_idx=%d, input_type=%u, output_type=%u", inplace_input_idx, input_type, output_type);
continue; continue;
} }
if (input_shape.GetDims() != output_shape.GetDims()) { if (input_shape.GetDims() != output_shape.GetDims()) {
GELOGD("Shape mismatch, in_idx=%d, input_shape=[%s], output_shape=[%s]",
GELOGW("Shape mismatch, in_idx=%d, input_shape=[%s], output_shape=[%s]",
inplace_input_idx, input_shape.ToString().c_str(), output_shape.ToString().c_str()); inplace_input_idx, input_shape.ToString().c_str(), output_shape.ToString().c_str());
continue; continue;
} }


+ 1
- 2
ge/hybrid/model/hybrid_model_builder.cc View File

@@ -932,8 +932,7 @@ Status HybridModelBuilder::InitVariableTensors() {
GELOGE(MEMALLOC_FAILED, "Malloc host memory for an existed GeTensor failed."); GELOGE(MEMALLOC_FAILED, "Malloc host memory for an existed GeTensor failed.");
return MEMALLOC_FAILED; return MEMALLOC_FAILED;
} }
GELOGD("Host variable [%s] malloc success, host_addr=%p, dev_addr=%p, size=%lld.",
it.first.c_str(), mem_info.host_aligned_ptr->Get(), mem_info.device_address, tensor_size);
GELOGD("Host variable [%s] malloc success, size=%lld.", it.first.c_str(), tensor_size);


std::unique_ptr<TensorValue> tensor(new (std::nothrow) TensorValue(mem_info.host_aligned_ptr->MutableGet(), std::unique_ptr<TensorValue> tensor(new (std::nothrow) TensorValue(mem_info.host_aligned_ptr->MutableGet(),
tensor_size)); tensor_size));


Loading…
Cancel
Save