tornado: Support sharding by user ID.

Signed-off-by: Anders Kaseorg <anders@zulip.com>
This commit is contained in:
Anders Kaseorg 2022-09-22 13:09:34 -07:00 committed by Tim Abbott
parent 8010d06f9e
commit e5c26eeb86
11 changed files with 146 additions and 35 deletions

View File

@ -42,6 +42,13 @@ location /api/v1/events {
include /etc/nginx/zulip-include/proxy_longpolling; include /etc/nginx/zulip-include/proxy_longpolling;
} }
# Handle X-Accel-Redirect from Tornado to Tornado
location ~ ^/tornado/(\d+)(/.*)$ {
internal;
proxy_pass http://tornado$1$2$is_args$args;
include /etc/nginx/zulip-include/proxy_longpolling;
}
# Send everything else to Django via uWSGI # Send everything else to Django via uWSGI
location / { location / {
include uwsgi_params; include uwsgi_params;

View File

@ -41,9 +41,10 @@ class zulip::tornado_sharding {
loglevel => warning, loglevel => warning,
} }
# The ports of Tornado processes to run on the server; defaults to # The ports of Tornado processes to run on the server, computed from
# 9800. # the zulip.conf configuration. Default is just port 9800.
$tornado_ports = unique(zulipconf_keys('tornado_sharding').map |$key| { regsubst($key, /_regex$/, '') }) $tornado_groups = zulipconf_keys('tornado_sharding').map |$key| { $key.regsubst(/_regex$/, '').split('_') }.unique
$tornado_ports = $tornado_groups.flatten.unique
file { '/etc/nginx/zulip-include/tornado-upstreams': file { '/etc/nginx/zulip-include/tornado-upstreams':
require => [Package[$zulip::common::nginx], Exec['stage_updated_sharding']], require => [Package[$zulip::common::nginx], Exec['stage_updated_sharding']],

View File

@ -5,6 +5,17 @@ upstream tornado<%= port %> {
keepalive 10000; keepalive 10000;
} }
<% end -%> <% end -%>
<% @tornado_groups.each do |group| -%>
<% if group.length > 1 -%>
upstream tornado<%= group.join('_') %> {
random;
<% group.each do |port| -%>
server 127.0.0.1:<%= port %>;
<% end -%>
keepalive 10000;
}
<% end -%>
<% end -%>
<% else -%> <% else -%>
upstream tornado { upstream tornado {
server 127.0.0.1:9800; server 127.0.0.1:9800;

View File

@ -5,7 +5,7 @@ import json
import os import os
import subprocess import subprocess
import sys import sys
from typing import Dict from typing import Dict, List, Tuple, Union
BASE_DIR = os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) BASE_DIR = os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
sys.path.append(BASE_DIR) sys.path.append(BASE_DIR)
@ -47,29 +47,31 @@ def write_updated_configs() -> None:
nginx_sharding_conf_f.write("map $http_host $tornado_server {\n") nginx_sharding_conf_f.write("map $http_host $tornado_server {\n")
nginx_sharding_conf_f.write(" default http://tornado9800;\n") nginx_sharding_conf_f.write(" default http://tornado9800;\n")
shard_map: Dict[str, int] = {} shard_map: Dict[str, Union[int, List[int]]] = {}
shard_regexes = [] shard_regexes: List[Tuple[str, Union[int, List[int]]]] = []
external_host = subprocess.check_output( external_host = subprocess.check_output(
[os.path.join(BASE_DIR, "scripts/get-django-setting"), "EXTERNAL_HOST"], [os.path.join(BASE_DIR, "scripts/get-django-setting"), "EXTERNAL_HOST"],
text=True, text=True,
).strip() ).strip()
for key, shards in config_file["tornado_sharding"].items(): for key, shards in config_file["tornado_sharding"].items():
if key.endswith("_regex"): if key.endswith("_regex"):
port = int(key[: -len("_regex")]) ports = [int(port) for port in key[: -len("_regex")].split("_")]
shard_regexes.append((shards, port)) shard_regexes.append((shards, ports[0] if len(ports) == 1 else ports))
nginx_sharding_conf_f.write( nginx_sharding_conf_f.write(
f" {nginx_quote('~*' + shards)} http://tornado{port};\n" f" {nginx_quote('~*' + shards)} http://tornado{'_'.join(map(str, ports))};\n"
) )
else: else:
port = int(key) ports = [int(port) for port in key.split("_")]
for shard in shards.split(): for shard in shards.split():
if "." in shard: if "." in shard:
host = shard host = shard
else: else:
host = f"{shard}.{external_host}" host = f"{shard}.{external_host}"
assert host not in shard_map, f"host {host} duplicated" assert host not in shard_map, f"host {host} duplicated"
shard_map[host] = port shard_map[host] = ports[0] if len(ports) == 1 else ports
nginx_sharding_conf_f.write(f" {nginx_quote(host)} http://tornado{port};\n") nginx_sharding_conf_f.write(
f" {nginx_quote(host)} http://tornado{'_'.join(map(str, ports))};\n"
)
nginx_sharding_conf_f.write("\n") nginx_sharding_conf_f.write("\n")
nginx_sharding_conf_f.write("}\n") nginx_sharding_conf_f.write("}\n")

View File

@ -595,8 +595,9 @@ def get_tornado_ports(config_file: configparser.RawConfigParser) -> List[int]:
if config_file.has_section("tornado_sharding"): if config_file.has_section("tornado_sharding"):
ports = sorted( ports = sorted(
{ {
int(key[: -len("_regex")] if key.endswith("_regex") else key) int(port)
for key in config_file.options("tornado_sharding") for key in config_file.options("tornado_sharding")
for port in (key[: -len("_regex")] if key.endswith("_regex") else key).split("_")
} }
) )
if not ports: if not ports:

View File

@ -19,6 +19,7 @@ if settings.PRODUCTION:
from zerver.lib.async_utils import NoAutoCreateEventLoopPolicy from zerver.lib.async_utils import NoAutoCreateEventLoopPolicy
from zerver.lib.debug import interactive_debug_listen from zerver.lib.debug import interactive_debug_listen
from zerver.tornado.application import create_tornado_application, setup_tornado_rabbitmq from zerver.tornado.application import create_tornado_application, setup_tornado_rabbitmq
from zerver.tornado.descriptors import set_current_port
from zerver.tornado.event_queue import ( from zerver.tornado.event_queue import (
add_client_gc_hook, add_client_gc_hook,
dump_event_queues, dump_event_queues,
@ -91,6 +92,7 @@ class Command(BaseCommand):
) )
await sync_to_async(add_signal_handlers, thread_sensitive=True)() await sync_to_async(add_signal_handlers, thread_sensitive=True)()
set_current_port(port)
translation.activate(settings.LANGUAGE_CODE) translation.activate(settings.LANGUAGE_CODE)
# We pass display_num_errors=False, since Django will # We pass display_num_errors=False, since Django will

View File

@ -1,5 +1,7 @@
from typing import TYPE_CHECKING, Dict, Optional from typing import TYPE_CHECKING, Dict, Optional
from django.conf import settings
if TYPE_CHECKING: if TYPE_CHECKING:
from zerver.tornado.event_queue import ClientDescriptor from zerver.tornado.event_queue import ClientDescriptor
@ -16,3 +18,15 @@ def set_descriptor_by_handler_id(handler_id: int, client_descriptor: "ClientDesc
def clear_descriptor_by_handler_id(handler_id: int) -> None: def clear_descriptor_by_handler_id(handler_id: int) -> None:
del descriptors_by_handler_id[handler_id] del descriptors_by_handler_id[handler_id]
current_port: Optional[int] = None
def is_current_port(port: int) -> Optional[int]:
return settings.TEST_SUITE or current_port == port
def set_current_port(port: int) -> None:
global current_port
current_port = port

View File

@ -1,3 +1,4 @@
from collections import defaultdict
from functools import lru_cache from functools import lru_cache
from typing import Any, Dict, Iterable, List, Mapping, Optional, Sequence, Tuple, Union from typing import Any, Dict, Iterable, List, Mapping, Optional, Sequence, Tuple, Union
from urllib.parse import urlparse from urllib.parse import urlparse
@ -11,7 +12,13 @@ from urllib3.util import Retry
from zerver.lib.queue import queue_json_publish from zerver.lib.queue import queue_json_publish
from zerver.models import Client, Realm, UserProfile from zerver.models import Client, Realm, UserProfile
from zerver.tornado.sharding import get_tornado_port, get_tornado_uri, notify_tornado_queue_name from zerver.tornado.sharding import (
get_realm_tornado_ports,
get_tornado_uri,
get_user_id_tornado_port,
get_user_tornado_port,
notify_tornado_queue_name,
)
class TornadoAdapter(HTTPAdapter): class TornadoAdapter(HTTPAdapter):
@ -81,7 +88,7 @@ def request_event_queue(
if not settings.USING_TORNADO: if not settings.USING_TORNADO:
return None return None
tornado_uri = get_tornado_uri(user_profile.realm) tornado_uri = get_tornado_uri(get_user_tornado_port(user_profile))
req = { req = {
"dont_block": "true", "dont_block": "true",
"apply_markdown": orjson.dumps(apply_markdown), "apply_markdown": orjson.dumps(apply_markdown),
@ -113,7 +120,7 @@ def get_user_events(
if not settings.USING_TORNADO: if not settings.USING_TORNADO:
return [] return []
tornado_uri = get_tornado_uri(user_profile.realm) tornado_uri = get_tornado_uri(get_user_tornado_port(user_profile))
post_data: Dict[str, Any] = { post_data: Dict[str, Any] = {
"queue_id": queue_id, "queue_id": queue_id,
"last_event_id": last_event_id, "last_event_id": last_event_id,
@ -126,7 +133,7 @@ def get_user_events(
return resp.json()["events"] return resp.json()["events"]
def send_notification_http(realm: Realm, data: Mapping[str, Any]) -> None: def send_notification_http(port: int, data: Mapping[str, Any]) -> None:
if not settings.USING_TORNADO or settings.RUNNING_INSIDE_TORNADO: if not settings.USING_TORNADO or settings.RUNNING_INSIDE_TORNADO:
# To allow the backend test suite to not require a separate # To allow the backend test suite to not require a separate
# Tornado process, we simply call the process_notification # Tornado process, we simply call the process_notification
@ -141,7 +148,7 @@ def send_notification_http(realm: Realm, data: Mapping[str, Any]) -> None:
process_notification(data) process_notification(data)
else: else:
tornado_uri = get_tornado_uri(realm) tornado_uri = get_tornado_uri(port)
requests_client().post( requests_client().post(
tornado_uri + "/notify_tornado", tornado_uri + "/notify_tornado",
data=dict(data=orjson.dumps(data), secret=settings.SHARED_SECRET), data=dict(data=orjson.dumps(data), secret=settings.SHARED_SECRET),
@ -163,9 +170,18 @@ def send_event(
) -> None: ) -> None:
"""`users` is a list of user IDs, or in some special cases like message """`users` is a list of user IDs, or in some special cases like message
send/update or embeds, dictionaries containing extra data.""" send/update or embeds, dictionaries containing extra data."""
port = get_tornado_port(realm) realm_ports = get_realm_tornado_ports(realm)
if len(realm_ports) == 1:
port_user_map = {realm_ports[0]: list(users)}
else:
port_user_map = defaultdict(list)
for user in users:
user_id = user if isinstance(user, int) else user["id"]
port_user_map[get_user_id_tornado_port(realm_ports, user_id)].append(user)
for port, port_users in port_user_map.items():
queue_json_publish( queue_json_publish(
notify_tornado_queue_name(port), notify_tornado_queue_name(port),
dict(event=event, users=list(users)), dict(event=event, users=port_users),
lambda *args, **kwargs: send_notification_http(realm, *args, **kwargs), lambda *args, **kwargs: send_notification_http(port, *args, **kwargs),
) )

View File

@ -1340,6 +1340,18 @@ def process_notification(notice: Mapping[str, Any]) -> None:
process_presence_event(event, cast(List[int], users)) process_presence_event(event, cast(List[int], users))
elif event["type"] == "custom_profile_fields": elif event["type"] == "custom_profile_fields":
process_custom_profile_fields_event(event, cast(List[int], users)) process_custom_profile_fields_event(event, cast(List[int], users))
elif event["type"] == "cleanup_queue":
# cleanup_event_queue may generate this event to forward cleanup
# requests to the right shard.
assert isinstance(users[0], int)
try:
client = access_client_descriptor(users[0], event["queue_id"])
except BadEventQueueIdError:
logging.info(
"Ignoring cleanup request for bad queue id %s (%d)", event["queue_id"], users[0]
)
else:
client.cleanup()
else: else:
process_event(event, cast(List[int], users)) process_event(event, cast(List[int], users))
logging.debug( logging.debug(

View File

@ -1,13 +1,14 @@
import json import json
import os import os
import re import re
from typing import Dict, List, Pattern, Tuple, Union
from django.conf import settings from django.conf import settings
from zerver.models import Realm from zerver.models import Realm, UserProfile
shard_map = {} shard_map: Dict[str, Union[int, List[int]]] = {}
shard_regexes = [] shard_regexes: List[Tuple[Pattern[str], Union[int, List[int]]]] = []
if os.path.exists("/etc/zulip/sharding.json"): if os.path.exists("/etc/zulip/sharding.json"):
with open("/etc/zulip/sharding.json") as f: with open("/etc/zulip/sharding.json") as f:
data = json.loads(f.read()) data = json.loads(f.read())
@ -20,19 +21,27 @@ if os.path.exists("/etc/zulip/sharding.json"):
] ]
def get_tornado_port(realm: Realm) -> int: def get_realm_tornado_ports(realm: Realm) -> List[int]:
if realm.host in shard_map: if realm.host in shard_map:
return shard_map[realm.host] ports = shard_map[realm.host]
return [ports] if isinstance(ports, int) else ports
for regex, port in shard_regexes: for regex, ports in shard_regexes:
if regex.match(realm.host): if regex.match(realm.host):
return port return [ports] if isinstance(ports, int) else ports
return settings.TORNADO_PORTS[0] return [settings.TORNADO_PORTS[0]]
def get_tornado_uri(realm: Realm) -> str: def get_user_id_tornado_port(realm_ports: List[int], user_id: int) -> int:
port = get_tornado_port(realm) return realm_ports[user_id % len(realm_ports)]
def get_user_tornado_port(user: UserProfile) -> int:
return get_user_id_tornado_port(get_realm_tornado_ports(user.realm), user.id)
def get_tornado_uri(port: int) -> str:
return f"http://127.0.0.1:{port}" return f"http://127.0.0.1:{port}"

View File

@ -2,12 +2,14 @@ import time
from typing import Any, Callable, Mapping, Optional, Sequence, TypeVar from typing import Any, Callable, Mapping, Optional, Sequence, TypeVar
from asgiref.sync import async_to_sync from asgiref.sync import async_to_sync
from django.conf import settings
from django.http import HttpRequest, HttpResponse from django.http import HttpRequest, HttpResponse
from django.utils.translation import gettext as _ from django.utils.translation import gettext as _
from typing_extensions import ParamSpec from typing_extensions import ParamSpec
from zerver.decorator import internal_notify_view, process_client from zerver.decorator import internal_notify_view, process_client
from zerver.lib.exceptions import JsonableError from zerver.lib.exceptions import JsonableError
from zerver.lib.queue import get_queue_client
from zerver.lib.request import REQ, RequestNotes, has_request_variables from zerver.lib.request import REQ, RequestNotes, has_request_variables
from zerver.lib.response import AsynchronousResponse, json_success from zerver.lib.response import AsynchronousResponse, json_success
from zerver.lib.validator import ( from zerver.lib.validator import (
@ -19,7 +21,9 @@ from zerver.lib.validator import (
to_non_negative_int, to_non_negative_int,
) )
from zerver.models import Client, UserProfile, get_client, get_user_profile_by_id from zerver.models import Client, UserProfile, get_client, get_user_profile_by_id
from zerver.tornado.descriptors import is_current_port
from zerver.tornado.event_queue import access_client_descriptor, fetch_events, process_notification from zerver.tornado.event_queue import access_client_descriptor, fetch_events, process_notification
from zerver.tornado.sharding import get_user_tornado_port, notify_tornado_queue_name
P = ParamSpec("P") P = ParamSpec("P")
T = TypeVar("T") T = TypeVar("T")
@ -45,10 +49,28 @@ def notify(
def cleanup_event_queue( def cleanup_event_queue(
request: HttpRequest, user_profile: UserProfile, queue_id: str = REQ() request: HttpRequest, user_profile: UserProfile, queue_id: str = REQ()
) -> HttpResponse: ) -> HttpResponse:
client = access_client_descriptor(user_profile.id, queue_id)
log_data = RequestNotes.get_notes(request).log_data log_data = RequestNotes.get_notes(request).log_data
assert log_data is not None assert log_data is not None
log_data["extra"] = f"[{queue_id}]" log_data["extra"] = f"[{queue_id}]"
user_port = get_user_tornado_port(user_profile)
if not is_current_port(user_port):
# X-Accel-Redirect is not supported for HTTP DELETE requests,
# so we notify the shard hosting the acting user's queues via
# enqueuing a special event.
#
# TODO: Because we return a 200 before confirming that the
# event queue had been actually deleted by the process hosting
# the queue, there's a race where a `GET /events` request can
# succeed after getting a 200 from this endpoint.
assert settings.USING_RABBITMQ
get_queue_client().json_publish(
notify_tornado_queue_name(user_port),
{"users": [user_profile.id], "event": {"type": "cleanup_queue", "queue_id": queue_id}},
)
return json_success(request)
client = access_client_descriptor(user_profile.id, queue_id)
in_tornado_thread(client.cleanup)() in_tornado_thread(client.cleanup)()
return json_success(request) return json_success(request)
@ -60,11 +82,25 @@ def get_events_internal(
) -> HttpResponse: ) -> HttpResponse:
user_profile = get_user_profile_by_id(user_profile_id) user_profile = get_user_profile_by_id(user_profile_id)
RequestNotes.get_notes(request).requestor_for_logs = user_profile.format_requestor_for_logs() RequestNotes.get_notes(request).requestor_for_logs = user_profile.format_requestor_for_logs()
assert is_current_port(get_user_tornado_port(user_profile))
process_client(request, user_profile, client_name="internal") process_client(request, user_profile, client_name="internal")
return get_events_backend(request, user_profile) return get_events_backend(request, user_profile)
def get_events(request: HttpRequest, user_profile: UserProfile) -> HttpResponse: def get_events(request: HttpRequest, user_profile: UserProfile) -> HttpResponse:
user_port = get_user_tornado_port(user_profile)
if not is_current_port(user_port):
# When a single realm is split across multiple Tornado shards,
# any `GET /events` requests that are routed to the wrong
# shard are redirected to the shard hosting the relevant
# user's queues. We use X-Accel-Redirect for this purpose,
# which is efficient and keeps this redirect invisible to
# clients.
return HttpResponse(
"", headers={"X-Accel-Redirect": f"/tornado/{user_port}{request.get_full_path()}"}
)
return get_events_backend(request, user_profile) return get_events_backend(request, user_profile)