diff --git a/zephyr/lib/actions.py b/zephyr/lib/actions.py index 8d47d566cc..1d6546b066 100644 --- a/zephyr/lib/actions.py +++ b/zephyr/lib/actions.py @@ -168,10 +168,21 @@ def do_send_message(message, no_log=False): # doesn't have to. message.to_dict(apply_markdown=True) message.to_dict(apply_markdown=False) - requests.post(settings.TORNADO_SERVER + '/notify_new_message', data=dict( + data = dict( secret = settings.SHARED_SECRET, message = message.id, - users = simplejson.dumps([str(user.id) for user in recipients]))) + users = simplejson.dumps([str(user.id) for user in recipients])) + if message.recipient.type == Recipient.STREAM: + # Note: This is where authorization for single-stream + # get_updates happens! We only attach stream data to the + # notify_new_message request if it's a public stream, + # ensuring that in the tornado server, non-public stream + # messages are only associated to their subscribed users. + stream = Stream.objects.get(id=message.recipient.type_id) + if stream.is_public(): + data['realm_id'] = stream.realm.id + data['stream_name'] = stream.name + requests.post(settings.TORNADO_SERVER + '/notify_new_message', data=data) def create_stream_if_needed(realm, stream_name): (stream, created) = Stream.objects.get_or_create( diff --git a/zephyr/tornadoviews.py b/zephyr/tornadoviews.py index 5b2fcb663b..2528fdec6f 100644 --- a/zephyr/tornadoviews.py +++ b/zephyr/tornadoviews.py @@ -1,5 +1,5 @@ from django.conf import settings -from zephyr.models import Message, UserProfile, UserMessage, UserActivity +from zephyr.models import Message, UserProfile, UserMessage, UserActivity, Recipient, Stream from zephyr.decorator import asynchronous, authenticated_api_view, \ authenticated_json_post_view, internal_notify_view, RespondAsynchronously, \ @@ -23,9 +23,20 @@ from zephyr.lib.message_cache import cache_save_message, cache_get_message SERVER_GENERATION = int(time.time()) class Callbacks(object): - TYPE_RECEIVE = 0 - TYPE_POINTER_UPDATE = 1 - TYPE_MAX = 2 + # A user received a message. The key is user_profile.id. + TYPE_USER_RECEIVE = 0 + + # A stream received a message. The key is a tuple + # (realm_id, lowercased stream name). + # See comment attached to the global stream_messages for why. + # Callers of this callback need to be careful to provide + # a lowercased stream name. + TYPE_STREAM_RECEIVE = 1 + + # A user's pointer was updated. The key is user_profile.id. + TYPE_POINTER_UPDATE = 2 + + TYPE_MAX = 3 def __init__(self): self.table = {} @@ -49,8 +60,11 @@ class Callbacks(object): callbacks_table = Callbacks() -def add_receive_callback(user_profile, cb): - callbacks_table.add(user_profile.id, Callbacks.TYPE_RECEIVE, cb) +def add_user_receive_callback(user_profile, cb): + callbacks_table.add(user_profile.id, Callbacks.TYPE_USER_RECEIVE, cb) + +def add_stream_receive_callback(realm_id, stream_name, cb): + callbacks_table.add((realm_id, stream_name.lower()), Callbacks.TYPE_STREAM_RECEIVE, cb) def add_pointer_update_callback(user_profile, cb): callbacks_table.add(user_profile.id, Callbacks.TYPE_POINTER_UPDATE, cb) @@ -64,6 +78,16 @@ def add_pointer_update_callback(user_profile, cb): # * O(k) read of highest k message ids # * Automatic maximum size support. user_messages = {} + +# Same deal as user_messages, but for streams. +# +# stream_messages: Map (realm_id, lowercased stream name) => [deque of message ids it received] +# +# Why don't we index by the stream_id? Because the client will make a +# request that specifies a particular realm and stream name, and since +# we're running within tornado, we don't want to have to do a database +# lookup to find the matching entry in this table. +stream_messages = {} USERMESSAGE_CACHE_COUNT = 25000 cache_minimum_id = sys.maxint def initialize_user_messages(): @@ -76,21 +100,40 @@ def initialize_user_messages(): for um in UserMessage.objects.filter(message_id__gte=cache_minimum_id).order_by("message"): add_user_message(um.user_profile_id, um.message_id) + streams = {} + for stream in Stream.objects.select_related().all(): + streams[stream.id] = stream + for m in (Message.objects.select_related() + .filter(id__gte=cache_minimum_id, + recipient__type=Recipient.STREAM).order_by("id")): + stream = streams[m.recipient.type_id] + add_stream_message(stream.realm.id, stream.name, m.id) + # Filling the memcached cache is a little slow, so do it in a child process. subprocess.Popen(["python", os.path.join(os.path.dirname(__file__), "..", "manage.py"), "fill_message_cache"]) def add_user_message(user_profile_id, message_id): + add_table_message(user_messages, user_profile_id, message_id) + +def add_stream_message(realm_id, stream_name, message_id): + add_table_message(stream_messages, (realm_id, stream_name.lower()), message_id) + +def add_table_message(table, key, message_id): if cache_minimum_id == sys.maxint: initialize_user_messages() - global user_messages - user_messages.setdefault(user_profile_id, collections.deque(maxlen=400)) - user_messages[user_profile_id].appendleft(message_id) + table.setdefault(key, collections.deque(maxlen=400)) + table[key].appendleft(message_id) def fetch_user_messages(user_profile_id, last): + return fetch_table_messages(user_messages, user_profile_id, last) + +def fetch_stream_messages(realm_id, stream_name, last): + return fetch_table_messages(stream_messages, (realm_id, stream_name.lower()), last) + +def fetch_table_messages(table, key, last): if cache_minimum_id == sys.maxint: initialize_user_messages() - global user_messages # We need to do this check after initialize_user_messages has been called. if last < cache_minimum_id: @@ -99,21 +142,28 @@ def fetch_user_messages(user_profile_id, last): raise JsonableError("last value of %d too old! Minimum valid is %d!" % (last, cache_minimum_id)) - # We need to initialize the deque here for any new users that were - # created since Tornado was started - user_messages.setdefault(user_profile_id, collections.deque(maxlen=400)) + # We need to initialize the deque here for any new users or + # streams that were created since Tornado was started + table.setdefault(key, collections.deque(maxlen=400)) message_list = [] - for message_id in user_messages[user_profile_id]: + for message_id in table[key]: if message_id <= last: return reversed(message_list) message_list.append(message_id) return [] # The user receives this message -def receive_message(user_profile_id, message): +def user_receive_message(user_profile_id, message): add_user_message(user_profile_id, message.id) - callbacks_table.call(user_profile_id, Callbacks.TYPE_RECEIVE, + callbacks_table.call(user_profile_id, Callbacks.TYPE_USER_RECEIVE, + messages=[message], update_types=["new_messages"]) + +# The stream receives this message +def stream_receive_message(realm_id, stream_name, message): + add_stream_message(realm_id, stream_name, message.id) + callbacks_table.call((realm_id, stream_name.lower()), + Callbacks.TYPE_STREAM_RECEIVE, messages=[message], update_types=["new_messages"]) # Simple caching implementation module for user pointers @@ -149,7 +199,12 @@ def notify_new_message(request): message = cache_get_message(int(request.POST['message'])) for user_profile_id in recipient_profile_ids: - receive_message(user_profile_id, message) + user_receive_message(user_profile_id, message) + + if 'stream_name' in request.POST: + realm_id = int(request.POST['realm_id']) + stream_name = request.POST['stream_name'] + stream_receive_message(realm_id, stream_name, message) return json_success() @@ -324,7 +379,7 @@ def get_updates_backend(request, user_profile, handler, client_id, except socket.error: pass - add_receive_callback(user_profile, handler.async_callback(cb)) + add_user_receive_callback(user_profile, handler.async_callback(cb)) if client_pointer is not None: add_pointer_update_callback(user_profile, handler.async_callback(cb))