123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236 |
- import pandas as pd
- import numpy as np
- from .gdm_data_classes import GDMFile, GDMRow
- from view.python_core.overviews.ctv_handlers import PixelWiseCTVHandler
- from view.python_core.p1_class import MetadataDefinition
- from view.python_core.rois.roi_io import get_roi_io_class
- def apply_mask_to_data(roi_data_dict, data_xyt, area=None):
- data_txy = np.moveaxis(data_xyt, [0, 1, 2], [1, 2, 0])
- roi_label_traces_dict = {}
- for label, roi_data in roi_data_dict.items():
- if area is not None:
- weighted_mask_xy = roi_data.get_weighted_mask_considering_area(area)
- else:
- weighted_mask_xy = roi_data.get_weighted_mask(data_xyt.shape[:2])
- # weighted_mask_xy will add up to 1
- # multiply each frame with the mask and sum over each frame
- roi_label_traces_dict[label] = (data_txy * weighted_mask_xy).sum(axis=(1, 2))
- return roi_label_traces_dict
- def get_roi_gdm_traces_dict(p1, flags, roi_data_dict):
- if flags["GDM_withinArea"]:
- return apply_mask_to_data(data_xyt=p1.sig1, roi_data_dict=roi_data_dict, area=p1.area_mask)
- else:
- return apply_mask_to_data(data_xyt=p1.sig1, roi_data_dict=roi_data_dict)
- def get_glodatamix_row_boiler_plate(p1, animal_name=None):
- meta_def = MetadataDefinition()
- list_column_p1_metadata_mapping = meta_def.get_list_column_p1_metadata_mapping()
- default_p1_metadata = meta_def.get_default_row()
- # add all non-default metadata from p1.metadata to GDM metadata
- temp = {
- k: p1.metadata[v] for k, v in list_column_p1_metadata_mapping.items()
- if pd.notnull(v) and (p1.metadata[v] != default_p1_metadata[k])
- }
- # add extra metadata in p1 to GDM metadata
- temp.update(p1.extra_metadata)
- metadata_boiler_plate = pd.Series(data=temp)
- # add animal name
- if animal_name is not None:
- metadata_boiler_plate["Animal"] = animal_name
- else:
- metadata_boiler_plate["Animal"] = "not set"
- # add stimulus timing information
- stim_starts_ms = [
- x / np.timedelta64(1, 'ms')
- for x in p1.pulsed_stimuli_handler.get_pulse_start_times()]
- metadata_boiler_plate["StimONms"] = str(stim_starts_ms)[1:-1]
-
- stim_durations_ms = [
- x / np.timedelta64(1, 'ms')
- for x in p1.pulsed_stimuli_handler.get_pulse_durations()]
- metadata_boiler_plate["StimLen"] = str(stim_durations_ms)[1:-1]
- return metadata_boiler_plate
- def create_gdm_file_basic(
- common_metadata: pd.Series, roi_label_gdm_traces_dict: dict, sampling_period_ms: int, trace_onset: float=0,
- roi_descriptions: dict = None, roi_label_additional_metadata: dict = None
- ):
- gdm_file = GDMFile()
- for roi_label, gdm_trace in roi_label_gdm_traces_dict.items():
- metadata_boiler_plate = common_metadata.copy()
- metadata_boiler_plate['GloTag'] = roi_label
- if roi_descriptions is not None:
- metadata_boiler_plate['GloInfo'] = roi_descriptions[roi_label]
- if roi_label_additional_metadata is not None:
- # the alternative implementation using series.update not working for some reason
- for k, v in roi_label_additional_metadata.get(roi_label, {}).items():
- metadata_boiler_plate[k] = v
- gdm_row = \
- GDMRow.from_data_and_metadata(
- metadata_dict=metadata_boiler_plate, trace=gdm_trace,
- sampling_period_ms=sampling_period_ms, starting_time_s=trace_onset)
- gdm_file.append_gdm_row(gdm_row)
- return gdm_file
- class FullTraceGDMGenerator(object):
- def __init__(self, p1, flags, additional_metadata=None):
- """
- :param p1:
- :param flags:
- :param int trace_onset: offset of the measurement in seconds
- :param dict additional_metadata: any additional metadata about this measurement to be added to GDMs
- """
- self.roi_data_dict, self.roi_file = get_roi_io_class(flags["RM_ROITrace"]).read(
- flags=flags, measurement_label=p1.metadata.ex_name)
- self.roi_descriptions = \
- {k: v.get_text_description(frame_size=p1.get_frame_size())
- for k, v in self.roi_data_dict.items()}
- self.roi_label_gdm_traces_dict = get_roi_gdm_traces_dict(p1=p1, flags=flags, roi_data_dict=self.roi_data_dict)
- self.metadata_boiler_plate = get_glodatamix_row_boiler_plate(
- p1=p1, animal_name=flags["STG_ReportTag"])
- if additional_metadata is not None:
- for k, v in additional_metadata.items():
- self.metadata_boiler_plate[k] = v
- self.sampling_period_ms = p1.metadata["trial_ticks"]
- self.ctv_handler_obj = PixelWiseCTVHandler(flags=flags, p1=p1)
- self.ctv_name = f"CTV_{flags['CTV_Method']}"
- self.pulsed_stimuli_handler = p1.pulsed_stimuli_handler
- def calc_ctv(self, trace):
- try:
- ctv_value = self.ctv_handler_obj.apply_pixel(trace)
- except (IndexError, AssertionError) as err:
- ctv_value = np.nan
- return ctv_value
- def get_gdm_file(self):
- roi_label_additional_metadata = {}
- for roi_label, trace in self.roi_label_gdm_traces_dict.items():
- roi_label_additional_metadata[roi_label] = {self.ctv_name: self.calc_ctv(trace)}
- return create_gdm_file_basic(
- common_metadata=self.metadata_boiler_plate,
- sampling_period_ms=self.sampling_period_ms,
- roi_label_gdm_traces_dict=self.roi_label_gdm_traces_dict,
- roi_descriptions=self.roi_descriptions,
- roi_label_additional_metadata=roi_label_additional_metadata
- )
- class ChunksOnlyGDMGenerator(FullTraceGDMGenerator):
- def __init__(self, p1, flags, additional_metadata=None):
- super().__init__(p1, flags, additional_metadata)
- self.gdm_chunkPostStim = flags['GDM_chunkPostStim']
- self.gdm_chunkPreStim = flags["GDM_chunkPreStim"]
- def get_gdm_file(self, write_ctv=True):
- gdm_file_all = GDMFile()
- # for every stimulus
- for ind, (odor, conc, start_td, end_td, sampling_period_td) in \
- self.pulsed_stimuli_handler.stimulus_frame.iterrows():
-
- start_sec = start_td.total_seconds()
- end_sec = end_td.total_seconds()
- chunk_start_td = start_td - pd.to_timedelta(self.gdm_chunkPreStim, "s")
- chunk_end_td = end_td + pd.to_timedelta(self.gdm_chunkPostStim, "s")
- chunk_slice_start_ind = np.round(chunk_start_td / sampling_period_td).astype(int)
- chunk_slice_end_ind = np.round(chunk_end_td / sampling_period_td).astype(int)
- chunk_slice_start_ind = max(0, chunk_slice_start_ind)
- chunk_start_quantized_ms = chunk_slice_start_ind * self.sampling_period_ms
- common_metadata = self.metadata_boiler_plate.copy()
- common_metadata["StimLen"] = (end_sec - start_sec) * 1000 # in ms
- # this is relative to chunk start
- common_metadata["StimONms"] = start_sec * 1000 - chunk_start_quantized_ms # in ms
- common_metadata["Odour"] = odor
- common_metadata["OConc"] = conc
- roi_label_chunk_dict = {}
- roi_label_additional_metadata = {}
- # for every ROI
- for roi_label, gdm_trace in self.roi_label_gdm_traces_dict.items():
- chunk_slice_end_ind = min(gdm_trace.shape[0], chunk_slice_end_ind)
- chunk = gdm_trace[chunk_slice_start_ind: chunk_slice_end_ind + 1]
- roi_label_chunk_dict[roi_label] = chunk
- roi_label_additional_metadata[roi_label] = {self.ctv_name: self.calc_ctv(chunk)}
- gdm_file_this_stim = create_gdm_file_basic(
- common_metadata=common_metadata,
- sampling_period_ms=self.sampling_period_ms,
- roi_label_gdm_traces_dict=roi_label_chunk_dict,
- roi_descriptions=self.roi_descriptions,
- roi_label_additional_metadata=roi_label_additional_metadata,
- # metadata field for arbitrary delay in measurement
- # here used to indicate chunk start time relative to trace start
- trace_onset=chunk_start_quantized_ms / 1000 # in seconds
- )
- gdm_file_all.append_from_a_gdm_file(gdm_file_this_stim)
- return gdm_file_all
- def get_gdm_file(p1, flags):
-
- if flags["GDM_outputType"] == "full_traces":
-
- return FullTraceGDMGenerator(p1=p1, flags=flags).get_gdm_file()
-
- elif flags["GDM_outputType"] == "chunks_only":
- return ChunksOnlyGDMGenerator(p1=p1, flags=flags).get_gdm_file()
- else:
- raise NotImplementedError
|