Clean up the decorators code for the API.

(imported from commit b3fd6cfa475f021e35043148ad9a38633d9bddfe)
This commit is contained in:
Tim Abbott 2012-10-16 16:32:47 -04:00
parent a859c10017
commit 8388353859
1 changed files with 8 additions and 18 deletions

View File

@ -42,13 +42,13 @@ def require_post(view_func):
# api_key_required will add the authenticated user's user_profile to # api_key_required will add the authenticated user's user_profile to
# the view function's arguments list, since we have to look it up # the view function's arguments list, since we have to look it up
# anyway. # anyway.
def api_key_required(view_func): def login_required_api_view(view_func):
@csrf_exempt
@require_post
def _wrapped_view_func(request, *args, **kwargs): def _wrapped_view_func(request, *args, **kwargs):
# Arguably @require_post should protect us from having to do # Arguably @require_post should protect us from having to do
# this, but I don't want to count on us always getting the # this, but I don't want to count on us always getting the
# decorator ordering right. # decorator ordering right.
if request.method != "POST":
return HttpResponseBadRequest('This form can only be submitted by POST.')
try: try:
user_profile = UserProfile.objects.get(user__email=request.POST.get("email")) user_profile = UserProfile.objects.get(user__email=request.POST.get("email"))
except UserProfile.DoesNotExist: except UserProfile.DoesNotExist:
@ -305,18 +305,14 @@ def json_get_updates(request, handler):
# Yes, this has a name similar to the previous function. I think this # Yes, this has a name similar to the previous function. I think this
# new name is better and expect the old function to be deleted and # new name is better and expect the old function to be deleted and
# replaced by the new one soon, so I'm not going to worry about it. # replaced by the new one soon, so I'm not going to worry about it.
@csrf_exempt @login_required_api_view
@asynchronous @asynchronous
@require_post
@api_key_required
def api_get_messages(request, user_profile, handler): def api_get_messages(request, user_profile, handler):
return get_updates_backend(request, user_profile, handler, return get_updates_backend(request, user_profile, handler,
apply_markdown=(request.POST.get("apply_markdown") is not None), apply_markdown=(request.POST.get("apply_markdown") is not None),
mit_sync_bot=request.POST.get("mit_sync_bot")) mit_sync_bot=request.POST.get("mit_sync_bot"))
@csrf_exempt @login_required_api_view
@require_post
@api_key_required
def api_send_message(request, user_profile): def api_send_message(request, user_profile):
return send_message_backend(request, user_profile, user_profile.user) return send_message_backend(request, user_profile, user_profile.user)
@ -450,9 +446,7 @@ def send_message_backend(request, user_profile, sender):
return json_success() return json_success()
@csrf_exempt @login_required_api_view
@require_post
@api_key_required
def api_get_public_streams(request, user_profile): def api_get_public_streams(request, user_profile):
streams = sorted([stream.name for stream in streams = sorted([stream.name for stream in
Stream.objects.filter(realm=user_profile.realm)]) Stream.objects.filter(realm=user_profile.realm)])
@ -464,9 +458,7 @@ def gather_subscriptions(user_profile):
return sorted([get_display_recipient(sub.recipient) for sub in subscriptions return sorted([get_display_recipient(sub.recipient) for sub in subscriptions
if sub.recipient.type == Recipient.STREAM]) if sub.recipient.type == Recipient.STREAM])
@csrf_exempt @login_required_api_view
@require_post
@api_key_required
def api_get_subscriptions(request, user_profile): def api_get_subscriptions(request, user_profile):
return json_success({"streams": gather_subscriptions(user_profile)}) return json_success({"streams": gather_subscriptions(user_profile)})
@ -500,9 +492,7 @@ def valid_stream_name(name):
# Streams must start with a letter or number. # Streams must start with a letter or number.
return re.match("^[.a-zA-Z0-9][.a-z A-Z0-9_-]*$", name) return re.match("^[.a-zA-Z0-9][.a-z A-Z0-9_-]*$", name)
@csrf_exempt @login_required_api_view
@require_post
@api_key_required
def api_subscribe(request, user_profile): def api_subscribe(request, user_profile):
if "streams" not in request.POST: if "streams" not in request.POST:
return json_error("Missing streams argument.") return json_error("Missing streams argument.")