mirror of https://github.com/zulip/zulip.git
sessions: Implement the concept of expirable session variables.
This can be useful in the future for various things, and right now it'll specifically be used in the signup mobile/desktop flows.
This commit is contained in:
parent
eb23c6fa6c
commit
fe33966642
|
@ -1,13 +1,15 @@
|
|||
import logging
|
||||
|
||||
from datetime import timedelta
|
||||
from django.conf import settings
|
||||
from django.contrib.auth import SESSION_KEY, get_user_model
|
||||
from django.contrib.sessions.models import Session
|
||||
from django.utils.timezone import now as timezone_now
|
||||
from importlib import import_module
|
||||
from typing import List, Mapping, Optional
|
||||
from typing import Any, List, Mapping, Optional
|
||||
|
||||
from zerver.models import Realm, UserProfile, get_user_profile_by_id
|
||||
from zerver.lib.timestamp import datetime_to_timestamp, timestamp_to_datetime
|
||||
|
||||
session_engine = import_module(settings.SESSION_ENGINE)
|
||||
|
||||
|
@ -53,3 +55,26 @@ def delete_all_deactivated_user_sessions() -> None:
|
|||
if not user_profile.is_active or user_profile.realm.deactivated:
|
||||
logging.info("Deactivating session for deactivated user %s" % (user_profile.id,))
|
||||
delete_session(session)
|
||||
|
||||
def set_expirable_session_var(session: Session, var_name: str, var_value: Any, expiry_seconds: int) -> None:
|
||||
expire_at = datetime_to_timestamp(timezone_now() + timedelta(seconds=expiry_seconds))
|
||||
session[var_name] = {'value': var_value, 'expire_at': expire_at}
|
||||
|
||||
def get_expirable_session_var(session: Session, var_name: str, default_value: Any=None,
|
||||
delete: bool=False) -> Any:
|
||||
if var_name not in session:
|
||||
return default_value
|
||||
|
||||
try:
|
||||
value, expire_at = (session[var_name]['value'], session[var_name]['expire_at'])
|
||||
except (KeyError, TypeError) as e:
|
||||
logging.warning("get_expirable_session_var: Variable {}: {}".format(var_name, e))
|
||||
return default_value
|
||||
|
||||
if timestamp_to_datetime(expire_at) < timezone_now():
|
||||
del session[var_name]
|
||||
return default_value
|
||||
|
||||
if delete:
|
||||
del session[var_name]
|
||||
return value
|
||||
|
|
|
@ -1,3 +1,5 @@
|
|||
from datetime import timedelta
|
||||
from django.utils.timezone import now as timezone_now
|
||||
from typing import Any, Callable
|
||||
|
||||
from zerver.lib.sessions import (
|
||||
|
@ -7,6 +9,8 @@ from zerver.lib.sessions import (
|
|||
delete_realm_user_sessions,
|
||||
delete_all_user_sessions,
|
||||
delete_all_deactivated_user_sessions,
|
||||
get_expirable_session_var,
|
||||
set_expirable_session_var,
|
||||
)
|
||||
|
||||
from zerver.models import (
|
||||
|
@ -15,6 +19,7 @@ from zerver.models import (
|
|||
|
||||
from zerver.lib.test_classes import ZulipTestCase
|
||||
|
||||
import mock
|
||||
|
||||
class TestSessions(ZulipTestCase):
|
||||
|
||||
|
@ -93,3 +98,35 @@ class TestSessions(ZulipTestCase):
|
|||
delete_all_deactivated_user_sessions()
|
||||
result = self.client_get("/")
|
||||
self.assertEqual('/login/', result.url)
|
||||
|
||||
class TestExpirableSessionVars(ZulipTestCase):
|
||||
def setUp(self) -> None:
|
||||
self.session = self.client.session
|
||||
super().setUp()
|
||||
|
||||
def test_set_and_get_basic(self) -> None:
|
||||
start_time = timezone_now()
|
||||
with mock.patch('zerver.lib.sessions.timezone_now', return_value=start_time):
|
||||
set_expirable_session_var(self.session, 'test_set_and_get_basic', 'some_value', expiry_seconds=10)
|
||||
value = get_expirable_session_var(self.session, 'test_set_and_get_basic')
|
||||
self.assertEqual(value, 'some_value')
|
||||
with mock.patch('zerver.lib.sessions.timezone_now', return_value=start_time + timedelta(seconds=11)):
|
||||
value = get_expirable_session_var(self.session, 'test_set_and_get_basic')
|
||||
self.assertEqual(value, None)
|
||||
|
||||
def test_set_and_get_with_delete(self) -> None:
|
||||
set_expirable_session_var(self.session, 'test_set_and_get_with_delete', 'some_value', expiry_seconds=10)
|
||||
value = get_expirable_session_var(self.session, 'test_set_and_get_with_delete', delete=True)
|
||||
self.assertEqual(value, 'some_value')
|
||||
self.assertEqual(get_expirable_session_var(self.session, 'test_set_and_get_with_delete'), None)
|
||||
|
||||
def test_get_var_not_set(self) -> None:
|
||||
value = get_expirable_session_var(self.session, 'test_get_var_not_set', default_value='default')
|
||||
self.assertEqual(value, 'default')
|
||||
|
||||
def test_get_var_is_not_expirable(self) -> None:
|
||||
self.session["test_get_var_is_not_expirable"] = 0
|
||||
with mock.patch('zerver.lib.sessions.logging.warning') as mock_warn:
|
||||
value = get_expirable_session_var(self.session, 'test_get_var_is_not_expirable', default_value='default')
|
||||
self.assertEqual(value, 'default')
|
||||
mock_warn.assert_called_once()
|
||||
|
|
Loading…
Reference in New Issue