|
|
@@ -462,39 +462,6 @@ class multiwoz22Processor(DSTProcessor): |
|
|
|
utt_tok_list.append(self.tokenize( |
|
|
|
utt['text'])) # normalize utterances |
|
|
|
|
|
|
|
# modified_slots = {} |
|
|
|
|
|
|
|
# If sys utt, extract metadata (identify and collect modified slots) |
|
|
|
# if is_sys_utt: |
|
|
|
# for d in utt['metadata']: |
|
|
|
# booked = utt['metadata'][d]['book']['booked'] |
|
|
|
# booked_slots = {} |
|
|
|
# # Check the booked section |
|
|
|
# if booked != []: |
|
|
|
# for s in booked[0]: |
|
|
|
# booked_slots[s] = self.normalize_label( |
|
|
|
# '%s-%s' % (d, s), |
|
|
|
# booked[0][s]) # normalize labels |
|
|
|
# # Check the semi and the inform slots |
|
|
|
# for category in ['book', 'semi']: |
|
|
|
# for s in utt['metadata'][d][category]: |
|
|
|
# cs = '%s-book_%s' % ( |
|
|
|
# d, s) if category == 'book' else '%s-%s' % (d, |
|
|
|
# s) |
|
|
|
# value_label = self.normalize_label( |
|
|
|
# cs, utt['metadata'][d][category] |
|
|
|
# [s]) # normalize labels |
|
|
|
# # Prefer the slot value as stored in the booked section |
|
|
|
# if s in booked_slots: |
|
|
|
# value_label = booked_slots[s] |
|
|
|
# # Remember modified slots and entire dialog state |
|
|
|
# if cs in slot_list and cumulative_labels[ |
|
|
|
# cs] != value_label: |
|
|
|
# modified_slots[cs] = value_label |
|
|
|
# cumulative_labels[cs] = value_label |
|
|
|
# |
|
|
|
# mod_slots_list.append(modified_slots.copy()) |
|
|
|
|
|
|
|
# Form proper (usr, sys) turns |
|
|
|
turn_itr = 0 |
|
|
|
diag_seen_slots_dict = {} |
|
|
@@ -938,8 +905,8 @@ def convert_examples_to_features(examples, |
|
|
|
# Account for [CLS], [SEP], [SEP], [SEP] with "- 4" (BERT) |
|
|
|
if len(tokens_a) + len(tokens_b) + len( |
|
|
|
history) > max_seq_length - model_specs['TOKEN_CORRECTION']: |
|
|
|
logger.info('Truncate Example %s. Total len=%d.' % |
|
|
|
(guid, len(tokens_a) + len(tokens_b) + len(history))) |
|
|
|
# logger.info('Truncate Example %s. Total len=%d.' % |
|
|
|
# (guid, len(tokens_a) + len(tokens_b) + len(history))) |
|
|
|
input_text_too_long = True |
|
|
|
else: |
|
|
|
input_text_too_long = False |
|
|
@@ -968,7 +935,6 @@ def convert_examples_to_features(examples, |
|
|
|
|
|
|
|
def _get_start_end_pos(class_type, token_label_ids, max_seq_length): |
|
|
|
if class_type == 'copy_value' and 1 not in token_label_ids: |
|
|
|
# logger.warn("copy_value label, but token_label not detected. Setting label to 'none'.") |
|
|
|
class_type = 'none' |
|
|
|
start_pos = 0 |
|
|
|
end_pos = 0 |
|
|
@@ -1045,9 +1011,6 @@ def convert_examples_to_features(examples, |
|
|
|
features = [] |
|
|
|
# Convert single example |
|
|
|
for (example_index, example) in enumerate(examples): |
|
|
|
if example_index % 1000 == 0: |
|
|
|
logger.info('Writing example %d of %d' % |
|
|
|
(example_index, len(examples))) |
|
|
|
|
|
|
|
total_cnt += 1 |
|
|
|
|
|
|
@@ -1075,17 +1038,6 @@ def convert_examples_to_features(examples, |
|
|
|
model_specs, example.guid) |
|
|
|
|
|
|
|
if input_text_too_long: |
|
|
|
if example_index < 10: |
|
|
|
if len(token_labels_a) > len(tokens_a): |
|
|
|
logger.info(' tokens_a truncated labels: %s' |
|
|
|
% str(token_labels_a[len(tokens_a):])) |
|
|
|
if len(token_labels_b) > len(tokens_b): |
|
|
|
logger.info(' tokens_b truncated labels: %s' |
|
|
|
% str(token_labels_b[len(tokens_b):])) |
|
|
|
if len(token_labels_history) > len(tokens_history): |
|
|
|
logger.info( |
|
|
|
' tokens_history truncated labels: %s' |
|
|
|
% str(token_labels_history[len(tokens_history):])) |
|
|
|
|
|
|
|
token_labels_a = token_labels_a[:len(tokens_a)] |
|
|
|
token_labels_b = token_labels_b[:len(tokens_b)] |
|
|
@@ -1136,25 +1088,6 @@ def convert_examples_to_features(examples, |
|
|
|
|
|
|
|
assert (len(input_ids) == len(input_ids_unmasked)) |
|
|
|
|
|
|
|
# if example_index < 10: |
|
|
|
# logger.info('*** Example ***') |
|
|
|
# logger.info('guid: %s' % (example.guid)) |
|
|
|
# logger.info('tokens: %s' % ' '.join(tokens)) |
|
|
|
# logger.info('input_ids: %s' % ' '.join([str(x) |
|
|
|
# for x in input_ids])) |
|
|
|
# logger.info('input_mask: %s' |
|
|
|
# % ' '.join([str(x) for x in input_mask])) |
|
|
|
# logger.info('segment_ids: %s' |
|
|
|
# % ' '.join([str(x) for x in segment_ids])) |
|
|
|
# logger.info('start_pos: %s' % str(start_pos_dict)) |
|
|
|
# logger.info('end_pos: %s' % str(end_pos_dict)) |
|
|
|
# logger.info('values: %s' % str(value_dict)) |
|
|
|
# logger.info('inform: %s' % str(inform_dict)) |
|
|
|
# logger.info('inform_slot: %s' % str(inform_slot_dict)) |
|
|
|
# logger.info('refer_id: %s' % str(refer_id_dict)) |
|
|
|
# logger.info('diag_state: %s' % str(diag_state_dict)) |
|
|
|
# logger.info('class_label_id: %s' % str(class_label_id_dict)) |
|
|
|
|
|
|
|
features.append( |
|
|
|
InputFeatures( |
|
|
|
guid=example.guid, |
|
|
@@ -1171,9 +1104,6 @@ def convert_examples_to_features(examples, |
|
|
|
diag_state=diag_state_dict, |
|
|
|
class_label_id=class_label_id_dict)) |
|
|
|
|
|
|
|
logger.info('========== %d out of %d examples have text too long' % |
|
|
|
(too_long_cnt, total_cnt)) |
|
|
|
|
|
|
|
return features |
|
|
|
|
|
|
|
|
|
|
|