idb_channel_bot/idb_channel_bot.py
2025-02-28 22:56:03 +03:00

424 lines
14 KiB
Python
Executable File

#!/usr/bin/env python3
import os
import re
import sqlite3
import logging
import asyncio
from _textutils import markdown_to_html, html_to_text, remove_paragraph_tags
from enum import Enum
from html import escape
from argparse import ArgumentParser, ArgumentError
from time import time, sleep
from dotenv import load_dotenv
from telegram import Bot, LinkPreviewOptions
from telegram.constants import ParseMode
from telegram.error import RetryAfter
from typing import Optional
logging.basicConfig(level=logging.INFO)
load_dotenv()
TOKEN = os.getenv('TELEGRAM_TOKEN')
CHANNEL_ID = os.getenv("CHANNEL_ID")
MESSAGE_LINK_TEMPLATE = os.getenv("MESSAGE_LINK_TEMPLATE")
POSTS_DIR = os.path.join(os.path.dirname(__file__), 'posts')
DB_PATH = os.path.join(os.path.dirname(__file__), "bot.db")
MAX_MESSAGE_LENGTH = 3900 # actually it's 4096, but my calculation is not perfect; telegram is full of surprises too
MAX_MESSAGE_MARKUP_LENGTH = 30000 # i read on stackoverflow that it's 32768 but i'm not sure
logger = logging.getLogger(__name__)
class PostNotFound(Exception):
pass
class NoPostsLeft(Exception):
pass
class PostType(Enum):
ARTICLE = 'article'
ELEVEN11 = '11:11'
class PostMessage:
def __init__(self, html, text):
self.html = html
self.text = text
@property
def len(self):
return len(self.text)
class Post:
file_path: str
meta: dict
text: str
links: list[str]
def __init__(self, file_name):
self.file_path = os.path.realpath(os.path.join(POSTS_DIR, file_name))
self.meta, self.text = self.parse()
self.links = []
self.validate_self_links()
def parse(self):
meta = {}
text_lines = []
in_meta = True
with open(self.file_path, encoding='utf-8') as f:
for line in f:
line = line.strip()
if in_meta and line.startswith('@'):
parts = line.split(maxsplit=1)
key = parts[0][1:].strip()
value = parts[1].strip() if len(parts) == 2 else ''
meta[key.lower()] = value
else:
in_meta = False
text_lines.append(line)
text = '\n'.join(text_lines).strip()
return meta, text
def validate_self_links(self):
def callback(match):
name = match.group(1)
href = match.group(2)[1:]
if href not in self.links:
self.links.append(href)
if not os.path.exists(os.path.join(POSTS_DIR, href)):
raise FileNotFoundError(f'{href}: linked post not found')
try:
message_id = db.get_published_post_message_ids(href)[0]
url = get_message_url(message_id)
except PostNotFound:
url = f'#{href}'
return f'[{name}]({url})'
self.text = re.sub(r'\[(.*?)]\((#.*?\.md)\)', callback, self.text)
def is_eleven11(self) -> bool:
return self.type == PostType.ELEVEN11
@property
def file_name(self) -> str:
return os.path.basename(self.file_path)
@property
def link(self) -> Optional[str]:
return self.meta['link'] if 'link' in self.meta else None
@property
def image(self) -> Optional[str]:
return self.meta['image'] if 'image' in self.meta else None
@property
def type(self) -> PostType:
return PostType(self.meta['type']) if 'type' in self.meta else PostType.ARTICLE
def get_messages(self) -> list[PostMessage]:
messages = []
paragraphs = []
preferred_split = None
for i, p in enumerate(self.text.split('\n\n')):
if p.strip().startswith('@split'):
if preferred_split is not None:
raise ValueError('only one @split directive per document allowed')
preferred_split = i
p = re.sub('^@split\n\n?', '', p)
paragraphs.append(p)
splitting = False
start = 0
end = len(paragraphs)
while start <= len(paragraphs)-1:
html = ''
if splitting:
html += '<p>(N из N'
if start > 0:
html += ', начало <a href="{start_url}">тут</a>'
html += ')</p>\n\n'
if start == 0:
html += f'<b><u>{escape(self.meta["title"])}</u></b>\n'
if 'date' in self.meta:
html += f'<i>({escape(self.meta["date"])})</i>\n'
html += '\n'
html += markdown_to_html('\n\n'.join(paragraphs[start:end]))
if self.link and (splitting is False or start > 0):
html += '\n\n'
html += f'<b>Читать далее:</b>\n{self.link}'
html_to_send = remove_paragraph_tags(html)
text = html_to_text(html)
if len(text) > MAX_MESSAGE_LENGTH or len(html_to_send) > MAX_MESSAGE_MARKUP_LENGTH:
logger.debug(f'long message ({len(text)})')
if preferred_split is not None and start < preferred_split < end:
end = preferred_split
else:
if paragraphs[end-2].startswith('___') or paragraphs[end-2].startswith('**'):
end -= 2
else:
end -= 1
splitting = True
else:
messages.append(PostMessage(html_to_send, text))
start = end
end = len(paragraphs)
if splitting:
for i, message in enumerate(messages):
messages[i].html = re.sub(r'^\(N из N', '('+str(i+1)+' из '+str(len(messages)), message.html)
logger.debug(f'{len(messages)} posts, lengths {", ".join(map(lambda m: str(m.len), messages))}')
return messages
class Database:
def __init__(self):
self.conn = sqlite3.connect(DB_PATH)
self.initialize()
def initialize(self):
cur = self.conn.cursor()
cur.execute(
"""
CREATE TABLE IF NOT EXISTS published (
post_name TEXT PRIMARY KEY,
published_at INTEGER NOT NULL,
main_message_id INTEGER NOT NULL,
all_message_ids TEXT NOT NULL
)
"""
)
cur.execute(
"""
CREATE TABLE IF NOT EXISTS links (
main_message_id INTEGER NOT NULL,
post_name TEXT NOT NULL
)"""
)
self.conn.commit()
def get_published_post_message_ids(self, post_name: str) -> Optional[tuple[int, list[int]]]:
cur = self.conn.cursor()
cur.execute("SELECT main_message_id, all_message_ids FROM published WHERE post_name=?", (post_name,))
row = cur.fetchone()
if not row:
raise PostNotFound()
return int(row[0]), list(map(int, row[1].split(',')))
def reset_published(self) -> None:
cur = self.conn.cursor()
cur.execute("DELETE FROM published")
self.conn.commit()
def is_already_published(self, post_name) -> bool:
cur = self.conn.cursor()
cur.execute("SELECT 1 FROM published WHERE post_name=?", (post_name,))
return cur.fetchone() is not None
def unmark_as_published(self, post_name) -> None:
cur = self.conn.cursor()
cur.execute("DELETE FROM published WHERE post_name=?", (post_name,))
self.conn.commit()
def mark_as_published(self,
post_name: str,
message_id: int,
message_ids: list[int]) -> None:
cur = self.conn.cursor()
cur.execute("INSERT INTO published (post_name, published_at, main_message_id, all_message_ids) VALUES (?, ?, ?, ?)",
(post_name, int(time()), message_id, ','.join(map(str, message_ids))))
self.conn.commit()
def add_links(self,
main_message_id: int,
links: list[str]) -> None:
cur = self.conn.cursor()
for post_name in links:
cur.execute("INSERT INTO links (main_message_id, post_name) VALUES (?, ?)", (main_message_id, post_name))
self.conn.commit()
# FIXME raise if not found
def get_linking_main_message_ids(self, post_name: str) -> list[int]:
cur = self.conn.cursor()
cur.execute("SELECT main_message_id FROM links WHERE post_name=?", (post_name,))
return list(map(lambda n: n[0], cur.fetchall()))
# FIXME
def delete_link(self, main_message_id, post_name) -> None:
cur = self.conn.cursor()
cur.execute("DELETE FROM links WHERE main_message_id=? AND post_name=?", (main_message_id, post_name))
self.conn.commit()
def get_message_url(message_id):
return MESSAGE_LINK_TEMPLATE.replace('{id}', str(message_id))
def get_files() -> Optional[list[str]]:
files = [f for f in os.listdir(POSTS_DIR) if f.endswith('.md')]
if not files:
logger.warning('no posts found')
return None
files.sort(key=lambda f: [int(c) if c.isdigit() else c.lower() for c in re.split(r'(\d+)', f)])
return files
def get_subparser(subparsers, *args, **kwargs):
p = subparsers.add_parser(*args, **kwargs)
p.add_argument('-V', '--verbose', action='store_true')
return p
async def do_send(type: Optional[PostType] = None,
mark_as_published=False,
post_name=None):
post = None
if post_name:
post = Post(post_name)
else:
for filename in get_files():
if db.is_already_published(filename):
continue
post = Post(filename)
if post.type == type:
break
post = None
if not post:
raise NoPostsLeft()
n = 1
messages = post.get_messages()
main_message_id = None
all_message_ids = []
bot = Bot(TOKEN)
for i, message in enumerate(messages):
kwargs = {}
if post.image and n == 1:
kwargs['link_preview_options'] = LinkPreviewOptions(url=post.image, show_above_text=True)
else:
kwargs['disable_web_page_preview'] = True
text = message.html
if i > 0:
text = text.replace('{start_url}', get_message_url(main_message_id))
retried = False
while True:
try:
sent_message = await bot.send_message(chat_id=CHANNEL_ID, text=text, parse_mode=ParseMode.HTML, **kwargs)
all_message_ids.append(sent_message.id)
n += 1
if not main_message_id:
main_message_id = sent_message.id
except RetryAfter as e:
if retried:
raise e
retried = True
sleep(e.retry_after)
continue
break
logger.info(f'{post.file_name}: sent')
if mark_as_published:
db.mark_as_published(post.file_name, main_message_id, all_message_ids)
logger.info(f'{post.file_name}: marked as published')
if post.links:
db.add_links(main_message_id, post.links)
logger.info(f'{post.file_name}: {len(post.links)} links added')
def do_list_posts(only_of_type=None):
if only_of_type is not None:
only_of_type = PostType(only_of_type)
files = get_files()
if only_of_type is not None:
files = list(filter(lambda f: Post(f).type == only_of_type, files))
for file in files:
published = db.is_already_published(file)
if published:
print(f'\033[90m[done] ', end='')
else:
print('[wait] ', end='')
print(file, end='')
if published:
print('\033[0m')
else:
print()
db: Optional[Database] = None
if __name__ == '__main__':
parser = ArgumentParser()
parser.add_argument('-V', '--verbose', action='store_true')
subparsers = parser.add_subparsers(dest='command', required=True, help='top-level commands')
send_parser = get_subparser(subparsers, 'send', help='send messages')
send_parser.add_argument('--prod', action='store_true', help='use production channel and save published posts to database')
send_parser.add_argument('-t', '--type', choices=[i.value for i in PostType], help='post type')
send_parser.add_argument('--mark', action='store_true', help='mark as published even in dev mode')
send_parser.add_argument('-c', '--count', type=int, default=1, help='number of posts to send')
send_parser.add_argument('--delay-between-posts', type=int, default=0, help='delay in seconds between posts')
send_parser.add_argument('name', nargs='?', help='post name')
list_parser = get_subparser(subparsers, 'list', help='list post queue')
list_parser.add_argument('--only-of-type', choices=[i.value for i in PostType], help='post type')
mark_parser = get_subparser(subparsers, 'mark', help='mark post as published')
mark_parser.add_argument('name', help='post name')
unmark_parser = get_subparser(subparsers, 'unmark', help='remove published mark')
unmark_parser.add_argument('name', help='post name')
reset_parser = get_subparser(subparsers, 'reset', help='reset database')
args = parser.parse_args()
db = Database()
if args.verbose:
logger.setLevel(logging.DEBUG)
if args.command == 'send':
if not args.name and not args.type:
raise ArgumentError(None, 'must specify post name or post type')
if args.prod:
CHANNEL_ID = os.getenv('PROD_CHANNEL_ID')
MESSAGE_LINK_TEMPLATE = os.getenv('PROD_MESSAGE_LINK_TEMPLATE')
for i in range(args.count):
if i > 0 and args.delay_between_posts:
sleep(args.delay_between_posts)
try:
asyncio.run(do_send(type=PostType(args.type) if args.type else None,
mark_as_published=args.prod or args.mark,
post_name=args.name))
except NoPostsLeft:
logger.warning('no unpublished posts matching the current mode criteria')
break
elif args.command == 'list':
do_list_posts(only_of_type=args.only_of_type)
elif args.command in ('mark', 'unmark'):
if args.name not in get_files():
logger.error(f'{args.name}: file not found')
exit(1)
if args.command == 'mark':
db.mark_as_published(args.name, 0, [0])
elif args.command == 'unmark':
db.unmark_as_published(args.name)
elif args.command == 'reset':
db.reset_published()