annotate_textgrids.py 4.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122
  1. #!usr/bin/env python
  2. # -*- coding: utf8 -*-
  3. import os
  4. import re
  5. import pandas as pd
  6. from pprint import pprint
  7. from pympi.Praat import TextGrid
  8. from utils import text_read, text_dump, walk_dir, extract_source_sentence_file, SENT_ID_PATTERN
  9. def annotate_textgrids(path):
  10. all_textgrids = walk_dir(path, ext='.TextGrid')
  11. for textgrid_file in all_textgrids:
  12. tg_path, tg_filename = os.path.split(textgrid_file)
  13. dir_name = tg_path.split(os.sep)[-1]
  14. # Get metadata (apply twice if respeak session)
  15. target_sentence_file = extract_source_sentence_file(dir_name)
  16. if 'respeak' in textgrid_file:
  17. target_sentence_file = extract_source_sentence_file(target_sentence_file)
  18. # Get sentid, wordid, senttype
  19. tg_raw_filename = tg_filename.replace('.TextGrid', '')
  20. _, sentid, *_ = tg_raw_filename.split('_')
  21. sentid_matched = re.search(SENT_ID_PATTERN, sentid)
  22. if not sentid_matched: continue
  23. pair_id, word_id, sent_type = sentid_matched.group(1), sentid_matched.group(2), sentid_matched.group(3)
  24. # Read summary sentences
  25. target_summary_all = pd.read_csv('./extra/sentences/summary/{}.csv'.format(target_sentence_file))
  26. row_idx = target_summary_all[target_summary_all['pair_index'] == 'pair{}'.format(pair_id)].index
  27. assert len(row_idx) == 1, ValueError("More than one pair found for that pairID!") # should never happen
  28. row_idx = int(row_idx[0])
  29. # Get sentence, target word, and target syllable
  30. target_summary_sent = target_summary_all.loc[row_idx, 'word{}.{}'.format(word_id, sent_type)]
  31. target_summary_word = target_summary_all.loc[row_idx, 'word{}.{}'.format(word_id, 'raw')]
  32. target_summary_syl = target_summary_all.loc[row_idx, "key_syllables"]
  33. # Read TextGrid files
  34. textgrid = TextGrid(textgrid_file)
  35. tier_list = [n for _, n in list(textgrid.get_tier_name_num())]
  36. # Remove extra tiers
  37. if len(tier_list) > 1:
  38. for tier in tier_list:
  39. if tier == 'transcription': continue
  40. textgrid.remove_tier(tier)
  41. # Read transcription interval
  42. trans_tier = textgrid.get_tier('transcription')
  43. tg_intervals = trans_tier.get_all_intervals()
  44. assert len(tg_intervals) == 1
  45. tg_start, tg_end, tg_trans = tg_intervals[0]
  46. # Get key syllables to annotate for that word
  47. words_key_syl = []
  48. for item in target_summary_syl.split(','):
  49. item = item.strip()
  50. for syl in [item.split('/')]:
  51. words_key_syl.append(syl[int(word_id)-1])
  52. # Add this information to the target word to annotate
  53. word_annotated = target_summary_word[:]
  54. for syl in words_key_syl:
  55. word_annotated=word_annotated.replace(syl, '[{}]'.format(syl))
  56. # Add information to TextGrid
  57. if tg_trans != target_summary_sent:
  58. print(tg_trans, target_summary_sent)
  59. trans_tier.clear_intervals()
  60. print(trans_tier.tier_type)
  61. trans_tier.add_interval(tg_start, tg_end, target_summary_sent)
  62. kw_tier = textgrid.add_tier('key-word')
  63. kw_tier.add_interval(textgrid.xmin, textgrid.xmax, word_annotated)
  64. textgrid.add_tier('key-syll-segment')
  65. textgrid.add_tier('key-syll-segment-sound')
  66. ordered_tiers = [textgrid.get_tier('transcription'),
  67. textgrid.get_tier('key-word'),
  68. textgrid.get_tier('key-syll-segment'),
  69. textgrid.get_tier('key-syll-segment-sound')]
  70. textgrid.tiers = list(reversed(ordered_tiers))
  71. textgrid.to_file(textgrid_file)
  72. def main(**kwargs):
  73. annotate_textgrids(**kwargs)
  74. def _parse_args(argv):
  75. import argparse
  76. parser = argparse.ArgumentParser(description='Pre-annotate blank generated TextGrid by adding key-word, key-syll, '
  77. 'and key-syll-segment tiers and enclose target word in square brackets.')
  78. parser.add_argument('--path', required=True,
  79. help='Path to the directory whose file will be annotated.')
  80. args = parser.parse_args(argv)
  81. return vars(args)
  82. if __name__ == '__main__':
  83. import sys
  84. import logging
  85. pgrm_name, argv = sys.argv[0], sys.argv[1:]
  86. args = _parse_args(argv)
  87. try:
  88. main(**args)
  89. sys.exit(0)
  90. except Exception as e:
  91. logging.exception(e)
  92. sys.exit(1)