test_on_all_languages.py 8.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209
  1. import os
  2. import sys
  3. import json
  4. sys.path.append("./")
  5. sys.path.append("../")
  6. sys.path.append(".../")
  7. from itertools import product
  8. from tqdm import tqdm
  9. import kenlm
  10. from math import log
  11. import numpy as np
  12. from make_noiser import Noise
  13. import pandas as pd
  14. import sys
  15. from get_most_probable_phonemes import get_most_probable_phonemes
  16. import random
  17. from collections import Counter
  18. random.seed(1023)
  19. LANGUAGES_TYPOLOGIES = {
  20. 'da' : ("Danish", "fusional"),
  21. 'de' : ("German", "fusional"),
  22. 'en' : ("English", "fusional"),
  23. 'es' : ("Spanish", "fusional"),
  24. 'et' : ("Estonian", "agglutinative"),
  25. 'eu' : ("Basque", "agglutinative"),
  26. 'fr' : ("French", "fusional"),
  27. 'ja' : ("Japanese", "agglutinative"),
  28. 'pl' : ("Polish", "fusional"),
  29. 'pt' : ("Portuguese", "fusional"),
  30. 'sr' : ("Serbian", "fusional"),
  31. 'tr' : ("Turkish", "agglutinative")}
  32. def compute_word_frequencies(word_train_corpus, pct=0.95) :
  33. frequencies = Counter()
  34. for line in word_train_corpus :
  35. line = line.strip()
  36. if not line : continue
  37. # line = line.strip()
  38. frequencies.update(Counter(line.split(" ")))
  39. return dict(frequencies)
  40. def statistics_word(utterances, word_frequencies, model) :
  41. phoneme_utterances = []
  42. unique_words = set()
  43. nb_unk = 0
  44. mlu_w = 0.0
  45. mlu_p = 0.0
  46. mean_word_frequencies = 0
  47. nb_utterances = 0
  48. nb_words = 0
  49. statistics = {}
  50. for utterance in utterances :
  51. utterance = utterance.strip()
  52. if not utterance : continue
  53. nb_utterances += 1
  54. utterance_w = utterance.replace("@", " ").replace("$", "")
  55. utterance_p = utterance.replace("@", " ").replace("$", " ")
  56. phoneme_utterances.append(utterance_p)
  57. utterance_words = utterance_w.split(" ")
  58. mlu_w += len(utterance_words)
  59. mlu_p += len(utterance_p.split(" "))
  60. nb_words += len(utterance_words)
  61. unique_words |= set(utterance_words)
  62. for word in utterance_words :
  63. word = word.strip()
  64. if word in word_frequencies :
  65. mean_word_frequencies += word_frequencies[word]
  66. else :
  67. nb_unk += 1
  68. mlu_w /= nb_utterances
  69. mlu_p /= nb_utterances
  70. ttr_w = len(unique_words) / nb_words
  71. ppl = model.perplexity("\n".join(phoneme_utterances))
  72. entropy = log(ppl)
  73. statistics["ppl"] = ppl
  74. statistics["entropy"] = entropy
  75. statistics["mlu_w"] = mlu_w
  76. statistics["mlu_p"] = mlu_p
  77. statistics["ttr_w"] = ttr_w
  78. statistics["mean_word_frequencies"] = mean_word_frequencies
  79. statistics["nb_unk"] = nb_unk
  80. return statistics
  81. def create_sparse_combinantions(values) :
  82. sparse_combinantions = []
  83. for value in values :
  84. for idx in range(len(values)) :
  85. sparse_values = [0.0] * len(values)
  86. sparse_values[idx] = value
  87. sparse_combinantions.append(tuple(sparse_values))
  88. return set(sparse_combinantions)
  89. def test(json_files_directory, models_directory, phoneme_train_files, word_train_files, add_noise=False) :
  90. """
  91. """
  92. columns = ["language", "typology", "family", "speaker",\
  93. "age", "perplexity", "entropy", "mlu", "mlu_without_repetition",\
  94. "phonemes_order_noise", "speakers_noise_adult",\
  95. "speakers_noise_child", "phonemes_noise"]
  96. results = pd.DataFrame(columns=columns, index=None)
  97. all_combinations = list(product((0.0, 0.25, 0.5, 0.75), repeat=4)) if add_noise else [((0.0, 0.0, 0.0, 0.0))]
  98. # sparse_combinantions = create_sparse_combinantions((0.0, 0.25, 0.5, 0.75))
  99. # noise_values = np.linspace(0.0, 1.0, num=6)
  100. for phonemes_noise, speakers_noise_child, speakers_noise_adult, phonemes_order_noise in tqdm(all_combinations, total=len(all_combinations)) :
  101. for test_filename, model_filename in product(os.listdir(json_files_directory), os.listdir(models_directory)) :
  102. lg_iso, _ = test_filename.split(".")
  103. model_lg = model_filename.split(".")[0]
  104. if lg_iso != model_lg : continue
  105. print(lg_iso, model_lg)
  106. most_probable_phonemes = get_most_probable_phonemes(f"{phoneme_train_files}/{lg_iso}.one_sentence_per_line")
  107. word_frequencies = compute_word_frequencies(f"{word_train_files}/{lg_iso}.one_sentence_per_line")
  108. loaded_json = json.load(open(f"{json_files_directory}/{test_filename}"))
  109. if add_noise :
  110. noise = Noise(most_probable_phonemes,
  111. phonemes_order_noise=phonemes_order_noise,
  112. speakers_noise=(speakers_noise_child, speakers_noise_adult),
  113. phonemes_noise=phonemes_noise)
  114. loaded_json = noise(loaded_json)
  115. model = kenlm.Model(f"{models_directory}/{model_filename}")
  116. for family in loaded_json :
  117. for age in loaded_json[family] :
  118. if age == "None" : print(family, lg_iso, age); continue
  119. for speaker in loaded_json[family][age] :
  120. if speaker not in ["Adult", "Target_Child"] : continue
  121. # test_utterances = "\n".join(loaded_json[family][age][speaker])
  122. # utterances = [utterance.split(" ") for utterance in loaded_json[family][age][speaker]]
  123. # mlu = np.mean([len(utterance) for utterance in utterances])
  124. # mlu_without_repetition = np.mean([len(set(utterance)) for utterance in utterances])
  125. # ppl = model.perplexity(test_utterances)
  126. # entropy = log(ppl)
  127. results_statistics = statistics_word(loaded_json[family][age][speaker], word_frequencies, model)
  128. language, typology = LANGUAGES_TYPOLOGIES[lg_iso]
  129. new_row = {"language" : language,
  130. "typology" : typology,
  131. "family" : family,
  132. "speaker" : speaker,
  133. "age" : float(age),
  134. "perplexity" : results_statistics["ppl"],
  135. "entropy" : results_statistics["entropy"],
  136. "mlu_w" : results_statistics["mlu_w"],
  137. "mlu_p" : results_statistics["mlu_p"],
  138. "ttr_w" : results_statistics["ttr_w"],
  139. "mean_word_frequencies" : results_statistics["mean_word_frequencies"],
  140. "nb_unk" : results_statistics["nb_unk"],
  141. "phonemes_order_noise" : phonemes_order_noise,
  142. "speakers_noise_adult" : speakers_noise_adult,
  143. "speakers_noise_child" : speakers_noise_child,
  144. "phonemes_noise" : phonemes_noise}
  145. results = results.append(new_row, ignore_index=True)
  146. return results
  147. if __name__ == "__main__":
  148. from argparse import ArgumentParser, BooleanOptionalAction
  149. parser = ArgumentParser()
  150. parser.add_argument('--phoneme_train_directory',
  151. required=True,
  152. help="Dataset containing the train files in phonemes (dot one_sentence_per_line) "
  153. )
  154. parser.add_argument('--word_train_directory',
  155. required=True,
  156. help="Dataset containing the train files in words (dot one_sentence_per_line) "
  157. )
  158. parser.add_argument('--models_directory',
  159. required=True,
  160. help="Folder containing the estimated parameters"
  161. )
  162. parser.add_argument('--json_files_directory',
  163. required=True,
  164. help="Directory containing json files for test"
  165. )
  166. parser.add_argument('--out_dirname',
  167. required=True,
  168. help="Out directory"
  169. )
  170. parser.add_argument('--out_filename',
  171. required=True,
  172. help="Out filename"
  173. )
  174. parser.add_argument("--add_noise", action=BooleanOptionalAction)
  175. args = parser.parse_args()
  176. add_noise = args.add_noise
  177. json_files_directory = args.json_files_directory
  178. phoneme_train_files, word_train_files = args.phoneme_train_directory, args.word_train_directory
  179. models_directory = args.models_directory
  180. out_dirname = args.out_dirname
  181. out_filename = args.out_filename
  182. if not os.path.exists("results"):
  183. os.makedirs("results")
  184. test(json_files_directory,
  185. models_directory,
  186. phoneme_train_files,
  187. word_train_files,
  188. add_noise).to_csv(f"{out_dirname}/{out_filename}.csv")