|
- /**
- * Copyright 2019-2020 Huawei Technologies Co., Ltd
- *
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
- #ifndef GE_OP_LOOKUP_OPS_H_
- #define GE_OP_LOOKUP_OPS_H_
-
- #include "graph/operator_reg.h"
-
- namespace ge {
-
- REG_OP(LookupTableImport)
- .INPUT(handle, TensorType({DT_RESOURCE}))
- .INPUT(keys, TensorType({DT_BOOL, DT_DOUBLE, \
- DT_FLOAT, DT_INT32, DT_INT64}))
- .INPUT(values, TensorType({DT_BOOL, DT_DOUBLE, \
- DT_FLOAT, DT_INT32, DT_INT64}))
- .OP_END_FACTORY_REG(LookupTableImport)
-
- REG_OP(LookupTableInsert)
- .INPUT(handle, TensorType({DT_RESOURCE}))
- .INPUT(keys, TensorType({DT_BOOL, DT_DOUBLE, DT_FLOAT, \
- DT_INT32, DT_INT64}))
- .INPUT(values, TensorType({DT_BOOL, DT_DOUBLE, DT_FLOAT, \
- DT_INT32, DT_INT64}))
- .OP_END_FACTORY_REG(LookupTableInsert)
-
- REG_OP(LookupTableExport)
- .INPUT(handle, TensorType({DT_RESOURCE}))
- .OUTPUT(keys, TensorType({DT_BOOL, DT_DOUBLE, DT_FLOAT, \
- DT_INT32, DT_INT64}))
- .OUTPUT(values, TensorType({DT_BOOL, DT_DOUBLE, DT_FLOAT, \
- DT_INT32,DT_INT64}))
- .REQUIRED_ATTR(Tkeys, Type)
- .REQUIRED_ATTR(Tvalues, Type)
- .OP_END_FACTORY_REG(LookupTableExport)
- REG_OP(LookupTableSize)
- .INPUT(handle, TensorType({DT_RESOURCE}))
- .OUTPUT(size, TensorType({DT_INT64}))
- .OP_END_FACTORY_REG(LookupTableSize)
-
- REG_OP(LookupTableFind)
- .INPUT(handle, TensorType({DT_RESOURCE}))
- .INPUT(keys, TensorType({DT_DOUBLE, DT_FLOAT, \
- DT_INT32, DT_INT64}))
- .INPUT(default_value, TensorType({DT_DOUBLE, DT_FLOAT, \
- DT_INT32, DT_INT64}))
- .OUTPUT(values, TensorType({DT_DOUBLE, DT_FLOAT, DT_INT32, \
- DT_INT64}))
- .REQUIRED_ATTR(Tout, Type)
- .OP_END_FACTORY_REG(LookupTableFind)
-
- REG_OP(HashTable)
- .OUTPUT(handle, TensorType({DT_RESOURCE}))
- .ATTR(container, String, "")
- .ATTR(shared_name, String, "")
- .ATTR(use_node_name_sharing, Bool, false)
- .REQUIRED_ATTR(key_dtype, Type)
- .REQUIRED_ATTR(value_dtype, Type)
- .OP_END_FACTORY_REG(HashTable)
-
- REG_OP(InitializeTable)
- .INPUT(handle, TensorType({DT_RESOURCE}))
- .INPUT(keys, TensorType({DT_INT8, DT_UINT8, DT_INT16, DT_UINT16, \
- DT_INT32, DT_INT64, DT_FLOAT16, DT_FLOAT, DT_DOUBLE}))
- .INPUT(values, TensorType({DT_INT8, DT_UINT8, DT_INT16, DT_UINT16, \
- DT_INT32, DT_INT64, DT_FLOAT16, DT_FLOAT, DT_DOUBLE}))
- .OP_END_FACTORY_REG(InitializeTable)
-
- REG_OP(MutableDenseHashTable)
- .INPUT(empty_key, TensorType({DT_INT32, DT_INT64}))
- .INPUT(deleted_key, TensorType({DT_INT32, DT_INT64}))
- .OUTPUT(handle, TensorType({DT_RESOURSE}))
- .ATTR(container, String, "")
- .ATTR(shared_name, String, "")
- .ATTR(use_node_name_sharing, Bool, false)
- .REQUIRED_ATTR(value_dtype, Type)
- .ATTR(value_shape, ListInt, {})
- .ATTR(initial_num_buckets, Int, 131072)
- .ATTR(max_load_factor, Float, 0.8)
- .OP_END_FACTORY_REG(MutableDenseHashTable)
-
- REG_OP(MutableHashTableOfTensors)
- .OUTPUT(handle, TensorType({DT_RESOURSE}))
- .ATTR(container, String, "")
- .ATTR(shared_name, String, "")
- .ATTR(use_node_name_sharing, Bool, false)
- .REQUIRED_ATTR(key_dtype, Type)
- .REQUIRED_ATTR(value_dtype, Type)
- .ATTR(value_shape, ListInt, {})
- .OP_END_FACTORY_REG(MutableHashTableOfTensors)
-
- REG_OP(MutableHashTable)
- .OUTPUT(handle, TensorType({DT_RESOURSE}))
- .ATTR(container, String, "")
- .ATTR(shared_name, String, "")
- .ATTR(use_node_name_sharing, Bool, false)
- .REQUIRED_ATTR(key_dtype, Type)
- .REQUIRED_ATTR(value_dtype, Type)
- .OP_END_FACTORY_REG(MutableHashTable)
- } // namespace ge
-
- #endif // GE_OP_LOOKUP_OPS_H_
|