dataladcmd_exec.py 7.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212
  1. from concurrent.futures import ThreadPoolExecutor
  2. import threading
  3. from time import time
  4. from types import MappingProxyType
  5. from typing import (
  6. Dict,
  7. )
  8. from PySide6.QtCore import (
  9. QObject,
  10. Signal,
  11. Slot,
  12. )
  13. from PySide6.QtWidgets import (
  14. QToolButton,
  15. )
  16. from datalad.interface.base import Interface
  17. from datalad.support.exceptions import CapturedException
  18. from datalad.utils import get_wrapped_class
  19. # lazy import
  20. dlapi = None
  21. class GooeyDataladCmdExec(QObject):
  22. """Non-blocking execution of DataLad API commands
  23. and Qt-signal result reporting
  24. """
  25. # thread_id, cmdname, cmdargs/kwargs, exec_params
  26. execution_started = Signal(str, str, MappingProxyType, MappingProxyType)
  27. execution_finished = Signal(str, str, MappingProxyType, MappingProxyType)
  28. # thread_id, cmdname, cmdargs/kwargs, exec_params, CapturedException
  29. execution_failed = Signal(str, str, MappingProxyType, MappingProxyType, CapturedException)
  30. results_received = Signal(Interface, list)
  31. def __init__(self):
  32. super().__init__()
  33. aw = QToolButton()
  34. aw.setAutoRaise(True)
  35. aw.clicked.connect(self._stop_thread)
  36. aw.hide()
  37. self._activity_widget = aw
  38. self.execution_started.connect(self._enable_activity_widget)
  39. self.execution_finished.connect(self._disable_activity_widget)
  40. self.execution_failed.connect(self._disable_activity_widget)
  41. # flag whether a running thread should stop ASAP
  42. self._kaboom = False
  43. self._threadpool = ThreadPoolExecutor(
  44. max_workers=1,
  45. thread_name_prefix='gooey_datalad_cmdexec',
  46. # some callable to start at each thread execution
  47. #initializer=self.
  48. #initargs=
  49. )
  50. self._futures = set()
  51. # connect maintenance slot to give us an accurate
  52. # assessment of ongoing commands
  53. self.execution_finished.connect(self._update_futures)
  54. self.execution_failed.connect(self._update_futures)
  55. def _update_futures(self):
  56. self._futures = set(f for f in self._futures if f.running())
  57. @Slot(str, dict)
  58. def execute(self, cmd: str,
  59. kwargs: MappingProxyType or None = None,
  60. exec_params: MappingProxyType or None = None):
  61. if kwargs is None:
  62. kwargs = dict()
  63. if exec_params is None:
  64. exec_params = dict()
  65. global dlapi
  66. if dlapi is None:
  67. from datalad import api as dl
  68. dlapi = dl
  69. # right now, we have no use for the returned future, because result
  70. # communication and thread finishing are handled by emitting Qt signals
  71. self._futures.add(self._threadpool.submit(
  72. self._cmdexec_thread,
  73. cmd,
  74. kwargs,
  75. exec_params,
  76. ))
  77. def _cmdexec_thread(
  78. self, cmdname: str,
  79. cmdkwargs: MappingProxyType,
  80. exec_params: MappingProxyType):
  81. """The code is executed in a worker thread"""
  82. # we need to amend the record below, make a mutable version
  83. cmdkwargs = cmdkwargs.copy()
  84. print('EXECINTHREAD', cmdname, cmdkwargs, exec_params)
  85. preferred_result_interval = exec_params.get(
  86. 'preferred_result_interval', 1.0)
  87. res_override = exec_params.get(
  88. 'result_override', {})
  89. # get_ident() is an int, but in the future we might want to move
  90. # to PY3.8+ native thread IDs, so let's go with a string identifier
  91. # right away
  92. thread_id = str(threading.get_ident())
  93. # get functor to execute, resolve name against full API
  94. try:
  95. cmd = getattr(dlapi, cmdname)
  96. cls = get_wrapped_class(cmd)
  97. except Exception as e:
  98. self.execution_failed.emit(
  99. thread_id,
  100. cmdname,
  101. cmdkwargs,
  102. exec_params,
  103. CapturedException(e),
  104. )
  105. return
  106. try:
  107. # the following is trivial, but we wrap it nevertheless to prevent
  108. # a silent crash of the worker thread
  109. self.execution_started.emit(
  110. thread_id,
  111. cmdname,
  112. cmdkwargs,
  113. exec_params,
  114. )
  115. # enforce return_type='generator' to get the most responsive
  116. # any command could be
  117. cmdkwargs['return_type'] = 'generator'
  118. # Unless explicitly specified, force result records instead of the
  119. # command's default transformation which might give Dataset
  120. # instances for example.
  121. if 'result_xfm' not in cmdkwargs:
  122. cmdkwargs['result_xfm'] = None
  123. if 'dataset' in cmdkwargs:
  124. # Pass actual instance, to have path arguments resolvedi
  125. # against it instead of Gooey's CWD.
  126. cmdkwargs['dataset'] = dlapi.Dataset(cmdkwargs['dataset'])
  127. except Exception as e:
  128. ce = CapturedException(e)
  129. self.execution_failed.emit(
  130. thread_id,
  131. cmdname,
  132. cmdkwargs,
  133. exec_params,
  134. ce
  135. )
  136. return
  137. gathered_results = []
  138. last_report_ts = time()
  139. try:
  140. for res in cmd(**cmdkwargs):
  141. t = time()
  142. res.update(res_override)
  143. gathered_results.append(res)
  144. if self._kaboom:
  145. raise InterruptedError()
  146. if (t - last_report_ts) > preferred_result_interval:
  147. self.results_received.emit(cls, gathered_results)
  148. gathered_results = []
  149. last_report_ts = t
  150. except Exception as e:
  151. if gathered_results:
  152. self.results_received.emit(cls, gathered_results)
  153. ce = CapturedException(e)
  154. self.execution_failed.emit(
  155. thread_id,
  156. cmdname,
  157. cmdkwargs,
  158. exec_params,
  159. ce
  160. )
  161. else:
  162. if gathered_results:
  163. self.results_received.emit(cls, gathered_results)
  164. self.execution_finished.emit(
  165. thread_id,
  166. cmdname,
  167. cmdkwargs,
  168. exec_params,
  169. )
  170. def _enable_activity_widget(
  171. self, thread_id: str, cmdname: str, cmdkwargs: dict,
  172. exec_params: dict):
  173. # thread_id, cmdname, cmdargs/kwargs, exec_params
  174. aw = self._activity_widget
  175. aw.setText(f"KABOOM {cmdname}")
  176. aw.show()
  177. def _disable_activity_widget(
  178. self, thread_id: str, cmdname: str, cmdkwargs: dict,
  179. exec_params: dict, exc: CapturedException = None):
  180. self._kaboom = False
  181. # thread_id, cmdname, cmdargs/kwargs, exec_params
  182. aw = self._activity_widget
  183. aw.hide()
  184. def _stop_thread(self):
  185. self._kaboom = True
  186. @property
  187. def activity_widget(self):
  188. return self._activity_widget
  189. @property
  190. def n_running(self):
  191. return len([f for f in self._futures if f.running()])