From 1ef223e2381a105ea89c5fe18dffe4b595c1370b Mon Sep 17 00:00:00 2001 From: KnugiHK <24708955+KnugiHK@users.noreply.github.com> Date: Sun, 2 Mar 2025 01:41:44 +0800 Subject: [PATCH] Refactor the data model --- Whatsapp_Chat_Exporter/__main__.py | 4 +- Whatsapp_Chat_Exporter/android_handler.py | 49 ++++++------ Whatsapp_Chat_Exporter/data_model.py | 96 ++++++++++++++++++++++- Whatsapp_Chat_Exporter/ios_handler.py | 53 +++++++------ 4 files changed, 150 insertions(+), 52 deletions(-) diff --git a/Whatsapp_Chat_Exporter/__main__.py b/Whatsapp_Chat_Exporter/__main__.py index 849b1b6..051f166 100644 --- a/Whatsapp_Chat_Exporter/__main__.py +++ b/Whatsapp_Chat_Exporter/__main__.py @@ -10,7 +10,7 @@ import glob import importlib.metadata from Whatsapp_Chat_Exporter import android_crypt, exported_handler, android_handler from Whatsapp_Chat_Exporter import ios_handler, ios_media_handler -from Whatsapp_Chat_Exporter.data_model import ChatStore +from Whatsapp_Chat_Exporter.data_model import ChatCollection, ChatStore from Whatsapp_Chat_Exporter.utility import APPLE_TIME, Crypt, check_update, DbType from Whatsapp_Chat_Exporter.utility import readable_to_bytes, sanitize_filename from Whatsapp_Chat_Exporter.utility import import_from_json, bytes_to_readable @@ -403,7 +403,7 @@ def main(): parser.error("Enter a phone number in the chat filter. See https://wts.knugi.dev/docs?dest=chat") filter_chat = (args.filter_chat_include, args.filter_chat_exclude) - data = {} + data = ChatCollection() if args.enrich_from_vcards is not None: if not vcards_deps_installed: diff --git a/Whatsapp_Chat_Exporter/android_handler.py b/Whatsapp_Chat_Exporter/android_handler.py index ca6e8a7..3374c8c 100644 --- a/Whatsapp_Chat_Exporter/android_handler.py +++ b/Whatsapp_Chat_Exporter/android_handler.py @@ -32,9 +32,9 @@ def contacts(db, data, enrich_from_vcards): c.execute("""SELECT jid, COALESCE(display_name, wa_name) as display_name, status FROM wa_contacts; """) row = c.fetchone() while row is not None: - data[row["jid"]] = ChatStore(Device.ANDROID, row["display_name"]) + current_chat = data.add_chat(row["jid"], ChatStore(Device.ANDROID, row["display_name"])) if row["status"] is not None: - data[row["jid"]].status = row["status"] + current_chat.status = row["status"] row = c.fetchone() @@ -207,8 +207,10 @@ def messages(db, data, media_folder, timezone_offset, filter_date, filter_chat, else: break while content is not None: - if content["key_remote_jid"] not in data: - data[content["key_remote_jid"]] = ChatStore(Device.ANDROID, content["chat_subject"]) + if not data.get_chat(content["key_remote_jid"]): + current_chat = data.add_chat(content["key_remote_jid"], ChatStore(Device.ANDROID, content["chat_subject"])) + else: + current_chat = data.get_chat(content["key_remote_jid"]) if content["key_remote_jid"] is None: continue # Not sure if "sender_jid_row_id" in content: @@ -232,7 +234,7 @@ def messages(db, data, media_folder, timezone_offset, filter_date, filter_chat, f"""('Decode')&input={b64encode(b64encode(content["data"])).decode()}">""") message.data += b64encode(content["data"]).decode("utf-8") + "" message.safe = message.meta = True - data[content["key_remote_jid"]].add_message(content["_id"], message) + current_chat.add_message(content["_id"], message) i += 1 content = c.fetchone() continue @@ -242,13 +244,13 @@ def messages(db, data, media_folder, timezone_offset, filter_date, filter_chat, if content["sender_jid_row_id"] > 0: _jid = content["group_sender_jid"] if _jid in data: - name = data[_jid].name + name = data.get_chat(_jid).name if "@" in _jid: fallback = _jid.split('@')[0] else: if content["remote_resource"] is not None: if content["remote_resource"] in data: - name = data[content["remote_resource"]].name + name = data.get_chat(content["remote_resource"]).name if "@" in content["remote_resource"]: fallback = content["remote_resource"].split('@')[0] @@ -281,7 +283,7 @@ def messages(db, data, media_folder, timezone_offset, filter_date, filter_chat, if content["sender_jid_row_id"] > 0: _jid = content["group_sender_jid"] if _jid in data: - name = data[_jid].name + name = data.get_chat(_jid).name if "@" in _jid: fallback = _jid.split('@')[0] else: @@ -290,7 +292,7 @@ def messages(db, data, media_folder, timezone_offset, filter_date, filter_chat, _jid = content["remote_resource"] if _jid is not None: if _jid in data: - name = data[_jid].name + name = data.get_chat(_jid).name if "@" in _jid: fallback = _jid.split('@')[0] else: @@ -343,7 +345,7 @@ def messages(db, data, media_folder, timezone_offset, filter_date, filter_chat, msg = msg.replace("\n", "
") message.data = msg - data[content["key_remote_jid"]].add_message(content["_id"], message) + current_chat.add_message(content["_id"], message) i += 1 if i % 1000 == 0: print(f"Processing messages...({i}/{total_row_number})", end="\r") @@ -451,7 +453,8 @@ def media(db, data, media_folder, filter_date, filter_chat, filter_empty, separa Path(f"{media_folder}/thumbnails").mkdir(parents=True, exist_ok=True) while content is not None: file_path = f"{media_folder}/{content['file_path']}" - message = data[content["key_remote_jid"]].get_message(content["message_row_id"]) + current_chat = data.get_chat(content["key_remote_jid"]) + message = current_chat.get_message(content["message_row_id"]) message.media = True if os.path.isfile(file_path): message.data = file_path @@ -464,7 +467,7 @@ def media(db, data, media_folder, filter_date, filter_chat, filter_empty, separa else: message.mime = content["mime_type"] if separate_media: - chat_display_name = slugify(data[content["key_remote_jid"]].name or message.sender \ + chat_display_name = slugify(current_chat.name or message.sender \ or content["key_remote_jid"].split('@')[0], True) current_filename = file_path.split("/")[-1] new_folder = os.path.join(media_folder, "separated", chat_display_name) @@ -547,7 +550,7 @@ def vcard(db, data, media_folder, filter_date, filter_chat, filter_empty): if not os.path.isfile(file_path): with open(file_path, "w", encoding="utf-8") as f: f.write(row["vcard"]) - message = data[row["key_remote_jid"]].get_message(row["message_row_id"]) + message = data.get_chat(row["key_remote_jid"]).get_message(row["message_row_id"]) message.data = "This media include the following vCard file(s):
" \ f'{htmle(media_name)}' message.mime = "text/x-vcard" @@ -603,7 +606,7 @@ def calls(db, data, timezone_offset, filter_chat): read_timestamp=None # TODO: Add timestamp ) _jid = content["raw_string"] - name = data[_jid].name if _jid in data else content["chat_subject"] or None + name = data.get_chat(_jid).name if _jid in data else content["chat_subject"] or None if _jid is not None and "@" in _jid: fallback = _jid.split('@')[0] else: @@ -632,7 +635,7 @@ def calls(db, data, timezone_offset, filter_chat): call.data += "in an unknown state." chat.add_message(content["_id"], call) content = c.fetchone() - data["000000000000000"] = chat + data.add_chat("000000000000000", chat) def create_html( @@ -657,8 +660,8 @@ def create_html( w3css = get_status_location(output_folder, offline_static) for current, contact in enumerate(data): - chat = data[contact] - safe_file_name, name = get_file_name(contact, chat) + current_chat = data.get_chat(contact) + safe_file_name, name = get_file_name(contact, current_chat) if maximum_size is not None: current_size = 0 @@ -666,8 +669,8 @@ def create_html( render_box = [] if maximum_size == 0: maximum_size = MAX_SIZE - last_msg = chat.get_last_message().key_id - for message in chat.get_messages(): + last_msg = current_chat.get_last_message().key_id + for message in current_chat.get_messages(): if message.data is not None and not message.meta and not message.media: current_size += len(message.data) + ROW_SIZE else: @@ -681,7 +684,7 @@ def create_html( render_box, contact, w3css, - chat, + current_chat, headline, next=f"{safe_file_name}-{current_page + 1}.html", previous=f"{safe_file_name}-{current_page - 1}.html" if current_page > 1 else False @@ -703,7 +706,7 @@ def create_html( render_box, contact, w3css, - chat, + current_chat, headline, False, previous=f"{safe_file_name}-{current_page - 1}.html" @@ -714,10 +717,10 @@ def create_html( output_file_name, template, name, - chat.get_messages(), + current_chat.get_messages(), contact, w3css, - chat, + current_chat, headline, False ) diff --git a/Whatsapp_Chat_Exporter/data_model.py b/Whatsapp_Chat_Exporter/data_model.py index 20b9754..578713a 100644 --- a/Whatsapp_Chat_Exporter/data_model.py +++ b/Whatsapp_Chat_Exporter/data_model.py @@ -1,6 +1,6 @@ import os from datetime import datetime, tzinfo, timedelta -from typing import Union, Optional, Dict, Any +from typing import MutableMapping, Union, Optional, Dict, Any class Timing: @@ -55,6 +55,98 @@ class TimeZone(tzinfo): return timedelta(0) +class ChatCollection(MutableMapping): + """ + A collection of chats that provides dictionary-like access with additional chat management methods. + Inherits from MutableMapping to implement a custom dictionary-like behavior. + """ + + def __init__(self) -> None: + """Initialize an empty chat collection.""" + self._chats: Dict[str, ChatStore] = {} + + def __getitem__(self, key: str) -> 'ChatStore': + """Get a chat by its ID. Required for dict-like access.""" + return self._chats[key] + + def __setitem__(self, key: str, value: 'ChatStore') -> None: + """Set a chat by its ID. Required for dict-like access.""" + if not isinstance(value, ChatStore): + raise TypeError("Value must be a ChatStore object") + self._chats[key] = value + + def __delitem__(self, key: str) -> None: + """Delete a chat by its ID. Required for dict-like access.""" + del self._chats[key] + + def __iter__(self): + """Iterate over chat IDs. Required for dict-like access.""" + return iter(self._chats) + + def __len__(self) -> int: + """Get number of chats. Required for dict-like access.""" + return len(self._chats) + + def get_chat(self, chat_id: str) -> Optional['ChatStore']: + """ + Get a chat by its ID. + + Args: + chat_id (str): The ID of the chat to retrieve + + Returns: + Optional['ChatStore']: The chat if found, None otherwise + """ + return self._chats.get(chat_id) + + def add_chat(self, chat_id: str, chat: 'ChatStore') -> None: + """ + Add a new chat to the collection. + + Args: + chat_id (str): The ID for the chat + chat (ChatStore): The chat to add + + Raises: + TypeError: If chat is not a ChatStore object + """ + if not isinstance(chat, ChatStore): + raise TypeError("Chat must be a ChatStore object") + self._chats[chat_id] = chat + return self._chats[chat_id] + + def remove_chat(self, chat_id: str) -> None: + """ + Remove a chat from the collection. + + Args: + chat_id (str): The ID of the chat to remove + """ + if chat_id in self._chats: + del self._chats[chat_id] + + def items(self): + """Get chat items (id, chat) pairs.""" + return self._chats.items() + + def values(self): + """Get all chats.""" + return self._chats.values() + + def keys(self): + """Get all chat IDs.""" + return self._chats.keys() + + def to_dict(self) -> Dict[str, Any]: + """ + Convert the collection to a dictionary. + + Returns: + Dict[str, Any]: Dictionary representation of all chats + """ + return {chat_id: chat.to_json() for chat_id, chat in self._chats.items()} + + class ChatStore: """ Stores chat information and messages. @@ -122,7 +214,7 @@ class ChatStore: """Get the most recent message in the chat.""" return tuple(self._messages.values())[-1] - def get_messages(self) -> Dict[str, 'Message']: + def get_messages(self) -> 'Message': """Get all messages in the chat.""" return self._messages.values() diff --git a/Whatsapp_Chat_Exporter/ios_handler.py b/Whatsapp_Chat_Exporter/ios_handler.py index a164e2d..a68183b 100644 --- a/Whatsapp_Chat_Exporter/ios_handler.py +++ b/Whatsapp_Chat_Exporter/ios_handler.py @@ -22,8 +22,9 @@ def contacts(db, data): while content is not None: if not content["ZWHATSAPPID"].endswith("@s.whatsapp.net"): ZWHATSAPPID = content["ZWHATSAPPID"] + "@s.whatsapp.net" - data[ZWHATSAPPID] = ChatStore(Device.IOS) - data[ZWHATSAPPID].status = content["ZABOUTTEXT"] + current_chat = ChatStore(Device.IOS) + current_chat.status = content["ZABOUTTEXT"] + data.add_chat(ZWHATSAPPID, current_chat) content = c.fetchone() @@ -76,20 +77,21 @@ def messages(db, data, media_folder, timezone_offset, filter_date, filter_chat, contact_name = content["ZPUSHNAME"] contact_id = content["ZCONTACTJID"] if contact_id not in data: - data[contact_id] = ChatStore(Device.IOS, contact_name, media_folder) + current_chat = data.add_chat(contact_id, ChatStore(Device.IOS, contact_name, media_folder)) else: - data[contact_id].name = contact_name - data[contact_id].my_avatar = os.path.join(media_folder, "Media/Profile/Photo.jpg") + current_chat = data.get_chat(contact_id) + current_chat.name = contact_name + current_chat.my_avatar = os.path.join(media_folder, "Media/Profile/Photo.jpg") path = f'{media_folder}/Media/Profile/{contact_id.split("@")[0]}' avatars = glob(f"{path}*") if 0 < len(avatars) <= 1: - data[contact_id].their_avatar = avatars[0] + current_chat.their_avatar = avatars[0] else: for avatar in avatars: - if avatar.endswith(".thumb") and data[content["ZCONTACTJID"]].their_avatar_thumb is None: - data[contact_id].their_avatar_thumb = avatar - elif avatar.endswith(".jpg") and data[content["ZCONTACTJID"]].their_avatar is None: - data[contact_id].their_avatar = avatar + if avatar.endswith(".thumb") and current_chat.their_avatar_thumb is None: + current_chat.their_avatar_thumb = avatar + elif avatar.endswith(".jpg") and current_chat.their_avatar is None: + current_chat.their_avatar = avatar content = c.fetchone() # Get message history @@ -135,17 +137,19 @@ def messages(db, data, media_folder, timezone_offset, filter_date, filter_chat, Z_PK = content["Z_PK"] is_group_message = content["ZGROUPINFO"] is not None if ZCONTACTJID not in data: - data[ZCONTACTJID] = ChatStore(Device.IOS) + current_chat = data.add_chat(ZCONTACTJID, ChatStore(Device.IOS)) path = f'{media_folder}/Media/Profile/{ZCONTACTJID.split("@")[0]}' avatars = glob(f"{path}*") if 0 < len(avatars) <= 1: - data[ZCONTACTJID].their_avatar = avatars[0] + current_chat.their_avatar = avatars[0] else: for avatar in avatars: if avatar.endswith(".thumb"): - data[ZCONTACTJID].their_avatar_thumb = avatar + current_chat.their_avatar_thumb = avatar elif avatar.endswith(".jpg"): - data[ZCONTACTJID].their_avatar = avatar + current_chat.their_avatar = avatar + else: + current_chat = data.get_chat(ZCONTACTJID) ts = APPLE_TIME + content["ZMESSAGEDATE"] message = Message( from_me=content["ZISFROMME"], @@ -162,7 +166,7 @@ def messages(db, data, media_folder, timezone_offset, filter_date, filter_chat, name = None if content["ZMEMBERJID"] is not None: if content["ZMEMBERJID"] in data: - name = data[content["ZMEMBERJID"]].name + name = data.get_chat(content["ZMEMBERJID"]).name if "@" in content["ZMEMBERJID"]: fallback = content["ZMEMBERJID"].split('@')[0] else: @@ -230,7 +234,7 @@ def messages(db, data, media_folder, timezone_offset, filter_date, filter_chat, msg = msg.replace("\n", "
") message.data = msg if not invalid: - data[ZCONTACTJID].add_message(Z_PK, message) + current_chat.add_message(Z_PK, message) i += 1 if i % 1000 == 0: print(f"Processing messages...({i}/{total_row_number})", end="\r") @@ -281,12 +285,11 @@ def media(db, data, media_folder, filter_date, filter_chat, filter_empty, separa mime = MimeTypes() while content is not None: file_path = f"{media_folder}/Message/{content['ZMEDIALOCALPATH']}" - ZMESSAGE = content["ZMESSAGE"] - contact = data[content["ZCONTACTJID"]] - message = contact.get_message(ZMESSAGE) + current_chat = data.get_chat(content["ZCONTACTJID"]) + message = current_chat.get_message(content["ZMESSAGE"]) message.media = True - if contact.media_base == "": - contact.media_base = media_folder + "/" + if current_chat.media_base == "": + current_chat.media_base = media_folder + "/" if os.path.isfile(file_path): message.data = '/'.join(file_path.split("/")[1:]) if content["ZVCARDSTRING"] is None: @@ -298,7 +301,7 @@ def media(db, data, media_folder, filter_date, filter_chat, filter_empty, separa else: message.mime = content["ZVCARDSTRING"] if separate_media: - chat_display_name = slugify(contact.name or message.sender \ + chat_display_name = slugify(current_chat.name or message.sender \ or content["ZCONTACTJID"].split('@')[0], True) current_filename = file_path.split("/")[-1] new_folder = os.path.join(media_folder, "separated", chat_display_name) @@ -367,7 +370,7 @@ def vcard(db, data, media_folder, filter_date, filter_chat, filter_empty): vcard_summary = "This media include the following vCard file(s):
" vcard_summary += " | ".join([f'{htmle(name)}' for name, fp in zip(vcard_names, file_paths)]) - message = data[content["ZCONTACTJID"]].get_message(content["ZMESSAGE"]) + message = data.get_chat(content["ZCONTACTJID"]).get_message(content["ZMESSAGE"]) message.data = vcard_summary message.mime = "text/x-vcard" message.media = True @@ -415,7 +418,7 @@ def calls(db, data, timezone_offset, filter_chat): timezone_offset=timezone_offset if timezone_offset else CURRENT_TZ_OFFSET ) _jid = content["ZGROUPCALLCREATORUSERJIDSTRING"] - name = data[_jid].name if _jid in data else None + name = data.get_chat(_jid).name if _jid in data else None if _jid is not None and "@" in _jid: fallback = _jid.split('@')[0] else: @@ -443,4 +446,4 @@ def calls(db, data, timezone_offset, filter_chat): call.data += "in an unknown state." chat.add_message(call.key_id, call) content = c.fetchone() - data["000000000000000"] = chat \ No newline at end of file + data.add_chat("000000000000000", chat)