Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
77 changes: 77 additions & 0 deletions examples/jig_driver.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
MuxGroup,
PinValueAddressHandler,
VirtualSwitch,
RelayMatrixMux,
)


Expand Down Expand Up @@ -59,3 +60,79 @@ class JigMuxGroup(MuxGroup):
jig.mux.mux_two("sig5")
jig.mux.mux_three("On")
jig.mux.mux_three(False)


# VirtualMuxes can be made generic
from typing import Literal

# note: the type keyword can't be used inside functions!
# generally we want to use type to avoid confusion around the type system
# this makes it clear we are creating something for typehinting
# e.g type MyInt = int - won't work in functions
# variable = int - is not obvious what the intent is and can behave differently depending on its scope

# the type keyword can be used to create reusable definitions
# otherwise Literal can be used directly
type MyTypedMuxSignals = Literal["signal_1", "signal_2"]


def do_some_stuff():
# otherwise the mux is created as normal
class MyTypedMux(VirtualMux[MyTypedMuxSignals]):
pin_list = ("x0", "x1")
map_list = (
("signal_1", "x0"),
("signal_2", "x1"),
)

mymux = MyTypedMux()

# signal names will appear in the autocompletion options (including the empty signal "")
mymux.multiplex("")
mymux.multiplex("signal_1")
mymux.multiplex("signal_2")

# anything that isn't a signal will be flagged
try:
mymux.multiplex("not_a_signal")
except ValueError as e:
print(e)

# the annotations can also be used directly with Literal
class MyDirectlyTypedMux(VirtualMux[Literal["Sig_1", "Sig_2"]]):
pin_list = ("x0", "x1")
# Note neither definition currently point out the incorrect signal mapping below!
# it is still up to the user to set up muxes correctly
map_list = (
("signal_1", "x0"),
("signal_2", "x1"),
)

# suggestions will work as normal
myothermux = MyDirectlyTypedMux()

myothermux.multiplex("")
myothermux.multiplex("Sig_1")
myothermux.multiplex("Sig_2")

try:
myothermux.multiplex("not_a_signal")
except ValueError as e:
print(e)

# general subclasses and RelayMatrixMux also work with this
# currently VirtualSwitch doesn't, it creates its own signal names doesn't really benefit from this
class MyTypedRelay(RelayMatrixMux[MyTypedMuxSignals]):
pin_list = ("x3", "x4")
map_list = (
("signal_1", "x3"),
("signal_2", "x4"),
)

myrelay = MyTypedRelay()
myrelay.multiplex("")
myrelay.multiplex("signal_1")
myrelay.multiplex("signal_2")

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could also show a concrete example of the alternate you hint at in the comment above

class MyTypedMux(VirtualMux[Literal["signal_1", "signal_2"]]):
        pin_list = ("x0", "x1")
        map_list = (
            ("signal_1", "x0"),
            ("signal_2", "x1"),
        )

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done


do_some_stuff()
88 changes: 43 additions & 45 deletions src/fixate/_switching.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,28 +32,28 @@
import itertools
import time
from typing import (
Generic,
Optional,
Callable,
Sequence,
TypeVar,
Generator,
Union,
Collection,
Dict,
FrozenSet,
Iterable,
TypeGuard,
Any,
Literal,
)
from dataclasses import dataclass
from functools import reduce
from operator import or_

Signal = str
Pin = str
PinList = Sequence[Pin]
PinSet = FrozenSet[Pin]
SignalMap = Dict[Signal, PinSet]
TreeDef = Sequence[Union[Signal, "TreeDef"]]
type Signal = str
type EmptySignal = Literal[""]
type Pin = str
type PinList = Sequence[Pin]
type PinSet = frozenset[Pin]
type PinUpdateCallback = Callable[[PinUpdate, bool], None]
type MuxSignal[M: Signal] = M | EmptySignal
type SignalMap[S: Signal] = dict[S, PinSet]
type TreeDef[S: Signal] = Sequence[S | "TreeDef"]


@dataclass(frozen=True)
Expand Down Expand Up @@ -85,17 +85,23 @@ def __or__(self, other: PinUpdate) -> PinUpdate:
return NotImplemented


PinUpdateCallback = Callable[[PinUpdate, bool], None]


class VirtualMux:
class VirtualMux[S: Signal]:
# define the union of what the user supplied and the automatically created
# signal here so we don't have to keep typing this union everywhere
pin_list: PinList = ()
clearing_time: float = 0.0

###########################################################################
# These methods are the public API for the class

def __init__(self, update_pins: Optional[PinUpdateCallback] = None):
def isSignal(self, obj: Any) -> TypeGuard[S]:
# we can tell the typechecker about our user specified signals here
# at runtime we just check if this is a string
# in the future S can be inspected using get_origin, get_args
# resolve_bases and get_original_bases
return isinstance(obj, Signal.__value__)

def __init__(self, update_pins: PinUpdateCallback | None = None):
self._last_update_time = time.monotonic()

self._update_pins: PinUpdateCallback
Expand All @@ -110,9 +116,9 @@ def __init__(self, update_pins: Optional[PinUpdateCallback] = None):
# we convert here and keep a reference to the set for future use.
self._pin_set = frozenset(self.pin_list)

self._state = ""
self._state: MuxSignal[S] = ""

self._signal_map: SignalMap = self._map_signals()
self._signal_map: SignalMap[MuxSignal[S]] = self._map_signals()

# Define the implicit signal "" which can be used to turn off all pins.
# If the signal map already has this defined, raise an error. In the old
Expand All @@ -129,7 +135,7 @@ def __init__(self, update_pins: Optional[PinUpdateCallback] = None):
if hasattr(self, "default_signal"):
raise ValueError("'default_signal' should not be set on a VirtualMux")

def __call__(self, signal: Signal, trigger_update: bool = True) -> None:
def __call__(self, signal: MuxSignal[S], trigger_update: bool = True) -> None:
"""
Convenience to avoid having to type jig.mux.<MuxName>.multiplex.

Expand All @@ -138,7 +144,7 @@ def __call__(self, signal: Signal, trigger_update: bool = True) -> None:
"""
self.multiplex(signal, trigger_update)

def multiplex(self, signal: Signal, trigger_update: bool = True) -> None:
def multiplex(self, signal: MuxSignal[S], trigger_update: bool = True) -> None:
"""
Update the multiplexer state to signal.

Expand All @@ -163,7 +169,7 @@ def multiplex(self, signal: Signal, trigger_update: bool = True) -> None:
self._last_update_time = time.monotonic()
self._state = signal

def all_signals(self) -> tuple[Signal, ...]:
def all_signals(self) -> tuple[MuxSignal[S], ...]:
return tuple(self._signal_map.keys())

def reset(self, trigger_update: bool = True) -> None:
Expand Down Expand Up @@ -191,7 +197,7 @@ def pins(self) -> frozenset[Pin]:
# The following methods are potential candidates to override in a subclass

def _calculate_pins(
self, old_signal: Signal, new_signal: Signal
self, old_signal: MuxSignal[S], new_signal: MuxSignal[S]
) -> tuple[PinSetState, PinSetState]:
"""
Calculate the pin sets for the two-step state change.
Expand All @@ -218,7 +224,7 @@ def _calculate_pins(
# The following methods are intended as implementation detail and
# subclasses should avoid overriding.

def _map_signals(self) -> SignalMap:
def _map_signals(self) -> SignalMap[MuxSignal[S]]:
"""
Default implementation of the signal mapping

Expand All @@ -239,7 +245,9 @@ def _map_signals(self) -> SignalMap:
"VirtualMux subclass must define either map_tree or map_list"
)

def _map_tree(self, tree: TreeDef, pins: PinList, fixed_pins: PinSet) -> SignalMap:
def _map_tree(
self, tree: TreeDef[S], pins: PinList, fixed_pins: PinSet
) -> SignalMap[MuxSignal[S]]:
"""recursively add nested signal lists to the signal map.
tree: is the current sub-branch to be added. At the first call
level, this would be initialised with self.map_tree. It can be
Expand Down Expand Up @@ -399,7 +407,7 @@ class Mux(VirtualMux):
mux_b = TreeMap(("a1_b0", "a1_b1", "a1_b2", None), ("x2", "x3"))
map_tree = TreeMap(("a0", mux_b, "a2", mux_c), ("x1", "x0"))
"""
signal_map: SignalMap = dict()
signal_map: SignalMap[MuxSignal[S]] = dict()

bits_at_this_level = (len(tree) - 1).bit_length()
pins_at_this_level = pins[:bits_at_this_level]
Expand All @@ -409,7 +417,7 @@ class Mux(VirtualMux):
):
if signal_or_tree is None:
continue
if isinstance(signal_or_tree, Signal):
if self.isSignal(signal_or_tree):
signal_map[signal_or_tree] = frozenset(pins_for_signal) | fixed_pins
else:
signal_map.update(
Expand Down Expand Up @@ -456,9 +464,7 @@ class VirtualSwitch(VirtualMux):
pin_name: Pin = ""
map_tree = ("Off", "On")

def multiplex(
self, signal: Union[Signal, bool], trigger_update: bool = True
) -> None:
def multiplex(self, signal: Signal | bool, trigger_update: bool = True) -> None:
if signal is True:
converted_signal = "On"
elif signal is False:
Expand All @@ -467,26 +473,24 @@ def multiplex(
converted_signal = signal
super().multiplex(converted_signal, trigger_update=trigger_update)

def __call__(
self, signal: Union[Signal, bool], trigger_update: bool = True
) -> None:
def __call__(self, signal: Signal | bool, trigger_update: bool = True) -> None:
"""Override call to set the type on signal_output correctly."""
self.multiplex(signal, trigger_update)

def __init__(
self,
update_pins: Optional[PinUpdateCallback] = None,
update_pins: PinUpdateCallback | None = None,
):
if not self.pin_list:
self.pin_list = [self.pin_name]
super().__init__(update_pins)


class RelayMatrixMux(VirtualMux):
class RelayMatrixMux[S: Signal](VirtualMux[S]):
clearing_time = 0.01

def _calculate_pins(
self, old_signal: Signal, new_signal: Signal
self, old_signal: MuxSignal[S], new_signal: MuxSignal[S]
) -> tuple[PinSetState, PinSetState]:
"""
Override of _calculate_pins to implement break-before-make switching.
Expand Down Expand Up @@ -684,10 +688,7 @@ def active_signals(self) -> list[str]:
return [str(mux) for mux in self.get_multiplexers()]


JigSpecificMuxGroup = TypeVar("JigSpecificMuxGroup", bound=MuxGroup)


class JigDriver(Generic[JigSpecificMuxGroup]):
class JigDriver[M: MuxGroup]():
"""
Combine multiple VirtualMux's and multiple AddressHandler's.

Expand All @@ -696,7 +697,7 @@ class JigDriver(Generic[JigSpecificMuxGroup]):

def __init__(
self,
mux_group_factory: Callable[[], JigSpecificMuxGroup],
mux_group_factory: Callable[[], M],
handlers: Sequence[AddressHandler],
):
# keep a reference to handlers so that we can close them if required.
Expand Down Expand Up @@ -774,10 +775,7 @@ def _validate(self) -> None:
)


_T = TypeVar("_T")


def _generate_bit_sets(bits: Sequence[_T]) -> Generator[set[_T], None, None]:
def _generate_bit_sets[_T](bits: Sequence[_T]) -> Generator[set[_T], None, None]:
"""
Create subsets of bits, representing bits of a list of integers

Expand Down
49 changes: 48 additions & 1 deletion test/test_switching.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Collection, Sequence
from typing import Collection, Sequence, Literal

from fixate._switching import (
Pin,
Expand Down Expand Up @@ -289,6 +289,13 @@ class MuxA(VirtualMux):
map_list = (("sig_a1", "a0", "a1"), ("sig_a2", "a1"))


class MuxATyped(VirtualMux[Literal["sig_a1", "sig_a2"]]):
"""A mux definition used by a few tests"""

pin_list = ("a0", "a1")
map_list = (("sig_a1", "a0", "a1"), ("sig_a2", "a1"))


def test_virtual_mux_basic():
updates = []
mux_a = MuxA(lambda x, y: updates.append((x, y)))
Expand All @@ -309,6 +316,46 @@ def test_virtual_mux_basic():
]


def test_virtual_mux_basic_typed():
updates = []
mux_a = MuxATyped(lambda x, y: updates.append((x, y)))

# test both the __call__ and multiplex methods trigger
# the appropriate update callback.
mux_a("sig_a1")
mux_a.multiplex("sig_a2", trigger_update=False)
mux_a("")

clear = PinSetState(off=frozenset({"a0", "a1"}))
a1 = PinSetState(on=frozenset({"a0", "a1"}))
a2 = PinSetState(on=frozenset({"a1"}), off=frozenset({"a0"}))
assert updates == [
(PinUpdate(PinSetState(), a1), True),
(PinUpdate(PinSetState(), a2), False),
(PinUpdate(PinSetState(), clear), True),
]


@pytest.mark.xfail(reason="Signal narrowowing not implemented")
def test_virtual_mux_typed_isSignal():
mux_a = MuxATyped()

assert mux_a.isSignal("sig_a1") # should pass
assert not mux_a.isSignal("") # this shouldn't be supplied by the user
assert not mux_a.isSignal(1) # wrong type
assert not mux_a.isSignal("1") # not a signal for MuxATyped - not yet implemented


def test_virtual_mux_isSignal():
mux_a = MuxA()
# this mux isn't type, so anything that is a string should pass this
# to check we don't accidentally break untyped muxes in the future
assert mux_a.isSignal("sig_a1") # should pass
assert mux_a.isSignal("") # should pass
assert not mux_a.isSignal(1) # wrong type
assert mux_a.isSignal("1") # should pass


def test_virtual_mux_reset():
"""Check that reset sends an update that sets all pins off"""

Expand Down
Loading