commit 23e108169aef67f737cba0b61f221b2c072d9d2e
parent 81044b0a3be33fc3ddaa4ea4d02415369a62b466
Author: ashermorgan <59518073+ashermorgan@users.noreply.github.com>
Date: Sat, 26 Jun 2021 19:22:30 -0700
Create CountdownBot class
Diffstat:
7 files changed, 106 insertions(+), 99 deletions(-)
diff --git a/run.py b/run.py
@@ -1,2 +1,19 @@
-import src
-src.countdownBot.run(src.settings["token"])
+# Import dependencies
+import json
+import os
+
+# Import modules
+from src import CountdownBot
+
+
+
+# Load settings
+settings = {}
+with open(os.path.join(os.path.dirname(os.path.dirname(__file__)), "settings.json"), "a+") as f:
+ f.seek(0)
+ settings = json.load(f)
+
+
+
+# Run countdown-bot
+CountdownBot(settings["database"], settings["prefixes"]).run(settings["token"])
diff --git a/src/__init__.py b/src/__init__.py
@@ -1,26 +1,2 @@
-# Import dependencies
-import os
-import json
-from sqlalchemy import create_engine
-from sqlalchemy.orm import sessionmaker
-
-
-
-# Load settings
-settings = {}
-with open(os.path.join(os.path.dirname(os.path.dirname(__file__)), "settings.json"), "a+") as f:
- f.seek(0)
- settings = json.load(f)
-
-
-
-# Connect to database and create tables
-from src.models import Base
-engine = create_engine(settings["database"])
-Base.metadata.create_all(bind=engine)
-Session = sessionmaker(bind=engine)
-
-
-
-# Import bot so it can be easily imported from src
-from src.bot import bot as countdownBot
+# Import CountdownBot so it can be easily imported from src
+from src.bot import CountdownBot
diff --git a/src/analyticsCog.py b/src/analyticsCog.py
@@ -9,15 +9,15 @@ import re
import tempfile
# Import modules
-from src import Session
from src.botUtilities import COLORS, getContextCountdown, getNickname, getUsername
from src.models import POINT_RULES
class Analytics(commands.Cog):
- def __init__(self, bot):
+ def __init__(self, bot, databaseSessionMaker):
self.bot = bot
+ self.databaseSessionMaker = databaseSessionMaker
@@ -27,7 +27,7 @@ class Analytics(commands.Cog):
Shows all countdown analytics
"""
- with Session() as session:
+ with self.databaseSessionMaker() as session:
# Get countdown channel
countdown = getContextCountdown(session, ctx)
@@ -54,7 +54,7 @@ class Analytics(commands.Cog):
Shows information about countdown contributors
"""
- with Session() as session:
+ with self.databaseSessionMaker() as session:
# Get countdown channel
countdown = getContextCountdown(session, ctx)
@@ -160,7 +160,7 @@ class Analytics(commands.Cog):
Shows information about the estimated completion date
"""
- with Session() as session:
+ with self.databaseSessionMaker() as session:
# Get countdown channel
countdown = getContextCountdown(session, ctx)
@@ -244,7 +244,7 @@ class Analytics(commands.Cog):
Shows the countdown leaderboard
"""
- with Session() as session:
+ with self.databaseSessionMaker() as session:
# Get countdown channel
countdown = getContextCountdown(session, ctx)
@@ -301,7 +301,7 @@ class Analytics(commands.Cog):
if (rank == None):
# Get user from nickname
for contributor in leaderboard:
- nickname = await getNickname(countdown.server_id, contributor["author"])
+ nickname = await getNickname(self.bot, countdown.server_id, contributor["author"])
if (nickname.lower().startswith(user.lower())):
rank = leaderboard.index(contributor)
@@ -345,7 +345,7 @@ class Analytics(commands.Cog):
Shows information about countdown progress
"""
- with Session() as session:
+ with self.databaseSessionMaker() as session:
# Get countdown channel
countdown = getContextCountdown(session, ctx)
@@ -416,7 +416,7 @@ class Analytics(commands.Cog):
Shows information about countdown speed
"""
- with Session() as session:
+ with self.databaseSessionMaker() as session:
# Get countdown channel
countdown = getContextCountdown(session, ctx)
diff --git a/src/bot.py b/src/bot.py
@@ -4,58 +4,59 @@ from discord.ext import commands
# Import modules
-from src import analyticsCog, utilitiesCog, Session
+from src import analyticsCog, utilitiesCog
from src.botUtilities import addMessage, COLORS, getCountdown, getPrefix
+from src.models import getSessionMaker
-# Create Discord bot
-bot = commands.Bot(command_prefix=getPrefix, case_insensitive=True)
+class CountdownBot(commands.Bot):
+ def __init__(self, databaseLocation, prefixes=["c."]):
+ # Initialize bot
+ commands.Bot.__init__(self, command_prefix=lambda bot, ctx: getPrefix(self.databaseSessionMaker, ctx, self.prefixes), case_insensitive=True)
+ # Set properties
+ self.databaseSessionMaker = getSessionMaker(databaseLocation)
+ self.prefixes = prefixes
+ # Add cogs
+ self.add_cog(analyticsCog.Analytics(self, self.databaseSessionMaker))
+ self.add_cog(utilitiesCog.Utilities(self, self.databaseSessionMaker))
-# Add cogs
-bot.add_cog(analyticsCog.Analytics(bot))
-bot.add_cog(utilitiesCog.Utilities(bot))
+ async def on_ready(self):
+ print(f"Connected to Discord as {self.user}")
-@bot.event
-async def on_ready():
- print(f"Connected to Discord as {bot.user}")
+ async def on_message(self, obj):
+ # Respond to @mentions
+ if self.user in obj.mentions:
+ embed=discord.Embed(title="countdown-bot", description=f"Use `{(await self.get_prefix(obj))[0]}help` to view help information", color=COLORS["embed"])
+ await obj.channel.send(embed=embed)
+ # Parse countdown message
+ with self.databaseSessionMaker() as session:
+ countdown = getCountdown(session, obj.channel.id)
+ if (countdown and obj.author.name != "countdown-bot"):
+ # Add message to countdown and commit changes
+ if (await addMessage(countdown, obj)): session.commit()
-@bot.event
-async def on_message(obj):
- # Respond to @mentions
- if bot.user in obj.mentions:
- embed=discord.Embed(title="countdown-bot", description=f"Use `{(await bot.get_prefix(obj))[0]}help` to view help information", color=COLORS["embed"])
- await obj.channel.send(embed=embed)
+ # Run commands
+ try:
+ await self.process_commands(obj)
+ except:
+ pass
- # Parse countdown message
- with Session() as session:
- countdown = getCountdown(session, obj.channel.id)
- if (countdown and obj.author.name != "countdown-bot"):
- # Add message to countdown and commit changes
- if (await addMessage(countdown, obj)): session.commit()
- # Run commands
- try:
- await bot.process_commands(obj)
- except:
- pass
-
-
-@bot.event
-async def on_command_error(ctx, error):
- # Send error embed
- embed=discord.Embed(title="Error", description=str(error), color=COLORS["error"])
- if (isinstance(error, commands.CommandNotFound)):
- embed.description = f"Command not found: `{str(error)[9:-14]}`"
- else:
- embed.description = str(error)
- embed.description += f"\nUse `{(await bot.get_prefix(ctx))[0]}help` to view help information\n"
- await ctx.send(embed=embed)
+ async def on_command_error(self, ctx, error):
+ # Send error embed
+ embed=discord.Embed(title="Error", description=str(error), color=COLORS["error"])
+ if (isinstance(error, commands.CommandNotFound)):
+ embed.description = f"Command not found: `{str(error)[9:-14]}`"
+ else:
+ embed.description = str(error)
+ embed.description += f"\nUse `{(await self.get_prefix(ctx))[0]}help` to view help information\n"
+ await ctx.send(embed=embed)
diff --git a/src/botUtilities.py b/src/botUtilities.py
@@ -3,7 +3,6 @@ import discord
import re
# Import modules
-from src import Session, settings
from src.models import Countdown, Message, MessageIncorrectError, MessageNotAllowedError
@@ -56,7 +55,7 @@ async def getNickname(bot, server, id):
The nickname
"""
- return (await (bot.get_guild(server)).fetch_member(id)).nick or await getUsername(id)
+ return (await (bot.get_guild(server)).fetch_member(id)).nick or await getUsername(bot, id)
@@ -100,8 +99,6 @@ def getContextCountdown(session, ctx, resortToFirst=True):
The countdown
"""
- global settings
-
if (isinstance(ctx.channel, discord.channel.TextChannel)):
# Countdown channel
countdown = getCountdown(session, ctx.channel.guild.id)
@@ -118,21 +115,22 @@ def getContextCountdown(session, ctx, resortToFirst=True):
-def getPrefix(bot, ctx):
+def getPrefix(databaseSessionMaker, ctx, default):
"""
Get the bot prefix for a certain context.
Parameters
----------
- bot : commands.Bot
- The bot
+ databaseSessionMaker : sqlalchemy.orm.sessionmaker
+ The database session maker
ctx : discord.ext.commands.Context
The context
+ default : list
+ The default prefixes
"""
- with Session() as session:
+ with databaseSessionMaker() as session:
# Countdown channel
- global settings
countdown = getCountdown(session, ctx.channel.id)
if (countdown and len(countdown.prefixes) > 0):
return [x.value for x in countdown.prefixes]
@@ -148,7 +146,7 @@ def getPrefix(bot, ctx):
return list(dict.fromkeys(prefixes))
# Return default prefixes
- return settings["prefixes"]
+ return default
diff --git a/src/models.py b/src/models.py
@@ -1,12 +1,32 @@
# Import dependencies
from datetime import datetime, timedelta
import math
-from sqlalchemy import Column, Integer, String, Float, DateTime, ForeignKey
-from sqlalchemy.orm import relationship
+from sqlalchemy import create_engine, Column, Integer, String, Float, DateTime, ForeignKey
+from sqlalchemy.orm import relationship, sessionmaker
from sqlalchemy.ext.declarative import declarative_base
+Base = declarative_base()
+
+
+
+def getSessionMaker(location):
+ """
+ Create a sessionmaker from a database URI
+
+ Parameters
+ ----------
+ location : str
+ The location of the database
+ """
+
+ engine = create_engine(location)
+ Base.metadata.create_all(bind=engine)
+ return sessionmaker(bind=engine)
+
+
+
# The rules for awarding leaderboard points
POINT_RULES = {
"1000s": 1000,
@@ -34,11 +54,6 @@ class MessageIncorrectError(Exception):
-# Initialize declarative base
-Base = declarative_base()
-
-
-
class Countdown(Base):
"""
A Discord countdown
diff --git a/src/utilitiesCog.py b/src/utilitiesCog.py
@@ -3,15 +3,15 @@ import discord
from discord.ext import commands
# Import modules
-from src import Session, settings
from src.botUtilities import COLORS, getContextCountdown, getCountdown, loadCountdown
from src.models import Countdown, Prefix, Reaction
class Utilities(commands.Cog):
- def __init__(self, bot):
+ def __init__(self, bot, databaseSessionMaker):
self.bot = bot
+ self.databaseSessionMaker = databaseSessionMaker
self.bot.remove_command("help")
@@ -22,7 +22,7 @@ class Utilities(commands.Cog):
Turns a channel into a countdown
"""
- with Session() as session:
+ with self.databaseSessionMaker() as session:
# Channel is already a coutndown
if (getCountdown(session, ctx.channel.id)):
embed = discord.Embed(title="Error", description="This channel is already a countdown", color=COLORS["error"])
@@ -45,7 +45,7 @@ class Utilities(commands.Cog):
id = ctx.channel.id,
server_id = ctx.channel.guild.id,
timezone = 0,
- prefixes = [Prefix(countdown_id=ctx.channel.id, value=x) for x in settings["prefixes"]],
+ prefixes = [Prefix(countdown_id=ctx.channel.id, value=x) for x in self.bot.prefixes],
reactions = [],
messages = [],
)
@@ -76,7 +76,7 @@ class Utilities(commands.Cog):
# Create embed
embed = discord.Embed(title=":gear: Countdown Settings", color=COLORS["embed"])
- with Session() as session:
+ with self.databaseSessionMaker() as session:
# Get countdown channel
try:
countdown = getContextCountdown(session, ctx, resortToFirst=False)
@@ -152,7 +152,7 @@ class Utilities(commands.Cog):
Deactivates a countdown channel
"""
- with Session() as session:
+ with self.databaseSessionMaker() as session:
# Channel isn't a countdown
countdown = getCountdown(session, ctx.channel.id)
if (not countdown):
@@ -363,7 +363,7 @@ class Utilities(commands.Cog):
Reloads the countdown cache
"""
- with Session() as session:
+ with self.databaseSessionMaker() as session:
countdown = getCountdown(session, ctx.channel.id)
if (countdown):
# Send inital responce