test_on_all_languages.py 7.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180
  1. import os
  2. import random
  3. import json
  4. from math import log
  5. from typing import Iterable
  6. from itertools import product
  7. from tqdm import tqdm
  8. import kenlm
  9. from make_noiser import Noise
  10. import pandas as pd
  11. from get_most_probable_phonemes import get_most_probable_phonemes
  12. random.seed(1023)
  13. LANGUAGES_TYPOLOGIES = {
  14. 'da' : ("Danish", "fusional"),
  15. 'de' : ("German", "fusional"),
  16. 'en' : ("English", "fusional"),
  17. 'es' : ("Spanish", "fusional"),
  18. 'et' : ("Estonian", "agglutinative"),
  19. 'eu' : ("Basque", "agglutinative"),
  20. 'fr' : ("French", "fusional"),
  21. 'ja' : ("Japanese", "agglutinative"),
  22. 'pl' : ("Polish", "fusional"),
  23. 'pt' : ("Portuguese", "fusional"),
  24. 'sr' : ("Serbian", "fusional"),
  25. 'tr' : ("Turkish", "agglutinative")}
  26. def statistics_word(utterances: list, model: kenlm.Model) -> dict:
  27. """
  28. This function will test a given language model\
  29. on a given list of utterances.\
  30. The function will also compute some statistics; MLU, TTR, etc
  31. Parameters
  32. ----------
  33. - model
  34. The estimated language model
  35. - utterances: list
  36. The utterances to test
  37. """
  38. phoneme_utterances = []
  39. unique_words = set()
  40. mlu_w = 0.0
  41. mlu_p = 0.0
  42. nb_utterances = 0
  43. nb_words = 0
  44. statistics = {}
  45. for utterance in utterances :
  46. utterance = utterance.strip()
  47. if not utterance : continue
  48. nb_utterances += 1
  49. utterance_w = utterance.replace("@", " ").replace("$", "")
  50. utterance_p = utterance.replace("@", " ").replace("$", " ")
  51. phoneme_utterances.append(utterance_p)
  52. utterance_words = utterance_w.split(" ")
  53. mlu_w += len(utterance_words)
  54. mlu_p += len(utterance_p.split(" "))
  55. nb_words += len(utterance_words)
  56. unique_words |= set(utterance_words)
  57. mlu_w /= nb_utterances
  58. mlu_p /= nb_utterances
  59. ttr_w = len(unique_words) / nb_words
  60. ppl = model.perplexity("\n".join(phoneme_utterances))
  61. entropy = log(ppl)
  62. statistics["ppl"] = ppl
  63. statistics["entropy"] = entropy
  64. statistics["mlu_w"] = mlu_w
  65. statistics["mlu_p"] = mlu_p
  66. statistics["ttr_w"] = ttr_w
  67. return statistics
  68. def create_sparse_combinantions(values: Iterable, variables=3) -> set:
  69. """
  70. This function will create combinantions for noising.
  71. Each item in the returned set contains four values corresponding\
  72. to (1) phoneme noise, (2) noise of from adult to child utterances,\
  73. (3) noise of from child to adult utterances and (4) noise of
  74. These combinantions are sparse because we only noise one value at time.
  75. For example, an item can be (0.0, 0.0, 0.0, 0.25), which means that we only
  76. noise 25 percent of the phonemes, and nothing else is affected.
  77. See the file make_noiser.py for more infomrations.
  78. """
  79. sparse_combinantions = []
  80. for value in values :
  81. for idx in range(variables) :
  82. sparse_values = [0.0] * variables
  83. sparse_values[idx] = value
  84. sparse_combinantions.append(tuple(sparse_values))
  85. return set(sparse_combinantions)
  86. def test(json_files_directory, models_directory, train_files, add_noise=True) :
  87. """
  88. This function will test the language models on CHILDES corpora
  89. """
  90. columns = ["language", "typology", "family", "speaker",\
  91. "age", "perplexity", "entropy", "phonemes_order_noise",\
  92. "speakers_noise", "phonemes_noise"]
  93. results = pd.DataFrame(columns=columns, index=None)
  94. # all_combinations = (list(product((0.0, 0.25, 0.5, 0.75), repeat=4))
  95. # if add_noise else [((0.0, 0.0, 0.0, 0.0))])
  96. sparse_combinantions = create_sparse_combinantions((0.0, 0.25, 0.5, 0.75, 1))
  97. # noise_values = np.linspace(0.0, 1.0, num=6)
  98. for phonemes_noise, speakers_noise, phonemes_order_noise in tqdm(sparse_combinantions, total=len(sparse_combinantions)) :
  99. for test_filename, model_filename in product(os.listdir(json_files_directory), os.listdir(models_directory)) :
  100. lg_iso, _ = test_filename.split(".")
  101. model_lg = model_filename.split(".")[0]
  102. if lg_iso != model_lg :
  103. continue
  104. most_probable_phonemes = get_most_probable_phonemes(f"{train_files}/{lg_iso}.one_sentence_per_line")
  105. loaded_json = json.load(open(f"{json_files_directory}/{test_filename}"))
  106. if add_noise :
  107. noise = Noise(most_probable_phonemes,
  108. phonemes_order_noise_value=phonemes_order_noise,
  109. speakers_noise_values=(speakers_noise, speakers_noise),
  110. phonemes_noise_value=phonemes_noise)
  111. loaded_json = noise(loaded_json)
  112. model = kenlm.Model(f"{models_directory}/{model_filename}")
  113. for family in loaded_json :
  114. for age in loaded_json[family] :
  115. if age == "None" : print(family, lg_iso, age); continue
  116. for speaker in loaded_json[family][age] :
  117. if speaker not in ["Adult", "Target_Child"] : continue
  118. # results_statistics = statistics_word(loaded_json[family][age][speaker], model)
  119. language, typology = LANGUAGES_TYPOLOGIES[lg_iso]
  120. ppl = model.perplexity("\n".join(loaded_json[family][age][speaker]))
  121. entropy = log(ppl)
  122. new_row = {"language" : language,
  123. "typology" : typology,
  124. "family" : family,
  125. "speaker" : speaker,
  126. "age" : float(age),
  127. "perplexity" : ppl,
  128. "entropy" : entropy,
  129. "phonemes_order_noise" : phonemes_order_noise,
  130. "speakers_noise" : speakers_noise,
  131. "phonemes_noise" : phonemes_noise}
  132. results = results.append(new_row, ignore_index=True)
  133. return results
  134. if __name__ == "__main__":
  135. from argparse import ArgumentParser, BooleanOptionalAction
  136. parser = ArgumentParser()
  137. parser.add_argument('--train_files_directory',
  138. required=True,
  139. help="The directory containing the train files tokenized in phonemes."
  140. )
  141. parser.add_argument('--model_files_directory',
  142. required=True,
  143. help="The directory containing the trained language models."
  144. )
  145. parser.add_argument('--json_files_directory',
  146. required=True,
  147. help="The directory containing CHILDES utterances in json format for each language"
  148. )
  149. parser.add_argument("--add_noise",
  150. help="Whether noise the CHILDES utterances or not",
  151. action=BooleanOptionalAction)
  152. args = parser.parse_args()
  153. add_noise = args.add_noise
  154. json_files_directory = args.json_files_directory
  155. phoneme_train_files = args.train_files_directory
  156. models_directory = args.model_files_directory
  157. if not os.path.exists("results"):
  158. os.makedirs("results")
  159. test(json_files_directory,
  160. models_directory,
  161. phoneme_train_files,
  162. add_noise=add_noise).to_csv("results/results.csv")