diff --git a/bitmask.py b/bitmask.py index 5ace01c..f03451a 100644 --- a/bitmask.py +++ b/bitmask.py @@ -171,14 +171,40 @@ class Bitmask: return self.__mask_op(other, lambda a, b : a | b) def __radd__(self, other): + """Alias the + operator in reverse.""" return self.__add__(other) + def __iadd__(self, other): + """Union bitmasks/flags together. + + Aliased to `Bitmask.__add__`. + """ + return self + other + + def __sub__(self, other): + """Subtract by bitmask/flag.""" + return self.__mask_op(other, lambda a, b : a & ~b) + + def __isub__(self, other): + """Subtract a bitmask/flag. + + Aliased to `Bitmask.__sub__`. + """ + self = self - other + return self + def discard(self, flag): """Remove flag bitmask if present. This behaves the same as built-in `set.discard()`. + + Raises: + TypeError: `flag` is not a single Enum value. """ - self._flag_op(flag, lambda a, b : a & ~b) + if not issubclass(type(flag), self._AllFlags): + raise TypeError(f"can only discard {self.AllFlags} (not '{type(flag)}') from {type(self)}") + + return self._flag_op(flag, lambda a, b : a & ~b) def remove(self, flag): """Remove `flag` from the bitmask. @@ -189,5 +215,6 @@ class Bitmask: KeyError: flag is not in bitmask. """ if not flag in self: - raise KeyError(flag) + raise KeyError(type(flag), self.AllFlags) + self.discard(flag) diff --git a/test_bitmask.py b/test_bitmask.py index 8982642..bd6a0c9 100644 --- a/test_bitmask.py +++ b/test_bitmask.py @@ -76,6 +76,8 @@ class TestBitmask(unittest.TestCase): self.bmask.remove(Desc.SMALL) with self.assertRaises(KeyError): self.bmask_empty.remove(Desc.SMALL) + with self.assertRaises(TypeError): + self.bmask_empty.remove(self.bmask2) def test_discard(self): """Test the `Bitmask.discard()` method.""" @@ -94,6 +96,22 @@ class TestBitmask(unittest.TestCase): self.bmask_empty, Bitmask(Desc) ) + with self.assertRaises(TypeError): + self.bmask_empty.remove(self.bmask2) + + def test_subtract(self): + """Test various subtraction operators.""" + # Operation + self.assertEqual( + self.bmask - Desc.SMALL, + Bitmask(Desc, Desc.FUNKY) + ) + # Assignment + self.bmask -= Desc.SMALL + self.assertEqual( + self.bmask, + Bitmask(Desc, Desc.FUNKY) + ) def test_value(self): """Ensure Bitmask.value lines up with the state."""