donotmerge: Add tool to analyze bulk-fetch performance.

I can clean this up if we ever want to have it in our
permanent codebase, but for now I am most interested in
learning stuff and having folks replicate or debunk
the findings.
This commit is contained in:
Steve Howell 2023-07-21 09:23:27 +00:00
parent 3255281a83
commit 941a3d204e
1 changed files with 231 additions and 0 deletions

231
tools/analyze.py Executable file
View File

@ -0,0 +1,231 @@
#!/usr/bin/env python3
import os
import sys
import timeit
import django
from django.db import connection
ZULIP_PATH = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
sys.path.insert(0, ZULIP_PATH)
os.chdir(ZULIP_PATH)
# check for the venv
from tools.lib import sanity_check
sanity_check.check_venv(__file__)
os.environ["DJANGO_SETTINGS_MODULE"] = "zproject.settings"
django.setup()
from zerver.lib.cache import (
cache_delete,
display_recipient_cache_key,
generic_bulk_cached_fetch,
to_dict_cache_key_id,
)
from zerver.lib.display_recipient import (
display_recipient_fields,
single_user_display_recipient_cache_key,
user_dict_id_fetcher,
)
from zerver.lib.message import (
MessageDict,
extract_message_dict,
stringify_message_dict,
)
from zerver.models import Message, Stream, UserProfile
def get_column_values_from_single_table_using_id_lookup(*, columns, table, id_field, ids):
if len(ids) == 0:
return []
column_list = ", ".join(columns)
for id in ids:
assert type(id) == int
id_list = ", ".join(str(id) for id in ids)
cursor = connection.cursor()
sql = f"SELECT {column_list} FROM {table} WHERE {id_field} in ({id_list})"
cursor.execute(sql)
desc = cursor.description
rows = [dict(zip((col[0] for col in desc), row)) for row in cursor.fetchall()]
cursor.close()
return rows
def direct_db_fetch(
cache_key_function,
query_function,
object_ids,
*,
extractor,
setter,
id_fetcher,
cache_transformer,
):
return {id_fetcher(row): cache_transformer(row) for row in query_function(list(object_ids))}
def fetch(use_cache, *args, **kwargs):
f = generic_bulk_cached_fetch if use_cache else direct_db_fetch
return f(*args, **kwargs)
def messages_for_ids(message_ids, *, use_cache):
# This is cribbed from real code, but it excludes some steps
# such as getting user-specific message flags.
cache_transformer = MessageDict.build_dict_from_raw_db_row
id_fetcher = lambda row: row["id"]
message_dicts = fetch(
use_cache,
to_dict_cache_key_id,
MessageDict.get_raw_db_rows,
message_ids,
id_fetcher=id_fetcher,
cache_transformer=cache_transformer,
extractor=extract_message_dict,
setter=stringify_message_dict,
)
return message_dicts
def bulk_fetch_single_user_display_recipients(uids, *, use_cache, optimize):
if optimize:
query_function = lambda ids: get_column_values_from_single_table_using_id_lookup(
columns=display_recipient_fields,
table="zerver_userprofile",
id_field="id",
ids=ids,
)
else:
query_function = lambda ids: list(
UserProfile.objects.filter(id__in=ids).values(*display_recipient_fields)
)
return fetch(
use_cache,
cache_key_function=single_user_display_recipient_cache_key,
query_function=query_function,
object_ids=uids,
id_fetcher=user_dict_id_fetcher,
extractor=lambda obj: obj,
setter=lambda obj: obj,
cache_transformer=lambda obj: obj,
)
def bulk_fetch_stream_names(stream_ids, *, use_cache, optimize):
# This is modified from the original version to deal just in stream ids
# and not recipient ids.
def get_tiny_stream_rows(ids):
if optimize:
return get_column_values_from_single_table_using_id_lookup(
columns=["id", "name"],
table="zerver_stream",
id_field="id",
ids=ids,
)
else:
return Stream.objects.filter(id__in=ids).values("id", "name")
def get_stream_id(row):
return row["id"]
def get_name(row):
return row["name"]
stream_display_recipients = fetch(
use_cache,
cache_key_function=display_recipient_cache_key,
query_function=get_tiny_stream_rows,
object_ids=stream_ids,
id_fetcher=get_stream_id,
cache_transformer=get_name,
setter=lambda obj: obj,
extractor=lambda obj: obj,
)
return stream_display_recipients
def run(f):
print()
print(f"===== Running {f.__name__}")
f()
@run
def benchmark_stream_fetching():
def run(num_ids, *, use_cache, optimize=False):
stream_ids = Stream.objects.all()[:num_ids].values_list("id", flat=True)
assert len(stream_ids) == num_ids
label = "memcache" if use_cache else "optimize" if optimize else "database"
# warm up cache
if use_cache:
bulk_fetch_stream_names(stream_ids, use_cache=True, optimize=False)
f = lambda: bulk_fetch_stream_names(stream_ids, use_cache=use_cache, optimize=optimize)
number = 200
cost = min(timeit.repeat(f, number=number, repeat=5))
print(label, 1000 * cost / (num_ids * number), "(milliseconds per row)")
for n in [1, 5, 10, 15]:
print(f"Testing with {n} stream ids")
run(n, use_cache=False)
run(n, use_cache=False, optimize=True)
run(n, use_cache=True)
print()
@run
def benchmark_user_fetching():
def run(num_ids, *, use_cache, optimize=False):
user_ids = UserProfile.objects.all()[:num_ids].values_list("id", flat=True)
assert len(user_ids) == num_ids
label = "memcache" if use_cache else "optimize" if optimize else "database"
# warm up cache
if use_cache:
for user_id in user_ids:
cache_delete(single_user_display_recipient_cache_key(user_id))
bulk_fetch_single_user_display_recipients(user_ids, use_cache=True, optimize=False)
f = lambda: bulk_fetch_single_user_display_recipients(
user_ids, use_cache=use_cache, optimize=optimize
)
number = 200
cost = min(timeit.repeat(f, number=number, repeat=5))
print(label, 1000 * cost / (num_ids * number), "(milliseconds per row)")
for n in [1, 5, 10, 15, 20, 30, 50]:
print(f"Testing with {n} user ids")
run(n, use_cache=False)
run(n, use_cache=False, optimize=True)
run(n, use_cache=True)
print()
@run
def benchmark_message_fetching():
def run(num_ids, *, use_cache):
print(f"Testing with {num_ids} message ids")
message_ids = Message.objects.all()[:num_ids].values_list("id", flat=True)
assert len(message_ids) == num_ids
label = "memcache" if use_cache else "database"
# warm up cache
if use_cache:
messages_for_ids(message_ids, use_cache=True)
f = lambda: messages_for_ids(message_ids, use_cache=use_cache)
number = 10
cost = min(timeit.repeat(f, number=number, repeat=3))
print(label, 1000 * cost / (num_ids * number), "(milliseconds per row)")
run(20, use_cache=False)
run(20, use_cache=True)