bitmask/test_bitmask.py

192 lines
5.4 KiB
Python

from bitmask import Bitmask
from enum import IntFlag
import unittest
class Desc(IntFlag):
SMALL = 1
ROUND = 1 << 1
FUNKY = 1 << 2
class TestBitmask(unittest.TestCase):
def setUp(self):
"""Initialize `Bitmask` instances for testing."""
self.bmask = Bitmask(Desc, Desc.SMALL, Desc.FUNKY)
self.bmask2 = Bitmask(Desc, Desc.ROUND)
self.bmask3 = Bitmask(Desc, Desc.ROUND)
self.bmask_empty = Bitmask(Desc)
def test_eq(self):
"""Test equality."""
self.assertEqual(self.bmask2, self.bmask3)
self.assertNotEqual(self.bmask, self.bmask3)
self.assertNotEqual(self.bmask, self.bmask_empty)
self.assertNotEqual(self.bmask, Desc.SMALL)
self.assertNotEqual(self.bmask, Desc.ROUND)
self.assertNotEqual(self.bmask_empty, Desc.ROUND)
def test_repr(self):
"""Test __repr__."""
self.assertEqual(eval(repr(self.bmask)), self.bmask)
self.assertEqual(eval(repr(self.bmask_empty)), self.bmask_empty)
def test_add(self):
"""Test the `Bitmask.add()` method."""
self.bmask.add(Desc.ROUND)
self.assertEqual(
self.bmask,
Bitmask(Desc, Desc.SMALL, Desc.FUNKY, Desc.ROUND)
)
self.bmask_empty.add(Desc.ROUND)
self.bmask_empty.add(Desc.ROUND)
self.assertEqual(
self.bmask_empty,
Bitmask(Desc, Desc.ROUND)
)
def test_add_operator(self):
"""Test the + operator."""
self.assertEqual(
self.bmask + Desc.ROUND,
Bitmask(Desc, Desc.SMALL, Desc.FUNKY, Desc.ROUND)
)
self.assertEqual(
self.bmask_empty + Desc.ROUND,
Bitmask(Desc, Desc.ROUND)
)
# Test `__radd__`.
self.assertEqual(
Desc.ROUND + self.bmask,
Bitmask(Desc, Desc.SMALL, Desc.FUNKY, Desc.ROUND)
)
# Test combining bitmasks
self.assertEqual(
self.bmask + self.bmask2 + self.bmask3,
Bitmask(Desc, Desc.SMALL, Desc.FUNKY, Desc.ROUND)
)
def test_remove(self):
"""Test the `Bitmask.remove()` method."""
self.bmask.remove(Desc.SMALL)
self.assertEqual(
self.bmask,
Bitmask(Desc, Desc.FUNKY)
)
with self.assertRaises(KeyError):
self.bmask.remove(Desc.SMALL)
with self.assertRaises(KeyError):
self.bmask_empty.remove(Desc.SMALL)
with self.assertRaises(TypeError):
self.bmask_empty.remove(self.bmask2)
def test_discard(self):
"""Test the `Bitmask.discard()` method."""
self.bmask.discard(Desc.SMALL)
self.assertEqual(
self.bmask,
Bitmask(Desc, Desc.FUNKY)
)
self.bmask.discard(Desc.SMALL)
self.assertEqual(
self.bmask,
Bitmask(Desc, Desc.FUNKY)
)
self.bmask_empty.discard(Desc.SMALL)
self.assertEqual(
self.bmask_empty,
Bitmask(Desc)
)
with self.assertRaises(TypeError):
self.bmask_empty.remove(self.bmask2)
def test_subtract(self):
"""Test various subtraction operators."""
# Operation
self.assertEqual(
self.bmask - Desc.SMALL,
Bitmask(Desc, Desc.FUNKY)
)
# Assignment
self.bmask -= Desc.SMALL
self.assertEqual(
self.bmask,
Bitmask(Desc, Desc.FUNKY)
)
def test_value(self):
"""Ensure Bitmask.value lines up with the state."""
self.assertEqual(self.bmask.value, 5)
self.bmask.add(Desc.ROUND)
self.assertEqual(self.bmask.value, 7)
self.assertEqual(self.bmask_empty.value, 0)
# Setting values directly
self.bmask.value = 0
self.assertEqual(
self.bmask,
Bitmask(Desc)
)
self.bmask.value = 1
self.assertEqual(
self.bmask,
Bitmask(Desc, Desc.SMALL)
)
with self.assertRaises(TypeError):
self.bmask.value = 1j
with self.assertRaises(TypeError):
self.bmask.value = 2.5
def test_contains(self):
"""Test `flag in mask` check."""
self.assertIn(Desc.FUNKY, self.bmask)
self.assertIn(Desc.SMALL, self.bmask)
self.assertNotIn(Desc.ROUND, self.bmask)
self.assertNotIn(Desc.ROUND, self.bmask_empty)
with self.assertRaises(TypeError):
self.bmask in self.bmask
def test_iter(self):
"""Test iteration."""
self.assertEqual(
[i for i in self.bmask],
[Desc.SMALL, Desc.FUNKY]
)
def test_str(self):
"""Test string conversion."""
self.assertEqual(
str(self.bmask),
"SMALL|FUNKY"
)
self.bmask.add(Desc.ROUND)
self.assertEqual(
str(self.bmask),
"SMALL|ROUND|FUNKY"
)
self.assertEqual(
str(self.bmask_empty),
"0"
)
self.bmask_empty.add(Desc.ROUND)
self.assertEqual(
str(self.bmask_empty),
"ROUND"
)
def test_int(self):
"""Test int conversion."""
self.assertEqual(
int(self.bmask),
self.bmask.value
)
self.bmask.value = 4
self.assertEqual(
int(self.bmask),
self.bmask.value
)
if __name__ == '__main__':
unittest.main()