analytics: Correctly type aggregate_table with Union.

Since mypy doesn't accept redefinition of the same variable within the
same scope, we need to use type annotations with Union to correctly
type aggregate_table. Note that the type cast is necessary for mypy to
narrow the type of aggregate_table.
This commit is contained in:
PIG208 2021-07-27 17:37:41 +08:00 committed by Tim Abbott
parent 6ab42b7663
commit ec99c1b58d
1 changed files with 22 additions and 10 deletions

View File

@ -1,7 +1,7 @@
import logging
from collections import defaultdict
from datetime import datetime, timedelta, timezone
from typing import Any, Dict, List, Optional, Type
from typing import Any, Dict, List, Optional, Tuple, Type, Union, cast
from django.conf import settings
from django.db.models.query import QuerySet
@ -219,10 +219,16 @@ def get_chart_data(
remote_realm_id: Optional[int] = None,
server: Optional["RemoteZulipServer"] = None,
) -> HttpResponse:
TableType = Union[
Type[RemoteInstallationCount],
Type[InstallationCount],
Type[RemoteRealmCount],
Type[RealmCount],
]
if for_installation:
if remote:
assert settings.ZILENCER_ENABLED
aggregate_table = RemoteInstallationCount
aggregate_table: TableType = RemoteInstallationCount
assert server is not None
else:
aggregate_table = InstallationCount
@ -235,13 +241,15 @@ def get_chart_data(
else:
aggregate_table = RealmCount
tables: Union[Tuple[TableType], Tuple[TableType, Type[UserCount]]]
if chart_name == "number_of_humans":
stats = [
COUNT_STATS["1day_actives::day"],
COUNT_STATS["realm_active_humans::day"],
COUNT_STATS["active_users_audit:is_bot:day"],
]
tables = [aggregate_table]
tables = (aggregate_table,)
subgroup_to_label: Dict[CountStat, Dict[Optional[str], str]] = {
stats[0]: {None: "_1day"},
stats[1]: {None: "_15day"},
@ -251,13 +259,13 @@ def get_chart_data(
include_empty_subgroups = True
elif chart_name == "messages_sent_over_time":
stats = [COUNT_STATS["messages_sent:is_bot:hour"]]
tables = [aggregate_table, UserCount]
tables = (aggregate_table, UserCount)
subgroup_to_label = {stats[0]: {"false": "human", "true": "bot"}}
labels_sort_function = None
include_empty_subgroups = True
elif chart_name == "messages_sent_by_message_type":
stats = [COUNT_STATS["messages_sent:message_type:day"]]
tables = [aggregate_table, UserCount]
tables = (aggregate_table, UserCount)
subgroup_to_label = {
stats[0]: {
"public_stream": _("Public streams"),
@ -270,7 +278,7 @@ def get_chart_data(
include_empty_subgroups = True
elif chart_name == "messages_sent_by_client":
stats = [COUNT_STATS["messages_sent:client:day"]]
tables = [aggregate_table, UserCount]
tables = (aggregate_table, UserCount)
# Note that the labels are further re-written by client_label_map
subgroup_to_label = {
stats[0]: {str(id): name for id, name in Client.objects.values_list("id", "name")}
@ -279,7 +287,7 @@ def get_chart_data(
include_empty_subgroups = False
elif chart_name == "messages_read_over_time":
stats = [COUNT_STATS["messages_read::hour"]]
tables = [aggregate_table, UserCount]
tables = (aggregate_table, UserCount)
subgroup_to_label = {stats[0]: {None: "read"}}
labels_sort_function = None
include_empty_subgroups = True
@ -310,16 +318,20 @@ def get_chart_data(
# should simply use the first and last data points for the
# table.
assert server is not None
if not aggregate_table.objects.filter(server=server).exists():
assert aggregate_table is RemoteInstallationCount or aggregate_table is RemoteRealmCount
aggregate_table_remote = cast(
Union[Type[RemoteInstallationCount], Type[RemoteRealmCount]], aggregate_table
) # https://stackoverflow.com/questions/68540528/mypy-assertions-on-the-types-of-types
if not aggregate_table_remote.objects.filter(server=server).exists():
raise JsonableError(
_("No analytics data available. Please contact your server administrator.")
)
if start is None:
first = aggregate_table.objects.filter(server=server).first()
first = aggregate_table_remote.objects.filter(server=server).first()
assert first is not None
start = first.end_time
if end is None:
last = aggregate_table.objects.filter(server=server).last()
last = aggregate_table_remote.objects.filter(server=server).last()
assert last is not None
end = last.end_time
else: