diff --git a/examples/jig_driver.py b/examples/jig_driver.py index 977318a..8ec950f 100644 --- a/examples/jig_driver.py +++ b/examples/jig_driver.py @@ -11,6 +11,7 @@ MuxGroup, PinValueAddressHandler, VirtualSwitch, + RelayMatrixMux, ) @@ -59,3 +60,79 @@ class JigMuxGroup(MuxGroup): jig.mux.mux_two("sig5") jig.mux.mux_three("On") jig.mux.mux_three(False) + + +# VirtualMuxes can be made generic +from typing import Literal + +# note: the type keyword can't be used inside functions! +# generally we want to use type to avoid confusion around the type system +# this makes it clear we are creating something for typehinting +# e.g type MyInt = int - won't work in functions +# variable = int - is not obvious what the intent is and can behave differently depending on its scope + +# the type keyword can be used to create reusable definitions +# otherwise Literal can be used directly +type MyTypedMuxSignals = Literal["signal_1", "signal_2"] + + +def do_some_stuff(): + # otherwise the mux is created as normal + class MyTypedMux(VirtualMux[MyTypedMuxSignals]): + pin_list = ("x0", "x1") + map_list = ( + ("signal_1", "x0"), + ("signal_2", "x1"), + ) + + mymux = MyTypedMux() + + # signal names will appear in the autocompletion options (including the empty signal "") + mymux.multiplex("") + mymux.multiplex("signal_1") + mymux.multiplex("signal_2") + + # anything that isn't a signal will be flagged + try: + mymux.multiplex("not_a_signal") + except ValueError as e: + print(e) + + # the annotations can also be used directly with Literal + class MyDirectlyTypedMux(VirtualMux[Literal["Sig_1", "Sig_2"]]): + pin_list = ("x0", "x1") + # Note neither definition currently point out the incorrect signal mapping below! + # it is still up to the user to set up muxes correctly + map_list = ( + ("signal_1", "x0"), + ("signal_2", "x1"), + ) + + # suggestions will work as normal + myothermux = MyDirectlyTypedMux() + + myothermux.multiplex("") + myothermux.multiplex("Sig_1") + myothermux.multiplex("Sig_2") + + try: + myothermux.multiplex("not_a_signal") + except ValueError as e: + print(e) + + # general subclasses and RelayMatrixMux also work with this + # currently VirtualSwitch doesn't, it creates its own signal names doesn't really benefit from this + class MyTypedRelay(RelayMatrixMux[MyTypedMuxSignals]): + pin_list = ("x3", "x4") + map_list = ( + ("signal_1", "x3"), + ("signal_2", "x4"), + ) + + myrelay = MyTypedRelay() + myrelay.multiplex("") + myrelay.multiplex("signal_1") + myrelay.multiplex("signal_2") + + +do_some_stuff() diff --git a/src/fixate/_switching.py b/src/fixate/_switching.py index eaf671e..6b62282 100644 --- a/src/fixate/_switching.py +++ b/src/fixate/_switching.py @@ -32,28 +32,28 @@ import itertools import time from typing import ( - Generic, - Optional, Callable, Sequence, - TypeVar, Generator, - Union, Collection, - Dict, - FrozenSet, Iterable, + TypeGuard, + Any, + Literal, ) from dataclasses import dataclass from functools import reduce from operator import or_ -Signal = str -Pin = str -PinList = Sequence[Pin] -PinSet = FrozenSet[Pin] -SignalMap = Dict[Signal, PinSet] -TreeDef = Sequence[Union[Signal, "TreeDef"]] +type Signal = str +type EmptySignal = Literal[""] +type Pin = str +type PinList = Sequence[Pin] +type PinSet = frozenset[Pin] +type PinUpdateCallback = Callable[[PinUpdate, bool], None] +type MuxSignal[M: Signal] = M | EmptySignal +type SignalMap[S: Signal] = dict[S, PinSet] +type TreeDef[S: Signal] = Sequence[S | "TreeDef"] @dataclass(frozen=True) @@ -85,17 +85,23 @@ def __or__(self, other: PinUpdate) -> PinUpdate: return NotImplemented -PinUpdateCallback = Callable[[PinUpdate, bool], None] - - -class VirtualMux: +class VirtualMux[S: Signal]: + # define the union of what the user supplied and the automatically created + # signal here so we don't have to keep typing this union everywhere pin_list: PinList = () clearing_time: float = 0.0 ########################################################################### # These methods are the public API for the class - def __init__(self, update_pins: Optional[PinUpdateCallback] = None): + def isSignal(self, obj: Any) -> TypeGuard[S]: + # we can tell the typechecker about our user specified signals here + # at runtime we just check if this is a string + # in the future S can be inspected using get_origin, get_args + # resolve_bases and get_original_bases + return isinstance(obj, Signal.__value__) + + def __init__(self, update_pins: PinUpdateCallback | None = None): self._last_update_time = time.monotonic() self._update_pins: PinUpdateCallback @@ -110,9 +116,9 @@ def __init__(self, update_pins: Optional[PinUpdateCallback] = None): # we convert here and keep a reference to the set for future use. self._pin_set = frozenset(self.pin_list) - self._state = "" + self._state: MuxSignal[S] = "" - self._signal_map: SignalMap = self._map_signals() + self._signal_map: SignalMap[MuxSignal[S]] = self._map_signals() # Define the implicit signal "" which can be used to turn off all pins. # If the signal map already has this defined, raise an error. In the old @@ -129,7 +135,7 @@ def __init__(self, update_pins: Optional[PinUpdateCallback] = None): if hasattr(self, "default_signal"): raise ValueError("'default_signal' should not be set on a VirtualMux") - def __call__(self, signal: Signal, trigger_update: bool = True) -> None: + def __call__(self, signal: MuxSignal[S], trigger_update: bool = True) -> None: """ Convenience to avoid having to type jig.mux..multiplex. @@ -138,7 +144,7 @@ def __call__(self, signal: Signal, trigger_update: bool = True) -> None: """ self.multiplex(signal, trigger_update) - def multiplex(self, signal: Signal, trigger_update: bool = True) -> None: + def multiplex(self, signal: MuxSignal[S], trigger_update: bool = True) -> None: """ Update the multiplexer state to signal. @@ -163,7 +169,7 @@ def multiplex(self, signal: Signal, trigger_update: bool = True) -> None: self._last_update_time = time.monotonic() self._state = signal - def all_signals(self) -> tuple[Signal, ...]: + def all_signals(self) -> tuple[MuxSignal[S], ...]: return tuple(self._signal_map.keys()) def reset(self, trigger_update: bool = True) -> None: @@ -191,7 +197,7 @@ def pins(self) -> frozenset[Pin]: # The following methods are potential candidates to override in a subclass def _calculate_pins( - self, old_signal: Signal, new_signal: Signal + self, old_signal: MuxSignal[S], new_signal: MuxSignal[S] ) -> tuple[PinSetState, PinSetState]: """ Calculate the pin sets for the two-step state change. @@ -218,7 +224,7 @@ def _calculate_pins( # The following methods are intended as implementation detail and # subclasses should avoid overriding. - def _map_signals(self) -> SignalMap: + def _map_signals(self) -> SignalMap[MuxSignal[S]]: """ Default implementation of the signal mapping @@ -239,7 +245,9 @@ def _map_signals(self) -> SignalMap: "VirtualMux subclass must define either map_tree or map_list" ) - def _map_tree(self, tree: TreeDef, pins: PinList, fixed_pins: PinSet) -> SignalMap: + def _map_tree( + self, tree: TreeDef[S], pins: PinList, fixed_pins: PinSet + ) -> SignalMap[MuxSignal[S]]: """recursively add nested signal lists to the signal map. tree: is the current sub-branch to be added. At the first call level, this would be initialised with self.map_tree. It can be @@ -399,7 +407,7 @@ class Mux(VirtualMux): mux_b = TreeMap(("a1_b0", "a1_b1", "a1_b2", None), ("x2", "x3")) map_tree = TreeMap(("a0", mux_b, "a2", mux_c), ("x1", "x0")) """ - signal_map: SignalMap = dict() + signal_map: SignalMap[MuxSignal[S]] = dict() bits_at_this_level = (len(tree) - 1).bit_length() pins_at_this_level = pins[:bits_at_this_level] @@ -409,7 +417,7 @@ class Mux(VirtualMux): ): if signal_or_tree is None: continue - if isinstance(signal_or_tree, Signal): + if self.isSignal(signal_or_tree): signal_map[signal_or_tree] = frozenset(pins_for_signal) | fixed_pins else: signal_map.update( @@ -456,9 +464,7 @@ class VirtualSwitch(VirtualMux): pin_name: Pin = "" map_tree = ("Off", "On") - def multiplex( - self, signal: Union[Signal, bool], trigger_update: bool = True - ) -> None: + def multiplex(self, signal: Signal | bool, trigger_update: bool = True) -> None: if signal is True: converted_signal = "On" elif signal is False: @@ -467,26 +473,24 @@ def multiplex( converted_signal = signal super().multiplex(converted_signal, trigger_update=trigger_update) - def __call__( - self, signal: Union[Signal, bool], trigger_update: bool = True - ) -> None: + def __call__(self, signal: Signal | bool, trigger_update: bool = True) -> None: """Override call to set the type on signal_output correctly.""" self.multiplex(signal, trigger_update) def __init__( self, - update_pins: Optional[PinUpdateCallback] = None, + update_pins: PinUpdateCallback | None = None, ): if not self.pin_list: self.pin_list = [self.pin_name] super().__init__(update_pins) -class RelayMatrixMux(VirtualMux): +class RelayMatrixMux[S: Signal](VirtualMux[S]): clearing_time = 0.01 def _calculate_pins( - self, old_signal: Signal, new_signal: Signal + self, old_signal: MuxSignal[S], new_signal: MuxSignal[S] ) -> tuple[PinSetState, PinSetState]: """ Override of _calculate_pins to implement break-before-make switching. @@ -684,10 +688,7 @@ def active_signals(self) -> list[str]: return [str(mux) for mux in self.get_multiplexers()] -JigSpecificMuxGroup = TypeVar("JigSpecificMuxGroup", bound=MuxGroup) - - -class JigDriver(Generic[JigSpecificMuxGroup]): +class JigDriver[M: MuxGroup](): """ Combine multiple VirtualMux's and multiple AddressHandler's. @@ -696,7 +697,7 @@ class JigDriver(Generic[JigSpecificMuxGroup]): def __init__( self, - mux_group_factory: Callable[[], JigSpecificMuxGroup], + mux_group_factory: Callable[[], M], handlers: Sequence[AddressHandler], ): # keep a reference to handlers so that we can close them if required. @@ -774,10 +775,7 @@ def _validate(self) -> None: ) -_T = TypeVar("_T") - - -def _generate_bit_sets(bits: Sequence[_T]) -> Generator[set[_T], None, None]: +def _generate_bit_sets[_T](bits: Sequence[_T]) -> Generator[set[_T], None, None]: """ Create subsets of bits, representing bits of a list of integers diff --git a/test/test_switching.py b/test/test_switching.py index 611d9dc..ce4ee09 100644 --- a/test/test_switching.py +++ b/test/test_switching.py @@ -1,4 +1,4 @@ -from typing import Collection, Sequence +from typing import Collection, Sequence, Literal from fixate._switching import ( Pin, @@ -289,6 +289,13 @@ class MuxA(VirtualMux): map_list = (("sig_a1", "a0", "a1"), ("sig_a2", "a1")) +class MuxATyped(VirtualMux[Literal["sig_a1", "sig_a2"]]): + """A mux definition used by a few tests""" + + pin_list = ("a0", "a1") + map_list = (("sig_a1", "a0", "a1"), ("sig_a2", "a1")) + + def test_virtual_mux_basic(): updates = [] mux_a = MuxA(lambda x, y: updates.append((x, y))) @@ -309,6 +316,46 @@ def test_virtual_mux_basic(): ] +def test_virtual_mux_basic_typed(): + updates = [] + mux_a = MuxATyped(lambda x, y: updates.append((x, y))) + + # test both the __call__ and multiplex methods trigger + # the appropriate update callback. + mux_a("sig_a1") + mux_a.multiplex("sig_a2", trigger_update=False) + mux_a("") + + clear = PinSetState(off=frozenset({"a0", "a1"})) + a1 = PinSetState(on=frozenset({"a0", "a1"})) + a2 = PinSetState(on=frozenset({"a1"}), off=frozenset({"a0"})) + assert updates == [ + (PinUpdate(PinSetState(), a1), True), + (PinUpdate(PinSetState(), a2), False), + (PinUpdate(PinSetState(), clear), True), + ] + + +@pytest.mark.xfail(reason="Signal narrowowing not implemented") +def test_virtual_mux_typed_isSignal(): + mux_a = MuxATyped() + + assert mux_a.isSignal("sig_a1") # should pass + assert not mux_a.isSignal("") # this shouldn't be supplied by the user + assert not mux_a.isSignal(1) # wrong type + assert not mux_a.isSignal("1") # not a signal for MuxATyped - not yet implemented + + +def test_virtual_mux_isSignal(): + mux_a = MuxA() + # this mux isn't type, so anything that is a string should pass this + # to check we don't accidentally break untyped muxes in the future + assert mux_a.isSignal("sig_a1") # should pass + assert mux_a.isSignal("") # should pass + assert not mux_a.isSignal(1) # wrong type + assert mux_a.isSignal("1") # should pass + + def test_virtual_mux_reset(): """Check that reset sends an update that sets all pins off"""