diff --git a/zerver/middleware.py b/zerver/middleware.py index 5403cff7db..4ca1babd69 100644 --- a/zerver/middleware.py +++ b/zerver/middleware.py @@ -9,7 +9,10 @@ from zerver.lib.cache import get_memcached_time, get_memcached_requests from zerver.lib.bugdown import get_bugdown_time, get_bugdown_requests from zerver.models import flush_per_request_caches from zerver.exceptions import RateLimited +from django.contrib.sessions.middleware import SessionMiddleware from django.views.csrf import csrf_failure as html_csrf_failure +from django.utils.cache import patch_vary_headers +from django.utils.http import cookie_date import logging import time @@ -286,3 +289,37 @@ class FlushDisplayRecipientCache(object): # are not shared at all between requests. flush_per_request_caches() return response + +class SessionHostDomainMiddleware(SessionMiddleware): + def process_response(self, request, response): + """ + If request.session was modified, or if the configuration is to save the + session every time, save the changes and set a session cookie. + """ + try: + accessed = request.session.accessed + modified = request.session.modified + except AttributeError: + pass + else: + if accessed: + patch_vary_headers(response, ('Cookie',)) + if modified or settings.SESSION_SAVE_EVERY_REQUEST: + if request.session.get_expire_at_browser_close(): + max_age = None + expires = None + else: + max_age = request.session.get_expiry_age() + expires_time = time.time() + max_age + expires = cookie_date(expires_time) + # Save the session data and refresh the client cookie. + # Skip session save for 500 responses, refs #3881. + if response.status_code != 500: + request.session.save() + response.set_cookie(settings.SESSION_COOKIE_NAME, + request.session.session_key, max_age=max_age, + expires=expires, domain=settings.SESSION_COOKIE_DOMAIN, + path=settings.SESSION_COOKIE_PATH, + secure=settings.SESSION_COOKIE_SECURE or None, + httponly=settings.SESSION_COOKIE_HTTPONLY or None) + return response diff --git a/zproject/settings.py b/zproject/settings.py index 2f7aed0a4b..bea6c45073 100644 --- a/zproject/settings.py +++ b/zproject/settings.py @@ -180,7 +180,7 @@ MIDDLEWARE_CLASSES = ( 'zerver.middleware.RateLimitMiddleware', 'zerver.middleware.FlushDisplayRecipientCache', 'django.middleware.common.CommonMiddleware', - 'django.contrib.sessions.middleware.SessionMiddleware', + 'zerver.middleware.SessionHostDomainMiddleware', 'django.middleware.csrf.CsrfViewMiddleware', 'django.contrib.auth.middleware.AuthenticationMiddleware', )