|
- # Copyright 2020 Huawei Technologies Co., Ltd
- #
- # Licensed under the Apache License, Version 2.0 (the "License");
- # you may not use this file except in compliance with the License.
- # You may obtain a copy of the License at
- #
- # http://www.apache.org/licenses/LICENSE-2.0
- #
- # Unless required by applicable law or agreed to in writing, software
- # distributed under the License is distributed on an "AS IS" BASIS,
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- # See the License for the specific language governing permissions and
- # limitations under the License.
- # ============================================================================
- """util file"""
-
- import numpy as np
-
- class AverageMeter():
- """Computes and stores the average and current value"""
-
- def __init__(self):
- self.reset()
-
- def reset(self):
- self.val = 0
- self.avg = 0
- self.sum = 0
- self.count = 0
-
- def update(self, val, n=1):
- self.val = val
- self.sum += val * n
- self.count += n
- self.avg = self.sum / self.count
-
-
- class CTCLabelConverter():
- """ Convert between text-label and text-index """
-
- def __init__(self, character):
- # character (str): set of the possible characters.
- dict_character = list(character)
-
- self.dict = {}
- for i, char in enumerate(dict_character):
- self.dict[char] = i
-
- self.character = dict_character + ['[blank]'] # dummy '[blank]' token for CTCLoss (index 0)
- self.dict['[blank]'] = len(dict_character)
-
- def encode(self, text):
- """convert text-label into text-index.
- input:
- text: text labels of each image. [batch_size]
-
- output:
- text: concatenated text index for CTCLoss.
- [sum(text_lengths)] = [text_index_0 + text_index_1 + ... + text_index_(n - 1)]
- length: length of each text. [batch_size]
- """
- length = [len(s) for s in text]
- text = ''.join(text)
- text = [self.dict[char] for char in text]
-
- return np.array(text), np.array(length)
-
- def decode(self, text_index, length):
- """ convert text-index into text-label. """
- texts = []
- index = 0
- for l in length:
- t = text_index[index:index + l]
-
- char_list = []
- for i in range(l):
- # if t[i] != self.dict['[blank]'] and (not (i > 0 and t[i - 1] == t[i])): # removing repeated characters and blank.
- if t[i] != self.dict['[blank]'] and (
- not (i > 0 and t[i - 1] == t[i])): # removing repeated characters and blank.
- char_list.append(self.character[t[i]])
- text = ''.join(char_list)
-
- texts.append(text)
- index += l
- return texts
-
- def reverse_encode(self, text_index, length):
- """ convert text-index into text-label. """
- texts = []
- index = 0
- for l in length:
- t = text_index[index:index + l]
-
- char_list = []
- for i in range(l):
- if t[i] != self.dict['[blank]']: # removing repeated characters and blank.
- char_list.append(self.character[t[i]])
- text = ''.join(char_list)
-
- texts.append(text)
- index += l
- return texts
|