zulip/tools/tail-ses

177 lines
6.1 KiB
Python
Executable File

#!/usr/bin/env python3
import os
import sys
from email.headerregistry import Address
from email.utils import parseaddr
sys.path.insert(0, os.path.join(os.path.dirname(__file__), ".."))
from scripts.lib.setup_path import setup_path
setup_path()
os.environ["DJANGO_SETTINGS_MODULE"] = "zproject.settings"
import argparse
import secrets
from collections.abc import Iterator
from contextlib import contextmanager
from typing import TypedDict
import boto3.session
import orjson
from django.conf import settings
from mypy_boto3_ses import SESClient
from mypy_boto3_sns import SNSClient
from mypy_boto3_sqs import SQSClient
from mypy_boto3_sqs.type_defs import DeleteMessageBatchRequestEntryTypeDef
class IdentityArgsDict(TypedDict, total=False):
default: str
required: bool
def main() -> None:
session = boto3.session.Session(region_name=settings.S3_REGION)
# Strip off the realname, if present, and extract just the domain name
_, from_address = parseaddr(settings.NOREPLY_EMAIL_ADDRESS)
from_host = Address(addr_spec=from_address).domain
ses: SESClient = session.client("ses")
possible_identities = []
identity_paginator = ses.get_paginator("list_identities")
for identity_resp in identity_paginator.paginate():
possible_identities += identity_resp["Identities"]
identity_args: IdentityArgsDict = {}
if from_host in possible_identities:
identity_args["default"] = from_host
elif from_address in possible_identities:
identity_args["default"] = from_address
else:
identity_args["required"] = True
parser = argparse.ArgumentParser(description="Tail SES delivery or bounces")
parser.add_argument(
"--identity",
"-i",
choices=possible_identities,
help="Sending identity in SES",
**identity_args,
)
topic_group = parser.add_mutually_exclusive_group(required=True)
topic_group.add_argument("--bounces", "-b", action="store_true")
topic_group.add_argument("--deliveries", "-d", action="store_true")
topic_group.add_argument("--complaints", "-c", action="store_true")
args = parser.parse_args()
sns_topic_arn = get_ses_arn(session, args)
with (
our_sqs_queue(session, sns_topic_arn) as (queue_arn, queue_url),
our_sns_subscription(session, sns_topic_arn, queue_arn),
):
print_messages(session, queue_url)
def get_ses_arn(session: boto3.session.Session, args: argparse.Namespace) -> str:
ses: SESClient = session.client("ses")
notification_settings = ses.get_identity_notification_attributes(Identities=[args.identity])
settings = notification_settings["NotificationAttributes"][args.identity]
if args.bounces:
return settings["BounceTopic"]
elif args.complaints:
return settings["ComplaintTopic"]
elif args.deliveries:
return settings["DeliveryTopic"]
raise AssertionError # Unreachable
@contextmanager
def our_sqs_queue(session: boto3.session.Session, ses_topic_arn: str) -> Iterator[tuple[str, str]]:
(_, _, _, region, account_id, topic_name) = ses_topic_arn.split(":")
sqs: SQSClient = session.client("sqs")
queue_name = "tail-ses-" + secrets.token_hex(10)
try:
resp = sqs.create_queue(
QueueName=queue_name,
Attributes={
"Policy": orjson.dumps(
{
"Version": "2012-10-17",
"Id": secrets.token_hex(10),
"Statement": [
{
"Sid": "Sid" + secrets.token_hex(10),
"Effect": "Allow",
"Principal": {"AWS": "*"},
"Action": "SQS:SendMessage",
"Resource": f"arn:aws:sqs:{region}:{account_id}:{queue_name}",
"Condition": {"ArnEquals": {"aws:SourceArn": ses_topic_arn}},
}
],
}
).decode("UTF-8")
},
)
queue_url = resp["QueueUrl"]
yield f"arn:aws:sqs:{region}:{account_id}:{queue_name}", queue_url
finally:
if queue_url is not None:
print("Deleting temporary queue...", file=sys.stderr)
sqs.delete_queue(QueueUrl=queue_url)
@contextmanager
def our_sns_subscription(
session: boto3.session.Session, ses_topic_arn: str, queue_arn: str
) -> Iterator[str]:
sns: SNSClient = session.client("sns")
try:
resp = sns.subscribe(
TopicArn=ses_topic_arn,
Protocol="sqs",
Endpoint=queue_arn,
Attributes={"RawMessageDelivery": "false"},
ReturnSubscriptionArn=True,
)
subscription_arn = resp["SubscriptionArn"]
yield subscription_arn
finally:
if subscription_arn is not None:
print("Deleting temporary SNS subscription...", file=sys.stderr)
sns.unsubscribe(SubscriptionArn=subscription_arn)
def print_messages(session: boto3.session.Session, queue_url: str) -> None:
sqs: SQSClient = session.client("sqs")
try:
while True:
resp = sqs.receive_message(
QueueUrl=queue_url,
MaxNumberOfMessages=10,
WaitTimeSeconds=5,
MessageAttributeNames=["All"],
)
messages = resp.get("Messages", [])
delete_list: list[DeleteMessageBatchRequestEntryTypeDef] = []
for m in messages:
body = orjson.loads(m["Body"])
body_message = orjson.loads(body["Message"])
print(
body["Timestamp"]
+ " "
+ orjson.dumps(body_message, option=orjson.OPT_INDENT_2).decode("utf-8")
)
delete_list.append({"Id": m["MessageId"], "ReceiptHandle": m["ReceiptHandle"]})
if delete_list:
sqs.delete_message_batch(QueueUrl=queue_url, Entries=delete_list)
except KeyboardInterrupt:
pass
if __name__ == "__main__":
main()