auth: Extract handle_desktop_flow wrapper.

This commit is contained in:
Mateusz Mandera 2020-05-31 17:31:30 +02:00 committed by Tim Abbott
parent d4958484a1
commit 676305f6ab
1 changed files with 13 additions and 4 deletions

View File

@ -18,6 +18,7 @@ from django.views.decorators.http import require_safe
from django.views.generic import TemplateView
from django.utils.translation import ugettext as _
from django.utils.http import is_safe_url
from functools import wraps
import urllib
from typing import Any, Dict, List, Optional, Mapping, cast
@ -34,6 +35,7 @@ from zerver.lib.request import REQ, has_request_variables, JsonableError
from zerver.lib.response import json_success, json_error
from zerver.lib.sessions import set_expirable_session_var
from zerver.lib.subdomains import get_subdomain, is_subdomain_root_or_alias
from zerver.lib.types import ViewFuncT
from zerver.lib.url_encoding import add_query_to_redirect_url
from zerver.lib.user_agent import parse_user_agent
from zerver.lib.users import get_api_key
@ -451,12 +453,19 @@ def oauth_redirect_to_root(request: HttpRequest, url: str,
return redirect(add_query_to_redirect_url(main_site_uri, urllib.parse.urlencode(params)))
def handle_desktop_flow(func: ViewFuncT) -> ViewFuncT:
@wraps(func)
def wrapper(request: HttpRequest, *args: Any, **kwargs: Any) -> HttpResponse:
user_agent = parse_user_agent(request.META.get("HTTP_USER_AGENT", "Missing User-Agent"))
if user_agent["name"] == "ZulipElectron":
return render(request, "zerver/desktop_login.html")
return func(request, *args, **kwargs)
return wrapper # type: ignore[return-value] # https://github.com/python/mypy/issues/1927
@handle_desktop_flow
def start_social_login(request: HttpRequest, backend: str, extra_arg: Optional[str]=None
) -> HttpResponse:
user_agent = parse_user_agent(request.META.get("HTTP_USER_AGENT", "Missing User-Agent"))
if user_agent["name"] == "ZulipElectron":
return render(request, "zerver/desktop_login.html")
backend_url = reverse('social:begin', args=[backend])
extra_url_params: Dict[str, str] = {}
if backend == "saml":