refactor: Add Python type annotations wherever appropriate (#269)

* Add Python type annotations wherever appropriate

* Might as well annotate this too
This commit is contained in:
Miko
2025-11-19 18:59:32 +00:00
committed by GitHub
parent be362c5079
commit 0a608afbe6
7 changed files with 679 additions and 698 deletions

29
head-tracking/colors.py Normal file
View File

@@ -0,0 +1,29 @@
import logging
from logging import Formatter, LogRecord
from typing import Dict
class Colors:
RESET: str = "\033[0m"
BOLD: str = "\033[1m"
RED: str = "\033[91m"
GREEN: str = "\033[92m"
YELLOW: str = "\033[93m"
BLUE: str = "\033[94m"
MAGENTA: str = "\033[95m"
CYAN: str = "\033[96m"
WHITE: str = "\033[97m"
BG_BLACK: str = "\033[40m"
class ColorFormatter(Formatter):
FORMATS: Dict[int, str] = {
logging.DEBUG: f"{Colors.BLUE}[%(levelname)s] %(message)s{Colors.RESET}",
logging.INFO: f"{Colors.GREEN}%(message)s{Colors.RESET}",
logging.WARNING: f"{Colors.YELLOW}%(message)s{Colors.RESET}",
logging.ERROR: f"{Colors.RED}[%(levelname)s] %(message)s{Colors.RESET}",
logging.CRITICAL: f"{Colors.RED}{Colors.BOLD}[%(levelname)s] %(message)s{Colors.RESET}"
}
def format(self, record: LogRecord) -> str:
log_fmt: str = self.FORMATS.get(record.levelno)
formatter: Formatter = Formatter(log_fmt, datefmt="%H:%M:%S")
return formatter.format(record)

View File

@@ -1,23 +1,25 @@
import bluetooth import bluetooth
import logging import logging
from bluetooth import BluetoothSocket
from logging import Logger
class ConnectionManager: class ConnectionManager:
INIT_CMD = "00 00 04 00 01 00 02 00 00 00 00 00 00 00 00 00" INIT_CMD: str = "00 00 04 00 01 00 02 00 00 00 00 00 00 00 00 00"
START_CMD = "04 00 04 00 17 00 00 00 10 00 10 00 08 A1 02 42 0B 08 0E 10 02 1A 05 01 40 9C 00 00" START_CMD: str = "04 00 04 00 17 00 00 00 10 00 10 00 08 A1 02 42 0B 08 0E 10 02 1A 05 01 40 9C 00 00"
STOP_CMD = "04 00 04 00 17 00 00 00 10 00 11 00 08 7E 10 02 42 0B 08 4E 10 02 1A 05 01 00 00 00 00" STOP_CMD: str = "04 00 04 00 17 00 00 00 10 00 11 00 08 7E 10 02 42 0B 08 4E 10 02 1A 05 01 00 00 00 00"
def __init__(self, bt_addr="28:2D:7F:C2:05:5B", psm=0x1001, logger=None): def __init__(self, bt_addr: str = "28:2D:7F:C2:05:5B", psm: int = 0x1001, logger: Logger = None) -> None:
self.bt_addr = bt_addr self.bt_addr: str = bt_addr
self.psm = psm self.psm: int = psm
self.logger = logger if logger else logging.getLogger(__name__) self.logger: Logger = logger if logger else logging.getLogger(__name__)
self.sock = None self.sock: BluetoothSocket = None
self.connected = False self.connected: bool = False
self.started = False self.started: bool = False
def connect(self): def connect(self) -> bool:
self.logger.info(f"Connecting to {self.bt_addr} on PSM {self.psm:#04x}...") self.logger.info(f"Connecting to {self.bt_addr} on PSM {self.psm:#04x}...")
try: try:
self.sock = bluetooth.BluetoothSocket(bluetooth.L2CAP) self.sock = BluetoothSocket(bluetooth.L2CAP)
self.sock.connect((self.bt_addr, self.psm)) self.sock.connect((self.bt_addr, self.psm))
self.connected = True self.connected = True
self.logger.info("Connected to AirPods.") self.logger.info("Connected to AirPods.")
@@ -28,7 +30,7 @@ class ConnectionManager:
self.connected = False self.connected = False
return self.connected return self.connected
def send_start(self): def send_start(self) -> bool:
if not self.connected: if not self.connected:
self.logger.error("Not connected. Cannot send START command.") self.logger.error("Not connected. Cannot send START command.")
return False return False
@@ -40,7 +42,7 @@ class ConnectionManager:
self.logger.info("START command has already been sent.") self.logger.info("START command has already been sent.")
return True return True
def send_stop(self): def send_stop(self) -> None:
if self.connected and self.started: if self.connected and self.started:
try: try:
self.sock.send(bytes.fromhex(self.STOP_CMD)) self.sock.send(bytes.fromhex(self.STOP_CMD))
@@ -51,7 +53,7 @@ class ConnectionManager:
else: else:
self.logger.info("Cannot send STOP; not started or not connected.") self.logger.info("Cannot send STOP; not started or not connected.")
def disconnect(self): def disconnect(self) -> None:
if self.sock: if self.sock:
try: try:
self.sock.close() self.sock.close()

View File

@@ -1,88 +1,65 @@
import bluetooth
import threading
import time
import logging import logging
import statistics import statistics
import time
from bluetooth import BluetoothSocket
from collections import deque from collections import deque
from colors import *
from connection_manager import ConnectionManager
from logging import Logger, StreamHandler
from threading import Lock, Thread
from typing import Any, Deque, List, Optional, Tuple
class Colors: handler: StreamHandler = StreamHandler()
RESET = "\033[0m"
BOLD = "\033[1m"
RED = "\033[91m"
GREEN = "\033[92m"
YELLOW = "\033[93m"
BLUE = "\033[94m"
MAGENTA = "\033[95m"
CYAN = "\033[96m"
WHITE = "\033[97m"
BG_BLACK = "\033[40m"
class ColorFormatter(logging.Formatter):
FORMATS = {
logging.DEBUG: Colors.BLUE + "[%(levelname)s] %(message)s" + Colors.RESET,
logging.INFO: Colors.GREEN + "%(message)s" + Colors.RESET,
logging.WARNING: Colors.YELLOW + "%(message)s" + Colors.RESET,
logging.ERROR: Colors.RED + "[%(levelname)s] %(message)s" + Colors.RESET,
logging.CRITICAL: Colors.RED + Colors.BOLD + "[%(levelname)s] %(message)s" + Colors.RESET
}
def format(self, record):
log_fmt = self.FORMATS.get(record.levelno)
formatter = logging.Formatter(log_fmt, datefmt="%H:%M:%S")
return formatter.format(record)
handler = logging.StreamHandler()
handler.setFormatter(ColorFormatter()) handler.setFormatter(ColorFormatter())
log = logging.getLogger(__name__) log: Logger = logging.getLogger(__name__)
log.setLevel(logging.INFO) log.setLevel(logging.INFO)
log.addHandler(handler) log.addHandler(handler)
log.propagate = False log.propagate = False
class GestureDetector: class GestureDetector:
INIT_CMD = "00 00 04 00 01 00 02 00 00 00 00 00 00 00 00 00" INIT_CMD: str = "00 00 04 00 01 00 02 00 00 00 00 00 00 00 00 00"
START_CMD = "04 00 04 00 17 00 00 00 10 00 10 00 08 A1 02 42 0B 08 0E 10 02 1A 05 01 40 9C 00 00" START_CMD: str = "04 00 04 00 17 00 00 00 10 00 10 00 08 A1 02 42 0B 08 0E 10 02 1A 05 01 40 9C 00 00"
STOP_CMD = "04 00 04 00 17 00 00 00 10 00 11 00 08 7E 10 02 42 0B 08 4E 10 02 1A 05 01 00 00 00 00" STOP_CMD: str = "04 00 04 00 17 00 00 00 10 00 11 00 08 7E 10 02 42 0B 08 4E 10 02 1A 05 01 00 00 00 00"
def __init__(self, conn=None): def __init__(self, conn: ConnectionManager = None) -> None:
self.sock = None self.sock: BluetoothSocket = None
self.bt_addr = "28:2D:7F:C2:05:5B" self.bt_addr: str = "28:2D:7F:C2:05:5B"
self.psm = 0x1001 self.psm: int = 0x1001
self.running = False self.running: bool = False
self.data_lock = threading.Lock() self.data_lock: Lock = Lock()
self.horiz_buffer = deque(maxlen=100) self.horiz_buffer: Deque[int] = deque(maxlen=100)
self.vert_buffer = deque(maxlen=100) self.vert_buffer: Deque[int] = deque(maxlen=100)
self.horiz_avg_buffer = deque(maxlen=5) self.horiz_avg_buffer: Deque[float] = deque(maxlen=5)
self.vert_avg_buffer = deque(maxlen=5) self.vert_avg_buffer: Deque[float] = deque(maxlen=5)
self.horiz_peaks = [] self.horiz_peaks: List[int] = []
self.horiz_troughs = [] self.horiz_troughs: List[int] = []
self.vert_peaks = [] self.vert_peaks: List[int] = []
self.vert_troughs = [] self.vert_troughs: List[int] = []
self.last_peak_time = 0 self.last_peak_time: float = 0
self.peak_intervals = deque(maxlen=5) self.peak_intervals: Deque[float] = deque(maxlen=5)
self.peak_threshold = 400 self.peak_threshold: int = 400
self.direction_change_threshold = 175 self.direction_change_threshold: int = 175
self.rhythm_consistency_threshold = 0.5 self.rhythm_consistency_threshold: float = 0.5
self.horiz_increasing = None self.horiz_increasing: Optional[bool] = None
self.vert_increasing = None self.vert_increasing: Optional[bool] = None
self.required_extremes = 3 self.required_extremes = 3
self.detection_timeout = 15 self.detection_timeout: int = 15
self.min_confidence_threshold = 0.7 self.min_confidence_threshold: float = 0.7
self.conn = conn self.conn: ConnectionManager = conn
def connect(self): def connect(self) -> bool:
try: try:
log.info(f"Connecting to AirPods at {self.bt_addr}...") log.info(f"Connecting to AirPods at {self.bt_addr}...")
if self.conn is None: if self.conn is None:
from connection_manager import ConnectionManager
self.conn = ConnectionManager(self.bt_addr, self.psm, logger=log) self.conn = ConnectionManager(self.bt_addr, self.psm, logger=log)
if not self.conn.connect(): if not self.conn.connect():
return False return False
@@ -97,13 +74,13 @@ class GestureDetector:
log.error(f"{Colors.RED}Connection failed: {e}{Colors.RESET}") log.error(f"{Colors.RED}Connection failed: {e}{Colors.RESET}")
return False return False
def process_data(self): def process_data(self) -> None:
"""Process incoming head tracking data.""" """Process incoming head tracking data."""
self.conn.send_start() self.conn.send_start()
log.info(f"{Colors.GREEN}✓ Head tracking activated{Colors.RESET}") log.info(f"{Colors.GREEN}✓ Head tracking activated{Colors.RESET}")
self.running = True self.running = True
start_time = time.time() start_time: float = time.time()
log.info(f"{Colors.GREEN}Ready! Make a YES or NO gesture{Colors.RESET}") log.info(f"{Colors.GREEN}Ready! Make a YES or NO gesture{Colors.RESET}")
log.info(f"{Colors.YELLOW}Tip: Use natural, moderate speed head movements{Colors.RESET}") log.info(f"{Colors.YELLOW}Tip: Use natural, moderate speed head movements{Colors.RESET}")
@@ -118,10 +95,10 @@ class GestureDetector:
if not self.sock: if not self.sock:
log.error("Socket not available.") log.error("Socket not available.")
break break
data = self.sock.recv(1024) data: bytes = self.sock.recv(1024)
formatted = self.format_hex(data) formatted: str = self.format_hex(data)
if self.is_valid_tracking_packet(formatted): if self.is_valid_tracking_packet(formatted):
raw_bytes = bytes.fromhex(formatted.replace(" ", "")) raw_bytes: bytes = bytes.fromhex(formatted.replace(" ", ""))
horizontal, vertical = self.extract_orientation_values(raw_bytes) horizontal, vertical = self.extract_orientation_values(raw_bytes)
if horizontal is not None and vertical is not None: if horizontal is not None and vertical is not None:
@@ -132,7 +109,7 @@ class GestureDetector:
self.vert_buffer.append(smooth_v) self.vert_buffer.append(smooth_v)
self.detect_peaks_and_troughs() self.detect_peaks_and_troughs()
gesture = self.detect_gestures() gesture: Optional[str] = self.detect_gestures()
if gesture: if gesture:
self.running = False self.running = False
@@ -143,19 +120,19 @@ class GestureDetector:
log.error(f"Data processing error: {e}") log.error(f"Data processing error: {e}")
break break
def disconnect(self): def disconnect(self) -> None:
"""Disconnect from socket.""" """Disconnect from socket."""
self.conn.disconnect() self.conn.disconnect()
def format_hex(self, data): def format_hex(self, data: bytes) -> str:
"""Format binary data to readable hex string.""" """Format binary data to readable hex string."""
hex_str = data.hex() hex_str: str = data.hex()
return ' '.join(hex_str[i:i+2] for i in range(0, len(hex_str), 2)) return ' '.join(hex_str[i:i+2] for i in range(0, len(hex_str), 2))
def is_valid_tracking_packet(self, hex_string): def is_valid_tracking_packet(self, hex_string: str) -> bool:
"""Verify packet is a valid head tracking packet.""" """Verify packet is a valid head tracking packet."""
standard_header = "04 00 04 00 17 00 00 00 10 00 45 00" standard_header: str = "04 00 04 00 17 00 00 00 10 00 45 00"
alternate_header = "04 00 04 00 17 00 00 00 10 00 44 00" alternate_header: str = "04 00 04 00 17 00 00 00 10 00 44 00"
if not hex_string.startswith(standard_header) and not hex_string.startswith(alternate_header): if not hex_string.startswith(standard_header) and not hex_string.startswith(alternate_header):
return False return False
@@ -164,55 +141,55 @@ class GestureDetector:
return True return True
def extract_orientation_values(self, raw_bytes): def extract_orientation_values(self, raw_bytes: bytes) -> Tuple[Optional[int], Optional[int]]:
"""Extract head orientation data from packet.""" """Extract head orientation data from packet."""
try: try:
horizontal = int.from_bytes(raw_bytes[51:53], byteorder='little', signed=True) horizontal: int = int.from_bytes(raw_bytes[51:53], byteorder='little', signed=True)
vertical = int.from_bytes(raw_bytes[53:55], byteorder='little', signed=True) vertical: int = int.from_bytes(raw_bytes[53:55], byteorder='little', signed=True)
return horizontal, vertical return horizontal, vertical
except Exception as e: except Exception as e:
log.debug(f"Failed to extract orientation: {e}") log.debug(f"Failed to extract orientation: {e}")
return None, None return None, None
def apply_smoothing(self, horizontal, vertical): def apply_smoothing(self, horizontal: int, vertical: int) -> Tuple[float, float]:
"""Apply moving average smoothing (Apple-like filtering).""" """Apply moving average smoothing (Apple-like filtering)."""
self.horiz_avg_buffer.append(horizontal) self.horiz_avg_buffer.append(horizontal)
self.vert_avg_buffer.append(vertical) self.vert_avg_buffer.append(vertical)
smooth_horiz = sum(self.horiz_avg_buffer) / len(self.horiz_avg_buffer) smooth_horiz: float = sum(self.horiz_avg_buffer) / len(self.horiz_avg_buffer)
smooth_vert = sum(self.vert_avg_buffer) / len(self.vert_avg_buffer) smooth_vert: float = sum(self.vert_avg_buffer) / len(self.vert_avg_buffer)
return smooth_horiz, smooth_vert return smooth_horiz, smooth_vert
def detect_peaks_and_troughs(self): def detect_peaks_and_troughs(self) -> None:
"""Detect motion direction changes with Apple-like refinements.""" """Detect motion direction changes with Apple-like refinements."""
if len(self.horiz_buffer) < 4 or len(self.vert_buffer) < 4: if len(self.horiz_buffer) < 4 or len(self.vert_buffer) < 4:
return return
h_values = list(self.horiz_buffer)[-4:] h_values: List[int] = list(self.horiz_buffer)[-4:]
v_values = list(self.vert_buffer)[-4:] v_values: List[int] = list(self.vert_buffer)[-4:]
h_variance = statistics.variance(h_values) if len(h_values) > 1 else 0 h_variance: float = statistics.variance(h_values) if len(h_values) > 1 else 0
v_variance = statistics.variance(v_values) if len(v_values) > 1 else 0 v_variance: float = statistics.variance(v_values) if len(v_values) > 1 else 0
current = self.horiz_buffer[-1] current: int = self.horiz_buffer[-1]
prev = self.horiz_buffer[-2] prev: int = self.horiz_buffer[-2]
if self.horiz_increasing is None: if self.horiz_increasing is None:
self.horiz_increasing = current > prev self.horiz_increasing = current > prev
dynamic_h_threshold = max(100, min(self.direction_change_threshold, h_variance / 3)) dynamic_h_threshold: float = max(100, min(self.direction_change_threshold, h_variance / 3))
if self.horiz_increasing and current < prev - dynamic_h_threshold: if self.horiz_increasing and current < prev - dynamic_h_threshold:
if abs(prev) > self.peak_threshold: if abs(prev) > self.peak_threshold:
self.horiz_peaks.append((len(self.horiz_buffer)-1, prev, time.time())) self.horiz_peaks.append((len(self.horiz_buffer)-1, prev, time.time()))
direction = "➡️ " if prev > 0 else "⬅️ " direction: str = "➡️ " if prev > 0 else "⬅️ "
log.info(f"{Colors.CYAN}{direction} Horizontal max: {prev} (threshold: {dynamic_h_threshold:.1f}){Colors.RESET}") log.info(f"{Colors.CYAN}{direction} Horizontal max: {prev} (threshold: {dynamic_h_threshold:.1f}){Colors.RESET}")
now = time.time() now: float = time.time()
if self.last_peak_time > 0: if self.last_peak_time > 0:
interval = now - self.last_peak_time interval: float = now - self.last_peak_time
self.peak_intervals.append(interval) self.peak_intervals.append(interval)
self.last_peak_time = now self.last_peak_time = now
@@ -221,34 +198,34 @@ class GestureDetector:
elif not self.horiz_increasing and current > prev + dynamic_h_threshold: elif not self.horiz_increasing and current > prev + dynamic_h_threshold:
if abs(prev) > self.peak_threshold: if abs(prev) > self.peak_threshold:
self.horiz_troughs.append((len(self.horiz_buffer)-1, prev, time.time())) self.horiz_troughs.append((len(self.horiz_buffer)-1, prev, time.time()))
direction = "➡️ " if prev > 0 else "⬅️ " direction: str = "➡️ " if prev > 0 else "⬅️ "
log.info(f"{Colors.CYAN}{direction} Horizontal max: {prev} (threshold: {dynamic_h_threshold:.1f}){Colors.RESET}") log.info(f"{Colors.CYAN}{direction} Horizontal max: {prev} (threshold: {dynamic_h_threshold:.1f}){Colors.RESET}")
now = time.time() now: float = time.time()
if self.last_peak_time > 0: if self.last_peak_time > 0:
interval = now - self.last_peak_time interval: float = now - self.last_peak_time
self.peak_intervals.append(interval) self.peak_intervals.append(interval)
self.last_peak_time = now self.last_peak_time = now
self.horiz_increasing = True self.horiz_increasing = True
current = self.vert_buffer[-1] current: int = self.vert_buffer[-1]
prev = self.vert_buffer[-2] prev: int = self.vert_buffer[-2]
if self.vert_increasing is None: if self.vert_increasing is None:
self.vert_increasing = current > prev self.vert_increasing = current > prev
dynamic_v_threshold = max(100, min(self.direction_change_threshold, v_variance / 3)) dynamic_v_threshold: float = max(100, min(self.direction_change_threshold, v_variance / 3))
if self.vert_increasing and current < prev - dynamic_v_threshold: if self.vert_increasing and current < prev - dynamic_v_threshold:
if abs(prev) > self.peak_threshold: if abs(prev) > self.peak_threshold:
self.vert_peaks.append((len(self.vert_buffer)-1, prev, time.time())) self.vert_peaks.append((len(self.vert_buffer)-1, prev, time.time()))
direction = "⬆️ " if prev > 0 else "⬇️ " direction: str = "⬆️ " if prev > 0 else "⬇️ "
log.info(f"{Colors.MAGENTA}{direction} Vertical max: {prev} (threshold: {dynamic_v_threshold:.1f}){Colors.RESET}") log.info(f"{Colors.MAGENTA}{direction} Vertical max: {prev} (threshold: {dynamic_v_threshold:.1f}){Colors.RESET}")
now = time.time() now: float = time.time()
if self.last_peak_time > 0: if self.last_peak_time > 0:
interval = now - self.last_peak_time interval: float = now - self.last_peak_time
self.peak_intervals.append(interval) self.peak_intervals.append(interval)
self.last_peak_time = now self.last_peak_time = now
@@ -257,60 +234,60 @@ class GestureDetector:
elif not self.vert_increasing and current > prev + dynamic_v_threshold: elif not self.vert_increasing and current > prev + dynamic_v_threshold:
if abs(prev) > self.peak_threshold: if abs(prev) > self.peak_threshold:
self.vert_troughs.append((len(self.vert_buffer)-1, prev, time.time())) self.vert_troughs.append((len(self.vert_buffer)-1, prev, time.time()))
direction = "⬆️ " if prev > 0 else "⬇️ " direction: str = "⬆️ " if prev > 0 else "⬇️ "
log.info(f"{Colors.MAGENTA}{direction} Vertical max: {prev} (threshold: {dynamic_v_threshold:.1f}){Colors.RESET}") log.info(f"{Colors.MAGENTA}{direction} Vertical max: {prev} (threshold: {dynamic_v_threshold:.1f}){Colors.RESET}")
now = time.time() now: float = time.time()
if self.last_peak_time > 0: if self.last_peak_time > 0:
interval = now - self.last_peak_time interval: float = now - self.last_peak_time
self.peak_intervals.append(interval) self.peak_intervals.append(interval)
self.last_peak_time = now self.last_peak_time = now
self.vert_increasing = True self.vert_increasing = True
def calculate_rhythm_consistency(self): def calculate_rhythm_consistency(self) -> float:
"""Calculate how consistent the timing between peaks is (Apple-like).""" """Calculate how consistent the timing between peaks is (Apple-like)."""
if len(self.peak_intervals) < 2: if len(self.peak_intervals) < 2:
return 0 return 0
mean_interval = statistics.mean(self.peak_intervals) mean_interval: float = statistics.mean(self.peak_intervals)
if mean_interval == 0: if mean_interval == 0:
return 0 return 0
variances = [(i/mean_interval - 1.0) ** 2 for i in self.peak_intervals] variances: List[float] = [(i/mean_interval - 1.0) ** 2 for i in self.peak_intervals]
consistency = 1.0 - min(1.0, statistics.mean(variances) / self.rhythm_consistency_threshold) consistency: float = 1.0 - min(1.0, statistics.mean(variances) / self.rhythm_consistency_threshold)
return max(0, consistency) return max(0, consistency)
def calculate_confidence_score(self, extremes, is_vertical=True): def calculate_confidence_score(self, extremes: List[Tuple[int, int, float]], is_vertical: bool = True) -> float:
"""Calculate confidence score for gesture detection (Apple-like).""" """Calculate confidence score for gesture detection (Apple-like)."""
if len(extremes) < self.required_extremes: if len(extremes) < self.required_extremes:
return 0.0 return 0.0
sorted_extremes = sorted(extremes, key=lambda x: x[0]) sorted_extremes: List[Tuple[int, int, float]] = sorted(extremes, key=lambda x: x[0])
recent = sorted_extremes[-self.required_extremes:] recent: List[Tuple[int, int, float]] = sorted_extremes[-self.required_extremes:]
avg_amplitude = sum(abs(val) for _, val, _ in recent) / len(recent) avg_amplitude: float = sum(abs(val) for _, val, _ in recent) / len(recent)
amplitude_factor = min(1.0, avg_amplitude / 600) amplitude_factor: float = min(1.0, avg_amplitude / 600)
rhythm_factor = self.calculate_rhythm_consistency() rhythm_factor: float = self.calculate_rhythm_consistency()
signs = [1 if val > 0 else -1 for _, val, _ in recent] signs: List[int] = [1 if val > 0 else -1 for _, val, _ in recent]
alternating = all(signs[i] != signs[i-1] for i in range(1, len(signs))) alternating: bool = all(signs[i] != signs[i-1] for i in range(1, len(signs)))
alternation_factor = 1.0 if alternating else 0.5 alternation_factor: float = 1.0 if alternating else 0.5
if is_vertical: if is_vertical:
vert_amp = sum(abs(val) for _, val, _ in recent) / len(recent) vert_amp: float = sum(abs(val) for _, val, _ in recent) / len(recent)
horiz_vals = list(self.horiz_buffer)[-len(recent)*2:] horiz_vals: List[int] = list(self.horiz_buffer)[-len(recent)*2:]
horiz_amp = sum(abs(val) for val in horiz_vals) / len(horiz_vals) if horiz_vals else 0 horiz_amp: float = sum(abs(val) for val in horiz_vals) / len(horiz_vals) if horiz_vals else 0
isolation_factor = min(1.0, vert_amp / (horiz_amp + 0.1) * 1.2) isolation_factor: float = min(1.0, vert_amp / (horiz_amp + 0.1) * 1.2)
else: else:
horiz_amp = sum(abs(val) for _, val, _ in recent) horiz_amp: float = sum(abs(val) for _, val, _ in recent)
vert_vals = list(self.vert_buffer)[-len(recent)*2:] vert_vals: List[int] = list(self.vert_buffer)[-len(recent)*2:]
vert_amp = sum(abs(val) for val in vert_vals) / len(vert_vals) if vert_vals else 0 vert_amp: float = sum(abs(val) for val in vert_vals) / len(vert_vals) if vert_vals else 0
isolation_factor = min(1.0, horiz_amp / (vert_amp + 0.1) * 1.2) isolation_factor: float = min(1.0, horiz_amp / (vert_amp + 0.1) * 1.2)
confidence = ( confidence: float = (
amplitude_factor * 0.4 + amplitude_factor * 0.4 +
rhythm_factor * 0.2 + rhythm_factor * 0.2 +
alternation_factor * 0.2 + alternation_factor * 0.2 +
@@ -319,12 +296,12 @@ class GestureDetector:
return confidence return confidence
def detect_gestures(self): def detect_gestures(self) -> Optional[str]:
"""Recognize head gesture patterns with Apple-like intelligence.""" """Recognize head gesture patterns with Apple-like intelligence."""
if len(self.vert_peaks) + len(self.vert_troughs) >= self.required_extremes: if len(self.vert_peaks) + len(self.vert_troughs) >= self.required_extremes:
all_extremes = sorted(self.vert_peaks + self.vert_troughs, key=lambda x: x[0]) all_extremes: List[Tuple[int, int, float]] = sorted(self.vert_peaks + self.vert_troughs, key=lambda x: x[0])
confidence = self.calculate_confidence_score(all_extremes, is_vertical=True) confidence: float = self.calculate_confidence_score(all_extremes, is_vertical=True)
log.info(f"Vertical motion confidence: {confidence:.2f} (need {self.min_confidence_threshold:.2f})") log.info(f"Vertical motion confidence: {confidence:.2f} (need {self.min_confidence_threshold:.2f})")
@@ -333,9 +310,9 @@ class GestureDetector:
return "YES" return "YES"
if len(self.horiz_peaks) + len(self.horiz_troughs) >= self.required_extremes: if len(self.horiz_peaks) + len(self.horiz_troughs) >= self.required_extremes:
all_extremes = sorted(self.horiz_peaks + self.horiz_troughs, key=lambda x: x[0]) all_extremes: List[Tuple[int, int, float]] = sorted(self.horiz_peaks + self.horiz_troughs, key=lambda x: x[0])
confidence = self.calculate_confidence_score(all_extremes, is_vertical=False) confidence: float = self.calculate_confidence_score(all_extremes, is_vertical=False)
log.info(f"Horizontal motion confidence: {confidence:.2f} (need {self.min_confidence_threshold:.2f})") log.info(f"Horizontal motion confidence: {confidence:.2f} (need {self.min_confidence_threshold:.2f})")
@@ -345,7 +322,7 @@ class GestureDetector:
return None return None
def start_detection(self): def start_detection(self) -> None:
"""Begin gesture detection process.""" """Begin gesture detection process."""
log.info(f"{Colors.BOLD}{Colors.WHITE}Starting gesture detection...{Colors.RESET}") log.info(f"{Colors.BOLD}{Colors.WHITE}Starting gesture detection...{Colors.RESET}")
@@ -353,7 +330,7 @@ class GestureDetector:
log.error(f"{Colors.RED}Failed to connect to AirPods.{Colors.RESET}") log.error(f"{Colors.RED}Failed to connect to AirPods.{Colors.RESET}")
return return
data_thread = threading.Thread(target=self.process_data) data_thread: Thread = Thread(target=self.process_data)
data_thread.daemon = True data_thread.daemon = True
data_thread.start() data_thread.start()
@@ -377,5 +354,5 @@ if __name__ == "__main__":
print(f"{Colors.GREEN}• YES: {Colors.WHITE}nodding head up and down{Colors.RESET}") print(f"{Colors.GREEN}• YES: {Colors.WHITE}nodding head up and down{Colors.RESET}")
print(f"{Colors.RED}• NO: {Colors.WHITE}shaking head left and right{Colors.RESET}\n") print(f"{Colors.RED}• NO: {Colors.WHITE}shaking head left and right{Colors.RESET}\n")
detector = GestureDetector() detector: GestureDetector = GestureDetector()
detector.start_detection() detector.start_detection()

View File

@@ -1,63 +1,43 @@
import math import math
import drawille
import numpy as np import numpy as np
import logging import logging
import os import os
from colors import *
from drawille import Canvas
from logging import Logger, StreamHandler
from matplotlib.animation import FuncAnimation
from matplotlib.pyplot import Axes, Figure
from numpy.typing import NDArray
from os import terminal_size as TerminalSize
from typing import Any, Dict, List, Optional, Tuple
class Colors: handler: StreamHandler = StreamHandler()
RESET = "\033[0m"
BOLD = "\033[1m"
RED = "\033[91m"
GREEN = "\033[92m"
YELLOW = "\033[93m"
BLUE = "\033[94m"
MAGENTA = "\033[95m"
CYAN = "\033[96m"
WHITE = "\033[97m"
BG_BLACK = "\033[40m"
class ColorFormatter(logging.Formatter):
FORMATS = {
logging.DEBUG: Colors.BLUE + "[%(levelname)s] %(message)s" + Colors.RESET,
logging.INFO: Colors.GREEN + "%(message)s" + Colors.RESET,
logging.WARNING: Colors.YELLOW + "%(message)s" + Colors.RESET,
logging.ERROR: Colors.RED + "[%(levelname)s] %(message)s" + Colors.RESET,
logging.CRITICAL: Colors.RED + Colors.BOLD + "[%(levelname)s] %(message)s" + Colors.RESET
}
def format(self, record):
log_fmt = self.FORMATS.get(record.levelno)
formatter = logging.Formatter(log_fmt, datefmt="%H:%M:%S")
return formatter.format(record)
handler = logging.StreamHandler()
handler.setFormatter(ColorFormatter()) handler.setFormatter(ColorFormatter())
log = logging.getLogger(__name__) log: Logger = logging.getLogger(__name__)
log.setLevel(logging.INFO) log.setLevel(logging.INFO)
log.addHandler(handler) log.addHandler(handler)
log.propagate = False log.propagate = False
class HeadOrientation: class HeadOrientation:
def __init__(self, use_terminal=False): def __init__(self, use_terminal: bool = False) -> None:
self.orientation_offset = 5500 self.orientation_offset: int = 5500
self.o1_neutral = 19000 self.o1_neutral: int = 19000
self.o2_neutral = 0 self.o2_neutral: int = 0
self.o3_neutral = 0 self.o3_neutral: int = 0
self.calibration_samples = [] self.calibration_samples: List[List[int]] = []
self.calibration_complete = False self.calibration_complete: bool = False
self.calibration_sample_count = 10 self.calibration_sample_count: int = 10
self.fig = None self.fig: Optional[Figure] = None
self.ax = None self.ax: Optional[Axes] = None
self.arrow = None self.arrow: Any = None
self.animation = None self.animation: Optional[FuncAnimation] = None
self.use_terminal = use_terminal self.use_terminal: bool = use_terminal
def reset_calibration(self): def reset_calibration(self) -> None:
self.calibration_samples = [] self.calibration_samples = []
self.calibration_complete = False self.calibration_complete = False
def add_calibration_sample(self, orientation_values): def add_calibration_sample(self, orientation_values: List[int]) -> bool:
if len(self.calibration_samples) < self.calibration_sample_count: if len(self.calibration_samples) < self.calibration_sample_count:
self.calibration_samples.append(orientation_values) self.calibration_samples.append(orientation_values)
return False return False
@@ -66,57 +46,58 @@ class HeadOrientation:
return True return True
return True return True
def _calculate_calibration(self): def _calculate_calibration(self) -> None:
if len(self.calibration_samples) < 3: if len(self.calibration_samples) < 3:
log.warning("Not enough calibration samples") log.warning("Not enough calibration samples")
return return
samples = np.array(self.calibration_samples) samples: NDArray[[List[int]]] = np.array(self.calibration_samples)
self.o1_neutral = np.mean(samples[:, 0]) self.o1_neutral: float = np.mean(samples[:, 0])
avg_o2 = np.mean(samples[:, 1]) avg_o2: float = np.mean(samples[:, 1])
avg_o3 = np.mean(samples[:, 2]) avg_o3: float = np.mean(samples[:, 2])
self.o2_neutral = avg_o2 self.o2_neutral: float = avg_o2
self.o3_neutral = avg_o3 self.o3_neutral: float = avg_o3
log.info("Calibration complete: o1_neutral=%.2f, o2_neutral=%.2f, o3_neutral=%.2f", log.info("Calibration complete: o1_neutral=%.2f, o2_neutral=%.2f, o3_neutral=%.2f",
self.o1_neutral, self.o2_neutral, self.o3_neutral) self.o1_neutral, self.o2_neutral, self.o3_neutral)
self.calibration_complete = True self.calibration_complete = True
def calculate_orientation(self, o1, o2, o3): def calculate_orientation(self, o1: float, o2: float, o3: float) -> Dict[str, float]:
if not self.calibration_complete: if not self.calibration_complete:
return {'pitch': 0, 'yaw': 0} return {'pitch': 0, 'yaw': 0}
o1_norm = o1 - self.o1_neutral o1_norm: float = o1 - self.o1_neutral
o2_norm = o2 - self.o2_neutral o2_norm: float = o2 - self.o2_neutral
o3_norm = o3 - self.o3_neutral o3_norm: float = o3 - self.o3_neutral
pitch = (o2_norm + o3_norm) / 2 / 32000 * 180 pitch: float = (o2_norm + o3_norm) / 2 / 32000 * 180
yaw = (o2_norm - o3_norm) / 2 / 32000 * 180 yaw: float = (o2_norm - o3_norm) / 2 / 32000 * 180
return {'pitch': pitch, 'yaw': yaw} return {'pitch': pitch, 'yaw': yaw}
def create_face_art(self, pitch, yaw): def create_face_art(self, pitch: float, yaw: float) -> str:
if self.use_terminal: if self.use_terminal:
try: try:
ts = os.get_terminal_size() ts: TerminalSize = os.get_terminal_size()
width, height = ts.columns, ts.lines * 2 width, height = ts.columns, ts.lines * 2
except Exception: except Exception:
width, height = 80, 40 width, height = 80, 40
else: else:
width, height = 80, 40 width, height = 80, 40
center_x, center_y = width // 2, height // 2 center_x, center_y = width // 2, height // 2
radius = (min(width, height) // 2 - 2) // 2 radius: int = (min(width, height) // 2 - 2) // 2
pitch_rad = math.radians(pitch) pitch_rad: float = math.radians(pitch)
yaw_rad = math.radians(yaw) yaw_rad: float = math.radians(yaw)
canvas = drawille.Canvas() canvas: Canvas = Canvas()
def rotate_point(x, y, z, pitch_r, yaw_r):
def rotate_point(x: float, y: float, z: float, pitch_r: float, yaw_r: float) -> Tuple[int, int]:
cos_y, sin_y = math.cos(yaw_r), math.sin(yaw_r) cos_y, sin_y = math.cos(yaw_r), math.sin(yaw_r)
cos_p, sin_p = math.cos(pitch_r), math.sin(pitch_r) cos_p, sin_p = math.cos(pitch_r), math.sin(pitch_r)
x1 = x * cos_y - z * sin_y x1: float = x * cos_y - z * sin_y
z1 = x * sin_y + z * cos_y z1: float = x * sin_y + z * cos_y
y1 = y * cos_p - z1 * sin_p y1: float = y * cos_p - z1 * sin_p
z2 = y * sin_p + z1 * cos_p z2: float = y * sin_p + z1 * cos_p
scale = 1 + (z2 / width) scale: float = 1 + (z2 / width)
return int(center_x + x1 * scale), int(center_y + y1 * scale) return int(center_x + x1 * scale), int(center_y + y1 * scale)
for angle in range(0, 360, 2): for angle in range(0, 360, 2):
rad = math.radians(angle) rad: float = math.radians(angle)
x = radius * math.cos(rad) x: float = radius * math.cos(rad)
y = radius * math.sin(rad) y: float = radius * math.sin(rad)
x1, y1 = rotate_point(x, y, 0, pitch_rad, yaw_rad) x1, y1 = rotate_point(x, y, 0, pitch_rad, yaw_rad)
canvas.set(x1, y1) canvas.set(x1, y1)
for eye in [(-radius//2, -radius//3, 2), (radius//2, -radius//3, 2)]: for eye in [(-radius//2, -radius//3, 2), (radius//2, -radius//3, 2)]:
@@ -129,14 +110,14 @@ class HeadOrientation:
for dx in [-1, 0, 1]: for dx in [-1, 0, 1]:
for dy in [-1, 0, 1]: for dy in [-1, 0, 1]:
canvas.set(nx + dx, ny + dy) canvas.set(nx + dx, ny + dy)
smile_depth = radius // 8 smile_depth: int = radius // 8
mouth_local_y = radius // 4 mouth_local_y: int = radius // 4
mouth_length = radius mouth_length: int = radius
for x_offset in range(-mouth_length // 2, mouth_length // 2 + 1): for x_offset in range(-mouth_length // 2, mouth_length // 2 + 1):
norm = abs(x_offset) / (mouth_length / 2) norm: float = abs(x_offset) / (mouth_length / 2)
y_offset = int((1 - norm ** 2) * smile_depth) y_offset: int = int((1 - norm ** 2) * smile_depth)
local_x = x_offset local_x: int = x_offset
local_y = mouth_local_y + y_offset local_y: int = mouth_local_y + y_offset
mx, my = rotate_point(local_x, local_y, 0, pitch_rad, yaw_rad) mx, my = rotate_point(local_x, local_y, 0, pitch_rad, yaw_rad)
canvas.set(mx, my) canvas.set(mx, my)
return canvas.frame() return canvas.frame()

View File

@@ -1,61 +1,41 @@
import struct
import bluetooth
import threading
import time
from datetime import datetime
import numpy as np
import matplotlib.pyplot as plt
from matplotlib.animation import FuncAnimation
import os
import asciichartpy as acp import asciichartpy as acp
import logging
import matplotlib.pyplot as plt
import numpy as np
import os
import struct
import time
from bluetooth import BluetoothSocket
from colors import *
from connection_manager import ConnectionManager
from datetime import datetime as DateTime
from drawille import Canvas
from head_orientation import HeadOrientation
from logging import Logger, StreamHandler
from matplotlib.animation import FuncAnimation
from matplotlib.legend import Legend
from matplotlib.pyplot import Axes, Figure
from numpy.typing import NDArray
from rich.live import Live from rich.live import Live
from rich.layout import Layout from rich.layout import Layout
from rich.panel import Panel from rich.panel import Panel
from rich.console import Console from rich.console import Console
import drawille from threading import Lock, Thread
from head_orientation import HeadOrientation from typing import Any, Dict, List, Optional, TextIO, Tuple, Union
import logging
from connection_manager import ConnectionManager
class Colors: handler: StreamHandler = StreamHandler()
RESET = "\033[0m"
BOLD = "\033[1m"
RED = "\033[91m"
GREEN = "\033[92m"
YELLOW = "\033[93m"
BLUE = "\033[94m"
MAGENTA = "\033[95m"
CYAN = "\033[96m"
WHITE = "\033[97m"
BG_BLACK = "\033[40m"
class ColorFormatter(logging.Formatter):
FORMATS = {
logging.DEBUG: Colors.BLUE + "[%(levelname)s] %(message)s" + Colors.RESET,
logging.INFO: Colors.GREEN + "%(message)s" + Colors.RESET,
logging.WARNING: Colors.YELLOW + "%(message)s" + Colors.RESET,
logging.ERROR: Colors.RED + "[%(levelname)s] %(message)s" + Colors.RESET,
logging.CRITICAL: Colors.RED + Colors.BOLD + "[%(levelname)s] %(message)s" + Colors.RESET
}
def format(self, record):
log_fmt = self.FORMATS.get(record.levelno)
formatter = logging.Formatter(log_fmt, datefmt="%H:%M:%S")
return formatter.format(record)
handler = logging.StreamHandler()
handler.setFormatter(ColorFormatter()) handler.setFormatter(ColorFormatter())
logger = logging.getLogger("airpods-head-tracking") logger: Logger = logging.getLogger("airpods-head-tracking")
logger.setLevel(logging.INFO) logger.setLevel(logging.INFO)
logger.addHandler(handler) logger.addHandler(handler)
logger.propagate = True logger.propagate = True
INIT_CMD = "00 00 04 00 01 00 02 00 00 00 00 00 00 00 00 00" INIT_CMD: str = "00 00 04 00 01 00 02 00 00 00 00 00 00 00 00 00"
NOTIF_CMD = "04 00 04 00 0F 00 FF FF FE FF" NOTIF_CMD: str = "04 00 04 00 0F 00 FF FF FE FF"
START_CMD = "04 00 04 00 17 00 00 00 10 00 10 00 08 A1 02 42 0B 08 0E 10 02 1A 05 01 40 9C 00 00" START_CMD: str = "04 00 04 00 17 00 00 00 10 00 10 00 08 A1 02 42 0B 08 0E 10 02 1A 05 01 40 9C 00 00"
STOP_CMD = "04 00 04 00 17 00 00 00 10 00 11 00 08 7E 10 02 42 0B 08 4E 10 02 1A 05 01 00 00 00 00" STOP_CMD: str = "04 00 04 00 17 00 00 00 10 00 11 00 08 7E 10 02 42 0B 08 4E 10 02 1A 05 01 00 00 00 00"
KEY_FIELDS = { KEY_FIELDS: Dict[str, Tuple[int, int]] = {
"orientation 1": (43, 2), "orientation 1": (43, 2),
"orientation 2": (45, 2), "orientation 2": (45, 2),
"orientation 3": (47, 2), "orientation 3": (47, 2),
@@ -68,28 +48,28 @@ KEY_FIELDS = {
} }
class AirPodsTracker: class AirPodsTracker:
def __init__(self): def __init__(self) -> None:
self.sock = None self.sock: BluetoothSocket = None
self.recording = False self.recording: bool = False
self.log_file = None self.log_file: Optional[TextIO] = None
self.listener_thread = None self.listener_thread: Optional[Thread] = None
self.bt_addr = "28:2D:7F:C2:05:5B" self.bt_addr: str = "28:2D:7F:C2:05:5B"
self.psm = 0x1001 self.psm: int = 0x1001
self.raw_packets = [] self.raw_packets: List[bytes] = []
self.parsed_packets = [] self.parsed_packets: List[bytes] = []
self.live_data = [] self.live_data: List[bytes] = []
self.live_plotting = False self.live_plotting: bool = False
self.animation = None self.animation: FuncAnimation = None
self.fig = None self.fig: Optional[Figure] = None
self.axes = None self.axes: Optional[Axes] = None
self.lines = {} self.lines: Dict[str, Any] = {}
self.selected_fields = [] self.selected_fields: List[str] = []
self.data_lock = threading.Lock() self.data_lock: Lock = Lock()
self.orientation_offset = 5500 self.orientation_offset: int = 5500
self.use_terminal = True # '--terminal' in sys.argv self.use_terminal: bool = True # '--terminal' in sys.argv
self.orientation_visualizer = HeadOrientation(use_terminal=self.use_terminal) self.orientation_visualizer: HeadOrientation = HeadOrientation(use_terminal=self.use_terminal)
self.conn = None self.conn: Optional[ConnectionManager] = None
def connect(self): def connect(self):
try: try:
@@ -102,35 +82,35 @@ class AirPodsTracker:
self.sock.send(bytes.fromhex(NOTIF_CMD)) self.sock.send(bytes.fromhex(NOTIF_CMD))
logger.info("Sent initialization command.") logger.info("Sent initialization command.")
self.listener_thread = threading.Thread(target=self.listen, daemon=True) self.listener_thread = Thread(target=self.listen, daemon=True)
self.listener_thread.start() self.listener_thread.start()
return True return True
except Exception as e: except Exception as e:
logger.error("Connection error: %s", e) logger.error("Connection error: %s", e)
return False return False
def start_tracking(self, duration=None): def start_tracking(self, duration: Optional[float] = None) -> None:
if not self.recording: if not self.recording:
self.conn.send_start() self.conn.send_start()
filename = "head_tracking_" + datetime.now().strftime("%Y%m%d_%H%M%S") + ".log" filename: str = f"head_tracking_{DateTime.now().strftime('%Y%m%d_%H%M%S')}.log"
self.log_file = open(filename, "w") self.log_file = open(filename, "w")
self.recording = True self.recording = True
logger.info("Recording started. Saving data to %s", filename) logger.info("Recording started. Saving data to %s", filename)
if duration is not None and duration > 0: if duration is not None and duration > 0:
def auto_stop(): def auto_stop() -> None:
time.sleep(duration) time.sleep(duration)
if self.recording: if self.recording:
self.stop_tracking() self.stop_tracking()
logger.info("Recording automatically stopped after %s seconds.", duration) logger.info("Recording automatically stopped after %s seconds.", duration)
timer_thread = threading.Thread(target=auto_stop, daemon=True) timer_thread = Thread(target=auto_stop, daemon=True)
timer_thread.start() timer_thread.start()
logger.info("Will automatically stop recording after %s seconds.", duration) logger.info("Will automatically stop recording after %s seconds.", duration)
else: else:
logger.info("Already recording.") logger.info("Already recording.")
def stop_tracking(self): def stop_tracking(self) -> None:
if self.recording: if self.recording:
self.conn.send_stop() self.conn.send_stop()
self.recording = False self.recording = False
@@ -141,39 +121,41 @@ class AirPodsTracker:
else: else:
logger.info("Not currently recording.") logger.info("Not currently recording.")
def format_hex(self, data): def format_hex(self, data: bytes) -> str:
hex_str = data.hex() hex_str: str = data.hex()
return ' '.join(hex_str[i:i + 2] for i in range(0, len(hex_str), 2)) return ' '.join(hex_str[i:i + 2] for i in range(0, len(hex_str), 2))
def parse_raw_packet(self, hex_string): def parse_raw_packet(self, hex_string: str) -> bytes:
return bytes.fromhex(hex_string.replace(" ", "")) return bytes.fromhex(hex_string.replace(" ", ""))
def interpret_bytes(self, raw_bytes, start, length, data_type="signed_short"): def interpret_bytes(self, raw_bytes: bytes, start: int, length: int, data_type: str = "signed_short") -> Optional[Union[int, float]]:
if start + length > len(raw_bytes): if start + length > len(raw_bytes):
return None return None
if data_type == "signed_short": match data_type:
return int.from_bytes(raw_bytes[start:start + 2], byteorder='little', signed=True) case "signed_short":
elif data_type == "unsigned_short": return int.from_bytes(raw_bytes[start:start + 2], byteorder='little', signed=True)
return int.from_bytes(raw_bytes[start:start + 2], byteorder='little', signed=False) case "unsigned_short":
elif data_type == "signed_short_be": return int.from_bytes(raw_bytes[start:start + 2], byteorder='little', signed=False)
return int.from_bytes(raw_bytes[start:start + 2], byteorder='big', signed=True) case "signed_short_be":
elif data_type == "float_le": return int.from_bytes(raw_bytes[start:start + 2], byteorder='big', signed=True)
if start + 4 <= len(raw_bytes): case "float_le":
return struct.unpack('<f', raw_bytes[start:start + 4])[0] if start + 4 <= len(raw_bytes):
elif data_type == "float_be": return struct.unpack('<f', raw_bytes[start:start + 4])[0]
if start + 4 <= len(raw_bytes): case "float_be":
return struct.unpack('>f', raw_bytes[start:start + 4])[0] if start + 4 <= len(raw_bytes):
return None return struct.unpack('>f', raw_bytes[start:start + 4])[0]
case _:
return None
def normalize_orientation(self, value, field_name): def normalize_orientation(self, value: Optional[Union[int, float]], field_name: str) -> Optional[Union[int, float]]:
if 'orientation' in field_name.lower(): if 'orientation' in field_name.lower():
return value + self.orientation_offset return value + self.orientation_offset
return value return value
def parse_packet_all_fields(self, raw_bytes): def parse_packet_all_fields(self, raw_bytes: bytes) -> Dict[str, Union[int, float]]:
packet = {} packet: Dict[str, Union[int, float]] = {}
packet["seq_num"] = int.from_bytes(raw_bytes[12:14], byteorder='little') packet["seq_num"] = int.from_bytes(raw_bytes[12:14], byteorder='little')
@@ -186,14 +168,14 @@ class AirPodsTracker:
packet[field_name] = self.normalize_orientation(raw_value, field_name) packet[field_name] = self.normalize_orientation(raw_value, field_name)
for i in range(30, min(90, len(raw_bytes) - 1), 2): for i in range(30, min(90, len(raw_bytes) - 1), 2):
field_name = f"byte_{i:02d}" field_name: str = f"byte_{i:02d}"
raw_value = self.interpret_bytes(raw_bytes, i, 2, "signed_short") raw_value: Optional[Union[int, float]] = self.interpret_bytes(raw_bytes, i, 2, "signed_short")
if raw_value is not None: if raw_value is not None:
packet[field_name] = self.normalize_orientation(raw_value, field_name) packet[field_name] = self.normalize_orientation(raw_value, field_name)
return packet return packet
def apply_dark_theme(self, fig, axes): def apply_dark_theme(self, fig: Figure, axes: List[Axes]) -> None:
fig.patch.set_facecolor('#1e1e1e') fig.patch.set_facecolor('#1e1e1e')
for ax in axes: for ax in axes:
ax.set_facecolor('#2d2d2d') ax.set_facecolor('#2d2d2d')
@@ -210,21 +192,21 @@ class AirPodsTracker:
for spine in ax.spines.values(): for spine in ax.spines.values():
spine.set_color('#555555') spine.set_color('#555555')
legend = ax.get_legend() legend: Optional[Legend] = ax.get_legend()
if (legend): if (legend):
legend.get_frame().set_facecolor('#2d2d2d') legend.get_frame().set_facecolor('#2d2d2d')
legend.get_frame().set_alpha(0.7) legend.get_frame().set_alpha(0.7)
for text in legend.get_texts(): for text in legend.get_texts():
text.set_color('white') text.set_color('white')
def listen(self): def listen(self) -> None:
while True: while True:
try: try:
data = self.sock.recv(1024) data: bytes = self.sock.recv(1024)
formatted = self.format_hex(data) formatted: str = self.format_hex(data)
timestamp = datetime.now().isoformat() timestamp: str = DateTime.now().isoformat()
is_valid = self.is_valid_tracking_packet(formatted) is_valid: bool = self.is_valid_tracking_packet(formatted)
if not self.live_plotting: if not self.live_plotting:
if is_valid: if is_valid:
@@ -238,8 +220,8 @@ class AirPodsTracker:
self.log_file.flush() self.log_file.flush()
try: try:
raw_bytes = self.parse_raw_packet(formatted) raw_bytes: bytes = self.parse_raw_packet(formatted)
packet = self.parse_packet_all_fields(raw_bytes) packet: Dict[str, Union[int, float]] = self.parse_packet_all_fields(raw_bytes)
with self.data_lock: with self.data_lock:
self.live_data.append(packet) self.live_data.append(packet)
@@ -253,7 +235,7 @@ class AirPodsTracker:
logger.error("Error receiving data: %s", e) logger.error("Error receiving data: %s", e)
break break
def load_log_file(self, filepath): def load_log_file(self, filepath: str) -> bool:
self.raw_packets = [] self.raw_packets = []
self.parsed_packets = [] self.parsed_packets = []
try: try:
@@ -262,11 +244,11 @@ class AirPodsTracker:
line = line.strip() line = line.strip()
if line: if line:
try: try:
raw_bytes = self.parse_raw_packet(line) raw_bytes: bytes = self.parse_raw_packet(line)
self.raw_packets.append(raw_bytes) self.raw_packets.append(raw_bytes)
packet = self.parse_packet_all_fields(raw_bytes) packet: Dict[str, Union[int, float]] = self.parse_packet_all_fields(raw_bytes)
min_seq_num = min( min_seq_num: int = min(
[parsed_packet["seq_num"] for parsed_packet in self.parsed_packets], default=0 [parsed_packet["seq_num"] for parsed_packet in self.parsed_packets], default=0
) )
@@ -282,26 +264,26 @@ class AirPodsTracker:
logger.error(f"Error loading log file: {e}") logger.error(f"Error loading log file: {e}")
return False return False
def extract_field_values(self, field_name, data_source='loaded'): def extract_field_values(self, field_name: str, data_source: str = 'loaded') -> List[Union[int, float]]:
if data_source == 'loaded': if data_source == 'loaded':
data = self.parsed_packets data: List[Dict[str, Union[int, float]]] = self.parsed_packets
else: else:
with self.data_lock: with self.data_lock:
data = self.live_data.copy() data: List[Dict[str, Union[int, float]]] = self.live_data.copy()
values = [packet.get(field_name, 0) for packet in data if field_name in packet] values: List[Union[int, float]] = [packet.get(field_name, 0) for packet in data if field_name in packet]
if data_source == 'live' and len(values) > 5: if data_source == 'live' and len(values) > 5:
try: try:
values = np.array(values, dtype=float) values: NDArray[Any] = np.array(values, dtype=float)
values = np.convolve(values, np.ones(5) / 5, mode='valid') values = np.convolve(values, np.ones(5) / 5, mode='valid')
except Exception as e: except Exception as e:
logger.warning(f"Smoothing error (non-critical): {e}") logger.warning(f"Smoothing error (non-critical): {e}")
return values return values
def is_valid_tracking_packet(self, hex_string): def is_valid_tracking_packet(self, hex_string: str) -> bool:
standard_header = "04 00 04 00 17 00 00 00 10 00" standard_header: str = "04 00 04 00 17 00 00 00 10 00"
if not hex_string.startswith(standard_header): if not hex_string.startswith(standard_header):
if self.live_plotting: if self.live_plotting:
@@ -316,13 +298,13 @@ class AirPodsTracker:
return True return True
def plot_fields(self, field_names=None): def plot_fields(self, field_names: Optional[List[str]] = None) -> None:
if not self.parsed_packets: if not self.parsed_packets:
logger.error("No data to plot. Load a log file first.") logger.error("No data to plot. Load a log file first.")
return return
if field_names is None: if field_names is None:
field_names = list(KEY_FIELDS.keys()) field_names: List[str] = list(KEY_FIELDS.keys())
if not self.orientation_visualizer.calibration_complete: if not self.orientation_visualizer.calibration_complete:
if len(self.parsed_packets) < self.orientation_visualizer.calibration_sample_count: if len(self.parsed_packets) < self.orientation_visualizer.calibration_sample_count:
@@ -339,16 +321,16 @@ class AirPodsTracker:
self._plot_fields_terminal(field_names) self._plot_fields_terminal(field_names)
else: else:
acceleration_fields = [f for f in field_names if 'acceleration' in f.lower()] acceleration_fields: List[str] = [f for f in field_names if 'acceleration' in f.lower()]
orientation_fields = [f for f in field_names if 'orientation' in f.lower()] orientation_fields: List[str] = [f for f in field_names if 'orientation' in f.lower()]
other_fields = [f for f in field_names if f not in acceleration_fields + orientation_fields] other_fields: List[str] = [f for f in field_names if f not in acceleration_fields + orientation_fields]
fig, axes = plt.subplots(3, 1, figsize=(14, 12), sharex=True) fig, axes = plt.subplots(3, 1, figsize=(14, 12), sharex=True)
self.apply_dark_theme(fig, axes) self.apply_dark_theme(fig, axes)
acceleration_colors = ['#FFFF00', '#00FFFF'] acceleration_colors: List[str] = ['#FFFF00', '#00FFFF']
orientation_colors = ['#FF00FF', '#00FF00', '#FFA500'] orientation_colors: List[str] = ['#FF00FF', '#00FF00', '#FFA500']
other_colors = ['#52b788', '#f4a261', '#e76f51', '#2a9d8f'] other_colors: List[str] = ['#52b788', '#f4a261', '#e76f51', '#2a9d8f']
if acceleration_fields: if acceleration_fields:
for i, field in enumerate(acceleration_fields): for i, field in enumerate(acceleration_fields):
@@ -375,17 +357,17 @@ class AirPodsTracker:
plt.tight_layout() plt.tight_layout()
plt.show() plt.show()
def _plot_fields_terminal(self, field_names): def _plot_fields_terminal(self, field_names: List[str]) -> None:
"""Internal method for terminal-based plotting""" """Internal method for terminal-based plotting"""
terminal_width = os.get_terminal_size().columns terminal_width: int = os.get_terminal_size().columns
plot_width = min(terminal_width - 10, 120) plot_width: int = min(terminal_width - 10, 120)
plot_height = 15 plot_height: int = 15
acceleration_fields = [f for f in field_names if 'acceleration' in f.lower()] acceleration_fields: List[str] = [f for f in field_names if 'acceleration' in f.lower()]
orientation_fields = [f for f in field_names if 'orientation' in f.lower()] orientation_fields: List[str] = [f for f in field_names if 'orientation' in f.lower()]
other_fields = [f for f in field_names if f not in acceleration_fields + orientation_fields] other_fields: List[str] = [f for f in field_names if f not in acceleration_fields + orientation_fields]
def plot_group(fields, title): def plot_group(fields: List[str], title: str) -> None:
if not fields: if not fields:
return return
@@ -393,40 +375,39 @@ class AirPodsTracker:
print("=" * len(title)) print("=" * len(title))
for field in fields: for field in fields:
values = self.extract_field_values(field) values: List[float] = self.extract_field_values(field)
if len(values) > plot_width: if len(values) > plot_width:
values = values[-plot_width:] values = values[-plot_width:]
if title == "Acceleration Data": if title == "Acceleration Data":
chart = acp.plot(values, {'height': plot_height}) chart: str = acp.plot(values, {'height': plot_height})
print(chart) print(chart)
else: else:
chart = acp.plot(values, {'height': plot_height}) chart: str = acp.plot(values, {'height': plot_height})
print(chart) print(chart)
print(f"Min: {min(values):.2f}, Max: {max(values):.2f}, " + print(f"Min: {min(values):.2f}, Max: {max(values):.2f}, " + f"Mean: {np.mean(values):.2f}")
f"Mean: {np.mean(values):.2f}")
print() print()
plot_group(acceleration_fields, "Acceleration Data") plot_group(acceleration_fields, "Acceleration Data")
plot_group(orientation_fields, "Orientation Data") plot_group(orientation_fields, "Orientation Data")
plot_group(other_fields, "Other Fields") plot_group(other_fields, "Other Fields")
def create_braille_plot(self, values, width=80, height=20, y_label=True, fixed_y_min=None, fixed_y_max=None): def create_braille_plot(self, values: List[float], width: int = 80, height: int = 20, y_label: bool = True, fixed_y_min: Optional[float] = None, fixed_y_max: Optional[float] = None) -> str:
canvas = drawille.Canvas() canvas: Canvas = Canvas()
if fixed_y_min is None or fixed_y_max is None: if fixed_y_min is None or fixed_y_max is None:
local_min, local_max = min(values), max(values) local_min, local_max = min(values), max(values)
else: else:
local_min, local_max = fixed_y_min, fixed_y_max local_min, local_max = fixed_y_min, fixed_y_max
y_range = local_max - local_min or 1 y_range: float = local_max - local_min or 1
x_step = max(1, len(values) // width) x_step: int = max(1, len(values) // width)
for i, v in enumerate(values[::x_step]): for i, v in enumerate(values[::x_step]):
y = int(((v - local_min) / y_range) * (height * 2 - 1)) y: int = int(((v - local_min) / y_range) * (height * 2 - 1))
canvas.set(i, y) canvas.set(i, y)
frame = canvas.frame() frame: str = canvas.frame()
if y_label: if y_label:
lines = frame.split('\n') lines: List[str] = frame.split('\n')
labeled_lines = [] labeled_lines: List[str] = []
for idx, line in enumerate(lines): for idx, line in enumerate(lines):
if idx == 0: if idx == 0:
labeled_lines.append(f"{local_max:6.0f} {line}") labeled_lines.append(f"{local_max:6.0f} {line}")
@@ -437,17 +418,17 @@ class AirPodsTracker:
frame = "\n".join(labeled_lines) frame = "\n".join(labeled_lines)
return frame return frame
def _start_live_plotting_terminal(self, record_data=False, duration=None): def _start_live_plotting_terminal(self, record_data: bool = False, duration: Optional[float] = None) -> None:
import sys, select, tty, termios import sys, select, tty, termios
old_settings = termios.tcgetattr(sys.stdin) old_settings = termios.tcgetattr(sys.stdin)
tty.setcbreak(sys.stdin.fileno()) tty.setcbreak(sys.stdin.fileno())
console = Console() console: Console = Console()
term_width = console.width term_width: int = console.width
plot_width = round(min(term_width / 2 - 15, 120)) plot_width: int = round(min(term_width / 2 - 15, 120))
ori_height = 10 ori_height: int = 10
def make_compact_layout(): def make_compact_layout() -> Layout:
layout = Layout() layout: Layout = Layout()
layout.split_column( layout.split_column(
Layout(name="header", size=3), Layout(name="header", size=3),
Layout(name="main", ratio=1), Layout(name="main", ratio=1),
@@ -466,7 +447,7 @@ class AirPodsTracker:
) )
return layout return layout
layout = make_compact_layout() layout: Layout = make_compact_layout()
try: try:
import time import time
@@ -479,76 +460,76 @@ class AirPodsTracker:
logger.info("Paused" if self.paused else "Resumed") logger.info("Paused" if self.paused else "Resumed")
if self.paused: if self.paused:
time.sleep(0.1) time.sleep(0.1)
rec_str = " [red][REC][/red]" if record_data else "" rec_str: str = " [red][REC][/red]" if record_data else ""
left = "AirPods Head Tracking - v1.0.0" left: str = "AirPods Head Tracking - v1.0.0"
right = "Ctrl+C - Close | p - Pause" + rec_str right: str = "Ctrl+C - Close | p - Pause" + rec_str
status = "[bold red]Paused[/bold red]" status: str = "[bold red]Paused[/bold red]"
header = list(" " * term_width) header: List[str] = list(" " * term_width)
header[0:len(left)] = list(left) header[0:len(left)] = list(left)
header[term_width - len(right):] = list(right) header[term_width - len(right):] = list(right)
start = (term_width - len(status)) // 2 start: int = (term_width - len(status)) // 2
header[start:start+len(status)] = list(status) header[start:start+len(status)] = list(status)
header_text = "".join(header) header_text: str = "".join(header)
layout["header"].update(Panel(header_text, style="bold white on black")) layout["header"].update(Panel(header_text, style="bold white on black"))
continue continue
with self.data_lock: with self.data_lock:
if len(self.live_data) < 1: if len(self.live_data) < 1:
continue continue
latest = self.live_data[-1] latest: Dict[str, float] = self.live_data[-1]
data = self.live_data[-plot_width:] data: List[Dict[str, float]] = self.live_data[-plot_width:]
if not self.orientation_visualizer.calibration_complete: if not self.orientation_visualizer.calibration_complete:
sample = [ sample: List[float] = [
latest.get('orientation 1', 0), latest.get('orientation 1', 0),
latest.get('orientation 2', 0), latest.get('orientation 2', 0),
latest.get('orientation 3', 0) latest.get('orientation 3', 0)
] ]
self.orientation_visualizer.add_calibration_sample(sample) self.orientation_visualizer.add_calibration_sample(sample)
time.sleep(0.05) time.sleep(0.05)
rec_str = " [red][REC][/red]" if record_data else "" rec_str: str = " [red][REC][/red]" if record_data else ""
left = "AirPods Head Tracking - v1.0.0" left: str = "AirPods Head Tracking - v1.0.0"
status = "[bold yellow]Calibrating...[/bold yellow]" status: str = "[bold yellow]Calibrating...[/bold yellow]"
right = "Ctrl+C - Close | p - Pause" right: str = "Ctrl+C - Close | p - Pause"
remaining = max(term_width - len(left) - len(right), 0) remaining: int = max(term_width - len(left) - len(right), 0)
header_text = f"{left}{status.center(remaining)}{right}{rec_str}" header_text: str = f"{left}{status.center(remaining)}{right}{rec_str}"
layout["header"].update(Panel(header_text, style="bold white on black")) layout["header"].update(Panel(header_text, style="bold white on black"))
live.refresh() live.refresh()
continue continue
o1 = latest.get('orientation 1', 0) o1: float = latest.get('orientation 1', 0)
o2 = latest.get('orientation 2', 0) o2: float = latest.get('orientation 2', 0)
o3 = latest.get('orientation 3', 0) o3: float = latest.get('orientation 3', 0)
orientation = self.orientation_visualizer.calculate_orientation(o1, o2, o3) orientation: Dict[str, float] = self.orientation_visualizer.calculate_orientation(o1, o2, o3)
pitch = orientation['pitch'] pitch: float = orientation['pitch']
yaw = orientation['yaw'] yaw: float = orientation['yaw']
h_accel = [p.get('Horizontal Acceleration', 0) for p in data] h_accel: List[float] = [p.get('Horizontal Acceleration', 0) for p in data]
v_accel = [p.get('Vertical Acceleration', 0) for p in data] v_accel: List[float] = [p.get('Vertical Acceleration', 0) for p in data]
if len(h_accel) > plot_width: if len(h_accel) > plot_width:
h_accel = h_accel[-plot_width:] h_accel = h_accel[-plot_width:]
if len(v_accel) > plot_width: if len(v_accel) > plot_width:
v_accel = v_accel[-plot_width:] v_accel = v_accel[-plot_width:]
global_min = min(min(v_accel), min(h_accel)) global_min: float = min(min(v_accel), min(h_accel))
global_max = max(max(v_accel), max(h_accel)) global_max: float = max(max(v_accel), max(h_accel))
config_acc = {'height': 20, 'min': global_min, 'max': global_max} config_acc: Dict[str, float] = {'height': 20, 'min': global_min, 'max': global_max}
vert_plot = acp.plot(v_accel, config_acc) vert_plot: str = acp.plot(v_accel, config_acc)
horiz_plot = acp.plot(h_accel, config_acc) horiz_plot: str = acp.plot(h_accel, config_acc)
rec_str = " [red][REC][/red]" if record_data else "" rec_str: str = " [red][REC][/red]" if record_data else ""
left = "AirPods Head Tracking - v1.0.0" left: str = "AirPods Head Tracking - v1.0.0"
right = "Ctrl+C - Close | p - Pause" + rec_str right: str = "Ctrl+C - Close | p - Pause" + rec_str
status = "[bold green]Live[/bold green]" status: str = "[bold green]Live[/bold green]"
header = list(" " * term_width) header: List[str] = list(" " * term_width)
header[0:len(left)] = list(left) header[0:len(left)] = list(left)
header[term_width - len(right):] = list(right) header[term_width - len(right):] = list(right)
start = (term_width - len(status)) // 2 start: int = (term_width - len(status)) // 2
header[start:start+len(status)] = list(status) header[start:start+len(status)] = list(status)
header_text = "".join(header) header_text: str = "".join(header)
layout["header"].update(Panel(header_text, style="bold white on black")) layout["header"].update(Panel(header_text, style="bold white on black"))
face_art = self.orientation_visualizer.create_face_art(pitch, yaw) face_art: str = self.orientation_visualizer.create_face_art(pitch, yaw)
layout["accelerations"]["vertical"].update(Panel( layout["accelerations"]["vertical"].update(Panel(
"[bold yellow]Vertical Acceleration[/]\n" + "[bold yellow]Vertical Acceleration[/]\n" +
vert_plot + "\n" + vert_plot + "\n" +
@@ -563,15 +544,15 @@ class AirPodsTracker:
)) ))
layout["orientations"]["face"].update(Panel(face_art, title="[green]Orientation - Visualization[/]", style="green")) layout["orientations"]["face"].update(Panel(face_art, title="[green]Orientation - Visualization[/]", style="green"))
o2_values = [p.get('orientation 2', 0) for p in data[-plot_width:]] o2_values: List[float] = [p.get('orientation 2', 0) for p in data[-plot_width:]]
o3_values = [p.get('orientation 3', 0) for p in data[-plot_width:]] o3_values: List[float] = [p.get('orientation 3', 0) for p in data[-plot_width:]]
o2_values = o2_values[:plot_width] o2_values: List[float] = o2_values[:plot_width]
o3_values = o3_values[:plot_width] o3_values: List[float] = o3_values[:plot_width]
common_min = min(min(o2_values), min(o3_values)) common_min: float = min(min(o2_values), min(o3_values))
common_max = max(max(o2_values), max(o3_values)) common_max: float = max(max(o2_values), max(o3_values))
config_ori = {'height': ori_height, 'min': common_min, 'max': common_max, 'format': "{:6.0f}"} config_ori: Dict[str, float] = {'height': ori_height, 'min': common_min, 'max': common_max, 'format': "{:6.0f}"}
chart_o2 = acp.plot(o2_values, config_ori) chart_o2: str = acp.plot(o2_values, config_ori)
chart_o3 = acp.plot(o3_values, config_ori) chart_o3: str = acp.plot(o3_values, config_ori)
layout["orientations"]["raw"].update(Panel( layout["orientations"]["raw"].update(Panel(
"[bold yellow]Orientation 1:[/]\n" + chart_o2 + "\n" + "[bold yellow]Orientation 1:[/]\n" + chart_o2 + "\n" +
f"Cur: {o2_values[-1]:6.1f} | Min: {min(o2_values):6.1f} | Max: {max(o2_values):6.1f}\n\n" + f"Cur: {o2_values[-1]:6.1f} | Min: {min(o2_values):6.1f} | Max: {max(o2_values):6.1f}\n\n" +
@@ -591,10 +572,10 @@ class AirPodsTracker:
finally: finally:
termios.tcsetattr(sys.stdin, termios.TCSADRAIN, old_settings) termios.tcsetattr(sys.stdin, termios.TCSADRAIN, old_settings)
def _start_live_plotting(self, record_data=False, duration=None): def _start_live_plotting(self, record_data: bool = False, duration: Optional[float] = None) -> None:
terminal_width = os.get_terminal_size().columns terminal_width: int = os.get_terminal_size().columns
plot_width = min(terminal_width - 10, 80) plot_width: int = min(terminal_width - 10, 80)
plot_height = 10 plot_height: int = 10
try: try:
while True: while True:
@@ -605,13 +586,13 @@ class AirPodsTracker:
time.sleep(0.1) time.sleep(0.1)
continue continue
data = self.live_data[-plot_width:] data: List[Dict[str, float]] = self.live_data[-plot_width:]
acceleration_fields = [f for f in KEY_FIELDS.keys() if 'acceleration' in f.lower()] acceleration_fields: List[str] = [f for f in KEY_FIELDS.keys() if 'acceleration' in f.lower()]
orientation_fields = [f for f in KEY_FIELDS.keys() if 'orientation' in f.lower()] orientation_fields: List[str] = [f for f in KEY_FIELDS.keys() if 'orientation' in f.lower()]
other_fields = [f for f in KEY_FIELDS.keys() if f not in acceleration_fields + orientation_fields] other_fields: List[str] = [f for f in KEY_FIELDS.keys() if f not in acceleration_fields + orientation_fields]
def plot_group(fields, title): def plot_group(fields: List[str], title: str) -> None:
if not fields: if not fields:
return return
@@ -619,9 +600,9 @@ class AirPodsTracker:
print("=" * len(title)) print("=" * len(title))
for field in fields: for field in fields:
values = [packet.get(field, 0) for packet in data if field in packet] values: List[float] = [packet.get(field, 0) for packet in data if field in packet]
if len(values) > 0: if len(values) > 0:
chart = acp.plot(values, {'height': plot_height}) chart: str = acp.plot(values, {'height': plot_height})
print(chart) print(chart)
print(f"Current: {values[-1]:.2f}, " + print(f"Current: {values[-1]:.2f}, " +
f"Min: {min(values):.2f}, Max: {max(values):.2f}") f"Min: {min(values):.2f}, Max: {max(values):.2f}")
@@ -641,7 +622,7 @@ class AirPodsTracker:
self.stop_tracking() self.stop_tracking()
self.live_plotting = False self.live_plotting = False
def start_live_plotting(self, record_data=False, duration=None): def start_live_plotting(self, record_data: bool = False, duration: Optional[float] = None) -> None:
if self.sock is None: if self.sock is None:
if not self.connect(): if not self.connect():
logger.error("Could not connect to AirPods. Live plotting aborted.") logger.error("Could not connect to AirPods. Live plotting aborted.")
@@ -660,12 +641,12 @@ class AirPodsTracker:
self._start_live_plotting_terminal(record_data, duration) self._start_live_plotting_terminal(record_data, duration)
else: else:
from matplotlib.gridspec import GridSpec, GridSpecFromSubplotSpec from matplotlib.gridspec import GridSpec, GridSpecFromSubplotSpec
fig = plt.figure(figsize=(14, 6)) fig: Figure = plt.figure(figsize=(14, 6))
gs = GridSpec(1, 2, width_ratios=[1, 1]) gs: GridSpec = GridSpec(1, 2, width_ratios=[1, 1])
ax_accel = fig.add_subplot(gs[0]) ax_accel: Axes = fig.add_subplot(gs[0])
subgs = GridSpecFromSubplotSpec(2, 1, subplot_spec=gs[1], height_ratios=[2, 1]) subgs: GridSpecFromSubplotSpec = GridSpecFromSubplotSpec(2, 1, subplot_spec=gs[1], height_ratios=[2, 1])
ax_head_top = fig.add_subplot(subgs[0], projection='3d') ax_head_top: Axes = fig.add_subplot(subgs[0], projection='3d')
ax_ori = fig.add_subplot(subgs[1]) ax_ori: Axes = fig.add_subplot(subgs[1])
ax_accel.set_title("Acceleration Data") ax_accel.set_title("Acceleration Data")
ax_accel.set_xlabel("Packet Index") ax_accel.set_xlabel("Packet Index")
@@ -676,16 +657,16 @@ class AirPodsTracker:
self.apply_dark_theme(fig, [ax_accel, ax_head_top, ax_ori]) self.apply_dark_theme(fig, [ax_accel, ax_head_top, ax_ori])
plt.ion() plt.ion()
def update_plot(_): def update_plot(_: int) -> None:
with self.data_lock: with self.data_lock:
data = self.live_data.copy() data: List[Dict[str, float]] = self.live_data.copy()
if len(data) == 0: if len(data) == 0:
return return
latest = data[-1] latest: Dict[str, float] = data[-1]
if not self.orientation_visualizer.calibration_complete: if not self.orientation_visualizer.calibration_complete:
sample = [ sample: List[float] = [
latest.get('orientation 1', 0), latest.get('orientation 1', 0),
latest.get('orientation 2', 0), latest.get('orientation 2', 0),
latest.get('orientation 3', 0) latest.get('orientation 3', 0)
@@ -696,9 +677,9 @@ class AirPodsTracker:
fig.canvas.draw_idle() fig.canvas.draw_idle()
return return
h_accel = [p.get('Horizontal Acceleration', 0) for p in data] h_accel: List[float] = [p.get('Horizontal Acceleration', 0) for p in data]
v_accel = [p.get('Vertical Acceleration', 0) for p in data] v_accel: List[float] = [p.get('Vertical Acceleration', 0) for p in data]
x_vals = list(range(len(h_accel))) x_vals: List[int] = list(range(len(h_accel)))
ax_accel.cla() ax_accel.cla()
ax_accel.plot(x_vals, v_accel, label='Vertical Acceleration', color='#FFFF00', linewidth=2) ax_accel.plot(x_vals, v_accel, label='Vertical Acceleration', color='#FFFF00', linewidth=2)
ax_accel.plot(x_vals, h_accel, label='Horizontal Acceleration', color='#00FFFF', linewidth=2) ax_accel.plot(x_vals, h_accel, label='Horizontal Acceleration', color='#00FFFF', linewidth=2)
@@ -711,13 +692,13 @@ class AirPodsTracker:
ax_accel.xaxis.label.set_color('white') ax_accel.xaxis.label.set_color('white')
ax_accel.yaxis.label.set_color('white') ax_accel.yaxis.label.set_color('white')
latest = data[-1] latest: Dict[str, float] = data[-1]
o1 = latest.get('orientation 1', 0) o1: float = latest.get('orientation 1', 0)
o2 = latest.get('orientation 2', 0) o2: float = latest.get('orientation 2', 0)
o3 = latest.get('orientation 3', 0) o3: float = latest.get('orientation 3', 0)
orientation = self.orientation_visualizer.calculate_orientation(o1, o2, o3) orientation: Dict[str, float] = self.orientation_visualizer.calculate_orientation(o1, o2, o3)
pitch = orientation['pitch'] pitch: float = orientation['pitch']
yaw = orientation['yaw'] yaw: float = orientation['yaw']
ax_head_top.cla() ax_head_top.cla()
ax_head_top.set_title("Head Orientation") ax_head_top.set_title("Head Orientation")
@@ -727,25 +708,25 @@ class AirPodsTracker:
ax_head_top.set_facecolor('#2d2d2d') ax_head_top.set_facecolor('#2d2d2d')
pitch_rad = np.radians(pitch) pitch_rad = np.radians(pitch)
yaw_rad = np.radians(yaw) yaw_rad = np.radians(yaw)
Rz = np.array([ Rz: NDArray[Any] = np.array([
[np.cos(yaw_rad), np.sin(yaw_rad), 0], [np.cos(yaw_rad), np.sin(yaw_rad), 0],
[-np.sin(yaw_rad), np.cos(yaw_rad), 0], [-np.sin(yaw_rad), np.cos(yaw_rad), 0],
[0, 0, 1] [0, 0, 1]
]) ])
Ry = np.array([ Ry: NDArray[Any] = np.array([
[np.cos(pitch_rad), 0, np.sin(pitch_rad)], [np.cos(pitch_rad), 0, np.sin(pitch_rad)],
[0, 1, 0], [0, 1, 0],
[-np.sin(pitch_rad), 0, np.cos(pitch_rad)] [-np.sin(pitch_rad), 0, np.cos(pitch_rad)]
]) ])
R = Rz @ Ry R: NDArray[Any] = Rz @ Ry
dir_vec = R @ np.array([1, 0, 0]) dir_vec: NDArray[Any] = R @ np.array([1, 0, 0])
ax_head_top.quiver(0, 0, 0, dir_vec[0], dir_vec[1], dir_vec[2], ax_head_top.quiver(0, 0, 0, dir_vec[0], dir_vec[1], dir_vec[2],
color='r', length=0.8, linewidth=3) color='r', length=0.8, linewidth=3)
ax_ori.cla() ax_ori.cla()
o2_values = [p.get('orientation 2', 0) for p in data] o2_values: List[float] = [p.get('orientation 2', 0) for p in data]
o3_values = [p.get('orientation 3', 0) for p in data] o3_values: List[float] = [p.get('orientation 3', 0) for p in data]
x_range = list(range(len(o2_values))) x_range: List[int] = list(range(len(o2_values)))
ax_ori.plot(x_range, o2_values, label='Orientation 1', color='red', linewidth=2) ax_ori.plot(x_range, o2_values, label='Orientation 1', color='red', linewidth=2)
ax_ori.plot(x_range, o3_values, label='Orientation 2', color='green', linewidth=2) ax_ori.plot(x_range, o3_values, label='Orientation 2', color='green', linewidth=2)
ax_ori.set_facecolor('#2d2d2d') ax_ori.set_facecolor('#2d2d2d')
@@ -775,9 +756,9 @@ class AirPodsTracker:
self.animation = None self.animation = None
plt.ioff() plt.ioff()
def interactive_mode(self): def interactive_mode(self) -> None:
from prompt_toolkit import PromptSession from prompt_toolkit import PromptSession
session = PromptSession("> ") session: PromptSession = PromptSession("> ")
logger.info("\nAirPods Head Tracking Analyzer") logger.info("\nAirPods Head Tracking Analyzer")
print("------------------------------") print("------------------------------")
logger.info("Commands:") logger.info("Commands:")
@@ -793,59 +774,61 @@ class AirPodsTracker:
while True: while True:
try: try:
cmd_input = session.prompt("> ") cmd_input: str = session.prompt("> ")
cmd_parts = cmd_input.strip().split() cmd_parts: List[str] = cmd_input.strip().split()
if not cmd_parts: if not cmd_parts:
continue continue
cmd = cmd_parts[0].lower() cmd = cmd_parts[0].lower()
if cmd == "connect": match cmd:
self.connect() case "connect":
elif cmd == "start": self.connect()
duration = float(cmd_parts[1]) if len(cmd_parts) > 1 else None case "start":
self.start_tracking(duration) duration = float(cmd_parts[1]) if len(cmd_parts) > 1 else None
elif cmd == "stop": self.start_tracking(duration)
self.stop_tracking() case "stop":
elif cmd == "load" and len(cmd_parts) > 1: self.stop_tracking()
self.load_log_file(cmd_parts[1]) case "load":
elif cmd == "plot": if len(cmd_parts) > 1:
self.plot_fields() self.load_log_file(cmd_parts[1])
elif cmd == "live": case "plot":
duration = float(cmd_parts[1]) if len(cmd_parts) > 1 else None self.plot_fields()
logger.info("Starting live plotting mode (without recording)%s.", case "live":
f" for {duration} seconds" if duration else "") duration = float(cmd_parts[1]) if len(cmd_parts) > 1 else None
self.start_live_plotting(record_data=False, duration=duration) logger.info("Starting live plotting mode (without recording)%s.",
elif cmd == "liver": f" for {duration} seconds" if duration else "")
duration = float(cmd_parts[1]) if len(cmd_parts) > 1 else None self.start_live_plotting(record_data=False, duration=duration)
logger.info("Starting live plotting mode WITH recording%s.", case "liver":
f" for {duration} seconds" if duration else "") duration = float(cmd_parts[1]) if len(cmd_parts) > 1 else None
self.start_live_plotting(record_data=True, duration=duration) logger.info("Starting live plotting mode WITH recording%s.",
elif cmd == "gestures": f" for {duration} seconds" if duration else "")
from gestures import GestureDetector self.start_live_plotting(record_data=True, duration=duration)
if self.conn is not None: case "gestures":
detector = GestureDetector(conn=self.conn) from gestures import GestureDetector
else: if self.conn is not None:
detector = GestureDetector() detector: GestureDetector = GestureDetector(conn=self.conn)
detector.start_detection() else:
elif cmd == "quit": detector: GestureDetector = GestureDetector()
logger.info("Exiting.") detector.start_detection()
if self.conn != None: case "quit":
self.conn.disconnect() logger.info("Exiting.")
break if self.conn != None:
elif cmd == "help": self.conn.disconnect()
logger.info("\nAirPods Head Tracking Analyzer") break
logger.info("------------------------------") case "help":
logger.info("Commands:") logger.info("\nAirPods Head Tracking Analyzer")
logger.info(" connect - connect to your AirPods") logger.info("------------------------------")
logger.info(" start [seconds] - start recording head tracking data, optionally for specified duration") logger.info("Commands:")
logger.info(" stop - stop recording") logger.info(" connect - connect to your AirPods")
logger.info(" load <file> - load and parse a log file") logger.info(" start [seconds] - start recording head tracking data, optionally for specified duration")
logger.info(" plot - plot all sensor data fields") logger.info(" stop - stop recording")
logger.info(" live [seconds] - start live plotting (without recording), optionally stop recording after seconds") logger.info(" load <file> - load and parse a log file")
logger.info(" liver [seconds] - start live plotting with recording, optionally stop recording after seconds") logger.info(" plot - plot all sensor data fields")
logger.info(" gestures - start gesture detection") logger.info(" live [seconds] - start live plotting (without recording), optionally stop recording after seconds")
logger.info(" quit - exit the program") logger.info(" liver [seconds] - start live plotting with recording, optionally stop recording after seconds")
else: logger.info(" gestures - start gesture detection")
logger.info("Unknown command. Type 'help' to see available commands.") logger.info(" quit - exit the program")
case _:
logger.info("Unknown command. Type 'help' to see available commands.")
except KeyboardInterrupt: except KeyboardInterrupt:
logger.info("Use 'quit' to exit.") logger.info("Use 'quit' to exit.")
except EOFError: except EOFError:
@@ -856,5 +839,5 @@ class AirPodsTracker:
if __name__ == "__main__": if __name__ == "__main__":
import sys import sys
tracker = AirPodsTracker() tracker: AirPodsTracker = AirPodsTracker()
tracker.interactive_mode() tracker.interactive_mode()

View File

@@ -1,10 +1,13 @@
import sys
import socket
import struct
import threading
from queue import Queue
import logging import logging
import signal import signal
import socket
import struct
import sys
import threading
from socket import socket as Socket, TimeoutError
from queue import Queue
from threading import Thread
from typing import Any, Dict, List, Optional
# Configure logging # Configure logging
logging.basicConfig(level=logging.DEBUG, format='%(asctime)s - %(levelname)s - %(message)s') logging.basicConfig(level=logging.DEBUG, format='%(asctime)s - %(levelname)s - %(message)s')
@@ -12,47 +15,47 @@ logging.basicConfig(level=logging.DEBUG, format='%(asctime)s - %(levelname)s - %
from PyQt5.QtWidgets import QApplication, QWidget, QVBoxLayout, QHBoxLayout, QLabel, QSlider, QCheckBox, QPushButton, QLineEdit, QFormLayout, QGridLayout from PyQt5.QtWidgets import QApplication, QWidget, QVBoxLayout, QHBoxLayout, QLabel, QSlider, QCheckBox, QPushButton, QLineEdit, QFormLayout, QGridLayout
from PyQt5.QtCore import Qt, QTimer, pyqtSignal, QObject from PyQt5.QtCore import Qt, QTimer, pyqtSignal, QObject
OPCODE_READ_REQUEST = 0x0A OPCODE_READ_REQUEST: int = 0x0A
OPCODE_WRITE_REQUEST = 0x12 OPCODE_WRITE_REQUEST: int = 0x12
OPCODE_HANDLE_VALUE_NTF = 0x1B OPCODE_HANDLE_VALUE_NTF: int = 0x1B
ATT_HANDLES = { ATT_HANDLES: Dict[str, int] = {
'TRANSPARENCY': 0x18, 'TRANSPARENCY': 0x18,
'LOUD_SOUND_REDUCTION': 0x1B, 'LOUD_SOUND_REDUCTION': 0x1B,
'HEARING_AID': 0x2A, 'HEARING_AID': 0x2A,
} }
ATT_CCCD_HANDLES = { ATT_CCCD_HANDLES: Dict[str, int] = {
'TRANSPARENCY': ATT_HANDLES['TRANSPARENCY'] + 1, 'TRANSPARENCY': ATT_HANDLES['TRANSPARENCY'] + 1,
'LOUD_SOUND_REDUCTION': ATT_HANDLES['LOUD_SOUND_REDUCTION'] + 1, 'LOUD_SOUND_REDUCTION': ATT_HANDLES['LOUD_SOUND_REDUCTION'] + 1,
'HEARING_AID': ATT_HANDLES['HEARING_AID'] + 1, 'HEARING_AID': ATT_HANDLES['HEARING_AID'] + 1,
} }
PSM_ATT = 31 PSM_ATT: int = 31
class ATTManager: class ATTManager:
def __init__(self, mac_address): def __init__(self, mac_address: str) -> None:
self.mac_address = mac_address self.mac_address: str = mac_address
self.sock = None self.sock: Optional[Socket] = None
self.responses = Queue() self.responses: Queue = Queue()
self.listeners = {} self.listeners: Dict[int, List[Any]] = {}
self.notification_thread = None self.notification_thread: Optional[Thread] = None
self.running = False self.running: bool = False
# Avoid logging full MAC address to prevent sensitive data exposure # Avoid logging full MAC address to prevent sensitive data exposure
mac_tail = ':'.join(mac_address.split(':')[-2:]) if isinstance(mac_address, str) and ':' in mac_address else '[redacted]' mac_tail: str = ':'.join(mac_address.split(':')[-2:]) if isinstance(mac_address, str) and ':' in mac_address else '[redacted]'
logging.info(f"ATTManager initialized") logging.info(f"ATTManager initialized")
def connect(self): def connect(self) -> None:
logging.info("Attempting to connect to ATT socket") logging.info("Attempting to connect to ATT socket")
self.sock = socket.socket(socket.AF_BLUETOOTH, socket.SOCK_SEQPACKET, socket.BTPROTO_L2CAP) self.sock = Socket(socket.AF_BLUETOOTH, socket.SOCK_SEQPACKET, socket.BTPROTO_L2CAP)
self.sock.connect((self.mac_address, PSM_ATT)) self.sock.connect((self.mac_address, PSM_ATT))
self.sock.settimeout(0.1) self.sock.settimeout(0.1)
self.running = True self.running = True
self.notification_thread = threading.Thread(target=self._listen_notifications) self.notification_thread = Thread(target=self._listen_notifications)
self.notification_thread.start() self.notification_thread.start()
logging.info("Connected to ATT socket") logging.info("Connected to ATT socket")
def disconnect(self): def disconnect(self) -> None:
logging.info("Disconnecting from ATT socket") logging.info("Disconnecting from ATT socket")
self.running = False self.running = False
if self.sock: if self.sock:
@@ -63,37 +66,37 @@ class ATTManager:
self.notification_thread.join(timeout=1.0) self.notification_thread.join(timeout=1.0)
logging.info("Disconnected from ATT socket") logging.info("Disconnected from ATT socket")
def register_listener(self, handle, listener): def register_listener(self, handle: int, listener: Any) -> None:
if handle not in self.listeners: if handle not in self.listeners:
self.listeners[handle] = [] self.listeners[handle] = []
self.listeners[handle].append(listener) self.listeners[handle].append(listener)
logging.debug(f"Registered listener for handle {handle}") logging.debug(f"Registered listener for handle {handle}")
def unregister_listener(self, handle, listener): def unregister_listener(self, handle: int, listener: Any) -> None:
if handle in self.listeners: if handle in self.listeners:
self.listeners[handle].remove(listener) self.listeners[handle].remove(listener)
logging.debug(f"Unregistered listener for handle {handle}") logging.debug(f"Unregistered listener for handle {handle}")
def enable_notifications(self, handle): def enable_notifications(self, handle: Any) -> None:
self.write_cccd(handle, b'\x01\x00') self.write_cccd(handle, b'\x01\x00')
logging.info(f"Enabled notifications for handle {handle.name}") logging.info(f"Enabled notifications for handle {handle.name}")
def read(self, handle): def read(self, handle: Any) -> bytes:
handle_value = ATT_HANDLES[handle.name] handle_value: int = ATT_HANDLES[handle.name]
lsb = handle_value & 0xFF lsb: int = handle_value & 0xFF
msb = (handle_value >> 8) & 0xFF msb: int = (handle_value >> 8) & 0xFF
pdu = bytes([OPCODE_READ_REQUEST, lsb, msb]) pdu: bytes = bytes([OPCODE_READ_REQUEST, lsb, msb])
logging.debug(f"Sending read request for handle {handle.name}: {pdu.hex()}") logging.debug(f"Sending read request for handle {handle.name}: {pdu.hex()}")
self._write_raw(pdu) self._write_raw(pdu)
response = self._read_response() response: bytes = self._read_response()
logging.debug(f"Read response for handle {handle.name}: {response.hex()}") logging.debug(f"Read response for handle {handle.name}: {response.hex()}")
return response return response
def write(self, handle, value): def write(self, handle: Any, value: bytes) -> None:
handle_value = ATT_HANDLES[handle.name] handle_value: int = ATT_HANDLES[handle.name]
lsb = handle_value & 0xFF lsb: int = handle_value & 0xFF
msb = (handle_value >> 8) & 0xFF msb: int = (handle_value >> 8) & 0xFF
pdu = bytes([OPCODE_WRITE_REQUEST, lsb, msb]) + value pdu: bytes = bytes([OPCODE_WRITE_REQUEST, lsb, msb]) + value
logging.debug(f"Sending write request for handle {handle.name}: {pdu.hex()}") logging.debug(f"Sending write request for handle {handle.name}: {pdu.hex()}")
self._write_raw(pdu) self._write_raw(pdu)
try: try:
@@ -102,11 +105,11 @@ class ATTManager:
except: except:
logging.warning(f"No write response received for handle {handle.name}") logging.warning(f"No write response received for handle {handle.name}")
def write_cccd(self, handle, value): def write_cccd(self, handle: Any, value: bytes) -> None:
handle_value = ATT_CCCD_HANDLES[handle.name] handle_value: int = ATT_CCCD_HANDLES[handle.name]
lsb = handle_value & 0xFF lsb: int = handle_value & 0xFF
msb = (handle_value >> 8) & 0xFF msb: int = (handle_value >> 8) & 0xFF
pdu = bytes([OPCODE_WRITE_REQUEST, lsb, msb]) + value pdu: bytes = bytes([OPCODE_WRITE_REQUEST, lsb, msb]) + value
logging.debug(f"Sending CCCD write request for handle {handle.name}: {pdu.hex()}") logging.debug(f"Sending CCCD write request for handle {handle.name}: {pdu.hex()}")
self._write_raw(pdu) self._write_raw(pdu)
try: try:
@@ -115,42 +118,42 @@ class ATTManager:
except: except:
logging.warning(f"No CCCD write response received for handle {handle.name}") logging.warning(f"No CCCD write response received for handle {handle.name}")
def _write_raw(self, pdu): def _write_raw(self, pdu: bytes) -> None:
self.sock.send(pdu) self.sock.send(pdu)
logging.debug(f"Sent PDU: {pdu.hex()}") logging.debug(f"Sent PDU: {pdu.hex()}")
def _read_pdu(self): def _read_pdu(self) -> Optional[bytes]:
try: try:
data = self.sock.recv(512) data: bytes = self.sock.recv(512)
logging.debug(f"Received PDU: {data.hex()}") logging.debug(f"Received PDU: {data.hex()}")
return data return data
except socket.timeout: except TimeoutError:
return None return None
except: except:
raise raise
def _read_response(self, timeout=2.0): def _read_response(self, timeout: float = 2.0) -> bytes:
try: try:
response = self.responses.get(timeout=timeout)[1:] # Skip opcode response: bytes = self.responses.get(timeout=timeout)[1:] # Skip opcode
logging.debug(f"Response received: {response.hex()}") logging.debug(f"Response received: {response.hex()}")
return response return response
except: except:
logging.error("No response received within timeout") logging.error("No response received within timeout")
raise Exception("No response received") raise Exception("No response received")
def _listen_notifications(self): def _listen_notifications(self) -> None:
logging.info("Starting notification listener thread") logging.info("Starting notification listener thread")
while self.running: while self.running:
try: try:
pdu = self._read_pdu() pdu: Optional[bytes] = self._read_pdu()
except: except:
break break
if pdu is None: if pdu is None:
continue continue
if len(pdu) > 0 and pdu[0] == OPCODE_HANDLE_VALUE_NTF: if len(pdu) > 0 and pdu[0] == OPCODE_HANDLE_VALUE_NTF:
logging.debug(f"Notification PDU received: {pdu.hex()}") logging.debug(f"Notification PDU received: {pdu.hex()}")
handle = pdu[1] | (pdu[2] << 8) handle: int = pdu[1] | (pdu[2] << 8)
value = pdu[3:] value: bytes = pdu[3:]
logging.debug(f"Notification for handle {handle}: {value.hex()}") logging.debug(f"Notification for handle {handle}: {value.hex()}")
if handle in self.listeners: if handle in self.listeners:
for listener in self.listeners[handle]: for listener in self.listeners[handle]:
@@ -165,36 +168,36 @@ class ATTManager:
logging.error(f"Reconnection failed: {e}") logging.error(f"Reconnection failed: {e}")
class HearingAidSettings: class HearingAidSettings:
def __init__(self, left_eq, right_eq, left_amp, right_amp, left_tone, right_tone, def __init__(self, left_eq: List[float], right_eq: List[float], left_amp: float, right_amp: float, left_tone: float, right_tone: float,
left_conv, right_conv, left_anr, right_anr, net_amp, balance, own_voice): left_conv: bool, right_conv: bool, left_anr: float, right_anr: float, net_amp: float, balance: float, own_voice: float) -> None:
self.left_eq = left_eq self.left_eq: List[float] = left_eq
self.right_eq = right_eq self.right_eq: List[float] = right_eq
self.left_amplification = left_amp self.left_amplification: float = left_amp
self.right_amplification = right_amp self.right_amplification: float = right_amp
self.left_tone = left_tone self.left_tone: float = left_tone
self.right_tone = right_tone self.right_tone: float = right_tone
self.left_conversation_boost = left_conv self.left_conversation_boost: bool = left_conv
self.right_conversation_boost = right_conv self.right_conversation_boost: bool = right_conv
self.left_ambient_noise_reduction = left_anr self.left_ambient_noise_reduction: float = left_anr
self.right_ambient_noise_reduction = right_anr self.right_ambient_noise_reduction: float = right_anr
self.net_amplification = net_amp self.net_amplification: float = net_amp
self.balance = balance self.balance: float = balance
self.own_voice_amplification = own_voice self.own_voice_amplification: float = own_voice
logging.debug(f"HearingAidSettings created: amp={net_amp}, balance={balance}, tone={left_tone}, anr={left_anr}, conv={left_conv}") logging.debug(f"HearingAidSettings created: amp={net_amp}, balance={balance}, tone={left_tone}, anr={left_anr}, conv={left_conv}")
def parse_hearing_aid_settings(data): def parse_hearing_aid_settings(data: bytes) -> Optional[HearingAidSettings]:
logging.debug(f"Parsing hearing aid settings from data: {data.hex()}") logging.debug(f"Parsing hearing aid settings from data: {data.hex()}")
if len(data) < 104: if len(data) < 104:
logging.warning("Data too short for parsing") logging.warning("Data too short for parsing")
return None return None
buffer = data buffer: bytes = data
offset = 0 offset: int = 0
offset += 4 offset += 4
logging.info(f"Parsing hearing aid settings, starting read at offset 4, value: {buffer[offset]:02x}") logging.info(f"Parsing hearing aid settings, starting read at offset 4, value: {buffer[offset]:02x}")
left_eq = [] left_eq: List[float] = []
for i in range(8): for i in range(8):
val, = struct.unpack('<f', buffer[offset:offset+4]) val, = struct.unpack('<f', buffer[offset:offset+4])
left_eq.append(val) left_eq.append(val)
@@ -228,23 +231,23 @@ def parse_hearing_aid_settings(data):
own_voice, = struct.unpack('<f', buffer[offset:offset+4]) own_voice, = struct.unpack('<f', buffer[offset:offset+4])
avg = (left_amp + right_amp) / 2 avg: float = (left_amp + right_amp) / 2
amplification = max(-1, min(1, avg)) amplification: float = max(-1, min(1, avg))
diff = right_amp - left_amp diff: float = right_amp - left_amp
balance = max(-1, min(1, diff)) balance: float = max(-1, min(1, diff))
settings = HearingAidSettings(left_eq, right_eq, left_amp, right_amp, left_tone, right_tone, settings: HearingAidSettings = HearingAidSettings(left_eq, right_eq, left_amp, right_amp, left_tone, right_tone,
left_conv, right_conv, left_anr, right_anr, amplification, balance, own_voice) left_conv, right_conv, left_anr, right_anr, amplification, balance, own_voice)
logging.info(f"Parsed settings: amp={amplification}, balance={balance}") logging.info(f"Parsed settings: amp={amplification}, balance={balance}")
return settings return settings
def send_hearing_aid_settings(att_manager, settings): def send_hearing_aid_settings(att_manager: ATTManager, settings: HearingAidSettings) -> None:
logging.info("Sending hearing aid settings") logging.info("Sending hearing aid settings")
data = att_manager.read(type('Handle', (), {'name': 'HEARING_AID'})()) data: bytes = att_manager.read(type('Handle', (), {'name': 'HEARING_AID'})())
if len(data) < 104: if len(data) < 104:
logging.error("Read data too short for sending settings") logging.error("Read data too short for sending settings")
return return
buffer = bytearray(data) buffer: bytearray = bytearray(data)
# Modify byte at index 2 to 0x64 # Modify byte at index 2 to 0x64
buffer[2] = 0x64 buffer[2] = 0x64
@@ -272,16 +275,16 @@ def send_hearing_aid_settings(att_manager, settings):
logging.info("Hearing aid settings sent") logging.info("Hearing aid settings sent")
class SignalEmitter(QObject): class SignalEmitter(QObject):
update_ui = pyqtSignal(HearingAidSettings) update_ui: pyqtSignal = pyqtSignal(HearingAidSettings)
class HearingAidApp(QWidget): class HearingAidApp(QWidget):
def __init__(self, mac_address): def __init__(self, mac_address: str) -> None:
super().__init__() super().__init__()
self.mac_address = mac_address self.mac_address: str = mac_address
self.att_manager = ATTManager(mac_address) self.att_manager: ATTManager = ATTManager(mac_address)
self.emitter = SignalEmitter() self.emitter: SignalEmitter = SignalEmitter()
self.emitter.update_ui.connect(self.on_update_ui) self.emitter.update_ui.connect(self.on_update_ui)
self.debounce_timer = QTimer() self.debounce_timer: QTimer = QTimer()
self.debounce_timer.setSingleShot(True) self.debounce_timer.setSingleShot(True)
self.debounce_timer.timeout.connect(self.send_settings) self.debounce_timer.timeout.connect(self.send_settings)
logging.info("HearingAidConfig initialized") logging.info("HearingAidConfig initialized")
@@ -289,25 +292,25 @@ class HearingAidApp(QWidget):
self.init_ui() self.init_ui()
self.connect_att() self.connect_att()
def init_ui(self): def init_ui(self) -> None:
logging.debug("Initializing UI") logging.debug("Initializing UI")
self.setWindowTitle("Hearing Aid Adjustments") self.setWindowTitle("Hearing Aid Adjustments")
layout = QVBoxLayout() layout: QVBoxLayout = QVBoxLayout()
# EQ Inputs # EQ Inputs
eq_layout = QGridLayout() eq_layout: QGridLayout = QGridLayout()
self.left_eq_inputs = [] self.left_eq_inputs: List[QLineEdit] = []
self.right_eq_inputs = [] self.right_eq_inputs: List[QLineEdit] = []
eq_labels = ["250Hz", "500Hz", "1kHz", "2kHz", "3kHz", "4kHz", "6kHz", "8kHz"] eq_labels: List[str] = ["250Hz", "500Hz", "1kHz", "2kHz", "3kHz", "4kHz", "6kHz", "8kHz"]
eq_layout.addWidget(QLabel("Frequency"), 0, 0) eq_layout.addWidget(QLabel("Frequency"), 0, 0)
eq_layout.addWidget(QLabel("Left"), 0, 1) eq_layout.addWidget(QLabel("Left"), 0, 1)
eq_layout.addWidget(QLabel("Right"), 0, 2) eq_layout.addWidget(QLabel("Right"), 0, 2)
for i, label in enumerate(eq_labels): for i, label in enumerate(eq_labels):
eq_layout.addWidget(QLabel(label), i + 1, 0) eq_layout.addWidget(QLabel(label), i + 1, 0)
left_input = QLineEdit() left_input: QLineEdit = QLineEdit()
right_input = QLineEdit() right_input: QLineEdit = QLineEdit()
left_input.setPlaceholderText("Left") left_input.setPlaceholderText("Left")
right_input.setPlaceholderText("Right") right_input.setPlaceholderText("Right")
self.left_eq_inputs.append(left_input) self.left_eq_inputs.append(left_input)
@@ -315,52 +318,52 @@ class HearingAidApp(QWidget):
eq_layout.addWidget(left_input, i + 1, 1) eq_layout.addWidget(left_input, i + 1, 1)
eq_layout.addWidget(right_input, i + 1, 2) eq_layout.addWidget(right_input, i + 1, 2)
eq_group = QWidget() eq_group: QWidget = QWidget()
eq_group.setLayout(eq_layout) eq_group.setLayout(eq_layout)
layout.addWidget(QLabel("Loss, in dBHL")) layout.addWidget(QLabel("Loss, in dBHL"))
layout.addWidget(eq_group) layout.addWidget(eq_group)
# Amplification # Amplification
self.amp_slider = QSlider(Qt.Horizontal) self.amp_slider: QSlider = QSlider(Qt.Horizontal)
self.amp_slider.setRange(-100, 100) self.amp_slider.setRange(-100, 100)
self.amp_slider.setValue(50) self.amp_slider.setValue(50)
layout.addWidget(QLabel("Amplification")) layout.addWidget(QLabel("Amplification"))
layout.addWidget(self.amp_slider) layout.addWidget(self.amp_slider)
# Balance # Balance
self.balance_slider = QSlider(Qt.Horizontal) self.balance_slider: QSlider = QSlider(Qt.Horizontal)
self.balance_slider.setRange(-100, 100) self.balance_slider.setRange(-100, 100)
self.balance_slider.setValue(50) self.balance_slider.setValue(50)
layout.addWidget(QLabel("Balance")) layout.addWidget(QLabel("Balance"))
layout.addWidget(self.balance_slider) layout.addWidget(self.balance_slider)
# Tone # Tone
self.tone_slider = QSlider(Qt.Horizontal) self.tone_slider: QSlider = QSlider(Qt.Horizontal)
self.tone_slider.setRange(-100, 100) self.tone_slider.setRange(-100, 100)
self.tone_slider.setValue(50) self.tone_slider.setValue(50)
layout.addWidget(QLabel("Tone")) layout.addWidget(QLabel("Tone"))
layout.addWidget(self.tone_slider) layout.addWidget(self.tone_slider)
# Ambient Noise Reduction # Ambient Noise Reduction
self.anr_slider = QSlider(Qt.Horizontal) self.anr_slider: QSlider = QSlider(Qt.Horizontal)
self.anr_slider.setRange(0, 100) self.anr_slider.setRange(0, 100)
self.anr_slider.setValue(0) self.anr_slider.setValue(0)
layout.addWidget(QLabel("Ambient Noise Reduction")) layout.addWidget(QLabel("Ambient Noise Reduction"))
layout.addWidget(self.anr_slider) layout.addWidget(self.anr_slider)
# Conversation Boost # Conversation Boost
self.conv_checkbox = QCheckBox("Conversation Boost") self.conv_checkbox: QCheckBox = QCheckBox("Conversation Boost")
layout.addWidget(self.conv_checkbox) layout.addWidget(self.conv_checkbox)
# Own Voice Amplification # Own Voice Amplification
self.own_voice_slider = QSlider(Qt.Horizontal) self.own_voice_slider: QSlider = QSlider(Qt.Horizontal)
self.own_voice_slider.setRange(0, 100) self.own_voice_slider.setRange(0, 100)
self.own_voice_slider.setValue(50) self.own_voice_slider.setValue(50)
# layout.addWidget(QLabel("Own Voice Amplification")) # layout.addWidget(QLabel("Own Voice Amplification"))
# layout.addWidget(self.own_voice_slider) # seems to have no effect # layout.addWidget(self.own_voice_slider) # seems to have no effect
# Reset button # Reset button
self.reset_button = QPushButton("Reset") self.reset_button: QPushButton = QPushButton("Reset")
layout.addWidget(self.reset_button) layout.addWidget(self.reset_button)
# Connect signals # Connect signals
@@ -377,15 +380,15 @@ class HearingAidApp(QWidget):
self.setLayout(layout) self.setLayout(layout)
logging.debug("UI initialized") logging.debug("UI initialized")
def connect_att(self): def connect_att(self) -> None:
logging.info("Connecting to ATT in UI") logging.info("Connecting to ATT in UI")
try: try:
self.att_manager.connect() self.att_manager.connect()
self.att_manager.enable_notifications(type('Handle', (), {'name': 'HEARING_AID'})()) self.att_manager.enable_notifications(type('Handle', (), {'name': 'HEARING_AID'})())
self.att_manager.register_listener(ATT_HANDLES['HEARING_AID'], self.on_notification) self.att_manager.register_listener(ATT_HANDLES['HEARING_AID'], self.on_notification)
# Initial read # Initial read
data = self.att_manager.read(type('Handle', (), {'name': 'HEARING_AID'})()) data: bytes = self.att_manager.read(type('Handle', (), {'name': 'HEARING_AID'})())
settings = parse_hearing_aid_settings(data) settings: Optional[HearingAidSettings] = parse_hearing_aid_settings(data)
if settings: if settings:
self.emitter.update_ui.emit(settings) self.emitter.update_ui.emit(settings)
logging.info("Initial settings loaded") logging.info("Initial settings loaded")
@@ -396,13 +399,13 @@ class HearingAidApp(QWidget):
else: else:
logging.error(f"Connection failed: {e}") logging.error(f"Connection failed: {e}")
def on_notification(self, value): def on_notification(self, value: bytes) -> None:
logging.debug("Notification received") logging.debug("Notification received")
settings = parse_hearing_aid_settings(value) settings: Optional[HearingAidSettings] = parse_hearing_aid_settings(value)
if settings: if settings:
self.emitter.update_ui.emit(settings) self.emitter.update_ui.emit(settings)
def on_update_ui(self, settings): def on_update_ui(self, settings: HearingAidSettings) -> None:
logging.debug("Updating UI with settings") logging.debug("Updating UI with settings")
self.amp_slider.setValue(int(settings.net_amplification * 100)) self.amp_slider.setValue(int(settings.net_amplification * 100))
self.balance_slider.setValue(int(settings.balance * 100)) self.balance_slider.setValue(int(settings.balance * 100))
@@ -416,30 +419,30 @@ class HearingAidApp(QWidget):
for i, value in enumerate(settings.right_eq): for i, value in enumerate(settings.right_eq):
self.right_eq_inputs[i].setText(f"{value:.2f}") self.right_eq_inputs[i].setText(f"{value:.2f}")
def on_value_changed(self): def on_value_changed(self) -> None:
logging.debug("UI value changed, starting debounce") logging.debug("UI value changed, starting debounce")
self.debounce_timer.start(100) self.debounce_timer.start(100)
def send_settings(self): def send_settings(self) -> None:
logging.info("Sending settings from UI") logging.info("Sending settings from UI")
amp = self.amp_slider.value() / 100.0 amp: float = self.amp_slider.value() / 100.0
balance = self.balance_slider.value() / 100.0 balance: float = self.balance_slider.value() / 100.0
tone = self.tone_slider.value() / 100.0 tone: float = self.tone_slider.value() / 100.0
anr = self.anr_slider.value() / 100.0 anr: float = self.anr_slider.value() / 100.0
conv = self.conv_checkbox.isChecked() conv: bool = self.conv_checkbox.isChecked()
own_voice = self.own_voice_slider.value() / 100.0 own_voice: float = self.own_voice_slider.value() / 100.0
left_amp = amp + (0.5 - balance) * amp * 2 if balance < 0 else amp left_amp: float = amp + (0.5 - balance) * amp * 2 if balance < 0 else amp
right_amp = amp + (balance - 0.5) * amp * 2 if balance > 0 else amp right_amp: float = amp + (balance - 0.5) * amp * 2 if balance > 0 else amp
left_eq = [float(input_box.text() or 0) for input_box in self.left_eq_inputs] left_eq: List[float] = [float(input_box.text() or 0) for input_box in self.left_eq_inputs]
right_eq = [float(input_box.text() or 0) for input_box in self.right_eq_inputs] right_eq: List[float] = [float(input_box.text() or 0) for input_box in self.right_eq_inputs]
settings = HearingAidSettings( settings: HearingAidSettings = HearingAidSettings(
left_eq, right_eq, left_amp, right_amp, tone, tone, left_eq, right_eq, left_amp, right_amp, tone, tone,
conv, conv, anr, anr, amp, balance, own_voice conv, conv, anr, anr, amp, balance, own_voice
) )
threading.Thread(target=send_hearing_aid_settings, args=(self.att_manager, settings)).start() Thread(target=send_hearing_aid_settings, args=(self.att_manager, settings)).start()
def reset_settings(self): def reset_settings(self):
logging.debug("Resetting settings to defaults") logging.debug("Resetting settings to defaults")
@@ -451,26 +454,25 @@ class HearingAidApp(QWidget):
self.own_voice_slider.setValue(50) self.own_voice_slider.setValue(50)
self.on_value_changed() self.on_value_changed()
def closeEvent(self, event): def closeEvent(self, event: Any) -> None:
logging.info("Closing app") logging.info("Closing app")
self.att_manager.disconnect() self.att_manager.disconnect()
event.accept() event.accept()
if __name__ == "__main__": if __name__ == "__main__":
mac = None
if len(sys.argv) != 2: if len(sys.argv) != 2:
logging.error("Usage: python hearing-aid-adjustments.py <MAC_ADDRESS>") logging.error("Usage: python hearing-aid-adjustments.py <MAC_ADDRESS>")
sys.exit(1) sys.exit(1)
mac = sys.argv[1] mac: str = sys.argv[1]
mac_regex = r'^([0-9A-Fa-f]{2}[:-]){5}([0-9A-Fa-f]{2})$' mac_regex: str = r'^([0-9A-Fa-f]{2}[:-]){5}([0-9A-Fa-f]{2})$'
import re import re
if not re.match(mac_regex, mac): if not re.match(mac_regex, mac):
logging.error("Invalid MAC address format") logging.error("Invalid MAC address format")
sys.exit(1) sys.exit(1)
logging.info(f"Starting app") logging.info(f"Starting app")
app = QApplication(sys.argv) app: QApplication = QApplication(sys.argv)
def quit_app(signum, frame): def quit_app(signum: int, frame: Any) -> None:
app.quit() app.quit()
signal.signal(signal.SIGINT, quit_app) signal.signal(signal.SIGINT, quit_app)

View File

@@ -4,50 +4,53 @@
# See https://github.com/google/bumble/blob/main/docs/mkdocs/src/platforms/windows.md for usage. # See https://github.com/google/bumble/blob/main/docs/mkdocs/src/platforms/windows.md for usage.
# You need to associate WinUSB with your Bluetooth interface. Once done, you can roll back to the original driver from Device Manager. # You need to associate WinUSB with your Bluetooth interface. Once done, you can roll back to the original driver from Device Manager.
import sys
import asyncio import asyncio
import argparse import colorama
import logging import logging
import platform import platform
from typing import Any, Optional from argparse import ArgumentParser, Namespace
from asyncio import Queue, TimeoutError
from colorama import Fore, Style
from logging import Formatter, LogRecord, Logger, StreamHandler
from socket import socket as Socket
from typing import Any, Dict, List, Optional, Tuple
from colorama import Fore, Style, init as colorama_init colorama.init(autoreset=True)
colorama_init(autoreset=True)
handler = logging.StreamHandler() handler: StreamHandler = StreamHandler()
class ColorFormatter(logging.Formatter): class ColorFormatter(Formatter):
COLORS = { COLORS: Dict[int, str] = {
logging.DEBUG: Fore.BLUE, logging.DEBUG: Fore.BLUE,
logging.INFO: Fore.GREEN, logging.INFO: Fore.GREEN,
logging.WARNING: Fore.YELLOW, logging.WARNING: Fore.YELLOW,
logging.ERROR: Fore.RED, logging.ERROR: Fore.RED,
logging.CRITICAL: Fore.MAGENTA, logging.CRITICAL: Fore.MAGENTA,
} }
def format(self, record): def format(self, record: LogRecord) -> str:
color = self.COLORS.get(record.levelno, "") color: str = self.COLORS.get(record.levelno, "")
prefix = f"{color}[{record.levelname}:{record.name}]{Style.RESET_ALL}" prefix: str = f"{color}[{record.levelname}:{record.name}]{Style.RESET_ALL}"
return f"{prefix} {record.getMessage()}" return f"{prefix} {record.getMessage()}"
handler.setFormatter(ColorFormatter()) handler.setFormatter(ColorFormatter())
logging.basicConfig(level=logging.INFO, handlers=[handler]) logging.basicConfig(level=logging.INFO, handlers=[handler])
logger = logging.getLogger("proximitykeys") logger: Logger = logging.getLogger("proximitykeys")
PROXIMITY_KEY_TYPES = {0x01: "IRK", 0x04: "ENC_KEY"} PROXIMITY_KEY_TYPES: Dict[int, str] = {0x01: "IRK", 0x04: "ENC_KEY"}
def parse_proximity_keys_response(data: bytes): def parse_proximity_keys_response(data: bytes) -> Optional[List[Tuple[str, bytes]]]:
if len(data) < 7 or data[4] != 0x31: if len(data) < 7 or data[4] != 0x31:
return None return None
key_count = data[6] key_count: int = data[6]
keys = [] keys: List[Tuple[str, bytes]] = []
offset = 7 offset: int = 7
for _ in range(key_count): for _ in range(key_count):
if offset + 3 >= len(data): if offset + 3 >= len(data):
break break
key_type = data[offset] key_type: int = data[offset]
key_length = data[offset + 2] key_length: int = data[offset + 2]
offset += 4 offset += 4
if offset + key_length > len(data): if offset + key_length > len(data):
break break
key_bytes = data[offset:offset + key_length] key_bytes: bytes = data[offset:offset + key_length]
keys.append((PROXIMITY_KEY_TYPES.get(key_type, f"TYPE_{key_type:02X}"), key_bytes)) keys.append((PROXIMITY_KEY_TYPES.get(key_type, f"TYPE_{key_type:02X}"), key_bytes))
offset += key_length offset += key_length
return keys return keys
@@ -55,7 +58,7 @@ def parse_proximity_keys_response(data: bytes):
def hexdump(data: bytes) -> str: def hexdump(data: bytes) -> str:
return " ".join(f"{b:02X}" for b in data) return " ".join(f"{b:02X}" for b in data)
async def run_bumble(bdaddr: str): async def run_bumble(bdaddr: str) -> int:
try: try:
from bumble.l2cap import ClassicChannelSpec from bumble.l2cap import ClassicChannelSpec
from bumble.transport import open_transport from bumble.transport import open_transport
@@ -68,19 +71,23 @@ async def run_bumble(bdaddr: str):
logger.error("Bumble not installed") logger.error("Bumble not installed")
return 1 return 1
PSM_PROXIMITY = 0x1001 PSM_PROXIMITY: int = 0x1001
HANDSHAKE = bytes.fromhex("00 00 04 00 01 00 02 00 00 00 00 00 00 00 00 00") HANDSHAKE: bytes = bytes.fromhex("00 00 04 00 01 00 02 00 00 00 00 00 00 00 00 00")
KEY_REQ = bytes.fromhex("04 00 04 00 30 00 05 00") KEY_REQ: bytes = bytes.fromhex("04 00 04 00 30 00 05 00")
class KeyStore: class KeyStore:
async def delete(self, name: str): pass async def delete(self, name: str) -> None:
async def update(self, name: str, keys: Any): pass pass
async def get(self, _name: str) -> Optional[Any]: return None async def update(self, name: str, keys: Any) -> None:
async def get_all(self): return [] pass
async def get(self, _name: str) -> Optional[Any]:
return None
async def get_all(self) -> List[Tuple[str, Any]]:
return []
async def get_resolving_keys(self) -> list[tuple[bytes, Any]]: async def get_resolving_keys(self) -> List[Tuple[bytes, Any]]:
all_keys = await self.get_all() all_keys: List[Tuple[str, Any]] = await self.get_all()
resolving_keys = [] resolving_keys: List[Tuple[bytes, Any]] = []
for name, keys in all_keys: for name, keys in all_keys:
if getattr(keys, "irk", None) is not None: if getattr(keys, "irk", None) is not None:
resolving_keys.append(( resolving_keys.append((
@@ -89,8 +96,8 @@ async def run_bumble(bdaddr: str):
)) ))
return resolving_keys return resolving_keys
async def exchange_keys(channel, timeout=5.0): async def exchange_keys(channel: Any, timeout: float = 5.0) -> Optional[List[Tuple[str, bytes]]]:
recv_q: asyncio.Queue = asyncio.Queue() recv_q: Queue = Queue()
channel.sink = lambda sdu: recv_q.put_nowait(sdu) channel.sink = lambda sdu: recv_q.put_nowait(sdu)
logger.info("Sending handshake packet...") logger.info("Sending handshake packet...")
channel.send_pdu(HANDSHAKE) channel.send_pdu(HANDSHAKE)
@@ -99,19 +106,19 @@ async def run_bumble(bdaddr: str):
channel.send_pdu(KEY_REQ) channel.send_pdu(KEY_REQ)
while True: while True:
try: try:
pkt = await asyncio.wait_for(recv_q.get(), timeout) pkt: bytes = await asyncio.wait_for(recv_q.get(), timeout)
except asyncio.TimeoutError: except TimeoutError:
logger.error("Timed out waiting for SDU response") logger.error("Timed out waiting for SDU response")
return None return None
logger.debug("Received SDU (%d bytes): %s", len(pkt), hexdump(pkt)) logger.debug("Received SDU (%d bytes): %s", len(pkt), hexdump(pkt))
keys = parse_proximity_keys_response(pkt) keys: Optional[List[Tuple[str, bytes]]] = parse_proximity_keys_response(pkt)
if keys: if keys:
return keys return keys
async def get_device(): async def get_device() -> Tuple[Any, Device]:
logger.info("Opening transport...") logger.info("Opening transport...")
transport = await open_transport("usb:0") transport: Any = await open_transport("usb:0")
device = Device(host=Host(controller_source=transport.source, controller_sink=transport.sink)) device: Device = Device(host=Host(controller_source=transport.source, controller_sink=transport.sink))
device.classic_enabled = True device.classic_enabled = True
device.le_enabled = False device.le_enabled = False
device.keystore = KeyStore() device.keystore = KeyStore()
@@ -123,15 +130,15 @@ async def run_bumble(bdaddr: str):
logger.info("Device powered on") logger.info("Device powered on")
return transport, device return transport, device
async def create_channel_and_exchange(conn): async def create_channel_and_exchange(conn: Any) -> None:
spec = ClassicChannelSpec(psm=PSM_PROXIMITY, mtu=2048) spec: ClassicChannelSpec = ClassicChannelSpec(psm=PSM_PROXIMITY, mtu=2048)
logger.info("Requesting L2CAP channel on PSM = 0x%04X", spec.psm) logger.info("Requesting L2CAP channel on PSM = 0x%04X", spec.psm)
if not conn.is_encrypted: if not conn.is_encrypted:
logger.info("Enabling link encryption...") logger.info("Enabling link encryption...")
await conn.encrypt() await conn.encrypt()
await asyncio.sleep(0.05) await asyncio.sleep(0.05)
channel = await conn.create_l2cap_channel(spec=spec) channel: Any = await conn.create_l2cap_channel(spec=spec)
keys = await exchange_keys(channel, timeout=8.0) keys: Optional[List[Tuple[str, bytes]]] = await exchange_keys(channel, timeout=8.0)
if not keys: if not keys:
logger.warning("No proximity keys found") logger.warning("No proximity keys found")
return return
@@ -165,14 +172,14 @@ async def run_bumble(bdaddr: str):
logger.info("Transport closed") logger.info("Transport closed")
return 0 return 0
def run_linux(bdaddr: str): def run_linux(bdaddr: str) -> None:
import socket import socket
PSM = 0x1001 PSM: int = 0x1001
handshake = bytes.fromhex("00 00 04 00 01 00 02 00 00 00 00 00 00 00 00 00") handshake: bytes = bytes.fromhex("00 00 04 00 01 00 02 00 00 00 00 00 00 00 00 00")
key_req = bytes.fromhex("04 00 04 00 30 00 05 00") key_req: bytes = bytes.fromhex("04 00 04 00 30 00 05 00")
logger.info("Connecting to %s (L2CAP)...", bdaddr) logger.info("Connecting to %s (L2CAP)...", bdaddr)
sock = socket.socket(socket.AF_BLUETOOTH, socket.SOCK_SEQPACKET, socket.BTPROTO_L2CAP) sock: Socket = Socket(socket.AF_BLUETOOTH, socket.SOCK_SEQPACKET, socket.BTPROTO_L2CAP)
try: try:
sock.connect((bdaddr, PSM)) sock.connect((bdaddr, PSM))
logger.info("Connected, sending handshake and key request...") logger.info("Connected, sending handshake and key request...")
@@ -180,9 +187,9 @@ def run_linux(bdaddr: str):
sock.send(key_req) sock.send(key_req)
while True: while True:
pkt = sock.recv(1024) pkt: bytes = sock.recv(1024)
logger.debug("Received packet (%d bytes): %s", len(pkt), hexdump(pkt)) logger.debug("Received packet (%d bytes): %s", len(pkt), hexdump(pkt))
keys = parse_proximity_keys_response(pkt) keys: Optional[List[Tuple[str, bytes]]] = parse_proximity_keys_response(pkt)
if keys: if keys:
logger.info("Keys successfully retrieved") logger.info("Keys successfully retrieved")
print(f"{Fore.CYAN}{Style.BRIGHT}Proximity Keys:{Style.RESET_ALL}") print(f"{Fore.CYAN}{Style.BRIGHT}Proximity Keys:{Style.RESET_ALL}")
@@ -197,12 +204,12 @@ def run_linux(bdaddr: str):
sock.close() sock.close()
logger.info("Connection closed") logger.info("Connection closed")
def main(): def main() -> None:
parser = argparse.ArgumentParser() parser: ArgumentParser = ArgumentParser()
parser.add_argument("bdaddr") parser.add_argument("bdaddr")
parser.add_argument("--debug", action="store_true") parser.add_argument("--debug", action="store_true")
parser.add_argument("--bumble", action="store_true") parser.add_argument("--bumble", action="store_true")
args = parser.parse_args() args: Namespace = parser.parse_args()
logging.getLogger().setLevel(logging.DEBUG if args.debug else logging.INFO) logging.getLogger().setLevel(logging.DEBUG if args.debug else logging.INFO)
if args.bumble or platform.system() == "Windows": if args.bumble or platform.system() == "Windows":