424 lines
14 KiB
Python
Executable File
424 lines
14 KiB
Python
Executable File
#!/usr/bin/env python3
|
|
import os
|
|
import re
|
|
import sqlite3
|
|
import logging
|
|
import asyncio
|
|
from _textutils import markdown_to_html, html_to_text, remove_paragraph_tags
|
|
from enum import Enum
|
|
from html import escape
|
|
from argparse import ArgumentParser, ArgumentError
|
|
from time import time, sleep
|
|
from dotenv import load_dotenv
|
|
from telegram import Bot, LinkPreviewOptions
|
|
from telegram.constants import ParseMode
|
|
from telegram.error import RetryAfter
|
|
from typing import Optional
|
|
|
|
logging.basicConfig(level=logging.INFO)
|
|
load_dotenv()
|
|
|
|
TOKEN = os.getenv('TELEGRAM_TOKEN')
|
|
CHANNEL_ID = os.getenv("CHANNEL_ID")
|
|
MESSAGE_LINK_TEMPLATE = os.getenv("MESSAGE_LINK_TEMPLATE")
|
|
POSTS_DIR = os.path.join(os.path.dirname(__file__), 'posts')
|
|
DB_PATH = os.path.join(os.path.dirname(__file__), "bot.db")
|
|
MAX_MESSAGE_LENGTH = 3900 # actually it's 4096, but my calculation is not perfect; telegram is full of surprises too
|
|
MAX_MESSAGE_MARKUP_LENGTH = 30000 # i read on stackoverflow that it's 32768 but i'm not sure
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
class PostNotFound(Exception):
|
|
pass
|
|
|
|
class NoPostsLeft(Exception):
|
|
pass
|
|
|
|
|
|
class PostType(Enum):
|
|
ARTICLE = 'article'
|
|
ELEVEN11 = '11:11'
|
|
|
|
|
|
class PostMessage:
|
|
def __init__(self, html, text):
|
|
self.html = html
|
|
self.text = text
|
|
|
|
@property
|
|
def len(self):
|
|
return len(self.text)
|
|
|
|
|
|
class Post:
|
|
file_path: str
|
|
meta: dict
|
|
text: str
|
|
links: list[str]
|
|
|
|
def __init__(self, file_name):
|
|
self.file_path = os.path.realpath(os.path.join(POSTS_DIR, file_name))
|
|
self.meta, self.text = self.parse()
|
|
self.links = []
|
|
|
|
self.validate_self_links()
|
|
|
|
def parse(self):
|
|
meta = {}
|
|
text_lines = []
|
|
in_meta = True
|
|
with open(self.file_path, encoding='utf-8') as f:
|
|
for line in f:
|
|
line = line.strip()
|
|
if in_meta and line.startswith('@'):
|
|
parts = line.split(maxsplit=1)
|
|
key = parts[0][1:].strip()
|
|
value = parts[1].strip() if len(parts) == 2 else ''
|
|
meta[key.lower()] = value
|
|
else:
|
|
in_meta = False
|
|
text_lines.append(line)
|
|
text = '\n'.join(text_lines).strip()
|
|
return meta, text
|
|
|
|
def validate_self_links(self):
|
|
def callback(match):
|
|
name = match.group(1)
|
|
href = match.group(2)[1:]
|
|
|
|
if href not in self.links:
|
|
self.links.append(href)
|
|
|
|
if not os.path.exists(os.path.join(POSTS_DIR, href)):
|
|
raise FileNotFoundError(f'{href}: linked post not found')
|
|
|
|
try:
|
|
message_id = db.get_published_post_message_ids(href)[0]
|
|
url = get_message_url(message_id)
|
|
except PostNotFound:
|
|
url = f'#{href}'
|
|
return f'[{name}]({url})'
|
|
|
|
self.text = re.sub(r'\[(.*?)]\((#.*?\.md)\)', callback, self.text)
|
|
|
|
def is_eleven11(self) -> bool:
|
|
return self.type == PostType.ELEVEN11
|
|
|
|
@property
|
|
def file_name(self) -> str:
|
|
return os.path.basename(self.file_path)
|
|
|
|
@property
|
|
def link(self) -> Optional[str]:
|
|
return self.meta['link'] if 'link' in self.meta else None
|
|
|
|
@property
|
|
def image(self) -> Optional[str]:
|
|
return self.meta['image'] if 'image' in self.meta else None
|
|
|
|
@property
|
|
def type(self) -> PostType:
|
|
return PostType(self.meta['type']) if 'type' in self.meta else PostType.ARTICLE
|
|
|
|
def get_messages(self) -> list[PostMessage]:
|
|
messages = []
|
|
paragraphs = []
|
|
preferred_split = None
|
|
for i, p in enumerate(self.text.split('\n\n')):
|
|
if p.strip().startswith('@split'):
|
|
if preferred_split is not None:
|
|
raise ValueError('only one @split directive per document allowed')
|
|
preferred_split = i
|
|
p = re.sub('^@split\n\n?', '', p)
|
|
paragraphs.append(p)
|
|
splitting = False
|
|
start = 0
|
|
end = len(paragraphs)
|
|
|
|
while start <= len(paragraphs)-1:
|
|
html = ''
|
|
if splitting:
|
|
html += '<p>(N из N'
|
|
if start > 0:
|
|
html += ', начало <a href="{start_url}">тут</a>'
|
|
html += ')</p>\n\n'
|
|
|
|
if start == 0:
|
|
html += f'<b><u>{escape(self.meta["title"])}</u></b>\n'
|
|
if 'date' in self.meta:
|
|
html += f'<i>({escape(self.meta["date"])})</i>\n'
|
|
html += '\n'
|
|
|
|
html += markdown_to_html('\n\n'.join(paragraphs[start:end]))
|
|
if self.link and (splitting is False or start > 0):
|
|
html += '\n\n'
|
|
html += f'<b>Читать далее:</b>\n{self.link}'
|
|
|
|
html_to_send = remove_paragraph_tags(html)
|
|
text = html_to_text(html)
|
|
|
|
if len(text) > MAX_MESSAGE_LENGTH or len(html_to_send) > MAX_MESSAGE_MARKUP_LENGTH:
|
|
logger.debug(f'long message ({len(text)})')
|
|
if preferred_split is not None and start < preferred_split < end:
|
|
end = preferred_split
|
|
else:
|
|
if paragraphs[end-2].startswith('___') or paragraphs[end-2].startswith('**'):
|
|
end -= 2
|
|
else:
|
|
end -= 1
|
|
splitting = True
|
|
else:
|
|
messages.append(PostMessage(html_to_send, text))
|
|
start = end
|
|
end = len(paragraphs)
|
|
|
|
if splitting:
|
|
for i, message in enumerate(messages):
|
|
messages[i].html = re.sub(r'^\(N из N', '('+str(i+1)+' из '+str(len(messages)), message.html)
|
|
|
|
logger.debug(f'{len(messages)} posts, lengths {", ".join(map(lambda m: str(m.len), messages))}')
|
|
return messages
|
|
|
|
|
|
class Database:
|
|
def __init__(self):
|
|
self.conn = sqlite3.connect(DB_PATH)
|
|
self.initialize()
|
|
|
|
def initialize(self):
|
|
cur = self.conn.cursor()
|
|
cur.execute(
|
|
"""
|
|
CREATE TABLE IF NOT EXISTS published (
|
|
post_name TEXT PRIMARY KEY,
|
|
published_at INTEGER NOT NULL,
|
|
main_message_id INTEGER NOT NULL,
|
|
all_message_ids TEXT NOT NULL
|
|
)
|
|
"""
|
|
)
|
|
cur.execute(
|
|
"""
|
|
CREATE TABLE IF NOT EXISTS links (
|
|
main_message_id INTEGER NOT NULL,
|
|
post_name TEXT NOT NULL
|
|
)"""
|
|
)
|
|
self.conn.commit()
|
|
|
|
def get_published_post_message_ids(self, post_name: str) -> Optional[tuple[int, list[int]]]:
|
|
cur = self.conn.cursor()
|
|
cur.execute("SELECT main_message_id, all_message_ids FROM published WHERE post_name=?", (post_name,))
|
|
row = cur.fetchone()
|
|
if not row:
|
|
raise PostNotFound()
|
|
return int(row[0]), list(map(int, row[1].split(',')))
|
|
|
|
def reset_published(self) -> None:
|
|
cur = self.conn.cursor()
|
|
cur.execute("DELETE FROM published")
|
|
self.conn.commit()
|
|
|
|
def is_already_published(self, post_name) -> bool:
|
|
cur = self.conn.cursor()
|
|
cur.execute("SELECT 1 FROM published WHERE post_name=?", (post_name,))
|
|
return cur.fetchone() is not None
|
|
|
|
def unmark_as_published(self, post_name) -> None:
|
|
cur = self.conn.cursor()
|
|
cur.execute("DELETE FROM published WHERE post_name=?", (post_name,))
|
|
self.conn.commit()
|
|
|
|
def mark_as_published(self,
|
|
post_name: str,
|
|
message_id: int,
|
|
message_ids: list[int]) -> None:
|
|
cur = self.conn.cursor()
|
|
cur.execute("INSERT INTO published (post_name, published_at, main_message_id, all_message_ids) VALUES (?, ?, ?, ?)",
|
|
(post_name, int(time()), message_id, ','.join(map(str, message_ids))))
|
|
self.conn.commit()
|
|
|
|
def add_links(self,
|
|
main_message_id: int,
|
|
links: list[str]) -> None:
|
|
cur = self.conn.cursor()
|
|
for post_name in links:
|
|
cur.execute("INSERT INTO links (main_message_id, post_name) VALUES (?, ?)", (main_message_id, post_name))
|
|
self.conn.commit()
|
|
|
|
# FIXME raise if not found
|
|
def get_linking_main_message_ids(self, post_name: str) -> list[int]:
|
|
cur = self.conn.cursor()
|
|
cur.execute("SELECT main_message_id FROM links WHERE post_name=?", (post_name,))
|
|
return list(map(lambda n: n[0], cur.fetchall()))
|
|
|
|
# FIXME
|
|
def delete_link(self, main_message_id, post_name) -> None:
|
|
cur = self.conn.cursor()
|
|
cur.execute("DELETE FROM links WHERE main_message_id=? AND post_name=?", (main_message_id, post_name))
|
|
self.conn.commit()
|
|
|
|
|
|
def get_message_url(message_id):
|
|
return MESSAGE_LINK_TEMPLATE.replace('{id}', str(message_id))
|
|
|
|
|
|
def get_files() -> Optional[list[str]]:
|
|
files = [f for f in os.listdir(POSTS_DIR) if f.endswith('.md')]
|
|
if not files:
|
|
logger.warning('no posts found')
|
|
return None
|
|
files.sort(key=lambda f: [int(c) if c.isdigit() else c.lower() for c in re.split(r'(\d+)', f)])
|
|
return files
|
|
|
|
|
|
def get_subparser(subparsers, *args, **kwargs):
|
|
p = subparsers.add_parser(*args, **kwargs)
|
|
p.add_argument('-V', '--verbose', action='store_true')
|
|
return p
|
|
|
|
|
|
async def do_send(type: Optional[PostType] = None,
|
|
mark_as_published=False,
|
|
post_name=None):
|
|
post = None
|
|
if post_name:
|
|
post = Post(post_name)
|
|
else:
|
|
for filename in get_files():
|
|
if db.is_already_published(filename):
|
|
continue
|
|
post = Post(filename)
|
|
if post.type == type:
|
|
break
|
|
post = None
|
|
|
|
if not post:
|
|
raise NoPostsLeft()
|
|
|
|
n = 1
|
|
messages = post.get_messages()
|
|
main_message_id = None
|
|
all_message_ids = []
|
|
bot = Bot(TOKEN)
|
|
for i, message in enumerate(messages):
|
|
kwargs = {}
|
|
if post.image and n == 1:
|
|
kwargs['link_preview_options'] = LinkPreviewOptions(url=post.image, show_above_text=True)
|
|
else:
|
|
kwargs['disable_web_page_preview'] = True
|
|
text = message.html
|
|
if i > 0:
|
|
text = text.replace('{start_url}', get_message_url(main_message_id))
|
|
|
|
retried = False
|
|
while True:
|
|
try:
|
|
sent_message = await bot.send_message(chat_id=CHANNEL_ID, text=text, parse_mode=ParseMode.HTML, **kwargs)
|
|
all_message_ids.append(sent_message.id)
|
|
n += 1
|
|
if not main_message_id:
|
|
main_message_id = sent_message.id
|
|
except RetryAfter as e:
|
|
if retried:
|
|
raise e
|
|
retried = True
|
|
sleep(e.retry_after)
|
|
continue
|
|
break
|
|
|
|
logger.info(f'{post.file_name}: sent')
|
|
|
|
if mark_as_published:
|
|
db.mark_as_published(post.file_name, main_message_id, all_message_ids)
|
|
logger.info(f'{post.file_name}: marked as published')
|
|
|
|
if post.links:
|
|
db.add_links(main_message_id, post.links)
|
|
logger.info(f'{post.file_name}: {len(post.links)} links added')
|
|
|
|
|
|
def do_list_posts(only_of_type=None):
|
|
if only_of_type is not None:
|
|
only_of_type = PostType(only_of_type)
|
|
files = get_files()
|
|
if only_of_type is not None:
|
|
files = list(filter(lambda f: Post(f).type == only_of_type, files))
|
|
|
|
for file in files:
|
|
published = db.is_already_published(file)
|
|
if published:
|
|
print(f'\033[90m[done] ', end='')
|
|
else:
|
|
print('[wait] ', end='')
|
|
print(file, end='')
|
|
if published:
|
|
print('\033[0m')
|
|
else:
|
|
print()
|
|
|
|
|
|
db: Optional[Database] = None
|
|
|
|
if __name__ == '__main__':
|
|
parser = ArgumentParser()
|
|
parser.add_argument('-V', '--verbose', action='store_true')
|
|
subparsers = parser.add_subparsers(dest='command', required=True, help='top-level commands')
|
|
|
|
send_parser = get_subparser(subparsers, 'send', help='send messages')
|
|
send_parser.add_argument('--prod', action='store_true', help='use production channel and save published posts to database')
|
|
send_parser.add_argument('-t', '--type', choices=[i.value for i in PostType], help='post type')
|
|
send_parser.add_argument('--mark', action='store_true', help='mark as published even in dev mode')
|
|
send_parser.add_argument('-c', '--count', type=int, default=1, help='number of posts to send')
|
|
send_parser.add_argument('--delay-between-posts', type=int, default=0, help='delay in seconds between posts')
|
|
send_parser.add_argument('name', nargs='?', help='post name')
|
|
|
|
list_parser = get_subparser(subparsers, 'list', help='list post queue')
|
|
list_parser.add_argument('--only-of-type', choices=[i.value for i in PostType], help='post type')
|
|
|
|
mark_parser = get_subparser(subparsers, 'mark', help='mark post as published')
|
|
mark_parser.add_argument('name', help='post name')
|
|
|
|
unmark_parser = get_subparser(subparsers, 'unmark', help='remove published mark')
|
|
unmark_parser.add_argument('name', help='post name')
|
|
|
|
reset_parser = get_subparser(subparsers, 'reset', help='reset database')
|
|
|
|
args = parser.parse_args()
|
|
db = Database()
|
|
|
|
if args.verbose:
|
|
logger.setLevel(logging.DEBUG)
|
|
|
|
if args.command == 'send':
|
|
if not args.name and not args.type:
|
|
raise ArgumentError(None, 'must specify post name or post type')
|
|
if args.prod:
|
|
CHANNEL_ID = os.getenv('PROD_CHANNEL_ID')
|
|
MESSAGE_LINK_TEMPLATE = os.getenv('PROD_MESSAGE_LINK_TEMPLATE')
|
|
for i in range(args.count):
|
|
if i > 0 and args.delay_between_posts:
|
|
sleep(args.delay_between_posts)
|
|
try:
|
|
asyncio.run(do_send(type=PostType(args.type) if args.type else None,
|
|
mark_as_published=args.prod or args.mark,
|
|
post_name=args.name))
|
|
except NoPostsLeft:
|
|
logger.warning('no unpublished posts matching the current mode criteria')
|
|
break
|
|
|
|
elif args.command == 'list':
|
|
do_list_posts(only_of_type=args.only_of_type)
|
|
|
|
elif args.command in ('mark', 'unmark'):
|
|
if args.name not in get_files():
|
|
logger.error(f'{args.name}: file not found')
|
|
exit(1)
|
|
if args.command == 'mark':
|
|
db.mark_as_published(args.name, 0, [0])
|
|
elif args.command == 'unmark':
|
|
db.unmark_as_published(args.name)
|
|
|
|
elif args.command == 'reset':
|
|
db.reset_published() |