settings: Make SHARED_SECRET mandatory.

This implements get_mandatory_secret that ensures SHARED_SECRET is
set when we hit zerver.decorator.authenticate_notify. To avoid getting
ZulipSettingsError when setting up the secrets, we set an environment
variable DISABLE_MANDATORY_SECRET_CHECK to skip the check and default
its value to an empty string.

Signed-off-by: Zixuan James Li <p359101898@gmail.com>
This commit is contained in:
Zixuan James Li 2022-08-11 22:41:13 -04:00 committed by Tim Abbott
parent b1293a84f8
commit 059d0e7be8
6 changed files with 39 additions and 5 deletions

View File

@ -11,6 +11,7 @@ from scripts.lib.zulip_tools import get_config, get_config_file
setup_path() setup_path()
os.environ["DISABLE_MANDATORY_SECRET_CHECK"] = "True"
os.environ["DJANGO_SETTINGS_MODULE"] = "zproject.settings" os.environ["DJANGO_SETTINGS_MODULE"] = "zproject.settings"
import argparse import argparse

View File

@ -1392,7 +1392,6 @@ class TestScriptMTA(ZulipTestCase):
mail_template = self.fixture_data("simple.txt", type="email") mail_template = self.fixture_data("simple.txt", type="email")
mail = mail_template.format(stream_to_address=stream_to_address, sender=sender) mail = mail_template.format(stream_to_address=stream_to_address, sender=sender)
assert settings.SHARED_SECRET is not None
subprocess.run( subprocess.run(
[script, "-r", stream_to_address, "-s", settings.SHARED_SECRET, "-t"], [script, "-r", stream_to_address, "-s", settings.SHARED_SECRET, "-t"],
input=mail, input=mail,
@ -1408,7 +1407,6 @@ class TestScriptMTA(ZulipTestCase):
stream_to_address = encode_email_address(stream) stream_to_address = encode_email_address(stream)
mail_template = self.fixture_data("simple.txt", type="email") mail_template = self.fixture_data("simple.txt", type="email")
mail = mail_template.format(stream_to_address=stream_to_address, sender=sender) mail = mail_template.format(stream_to_address=stream_to_address, sender=sender)
assert settings.SHARED_SECRET is not None
p = subprocess.run( p = subprocess.run(
[script, "-s", settings.SHARED_SECRET, "-t"], [script, "-s", settings.SHARED_SECRET, "-t"],
input=mail, input=mail,

View File

@ -220,7 +220,6 @@ class EventsEndpointTest(ZulipTestCase):
self.assertEqual(str(access_denied_error.exception), "Access denied") self.assertEqual(str(access_denied_error.exception), "Access denied")
self.assertEqual(access_denied_error.exception.http_status_code, 403) self.assertEqual(access_denied_error.exception.http_status_code, 403)
assert settings.SHARED_SECRET is not None
post_data["secret"] = settings.SHARED_SECRET post_data["secret"] = settings.SHARED_SECRET
req = HostRequestMock(post_data, tornado_handler=dummy_handler) req = HostRequestMock(post_data, tornado_handler=dummy_handler)
req.META["REMOTE_ADDR"] = "127.0.0.1" req.META["REMOTE_ADDR"] = "127.0.0.1"

View File

@ -0,0 +1,20 @@
import os
from unittest import mock
from zerver.lib.test_classes import ZulipTestCase
from zproject import config
class ConfigTest(ZulipTestCase):
def test_get_mandatory_secret_succeed(self) -> None:
secret = config.get_mandatory_secret("shared_secret")
self.assertGreater(len(secret), 0)
def test_get_mandatory_secret_failed(self) -> None:
with self.assertRaisesRegex(config.ZulipSettingsError, "nonexistent"):
config.get_mandatory_secret("nonexistent")
def test_disable_mandatory_secret_check(self) -> None:
with mock.patch.dict(os.environ, {"DISABLE_MANDATORY_SECRET_CHECK": "True"}):
secret = config.get_mandatory_secret("nonexistent")
self.assertEqual(secret, "")

View File

@ -19,6 +19,7 @@ from .config import (
config_file, config_file,
get_config, get_config,
get_from_file_if_exists, get_from_file_if_exists,
get_mandatory_secret,
get_secret, get_secret,
) )
from .configured_settings import ( from .configured_settings import (
@ -75,7 +76,7 @@ from .configured_settings import (
SECRET_KEY = get_secret("secret_key") SECRET_KEY = get_secret("secret_key")
# A shared secret, used to authenticate different parts of the app to each other. # A shared secret, used to authenticate different parts of the app to each other.
SHARED_SECRET = get_secret("shared_secret") SHARED_SECRET = get_mandatory_secret("shared_secret")
# We use this salt to hash a user's email into a filename for their user-uploaded # We use this salt to hash a user's email into a filename for their user-uploaded
# avatar. If this salt is discovered, attackers will only be able to determine # avatar. If this salt is discovered, attackers will only be able to determine

View File

@ -2,6 +2,13 @@ import configparser
import os import os
from typing import Optional, overload from typing import Optional, overload
from django.core.exceptions import ImproperlyConfigured
class ZulipSettingsError(ImproperlyConfigured):
pass
DEPLOY_ROOT = os.path.realpath(os.path.dirname(os.path.dirname(__file__))) DEPLOY_ROOT = os.path.realpath(os.path.dirname(os.path.dirname(__file__)))
config_file = configparser.RawConfigParser() config_file = configparser.RawConfigParser()
@ -10,7 +17,6 @@ config_file.read("/etc/zulip/zulip.conf")
# Whether this instance of Zulip is running in a production environment. # Whether this instance of Zulip is running in a production environment.
PRODUCTION = config_file.has_option("machine", "deploy_type") PRODUCTION = config_file.has_option("machine", "deploy_type")
DEVELOPMENT = not PRODUCTION DEVELOPMENT = not PRODUCTION
secrets_file = configparser.RawConfigParser() secrets_file = configparser.RawConfigParser()
if PRODUCTION: if PRODUCTION:
secrets_file.read("/etc/zulip/zulip-secrets.conf") secrets_file.read("/etc/zulip/zulip-secrets.conf")
@ -38,6 +44,15 @@ def get_secret(
return secrets_file.get("secrets", key, fallback=default_value) return secrets_file.get("secrets", key, fallback=default_value)
def get_mandatory_secret(key: str) -> str:
secret = get_secret(key)
if secret is None:
if os.environ.get("DISABLE_MANDATORY_SECRET_CHECK") == "True":
return ""
raise ZulipSettingsError(f'Mandatory secret "{key}" is not set')
return secret
@overload @overload
def get_config(section: str, key: str, default_value: str) -> str: def get_config(section: str, key: str, default_value: str) -> str:
... ...