diff --git a/zerver/decorator.py b/zerver/decorator.py index b0f8a60d76..a6a8561170 100644 --- a/zerver/decorator.py +++ b/zerver/decorator.py @@ -6,7 +6,7 @@ from django.contrib.auth import REDIRECT_FIELD_NAME, login as django_login from django.views.decorators.csrf import csrf_exempt from django.http import QueryDict, HttpResponseNotAllowed, HttpRequest from django.http.multipartparser import MultiPartParser -from zerver.models import UserProfile, get_client +from zerver.models import UserProfile, get_client, get_user_profile_by_api_key from zerver.lib.response import json_error, json_unauthorized, json_success from django.shortcuts import resolve_url from django.utils.decorators import available_attrs @@ -229,7 +229,7 @@ def validate_account_and_subdomain(request, user_profile): def access_user_by_api_key(request, api_key, email=None): # type: (HttpRequest, Text, Optional[Text]) -> UserProfile try: - user_profile = UserProfile.objects.get(api_key=api_key) + user_profile = get_user_profile_by_api_key(api_key) except UserProfile.DoesNotExist: raise JsonableError(_("Invalid API key")) if email is not None and email != user_profile.email: diff --git a/zerver/lib/cache.py b/zerver/lib/cache.py index c9f1fcc898..4e57beefd6 100644 --- a/zerver/lib/cache.py +++ b/zerver/lib/cache.py @@ -321,6 +321,10 @@ def user_profile_by_id_cache_key(user_profile_id): # type: (int) -> Text return u"user_profile_by_id:%s" % (user_profile_id,) +def user_profile_by_api_key_cache_key(api_key): + # type: (Text) -> Text + return u"user_profile_by_api_key:%s" % (api_key,) + # TODO: Refactor these cache helpers into another file that can import # models.py so that python v3 style type annotations can also work. @@ -361,6 +365,7 @@ def delete_user_profile_caches(user_profiles): for user_profile in user_profiles: keys.append(user_profile_by_email_cache_key(user_profile.email)) keys.append(user_profile_by_id_cache_key(user_profile.id)) + keys.append(user_profile_by_api_key_cache_key(user_profile.api_key)) keys.append(user_profile_cache_key(user_profile.email, user_profile.realm)) cache_delete_many(keys) diff --git a/zerver/lib/cache_helpers.py b/zerver/lib/cache_helpers.py index 7197230563..e3fb625c24 100644 --- a/zerver/lib/cache_helpers.py +++ b/zerver/lib/cache_helpers.py @@ -11,7 +11,9 @@ from zerver.models import Message, UserProfile, Stream, get_stream_cache_key, \ Recipient, get_recipient_cache_key, Client, get_client_cache_key, \ Huddle, huddle_hash_cache_key from zerver.lib.cache import cache_with_key, cache_set, \ - user_profile_by_email_cache_key, user_profile_by_id_cache_key, \ + user_profile_by_api_key_cache_key, \ + user_profile_by_email_cache_key, \ + user_profile_by_id_cache_key, \ user_profile_cache_key, get_remote_cache_time, get_remote_cache_requests, \ cache_set_many, to_dict_cache_key_id from importlib import import_module @@ -38,6 +40,7 @@ def user_cache_items(items_for_remote_cache, user_profile): # type: (Dict[Text, Tuple[UserProfile]], UserProfile) -> None items_for_remote_cache[user_profile_by_email_cache_key(user_profile.email)] = (user_profile,) items_for_remote_cache[user_profile_by_id_cache_key(user_profile.id)] = (user_profile,) + items_for_remote_cache[user_profile_by_api_key_cache_key(user_profile.api_key)] = (user_profile,) items_for_remote_cache[user_profile_cache_key(user_profile.email, user_profile.realm)] = (user_profile,) def stream_cache_items(items_for_remote_cache, stream): diff --git a/zerver/management/commands/rate_limit.py b/zerver/management/commands/rate_limit.py index f9e76168cd..330f1cbf78 100644 --- a/zerver/management/commands/rate_limit.py +++ b/zerver/management/commands/rate_limit.py @@ -4,7 +4,7 @@ from __future__ import print_function from typing import Any from argparse import ArgumentParser -from zerver.models import UserProfile +from zerver.models import UserProfile, get_user_profile_by_api_key from zerver.lib.rate_limiter import block_access, unblock_access, RateLimitedUser from zerver.lib.management import ZulipBaseCommand @@ -51,8 +51,8 @@ class Command(ZulipBaseCommand): user_profile = self.get_user(options['email'], realm) else: try: - user_profile = UserProfile.objects.get(api_key=options['api_key']) - except Exception: + user_profile = get_user_profile_by_api_key(options['api_key']) + except UserProfile.DoesNotExist: print("Unable to get user profile for api key %s" % (options['api_key'],)) exit(1) diff --git a/zerver/models.py b/zerver/models.py index e252a94e9c..362bbd13f7 100644 --- a/zerver/models.py +++ b/zerver/models.py @@ -16,6 +16,7 @@ from django.core.validators import URLValidator, MinLengthValidator, \ RegexValidator from django.dispatch import receiver from zerver.lib.cache import cache_with_key, flush_user_profile, flush_realm, \ + user_profile_by_api_key_cache_key, \ user_profile_by_id_cache_key, user_profile_by_email_cache_key, \ user_profile_cache_key, generic_bulk_cached_fetch, cache_set, flush_stream, \ display_recipient_cache_key, cache_delete, \ @@ -1444,6 +1445,11 @@ def get_user_profile_by_email(email): # type: (Text) -> UserProfile return UserProfile.objects.select_related().get(email__iexact=email.strip()) +@cache_with_key(user_profile_by_api_key_cache_key, timeout=3600*24*7) +def get_user_profile_by_api_key(api_key): + # type: (Text) -> UserProfile + return UserProfile.objects.select_related().get(api_key=api_key) + @cache_with_key(user_profile_cache_key, timeout=3600*24*7) def get_user(email, realm): # type: (Text, Realm) -> UserProfile diff --git a/zerver/tests/test_subs.py b/zerver/tests/test_subs.py index b013d6c111..1b98f699bf 100644 --- a/zerver/tests/test_subs.py +++ b/zerver/tests/test_subs.py @@ -1599,7 +1599,7 @@ class SubscriptionAPITest(ZulipTestCase): streams_to_sub, dict(principals=ujson.dumps([email1, email2])), ) - self.assert_length(queries, 43) + self.assert_length(queries, 42) self.assert_length(events, 7) for ev in [x for x in events if x['event']['type'] not in ('message', 'stream')]: @@ -1627,7 +1627,7 @@ class SubscriptionAPITest(ZulipTestCase): streams_to_sub, dict(principals=ujson.dumps([self.test_email])), ) - self.assert_length(queries, 16) + self.assert_length(queries, 15) self.assert_length(events, 2) add_event, add_peer_event = events @@ -1840,7 +1840,7 @@ class SubscriptionAPITest(ZulipTestCase): # Make sure Zephyr mirroring realms such as MIT do not get # any tornado subscription events self.assert_length(events, 0) - self.assert_length(queries, 9) + self.assert_length(queries, 8) def test_bulk_subscribe_many(self): # type: () -> None @@ -1857,7 +1857,7 @@ class SubscriptionAPITest(ZulipTestCase): dict(principals=ujson.dumps([self.test_email])), ) # Make sure we don't make O(streams) queries - self.assert_length(queries, 20) + self.assert_length(queries, 19) @slow("common_subscribe_to_streams is slow") def test_subscriptions_add_for_principal(self):