192 lines
5.4 KiB
Python
192 lines
5.4 KiB
Python
from bitmask import Bitmask
|
|
from enum import IntFlag
|
|
import unittest
|
|
|
|
class Desc(IntFlag):
|
|
SMALL = 1
|
|
ROUND = 1 << 1
|
|
FUNKY = 1 << 2
|
|
|
|
class TestBitmask(unittest.TestCase):
|
|
def setUp(self):
|
|
"""Initialize `Bitmask` instances for testing."""
|
|
self.bmask = Bitmask(Desc, Desc.SMALL, Desc.FUNKY)
|
|
self.bmask2 = Bitmask(Desc, Desc.ROUND)
|
|
self.bmask3 = Bitmask(Desc, Desc.ROUND)
|
|
self.bmask_empty = Bitmask(Desc)
|
|
|
|
def test_eq(self):
|
|
"""Test equality."""
|
|
self.assertEqual(self.bmask2, self.bmask3)
|
|
self.assertNotEqual(self.bmask, self.bmask3)
|
|
self.assertNotEqual(self.bmask, self.bmask_empty)
|
|
self.assertNotEqual(self.bmask, Desc.SMALL)
|
|
self.assertNotEqual(self.bmask, Desc.ROUND)
|
|
self.assertNotEqual(self.bmask_empty, Desc.ROUND)
|
|
|
|
def test_repr(self):
|
|
"""Test __repr__."""
|
|
self.assertEqual(eval(repr(self.bmask)), self.bmask)
|
|
self.assertEqual(eval(repr(self.bmask_empty)), self.bmask_empty)
|
|
|
|
def test_add(self):
|
|
"""Test the `Bitmask.add()` method."""
|
|
self.bmask.add(Desc.ROUND)
|
|
self.assertEqual(
|
|
self.bmask,
|
|
Bitmask(Desc, Desc.SMALL, Desc.FUNKY, Desc.ROUND)
|
|
)
|
|
|
|
self.bmask_empty.add(Desc.ROUND)
|
|
self.bmask_empty.add(Desc.ROUND)
|
|
self.assertEqual(
|
|
self.bmask_empty,
|
|
Bitmask(Desc, Desc.ROUND)
|
|
)
|
|
|
|
def test_add_operator(self):
|
|
"""Test the + operator."""
|
|
self.assertEqual(
|
|
self.bmask + Desc.ROUND,
|
|
Bitmask(Desc, Desc.SMALL, Desc.FUNKY, Desc.ROUND)
|
|
)
|
|
self.assertEqual(
|
|
self.bmask_empty + Desc.ROUND,
|
|
Bitmask(Desc, Desc.ROUND)
|
|
)
|
|
# Test `__radd__`.
|
|
self.assertEqual(
|
|
Desc.ROUND + self.bmask,
|
|
Bitmask(Desc, Desc.SMALL, Desc.FUNKY, Desc.ROUND)
|
|
)
|
|
# Test combining bitmasks
|
|
self.assertEqual(
|
|
self.bmask + self.bmask2 + self.bmask3,
|
|
Bitmask(Desc, Desc.SMALL, Desc.FUNKY, Desc.ROUND)
|
|
)
|
|
|
|
def test_remove(self):
|
|
"""Test the `Bitmask.remove()` method."""
|
|
self.bmask.remove(Desc.SMALL)
|
|
self.assertEqual(
|
|
self.bmask,
|
|
Bitmask(Desc, Desc.FUNKY)
|
|
)
|
|
with self.assertRaises(KeyError):
|
|
self.bmask.remove(Desc.SMALL)
|
|
with self.assertRaises(KeyError):
|
|
self.bmask_empty.remove(Desc.SMALL)
|
|
with self.assertRaises(TypeError):
|
|
self.bmask_empty.remove(self.bmask2)
|
|
|
|
def test_discard(self):
|
|
"""Test the `Bitmask.discard()` method."""
|
|
self.bmask.discard(Desc.SMALL)
|
|
self.assertEqual(
|
|
self.bmask,
|
|
Bitmask(Desc, Desc.FUNKY)
|
|
)
|
|
self.bmask.discard(Desc.SMALL)
|
|
self.assertEqual(
|
|
self.bmask,
|
|
Bitmask(Desc, Desc.FUNKY)
|
|
)
|
|
self.bmask_empty.discard(Desc.SMALL)
|
|
self.assertEqual(
|
|
self.bmask_empty,
|
|
Bitmask(Desc)
|
|
)
|
|
with self.assertRaises(TypeError):
|
|
self.bmask_empty.remove(self.bmask2)
|
|
|
|
def test_subtract(self):
|
|
"""Test various subtraction operators."""
|
|
# Operation
|
|
self.assertEqual(
|
|
self.bmask - Desc.SMALL,
|
|
Bitmask(Desc, Desc.FUNKY)
|
|
)
|
|
# Assignment
|
|
self.bmask -= Desc.SMALL
|
|
self.assertEqual(
|
|
self.bmask,
|
|
Bitmask(Desc, Desc.FUNKY)
|
|
)
|
|
|
|
def test_value(self):
|
|
"""Ensure Bitmask.value lines up with the state."""
|
|
self.assertEqual(self.bmask.value, 5)
|
|
self.bmask.add(Desc.ROUND)
|
|
self.assertEqual(self.bmask.value, 7)
|
|
|
|
self.assertEqual(self.bmask_empty.value, 0)
|
|
|
|
# Setting values directly
|
|
self.bmask.value = 0
|
|
self.assertEqual(
|
|
self.bmask,
|
|
Bitmask(Desc)
|
|
)
|
|
self.bmask.value = 1
|
|
self.assertEqual(
|
|
self.bmask,
|
|
Bitmask(Desc, Desc.SMALL)
|
|
)
|
|
with self.assertRaises(TypeError):
|
|
self.bmask.value = 1j
|
|
with self.assertRaises(TypeError):
|
|
self.bmask.value = 2.5
|
|
|
|
def test_contains(self):
|
|
"""Test `flag in mask` check."""
|
|
self.assertIn(Desc.FUNKY, self.bmask)
|
|
self.assertIn(Desc.SMALL, self.bmask)
|
|
self.assertNotIn(Desc.ROUND, self.bmask)
|
|
self.assertNotIn(Desc.ROUND, self.bmask_empty)
|
|
with self.assertRaises(TypeError):
|
|
self.bmask in self.bmask
|
|
|
|
def test_iter(self):
|
|
"""Test iteration."""
|
|
self.assertEqual(
|
|
[i for i in self.bmask],
|
|
[Desc.SMALL, Desc.FUNKY]
|
|
)
|
|
|
|
def test_str(self):
|
|
"""Test string conversion."""
|
|
self.assertEqual(
|
|
str(self.bmask),
|
|
"SMALL|FUNKY"
|
|
)
|
|
self.bmask.add(Desc.ROUND)
|
|
self.assertEqual(
|
|
str(self.bmask),
|
|
"SMALL|ROUND|FUNKY"
|
|
)
|
|
|
|
self.assertEqual(
|
|
str(self.bmask_empty),
|
|
"0"
|
|
)
|
|
self.bmask_empty.add(Desc.ROUND)
|
|
self.assertEqual(
|
|
str(self.bmask_empty),
|
|
"ROUND"
|
|
)
|
|
|
|
def test_int(self):
|
|
"""Test int conversion."""
|
|
self.assertEqual(
|
|
int(self.bmask),
|
|
self.bmask.value
|
|
)
|
|
self.bmask.value = 4
|
|
self.assertEqual(
|
|
int(self.bmask),
|
|
self.bmask.value
|
|
)
|
|
|
|
if __name__ == '__main__':
|
|
unittest.main()
|