You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

257 lines
7.7 KiB

5 years ago
8 years ago
8 years ago
8 years ago
7 years ago
  1. from pims import Frame
  2. from pims.base_frames import FramesSequenceND
  3. from nd2reader.exceptions import EmptyFileError
  4. from nd2reader.parser import Parser
  5. import numpy as np
  6. class ND2Reader(FramesSequenceND):
  7. """PIMS wrapper for the ND2 parser.
  8. This is the main class: use this to process your .nd2 files.
  9. """
  10. class_priority = 12
  11. def __init__(self, filename):
  12. super(self.__class__, self).__init__()
  13. self.filename = filename
  14. # first use the parser to parse the file
  15. self._fh = open(filename, "rb")
  16. self._parser = Parser(self._fh)
  17. # Setup metadata
  18. self.metadata = self._parser.metadata
  19. # Set data type
  20. self._dtype = self._parser.get_dtype_from_metadata()
  21. # Setup the axes
  22. self._setup_axes()
  23. # Other properties
  24. self._timesteps = None
  25. @classmethod
  26. def class_exts(cls):
  27. """Let PIMS open function use this reader for opening .nd2 files
  28. """
  29. return {'nd2'} | super(ND2Reader, cls).class_exts()
  30. def close(self):
  31. """Correctly close the file handle
  32. """
  33. if self._fh is not None:
  34. self._fh.close()
  35. def _get_default(self, coord):
  36. try:
  37. return self.default_coords[coord]
  38. except KeyError:
  39. return 0
  40. def get_frame_2D(self, c=0, t=0, z=0, x=0, y=0, v=0):
  41. """Fallback function for backwards compatibility
  42. """
  43. return self.get_frame_vczyx(v=v, c=c, t=t, z=z, x=x, y=y)
  44. def get_frame_vczyx(self, v=None, c=None, t=None, z=None, x=None, y=None):
  45. """Retrieve a frame based on the specified coordinates
  46. Axes order is set by self.bundle_axes, x and y coordinates are ignored,
  47. because we always return Frame objects.
  48. """
  49. # remove 'x', 'y' from bundle axes and set to width, height
  50. bundle_axes = list(self.bundle_axes)
  51. try:
  52. bundle_axes.remove('x')
  53. except ValueError:
  54. pass
  55. try:
  56. bundle_axes.remove('y')
  57. except ValueError:
  58. pass
  59. x = self.metadata["width"]
  60. y = self.metadata["height"]
  61. # make coords dictionary based on function input
  62. coords = dict(v=v, c=c, t=t, z=z)
  63. # Set appropriate values for None and bundle_axes coords
  64. for dim in coords:
  65. coords[dim] = self._get_possible_coords(dim, coords[dim])
  66. # Initialize empty array of Frames of right shape
  67. if len(bundle_axes) > 0:
  68. shape = tuple((len(coords[dim]) for dim in bundle_axes))
  69. results = np.empty(shape, dtype=Frame)
  70. else:
  71. results = np.empty((1,), dtype=Frame)
  72. # order for the get_image_by_attributes function
  73. argument_order = dict(t=0, v=1, c=2, z=3)
  74. # Now, collect the results in the right order
  75. for index, _ in np.ndenumerate(results):
  76. current_coords = [0, 0, 0, 0, y, x]
  77. for dim in coords:
  78. if dim in bundle_axes:
  79. dim_val = coords[dim][index[bundle_axes.index(dim)]]
  80. else:
  81. dim_val = coords[dim][0]
  82. current_coords[argument_order[dim]] = dim_val
  83. # Actually get the corresponding Frame
  84. results[index] = Frame(self._parser.get_image_by_attributes(*current_coords), metadata=self.metadata)
  85. if len(bundle_axes) == 0:
  86. return results[0]
  87. return results
  88. def _get_possible_coords(self, dim, default):
  89. if dim in self.sizes:
  90. if dim in self.bundle_axes:
  91. return range(self.sizes[dim])
  92. else:
  93. return [default] if default is not None else range(self.sizes[dim])
  94. return [None]
  95. @property
  96. def parser(self):
  97. """
  98. Returns the parser object.
  99. Returns:
  100. Parser: the parser object
  101. """
  102. return self._parser
  103. @property
  104. def pixel_type(self):
  105. """Return the pixel data type
  106. Returns:
  107. dtype: the pixel data type
  108. """
  109. return self._dtype
  110. @property
  111. def timesteps(self):
  112. """Get the timesteps of the experiment
  113. Returns:
  114. np.ndarray: an array of times in milliseconds.
  115. """
  116. if self._timesteps is None:
  117. return self.get_timesteps()
  118. return self._timesteps
  119. @property
  120. def events(self):
  121. """Get the events of the experiment
  122. Returns:
  123. iterator of events as dict
  124. """
  125. return self._get_metadata_property("events")
  126. @property
  127. def frame_rate(self):
  128. """The (average) frame rate
  129. Returns:
  130. float: the (average) frame rate in frames per second
  131. """
  132. total_duration = 0.0
  133. for loop in self.metadata['experiment']['loops']:
  134. total_duration += loop['duration']
  135. if total_duration == 0:
  136. raise ValueError('Total measurement duration could not be determined from loops')
  137. return self.metadata['num_frames'] / (total_duration/1000.0)
  138. def _get_metadata_property(self, key, default=None):
  139. if self.metadata is None:
  140. return default
  141. if key not in self.metadata:
  142. return default
  143. if self.metadata[key] is None:
  144. return default
  145. return self.metadata[key]
  146. def _setup_axes(self):
  147. """Setup the xyctz axes, iterate over t axis by default
  148. """
  149. self._init_axis_if_exists('x', self._get_metadata_property("width", default=0))
  150. self._init_axis_if_exists('y', self._get_metadata_property("height", default=0))
  151. self._init_axis_if_exists('c', len(self._get_metadata_property("channels", default=[])), min_size=2)
  152. self._init_axis_if_exists('t', len(self._get_metadata_property("frames", default=[])))
  153. self._init_axis_if_exists('z', len(self._get_metadata_property("z_levels", default=[])), min_size=2)
  154. self._init_axis_if_exists('v', len(self._get_metadata_property("fields_of_view", default=[])), min_size=2)
  155. if len(self.sizes) == 0:
  156. raise EmptyFileError("No axes were found for this .nd2 file.")
  157. # provide the default
  158. self.iter_axes = self._guess_default_iter_axis()
  159. self._register_get_frame(self.get_frame_vczyx, 'vczyx')
  160. self._register_get_frame(self.get_frame_vczyx, 'vzyx')
  161. self._register_get_frame(self.get_frame_vczyx, 'vcyx')
  162. self._register_get_frame(self.get_frame_vczyx, 'vyx')
  163. self._register_get_frame(self.get_frame_vczyx, 'czyx')
  164. self._register_get_frame(self.get_frame_vczyx, 'cyx')
  165. self._register_get_frame(self.get_frame_vczyx, 'zyx')
  166. self._register_get_frame(self.get_frame_vczyx, 'yx')
  167. def _init_axis_if_exists(self, axis, size, min_size=1):
  168. if size >= min_size:
  169. self._init_axis(axis, size)
  170. def _guess_default_iter_axis(self):
  171. """
  172. Guesses the default axis to iterate over based on axis sizes.
  173. Returns:
  174. the axis to iterate over
  175. """
  176. priority = ['t', 'z', 'c', 'v']
  177. found_axes = []
  178. for axis in priority:
  179. try:
  180. current_size = self.sizes[axis]
  181. except KeyError:
  182. continue
  183. if current_size > 1:
  184. return axis
  185. found_axes.append(axis)
  186. return found_axes[0]
  187. def get_timesteps(self):
  188. """Get the timesteps of the experiment
  189. Returns:
  190. np.ndarray: an array of times in milliseconds.
  191. """
  192. if self._timesteps is not None and len(self._timesteps) > 0:
  193. return self._timesteps
  194. self._timesteps = np.array(list(self._parser._raw_metadata.acquisition_times), dtype=np.float) * 1000.0
  195. return self._timesteps