Implement a heap based task scheduler
This commit is contained in:
parent
bc5fb9dc9e
commit
3c4e064201
@ -1,19 +1,17 @@
|
||||
try:
|
||||
from typing import Callable, Optional, Tuple
|
||||
from typing import Callable, Optional
|
||||
except ImportError:
|
||||
pass
|
||||
|
||||
from supervisor import ticks_ms
|
||||
|
||||
from collections import namedtuple
|
||||
from keypad import Event as KeyEvent
|
||||
|
||||
from kmk.consts import UnicodeMode
|
||||
from kmk.hid import BLEHID, USBHID, AbstractHID, HIDModes
|
||||
from kmk.keys import KC, Key
|
||||
from kmk.kmktime import ticks_add, ticks_diff
|
||||
from kmk.modules import Module
|
||||
from kmk.scanners.keypad import MatrixScanner
|
||||
from kmk.scheduler import Task, cancel_task, create_task, get_due_task
|
||||
from kmk.utils import Debug
|
||||
|
||||
debug = Debug('kmk.keyboard')
|
||||
@ -266,60 +264,17 @@ class KMKKeyboard:
|
||||
def tap_key(self, keycode: Key) -> None:
|
||||
self.add_key(keycode)
|
||||
# On the next cycle, we'll remove the key.
|
||||
self.set_timeout(False, lambda: self.remove_key(keycode))
|
||||
self.set_timeout(0, lambda: self.remove_key(keycode))
|
||||
|
||||
def set_timeout(
|
||||
self, after_ticks: int, callback: Callable[[None], None]
|
||||
) -> Tuple[int, int]:
|
||||
# We allow passing False as an implicit "run this on the next process timeouts cycle"
|
||||
if after_ticks is False:
|
||||
after_ticks = 0
|
||||
|
||||
if after_ticks == 0 and self._processing_timeouts:
|
||||
after_ticks += 1
|
||||
|
||||
timeout_key = ticks_add(ticks_ms(), after_ticks)
|
||||
|
||||
if timeout_key not in self._timeouts:
|
||||
self._timeouts[timeout_key] = []
|
||||
|
||||
idx = len(self._timeouts[timeout_key])
|
||||
self._timeouts[timeout_key].append(callback)
|
||||
|
||||
return (timeout_key, idx)
|
||||
def set_timeout(self, after_ticks: int, callback: Callable[[None], None]) -> [Task]:
|
||||
return create_task(callback, after_ms=after_ticks)
|
||||
|
||||
def cancel_timeout(self, timeout_key: int) -> None:
|
||||
try:
|
||||
self._timeouts[timeout_key[0]][timeout_key[1]] = None
|
||||
except (KeyError, IndexError):
|
||||
if debug.enabled:
|
||||
debug(f'no such timeout: {timeout_key}')
|
||||
cancel_task(timeout_key)
|
||||
|
||||
def _process_timeouts(self) -> None:
|
||||
if not self._timeouts:
|
||||
return
|
||||
|
||||
# Copy timeout keys to a temporary list to allow sorting.
|
||||
# Prevent net timeouts set during handling from running on the current
|
||||
# cycle by setting a flag `_processing_timeouts`.
|
||||
current_time = ticks_ms()
|
||||
timeout_keys = []
|
||||
self._processing_timeouts = True
|
||||
|
||||
for k in self._timeouts.keys():
|
||||
if ticks_diff(k, current_time) <= 0:
|
||||
timeout_keys.append(k)
|
||||
|
||||
if timeout_keys and debug.enabled:
|
||||
debug('processing timeouts')
|
||||
|
||||
for k in sorted(timeout_keys):
|
||||
for callback in self._timeouts[k]:
|
||||
if callback:
|
||||
callback()
|
||||
del self._timeouts[k]
|
||||
|
||||
self._processing_timeouts = False
|
||||
for task in get_due_task():
|
||||
task()
|
||||
|
||||
def _init_sanity_check(self) -> None:
|
||||
'''
|
||||
|
67
kmk/scheduler.py
Normal file
67
kmk/scheduler.py
Normal file
@ -0,0 +1,67 @@
|
||||
'''
|
||||
Here we're abusing _asyncios TaskQueue to implement a very simple priority
|
||||
queue task scheduler.
|
||||
Despite documentation, Circuitpython doesn't usually ship with a min-heap
|
||||
module; it does however implement a pairing-heap for `TaskQueue` in native code.
|
||||
'''
|
||||
|
||||
try:
|
||||
from typing import Callable
|
||||
except ImportError:
|
||||
pass
|
||||
|
||||
from supervisor import ticks_ms
|
||||
|
||||
from _asyncio import Task, TaskQueue
|
||||
|
||||
from kmk.kmktime import ticks_add, ticks_diff
|
||||
|
||||
_task_queue = TaskQueue()
|
||||
|
||||
|
||||
class PeriodicTaskMeta:
|
||||
def __init__(self, func: Callable[[None], None], period: int) -> None:
|
||||
self._task = Task(self.call)
|
||||
self._coro = func
|
||||
self.period = period
|
||||
|
||||
def call(self) -> None:
|
||||
self._coro()
|
||||
after_ms = ticks_add(self._task.ph_key, self.period)
|
||||
_task_queue.push_sorted(self._task, after_ms)
|
||||
|
||||
|
||||
def create_task(
|
||||
func: Callable[[None], None],
|
||||
*,
|
||||
after_ms: int = 0,
|
||||
period_ms: int = 0,
|
||||
) -> [Task, PeriodicTaskMeta]:
|
||||
if period_ms:
|
||||
r = PeriodicTaskMeta(func, period_ms)
|
||||
t = r._task
|
||||
else:
|
||||
t = r = Task(func)
|
||||
|
||||
if after_ms:
|
||||
after_ms = ticks_add(ticks_ms(), after_ms)
|
||||
_task_queue.push_sorted(t, after_ms)
|
||||
else:
|
||||
_task_queue.push_head(t)
|
||||
|
||||
return r
|
||||
|
||||
|
||||
def get_due_task() -> [Callable, None]:
|
||||
while True:
|
||||
t = _task_queue.peek()
|
||||
if not t or ticks_diff(t.ph_key, ticks_ms()) > 0:
|
||||
break
|
||||
_task_queue.pop_head()
|
||||
yield t.coro
|
||||
|
||||
|
||||
def cancel_task(t: [Task, PeriodicTaskMeta]) -> None:
|
||||
if isinstance(t, PeriodicTaskMeta):
|
||||
t = t._task
|
||||
_task_queue.remove(t)
|
@ -6,6 +6,7 @@ from kmk.keys import KC, ModifierKey
|
||||
from kmk.kmk_keyboard import KMKKeyboard
|
||||
from kmk.scanners import DiodeOrientation
|
||||
from kmk.scanners.digitalio import MatrixScanner
|
||||
from kmk.scheduler import _task_queue
|
||||
|
||||
|
||||
class DigitalInOut(Mock):
|
||||
@ -81,7 +82,7 @@ class KeyboardTest:
|
||||
timeout = time.time_ns() + 10 * 1_000_000_000
|
||||
while timeout > time.time_ns():
|
||||
self.do_main_loop()
|
||||
if not self.keyboard._timeouts and not self.keyboard._resume_buffer:
|
||||
if not _task_queue.peek() and not self.keyboard._resume_buffer:
|
||||
break
|
||||
assert timeout > time.time_ns(), 'infinite loop detected'
|
||||
|
||||
|
@ -9,6 +9,10 @@ class KeyEvent:
|
||||
self.pressed = pressed
|
||||
|
||||
|
||||
def ticks_ms():
|
||||
return (time.time_ns() // 1_000_000) % (1 << 29)
|
||||
|
||||
|
||||
def init_circuit_python_modules_mocks():
|
||||
sys.modules['usb_hid'] = Mock()
|
||||
sys.modules['digitalio'] = Mock()
|
||||
@ -26,4 +30,8 @@ def init_circuit_python_modules_mocks():
|
||||
sys.modules['micropython'].const = lambda x: x
|
||||
|
||||
sys.modules['supervisor'] = Mock()
|
||||
sys.modules['supervisor'].ticks_ms = lambda: time.time_ns() // 1_000_000
|
||||
sys.modules['supervisor'].ticks_ms = ticks_ms
|
||||
|
||||
from . import task
|
||||
|
||||
sys.modules['_asyncio'] = task
|
||||
|
196
tests/task.py
Normal file
196
tests/task.py
Normal file
@ -0,0 +1,196 @@
|
||||
# MicroPython uasyncio module
|
||||
# MIT license; Copyright (c) 2019-2020 Damien P. George
|
||||
|
||||
# This file contains the core TaskQueue based on a pairing heap, and the core Task class.
|
||||
# They can optionally be replaced by C implementations.
|
||||
|
||||
# This file is a modified version, based on the extmod in Circuitpython, for
|
||||
# unit testing in KMK only.
|
||||
|
||||
from supervisor import ticks_ms
|
||||
|
||||
from kmk.kmktime import ticks_diff
|
||||
|
||||
cur_task = None
|
||||
__task_queue = None
|
||||
|
||||
|
||||
class CancelledError(BaseException):
|
||||
pass
|
||||
|
||||
|
||||
# pairing-heap meld of 2 heaps; O(1)
|
||||
def ph_meld(h1, h2):
|
||||
if h1 is None:
|
||||
return h2
|
||||
if h2 is None:
|
||||
return h1
|
||||
lt = ticks_diff(h1.ph_key, h2.ph_key) < 0
|
||||
if lt:
|
||||
if h1.ph_child is None:
|
||||
h1.ph_child = h2
|
||||
else:
|
||||
h1.ph_child_last.ph_next = h2
|
||||
h1.ph_child_last = h2
|
||||
h2.ph_next = None
|
||||
h2.ph_rightmost_parent = h1
|
||||
return h1
|
||||
else:
|
||||
h1.ph_next = h2.ph_child
|
||||
h2.ph_child = h1
|
||||
if h1.ph_next is None:
|
||||
h2.ph_child_last = h1
|
||||
h1.ph_rightmost_parent = h2
|
||||
return h2
|
||||
|
||||
|
||||
# pairing-heap pairing operation; amortised O(log N)
|
||||
def ph_pairing(child):
|
||||
heap = None
|
||||
while child is not None:
|
||||
n1 = child
|
||||
child = child.ph_next
|
||||
n1.ph_next = None
|
||||
if child is not None:
|
||||
n2 = child
|
||||
child = child.ph_next
|
||||
n2.ph_next = None
|
||||
n1 = ph_meld(n1, n2)
|
||||
heap = ph_meld(heap, n1)
|
||||
return heap
|
||||
|
||||
|
||||
# pairing-heap delete of a node; stable, amortised O(log N)
|
||||
def ph_delete(heap, node):
|
||||
if node is heap:
|
||||
child = heap.ph_child
|
||||
node.ph_child = None
|
||||
return ph_pairing(child)
|
||||
# Find parent of node
|
||||
parent = node
|
||||
while parent.ph_next is not None:
|
||||
parent = parent.ph_next
|
||||
parent = parent.ph_rightmost_parent
|
||||
if parent is None or parent.ph_child is None:
|
||||
return heap
|
||||
# Replace node with pairing of its children
|
||||
if node is parent.ph_child and node.ph_child is None:
|
||||
parent.ph_child = node.ph_next
|
||||
node.ph_next = None
|
||||
return heap
|
||||
elif node is parent.ph_child:
|
||||
child = node.ph_child
|
||||
next = node.ph_next
|
||||
node.ph_child = None
|
||||
node.ph_next = None
|
||||
node = ph_pairing(child)
|
||||
parent.ph_child = node
|
||||
else:
|
||||
n = parent.ph_child
|
||||
while node is not n.ph_next:
|
||||
n = n.ph_next
|
||||
if not n:
|
||||
return heap
|
||||
child = node.ph_child
|
||||
next = node.ph_next
|
||||
node.ph_child = None
|
||||
node.ph_next = None
|
||||
node = ph_pairing(child)
|
||||
if node is None:
|
||||
node = n
|
||||
else:
|
||||
n.ph_next = node
|
||||
node.ph_next = next
|
||||
if next is None:
|
||||
node.ph_rightmost_parent = parent
|
||||
parent.ph_child_last = node
|
||||
return heap
|
||||
|
||||
|
||||
# TaskQueue class based on the above pairing-heap functions.
|
||||
class TaskQueue:
|
||||
def __init__(self):
|
||||
self.heap = None
|
||||
|
||||
def peek(self):
|
||||
return self.heap
|
||||
|
||||
def push_sorted(self, v, key):
|
||||
v.data = None
|
||||
v.ph_key = key
|
||||
v.ph_child = None
|
||||
v.ph_next = None
|
||||
self.heap = ph_meld(v, self.heap)
|
||||
|
||||
def push_head(self, v):
|
||||
self.push_sorted(v, ticks_ms())
|
||||
|
||||
def pop_head(self):
|
||||
v = self.heap
|
||||
self.heap = ph_pairing(v.ph_child)
|
||||
# v.ph_child = None
|
||||
return v
|
||||
|
||||
def remove(self, v):
|
||||
self.heap = ph_delete(self.heap, v)
|
||||
|
||||
|
||||
# Task class representing a coroutine, can be waited on and cancelled.
|
||||
class Task:
|
||||
def __init__(self, coro, globals=None):
|
||||
self.coro = coro # Coroutine of this Task
|
||||
self.data = None # General data for queue it is waiting on
|
||||
self.state = True # None, False, True or a TaskQueue instance
|
||||
self.ph_key = 0 # Pairing heap
|
||||
self.ph_child = None # Paring heap
|
||||
self.ph_child_last = None # Paring heap
|
||||
self.ph_next = None # Paring heap
|
||||
self.ph_rightmost_parent = None # Paring heap
|
||||
|
||||
def __await__(self):
|
||||
if not self.state:
|
||||
# Task finished, signal that is has been await'ed on.
|
||||
self.state = False
|
||||
elif self.state is True:
|
||||
# Allocated head of linked list of Tasks waiting on completion of this task.
|
||||
self.state = TaskQueue()
|
||||
return self
|
||||
|
||||
def __next__(self):
|
||||
if not self.state:
|
||||
if self.data is None:
|
||||
# Task finished but has already been sent to the loop's exception handler.
|
||||
raise StopIteration
|
||||
else:
|
||||
# Task finished, raise return value to caller so it can continue.
|
||||
raise self.data
|
||||
else:
|
||||
# Put calling task on waiting queue.
|
||||
self.state.push_head(cur_task)
|
||||
# Set calling task's data to this task that it waits on, to double-link it.
|
||||
cur_task.data = self
|
||||
|
||||
def done(self):
|
||||
return not self.state
|
||||
|
||||
def cancel(self):
|
||||
# Check if task is already finished.
|
||||
if not self.state:
|
||||
return False
|
||||
# Can't cancel self (not supported yet).
|
||||
if self is cur_task:
|
||||
raise RuntimeError("can't cancel self")
|
||||
# If Task waits on another task then forward the cancel to the one it's waiting on.
|
||||
while isinstance(self.data, Task):
|
||||
self = self.data
|
||||
# Reschedule Task as a cancelled task.
|
||||
if hasattr(self.data, 'remove'):
|
||||
# Not on the main running queue, remove the task from the queue it's on.
|
||||
self.data.remove(self)
|
||||
__task_queue.push_head(self)
|
||||
elif ticks_diff(self.ph_key, ticks_ms()) > 0:
|
||||
# On the main running queue but scheduled in the future, so bring it forward to now.
|
||||
__task_queue.remove(self)
|
||||
__task_queue.push_head(self)
|
||||
self.data = CancelledError
|
||||
return True
|
Loading…
Reference in New Issue
Block a user