oe_data.py 4.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162
  1. import time
  2. import sys
  3. import numpy as np
  4. import uuid
  5. import json
  6. import zmq
  7. import serial
  8. import matplotlib.pyplot as plt
  9. from collections import OrderedDict
  10. sys.path.append('../')
  11. from openephys import OpenEphysEvent, OpenEphysSpikeEvent
  12. context = zmq.Context()
  13. TTL_THRESHOLD = 2.0
  14. def send_heartbeat():
  15. global data_socket, event_socket, poller
  16. global socket_waits_reply
  17. global last_heartbeat_time
  18. d = {'application': 'latency_test', 'uuid': uuid, 'type': 'heartbeat'}
  19. j_msg = json.dumps(d)
  20. event_socket.send(j_msg.encode('utf-8'))
  21. last_heartbeat_time = time.time()
  22. socket_waits_reply = True
  23. def fdump(fhnd, ts, data=None):
  24. if data is not None:
  25. fhnd.write("%d, %.5f" % (ts, data))
  26. else:
  27. fhnd.write("%d" % ts)
  28. fhnd.flush()
  29. def connect():
  30. global data_socket, event_socket, poller
  31. global ser
  32. print("init socket")
  33. data_socket = context.socket(zmq.SUB)
  34. data_socket.connect("tcp://localhost:%d" % dataport)
  35. event_socket = context.socket(zmq.REQ)
  36. event_socket.connect("tcp://localhost:%d" % eventport)
  37. data_socket.setsockopt(zmq.SUBSCRIBE, b'')
  38. poller.register(data_socket, zmq.POLLIN)
  39. poller.register(event_socket, zmq.POLLIN)
  40. ser = serial.Serial('COM6', 2000000, timeout=.1)
  41. def dump_event(header, event):
  42. global timestamp
  43. if event.type == 'TIMESTAMP':
  44. timestamp = event.timestamp
  45. elif event.type == 'TTL' and event.event_id == 1:
  46. fdump(fttl, event.sample_num + timestamp)
  47. print("Event:", header)
  48. print(event)
  49. def dump_data(header, content, data):
  50. if timestamp == -1:
  51. print("Dropping data - arrived before timestamp")
  52. print("Data:", content)
  53. print(header)
  54. def send_pps():
  55. ser.write(b'x')
  56. print("Pulse triggered")
  57. def find_ttl(data):
  58. if data.max() > TTL_THRESHOLD:
  59. send_pps()
  60. #print(data)
  61. #print(data.shape)
  62. #plt.plot(data.transpose())
  63. #plt.show()
  64. if __name__ == "__main__":
  65. dataport=5556
  66. eventport=5557
  67. data_socket = None
  68. event_socket = None
  69. poller = zmq.Poller()
  70. message_no = -1
  71. socket_waits_reply = False
  72. app_name = 'Dumper Process'
  73. uuid = str(uuid.uuid4())
  74. last_heartbeat_time = 0
  75. last_reply_time = time.time()
  76. timestamp = -1
  77. connect()
  78. next = time.perf_counter() + 1
  79. last_sent = 0
  80. measurements = []
  81. print("Waiting 20 seconds")
  82. #time.sleep(20)
  83. limit = 500
  84. evtcount = 0
  85. while evtcount < limit:
  86. socks = dict(poller.poll(1))
  87. if not socks:
  88. # print("poll exits")
  89. continue
  90. if data_socket in socks:
  91. print(".", end='', flush=True)
  92. try:
  93. # drop data immediately
  94. message = data_socket.recv_multipart(zmq.NOBLOCK)
  95. except zmq.ZMQError as err:
  96. logger.error("Got error: {0}".format(err))
  97. break
  98. if message:
  99. if len(message) < 2:
  100. logger.info("No frames for message: ", message[0])
  101. else:
  102. try:
  103. header = json.loads(message[1].decode('utf-8'))
  104. if header['type'] == 'event':
  105. print("event detected")
  106. elif header['type'] == 'data':
  107. c = header['content']
  108. n_samples = c['n_samples']
  109. n_channels = c['n_channels']
  110. n_real_samples = c['n_real_samples']
  111. try:
  112. n_arr = np.frombuffer(message[2], dtype=np.float32)
  113. n_arr = np.reshape(n_arr, (n_channels, n_samples))
  114. #print (n_channels, n_samples)
  115. if n_real_samples > 0:
  116. n_arr = n_arr[:, 0:n_real_samples]
  117. find_ttl(n_arr)
  118. except IndexError as e:
  119. logger.error(e)
  120. logger.error(header)
  121. logger.error(message[1])
  122. if len(message) > 2:
  123. logger.error(len(message[2]))
  124. else:
  125. logger.error("Only one frame???")
  126. except ValueError as e:
  127. logger.error("ValueError: ", e)
  128. logger.info(message[1])
  129. elif event_socket in socks and socket_waits_reply:
  130. evtcount += 1
  131. message = event_socket.recv()
  132. print("+")