test_prediction2.py 3.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140
  1. '''
  2. description: testing prediction and speller matrix by writing freely on the console
  3. author: Alessandro Tonin
  4. date: 06.02.2020
  5. Copyright (c) 2020 Alessandro Tonin.
  6. All rights reserved.'''
  7. import importlib
  8. import logging
  9. import sys
  10. import time
  11. from random import randint
  12. import tkinter as tk
  13. import yaml
  14. # import pyaudio
  15. from colorlog import ColoredFormatter
  16. import aux
  17. from aux import log
  18. from paradigms import colorSpeller
  19. import multiprocessing as mp
  20. from multiprocessing import Array, Pipe
  21. import readchar
  22. import csv
  23. import re
  24. importlib.reload(colorSpeller)
  25. importlib.reload(aux)
  26. file_log = '/kiap/src/kiap_bci/tests/records_for_annotation_yml.yaml'
  27. file_in = '/kiap/src/kiap_bci/tests/user_test.txt'
  28. file_out = '/kiap/src/kiap_bci/tests/user_test_out_update_spcorpus_sp_us_gen.txt'
  29. sp = colorSpeller.ColorSpeller()
  30. update_corpus = True
  31. # # process yaml file to correctly format the spelled sentences
  32. # with open(file_log, 'r') as flog, open(file_in,'w') as fin:
  33. # entries = yaml.load(flog,Loader=yaml.Loader)
  34. # for k in entries.keys():
  35. # phrase = entries[k]['phrase']
  36. # init_phrase = entries[k]['phrase_start']
  37. # phrase = phrase.strip() #just in case...
  38. # if init_phrase and phrase.find(init_phrase) == 0:
  39. # phrase = '['+phrase[0:len(init_phrase)]+']'+phrase[len(init_phrase):]
  40. # fin.write(phrase+'\n')
  41. # txt file with one speller selection in each line
  42. # Read txt file line by line
  43. with open(file_in,'r+') as fin, open(file_out, 'w') as fout:
  44. results_writer = csv.writer(fout, delimiter=',')
  45. results_writer.writerow(['num', 'target', 'string', 'vocab1', 'vocab2', 'vocab3', 'vocab4', 'vocab5', 'pred corr','pred wrong','saved char'])
  46. line = fin.readline()
  47. cnt = 1
  48. while line:
  49. # Check initial string
  50. init_str_start = line.find('[')
  51. init_str_end = line.find(']')
  52. if init_str_start != init_str_end:
  53. curr_str = line[init_str_start+1:init_str_end]
  54. else:
  55. curr_str = ''
  56. # Delete initial string identifiers
  57. for r in (('[',''),(']',''),('\n','')):
  58. line = line.replace(*r)
  59. log.warning(f'target: {line}')
  60. sp.set_current_string(curr_str)
  61. sp.external_updates()
  62. vocab = sp.get_vocabulary(5)
  63. # fill empty vocabulary with dash
  64. while len(vocab) < 5:
  65. vocab.append('#')
  66. log.info(f'string: {curr_str}')
  67. log.info(f'vocab: {vocab}')
  68. results_writer.writerow([cnt,line,curr_str]+vocab)
  69. # use while loop instead of for in order to skip iterations when prediction is correct
  70. idx_letter = 0
  71. pred_correct = 0
  72. pred_wrong = 0
  73. saved_char = 0
  74. while idx_letter < len(line):
  75. # find the actual word
  76. word_start = line.rfind(' ',0,idx_letter)
  77. word_end = line.find(' ',idx_letter)
  78. if word_end == -1: word_end = len(line) # because if not find result is -1
  79. curr_word = line[word_start+1:word_end]
  80. # Check if there is a probable word and if it is correct
  81. if len(vocab)>0 and vocab[0].startswith('*') and vocab[0][1:-1] == curr_word:
  82. # in case add the probable word to the string
  83. curr_str = curr_str[:word_start+1] + curr_word
  84. pred_correct += 1
  85. saved_char += word_end-idx_letter
  86. if word_end < len(line):
  87. curr_str += ' '
  88. saved_char += 1
  89. elif len(vocab)>0 and vocab[0].startswith('*'):
  90. curr_str += line[idx_letter]
  91. pred_wrong += 1
  92. else:
  93. # Add to speller the selection letter by letter
  94. curr_str += line[idx_letter]
  95. # Keep trace of predicted words
  96. sp.set_current_string(curr_str)
  97. sp.external_updates()
  98. # Check prediction after every new letter
  99. vocab = sp.get_vocabulary(5)
  100. # fill empty vocabulary with dash
  101. while len(vocab) < 5:
  102. vocab.append('#')
  103. log.info(f'string: {curr_str}')
  104. log.info(f'vocab: {vocab}')
  105. idx_letter = len(curr_str)
  106. results_writer.writerow([cnt,line,curr_str]+vocab+[pred_correct,pred_wrong,saved_char])
  107. if update_corpus:
  108. sp._save_string()
  109. sp.speller_user_cfd = sp._createCFD(root=sp.corpora_path,corpus_name=sp.speller_user_corpus_name)
  110. line = fin.readline()
  111. cnt += 1