diff options
Diffstat (limited to 'src/home/mqtt/payload/base_payload.py')
-rw-r--r-- | src/home/mqtt/payload/base_payload.py | 44 |
1 files changed, 20 insertions, 24 deletions
diff --git a/src/home/mqtt/payload/base_payload.py b/src/home/mqtt/payload/base_payload.py index c9ec907..108e0c0 100644 --- a/src/home/mqtt/payload/base_payload.py +++ b/src/home/mqtt/payload/base_payload.py @@ -1,7 +1,8 @@ import abc import struct +import re -from typing import Generic, TypeVar +from typing import Optional, Tuple class MQTTPayload(abc.ABC): @@ -20,15 +21,9 @@ class MQTTPayload(abc.ABC): bf_progress = 0 for field, field_type in self.__class__.__annotations__.items(): - field_type_origin = None - if hasattr(field_type, '__extra__') or hasattr(field_type, '__origin__'): - try: - field_type_origin = field_type.__extra__ - except AttributeError: - field_type_origin = field_type.__origin__ - - if field_type_origin is not None and issubclass(field_type_origin, MQTTPayloadBitField): - n, s, b = field_type.__args__ + bfp = _bit_field_params(field_type) + if bfp: + n, s, b = bfp if n != bf_number: if bf_number != -1: args.append(bf_arg) @@ -61,15 +56,9 @@ class MQTTPayload(abc.ABC): bf_progress = 0 for field, field_type in cls.__annotations__.items(): - field_type_origin = None - if hasattr(field_type, '__extra__') or hasattr(field_type, '__origin__'): - try: - field_type_origin = field_type.__extra__ - except AttributeError: - field_type_origin = field_type.__origin__ - - if field_type_origin is not None and issubclass(field_type_origin, MQTTPayloadBitField): - n, s, b = field_type.__args__ + bfp = _bit_field_params(field_type) + if bfp: + n, s, b = bfp if n != bf_number: bf_number = n bf_progress = 0 @@ -86,6 +75,7 @@ class MQTTPayload(abc.ABC): else: kwargs[field] = cls._unpack_field(field, data[i]) i += 1 + return cls(**kwargs) def _pack_field(self, name): @@ -120,10 +110,16 @@ class MQTTPayloadCustomField(abc.ABC): pass -NT = TypeVar('NT') # number of bit field -ST = TypeVar('ST') # size in bytes -BT = TypeVar('BT') # size in bits of particular value +def bit_field(seq_no: int, total_bits: int, bits: int): + return type(f'MQTTPayloadBitField_{seq_no}_{total_bits}_{bits}', (object,), { + 'seq_no': seq_no, + 'total_bits': total_bits, + 'bits': bits + }) -class MQTTPayloadBitField(int, Generic[NT, ST, BT]): - pass +def _bit_field_params(cl) -> Optional[Tuple[int, ...]]: + match = re.match(r'MQTTPayloadBitField_(\d+)_(\d+)_(\d)$', cl.__name__) + if match is not None: + return tuple([int(match.group(i)) for i in range(1, 4)]) + return None |