Implement a heap based task scheduler

This commit is contained in:
xs5871 2023-03-29 20:02:29 +00:00 committed by xs5871
parent bc5fb9dc9e
commit 3c4e064201
5 changed files with 282 additions and 55 deletions

View File

@ -1,19 +1,17 @@
try: try:
from typing import Callable, Optional, Tuple from typing import Callable, Optional
except ImportError: except ImportError:
pass pass
from supervisor import ticks_ms
from collections import namedtuple from collections import namedtuple
from keypad import Event as KeyEvent from keypad import Event as KeyEvent
from kmk.consts import UnicodeMode from kmk.consts import UnicodeMode
from kmk.hid import BLEHID, USBHID, AbstractHID, HIDModes from kmk.hid import BLEHID, USBHID, AbstractHID, HIDModes
from kmk.keys import KC, Key from kmk.keys import KC, Key
from kmk.kmktime import ticks_add, ticks_diff
from kmk.modules import Module from kmk.modules import Module
from kmk.scanners.keypad import MatrixScanner from kmk.scanners.keypad import MatrixScanner
from kmk.scheduler import Task, cancel_task, create_task, get_due_task
from kmk.utils import Debug from kmk.utils import Debug
debug = Debug('kmk.keyboard') debug = Debug('kmk.keyboard')
@ -266,60 +264,17 @@ class KMKKeyboard:
def tap_key(self, keycode: Key) -> None: def tap_key(self, keycode: Key) -> None:
self.add_key(keycode) self.add_key(keycode)
# On the next cycle, we'll remove the key. # On the next cycle, we'll remove the key.
self.set_timeout(False, lambda: self.remove_key(keycode)) self.set_timeout(0, lambda: self.remove_key(keycode))
def set_timeout( def set_timeout(self, after_ticks: int, callback: Callable[[None], None]) -> [Task]:
self, after_ticks: int, callback: Callable[[None], None] return create_task(callback, after_ms=after_ticks)
) -> Tuple[int, int]:
# We allow passing False as an implicit "run this on the next process timeouts cycle"
if after_ticks is False:
after_ticks = 0
if after_ticks == 0 and self._processing_timeouts:
after_ticks += 1
timeout_key = ticks_add(ticks_ms(), after_ticks)
if timeout_key not in self._timeouts:
self._timeouts[timeout_key] = []
idx = len(self._timeouts[timeout_key])
self._timeouts[timeout_key].append(callback)
return (timeout_key, idx)
def cancel_timeout(self, timeout_key: int) -> None: def cancel_timeout(self, timeout_key: int) -> None:
try: cancel_task(timeout_key)
self._timeouts[timeout_key[0]][timeout_key[1]] = None
except (KeyError, IndexError):
if debug.enabled:
debug(f'no such timeout: {timeout_key}')
def _process_timeouts(self) -> None: def _process_timeouts(self) -> None:
if not self._timeouts: for task in get_due_task():
return task()
# Copy timeout keys to a temporary list to allow sorting.
# Prevent net timeouts set during handling from running on the current
# cycle by setting a flag `_processing_timeouts`.
current_time = ticks_ms()
timeout_keys = []
self._processing_timeouts = True
for k in self._timeouts.keys():
if ticks_diff(k, current_time) <= 0:
timeout_keys.append(k)
if timeout_keys and debug.enabled:
debug('processing timeouts')
for k in sorted(timeout_keys):
for callback in self._timeouts[k]:
if callback:
callback()
del self._timeouts[k]
self._processing_timeouts = False
def _init_sanity_check(self) -> None: def _init_sanity_check(self) -> None:
''' '''

67
kmk/scheduler.py Normal file
View File

@ -0,0 +1,67 @@
'''
Here we're abusing _asyncios TaskQueue to implement a very simple priority
queue task scheduler.
Despite documentation, Circuitpython doesn't usually ship with a min-heap
module; it does however implement a pairing-heap for `TaskQueue` in native code.
'''
try:
from typing import Callable
except ImportError:
pass
from supervisor import ticks_ms
from _asyncio import Task, TaskQueue
from kmk.kmktime import ticks_add, ticks_diff
_task_queue = TaskQueue()
class PeriodicTaskMeta:
def __init__(self, func: Callable[[None], None], period: int) -> None:
self._task = Task(self.call)
self._coro = func
self.period = period
def call(self) -> None:
self._coro()
after_ms = ticks_add(self._task.ph_key, self.period)
_task_queue.push_sorted(self._task, after_ms)
def create_task(
func: Callable[[None], None],
*,
after_ms: int = 0,
period_ms: int = 0,
) -> [Task, PeriodicTaskMeta]:
if period_ms:
r = PeriodicTaskMeta(func, period_ms)
t = r._task
else:
t = r = Task(func)
if after_ms:
after_ms = ticks_add(ticks_ms(), after_ms)
_task_queue.push_sorted(t, after_ms)
else:
_task_queue.push_head(t)
return r
def get_due_task() -> [Callable, None]:
while True:
t = _task_queue.peek()
if not t or ticks_diff(t.ph_key, ticks_ms()) > 0:
break
_task_queue.pop_head()
yield t.coro
def cancel_task(t: [Task, PeriodicTaskMeta]) -> None:
if isinstance(t, PeriodicTaskMeta):
t = t._task
_task_queue.remove(t)

View File

@ -6,6 +6,7 @@ from kmk.keys import KC, ModifierKey
from kmk.kmk_keyboard import KMKKeyboard from kmk.kmk_keyboard import KMKKeyboard
from kmk.scanners import DiodeOrientation from kmk.scanners import DiodeOrientation
from kmk.scanners.digitalio import MatrixScanner from kmk.scanners.digitalio import MatrixScanner
from kmk.scheduler import _task_queue
class DigitalInOut(Mock): class DigitalInOut(Mock):
@ -81,7 +82,7 @@ class KeyboardTest:
timeout = time.time_ns() + 10 * 1_000_000_000 timeout = time.time_ns() + 10 * 1_000_000_000
while timeout > time.time_ns(): while timeout > time.time_ns():
self.do_main_loop() self.do_main_loop()
if not self.keyboard._timeouts and not self.keyboard._resume_buffer: if not _task_queue.peek() and not self.keyboard._resume_buffer:
break break
assert timeout > time.time_ns(), 'infinite loop detected' assert timeout > time.time_ns(), 'infinite loop detected'

View File

@ -9,6 +9,10 @@ class KeyEvent:
self.pressed = pressed self.pressed = pressed
def ticks_ms():
return (time.time_ns() // 1_000_000) % (1 << 29)
def init_circuit_python_modules_mocks(): def init_circuit_python_modules_mocks():
sys.modules['usb_hid'] = Mock() sys.modules['usb_hid'] = Mock()
sys.modules['digitalio'] = Mock() sys.modules['digitalio'] = Mock()
@ -26,4 +30,8 @@ def init_circuit_python_modules_mocks():
sys.modules['micropython'].const = lambda x: x sys.modules['micropython'].const = lambda x: x
sys.modules['supervisor'] = Mock() sys.modules['supervisor'] = Mock()
sys.modules['supervisor'].ticks_ms = lambda: time.time_ns() // 1_000_000 sys.modules['supervisor'].ticks_ms = ticks_ms
from . import task
sys.modules['_asyncio'] = task

196
tests/task.py Normal file
View File

@ -0,0 +1,196 @@
# MicroPython uasyncio module
# MIT license; Copyright (c) 2019-2020 Damien P. George
# This file contains the core TaskQueue based on a pairing heap, and the core Task class.
# They can optionally be replaced by C implementations.
# This file is a modified version, based on the extmod in Circuitpython, for
# unit testing in KMK only.
from supervisor import ticks_ms
from kmk.kmktime import ticks_diff
cur_task = None
__task_queue = None
class CancelledError(BaseException):
pass
# pairing-heap meld of 2 heaps; O(1)
def ph_meld(h1, h2):
if h1 is None:
return h2
if h2 is None:
return h1
lt = ticks_diff(h1.ph_key, h2.ph_key) < 0
if lt:
if h1.ph_child is None:
h1.ph_child = h2
else:
h1.ph_child_last.ph_next = h2
h1.ph_child_last = h2
h2.ph_next = None
h2.ph_rightmost_parent = h1
return h1
else:
h1.ph_next = h2.ph_child
h2.ph_child = h1
if h1.ph_next is None:
h2.ph_child_last = h1
h1.ph_rightmost_parent = h2
return h2
# pairing-heap pairing operation; amortised O(log N)
def ph_pairing(child):
heap = None
while child is not None:
n1 = child
child = child.ph_next
n1.ph_next = None
if child is not None:
n2 = child
child = child.ph_next
n2.ph_next = None
n1 = ph_meld(n1, n2)
heap = ph_meld(heap, n1)
return heap
# pairing-heap delete of a node; stable, amortised O(log N)
def ph_delete(heap, node):
if node is heap:
child = heap.ph_child
node.ph_child = None
return ph_pairing(child)
# Find parent of node
parent = node
while parent.ph_next is not None:
parent = parent.ph_next
parent = parent.ph_rightmost_parent
if parent is None or parent.ph_child is None:
return heap
# Replace node with pairing of its children
if node is parent.ph_child and node.ph_child is None:
parent.ph_child = node.ph_next
node.ph_next = None
return heap
elif node is parent.ph_child:
child = node.ph_child
next = node.ph_next
node.ph_child = None
node.ph_next = None
node = ph_pairing(child)
parent.ph_child = node
else:
n = parent.ph_child
while node is not n.ph_next:
n = n.ph_next
if not n:
return heap
child = node.ph_child
next = node.ph_next
node.ph_child = None
node.ph_next = None
node = ph_pairing(child)
if node is None:
node = n
else:
n.ph_next = node
node.ph_next = next
if next is None:
node.ph_rightmost_parent = parent
parent.ph_child_last = node
return heap
# TaskQueue class based on the above pairing-heap functions.
class TaskQueue:
def __init__(self):
self.heap = None
def peek(self):
return self.heap
def push_sorted(self, v, key):
v.data = None
v.ph_key = key
v.ph_child = None
v.ph_next = None
self.heap = ph_meld(v, self.heap)
def push_head(self, v):
self.push_sorted(v, ticks_ms())
def pop_head(self):
v = self.heap
self.heap = ph_pairing(v.ph_child)
# v.ph_child = None
return v
def remove(self, v):
self.heap = ph_delete(self.heap, v)
# Task class representing a coroutine, can be waited on and cancelled.
class Task:
def __init__(self, coro, globals=None):
self.coro = coro # Coroutine of this Task
self.data = None # General data for queue it is waiting on
self.state = True # None, False, True or a TaskQueue instance
self.ph_key = 0 # Pairing heap
self.ph_child = None # Paring heap
self.ph_child_last = None # Paring heap
self.ph_next = None # Paring heap
self.ph_rightmost_parent = None # Paring heap
def __await__(self):
if not self.state:
# Task finished, signal that is has been await'ed on.
self.state = False
elif self.state is True:
# Allocated head of linked list of Tasks waiting on completion of this task.
self.state = TaskQueue()
return self
def __next__(self):
if not self.state:
if self.data is None:
# Task finished but has already been sent to the loop's exception handler.
raise StopIteration
else:
# Task finished, raise return value to caller so it can continue.
raise self.data
else:
# Put calling task on waiting queue.
self.state.push_head(cur_task)
# Set calling task's data to this task that it waits on, to double-link it.
cur_task.data = self
def done(self):
return not self.state
def cancel(self):
# Check if task is already finished.
if not self.state:
return False
# Can't cancel self (not supported yet).
if self is cur_task:
raise RuntimeError("can't cancel self")
# If Task waits on another task then forward the cancel to the one it's waiting on.
while isinstance(self.data, Task):
self = self.data
# Reschedule Task as a cancelled task.
if hasattr(self.data, 'remove'):
# Not on the main running queue, remove the task from the queue it's on.
self.data.remove(self)
__task_queue.push_head(self)
elif ticks_diff(self.ph_key, ticks_ms()) > 0:
# On the main running queue but scheduled in the future, so bring it forward to now.
__task_queue.remove(self)
__task_queue.push_head(self)
self.data = CancelledError
return True