stim_movie.py 8.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257
  1. # -*- coding: utf-8 -*-
  2. """Some classes and functions to load and save MultiViewVideo files (*.idlmov)
  3. Client code only needs class StimulusSet.
  4. Save a list of pictures as stimulus set file:
  5. >>> stim_set = StimulusSet(pics=pics)
  6. >>> stim_set.save(file_path)
  7. Load a stimulus file:
  8. >>> stimset = StimulusSet(file_path)
  9. """
  10. import objsimpy.element_tree_object as eto
  11. import struct
  12. import numpy as np
  13. import array
  14. import logging
  15. import copy
  16. from objsimpy.xmltodict import unparse as xml_from_dict
  17. def remove_null_character(s):
  18. end = s.find('\x00')
  19. if end == -1:
  20. return ''
  21. else:
  22. return s[0:end]
  23. class StimMetaData:
  24. def __init__(self, file_path):
  25. self.eto = eto.EtObject(file_path=file_path)
  26. self.movie_file_name = self.eto.StimParas[0].MovieFileName[0].value
  27. self.ny_paras = int(self.eto.StimParas[0].NYParas[0].value)
  28. self.nx_paras = int(self.eto.StimParas[0].NXParas[0].value)
  29. def write_meta_xml(meta_dict, fname_meta):
  30. """Write stimulus meta data xml file.
  31. Args:
  32. meta_dict (dict): stimulus meta data in a dictionary
  33. fname_meta (string): name of xml file
  34. """
  35. meta_copy = copy.deepcopy(meta_dict)
  36. make_content_value_recursive(meta_copy)
  37. with open(fname_meta, 'w') as fobj_meta:
  38. fobj_meta.write(xml_from_dict(meta_copy, pretty=True))
  39. def make_content_value_recursive(meta_dict):
  40. keys = meta_dict.keys()
  41. for key in keys:
  42. if isinstance(meta_dict[key], dict):
  43. make_content_value_recursive(meta_dict[key])
  44. else:
  45. meta_dict[key] = {'@value': str(meta_dict[key])}
  46. class StimFileHeader:
  47. def __init__(self, file_obj=None, info="MindVideo", version="0.05", width=1, height=1, nfilters=1):
  48. if file_obj is not None:
  49. self.from_file(file_obj)
  50. else:
  51. self.info = info
  52. self.version = version
  53. self.width = width
  54. self.height = height
  55. self.nfilters = nfilters
  56. def from_file(self, file_obj):
  57. (info,) = struct.unpack('12s', file_obj.read(12))
  58. self.info = remove_null_character(info)
  59. (version,) = struct.unpack('8s', file_obj.read(8))
  60. self.version = remove_null_character(version)
  61. (self.width, self.height, self.nfilters) = struct.unpack('iii', file_obj.read(12))
  62. def save(self, file_obj):
  63. file_obj.write(struct.pack('12s', self.info))
  64. file_obj.write(struct.pack('8s', self.version))
  65. file_obj.write(struct.pack('iii', self.width, self.height, self.nfilters))
  66. def __unicode__(self):
  67. attrs = ["info", "version", "width", "height", "nfilters"]
  68. ret = u"<StimFileHeader"
  69. for a in attrs:
  70. ret += " " + a + "=" + str(getattr(self, a))
  71. ret += "/>"
  72. return ret
  73. class Filter:
  74. def __init__(self, open_file=None, pic=None, fwidth=2, fheight=2, wshift=0, hshift=0,
  75. output_width=5, output_height=5, output_size=25):
  76. if open_file is not None:
  77. self.from_file(open_file)
  78. else:
  79. if pic is None:
  80. pic = np.zeros((fwidth, fheight))
  81. self.set_filter_dimensions_from_pic(pic)
  82. self.wshift = wshift
  83. self.hshift = hshift
  84. self.output_width = output_width
  85. self.output_height = output_height
  86. def set_filter_dimensions_from_pic(self, pic):
  87. self.pic = pic
  88. self.fheight = pic.shape[0]
  89. self.fwidth = pic.shape[1]
  90. def from_file(self, file_obj):
  91. n_integers = 6
  92. data = struct.unpack('i'*n_integers, file_obj.read(4*n_integers))
  93. self.fwidth, self.fheight = data[:2]
  94. self.fsize = self.fwidth * self.fheight
  95. self.wshift, self.hshift = data[2:4]
  96. self.output_width, self.output_height = data[4:6]
  97. self.output_size = self.output_width * self.output_height
  98. self.pic = struct.unpack('f'*self.fsize, file_obj.read(4*self.fsize))
  99. def save(self, file_obj):
  100. data = [self.fwidth, self.fheight, self.wshift, self.hshift, self.output_width, self.output_height]
  101. int_array = array.array('i', data)
  102. int_array.tofile(file_obj)
  103. pic_float_array = array.array('f', self.pic.flat)
  104. pic_float_array.tofile(file_obj)
  105. def __repr__(self):
  106. attrs = ["fwidth", "fheight", "wshift", "hshift", "output_width", "output_height", "output_size"]
  107. ret = "<Filter"
  108. for a in attrs:
  109. ret += " " + a + "=" + str(getattr(self, a))
  110. ret += "/>"
  111. return ret
  112. class StimulusSet:
  113. """Represents a set of images used as stimuli for neural network.
  114. A single stimulus is a frame, consisting of an image for every filter.
  115. Currently only stimulus sets with single filter are supported.
  116. Args:
  117. file_path (str, optional): if given, the stimulus set will be loaded from there
  118. pics (list of ndarray, optional): List of 2d numpy arrays representing stimulus pictures
  119. Attributes:
  120. pics (list of ndarray): list of images (2d numpy arrays)
  121. """
  122. def __init__(self, file_path=None, pics=None):
  123. """Initializes an instance of StimulusSet.
  124. Args:
  125. file_path (str, optional): if given, the stimulus set will be loaded from there
  126. pics (list of ndarray, optional): List of 2d numpy arrays representing stimulus pictures
  127. """
  128. self.header = None
  129. self.filters = None
  130. if pics is None:
  131. self.pics = []
  132. else:
  133. self.pics = pics
  134. if file_path is not None:
  135. self.pics = self.from_file(file_path)
  136. def add_frame(self, pics=None):
  137. if pics is None:
  138. pics = []
  139. if len(pics) > 1:
  140. raise NotImplementedError
  141. self.pics.append(pics[0])
  142. def get_output_width(self, filter_num=0):
  143. if self.filters is not None:
  144. return self.filters[filter_num].output_width
  145. else:
  146. return 0
  147. def get_output_height(self, filter_num=0):
  148. if self.filters is not None:
  149. return self.filters[filter_num].output_height
  150. else:
  151. return 0
  152. def from_file(self, file_path):
  153. """Loads a stimulus set from file_path"""
  154. file_obj = open(file_path, "rb")
  155. self.header = StimFileHeader(file_obj)
  156. logging.debug(unicode(self.header).encode('utf8'))
  157. self.filters = []
  158. for i in range(self.header.nfilters):
  159. f = Filter(file_obj)
  160. self.filters.append(f)
  161. logging.debug(f)
  162. if self.header.nfilters > 1:
  163. raise NotImplementedError
  164. img_filter = self.filters[0]
  165. n = img_filter.output_size
  166. pics = []
  167. while True:
  168. n_bytes = 4 * n
  169. data = file_obj.read(n_bytes)
  170. if len(data) < n_bytes:
  171. break
  172. pic = struct.unpack('f'*n, data)
  173. pic = np.array(pic)
  174. pic = pic.reshape((img_filter.output_height, img_filter.output_width))
  175. pics.append(pic)
  176. file_obj.close()
  177. return pics
  178. def save(self, file_path):
  179. """Saves the stimulus set to file_path."""
  180. file_obj = open(file_path, "wb")
  181. self.save_header(file_obj)
  182. self.save_filters(file_obj)
  183. self.save_frames(file_obj)
  184. file_obj.close()
  185. # only with 1 filter implemented
  186. def save_header(self, file_obj):
  187. height, width = self.pics[0].shape
  188. if self.filters is None:
  189. self.filters = [Filter(output_width=width, output_height=height),]
  190. nfilters = len(self.filters)
  191. if nfilters > 1:
  192. raise NotImplementedError
  193. header = StimFileHeader(info="MindVideo", version="0.05", width=width, height=height, nfilters=nfilters)
  194. header.save(file_obj)
  195. def save_filters(self, file_obj):
  196. for f in self.filters:
  197. f.save(file_obj)
  198. def save_frames(self, file_obj):
  199. nfilters = len(self.filters)
  200. if nfilters > 1:
  201. raise NotImplementedError
  202. for pic in self.pics:
  203. pic_float_array = array.array('f', pic.flat)
  204. pic_float_array.tofile(file_obj)
  205. def load_stim_movie():
  206. pass