commit c15ffe9b4ce2879587d5524014d495f99b2445a8
parent 6dcdbfa0671ebae27e65523fb2c3c91685765773
Author: Asher Morgan <59518073+ashermorgan@users.noreply.github.com>
Date: Tue, 23 Apr 2024 20:33:02 -0700
Update botUtilities to use postgresql db
Diffstat:
3 files changed, 141 insertions(+), 89 deletions(-)
diff --git a/countdown_bot/bot.py b/countdown_bot/bot.py
@@ -24,7 +24,7 @@ class CountdownBot(commands.Bot):
intents.message_content = True
# Initialize bot
- super().__init__(command_prefix=lambda bot, ctx: getPrefix(self.databaseSessionMaker, ctx, self.prefixes), intents=intents)
+ super().__init__(command_prefix=lambda bot, ctx: getPrefix(self.db_connection, ctx, self.prefixes), intents=intents)
@@ -63,11 +63,9 @@ class CountdownBot(commands.Bot):
await obj.channel.send(embed=embed)
# Parse countdown message
- with self.databaseSessionMaker() as session:
- countdown = getCountdown(session, obj.channel.id)
- if (countdown):
- # Add message to countdown and commit changes
- if (await addMessage(countdown, obj)): session.commit()
+ with self.db_connection.cursor() as cur:
+ if (await addMessage(cur, obj)):
+ self.db_connection.commit()
# Run commands
try:
diff --git a/countdown_bot/botUtilities.py b/countdown_bot/botUtilities.py
@@ -3,7 +3,7 @@ import discord
import re
# Import modules
-from .models import Countdown, Message, MessageIncorrectError, MessageNotAllowedError
+from .models import Countdown, Message
@@ -148,6 +148,29 @@ def getCountdown(session, id):
+def isCountdown(cur, id):
+ """
+ Determine whether a channel is a countdown
+
+ Parameters
+ ----------
+ cur : psycopg.cursor
+ The database cursor
+ id : int
+ The countdown ID
+
+ Returns
+ -------
+ bool
+ A boolean indicating whether the channel is a countdown
+ """
+
+ cur.execute("CALL isCountdown(%s, null);",
+ (ctx.channel.id,))
+ return cur.fetchone()[0]
+
+
+
def getContextCountdown(session, ctx):
"""
Get the most relevant countdown to a certain context
@@ -186,80 +209,74 @@ def getContextCountdown(session, ctx):
raise CountdownNotFound()
-
-
-def getPrefix(databaseSessionMaker, ctx, default):
+def getContextCountdown2(cur, ctx):
"""
- Get the bot prefix for a certain context
+ Get the most relevant countdown to a certain context
Parameters
----------
- databaseSessionMaker : sqlalchemy.orm.sessionmaker
- The database session maker
+ cur : psycopg.cursor
+ The database cursor
ctx : discord.ext.commands.Context
The context
- default : list
- The default prefixes
+
+ Returns
+ -------
+ countdownID
+ The countdown ID
"""
- with databaseSessionMaker() as session:
- # Countdown channel
- countdown = getCountdown(session, ctx.channel.id)
- if (countdown and len(countdown.prefixes) > 0):
- return [x.value.lower() for x in countdown.prefixes]
+ if (isinstance(ctx.channel, discord.channel.TextChannel)):
+ # Channel inside a server
+ cur.execute("CALL getServerContextCountdown(%s, %s, %s, null);",
+ (ctx.channel.guild.id, ctx.channel.id, ctx.prefix))
+ return cur.fetchone()[0]
- # Server with countdown channels
- if (isinstance(ctx.channel, discord.channel.TextChannel)):
- serverCountdowns = session.query(Countdown).filter(Countdown.server_id == ctx.channel.guild.id).all()
- # Get list of prefixes
- prefixes = []
- for countdown in serverCountdowns:
- prefixes += [x.value.lower() for x in countdown.prefixes]
- if (len(prefixes) > 0):
- return list(dict.fromkeys(prefixes))
+ if (isinstance(ctx.channel, discord.channel.DMChannel)):
+ # DM with a user
+ cur.execute("CALL getUserContextCountdown(%s, null);",
+ (ctx.author.id,))
+ return cur.fetchone()[0]
- # Return default prefixes
- return [x.lower() for x in default]
+ return None
-def parseMessage(message):
+def getPrefix(conn, ctx, default):
"""
- Parses a countdown message from a Discord message
+ Get the bot prefix for a certain context
Parameters
----------
- message : discord.Message
- The Discord message
-
- Returns
- -------
- Message
+ conn : psycopg.Connection
+ The database connection
+ ctx : discord.ext.commands.Context
+ The context
+ default : list
+ The default prefixes
"""
- return Message(
- id = message.id,
- countdown_id = message.channel.id,
- author_id = message.author.id,
- timestamp = message.created_at,
- number = int(re.findall("^[0-9,]+", message.content)[0].replace(",","")),
- )
+ with conn.cursor() as cur:
+ cur.execute("SELECT * FROM getPrefixes(%s, %s);",
+ (ctx.channel.guild.id, ctx.channel.id))
+ prefixes = cur.fetchall()
+ return [x[0] for x in prefixes] if prefixes else default
-async def addMessage(countdown, rawMessage):
+async def addMessage(cur, message):
"""
Parse a message and add it to a countdown
Notes
-----
- If the message is invalid or incorrect, a reacted will be added accordingly
+ If the message is invalid or incorrect, a reaction will be added accordingly
Parameters
----------
- countdown : Countdown
- The countdown
- rawMessage : discord.Message
+ cur : psycopg.cursor
+ The database cursor
+ message : discord.Message
The Discord message object
Returns
@@ -268,32 +285,32 @@ async def addMessage(countdown, rawMessage):
Whether the message was valid and added to the countdown
"""
- try:
- # Parse message
- message = parseMessage(rawMessage)
-
- # Add message
- countdown.addMessage(message)
-
- # Mark important messages
- if (message.number in [x.number for x in countdown.reactions]):
- for reaction in [x for x in countdown.reactions if x.number == message.number]:
- try:
- await rawMessage.add_reaction(reaction.value)
- except:
- pass
- if (countdown.messages[0].number >= 500 and message.number % (countdown.messages[0].number // 50) == 0):
- await rawMessage.pin()
- except MessageNotAllowedError:
- await rawMessage.add_reaction("⛔")
- return False
- except MessageIncorrectError:
- await rawMessage.add_reaction("❌")
- return False
- except:
- return False
- else:
- return True
+ # Parse message number
+ match = re.search("^[0-9,]+", message.content)
+ if not match: return False
+ number = int(match[0].replace(",", ""))
+
+ # Attempt to add result
+ cur.execute("CALL addMessage(%s,%s,%s,%s,%s,null,null,null);", (
+ message.id, message.channel.id, message.author.id, number,
+ message.created_at
+ ))
+ result = cur.fetchone()
+
+ # Process result
+ if result[0] == 'badNumber':
+ await message.add_reaction("❌")
+ if result[0] == 'badUser':
+ await message.add_reaction("⛔")
+ if result[1]:
+ await message.pin()
+ if result[2]:
+ cur.execute("SELECT * FROM getReactions(%s, %s);",
+ (message.channel.id, number))
+ for reaction in cur.fetchall():
+ await message.add_reaction(reaction[0])
+
+ return result[0] == 'good'
@@ -305,17 +322,24 @@ async def loadCountdown(bot, countdown):
----------
bot : commands.Bot
The bot to load messages with
+ cur : psycopg.cursor
+ The database cursor
countdown : Countdown
The countdown to load messages for
"""
- # Clear countdown
- countdown.messages = []
+ with bot.db_connection.cursor() as cur:
+ # Clear countdown
+ cur.execute("CALL clearCountdown(%s);", (countdown.id,))
+
+ # Get Discord messages
+ messages = [message async for message in
+ bot.get_channel(countdown.id).history(limit=10100)]
+ messages.reverse()
- # Get Discord messages
- rawMessages = [message async for message in bot.get_channel(countdown.id).history(limit=10100)]
- rawMessages.reverse()
+ # Add messages to countdown
+ for message in messages:
+ await addMessage(cur, message)
- # Add messages to countdown
- for rawMessage in rawMessages:
- await addMessage(countdown, rawMessage)
+ # Commit changes
+ bot.db_connection.commit()
diff --git a/models/utilities.sql b/models/utilities.sql
@@ -3,6 +3,8 @@
DROP FUNCTION IF EXISTS getReactions;
DROP PROCEDURE IF EXISTS addMessage;
DROP TYPE IF EXISTS addMessageResults;
+DROP PROCEDURE IF EXISTS clearCountdown;
+DROP PROCEDURE IF EXISTS isCountdown;
DROP PROCEDURE IF EXISTS getUserContextCountdown;
DROP PROCEDURE IF EXISTS getServerContextCountdown;
DROP FUNCTION IF EXISTS getPrefixes;
@@ -89,6 +91,34 @@ BEGIN
END
$$;
+-- Determine if a channel is a countdown
+CREATE PROCEDURE isCountdown (
+ channelID IN BIGINT, -- The channel ID
+ result OUT BOOLEAN -- Whether the channel is a countdown
+)
+LANGUAGE plpgsql AS $$
+BEGIN
+ SELECT EXISTS(
+ SELECT 1
+ FROM countdowns
+ WHERE countdownID = channelID
+ ) INTO result;
+END
+$$;
+
+-- Delete all messages in a countdown
+CREATE PROCEDURE clearCountdown (
+ _countdownID IN BIGINT -- The countdown channel ID
+)
+LANGUAGE plpgsql AS $$
+BEGIN
+ DELETE
+ FROM messages
+ WHERE countdownID = _countdownID;
+END
+$$;
+
+
-- Possible results of the addMessage procedure
CREATE TYPE addMessageResults AS ENUM (
'badCountdown', -- Countdown doesn't exist or has ended
@@ -126,6 +156,10 @@ BEGIN
ORDER BY messages.value ASC
LIMIT 1;
+ -- Initialize pin and reactions
+ pin := FALSE;
+ reactions := FALSE;
+
-- Validate message
IF lastMessage.countdownID IS NULL OR lastMessage.value = 0 THEN
-- Countdown doesn't exist or has ended
@@ -157,8 +191,6 @@ BEGIN
-- Check if message should be pinned
IF total >= 500 AND _value % (total / 50) = 0 AND _value != 0 THEN
pin := TRUE;
- ELSE
- pin := FALSE;
END IF;
-- Check if message has custom reactions
@@ -166,8 +198,6 @@ BEGIN
WHERE countdownID = _countdownID AND number = _value
) THEN
reactions := TRUE;
- ELSE
- reactions := FALSE;
END IF;
END IF;
END