dataladcmd_exec.py 5.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154
  1. from concurrent.futures import ThreadPoolExecutor
  2. import threading
  3. from time import time
  4. from typing import (
  5. Dict,
  6. )
  7. from PySide6.QtCore import (
  8. QObject,
  9. Signal,
  10. Slot,
  11. )
  12. from datalad.interface.base import Interface
  13. from datalad.support.exceptions import CapturedException
  14. from datalad.utils import get_wrapped_class
  15. # lazy import
  16. dlapi = None
  17. class GooeyDataladCmdExec(QObject):
  18. """Non-blocking execution of DataLad API commands
  19. and Qt-signal result reporting
  20. """
  21. # thread_id, cmdname, cmdargs/kwargs
  22. execution_started = Signal(str, str, dict, dict)
  23. execution_finished = Signal(str, str, dict, dict)
  24. # thread_id, cmdname, cmdargs/kwargs, CapturedException
  25. execution_failed = Signal(str, str, dict, CapturedException)
  26. results_received = Signal(Interface, list)
  27. def __init__(self):
  28. super().__init__()
  29. self._threadpool = ThreadPoolExecutor(
  30. max_workers=1,
  31. thread_name_prefix='gooey_datalad_cmdexec',
  32. # some callable to start at each thread execution
  33. #initializer=self.
  34. #initargs=
  35. )
  36. self._futures = set()
  37. # connect maintenance slot to give us an accurate
  38. # assessment of ongoing commands
  39. self.execution_finished.connect(self._update_futures)
  40. self.execution_failed.connect(self._update_futures)
  41. def _update_futures(self):
  42. self._futures = set(f for f in self._futures if f.running())
  43. @Slot(str, dict)
  44. def execute(self, cmd: str,
  45. kwargs: Dict or None = None,
  46. exec_params: Dict or None = None):
  47. if kwargs is None:
  48. kwargs = dict()
  49. if exec_params is None:
  50. exec_params = dict()
  51. global dlapi
  52. if dlapi is None:
  53. from datalad import api as dl
  54. dlapi = dl
  55. # right now, we have no use for the returned future, because result
  56. # communication and thread finishing are handled by emitting Qt signals
  57. self._futures.add(self._threadpool.submit(
  58. self._cmdexec_thread,
  59. cmd,
  60. kwargs,
  61. exec_params,
  62. ))
  63. def _cmdexec_thread(self, cmdname: str, cmdkwargs: Dict, exec_params: Dict):
  64. """The code is executed in a worker thread"""
  65. print('EXECINTHREAD', cmdname, cmdkwargs, exec_params)
  66. preferred_result_interval = exec_params.get(
  67. 'preferred_result_interval', 1.0)
  68. res_override = exec_params.get(
  69. 'result_override', {})
  70. # get_ident() is an int, but in the future we might want to move
  71. # to PY3.8+ native thread IDs, so let's go with a string identifier
  72. # right away
  73. thread_id = str(threading.get_ident())
  74. # get functor to execute, resolve name against full API
  75. try:
  76. cmd = getattr(dlapi, cmdname)
  77. cls = get_wrapped_class(cmd)
  78. except Exception as e:
  79. self.execution_failed.emit(
  80. thread_id,
  81. cmdname,
  82. cmdkwargs,
  83. exec_params,
  84. CapturedException(e),
  85. )
  86. return
  87. self.execution_started.emit(
  88. thread_id,
  89. cmdname,
  90. cmdkwargs,
  91. exec_params,
  92. )
  93. # enforce return_type='generator' to get the most responsive
  94. # any command could be
  95. cmdkwargs['return_type'] = 'generator'
  96. # Unless explicitly specified, force result records instead of the
  97. # command's default transformation which might give Dataset instances
  98. # for example.
  99. if 'result_xfm' not in cmdkwargs:
  100. cmdkwargs['result_xfm'] = None
  101. if 'dataset' in cmdkwargs:
  102. # Pass actual instance, to have path arguments resolved against it
  103. # instead of Gooey's CWD.
  104. cmdkwargs['dataset'] = dlapi.Dataset(cmdkwargs['dataset'])
  105. gathered_results = []
  106. last_report_ts = time()
  107. try:
  108. for res in cmd(**cmdkwargs):
  109. t = time()
  110. res.update(res_override)
  111. gathered_results.append(res)
  112. if (t - last_report_ts) > preferred_result_interval:
  113. self.results_received.emit(cls, gathered_results)
  114. gathered_results = []
  115. last_report_ts = t
  116. except Exception as e:
  117. if gathered_results:
  118. self.results_received.emit(cls, gathered_results)
  119. ce = CapturedException(e)
  120. self.execution_failed.emit(
  121. thread_id,
  122. cmdname,
  123. cmdkwargs,
  124. exec_params,
  125. ce
  126. )
  127. else:
  128. if gathered_results:
  129. self.results_received.emit(cls, gathered_results)
  130. self.execution_finished.emit(
  131. thread_id,
  132. cmdname,
  133. cmdkwargs,
  134. exec_params,
  135. )
  136. @property
  137. def n_running(self):
  138. return len([f for f in self._futures if f.running()])