aboutsummaryrefslogtreecommitdiff
path: root/include/py/homekit/util.py
diff options
context:
space:
mode:
Diffstat (limited to 'include/py/homekit/util.py')
-rw-r--r--include/py/homekit/util.py255
1 files changed, 255 insertions, 0 deletions
diff --git a/include/py/homekit/util.py b/include/py/homekit/util.py
new file mode 100644
index 0000000..22bba86
--- /dev/null
+++ b/include/py/homekit/util.py
@@ -0,0 +1,255 @@
+from __future__ import annotations
+
+import json
+import socket
+import time
+import subprocess
+import traceback
+import logging
+import string
+import random
+import re
+
+from enum import Enum
+from datetime import datetime
+from typing import Optional, List
+from zlib import adler32
+
+logger = logging.getLogger(__name__)
+
+
+def validate_ipv4_or_hostname(address: str, raise_exception: bool = False) -> bool:
+ if re.match(r'^(\d{1,3}\.){3}\d{1,3}$', address):
+ parts = address.split('.')
+ if all(0 <= int(part) < 256 for part in parts):
+ return True
+ else:
+ if raise_exception:
+ raise ValueError(f"invalid IPv4 address: {address}")
+ return False
+
+ if re.match(r'^[a-zA-Z0-9.-]+$', address):
+ return True
+ else:
+ if raise_exception:
+ raise ValueError(f"invalid hostname: {address}")
+ return False
+
+
+def validate_mac_address(mac_address: str) -> bool:
+ mac_pattern = r'^([0-9A-Fa-f]{2}[:-]){5}([0-9A-Fa-f]{2})$'
+ if re.match(mac_pattern, mac_address):
+ return True
+ else:
+ return False
+
+
+class Addr:
+ host: str
+ port: Optional[int]
+
+ def __init__(self, host: str, port: Optional[int] = None):
+ self.host = host
+ self.port = port
+
+ @staticmethod
+ def fromstring(addr: str) -> Addr:
+ colons = addr.count(':')
+ if colons != 1:
+ raise ValueError('invalid host:port format')
+
+ if not colons:
+ host = addr
+ port = None
+ else:
+ host, port = addr.split(':')
+
+ validate_ipv4_or_hostname(host, raise_exception=True)
+
+ if port is not None:
+ port = int(port)
+ if not 0 <= port <= 65535:
+ raise ValueError(f'invalid port {port}')
+
+ return Addr(host, port)
+
+ def __str__(self):
+ buf = self.host
+ if self.port is not None:
+ buf += ':'+str(self.port)
+ return buf
+
+ def __iter__(self):
+ yield self.host
+ yield self.port
+
+
+# https://stackoverflow.com/questions/312443/how-do-you-split-a-list-into-evenly-sized-chunks
+def chunks(lst, n):
+ """Yield successive n-sized chunks from lst."""
+ for i in range(0, len(lst), n):
+ yield lst[i:i + n]
+
+
+def json_serial(obj):
+ """JSON serializer for datetime objects"""
+ if isinstance(obj, datetime):
+ return obj.timestamp()
+ if isinstance(obj, Enum):
+ return obj.value
+ raise TypeError("Type %s not serializable" % type(obj))
+
+
+def stringify(v) -> str:
+ return json.dumps(v, separators=(',', ':'), default=json_serial)
+
+
+def ipv4_valid(ip: str) -> bool:
+ try:
+ socket.inet_aton(ip)
+ return True
+ except socket.error:
+ return False
+
+
+def strgen(n: int):
+ return ''.join(random.choices(string.ascii_letters + string.digits, k=n))
+
+
+class MySimpleSocketClient:
+ host: str
+ port: int
+
+ def __init__(self, host: str, port: int):
+ self.host = host
+ self.port = port
+ self.sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
+ self.sock.connect((self.host, self.port))
+ self.sock.settimeout(5)
+
+ def __del__(self):
+ self.sock.close()
+
+ def write(self, line: str) -> None:
+ self.sock.sendall((line + '\r\n').encode())
+
+ def read(self) -> str:
+ buf = bytearray()
+ while True:
+ buf.extend(self.sock.recv(256))
+ if b'\r\n' in buf:
+ break
+
+ response = buf.decode().strip()
+ return response
+
+
+def send_datagram(message: str, addr: Addr) -> None:
+ sock = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
+ sock.sendto(message.encode(), addr)
+
+
+def format_tb(exc) -> Optional[List[str]]:
+ tb = traceback.format_tb(exc.__traceback__)
+ if not tb:
+ return None
+
+ tb = list(map(lambda s: s.strip(), tb))
+ tb.reverse()
+ if tb[0][-1:] == ':':
+ tb[0] = tb[0][:-1]
+
+ return tb
+
+
+class ChildProcessInfo:
+ pid: int
+ cmd: str
+
+ def __init__(self,
+ pid: int,
+ cmd: str):
+ self.pid = pid
+ self.cmd = cmd
+
+
+def find_child_processes(ppid: int) -> List[ChildProcessInfo]:
+ p = subprocess.run(['pgrep', '-P', str(ppid), '--list-full'], capture_output=True)
+ if p.returncode != 0:
+ raise OSError(f'pgrep returned {p.returncode}')
+
+ children = []
+
+ lines = p.stdout.decode().strip().split('\n')
+ for line in lines:
+ try:
+ space_idx = line.index(' ')
+ except ValueError as exc:
+ logger.exception(exc)
+ continue
+
+ pid = int(line[0:space_idx])
+ cmd = line[space_idx+1:]
+
+ children.append(ChildProcessInfo(pid, cmd))
+
+ return children
+
+
+class Stopwatch:
+ elapsed: float
+ time_started: Optional[float]
+
+ def __init__(self):
+ self.elapsed = 0
+ self.time_started = None
+
+ def go(self):
+ if self.time_started is not None:
+ raise StopwatchError('stopwatch was already started')
+
+ self.time_started = time.time()
+
+ def pause(self):
+ if self.time_started is None:
+ raise StopwatchError('stopwatch was paused')
+
+ self.elapsed += time.time() - self.time_started
+ self.time_started = None
+
+ def get_elapsed_time(self):
+ elapsed = self.elapsed
+ if self.time_started is not None:
+ elapsed += time.time() - self.time_started
+ return elapsed
+
+ def reset(self):
+ self.time_started = None
+ self.elapsed = 0
+
+ def is_paused(self):
+ return self.time_started is None
+
+
+class StopwatchError(RuntimeError):
+ pass
+
+
+def filesize_fmt(num, suffix="B") -> str:
+ for unit in ["", "Ki", "Mi", "Gi", "Ti", "Pi", "Ei", "Zi"]:
+ if abs(num) < 1024.0:
+ return f"{num:3.1f} {unit}{suffix}"
+ num /= 1024.0
+ return f"{num:.1f} Yi{suffix}"
+
+
+class HashableEnum(Enum):
+ def hash(self) -> int:
+ return adler32(self.name.encode())
+
+
+def next_tick_gen(freq):
+ t = time.time()
+ while True:
+ t += freq
+ yield max(t - time.time(), 0) \ No newline at end of file