summaryrefslogtreecommitdiff
path: root/src/home/util.py
diff options
context:
space:
mode:
Diffstat (limited to 'src/home/util.py')
-rw-r--r--src/home/util.py70
1 files changed, 52 insertions, 18 deletions
diff --git a/src/home/util.py b/src/home/util.py
index 93a9d8f..35505bc 100644
--- a/src/home/util.py
+++ b/src/home/util.py
@@ -1,3 +1,5 @@
+from __future__ import annotations
+
import json
import socket
import time
@@ -6,17 +8,57 @@ import traceback
import logging
import string
import random
+import re
from enum import Enum
from datetime import datetime
from typing import Tuple, Optional, List
from zlib import adler32
-Addr = Tuple[str, int] # network address type (host, port)
-
logger = logging.getLogger(__name__)
+def validate_ipv4_or_hostname(address: str, raise_exception: bool = False) -> bool:
+ if re.match(r'^(\d{1,3}\.){3}\d{1,3}$', address):
+ parts = address.split('.')
+ if all(0 <= int(part) < 256 for part in parts):
+ return True
+ else:
+ if raise_exception:
+ raise ValueError(f"invalid IPv4 address: {address}")
+ return False
+
+ if re.match(r'^[a-zA-Z0-9.-]+$', address):
+ return True
+ else:
+ if raise_exception:
+ raise ValueError(f"invalid hostname: {address}")
+ return False
+
+
+class Addr:
+ host: str
+ port: int
+
+ def __init__(self, host: str, port: int):
+ self.host = host
+ self.port = port
+
+ @staticmethod
+ def fromstring(addr: str) -> Addr:
+ if addr.count(':') != 1:
+ raise ValueError('invalid host:port format')
+
+ host, port = addr.split(':')
+ validate_ipv4_or_hostname(host, raise_exception=True)
+
+ port = int(port)
+ if not 0 <= port <= 65535:
+ raise ValueError(f'invalid port {port}')
+
+ return Addr(host, port)
+
+
# https://stackoverflow.com/questions/312443/how-do-you-split-a-list-into-evenly-sized-chunks
def chunks(lst, n):
"""Yield successive n-sized chunks from lst."""
@@ -45,21 +87,6 @@ def ipv4_valid(ip: str) -> bool:
return False
-def parse_addr(addr: str) -> Addr:
- if addr.count(':') != 1:
- raise ValueError('invalid host:port format')
-
- host, port = addr.split(':')
- if not ipv4_valid(host):
- raise ValueError('invalid ipv4 address')
-
- port = int(port)
- if not 0 <= port <= 65535:
- raise ValueError('invalid port')
-
- return host, port
-
-
def strgen(n: int):
return ''.join(random.choices(string.ascii_letters + string.digits, k=n))
@@ -193,4 +220,11 @@ def filesize_fmt(num, suffix="B") -> str:
class HashableEnum(Enum):
def hash(self) -> int:
- return adler32(self.name.encode()) \ No newline at end of file
+ return adler32(self.name.encode())
+
+
+def next_tick_gen(freq):
+ t = time.time()
+ while True:
+ t += freq
+ yield max(t - time.time(), 0) \ No newline at end of file