__init__.py 8.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236
  1. import pandas as pd
  2. import numpy as np
  3. from .gdm_data_classes import GDMFile, GDMRow
  4. from view.python_core.overviews.ctv_handlers import PixelWiseCTVHandler
  5. from view.python_core.p1_class import MetadataDefinition
  6. from view.python_core.rois.roi_io import get_roi_io_class
  7. def apply_mask_to_data(roi_data_dict, data_xyt, area=None):
  8. data_txy = np.moveaxis(data_xyt, [0, 1, 2], [1, 2, 0])
  9. roi_label_traces_dict = {}
  10. for label, roi_data in roi_data_dict.items():
  11. if area is not None:
  12. weighted_mask_xy = roi_data.get_weighted_mask_considering_area(area)
  13. else:
  14. weighted_mask_xy = roi_data.get_weighted_mask(data_xyt.shape[:2])
  15. # weighted_mask_xy will add up to 1
  16. # multiply each frame with the mask and sum over each frame
  17. roi_label_traces_dict[label] = (data_txy * weighted_mask_xy).sum(axis=(1, 2))
  18. return roi_label_traces_dict
  19. def get_roi_gdm_traces_dict(p1, flags, roi_data_dict):
  20. if flags["GDM_withinArea"]:
  21. return apply_mask_to_data(data_xyt=p1.sig1, roi_data_dict=roi_data_dict, area=p1.area_mask)
  22. else:
  23. return apply_mask_to_data(data_xyt=p1.sig1, roi_data_dict=roi_data_dict)
  24. def get_glodatamix_row_boiler_plate(p1, animal_name=None):
  25. meta_def = MetadataDefinition()
  26. list_column_p1_metadata_mapping = meta_def.get_list_column_p1_metadata_mapping()
  27. default_p1_metadata = meta_def.get_default_row()
  28. # add all non-default metadata from p1.metadata to GDM metadata
  29. temp = {
  30. k: p1.metadata[v] for k, v in list_column_p1_metadata_mapping.items()
  31. if pd.notnull(v) and (p1.metadata[v] != default_p1_metadata[k])
  32. }
  33. # add extra metadata in p1 to GDM metadata
  34. temp.update(p1.extra_metadata)
  35. metadata_boiler_plate = pd.Series(data=temp)
  36. # add animal name
  37. if animal_name is not None:
  38. metadata_boiler_plate["Animal"] = animal_name
  39. else:
  40. metadata_boiler_plate["Animal"] = "not set"
  41. # add stimulus timing information
  42. stim_starts_ms = [
  43. x / np.timedelta64(1, 'ms')
  44. for x in p1.pulsed_stimuli_handler.get_pulse_start_times()]
  45. metadata_boiler_plate["StimONms"] = str(stim_starts_ms)[1:-1]
  46. stim_durations_ms = [
  47. x / np.timedelta64(1, 'ms')
  48. for x in p1.pulsed_stimuli_handler.get_pulse_durations()]
  49. metadata_boiler_plate["StimLen"] = str(stim_durations_ms)[1:-1]
  50. return metadata_boiler_plate
  51. def create_gdm_file_basic(
  52. common_metadata: pd.Series, roi_label_gdm_traces_dict: dict, sampling_period_ms: int, trace_onset: float=0,
  53. roi_descriptions: dict = None, roi_label_additional_metadata: dict = None
  54. ):
  55. gdm_file = GDMFile()
  56. for roi_label, gdm_trace in roi_label_gdm_traces_dict.items():
  57. metadata_boiler_plate = common_metadata.copy()
  58. metadata_boiler_plate['GloTag'] = roi_label
  59. if roi_descriptions is not None:
  60. metadata_boiler_plate['GloInfo'] = roi_descriptions[roi_label]
  61. if roi_label_additional_metadata is not None:
  62. # the alternative implementation using series.update not working for some reason
  63. for k, v in roi_label_additional_metadata.get(roi_label, {}).items():
  64. metadata_boiler_plate[k] = v
  65. gdm_row = \
  66. GDMRow.from_data_and_metadata(
  67. metadata_dict=metadata_boiler_plate, trace=gdm_trace,
  68. sampling_period_ms=sampling_period_ms, starting_time_s=trace_onset)
  69. gdm_file.append_gdm_row(gdm_row)
  70. return gdm_file
  71. class FullTraceGDMGenerator(object):
  72. def __init__(self, p1, flags, additional_metadata=None):
  73. """
  74. :param p1:
  75. :param flags:
  76. :param int trace_onset: offset of the measurement in seconds
  77. :param dict additional_metadata: any additional metadata about this measurement to be added to GDMs
  78. """
  79. self.roi_data_dict, self.roi_file = get_roi_io_class(flags["RM_ROITrace"]).read(
  80. flags=flags, measurement_label=p1.metadata.ex_name)
  81. self.roi_descriptions = \
  82. {k: v.get_text_description(frame_size=p1.get_frame_size())
  83. for k, v in self.roi_data_dict.items()}
  84. self.roi_label_gdm_traces_dict = get_roi_gdm_traces_dict(p1=p1, flags=flags, roi_data_dict=self.roi_data_dict)
  85. self.metadata_boiler_plate = get_glodatamix_row_boiler_plate(
  86. p1=p1, animal_name=flags["STG_ReportTag"])
  87. if additional_metadata is not None:
  88. for k, v in additional_metadata.items():
  89. self.metadata_boiler_plate[k] = v
  90. self.sampling_period_ms = p1.metadata["trial_ticks"]
  91. self.ctv_handler_obj = PixelWiseCTVHandler(flags=flags, p1=p1)
  92. self.ctv_name = f"CTV_{flags['CTV_Method']}"
  93. self.pulsed_stimuli_handler = p1.pulsed_stimuli_handler
  94. def calc_ctv(self, trace):
  95. try:
  96. ctv_value = self.ctv_handler_obj.apply_pixel(trace)
  97. except (IndexError, AssertionError) as err:
  98. ctv_value = np.nan
  99. return ctv_value
  100. def get_gdm_file(self):
  101. roi_label_additional_metadata = {}
  102. for roi_label, trace in self.roi_label_gdm_traces_dict.items():
  103. roi_label_additional_metadata[roi_label] = {self.ctv_name: self.calc_ctv(trace)}
  104. return create_gdm_file_basic(
  105. common_metadata=self.metadata_boiler_plate,
  106. sampling_period_ms=self.sampling_period_ms,
  107. roi_label_gdm_traces_dict=self.roi_label_gdm_traces_dict,
  108. roi_descriptions=self.roi_descriptions,
  109. roi_label_additional_metadata=roi_label_additional_metadata
  110. )
  111. class ChunksOnlyGDMGenerator(FullTraceGDMGenerator):
  112. def __init__(self, p1, flags, additional_metadata=None):
  113. super().__init__(p1, flags, additional_metadata)
  114. self.gdm_chunkPostStim = flags['GDM_chunkPostStim']
  115. self.gdm_chunkPreStim = flags["GDM_chunkPreStim"]
  116. def get_gdm_file(self, write_ctv=True):
  117. gdm_file_all = GDMFile()
  118. # for every stimulus
  119. for ind, (odor, conc, start_td, end_td, sampling_period_td) in \
  120. self.pulsed_stimuli_handler.stimulus_frame.iterrows():
  121. start_sec = start_td.total_seconds()
  122. end_sec = end_td.total_seconds()
  123. chunk_start_td = start_td - pd.to_timedelta(self.gdm_chunkPreStim, "s")
  124. chunk_end_td = end_td + pd.to_timedelta(self.gdm_chunkPostStim, "s")
  125. chunk_slice_start_ind = np.round(chunk_start_td / sampling_period_td).astype(int)
  126. chunk_slice_end_ind = np.round(chunk_end_td / sampling_period_td).astype(int)
  127. chunk_slice_start_ind = max(0, chunk_slice_start_ind)
  128. chunk_start_quantized_ms = chunk_slice_start_ind * self.sampling_period_ms
  129. common_metadata = self.metadata_boiler_plate.copy()
  130. common_metadata["StimLen"] = (end_sec - start_sec) * 1000 # in ms
  131. # this is relative to chunk start
  132. common_metadata["StimONms"] = start_sec * 1000 - chunk_start_quantized_ms # in ms
  133. common_metadata["Odour"] = odor
  134. common_metadata["OConc"] = conc
  135. roi_label_chunk_dict = {}
  136. roi_label_additional_metadata = {}
  137. # for every ROI
  138. for roi_label, gdm_trace in self.roi_label_gdm_traces_dict.items():
  139. chunk_slice_end_ind = min(gdm_trace.shape[0], chunk_slice_end_ind)
  140. chunk = gdm_trace[chunk_slice_start_ind: chunk_slice_end_ind + 1]
  141. roi_label_chunk_dict[roi_label] = chunk
  142. roi_label_additional_metadata[roi_label] = {self.ctv_name: self.calc_ctv(chunk)}
  143. gdm_file_this_stim = create_gdm_file_basic(
  144. common_metadata=common_metadata,
  145. sampling_period_ms=self.sampling_period_ms,
  146. roi_label_gdm_traces_dict=roi_label_chunk_dict,
  147. roi_descriptions=self.roi_descriptions,
  148. roi_label_additional_metadata=roi_label_additional_metadata,
  149. # metadata field for arbitrary delay in measurement
  150. # here used to indicate chunk start time relative to trace start
  151. trace_onset=chunk_start_quantized_ms / 1000 # in seconds
  152. )
  153. gdm_file_all.append_from_a_gdm_file(gdm_file_this_stim)
  154. return gdm_file_all
  155. def get_gdm_file(p1, flags):
  156. if flags["GDM_outputType"] == "full_traces":
  157. return FullTraceGDMGenerator(p1=p1, flags=flags).get_gdm_file()
  158. elif flags["GDM_outputType"] == "chunks_only":
  159. return ChunksOnlyGDMGenerator(p1=p1, flags=flags).get_gdm_file()
  160. else:
  161. raise NotImplementedError