diff --git a/puppet/zulip/files/nginx/zulip-include-frontend/app b/puppet/zulip/files/nginx/zulip-include-frontend/app index 7917cb7fc7..4b7e76e13e 100644 --- a/puppet/zulip/files/nginx/zulip-include-frontend/app +++ b/puppet/zulip/files/nginx/zulip-include-frontend/app @@ -42,6 +42,13 @@ location /api/v1/events { include /etc/nginx/zulip-include/proxy_longpolling; } +# Handle X-Accel-Redirect from Tornado to Tornado +location ~ ^/tornado/(\d+)(/.*)$ { + internal; + proxy_pass http://tornado$1$2$is_args$args; + include /etc/nginx/zulip-include/proxy_longpolling; +} + # Send everything else to Django via uWSGI location / { include uwsgi_params; diff --git a/puppet/zulip/manifests/tornado_sharding.pp b/puppet/zulip/manifests/tornado_sharding.pp index ded1e6a298..89a8580b3c 100644 --- a/puppet/zulip/manifests/tornado_sharding.pp +++ b/puppet/zulip/manifests/tornado_sharding.pp @@ -41,9 +41,10 @@ class zulip::tornado_sharding { loglevel => warning, } - # The ports of Tornado processes to run on the server; defaults to - # 9800. - $tornado_ports = unique(zulipconf_keys('tornado_sharding').map |$key| { regsubst($key, /_regex$/, '') }) + # The ports of Tornado processes to run on the server, computed from + # the zulip.conf configuration. Default is just port 9800. + $tornado_groups = zulipconf_keys('tornado_sharding').map |$key| { $key.regsubst(/_regex$/, '').split('_') }.unique + $tornado_ports = $tornado_groups.flatten.unique file { '/etc/nginx/zulip-include/tornado-upstreams': require => [Package[$zulip::common::nginx], Exec['stage_updated_sharding']], diff --git a/puppet/zulip/templates/nginx/tornado-upstreams.conf.template.erb b/puppet/zulip/templates/nginx/tornado-upstreams.conf.template.erb index 2591d5603f..d0ad8043e3 100644 --- a/puppet/zulip/templates/nginx/tornado-upstreams.conf.template.erb +++ b/puppet/zulip/templates/nginx/tornado-upstreams.conf.template.erb @@ -5,6 +5,17 @@ upstream tornado<%= port %> { keepalive 10000; } <% end -%> +<% @tornado_groups.each do |group| -%> +<% if group.length > 1 -%> +upstream tornado<%= group.join('_') %> { + random; +<% group.each do |port| -%> + server 127.0.0.1:<%= port %>; +<% end -%> + keepalive 10000; +} +<% end -%> +<% end -%> <% else -%> upstream tornado { server 127.0.0.1:9800; diff --git a/scripts/lib/sharding.py b/scripts/lib/sharding.py index cd5283c079..b3c8425fd3 100755 --- a/scripts/lib/sharding.py +++ b/scripts/lib/sharding.py @@ -5,7 +5,7 @@ import json import os import subprocess import sys -from typing import Dict +from typing import Dict, List, Tuple, Union BASE_DIR = os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) sys.path.append(BASE_DIR) @@ -47,29 +47,31 @@ def write_updated_configs() -> None: nginx_sharding_conf_f.write("map $http_host $tornado_server {\n") nginx_sharding_conf_f.write(" default http://tornado9800;\n") - shard_map: Dict[str, int] = {} - shard_regexes = [] + shard_map: Dict[str, Union[int, List[int]]] = {} + shard_regexes: List[Tuple[str, Union[int, List[int]]]] = [] external_host = subprocess.check_output( [os.path.join(BASE_DIR, "scripts/get-django-setting"), "EXTERNAL_HOST"], text=True, ).strip() for key, shards in config_file["tornado_sharding"].items(): if key.endswith("_regex"): - port = int(key[: -len("_regex")]) - shard_regexes.append((shards, port)) + ports = [int(port) for port in key[: -len("_regex")].split("_")] + shard_regexes.append((shards, ports[0] if len(ports) == 1 else ports)) nginx_sharding_conf_f.write( - f" {nginx_quote('~*' + shards)} http://tornado{port};\n" + f" {nginx_quote('~*' + shards)} http://tornado{'_'.join(map(str, ports))};\n" ) else: - port = int(key) + ports = [int(port) for port in key.split("_")] for shard in shards.split(): if "." in shard: host = shard else: host = f"{shard}.{external_host}" assert host not in shard_map, f"host {host} duplicated" - shard_map[host] = port - nginx_sharding_conf_f.write(f" {nginx_quote(host)} http://tornado{port};\n") + shard_map[host] = ports[0] if len(ports) == 1 else ports + nginx_sharding_conf_f.write( + f" {nginx_quote(host)} http://tornado{'_'.join(map(str, ports))};\n" + ) nginx_sharding_conf_f.write("\n") nginx_sharding_conf_f.write("}\n") diff --git a/scripts/lib/zulip_tools.py b/scripts/lib/zulip_tools.py index 6d9ba59b78..77094d8e08 100755 --- a/scripts/lib/zulip_tools.py +++ b/scripts/lib/zulip_tools.py @@ -595,8 +595,9 @@ def get_tornado_ports(config_file: configparser.RawConfigParser) -> List[int]: if config_file.has_section("tornado_sharding"): ports = sorted( { - int(key[: -len("_regex")] if key.endswith("_regex") else key) + int(port) for key in config_file.options("tornado_sharding") + for port in (key[: -len("_regex")] if key.endswith("_regex") else key).split("_") } ) if not ports: diff --git a/zerver/management/commands/runtornado.py b/zerver/management/commands/runtornado.py index ca13be9f9b..e74ea0d969 100644 --- a/zerver/management/commands/runtornado.py +++ b/zerver/management/commands/runtornado.py @@ -19,6 +19,7 @@ if settings.PRODUCTION: from zerver.lib.async_utils import NoAutoCreateEventLoopPolicy from zerver.lib.debug import interactive_debug_listen from zerver.tornado.application import create_tornado_application, setup_tornado_rabbitmq +from zerver.tornado.descriptors import set_current_port from zerver.tornado.event_queue import ( add_client_gc_hook, dump_event_queues, @@ -91,6 +92,7 @@ class Command(BaseCommand): ) await sync_to_async(add_signal_handlers, thread_sensitive=True)() + set_current_port(port) translation.activate(settings.LANGUAGE_CODE) # We pass display_num_errors=False, since Django will diff --git a/zerver/tornado/descriptors.py b/zerver/tornado/descriptors.py index e6e85fdd21..4a91fc0f1c 100644 --- a/zerver/tornado/descriptors.py +++ b/zerver/tornado/descriptors.py @@ -1,5 +1,7 @@ from typing import TYPE_CHECKING, Dict, Optional +from django.conf import settings + if TYPE_CHECKING: from zerver.tornado.event_queue import ClientDescriptor @@ -16,3 +18,15 @@ def set_descriptor_by_handler_id(handler_id: int, client_descriptor: "ClientDesc def clear_descriptor_by_handler_id(handler_id: int) -> None: del descriptors_by_handler_id[handler_id] + + +current_port: Optional[int] = None + + +def is_current_port(port: int) -> Optional[int]: + return settings.TEST_SUITE or current_port == port + + +def set_current_port(port: int) -> None: + global current_port + current_port = port diff --git a/zerver/tornado/django_api.py b/zerver/tornado/django_api.py index e9ce17bab4..0ef51e069e 100644 --- a/zerver/tornado/django_api.py +++ b/zerver/tornado/django_api.py @@ -1,3 +1,4 @@ +from collections import defaultdict from functools import lru_cache from typing import Any, Dict, Iterable, List, Mapping, Optional, Sequence, Tuple, Union from urllib.parse import urlparse @@ -11,7 +12,13 @@ from urllib3.util import Retry from zerver.lib.queue import queue_json_publish from zerver.models import Client, Realm, UserProfile -from zerver.tornado.sharding import get_tornado_port, get_tornado_uri, notify_tornado_queue_name +from zerver.tornado.sharding import ( + get_realm_tornado_ports, + get_tornado_uri, + get_user_id_tornado_port, + get_user_tornado_port, + notify_tornado_queue_name, +) class TornadoAdapter(HTTPAdapter): @@ -81,7 +88,7 @@ def request_event_queue( if not settings.USING_TORNADO: return None - tornado_uri = get_tornado_uri(user_profile.realm) + tornado_uri = get_tornado_uri(get_user_tornado_port(user_profile)) req = { "dont_block": "true", "apply_markdown": orjson.dumps(apply_markdown), @@ -113,7 +120,7 @@ def get_user_events( if not settings.USING_TORNADO: return [] - tornado_uri = get_tornado_uri(user_profile.realm) + tornado_uri = get_tornado_uri(get_user_tornado_port(user_profile)) post_data: Dict[str, Any] = { "queue_id": queue_id, "last_event_id": last_event_id, @@ -126,7 +133,7 @@ def get_user_events( return resp.json()["events"] -def send_notification_http(realm: Realm, data: Mapping[str, Any]) -> None: +def send_notification_http(port: int, data: Mapping[str, Any]) -> None: if not settings.USING_TORNADO or settings.RUNNING_INSIDE_TORNADO: # To allow the backend test suite to not require a separate # Tornado process, we simply call the process_notification @@ -141,7 +148,7 @@ def send_notification_http(realm: Realm, data: Mapping[str, Any]) -> None: process_notification(data) else: - tornado_uri = get_tornado_uri(realm) + tornado_uri = get_tornado_uri(port) requests_client().post( tornado_uri + "/notify_tornado", data=dict(data=orjson.dumps(data), secret=settings.SHARED_SECRET), @@ -163,9 +170,18 @@ def send_event( ) -> None: """`users` is a list of user IDs, or in some special cases like message send/update or embeds, dictionaries containing extra data.""" - port = get_tornado_port(realm) - queue_json_publish( - notify_tornado_queue_name(port), - dict(event=event, users=list(users)), - lambda *args, **kwargs: send_notification_http(realm, *args, **kwargs), - ) + realm_ports = get_realm_tornado_ports(realm) + if len(realm_ports) == 1: + port_user_map = {realm_ports[0]: list(users)} + else: + port_user_map = defaultdict(list) + for user in users: + user_id = user if isinstance(user, int) else user["id"] + port_user_map[get_user_id_tornado_port(realm_ports, user_id)].append(user) + + for port, port_users in port_user_map.items(): + queue_json_publish( + notify_tornado_queue_name(port), + dict(event=event, users=port_users), + lambda *args, **kwargs: send_notification_http(port, *args, **kwargs), + ) diff --git a/zerver/tornado/event_queue.py b/zerver/tornado/event_queue.py index cb096221fe..c8f3304834 100644 --- a/zerver/tornado/event_queue.py +++ b/zerver/tornado/event_queue.py @@ -1340,6 +1340,18 @@ def process_notification(notice: Mapping[str, Any]) -> None: process_presence_event(event, cast(List[int], users)) elif event["type"] == "custom_profile_fields": process_custom_profile_fields_event(event, cast(List[int], users)) + elif event["type"] == "cleanup_queue": + # cleanup_event_queue may generate this event to forward cleanup + # requests to the right shard. + assert isinstance(users[0], int) + try: + client = access_client_descriptor(users[0], event["queue_id"]) + except BadEventQueueIdError: + logging.info( + "Ignoring cleanup request for bad queue id %s (%d)", event["queue_id"], users[0] + ) + else: + client.cleanup() else: process_event(event, cast(List[int], users)) logging.debug( diff --git a/zerver/tornado/sharding.py b/zerver/tornado/sharding.py index 338a73aeaf..f7617943a9 100644 --- a/zerver/tornado/sharding.py +++ b/zerver/tornado/sharding.py @@ -1,13 +1,14 @@ import json import os import re +from typing import Dict, List, Pattern, Tuple, Union from django.conf import settings -from zerver.models import Realm +from zerver.models import Realm, UserProfile -shard_map = {} -shard_regexes = [] +shard_map: Dict[str, Union[int, List[int]]] = {} +shard_regexes: List[Tuple[Pattern[str], Union[int, List[int]]]] = [] if os.path.exists("/etc/zulip/sharding.json"): with open("/etc/zulip/sharding.json") as f: data = json.loads(f.read()) @@ -20,19 +21,27 @@ if os.path.exists("/etc/zulip/sharding.json"): ] -def get_tornado_port(realm: Realm) -> int: +def get_realm_tornado_ports(realm: Realm) -> List[int]: if realm.host in shard_map: - return shard_map[realm.host] + ports = shard_map[realm.host] + return [ports] if isinstance(ports, int) else ports - for regex, port in shard_regexes: + for regex, ports in shard_regexes: if regex.match(realm.host): - return port + return [ports] if isinstance(ports, int) else ports - return settings.TORNADO_PORTS[0] + return [settings.TORNADO_PORTS[0]] -def get_tornado_uri(realm: Realm) -> str: - port = get_tornado_port(realm) +def get_user_id_tornado_port(realm_ports: List[int], user_id: int) -> int: + return realm_ports[user_id % len(realm_ports)] + + +def get_user_tornado_port(user: UserProfile) -> int: + return get_user_id_tornado_port(get_realm_tornado_ports(user.realm), user.id) + + +def get_tornado_uri(port: int) -> str: return f"http://127.0.0.1:{port}" diff --git a/zerver/tornado/views.py b/zerver/tornado/views.py index bf152d9f54..3f511bde20 100644 --- a/zerver/tornado/views.py +++ b/zerver/tornado/views.py @@ -2,12 +2,14 @@ import time from typing import Any, Callable, Mapping, Optional, Sequence, TypeVar from asgiref.sync import async_to_sync +from django.conf import settings from django.http import HttpRequest, HttpResponse from django.utils.translation import gettext as _ from typing_extensions import ParamSpec from zerver.decorator import internal_notify_view, process_client from zerver.lib.exceptions import JsonableError +from zerver.lib.queue import get_queue_client from zerver.lib.request import REQ, RequestNotes, has_request_variables from zerver.lib.response import AsynchronousResponse, json_success from zerver.lib.validator import ( @@ -19,7 +21,9 @@ from zerver.lib.validator import ( to_non_negative_int, ) from zerver.models import Client, UserProfile, get_client, get_user_profile_by_id +from zerver.tornado.descriptors import is_current_port from zerver.tornado.event_queue import access_client_descriptor, fetch_events, process_notification +from zerver.tornado.sharding import get_user_tornado_port, notify_tornado_queue_name P = ParamSpec("P") T = TypeVar("T") @@ -45,10 +49,28 @@ def notify( def cleanup_event_queue( request: HttpRequest, user_profile: UserProfile, queue_id: str = REQ() ) -> HttpResponse: - client = access_client_descriptor(user_profile.id, queue_id) log_data = RequestNotes.get_notes(request).log_data assert log_data is not None log_data["extra"] = f"[{queue_id}]" + + user_port = get_user_tornado_port(user_profile) + if not is_current_port(user_port): + # X-Accel-Redirect is not supported for HTTP DELETE requests, + # so we notify the shard hosting the acting user's queues via + # enqueuing a special event. + # + # TODO: Because we return a 200 before confirming that the + # event queue had been actually deleted by the process hosting + # the queue, there's a race where a `GET /events` request can + # succeed after getting a 200 from this endpoint. + assert settings.USING_RABBITMQ + get_queue_client().json_publish( + notify_tornado_queue_name(user_port), + {"users": [user_profile.id], "event": {"type": "cleanup_queue", "queue_id": queue_id}}, + ) + return json_success(request) + + client = access_client_descriptor(user_profile.id, queue_id) in_tornado_thread(client.cleanup)() return json_success(request) @@ -60,11 +82,25 @@ def get_events_internal( ) -> HttpResponse: user_profile = get_user_profile_by_id(user_profile_id) RequestNotes.get_notes(request).requestor_for_logs = user_profile.format_requestor_for_logs() + assert is_current_port(get_user_tornado_port(user_profile)) + process_client(request, user_profile, client_name="internal") return get_events_backend(request, user_profile) def get_events(request: HttpRequest, user_profile: UserProfile) -> HttpResponse: + user_port = get_user_tornado_port(user_profile) + if not is_current_port(user_port): + # When a single realm is split across multiple Tornado shards, + # any `GET /events` requests that are routed to the wrong + # shard are redirected to the shard hosting the relevant + # user's queues. We use X-Accel-Redirect for this purpose, + # which is efficient and keeps this redirect invisible to + # clients. + return HttpResponse( + "", headers={"X-Accel-Redirect": f"/tornado/{user_port}{request.get_full_path()}"} + ) + return get_events_backend(request, user_profile)