Fix Bug Where Long Messages Got Cut Off
authorGeorgios Atheridis <georgios@atheridis.org>
Fri, 30 Jun 2023 16:53:29 +0000 (17:53 +0100)
committerGeorgios Atheridis <georgios@atheridis.org>
Fri, 30 Jun 2023 16:53:29 +0000 (17:53 +0100)
.flake8 [new file with mode: 0644]
aptbot/bot.py

diff --git a/.flake8 b/.flake8
new file mode 100644 (file)
index 0000000..541e00a
--- /dev/null
+++ b/.flake8
@@ -0,0 +1,3 @@
+[flake8]
+ignore = E203, W503
+max-line-length = 88
index cabd595706475a4483fc6dd4b0ad892576d4aac5..1ef4b474960a3b874f6058a9bd92adf220bbdded 100644 (file)
@@ -1,7 +1,6 @@
 import logging
 import re
 import socket
-import sys
 import time
 from abc import ABC, abstractmethod
 from dataclasses import dataclass, field
@@ -44,19 +43,69 @@ class ABCBot(ABC):
 
 
 class Bot(ABCBot):
+    """
+    Creates and manages a connection to Twitch's IRC servers.
+    """
+
+    # Connectiion information
+    _SERVER = "irc.chat.twitch.tv"
+    _PORT = 6667
+
+    # max character length allowed to send through twitch chat
+    _MAX_LEN = 500
+
+    _LOGIN = (
+        b"CAP REQ :twitch.tv/membership\r\n"
+        b"CAP REQ :twitch.tv/tags\r\n"
+        b"CAP REQ :twitch.tv/commands\r\n"
+        b"PASS oauth:%(pass)s\r\n"
+        b"NICK %(nick)s\r\n"
+    )
+
+    _RECONNECT = b":tmi.twitch.tv RECONNECT"
+    _PING = b"PING :tmi.twitch.tv"
+    _PONG = b"PONG :tmi.twitch.tv"
+    _JOIN = b"JOIN #%s"
+    _PART = b"PART #%s"
+
+    # constants to show if the connection was successful
+    _AUTH_SUCC = (
+        b":tmi.twitch.tv 001 %(nick)s :Welcome, GLHF!\r\n"
+        b":tmi.twitch.tv 002 %(nick)s :Your host is tmi.twitch.tv\r\n"
+        b":tmi.twitch.tv 003 %(nick)s :This server is rather new\r\n"
+        b":tmi.twitch.tv 004 %(nick)s :-\r\n"
+        b":tmi.twitch.tv 375 %(nick)s :-\r\n"
+        b":tmi.twitch.tv 372 %(nick)s :You are in a maze of twisty passages, "
+        b"all alike.\r\n"
+        b":tmi.twitch.tv 376 %(nick)s :>\r\n"
+    )
+    _AUTH_FAIL = b":tmi.twitch.tv NOTICE * :Login authentication failed"
+    _AUTH_BAD_FORMAT = b":tmi.twitch.tv NOTICE * :Improperly formatted auth"
+
+    # reply and PRIVMSG are strings, not raw bytes.
+    _REPLY = "@reply-parent-msg-id=%s "
+    _PRIVMSG = "PRIVMSG #%s :%s"
+
+    # time until authentication timeout in seconds
+    _AUTH_TIMEOUT = 15
+
+    # maximum kilobytes to read from socket
+    _BYTE_READ = 1024 * 256
+
     def __init__(self, nick: str, oauth_token: str):
-        self._server = "irc.chat.twitch.tv"
-        self._port = 6667
-        self._nick = nick
-        self._oauth_token = oauth_token
-        self._connected_channels = set()
+        self._nick = nick.encode()
+        self._oauth_token = oauth_token.encode()
+        self._connected_channels: set[str] = set()
+
         self._buffered_messages = []
-        self.empty_message_count = 0
+        self._empty_message_count = 0
+
+        self._recv_buff = b""
 
-    def _send_command(self, command: str):
-        if "PASS" not in command:
+    def _send_command(self, command: bytes):
+        if b"PASS" not in command:
             logger.info(f"< {command}")
-        self._irc.send((command + "\r\n").encode())
+        self._irc.send((command + b"\r\n"))
 
     def connect(self) -> bool:
         self._connect()
@@ -65,16 +114,14 @@ class Bot(ABCBot):
 
     def _connect(self) -> None:
         self._irc = socket.socket()
-        self._irc.connect((self._server, self._port))
+        self._irc.connect((self._SERVER, self._PORT))
         logger.debug("Connecting...")
-        self._send_command(f"PASS oauth:{self._oauth_token}")
-        self._send_command(f"NICK {self._nick}")
         self._send_command(
-            f"CAP REQ :twitch.tv/membership twitch.tv/tags twitch.tv/commands"
+            self._LOGIN % {b"pass": self._oauth_token, b"nick": self._nick}
         )
 
     def join_channel(self, channel: str):
-        self._send_command(f"{Commands.JOIN.value} #{channel}")
+        self._send_command(self._JOIN % channel.encode())
         self._connected_channels.add(channel)
 
     def join_channels(self, channels: Iterable):
@@ -82,7 +129,7 @@ class Bot(ABCBot):
             self.join_channel(channel)
 
     def leave_channel(self, channel: str):
-        self._send_command(f"{Commands.PART.value} #{channel}")
+        self._send_command(self._PART % channel.encode())
         try:
             self._connected_channels.remove(channel)
         except KeyError as e:
@@ -97,10 +144,10 @@ class Bot(ABCBot):
         if isinstance(text, list):
             for t in text:
                 command = replied_command + f"{Commands.PRIVMSG.value} #{channel} :{t}"
-                self._send_command(command)
+                self._send_command(command.encode())
         else:
             command = replied_command + f"{Commands.PRIVMSG.value} #{channel} :{text}"
-            self._send_command(command)
+            self._send_command(command.encode())
 
     @staticmethod
     def _replace_escaped_characters_in_tags(tag_value: str) -> str:
@@ -166,68 +213,62 @@ class Bot(ABCBot):
             value=value,
         )
 
-    def _handle_message(self, received_msg: str) -> Message:
-        logger.info(f"> {received_msg}")
-        if received_msg == "PING :tmi.twitch.tv":
-            self._send_command("PONG :tmi.twitch.tv")
-            return Message(command=Commands.PING)
-        elif received_msg == ":tmi.twitch.tv RECONNECT":
-            self._restart_connection()
-        elif not received_msg:
-            return Message()
-        return Bot._parse_message(received_msg)
-
     def _receive_messages(self) -> bytes:
-        for _ in range(10):
-            try:
-                received_msgs = self._irc.recv(2048)
-            except ConnectionResetError as e:
-                logger.exception(e)
-                time.sleep(1)
-                self._restart_connection()
-            else:
-                break
-        else:
-            logger.error("Unable to connect to twitch. Exiting")
-            sys.exit(1)
-        return received_msgs
+        try:
+            data = self._irc.recv(self._BYTE_READ)
+        except ConnectionError as e:
+            logger.exception(e)
+            self._restart_connection()
+            return b""
+        if len(data) == 0:
+            logger.error("Connection was terminated.")
+            self._restart_connection()
+        return data
 
     def _connected(self) -> bool:
-        received_msgs = self._receive_messages()
-        for received_msg in received_msgs.decode("utf-8").split("\r\n"):
-            self._buffered_messages.append(self._handle_message(received_msg))
-        if self._buffered_messages[0] == Message(
-            {},
-            "",
-            Commands.NOTICE,
-            "",
-            "Login authentication failed",
-        ):
-            logger.debug(f"Not connected")
-            return False
-        logger.debug(f"Connected")
-        return True
-
-    def get_messages(self) -> list[Message]:
-        messages = []
-        messages.extend(self._buffered_messages)
-        self._buffered_messages = []
-        received_msgs = self._receive_messages()
-        for received_msg in received_msgs.decode("utf-8").split("\r\n"):
-            message = self._handle_message(received_msg)
-            messages.append(message)
-
-            # If twitch closes the connection,
-            # we get spammed by empty messages.
-            # So we restart the connection
-            if message == Message():
-                self.empty_message_count += 1
-            else:
-                self.empty_message_count = 0
-            if self.empty_message_count > 10:
-                self.empty_message_count = 0
-                self._restart_connection()
-                return messages
+        """
+        Authenticate that the connection was successful
+        """
+        message = b""
+        timeout_start = time.time()
+        while time.time() - timeout_start < self._AUTH_TIMEOUT:
+            data = self._receive_messages()
+            message += data
+            if self._AUTH_SUCC % {b"nick": self._nick} in message:
+                logger.info("Connection with %s authenticated", self._nick)
+                self.disconnect()
+                return True
+            elif self._AUTH_BAD_FORMAT in message:
+                logger.critical(
+                    "Message with %s is badly formatted: %s",
+                    self._nick,
+                    message,
+                )
+                self.disconnect()
+                return False
+            elif self._AUTH_FAIL in message:
+                logger.warning("%s failed authentication", self._nick)
+                return False
+        self.disconnect()
+        logger.critical("Connection with %s timed out with: %s", self._nick, message)
+        return False
+
+    def get_messagess(self) -> list[Message]:
+        messages: list[Message] = []
+        data = self._receive_messages()
+        if not data:
+            return []
+        data = self._recv_buff + data
+        self._recv_buff = b""
+        split_messages = data.split(b"\r\n")
+        self._recv_buff = split_messages[-1]
+        if self._RECONNECT in split_messages:
+            logger.warning("Reconnecting due to twitch.")
+            self._restart_connection()
+        elif self._PING in split_messages:
+            self._send_command(self._PONG)
+        for msg in split_messages[:-1]:
+            messages.append(Bot._parse_message(msg.decode()))
         return messages
 
     def disconnect(self) -> None:
@@ -237,10 +278,10 @@ class Bot(ABCBot):
     def _restart_connection(self):
         logger.warning("Restarting twitch connection")
         self.disconnect()
-        time.sleep(2)
+        time.sleep(1.5)
         self._connect()
         self.join_channels(self._connected_channels)
-        time.sleep(2)
+        time.sleep(0.5)
 
     # Aliasing method names for backwards compatibility
     send_privmsg = send_message