diff --git a/zerver/lib/actions.py b/zerver/lib/actions.py index 1eaf3b4e8f..b780df0707 100644 --- a/zerver/lib/actions.py +++ b/zerver/lib/actions.py @@ -96,8 +96,9 @@ from zerver.lib.export import get_realm_exports_serialized from zerver.lib.external_accounts import DEFAULT_EXTERNAL_ACCOUNTS from zerver.lib.hotspots import get_next_hotspots from zerver.lib.i18n import get_language_name -from zerver.lib.markdown import MentionData, topic_links +from zerver.lib.markdown import topic_links from zerver.lib.markdown import version as markdown_version +from zerver.lib.mention import MentionData from zerver.lib.message import ( MessageDict, SendMessageRequest, diff --git a/zerver/lib/markdown/__init__.py b/zerver/lib/markdown/__init__.py index a50b0639e1..664d30cbd6 100644 --- a/zerver/lib/markdown/__init__.py +++ b/zerver/lib/markdown/__init__.py @@ -1,14 +1,13 @@ # Zulip's main Markdown implementation. See docs/subsystems/markdown.md for # detailed documentation on our Markdown syntax. import datetime -import functools import html import logging import re import time import urllib import urllib.parse -from collections import defaultdict, deque +from collections import deque from dataclasses import dataclass from typing import ( Any, @@ -39,7 +38,6 @@ import markdown.treeprocessors import markdown.util import requests from django.conf import settings -from django.db.models import Q from markdown.blockparser import BlockParser from markdown.extensions import codehilite, nl2br, sane_lists, tables from tlds import tld_set @@ -52,7 +50,7 @@ from zerver.lib.emoji import EMOTICON_RE, codepoint_to_name, name_to_codepoint, from zerver.lib.exceptions import MarkdownRenderingException from zerver.lib.markdown import fenced_code from zerver.lib.markdown.fenced_code import FENCE_RE -from zerver.lib.mention import possible_mentions, possible_user_group_mentions +from zerver.lib.mention import MentionData, get_stream_name_info from zerver.lib.subdomains import is_static_or_current_realm_url from zerver.lib.tex import render_tex from zerver.lib.thumbnail import user_uploads_or_external @@ -61,15 +59,7 @@ from zerver.lib.timezone import common_timezones from zerver.lib.types import LinkifierDict from zerver.lib.url_encoding import encode_stream, hash_util_encode from zerver.lib.url_preview import preview as link_preview -from zerver.models import ( - Message, - Realm, - UserGroup, - UserGroupMembership, - UserProfile, - get_active_streams, - linkifiers_for_realm, -) +from zerver.models import Message, Realm, linkifiers_for_realm ReturnT = TypeVar("ReturnT") @@ -92,12 +82,6 @@ def one_time(method: Callable[[], ReturnT]) -> Callable[[], ReturnT]: return cache_wrapper -class FullNameInfo(TypedDict): - id: int - email: str - full_name: str - - class LinkInfo(TypedDict): parent: Element title: Optional[str] @@ -2393,132 +2377,6 @@ def privacy_clean_markdown(content: str) -> str: return repr(_privacy_re.sub("x", content)) -def get_possible_mentions_info(realm_id: int, mention_texts: Set[str]) -> List[FullNameInfo]: - if not mention_texts: - return [] - - q_list = set() - - name_re = r"(?P.+)?\|(?P\d+)$" - for mention_text in mention_texts: - name_syntax_match = re.match(name_re, mention_text) - if name_syntax_match: - full_name = name_syntax_match.group("full_name") - mention_id = name_syntax_match.group("mention_id") - if full_name: - # For **name|id** mentions as mention_id - # cannot be null inside this block. - q_list.add(Q(full_name__iexact=full_name, id=mention_id)) - else: - # For **|id** syntax. - q_list.add(Q(id=mention_id)) - else: - # For **name** syntax. - q_list.add(Q(full_name__iexact=mention_text)) - - rows = ( - UserProfile.objects.filter( - realm_id=realm_id, - is_active=True, - ) - .filter( - functools.reduce(lambda a, b: a | b, q_list), - ) - .values( - "id", - "full_name", - "email", - ) - ) - return list(rows) - - -class MentionData: - def __init__(self, realm_id: int, content: str) -> None: - mention_texts, has_wildcards = possible_mentions(content) - possible_mentions_info = get_possible_mentions_info(realm_id, mention_texts) - self.full_name_info = {row["full_name"].lower(): row for row in possible_mentions_info} - self.user_id_info = {row["id"]: row for row in possible_mentions_info} - self.init_user_group_data(realm_id=realm_id, content=content) - self.has_wildcards = has_wildcards - - def message_has_wildcards(self) -> bool: - return self.has_wildcards - - def init_user_group_data(self, realm_id: int, content: str) -> None: - user_group_names = possible_user_group_mentions(content) - self.user_group_name_info = get_user_group_name_info(realm_id, user_group_names) - self.user_group_members: Dict[int, List[int]] = defaultdict(list) - group_ids = [group.id for group in self.user_group_name_info.values()] - - if not group_ids: - # Early-return to avoid the cost of hitting the ORM, - # which shows up in profiles. - return - - membership = UserGroupMembership.objects.filter(user_group_id__in=group_ids) - for info in membership.values("user_group_id", "user_profile_id"): - group_id = info["user_group_id"] - user_profile_id = info["user_profile_id"] - self.user_group_members[group_id].append(user_profile_id) - - def get_user_by_name(self, name: str) -> Optional[FullNameInfo]: - # warning: get_user_by_name is not dependable if two - # users of the same full name are mentioned. Use - # get_user_by_id where possible. - return self.full_name_info.get(name.lower(), None) - - def get_user_by_id(self, id: int) -> Optional[FullNameInfo]: - return self.user_id_info.get(id, None) - - def get_user_ids(self) -> Set[int]: - """ - Returns the user IDs that might have been mentioned by this - content. Note that because this data structure has not parsed - the message and does not know about escaping/code blocks, this - will overestimate the list of user ids. - """ - return set(self.user_id_info.keys()) - - def get_user_group(self, name: str) -> Optional[UserGroup]: - return self.user_group_name_info.get(name.lower(), None) - - def get_group_members(self, user_group_id: int) -> List[int]: - return self.user_group_members.get(user_group_id, []) - - -def get_user_group_name_info(realm_id: int, user_group_names: Set[str]) -> Dict[str, UserGroup]: - if not user_group_names: - return {} - - rows = UserGroup.objects.filter(realm_id=realm_id, name__in=user_group_names) - dct = {row.name.lower(): row for row in rows} - return dct - - -def get_stream_name_info(realm: Realm, stream_names: Set[str]) -> Dict[str, FullNameInfo]: - if not stream_names: - return {} - - q_list = {Q(name=name) for name in stream_names} - - rows = ( - get_active_streams( - realm=realm, - ) - .filter( - functools.reduce(lambda a, b: a | b, q_list), - ) - .values( - "id", - "name", - ) - ) - - dct = {row["name"]: row for row in rows} - return dct - - def do_convert( content: str, realm_alert_words_automaton: Optional[ahocorasick.Automaton] = None, diff --git a/zerver/lib/mention.py b/zerver/lib/mention.py index 076451fad7..5b93c0a2cc 100644 --- a/zerver/lib/mention.py +++ b/zerver/lib/mention.py @@ -1,5 +1,12 @@ +import functools import re -from typing import Match, Optional, Set, Tuple +from collections import defaultdict +from typing import Dict, List, Match, Optional, Set, Tuple + +from django.db.models import Q + +from zerver.lib.types import FullNameInfo +from zerver.models import Realm, UserGroup, UserGroupMembership, UserProfile, get_active_streams # Match multi-word string between @** ** or match any one-word # sequences after @ @@ -35,3 +42,129 @@ def possible_mentions(content: str) -> Tuple[Set[str], bool]: def possible_user_group_mentions(content: str) -> Set[str]: return {m.group("match") for m in USER_GROUP_MENTIONS_RE.finditer(content)} + + +def get_possible_mentions_info(realm_id: int, mention_texts: Set[str]) -> List[FullNameInfo]: + if not mention_texts: + return [] + + q_list = set() + + name_re = r"(?P.+)?\|(?P\d+)$" + for mention_text in mention_texts: + name_syntax_match = re.match(name_re, mention_text) + if name_syntax_match: + full_name = name_syntax_match.group("full_name") + mention_id = name_syntax_match.group("mention_id") + if full_name: + # For **name|id** mentions as mention_id + # cannot be null inside this block. + q_list.add(Q(full_name__iexact=full_name, id=mention_id)) + else: + # For **|id** syntax. + q_list.add(Q(id=mention_id)) + else: + # For **name** syntax. + q_list.add(Q(full_name__iexact=mention_text)) + + rows = ( + UserProfile.objects.filter( + realm_id=realm_id, + is_active=True, + ) + .filter( + functools.reduce(lambda a, b: a | b, q_list), + ) + .values( + "id", + "full_name", + "email", + ) + ) + return list(rows) + + +def get_user_group_name_info(realm_id: int, user_group_names: Set[str]) -> Dict[str, UserGroup]: + if not user_group_names: + return {} + + rows = UserGroup.objects.filter(realm_id=realm_id, name__in=user_group_names) + dct = {row.name.lower(): row for row in rows} + return dct + + +class MentionData: + def __init__(self, realm_id: int, content: str) -> None: + mention_texts, has_wildcards = possible_mentions(content) + possible_mentions_info = get_possible_mentions_info(realm_id, mention_texts) + self.full_name_info = {row["full_name"].lower(): row for row in possible_mentions_info} + self.user_id_info = {row["id"]: row for row in possible_mentions_info} + self.init_user_group_data(realm_id=realm_id, content=content) + self.has_wildcards = has_wildcards + + def message_has_wildcards(self) -> bool: + return self.has_wildcards + + def init_user_group_data(self, realm_id: int, content: str) -> None: + user_group_names = possible_user_group_mentions(content) + self.user_group_name_info = get_user_group_name_info(realm_id, user_group_names) + self.user_group_members: Dict[int, List[int]] = defaultdict(list) + group_ids = [group.id for group in self.user_group_name_info.values()] + + if not group_ids: + # Early-return to avoid the cost of hitting the ORM, + # which shows up in profiles. + return + + membership = UserGroupMembership.objects.filter(user_group_id__in=group_ids) + for info in membership.values("user_group_id", "user_profile_id"): + group_id = info["user_group_id"] + user_profile_id = info["user_profile_id"] + self.user_group_members[group_id].append(user_profile_id) + + def get_user_by_name(self, name: str) -> Optional[FullNameInfo]: + # warning: get_user_by_name is not dependable if two + # users of the same full name are mentioned. Use + # get_user_by_id where possible. + return self.full_name_info.get(name.lower(), None) + + def get_user_by_id(self, id: int) -> Optional[FullNameInfo]: + return self.user_id_info.get(id, None) + + def get_user_ids(self) -> Set[int]: + """ + Returns the user IDs that might have been mentioned by this + content. Note that because this data structure has not parsed + the message and does not know about escaping/code blocks, this + will overestimate the list of user ids. + """ + return set(self.user_id_info.keys()) + + def get_user_group(self, name: str) -> Optional[UserGroup]: + return self.user_group_name_info.get(name.lower(), None) + + def get_group_members(self, user_group_id: int) -> List[int]: + return self.user_group_members.get(user_group_id, []) + + +def get_stream_name_info(realm: Realm, stream_names: Set[str]) -> Dict[str, FullNameInfo]: + if not stream_names: + return {} + + q_list = {Q(name=name) for name in stream_names} + + rows = ( + get_active_streams( + realm=realm, + ) + .filter( + functools.reduce(lambda a, b: a | b, q_list), + ) + .values( + "id", + "name", + ) + ) + + dct = {row["name"]: row for row in rows} + return dct diff --git a/zerver/lib/message.py b/zerver/lib/message.py index 90d2d76a5f..3abdb9b2c8 100644 --- a/zerver/lib/message.py +++ b/zerver/lib/message.py @@ -27,8 +27,9 @@ from zerver.lib.display_recipient import ( UserDisplayRecipient, bulk_fetch_display_recipients, ) -from zerver.lib.markdown import MentionData, markdown_convert, topic_links +from zerver.lib.markdown import markdown_convert, topic_links from zerver.lib.markdown import version as markdown_version +from zerver.lib.mention import MentionData from zerver.lib.request import JsonableError from zerver.lib.stream_subscription import ( get_stream_subscriptions_for_user, diff --git a/zerver/lib/types.py b/zerver/lib/types.py index f0555a1da5..3c326ee6d4 100644 --- a/zerver/lib/types.py +++ b/zerver/lib/types.py @@ -68,3 +68,9 @@ class SAMLIdPConfigDict(TypedDict, total=False): extra_attrs: List[str] x509cert: str x509cert_path: str + + +class FullNameInfo(TypedDict): + id: int + email: str + full_name: str diff --git a/zerver/tests/test_events.py b/zerver/tests/test_events.py index 8b71884956..44f7726d57 100644 --- a/zerver/tests/test_events.py +++ b/zerver/tests/test_events.py @@ -167,7 +167,7 @@ from zerver.lib.events import ( fetch_initial_state_data, post_process_state, ) -from zerver.lib.markdown import MentionData +from zerver.lib.mention import MentionData from zerver.lib.message import render_markdown from zerver.lib.test_classes import ZulipTestCase from zerver.lib.test_helpers import ( diff --git a/zerver/tests/test_markdown.py b/zerver/tests/test_markdown.py index 3230bfb497..480a5b8dce 100644 --- a/zerver/tests/test_markdown.py +++ b/zerver/tests/test_markdown.py @@ -25,11 +25,9 @@ from zerver.lib.emoji import get_emoji_url from zerver.lib.exceptions import MarkdownRenderingException from zerver.lib.markdown import ( MarkdownListPreprocessor, - MentionData, clear_state_for_testing, content_has_emoji_syntax, fetch_tweet_data, - get_possible_mentions_info, get_tweet_id, image_preview_enabled, markdown_convert, @@ -41,7 +39,12 @@ from zerver.lib.markdown import ( ) from zerver.lib.markdown.fenced_code import FencedBlockPreprocessor from zerver.lib.mdiff import diff_strings -from zerver.lib.mention import possible_mentions, possible_user_group_mentions +from zerver.lib.mention import ( + MentionData, + get_possible_mentions_info, + possible_mentions, + possible_user_group_mentions, +) from zerver.lib.message import render_markdown from zerver.lib.request import JsonableError from zerver.lib.test_classes import ZulipTestCase diff --git a/zerver/tests/test_message_fetch.py b/zerver/tests/test_message_fetch.py index 770b2731f8..49a85b801c 100644 --- a/zerver/tests/test_message_fetch.py +++ b/zerver/tests/test_message_fetch.py @@ -20,7 +20,7 @@ from zerver.lib.actions import ( do_update_message, ) from zerver.lib.avatar import avatar_url -from zerver.lib.markdown import MentionData +from zerver.lib.mention import MentionData from zerver.lib.message import ( MessageDict, get_first_visible_message_id,