Implement a heap based task scheduler
This commit is contained in:
parent
bc5fb9dc9e
commit
3c4e064201
@ -1,19 +1,17 @@
|
|||||||
try:
|
try:
|
||||||
from typing import Callable, Optional, Tuple
|
from typing import Callable, Optional
|
||||||
except ImportError:
|
except ImportError:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
from supervisor import ticks_ms
|
|
||||||
|
|
||||||
from collections import namedtuple
|
from collections import namedtuple
|
||||||
from keypad import Event as KeyEvent
|
from keypad import Event as KeyEvent
|
||||||
|
|
||||||
from kmk.consts import UnicodeMode
|
from kmk.consts import UnicodeMode
|
||||||
from kmk.hid import BLEHID, USBHID, AbstractHID, HIDModes
|
from kmk.hid import BLEHID, USBHID, AbstractHID, HIDModes
|
||||||
from kmk.keys import KC, Key
|
from kmk.keys import KC, Key
|
||||||
from kmk.kmktime import ticks_add, ticks_diff
|
|
||||||
from kmk.modules import Module
|
from kmk.modules import Module
|
||||||
from kmk.scanners.keypad import MatrixScanner
|
from kmk.scanners.keypad import MatrixScanner
|
||||||
|
from kmk.scheduler import Task, cancel_task, create_task, get_due_task
|
||||||
from kmk.utils import Debug
|
from kmk.utils import Debug
|
||||||
|
|
||||||
debug = Debug('kmk.keyboard')
|
debug = Debug('kmk.keyboard')
|
||||||
@ -266,60 +264,17 @@ class KMKKeyboard:
|
|||||||
def tap_key(self, keycode: Key) -> None:
|
def tap_key(self, keycode: Key) -> None:
|
||||||
self.add_key(keycode)
|
self.add_key(keycode)
|
||||||
# On the next cycle, we'll remove the key.
|
# On the next cycle, we'll remove the key.
|
||||||
self.set_timeout(False, lambda: self.remove_key(keycode))
|
self.set_timeout(0, lambda: self.remove_key(keycode))
|
||||||
|
|
||||||
def set_timeout(
|
def set_timeout(self, after_ticks: int, callback: Callable[[None], None]) -> [Task]:
|
||||||
self, after_ticks: int, callback: Callable[[None], None]
|
return create_task(callback, after_ms=after_ticks)
|
||||||
) -> Tuple[int, int]:
|
|
||||||
# We allow passing False as an implicit "run this on the next process timeouts cycle"
|
|
||||||
if after_ticks is False:
|
|
||||||
after_ticks = 0
|
|
||||||
|
|
||||||
if after_ticks == 0 and self._processing_timeouts:
|
|
||||||
after_ticks += 1
|
|
||||||
|
|
||||||
timeout_key = ticks_add(ticks_ms(), after_ticks)
|
|
||||||
|
|
||||||
if timeout_key not in self._timeouts:
|
|
||||||
self._timeouts[timeout_key] = []
|
|
||||||
|
|
||||||
idx = len(self._timeouts[timeout_key])
|
|
||||||
self._timeouts[timeout_key].append(callback)
|
|
||||||
|
|
||||||
return (timeout_key, idx)
|
|
||||||
|
|
||||||
def cancel_timeout(self, timeout_key: int) -> None:
|
def cancel_timeout(self, timeout_key: int) -> None:
|
||||||
try:
|
cancel_task(timeout_key)
|
||||||
self._timeouts[timeout_key[0]][timeout_key[1]] = None
|
|
||||||
except (KeyError, IndexError):
|
|
||||||
if debug.enabled:
|
|
||||||
debug(f'no such timeout: {timeout_key}')
|
|
||||||
|
|
||||||
def _process_timeouts(self) -> None:
|
def _process_timeouts(self) -> None:
|
||||||
if not self._timeouts:
|
for task in get_due_task():
|
||||||
return
|
task()
|
||||||
|
|
||||||
# Copy timeout keys to a temporary list to allow sorting.
|
|
||||||
# Prevent net timeouts set during handling from running on the current
|
|
||||||
# cycle by setting a flag `_processing_timeouts`.
|
|
||||||
current_time = ticks_ms()
|
|
||||||
timeout_keys = []
|
|
||||||
self._processing_timeouts = True
|
|
||||||
|
|
||||||
for k in self._timeouts.keys():
|
|
||||||
if ticks_diff(k, current_time) <= 0:
|
|
||||||
timeout_keys.append(k)
|
|
||||||
|
|
||||||
if timeout_keys and debug.enabled:
|
|
||||||
debug('processing timeouts')
|
|
||||||
|
|
||||||
for k in sorted(timeout_keys):
|
|
||||||
for callback in self._timeouts[k]:
|
|
||||||
if callback:
|
|
||||||
callback()
|
|
||||||
del self._timeouts[k]
|
|
||||||
|
|
||||||
self._processing_timeouts = False
|
|
||||||
|
|
||||||
def _init_sanity_check(self) -> None:
|
def _init_sanity_check(self) -> None:
|
||||||
'''
|
'''
|
||||||
|
67
kmk/scheduler.py
Normal file
67
kmk/scheduler.py
Normal file
@ -0,0 +1,67 @@
|
|||||||
|
'''
|
||||||
|
Here we're abusing _asyncios TaskQueue to implement a very simple priority
|
||||||
|
queue task scheduler.
|
||||||
|
Despite documentation, Circuitpython doesn't usually ship with a min-heap
|
||||||
|
module; it does however implement a pairing-heap for `TaskQueue` in native code.
|
||||||
|
'''
|
||||||
|
|
||||||
|
try:
|
||||||
|
from typing import Callable
|
||||||
|
except ImportError:
|
||||||
|
pass
|
||||||
|
|
||||||
|
from supervisor import ticks_ms
|
||||||
|
|
||||||
|
from _asyncio import Task, TaskQueue
|
||||||
|
|
||||||
|
from kmk.kmktime import ticks_add, ticks_diff
|
||||||
|
|
||||||
|
_task_queue = TaskQueue()
|
||||||
|
|
||||||
|
|
||||||
|
class PeriodicTaskMeta:
|
||||||
|
def __init__(self, func: Callable[[None], None], period: int) -> None:
|
||||||
|
self._task = Task(self.call)
|
||||||
|
self._coro = func
|
||||||
|
self.period = period
|
||||||
|
|
||||||
|
def call(self) -> None:
|
||||||
|
self._coro()
|
||||||
|
after_ms = ticks_add(self._task.ph_key, self.period)
|
||||||
|
_task_queue.push_sorted(self._task, after_ms)
|
||||||
|
|
||||||
|
|
||||||
|
def create_task(
|
||||||
|
func: Callable[[None], None],
|
||||||
|
*,
|
||||||
|
after_ms: int = 0,
|
||||||
|
period_ms: int = 0,
|
||||||
|
) -> [Task, PeriodicTaskMeta]:
|
||||||
|
if period_ms:
|
||||||
|
r = PeriodicTaskMeta(func, period_ms)
|
||||||
|
t = r._task
|
||||||
|
else:
|
||||||
|
t = r = Task(func)
|
||||||
|
|
||||||
|
if after_ms:
|
||||||
|
after_ms = ticks_add(ticks_ms(), after_ms)
|
||||||
|
_task_queue.push_sorted(t, after_ms)
|
||||||
|
else:
|
||||||
|
_task_queue.push_head(t)
|
||||||
|
|
||||||
|
return r
|
||||||
|
|
||||||
|
|
||||||
|
def get_due_task() -> [Callable, None]:
|
||||||
|
while True:
|
||||||
|
t = _task_queue.peek()
|
||||||
|
if not t or ticks_diff(t.ph_key, ticks_ms()) > 0:
|
||||||
|
break
|
||||||
|
_task_queue.pop_head()
|
||||||
|
yield t.coro
|
||||||
|
|
||||||
|
|
||||||
|
def cancel_task(t: [Task, PeriodicTaskMeta]) -> None:
|
||||||
|
if isinstance(t, PeriodicTaskMeta):
|
||||||
|
t = t._task
|
||||||
|
_task_queue.remove(t)
|
@ -6,6 +6,7 @@ from kmk.keys import KC, ModifierKey
|
|||||||
from kmk.kmk_keyboard import KMKKeyboard
|
from kmk.kmk_keyboard import KMKKeyboard
|
||||||
from kmk.scanners import DiodeOrientation
|
from kmk.scanners import DiodeOrientation
|
||||||
from kmk.scanners.digitalio import MatrixScanner
|
from kmk.scanners.digitalio import MatrixScanner
|
||||||
|
from kmk.scheduler import _task_queue
|
||||||
|
|
||||||
|
|
||||||
class DigitalInOut(Mock):
|
class DigitalInOut(Mock):
|
||||||
@ -81,7 +82,7 @@ class KeyboardTest:
|
|||||||
timeout = time.time_ns() + 10 * 1_000_000_000
|
timeout = time.time_ns() + 10 * 1_000_000_000
|
||||||
while timeout > time.time_ns():
|
while timeout > time.time_ns():
|
||||||
self.do_main_loop()
|
self.do_main_loop()
|
||||||
if not self.keyboard._timeouts and not self.keyboard._resume_buffer:
|
if not _task_queue.peek() and not self.keyboard._resume_buffer:
|
||||||
break
|
break
|
||||||
assert timeout > time.time_ns(), 'infinite loop detected'
|
assert timeout > time.time_ns(), 'infinite loop detected'
|
||||||
|
|
||||||
|
@ -9,6 +9,10 @@ class KeyEvent:
|
|||||||
self.pressed = pressed
|
self.pressed = pressed
|
||||||
|
|
||||||
|
|
||||||
|
def ticks_ms():
|
||||||
|
return (time.time_ns() // 1_000_000) % (1 << 29)
|
||||||
|
|
||||||
|
|
||||||
def init_circuit_python_modules_mocks():
|
def init_circuit_python_modules_mocks():
|
||||||
sys.modules['usb_hid'] = Mock()
|
sys.modules['usb_hid'] = Mock()
|
||||||
sys.modules['digitalio'] = Mock()
|
sys.modules['digitalio'] = Mock()
|
||||||
@ -26,4 +30,8 @@ def init_circuit_python_modules_mocks():
|
|||||||
sys.modules['micropython'].const = lambda x: x
|
sys.modules['micropython'].const = lambda x: x
|
||||||
|
|
||||||
sys.modules['supervisor'] = Mock()
|
sys.modules['supervisor'] = Mock()
|
||||||
sys.modules['supervisor'].ticks_ms = lambda: time.time_ns() // 1_000_000
|
sys.modules['supervisor'].ticks_ms = ticks_ms
|
||||||
|
|
||||||
|
from . import task
|
||||||
|
|
||||||
|
sys.modules['_asyncio'] = task
|
||||||
|
196
tests/task.py
Normal file
196
tests/task.py
Normal file
@ -0,0 +1,196 @@
|
|||||||
|
# MicroPython uasyncio module
|
||||||
|
# MIT license; Copyright (c) 2019-2020 Damien P. George
|
||||||
|
|
||||||
|
# This file contains the core TaskQueue based on a pairing heap, and the core Task class.
|
||||||
|
# They can optionally be replaced by C implementations.
|
||||||
|
|
||||||
|
# This file is a modified version, based on the extmod in Circuitpython, for
|
||||||
|
# unit testing in KMK only.
|
||||||
|
|
||||||
|
from supervisor import ticks_ms
|
||||||
|
|
||||||
|
from kmk.kmktime import ticks_diff
|
||||||
|
|
||||||
|
cur_task = None
|
||||||
|
__task_queue = None
|
||||||
|
|
||||||
|
|
||||||
|
class CancelledError(BaseException):
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
# pairing-heap meld of 2 heaps; O(1)
|
||||||
|
def ph_meld(h1, h2):
|
||||||
|
if h1 is None:
|
||||||
|
return h2
|
||||||
|
if h2 is None:
|
||||||
|
return h1
|
||||||
|
lt = ticks_diff(h1.ph_key, h2.ph_key) < 0
|
||||||
|
if lt:
|
||||||
|
if h1.ph_child is None:
|
||||||
|
h1.ph_child = h2
|
||||||
|
else:
|
||||||
|
h1.ph_child_last.ph_next = h2
|
||||||
|
h1.ph_child_last = h2
|
||||||
|
h2.ph_next = None
|
||||||
|
h2.ph_rightmost_parent = h1
|
||||||
|
return h1
|
||||||
|
else:
|
||||||
|
h1.ph_next = h2.ph_child
|
||||||
|
h2.ph_child = h1
|
||||||
|
if h1.ph_next is None:
|
||||||
|
h2.ph_child_last = h1
|
||||||
|
h1.ph_rightmost_parent = h2
|
||||||
|
return h2
|
||||||
|
|
||||||
|
|
||||||
|
# pairing-heap pairing operation; amortised O(log N)
|
||||||
|
def ph_pairing(child):
|
||||||
|
heap = None
|
||||||
|
while child is not None:
|
||||||
|
n1 = child
|
||||||
|
child = child.ph_next
|
||||||
|
n1.ph_next = None
|
||||||
|
if child is not None:
|
||||||
|
n2 = child
|
||||||
|
child = child.ph_next
|
||||||
|
n2.ph_next = None
|
||||||
|
n1 = ph_meld(n1, n2)
|
||||||
|
heap = ph_meld(heap, n1)
|
||||||
|
return heap
|
||||||
|
|
||||||
|
|
||||||
|
# pairing-heap delete of a node; stable, amortised O(log N)
|
||||||
|
def ph_delete(heap, node):
|
||||||
|
if node is heap:
|
||||||
|
child = heap.ph_child
|
||||||
|
node.ph_child = None
|
||||||
|
return ph_pairing(child)
|
||||||
|
# Find parent of node
|
||||||
|
parent = node
|
||||||
|
while parent.ph_next is not None:
|
||||||
|
parent = parent.ph_next
|
||||||
|
parent = parent.ph_rightmost_parent
|
||||||
|
if parent is None or parent.ph_child is None:
|
||||||
|
return heap
|
||||||
|
# Replace node with pairing of its children
|
||||||
|
if node is parent.ph_child and node.ph_child is None:
|
||||||
|
parent.ph_child = node.ph_next
|
||||||
|
node.ph_next = None
|
||||||
|
return heap
|
||||||
|
elif node is parent.ph_child:
|
||||||
|
child = node.ph_child
|
||||||
|
next = node.ph_next
|
||||||
|
node.ph_child = None
|
||||||
|
node.ph_next = None
|
||||||
|
node = ph_pairing(child)
|
||||||
|
parent.ph_child = node
|
||||||
|
else:
|
||||||
|
n = parent.ph_child
|
||||||
|
while node is not n.ph_next:
|
||||||
|
n = n.ph_next
|
||||||
|
if not n:
|
||||||
|
return heap
|
||||||
|
child = node.ph_child
|
||||||
|
next = node.ph_next
|
||||||
|
node.ph_child = None
|
||||||
|
node.ph_next = None
|
||||||
|
node = ph_pairing(child)
|
||||||
|
if node is None:
|
||||||
|
node = n
|
||||||
|
else:
|
||||||
|
n.ph_next = node
|
||||||
|
node.ph_next = next
|
||||||
|
if next is None:
|
||||||
|
node.ph_rightmost_parent = parent
|
||||||
|
parent.ph_child_last = node
|
||||||
|
return heap
|
||||||
|
|
||||||
|
|
||||||
|
# TaskQueue class based on the above pairing-heap functions.
|
||||||
|
class TaskQueue:
|
||||||
|
def __init__(self):
|
||||||
|
self.heap = None
|
||||||
|
|
||||||
|
def peek(self):
|
||||||
|
return self.heap
|
||||||
|
|
||||||
|
def push_sorted(self, v, key):
|
||||||
|
v.data = None
|
||||||
|
v.ph_key = key
|
||||||
|
v.ph_child = None
|
||||||
|
v.ph_next = None
|
||||||
|
self.heap = ph_meld(v, self.heap)
|
||||||
|
|
||||||
|
def push_head(self, v):
|
||||||
|
self.push_sorted(v, ticks_ms())
|
||||||
|
|
||||||
|
def pop_head(self):
|
||||||
|
v = self.heap
|
||||||
|
self.heap = ph_pairing(v.ph_child)
|
||||||
|
# v.ph_child = None
|
||||||
|
return v
|
||||||
|
|
||||||
|
def remove(self, v):
|
||||||
|
self.heap = ph_delete(self.heap, v)
|
||||||
|
|
||||||
|
|
||||||
|
# Task class representing a coroutine, can be waited on and cancelled.
|
||||||
|
class Task:
|
||||||
|
def __init__(self, coro, globals=None):
|
||||||
|
self.coro = coro # Coroutine of this Task
|
||||||
|
self.data = None # General data for queue it is waiting on
|
||||||
|
self.state = True # None, False, True or a TaskQueue instance
|
||||||
|
self.ph_key = 0 # Pairing heap
|
||||||
|
self.ph_child = None # Paring heap
|
||||||
|
self.ph_child_last = None # Paring heap
|
||||||
|
self.ph_next = None # Paring heap
|
||||||
|
self.ph_rightmost_parent = None # Paring heap
|
||||||
|
|
||||||
|
def __await__(self):
|
||||||
|
if not self.state:
|
||||||
|
# Task finished, signal that is has been await'ed on.
|
||||||
|
self.state = False
|
||||||
|
elif self.state is True:
|
||||||
|
# Allocated head of linked list of Tasks waiting on completion of this task.
|
||||||
|
self.state = TaskQueue()
|
||||||
|
return self
|
||||||
|
|
||||||
|
def __next__(self):
|
||||||
|
if not self.state:
|
||||||
|
if self.data is None:
|
||||||
|
# Task finished but has already been sent to the loop's exception handler.
|
||||||
|
raise StopIteration
|
||||||
|
else:
|
||||||
|
# Task finished, raise return value to caller so it can continue.
|
||||||
|
raise self.data
|
||||||
|
else:
|
||||||
|
# Put calling task on waiting queue.
|
||||||
|
self.state.push_head(cur_task)
|
||||||
|
# Set calling task's data to this task that it waits on, to double-link it.
|
||||||
|
cur_task.data = self
|
||||||
|
|
||||||
|
def done(self):
|
||||||
|
return not self.state
|
||||||
|
|
||||||
|
def cancel(self):
|
||||||
|
# Check if task is already finished.
|
||||||
|
if not self.state:
|
||||||
|
return False
|
||||||
|
# Can't cancel self (not supported yet).
|
||||||
|
if self is cur_task:
|
||||||
|
raise RuntimeError("can't cancel self")
|
||||||
|
# If Task waits on another task then forward the cancel to the one it's waiting on.
|
||||||
|
while isinstance(self.data, Task):
|
||||||
|
self = self.data
|
||||||
|
# Reschedule Task as a cancelled task.
|
||||||
|
if hasattr(self.data, 'remove'):
|
||||||
|
# Not on the main running queue, remove the task from the queue it's on.
|
||||||
|
self.data.remove(self)
|
||||||
|
__task_queue.push_head(self)
|
||||||
|
elif ticks_diff(self.ph_key, ticks_ms()) > 0:
|
||||||
|
# On the main running queue but scheduled in the future, so bring it forward to now.
|
||||||
|
__task_queue.remove(self)
|
||||||
|
__task_queue.push_head(self)
|
||||||
|
self.data = CancelledError
|
||||||
|
return True
|
Loading…
Reference in New Issue
Block a user