diff options
author | Evgeny Zinoviev <me@ch1p.io> | 2021-02-16 02:02:14 +0300 |
---|---|---|
committer | Evgeny Zinoviev <me@ch1p.io> | 2021-02-16 02:02:14 +0300 |
commit | d9ea8224613d5bd27bf527b14fa4ef02f827e482 (patch) | |
tree | e6db2aab27d5468b4ac73919729e91a5991a337c /mqtt.py | |
parent | b80d4936ce1ce434a892d361413ec9e77d2b0a79 (diff) |
refactor code, support multiple mqtt packets in one tcp message, support other mqtt packets
Diffstat (limited to 'mqtt.py')
-rw-r--r-- | mqtt.py | 351 |
1 files changed, 351 insertions, 0 deletions
@@ -0,0 +1,351 @@ +from mitmproxy.utils import strutils +from typing import List + +import struct + + +class MQTTControlPacket: + # Packet types + ( + CONNECT, + CONNACK, + PUBLISH, + PUBACK, + PUBREC, + PUBREL, + PUBCOMP, + SUBSCRIBE, + SUBACK, + UNSUBSCRIBE, + UNSUBACK, + PINGREQ, + PINGRESP, + DISCONNECT, + ) = range(1, 15) + + # http://docs.oasis-open.org/mqtt/mqtt/v3.1.1/os/mqtt-v3.1.1-os.html#_Table_2.1_- + NAMES = [ + "reserved", + "CONNECT", + "CONNACK", + "PUBLISH", + "PUBACK", + "PUBREC", + "PUBREL", + "PUBCOMP", + "SUBSCRIBE", + "SUBACK", + "UNSUBSCRIBE", + "UNSUBACK", + "PINGREQ", + "PINGRESP", + "DISCONNECT", + "reserved", + ] + + # http://docs.oasis-open.org/mqtt/mqtt/v3.1.1/os/mqtt-v3.1.1-os.html#_Table_3.1_- + CONNECT_RETURN_CODES = [ + "Connection Accepted", + "Connection Refused, unacceptable protocol version", + "Connection Refused, identifier rejected", + "Connection Refused, Server unavailable", + "Connection Refused, bad user name or password", + "Connection Refused, not authorized", + ] + + SUBACK_RETURN_CODES = { + 0x00: "Success - Maximum QoS 0", + 0x01: "Success - Maximum QoS 1", + 0x02: "Success - Maximum QoS 2", + 0x80: "Failure", + } + + PACKETS_WITH_IDENTIFIER = [ + PUBACK, + PUBREC, + PUBREL, + PUBCOMP, + SUBSCRIBE, + SUBACK, + UNSUBSCRIBE, + UNSUBACK, + ] + + def __init__(self, buf: bytes, packet_type: int, packet_flags: int, length: int, length_size=1): + self.buf = buf + self.packet_type = packet_type + self.packet_type_human = self.NAMES[self.packet_type] + self.packet_flags = packet_flags + self.remaining_length = length + self.remaining_length_size = length_size + self.packet_identifier = None + self.payload = {} + + self.dup, self.qos, self.retain = self._parse_flags() + + def parse(self): + # Variable header & Payload + # http://docs.oasis-open.org/mqtt/mqtt/v3.1.1/os/mqtt-v3.1.1-os.html#_Toc398718024 + # http://docs.oasis-open.org/mqtt/mqtt/v3.1.1/os/mqtt-v3.1.1-os.html#_Toc398718026 + if self.packet_type == self.CONNECT: + self._parse_connect_variable_headers() + self._parse_connect_payload() + + elif self.packet_type == self.CONNACK: + self._parse_connack_variable_headers() + + elif self.packet_type == self.PUBLISH: + self._parse_publish_variable_headers() + self._parse_publish_payload() + + elif self.packet_type in (self.SUBSCRIBE, self.UNSUBSCRIBE): + flags = self.packet_flags + if flags != 0x2: + raise Exception(f'Packet is malformed: flags = {flags} != 0x2') + + self._parse_packet_identifier() + self._parse_subscribe_payload(with_qos=(self.packet_type == self.SUBSCRIBE)) + + elif self.packet_type == self.SUBACK: + self._parse_packet_identifier() + self._parse_suback_payload() + + elif self.packet_type == self.UNSUBACK: + self._parse_packet_identifier() + + elif self.packet_type == self.PUBACK: + self._parse_packet_identifier() + + elif self.packet_type == self.UNSUBSCRIBE: + pass + + # else: + # self.payload = None + + def pprint(self): + pid = f' {self.packet_identifier:04x}' if self.packet_identifier is not None else '' + s = f'[{self.NAMES[self.packet_type]}{pid}]' + + if self.packet_type == self.CONNECT: + s += f""" +Protocol Level: {self.variable_headers['ProtocolLevel'][0]} +Client Id: {self.payload['ClientId']} +Will Topic: {self.payload.get('WillTopic')} +Will Message: {strutils.bytes_to_escaped_str(self.payload.get('WillMessage', b'None'))} +User Name: {self.payload.get('UserName')} +Password: {strutils.bytes_to_escaped_str(self.payload.get('Password', b'None'))} +""" + + elif self.packet_type == self.CONNACK: + rc = self.connack_headers["ReturnCode"] + rc_desc = self.CONNECT_RETURN_CODES[rc] if rc < len(self.CONNECT_RETURN_CODES) else f'{rc:02x}' + s += f" SessionPresent: {self.connack_headers['SessionPresent']}. {rc_desc}" + + elif self.packet_type in (self.SUBSCRIBE, self.UNSUBSCRIBE): + s += " sent topic filters: " + s += ", ".join([f"'{tf}'" for tf in self.topic_filters]) + + elif self.packet_type == self.SUBACK: + rc = self.payload['ReturnCode'] + s += " " + s += self.SUBACK_RETURN_CODES[rc] if rc in self.SUBACK_RETURN_CODES else f'{rc:02xs}' + + elif self.packet_type == self.PUBLISH: + topic_name = strutils.bytes_to_escaped_str(self.topic_name) + payload = strutils.bytes_to_escaped_str(self.payload) + + s += f" '{payload}' to topic '{topic_name}'" + + elif self.packet_type in (self.PINGREQ, self.PINGRESP, self.UNSUBACK, self.PUBACK, self.DISCONNECT): + # just print packet type with packet identifier (if any) + pass + + else: + s = f"Packet type {self.NAMES[self.packet_type]} is not supported yet!" + + return s + + def _parse_length_prefixed_bytes(self, offset): + field_length_bytes = self.buf[offset: offset + 2] + field_length = struct.unpack("!H", field_length_bytes)[0] + offset += 2 + + field_content_bytes = self.buf[offset: offset + field_length] + + return field_length + 2, field_content_bytes + + def _parse_publish_variable_headers(self): + offset = len(self.buf) - self.remaining_length + + field_length, field_content_bytes = self._parse_length_prefixed_bytes(offset) + self.topic_name = field_content_bytes + + if self.qos in [0x01, 0x02]: + offset += field_length + self.packet_identifier = self.buf[offset: offset + 2][0] + + def _parse_publish_payload(self): + fixed_header_length = len(self.buf) - self.remaining_length + variable_header_length = 2 + len(self.topic_name) + + if self.qos in [0x01, 0x02]: + variable_header_length += 2 + + offset = fixed_header_length + variable_header_length + + self.payload = self.buf[offset:] + + def _parse_subscribe_payload(self, with_qos=True): + # skip packet identifier + offset = self.remaining_length_size + offset += 1 # fixed header + offset += 2 # packet identifier + + self.topic_filters = [] + + while len(self.buf) - offset > 0: + field_length, topic_filter_bytes = self._parse_length_prefixed_bytes(offset) + offset += field_length + + topic_filter = { + 'topic': topic_filter_bytes.decode("utf-8") + } + + if with_qos: + topic_filter['qos'] = self.buf[offset: offset + 1][0] & 0x3 + offset += 1 + + self.topic_filters.append(topic_filter) + + def _parse_suback_payload(self): + offset = len(self.buf) - self.remaining_length + 2 + rc = self.buf[offset] + self.payload['ReturnCode'] = rc + + # http://docs.oasis-open.org/mqtt/mqtt/v3.1.1/os/mqtt-v3.1.1-os.html#_Toc398718030 + def _parse_connect_variable_headers(self): + offset = len(self.buf) - self.remaining_length + + self.variable_headers = {} + self.connect_flags = {} + + self.variable_headers["ProtocolName"] = self.buf[offset: offset + 6] + self.variable_headers["ProtocolLevel"] = self.buf[offset + 6: offset + 7] + self.variable_headers["ConnectFlags"] = self.buf[offset + 7: offset + 8] + self.variable_headers["KeepAlive"] = self.buf[offset + 8: offset + 10] + + # http://docs.oasis-open.org/mqtt/mqtt/v3.1.1/os/mqtt-v3.1.1-os.html#_Toc385349229 + self.connect_flags["CleanSession"] = bool(self.variable_headers["ConnectFlags"][0] & 0x02) + self.connect_flags["Will"] = bool(self.variable_headers["ConnectFlags"][0] & 0x04) + self.will_qos = (self.variable_headers["ConnectFlags"][0] >> 3) & 0x03 + self.connect_flags["WillRetain"] = bool(self.variable_headers["ConnectFlags"][0] & 0x20) + self.connect_flags["Password"] = bool(self.variable_headers["ConnectFlags"][0] & 0x40) + self.connect_flags["UserName"] = bool(self.variable_headers["ConnectFlags"][0] & 0x80) + + # http://docs.oasis-open.org/mqtt/mqtt/v3.1.1/os/mqtt-v3.1.1-os.html#_Toc398718031 + def _parse_connect_payload(self): + fields = [] + offset = len(self.buf) - self.remaining_length + 10 + + while len(self.buf) - offset > 0: + field_length, field_content = self._parse_length_prefixed_bytes(offset) + fields.append(field_content) + offset += field_length + + self.payload = {} + + for f in fields: + # http://docs.oasis-open.org/mqtt/mqtt/v3.1.1/os/mqtt-v3.1.1-os.html#_Toc385349242 + if "ClientId" not in self.payload: + try: + self.payload["ClientId"] = f.decode("utf-8") + except: + self.payload["ClientId"] = str(f) + + # http://docs.oasis-open.org/mqtt/mqtt/v3.1.1/os/mqtt-v3.1.1-os.html#_Toc385349243 + elif self.connect_flags["Will"] and "WillTopic" not in self.payload: + self.payload["WillTopic"] = f.decode("utf-8") + + elif self.connect_flags["Will"] and "WillMessage" not in self.payload: + self.payload["WillMessage"] = f + + elif self.connect_flags["UserName"] and "UserName" not in self.payload: + self.payload["UserName"] = f.decode("utf-8") + + elif self.connect_flags["Password"] and "Password" not in self.payload: + self.payload["Password"] = f + + else: + raise Exception("") + + # http://docs.oasis-open.org/mqtt/mqtt/v3.1.1/os/mqtt-v3.1.1-os.html#_Toc398718035 + def _parse_connack_variable_headers(self): + self.connack_headers = {} + + offset = len(self.buf) - self.remaining_length + + self.connack_headers["SessionPresent"] = self.buf[offset: offset + 1][0] & 0x01 == 0x01 + self.connack_headers["ReturnCode"] = self.buf[offset + 1: offset + 2][0] + + # http://docs.oasis-open.org/mqtt/mqtt/v3.1.1/os/mqtt-v3.1.1-os.html#_Toc398718022 + def _parse_flags(self): + dup = None + qos = None + retain = None + + if self.packet_type == self.PUBLISH: + dup = (self.buf[0] >> 3) & 0x01 + qos = (self.buf[0] >> 1) & 0x03 + retain = self.buf[0] & 0x01 + + return dup, qos, retain + + # http://docs.oasis-open.org/mqtt/mqtt/v3.1.1/os/mqtt-v3.1.1-os.html#_Table_2.5_- + def _parse_packet_identifier(self): + offset = 1 + self.remaining_length_size + self.packet_identifier = struct.unpack('!H', self.buf[offset: offset + 2])[0] + + +def _get_packet_type(buf: bytes) -> int: + return buf[0] >> 4 + + +def _get_packet_flags(buf: bytes) -> int: + return buf[0] & 0xf + + +def _get_remaining_length(buf: bytes) -> tuple: + multiplier = 1 + value = 0 + i = 1 + + while True: + encoded_byte = buf[i-1] + value += (encoded_byte & 127) * multiplier + multiplier *= 128 + + if multiplier > 128 * 128 * 128: + raise Exception("Malformed Remaining Length") + + if encoded_byte & 128 == 0: + break + + i += 1 + + return i, value + + +def read_packets(buf: bytes) -> List[MQTTControlPacket]: + packets = [] + while len(buf) > 0: + # Fixed header + # http://docs.oasis-open.org/mqtt/mqtt/v3.1.1/os/mqtt-v3.1.1-os.html#_Toc398718020 + packet_type = _get_packet_type(buf) + packet_flags = _get_packet_flags(buf) + length_size, length = _get_remaining_length(buf[1:]) + + packets.append(MQTTControlPacket(buf[:1+length_size+length], packet_type, packet_flags, length, length_size=length_size)) + + buf = buf[1+length_size+length:] + + return packets |