123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162 |
- import time
- import sys
- import numpy as np
- import uuid
- import json
- import zmq
- import serial
- import matplotlib.pyplot as plt
- from collections import OrderedDict
- sys.path.append('../')
- from openephys import OpenEphysEvent, OpenEphysSpikeEvent
- context = zmq.Context()
- TTL_THRESHOLD = 2.0
- def send_heartbeat():
- global data_socket, event_socket, poller
- global socket_waits_reply
- global last_heartbeat_time
- d = {'application': 'latency_test', 'uuid': uuid, 'type': 'heartbeat'}
- j_msg = json.dumps(d)
- event_socket.send(j_msg.encode('utf-8'))
- last_heartbeat_time = time.time()
- socket_waits_reply = True
- def fdump(fhnd, ts, data=None):
- if data is not None:
- fhnd.write("%d, %.5f" % (ts, data))
- else:
- fhnd.write("%d" % ts)
- fhnd.flush()
- def connect():
- global data_socket, event_socket, poller
- global ser
- print("init socket")
- data_socket = context.socket(zmq.SUB)
- data_socket.connect("tcp://localhost:%d" % dataport)
- event_socket = context.socket(zmq.REQ)
- event_socket.connect("tcp://localhost:%d" % eventport)
- data_socket.setsockopt(zmq.SUBSCRIBE, b'')
- poller.register(data_socket, zmq.POLLIN)
- poller.register(event_socket, zmq.POLLIN)
-
- ser = serial.Serial('COM6', 2000000, timeout=.1)
-
- def dump_event(header, event):
- global timestamp
- if event.type == 'TIMESTAMP':
- timestamp = event.timestamp
- elif event.type == 'TTL' and event.event_id == 1:
- fdump(fttl, event.sample_num + timestamp)
- print("Event:", header)
- print(event)
- def dump_data(header, content, data):
- if timestamp == -1:
- print("Dropping data - arrived before timestamp")
- print("Data:", content)
- print(header)
-
- def send_pps():
- ser.write(b'x')
- print("Pulse triggered")
- def find_ttl(data):
- if data.max() > TTL_THRESHOLD:
- send_pps()
- #print(data)
- #print(data.shape)
- #plt.plot(data.transpose())
- #plt.show()
- if __name__ == "__main__":
- dataport=5556
- eventport=5557
- data_socket = None
- event_socket = None
- poller = zmq.Poller()
- message_no = -1
- socket_waits_reply = False
- app_name = 'Dumper Process'
- uuid = str(uuid.uuid4())
- last_heartbeat_time = 0
- last_reply_time = time.time()
- timestamp = -1
-
- connect()
- next = time.perf_counter() + 1
- last_sent = 0
-
- measurements = []
- print("Waiting 20 seconds")
- #time.sleep(20)
- limit = 500
- evtcount = 0
- while evtcount < limit:
- socks = dict(poller.poll(1))
-
- if not socks:
- # print("poll exits")
- continue
- if data_socket in socks:
- print(".", end='', flush=True)
- try:
- # drop data immediately
- message = data_socket.recv_multipart(zmq.NOBLOCK)
- except zmq.ZMQError as err:
- logger.error("Got error: {0}".format(err))
- break
-
- if message:
- if len(message) < 2:
- logger.info("No frames for message: ", message[0])
- else:
- try:
- header = json.loads(message[1].decode('utf-8'))
- if header['type'] == 'event':
- print("event detected")
- elif header['type'] == 'data':
- c = header['content']
- n_samples = c['n_samples']
- n_channels = c['n_channels']
- n_real_samples = c['n_real_samples']
- try:
- n_arr = np.frombuffer(message[2], dtype=np.float32)
- n_arr = np.reshape(n_arr, (n_channels, n_samples))
- #print (n_channels, n_samples)
- if n_real_samples > 0:
- n_arr = n_arr[:, 0:n_real_samples]
- find_ttl(n_arr)
-
- except IndexError as e:
- logger.error(e)
- logger.error(header)
- logger.error(message[1])
- if len(message) > 2:
- logger.error(len(message[2]))
- else:
- logger.error("Only one frame???")
-
- except ValueError as e:
- logger.error("ValueError: ", e)
- logger.info(message[1])
- elif event_socket in socks and socket_waits_reply:
- evtcount += 1
- message = event_socket.recv()
- print("+")
|