From 3c4e064201457c149521f0338c40a86441e25e54 Mon Sep 17 00:00:00 2001 From: xs5871 Date: Wed, 29 Mar 2023 20:02:29 +0000 Subject: [PATCH] Implement a heap based task scheduler --- kmk/kmk_keyboard.py | 61 ++----------- kmk/scheduler.py | 67 ++++++++++++++ tests/keyboard_test.py | 3 +- tests/mocks.py | 10 ++- tests/task.py | 196 +++++++++++++++++++++++++++++++++++++++++ 5 files changed, 282 insertions(+), 55 deletions(-) create mode 100644 kmk/scheduler.py create mode 100644 tests/task.py diff --git a/kmk/kmk_keyboard.py b/kmk/kmk_keyboard.py index d4fdecc..63cb672 100644 --- a/kmk/kmk_keyboard.py +++ b/kmk/kmk_keyboard.py @@ -1,19 +1,17 @@ try: - from typing import Callable, Optional, Tuple + from typing import Callable, Optional except ImportError: pass -from supervisor import ticks_ms - from collections import namedtuple from keypad import Event as KeyEvent from kmk.consts import UnicodeMode from kmk.hid import BLEHID, USBHID, AbstractHID, HIDModes from kmk.keys import KC, Key -from kmk.kmktime import ticks_add, ticks_diff from kmk.modules import Module from kmk.scanners.keypad import MatrixScanner +from kmk.scheduler import Task, cancel_task, create_task, get_due_task from kmk.utils import Debug debug = Debug('kmk.keyboard') @@ -266,60 +264,17 @@ class KMKKeyboard: def tap_key(self, keycode: Key) -> None: self.add_key(keycode) # On the next cycle, we'll remove the key. - self.set_timeout(False, lambda: self.remove_key(keycode)) + self.set_timeout(0, lambda: self.remove_key(keycode)) - def set_timeout( - self, after_ticks: int, callback: Callable[[None], None] - ) -> Tuple[int, int]: - # We allow passing False as an implicit "run this on the next process timeouts cycle" - if after_ticks is False: - after_ticks = 0 - - if after_ticks == 0 and self._processing_timeouts: - after_ticks += 1 - - timeout_key = ticks_add(ticks_ms(), after_ticks) - - if timeout_key not in self._timeouts: - self._timeouts[timeout_key] = [] - - idx = len(self._timeouts[timeout_key]) - self._timeouts[timeout_key].append(callback) - - return (timeout_key, idx) + def set_timeout(self, after_ticks: int, callback: Callable[[None], None]) -> [Task]: + return create_task(callback, after_ms=after_ticks) def cancel_timeout(self, timeout_key: int) -> None: - try: - self._timeouts[timeout_key[0]][timeout_key[1]] = None - except (KeyError, IndexError): - if debug.enabled: - debug(f'no such timeout: {timeout_key}') + cancel_task(timeout_key) def _process_timeouts(self) -> None: - if not self._timeouts: - return - - # Copy timeout keys to a temporary list to allow sorting. - # Prevent net timeouts set during handling from running on the current - # cycle by setting a flag `_processing_timeouts`. - current_time = ticks_ms() - timeout_keys = [] - self._processing_timeouts = True - - for k in self._timeouts.keys(): - if ticks_diff(k, current_time) <= 0: - timeout_keys.append(k) - - if timeout_keys and debug.enabled: - debug('processing timeouts') - - for k in sorted(timeout_keys): - for callback in self._timeouts[k]: - if callback: - callback() - del self._timeouts[k] - - self._processing_timeouts = False + for task in get_due_task(): + task() def _init_sanity_check(self) -> None: ''' diff --git a/kmk/scheduler.py b/kmk/scheduler.py new file mode 100644 index 0000000..1e14ebf --- /dev/null +++ b/kmk/scheduler.py @@ -0,0 +1,67 @@ +''' +Here we're abusing _asyncios TaskQueue to implement a very simple priority +queue task scheduler. +Despite documentation, Circuitpython doesn't usually ship with a min-heap +module; it does however implement a pairing-heap for `TaskQueue` in native code. +''' + +try: + from typing import Callable +except ImportError: + pass + +from supervisor import ticks_ms + +from _asyncio import Task, TaskQueue + +from kmk.kmktime import ticks_add, ticks_diff + +_task_queue = TaskQueue() + + +class PeriodicTaskMeta: + def __init__(self, func: Callable[[None], None], period: int) -> None: + self._task = Task(self.call) + self._coro = func + self.period = period + + def call(self) -> None: + self._coro() + after_ms = ticks_add(self._task.ph_key, self.period) + _task_queue.push_sorted(self._task, after_ms) + + +def create_task( + func: Callable[[None], None], + *, + after_ms: int = 0, + period_ms: int = 0, +) -> [Task, PeriodicTaskMeta]: + if period_ms: + r = PeriodicTaskMeta(func, period_ms) + t = r._task + else: + t = r = Task(func) + + if after_ms: + after_ms = ticks_add(ticks_ms(), after_ms) + _task_queue.push_sorted(t, after_ms) + else: + _task_queue.push_head(t) + + return r + + +def get_due_task() -> [Callable, None]: + while True: + t = _task_queue.peek() + if not t or ticks_diff(t.ph_key, ticks_ms()) > 0: + break + _task_queue.pop_head() + yield t.coro + + +def cancel_task(t: [Task, PeriodicTaskMeta]) -> None: + if isinstance(t, PeriodicTaskMeta): + t = t._task + _task_queue.remove(t) diff --git a/tests/keyboard_test.py b/tests/keyboard_test.py index 84a78e3..7bb47d0 100644 --- a/tests/keyboard_test.py +++ b/tests/keyboard_test.py @@ -6,6 +6,7 @@ from kmk.keys import KC, ModifierKey from kmk.kmk_keyboard import KMKKeyboard from kmk.scanners import DiodeOrientation from kmk.scanners.digitalio import MatrixScanner +from kmk.scheduler import _task_queue class DigitalInOut(Mock): @@ -81,7 +82,7 @@ class KeyboardTest: timeout = time.time_ns() + 10 * 1_000_000_000 while timeout > time.time_ns(): self.do_main_loop() - if not self.keyboard._timeouts and not self.keyboard._resume_buffer: + if not _task_queue.peek() and not self.keyboard._resume_buffer: break assert timeout > time.time_ns(), 'infinite loop detected' diff --git a/tests/mocks.py b/tests/mocks.py index 669977f..991ca83 100644 --- a/tests/mocks.py +++ b/tests/mocks.py @@ -9,6 +9,10 @@ class KeyEvent: self.pressed = pressed +def ticks_ms(): + return (time.time_ns() // 1_000_000) % (1 << 29) + + def init_circuit_python_modules_mocks(): sys.modules['usb_hid'] = Mock() sys.modules['digitalio'] = Mock() @@ -26,4 +30,8 @@ def init_circuit_python_modules_mocks(): sys.modules['micropython'].const = lambda x: x sys.modules['supervisor'] = Mock() - sys.modules['supervisor'].ticks_ms = lambda: time.time_ns() // 1_000_000 + sys.modules['supervisor'].ticks_ms = ticks_ms + + from . import task + + sys.modules['_asyncio'] = task diff --git a/tests/task.py b/tests/task.py new file mode 100644 index 0000000..832776b --- /dev/null +++ b/tests/task.py @@ -0,0 +1,196 @@ +# MicroPython uasyncio module +# MIT license; Copyright (c) 2019-2020 Damien P. George + +# This file contains the core TaskQueue based on a pairing heap, and the core Task class. +# They can optionally be replaced by C implementations. + +# This file is a modified version, based on the extmod in Circuitpython, for +# unit testing in KMK only. + +from supervisor import ticks_ms + +from kmk.kmktime import ticks_diff + +cur_task = None +__task_queue = None + + +class CancelledError(BaseException): + pass + + +# pairing-heap meld of 2 heaps; O(1) +def ph_meld(h1, h2): + if h1 is None: + return h2 + if h2 is None: + return h1 + lt = ticks_diff(h1.ph_key, h2.ph_key) < 0 + if lt: + if h1.ph_child is None: + h1.ph_child = h2 + else: + h1.ph_child_last.ph_next = h2 + h1.ph_child_last = h2 + h2.ph_next = None + h2.ph_rightmost_parent = h1 + return h1 + else: + h1.ph_next = h2.ph_child + h2.ph_child = h1 + if h1.ph_next is None: + h2.ph_child_last = h1 + h1.ph_rightmost_parent = h2 + return h2 + + +# pairing-heap pairing operation; amortised O(log N) +def ph_pairing(child): + heap = None + while child is not None: + n1 = child + child = child.ph_next + n1.ph_next = None + if child is not None: + n2 = child + child = child.ph_next + n2.ph_next = None + n1 = ph_meld(n1, n2) + heap = ph_meld(heap, n1) + return heap + + +# pairing-heap delete of a node; stable, amortised O(log N) +def ph_delete(heap, node): + if node is heap: + child = heap.ph_child + node.ph_child = None + return ph_pairing(child) + # Find parent of node + parent = node + while parent.ph_next is not None: + parent = parent.ph_next + parent = parent.ph_rightmost_parent + if parent is None or parent.ph_child is None: + return heap + # Replace node with pairing of its children + if node is parent.ph_child and node.ph_child is None: + parent.ph_child = node.ph_next + node.ph_next = None + return heap + elif node is parent.ph_child: + child = node.ph_child + next = node.ph_next + node.ph_child = None + node.ph_next = None + node = ph_pairing(child) + parent.ph_child = node + else: + n = parent.ph_child + while node is not n.ph_next: + n = n.ph_next + if not n: + return heap + child = node.ph_child + next = node.ph_next + node.ph_child = None + node.ph_next = None + node = ph_pairing(child) + if node is None: + node = n + else: + n.ph_next = node + node.ph_next = next + if next is None: + node.ph_rightmost_parent = parent + parent.ph_child_last = node + return heap + + +# TaskQueue class based on the above pairing-heap functions. +class TaskQueue: + def __init__(self): + self.heap = None + + def peek(self): + return self.heap + + def push_sorted(self, v, key): + v.data = None + v.ph_key = key + v.ph_child = None + v.ph_next = None + self.heap = ph_meld(v, self.heap) + + def push_head(self, v): + self.push_sorted(v, ticks_ms()) + + def pop_head(self): + v = self.heap + self.heap = ph_pairing(v.ph_child) + # v.ph_child = None + return v + + def remove(self, v): + self.heap = ph_delete(self.heap, v) + + +# Task class representing a coroutine, can be waited on and cancelled. +class Task: + def __init__(self, coro, globals=None): + self.coro = coro # Coroutine of this Task + self.data = None # General data for queue it is waiting on + self.state = True # None, False, True or a TaskQueue instance + self.ph_key = 0 # Pairing heap + self.ph_child = None # Paring heap + self.ph_child_last = None # Paring heap + self.ph_next = None # Paring heap + self.ph_rightmost_parent = None # Paring heap + + def __await__(self): + if not self.state: + # Task finished, signal that is has been await'ed on. + self.state = False + elif self.state is True: + # Allocated head of linked list of Tasks waiting on completion of this task. + self.state = TaskQueue() + return self + + def __next__(self): + if not self.state: + if self.data is None: + # Task finished but has already been sent to the loop's exception handler. + raise StopIteration + else: + # Task finished, raise return value to caller so it can continue. + raise self.data + else: + # Put calling task on waiting queue. + self.state.push_head(cur_task) + # Set calling task's data to this task that it waits on, to double-link it. + cur_task.data = self + + def done(self): + return not self.state + + def cancel(self): + # Check if task is already finished. + if not self.state: + return False + # Can't cancel self (not supported yet). + if self is cur_task: + raise RuntimeError("can't cancel self") + # If Task waits on another task then forward the cancel to the one it's waiting on. + while isinstance(self.data, Task): + self = self.data + # Reschedule Task as a cancelled task. + if hasattr(self.data, 'remove'): + # Not on the main running queue, remove the task from the queue it's on. + self.data.remove(self) + __task_queue.push_head(self) + elif ticks_diff(self.ph_key, ticks_ms()) > 0: + # On the main running queue but scheduled in the future, so bring it forward to now. + __task_queue.remove(self) + __task_queue.push_head(self) + self.data = CancelledError + return True