refactor: Export non-markdown logic in mention.py.

This commit is contained in:
akshatdalton 2021-06-12 22:51:30 +00:00 committed by Tim Abbott
parent a9320accdc
commit c507931ac8
8 changed files with 155 additions and 153 deletions

View File

@ -96,8 +96,9 @@ from zerver.lib.export import get_realm_exports_serialized
from zerver.lib.external_accounts import DEFAULT_EXTERNAL_ACCOUNTS from zerver.lib.external_accounts import DEFAULT_EXTERNAL_ACCOUNTS
from zerver.lib.hotspots import get_next_hotspots from zerver.lib.hotspots import get_next_hotspots
from zerver.lib.i18n import get_language_name from zerver.lib.i18n import get_language_name
from zerver.lib.markdown import MentionData, topic_links from zerver.lib.markdown import topic_links
from zerver.lib.markdown import version as markdown_version from zerver.lib.markdown import version as markdown_version
from zerver.lib.mention import MentionData
from zerver.lib.message import ( from zerver.lib.message import (
MessageDict, MessageDict,
SendMessageRequest, SendMessageRequest,

View File

@ -1,14 +1,13 @@
# Zulip's main Markdown implementation. See docs/subsystems/markdown.md for # Zulip's main Markdown implementation. See docs/subsystems/markdown.md for
# detailed documentation on our Markdown syntax. # detailed documentation on our Markdown syntax.
import datetime import datetime
import functools
import html import html
import logging import logging
import re import re
import time import time
import urllib import urllib
import urllib.parse import urllib.parse
from collections import defaultdict, deque from collections import deque
from dataclasses import dataclass from dataclasses import dataclass
from typing import ( from typing import (
Any, Any,
@ -39,7 +38,6 @@ import markdown.treeprocessors
import markdown.util import markdown.util
import requests import requests
from django.conf import settings from django.conf import settings
from django.db.models import Q
from markdown.blockparser import BlockParser from markdown.blockparser import BlockParser
from markdown.extensions import codehilite, nl2br, sane_lists, tables from markdown.extensions import codehilite, nl2br, sane_lists, tables
from tlds import tld_set from tlds import tld_set
@ -52,7 +50,7 @@ from zerver.lib.emoji import EMOTICON_RE, codepoint_to_name, name_to_codepoint,
from zerver.lib.exceptions import MarkdownRenderingException from zerver.lib.exceptions import MarkdownRenderingException
from zerver.lib.markdown import fenced_code from zerver.lib.markdown import fenced_code
from zerver.lib.markdown.fenced_code import FENCE_RE from zerver.lib.markdown.fenced_code import FENCE_RE
from zerver.lib.mention import possible_mentions, possible_user_group_mentions from zerver.lib.mention import MentionData, get_stream_name_info
from zerver.lib.subdomains import is_static_or_current_realm_url from zerver.lib.subdomains import is_static_or_current_realm_url
from zerver.lib.tex import render_tex from zerver.lib.tex import render_tex
from zerver.lib.thumbnail import user_uploads_or_external from zerver.lib.thumbnail import user_uploads_or_external
@ -61,15 +59,7 @@ from zerver.lib.timezone import common_timezones
from zerver.lib.types import LinkifierDict from zerver.lib.types import LinkifierDict
from zerver.lib.url_encoding import encode_stream, hash_util_encode from zerver.lib.url_encoding import encode_stream, hash_util_encode
from zerver.lib.url_preview import preview as link_preview from zerver.lib.url_preview import preview as link_preview
from zerver.models import ( from zerver.models import Message, Realm, linkifiers_for_realm
Message,
Realm,
UserGroup,
UserGroupMembership,
UserProfile,
get_active_streams,
linkifiers_for_realm,
)
ReturnT = TypeVar("ReturnT") ReturnT = TypeVar("ReturnT")
@ -92,12 +82,6 @@ def one_time(method: Callable[[], ReturnT]) -> Callable[[], ReturnT]:
return cache_wrapper return cache_wrapper
class FullNameInfo(TypedDict):
id: int
email: str
full_name: str
class LinkInfo(TypedDict): class LinkInfo(TypedDict):
parent: Element parent: Element
title: Optional[str] title: Optional[str]
@ -2393,132 +2377,6 @@ def privacy_clean_markdown(content: str) -> str:
return repr(_privacy_re.sub("x", content)) return repr(_privacy_re.sub("x", content))
def get_possible_mentions_info(realm_id: int, mention_texts: Set[str]) -> List[FullNameInfo]:
if not mention_texts:
return []
q_list = set()
name_re = r"(?P<full_name>.+)?\|(?P<mention_id>\d+)$"
for mention_text in mention_texts:
name_syntax_match = re.match(name_re, mention_text)
if name_syntax_match:
full_name = name_syntax_match.group("full_name")
mention_id = name_syntax_match.group("mention_id")
if full_name:
# For **name|id** mentions as mention_id
# cannot be null inside this block.
q_list.add(Q(full_name__iexact=full_name, id=mention_id))
else:
# For **|id** syntax.
q_list.add(Q(id=mention_id))
else:
# For **name** syntax.
q_list.add(Q(full_name__iexact=mention_text))
rows = (
UserProfile.objects.filter(
realm_id=realm_id,
is_active=True,
)
.filter(
functools.reduce(lambda a, b: a | b, q_list),
)
.values(
"id",
"full_name",
"email",
)
)
return list(rows)
class MentionData:
def __init__(self, realm_id: int, content: str) -> None:
mention_texts, has_wildcards = possible_mentions(content)
possible_mentions_info = get_possible_mentions_info(realm_id, mention_texts)
self.full_name_info = {row["full_name"].lower(): row for row in possible_mentions_info}
self.user_id_info = {row["id"]: row for row in possible_mentions_info}
self.init_user_group_data(realm_id=realm_id, content=content)
self.has_wildcards = has_wildcards
def message_has_wildcards(self) -> bool:
return self.has_wildcards
def init_user_group_data(self, realm_id: int, content: str) -> None:
user_group_names = possible_user_group_mentions(content)
self.user_group_name_info = get_user_group_name_info(realm_id, user_group_names)
self.user_group_members: Dict[int, List[int]] = defaultdict(list)
group_ids = [group.id for group in self.user_group_name_info.values()]
if not group_ids:
# Early-return to avoid the cost of hitting the ORM,
# which shows up in profiles.
return
membership = UserGroupMembership.objects.filter(user_group_id__in=group_ids)
for info in membership.values("user_group_id", "user_profile_id"):
group_id = info["user_group_id"]
user_profile_id = info["user_profile_id"]
self.user_group_members[group_id].append(user_profile_id)
def get_user_by_name(self, name: str) -> Optional[FullNameInfo]:
# warning: get_user_by_name is not dependable if two
# users of the same full name are mentioned. Use
# get_user_by_id where possible.
return self.full_name_info.get(name.lower(), None)
def get_user_by_id(self, id: int) -> Optional[FullNameInfo]:
return self.user_id_info.get(id, None)
def get_user_ids(self) -> Set[int]:
"""
Returns the user IDs that might have been mentioned by this
content. Note that because this data structure has not parsed
the message and does not know about escaping/code blocks, this
will overestimate the list of user ids.
"""
return set(self.user_id_info.keys())
def get_user_group(self, name: str) -> Optional[UserGroup]:
return self.user_group_name_info.get(name.lower(), None)
def get_group_members(self, user_group_id: int) -> List[int]:
return self.user_group_members.get(user_group_id, [])
def get_user_group_name_info(realm_id: int, user_group_names: Set[str]) -> Dict[str, UserGroup]:
if not user_group_names:
return {}
rows = UserGroup.objects.filter(realm_id=realm_id, name__in=user_group_names)
dct = {row.name.lower(): row for row in rows}
return dct
def get_stream_name_info(realm: Realm, stream_names: Set[str]) -> Dict[str, FullNameInfo]:
if not stream_names:
return {}
q_list = {Q(name=name) for name in stream_names}
rows = (
get_active_streams(
realm=realm,
)
.filter(
functools.reduce(lambda a, b: a | b, q_list),
)
.values(
"id",
"name",
)
)
dct = {row["name"]: row for row in rows}
return dct
def do_convert( def do_convert(
content: str, content: str,
realm_alert_words_automaton: Optional[ahocorasick.Automaton] = None, realm_alert_words_automaton: Optional[ahocorasick.Automaton] = None,

View File

@ -1,5 +1,12 @@
import functools
import re import re
from typing import Match, Optional, Set, Tuple from collections import defaultdict
from typing import Dict, List, Match, Optional, Set, Tuple
from django.db.models import Q
from zerver.lib.types import FullNameInfo
from zerver.models import Realm, UserGroup, UserGroupMembership, UserProfile, get_active_streams
# Match multi-word string between @** ** or match any one-word # Match multi-word string between @** ** or match any one-word
# sequences after @ # sequences after @
@ -35,3 +42,129 @@ def possible_mentions(content: str) -> Tuple[Set[str], bool]:
def possible_user_group_mentions(content: str) -> Set[str]: def possible_user_group_mentions(content: str) -> Set[str]:
return {m.group("match") for m in USER_GROUP_MENTIONS_RE.finditer(content)} return {m.group("match") for m in USER_GROUP_MENTIONS_RE.finditer(content)}
def get_possible_mentions_info(realm_id: int, mention_texts: Set[str]) -> List[FullNameInfo]:
if not mention_texts:
return []
q_list = set()
name_re = r"(?P<full_name>.+)?\|(?P<mention_id>\d+)$"
for mention_text in mention_texts:
name_syntax_match = re.match(name_re, mention_text)
if name_syntax_match:
full_name = name_syntax_match.group("full_name")
mention_id = name_syntax_match.group("mention_id")
if full_name:
# For **name|id** mentions as mention_id
# cannot be null inside this block.
q_list.add(Q(full_name__iexact=full_name, id=mention_id))
else:
# For **|id** syntax.
q_list.add(Q(id=mention_id))
else:
# For **name** syntax.
q_list.add(Q(full_name__iexact=mention_text))
rows = (
UserProfile.objects.filter(
realm_id=realm_id,
is_active=True,
)
.filter(
functools.reduce(lambda a, b: a | b, q_list),
)
.values(
"id",
"full_name",
"email",
)
)
return list(rows)
def get_user_group_name_info(realm_id: int, user_group_names: Set[str]) -> Dict[str, UserGroup]:
if not user_group_names:
return {}
rows = UserGroup.objects.filter(realm_id=realm_id, name__in=user_group_names)
dct = {row.name.lower(): row for row in rows}
return dct
class MentionData:
def __init__(self, realm_id: int, content: str) -> None:
mention_texts, has_wildcards = possible_mentions(content)
possible_mentions_info = get_possible_mentions_info(realm_id, mention_texts)
self.full_name_info = {row["full_name"].lower(): row for row in possible_mentions_info}
self.user_id_info = {row["id"]: row for row in possible_mentions_info}
self.init_user_group_data(realm_id=realm_id, content=content)
self.has_wildcards = has_wildcards
def message_has_wildcards(self) -> bool:
return self.has_wildcards
def init_user_group_data(self, realm_id: int, content: str) -> None:
user_group_names = possible_user_group_mentions(content)
self.user_group_name_info = get_user_group_name_info(realm_id, user_group_names)
self.user_group_members: Dict[int, List[int]] = defaultdict(list)
group_ids = [group.id for group in self.user_group_name_info.values()]
if not group_ids:
# Early-return to avoid the cost of hitting the ORM,
# which shows up in profiles.
return
membership = UserGroupMembership.objects.filter(user_group_id__in=group_ids)
for info in membership.values("user_group_id", "user_profile_id"):
group_id = info["user_group_id"]
user_profile_id = info["user_profile_id"]
self.user_group_members[group_id].append(user_profile_id)
def get_user_by_name(self, name: str) -> Optional[FullNameInfo]:
# warning: get_user_by_name is not dependable if two
# users of the same full name are mentioned. Use
# get_user_by_id where possible.
return self.full_name_info.get(name.lower(), None)
def get_user_by_id(self, id: int) -> Optional[FullNameInfo]:
return self.user_id_info.get(id, None)
def get_user_ids(self) -> Set[int]:
"""
Returns the user IDs that might have been mentioned by this
content. Note that because this data structure has not parsed
the message and does not know about escaping/code blocks, this
will overestimate the list of user ids.
"""
return set(self.user_id_info.keys())
def get_user_group(self, name: str) -> Optional[UserGroup]:
return self.user_group_name_info.get(name.lower(), None)
def get_group_members(self, user_group_id: int) -> List[int]:
return self.user_group_members.get(user_group_id, [])
def get_stream_name_info(realm: Realm, stream_names: Set[str]) -> Dict[str, FullNameInfo]:
if not stream_names:
return {}
q_list = {Q(name=name) for name in stream_names}
rows = (
get_active_streams(
realm=realm,
)
.filter(
functools.reduce(lambda a, b: a | b, q_list),
)
.values(
"id",
"name",
)
)
dct = {row["name"]: row for row in rows}
return dct

View File

@ -27,8 +27,9 @@ from zerver.lib.display_recipient import (
UserDisplayRecipient, UserDisplayRecipient,
bulk_fetch_display_recipients, bulk_fetch_display_recipients,
) )
from zerver.lib.markdown import MentionData, markdown_convert, topic_links from zerver.lib.markdown import markdown_convert, topic_links
from zerver.lib.markdown import version as markdown_version from zerver.lib.markdown import version as markdown_version
from zerver.lib.mention import MentionData
from zerver.lib.request import JsonableError from zerver.lib.request import JsonableError
from zerver.lib.stream_subscription import ( from zerver.lib.stream_subscription import (
get_stream_subscriptions_for_user, get_stream_subscriptions_for_user,

View File

@ -68,3 +68,9 @@ class SAMLIdPConfigDict(TypedDict, total=False):
extra_attrs: List[str] extra_attrs: List[str]
x509cert: str x509cert: str
x509cert_path: str x509cert_path: str
class FullNameInfo(TypedDict):
id: int
email: str
full_name: str

View File

@ -167,7 +167,7 @@ from zerver.lib.events import (
fetch_initial_state_data, fetch_initial_state_data,
post_process_state, post_process_state,
) )
from zerver.lib.markdown import MentionData from zerver.lib.mention import MentionData
from zerver.lib.message import render_markdown from zerver.lib.message import render_markdown
from zerver.lib.test_classes import ZulipTestCase from zerver.lib.test_classes import ZulipTestCase
from zerver.lib.test_helpers import ( from zerver.lib.test_helpers import (

View File

@ -25,11 +25,9 @@ from zerver.lib.emoji import get_emoji_url
from zerver.lib.exceptions import MarkdownRenderingException from zerver.lib.exceptions import MarkdownRenderingException
from zerver.lib.markdown import ( from zerver.lib.markdown import (
MarkdownListPreprocessor, MarkdownListPreprocessor,
MentionData,
clear_state_for_testing, clear_state_for_testing,
content_has_emoji_syntax, content_has_emoji_syntax,
fetch_tweet_data, fetch_tweet_data,
get_possible_mentions_info,
get_tweet_id, get_tweet_id,
image_preview_enabled, image_preview_enabled,
markdown_convert, markdown_convert,
@ -41,7 +39,12 @@ from zerver.lib.markdown import (
) )
from zerver.lib.markdown.fenced_code import FencedBlockPreprocessor from zerver.lib.markdown.fenced_code import FencedBlockPreprocessor
from zerver.lib.mdiff import diff_strings from zerver.lib.mdiff import diff_strings
from zerver.lib.mention import possible_mentions, possible_user_group_mentions from zerver.lib.mention import (
MentionData,
get_possible_mentions_info,
possible_mentions,
possible_user_group_mentions,
)
from zerver.lib.message import render_markdown from zerver.lib.message import render_markdown
from zerver.lib.request import JsonableError from zerver.lib.request import JsonableError
from zerver.lib.test_classes import ZulipTestCase from zerver.lib.test_classes import ZulipTestCase

View File

@ -20,7 +20,7 @@ from zerver.lib.actions import (
do_update_message, do_update_message,
) )
from zerver.lib.avatar import avatar_url from zerver.lib.avatar import avatar_url
from zerver.lib.markdown import MentionData from zerver.lib.mention import MentionData
from zerver.lib.message import ( from zerver.lib.message import (
MessageDict, MessageDict,
get_first_visible_message_id, get_first_visible_message_id,