summaryrefslogtreecommitdiff
path: root/mqtt.py
diff options
context:
space:
mode:
authorEvgeny Zinoviev <me@ch1p.io>2021-02-16 02:02:14 +0300
committerEvgeny Zinoviev <me@ch1p.io>2021-02-16 02:02:14 +0300
commitd9ea8224613d5bd27bf527b14fa4ef02f827e482 (patch)
treee6db2aab27d5468b4ac73919729e91a5991a337c /mqtt.py
parentb80d4936ce1ce434a892d361413ec9e77d2b0a79 (diff)
refactor code, support multiple mqtt packets in one tcp message, support other mqtt packets
Diffstat (limited to 'mqtt.py')
-rw-r--r--mqtt.py351
1 files changed, 351 insertions, 0 deletions
diff --git a/mqtt.py b/mqtt.py
new file mode 100644
index 0000000..b2cbb33
--- /dev/null
+++ b/mqtt.py
@@ -0,0 +1,351 @@
+from mitmproxy.utils import strutils
+from typing import List
+
+import struct
+
+
+class MQTTControlPacket:
+ # Packet types
+ (
+ CONNECT,
+ CONNACK,
+ PUBLISH,
+ PUBACK,
+ PUBREC,
+ PUBREL,
+ PUBCOMP,
+ SUBSCRIBE,
+ SUBACK,
+ UNSUBSCRIBE,
+ UNSUBACK,
+ PINGREQ,
+ PINGRESP,
+ DISCONNECT,
+ ) = range(1, 15)
+
+ # http://docs.oasis-open.org/mqtt/mqtt/v3.1.1/os/mqtt-v3.1.1-os.html#_Table_2.1_-
+ NAMES = [
+ "reserved",
+ "CONNECT",
+ "CONNACK",
+ "PUBLISH",
+ "PUBACK",
+ "PUBREC",
+ "PUBREL",
+ "PUBCOMP",
+ "SUBSCRIBE",
+ "SUBACK",
+ "UNSUBSCRIBE",
+ "UNSUBACK",
+ "PINGREQ",
+ "PINGRESP",
+ "DISCONNECT",
+ "reserved",
+ ]
+
+ # http://docs.oasis-open.org/mqtt/mqtt/v3.1.1/os/mqtt-v3.1.1-os.html#_Table_3.1_-
+ CONNECT_RETURN_CODES = [
+ "Connection Accepted",
+ "Connection Refused, unacceptable protocol version",
+ "Connection Refused, identifier rejected",
+ "Connection Refused, Server unavailable",
+ "Connection Refused, bad user name or password",
+ "Connection Refused, not authorized",
+ ]
+
+ SUBACK_RETURN_CODES = {
+ 0x00: "Success - Maximum QoS 0",
+ 0x01: "Success - Maximum QoS 1",
+ 0x02: "Success - Maximum QoS 2",
+ 0x80: "Failure",
+ }
+
+ PACKETS_WITH_IDENTIFIER = [
+ PUBACK,
+ PUBREC,
+ PUBREL,
+ PUBCOMP,
+ SUBSCRIBE,
+ SUBACK,
+ UNSUBSCRIBE,
+ UNSUBACK,
+ ]
+
+ def __init__(self, buf: bytes, packet_type: int, packet_flags: int, length: int, length_size=1):
+ self.buf = buf
+ self.packet_type = packet_type
+ self.packet_type_human = self.NAMES[self.packet_type]
+ self.packet_flags = packet_flags
+ self.remaining_length = length
+ self.remaining_length_size = length_size
+ self.packet_identifier = None
+ self.payload = {}
+
+ self.dup, self.qos, self.retain = self._parse_flags()
+
+ def parse(self):
+ # Variable header & Payload
+ # http://docs.oasis-open.org/mqtt/mqtt/v3.1.1/os/mqtt-v3.1.1-os.html#_Toc398718024
+ # http://docs.oasis-open.org/mqtt/mqtt/v3.1.1/os/mqtt-v3.1.1-os.html#_Toc398718026
+ if self.packet_type == self.CONNECT:
+ self._parse_connect_variable_headers()
+ self._parse_connect_payload()
+
+ elif self.packet_type == self.CONNACK:
+ self._parse_connack_variable_headers()
+
+ elif self.packet_type == self.PUBLISH:
+ self._parse_publish_variable_headers()
+ self._parse_publish_payload()
+
+ elif self.packet_type in (self.SUBSCRIBE, self.UNSUBSCRIBE):
+ flags = self.packet_flags
+ if flags != 0x2:
+ raise Exception(f'Packet is malformed: flags = {flags} != 0x2')
+
+ self._parse_packet_identifier()
+ self._parse_subscribe_payload(with_qos=(self.packet_type == self.SUBSCRIBE))
+
+ elif self.packet_type == self.SUBACK:
+ self._parse_packet_identifier()
+ self._parse_suback_payload()
+
+ elif self.packet_type == self.UNSUBACK:
+ self._parse_packet_identifier()
+
+ elif self.packet_type == self.PUBACK:
+ self._parse_packet_identifier()
+
+ elif self.packet_type == self.UNSUBSCRIBE:
+ pass
+
+ # else:
+ # self.payload = None
+
+ def pprint(self):
+ pid = f' {self.packet_identifier:04x}' if self.packet_identifier is not None else ''
+ s = f'[{self.NAMES[self.packet_type]}{pid}]'
+
+ if self.packet_type == self.CONNECT:
+ s += f"""
+Protocol Level: {self.variable_headers['ProtocolLevel'][0]}
+Client Id: {self.payload['ClientId']}
+Will Topic: {self.payload.get('WillTopic')}
+Will Message: {strutils.bytes_to_escaped_str(self.payload.get('WillMessage', b'None'))}
+User Name: {self.payload.get('UserName')}
+Password: {strutils.bytes_to_escaped_str(self.payload.get('Password', b'None'))}
+"""
+
+ elif self.packet_type == self.CONNACK:
+ rc = self.connack_headers["ReturnCode"]
+ rc_desc = self.CONNECT_RETURN_CODES[rc] if rc < len(self.CONNECT_RETURN_CODES) else f'{rc:02x}'
+ s += f" SessionPresent: {self.connack_headers['SessionPresent']}. {rc_desc}"
+
+ elif self.packet_type in (self.SUBSCRIBE, self.UNSUBSCRIBE):
+ s += " sent topic filters: "
+ s += ", ".join([f"'{tf}'" for tf in self.topic_filters])
+
+ elif self.packet_type == self.SUBACK:
+ rc = self.payload['ReturnCode']
+ s += " "
+ s += self.SUBACK_RETURN_CODES[rc] if rc in self.SUBACK_RETURN_CODES else f'{rc:02xs}'
+
+ elif self.packet_type == self.PUBLISH:
+ topic_name = strutils.bytes_to_escaped_str(self.topic_name)
+ payload = strutils.bytes_to_escaped_str(self.payload)
+
+ s += f" '{payload}' to topic '{topic_name}'"
+
+ elif self.packet_type in (self.PINGREQ, self.PINGRESP, self.UNSUBACK, self.PUBACK, self.DISCONNECT):
+ # just print packet type with packet identifier (if any)
+ pass
+
+ else:
+ s = f"Packet type {self.NAMES[self.packet_type]} is not supported yet!"
+
+ return s
+
+ def _parse_length_prefixed_bytes(self, offset):
+ field_length_bytes = self.buf[offset: offset + 2]
+ field_length = struct.unpack("!H", field_length_bytes)[0]
+ offset += 2
+
+ field_content_bytes = self.buf[offset: offset + field_length]
+
+ return field_length + 2, field_content_bytes
+
+ def _parse_publish_variable_headers(self):
+ offset = len(self.buf) - self.remaining_length
+
+ field_length, field_content_bytes = self._parse_length_prefixed_bytes(offset)
+ self.topic_name = field_content_bytes
+
+ if self.qos in [0x01, 0x02]:
+ offset += field_length
+ self.packet_identifier = self.buf[offset: offset + 2][0]
+
+ def _parse_publish_payload(self):
+ fixed_header_length = len(self.buf) - self.remaining_length
+ variable_header_length = 2 + len(self.topic_name)
+
+ if self.qos in [0x01, 0x02]:
+ variable_header_length += 2
+
+ offset = fixed_header_length + variable_header_length
+
+ self.payload = self.buf[offset:]
+
+ def _parse_subscribe_payload(self, with_qos=True):
+ # skip packet identifier
+ offset = self.remaining_length_size
+ offset += 1 # fixed header
+ offset += 2 # packet identifier
+
+ self.topic_filters = []
+
+ while len(self.buf) - offset > 0:
+ field_length, topic_filter_bytes = self._parse_length_prefixed_bytes(offset)
+ offset += field_length
+
+ topic_filter = {
+ 'topic': topic_filter_bytes.decode("utf-8")
+ }
+
+ if with_qos:
+ topic_filter['qos'] = self.buf[offset: offset + 1][0] & 0x3
+ offset += 1
+
+ self.topic_filters.append(topic_filter)
+
+ def _parse_suback_payload(self):
+ offset = len(self.buf) - self.remaining_length + 2
+ rc = self.buf[offset]
+ self.payload['ReturnCode'] = rc
+
+ # http://docs.oasis-open.org/mqtt/mqtt/v3.1.1/os/mqtt-v3.1.1-os.html#_Toc398718030
+ def _parse_connect_variable_headers(self):
+ offset = len(self.buf) - self.remaining_length
+
+ self.variable_headers = {}
+ self.connect_flags = {}
+
+ self.variable_headers["ProtocolName"] = self.buf[offset: offset + 6]
+ self.variable_headers["ProtocolLevel"] = self.buf[offset + 6: offset + 7]
+ self.variable_headers["ConnectFlags"] = self.buf[offset + 7: offset + 8]
+ self.variable_headers["KeepAlive"] = self.buf[offset + 8: offset + 10]
+
+ # http://docs.oasis-open.org/mqtt/mqtt/v3.1.1/os/mqtt-v3.1.1-os.html#_Toc385349229
+ self.connect_flags["CleanSession"] = bool(self.variable_headers["ConnectFlags"][0] & 0x02)
+ self.connect_flags["Will"] = bool(self.variable_headers["ConnectFlags"][0] & 0x04)
+ self.will_qos = (self.variable_headers["ConnectFlags"][0] >> 3) & 0x03
+ self.connect_flags["WillRetain"] = bool(self.variable_headers["ConnectFlags"][0] & 0x20)
+ self.connect_flags["Password"] = bool(self.variable_headers["ConnectFlags"][0] & 0x40)
+ self.connect_flags["UserName"] = bool(self.variable_headers["ConnectFlags"][0] & 0x80)
+
+ # http://docs.oasis-open.org/mqtt/mqtt/v3.1.1/os/mqtt-v3.1.1-os.html#_Toc398718031
+ def _parse_connect_payload(self):
+ fields = []
+ offset = len(self.buf) - self.remaining_length + 10
+
+ while len(self.buf) - offset > 0:
+ field_length, field_content = self._parse_length_prefixed_bytes(offset)
+ fields.append(field_content)
+ offset += field_length
+
+ self.payload = {}
+
+ for f in fields:
+ # http://docs.oasis-open.org/mqtt/mqtt/v3.1.1/os/mqtt-v3.1.1-os.html#_Toc385349242
+ if "ClientId" not in self.payload:
+ try:
+ self.payload["ClientId"] = f.decode("utf-8")
+ except:
+ self.payload["ClientId"] = str(f)
+
+ # http://docs.oasis-open.org/mqtt/mqtt/v3.1.1/os/mqtt-v3.1.1-os.html#_Toc385349243
+ elif self.connect_flags["Will"] and "WillTopic" not in self.payload:
+ self.payload["WillTopic"] = f.decode("utf-8")
+
+ elif self.connect_flags["Will"] and "WillMessage" not in self.payload:
+ self.payload["WillMessage"] = f
+
+ elif self.connect_flags["UserName"] and "UserName" not in self.payload:
+ self.payload["UserName"] = f.decode("utf-8")
+
+ elif self.connect_flags["Password"] and "Password" not in self.payload:
+ self.payload["Password"] = f
+
+ else:
+ raise Exception("")
+
+ # http://docs.oasis-open.org/mqtt/mqtt/v3.1.1/os/mqtt-v3.1.1-os.html#_Toc398718035
+ def _parse_connack_variable_headers(self):
+ self.connack_headers = {}
+
+ offset = len(self.buf) - self.remaining_length
+
+ self.connack_headers["SessionPresent"] = self.buf[offset: offset + 1][0] & 0x01 == 0x01
+ self.connack_headers["ReturnCode"] = self.buf[offset + 1: offset + 2][0]
+
+ # http://docs.oasis-open.org/mqtt/mqtt/v3.1.1/os/mqtt-v3.1.1-os.html#_Toc398718022
+ def _parse_flags(self):
+ dup = None
+ qos = None
+ retain = None
+
+ if self.packet_type == self.PUBLISH:
+ dup = (self.buf[0] >> 3) & 0x01
+ qos = (self.buf[0] >> 1) & 0x03
+ retain = self.buf[0] & 0x01
+
+ return dup, qos, retain
+
+ # http://docs.oasis-open.org/mqtt/mqtt/v3.1.1/os/mqtt-v3.1.1-os.html#_Table_2.5_-
+ def _parse_packet_identifier(self):
+ offset = 1 + self.remaining_length_size
+ self.packet_identifier = struct.unpack('!H', self.buf[offset: offset + 2])[0]
+
+
+def _get_packet_type(buf: bytes) -> int:
+ return buf[0] >> 4
+
+
+def _get_packet_flags(buf: bytes) -> int:
+ return buf[0] & 0xf
+
+
+def _get_remaining_length(buf: bytes) -> tuple:
+ multiplier = 1
+ value = 0
+ i = 1
+
+ while True:
+ encoded_byte = buf[i-1]
+ value += (encoded_byte & 127) * multiplier
+ multiplier *= 128
+
+ if multiplier > 128 * 128 * 128:
+ raise Exception("Malformed Remaining Length")
+
+ if encoded_byte & 128 == 0:
+ break
+
+ i += 1
+
+ return i, value
+
+
+def read_packets(buf: bytes) -> List[MQTTControlPacket]:
+ packets = []
+ while len(buf) > 0:
+ # Fixed header
+ # http://docs.oasis-open.org/mqtt/mqtt/v3.1.1/os/mqtt-v3.1.1-os.html#_Toc398718020
+ packet_type = _get_packet_type(buf)
+ packet_flags = _get_packet_flags(buf)
+ length_size, length = _get_remaining_length(buf[1:])
+
+ packets.append(MQTTControlPacket(buf[:1+length_size+length], packet_type, packet_flags, length, length_size=length_size))
+
+ buf = buf[1+length_size+length:]
+
+ return packets