semgrep: Detect some unsafe uses of markupsafe.Markup.

Use the built-in HTML escaping of Markup("…{var}…").format(), in order
to allow Semgrep to detect mistakes like Markup("…{var}…".format())
and Markup(f"…{var}…").

Signed-off-by: Anders Kaseorg <anders@zulip.com>
This commit is contained in:
Anders Kaseorg 2023-03-20 23:10:20 -07:00 committed by Tim Abbott
parent f66d952c57
commit afa218fa2a
6 changed files with 36 additions and 28 deletions

View File

@ -1,7 +1,6 @@
import re import re
import sys import sys
from datetime import datetime from datetime import datetime
from html import escape
from typing import Any, Collection, Dict, List, Optional, Sequence from typing import Any, Collection, Dict, List, Optional, Sequence
from urllib.parse import urlencode from urllib.parse import urlencode
@ -63,46 +62,42 @@ def user_activity_link(email: str, user_profile_id: int) -> Markup:
from analytics.views.user_activity import get_user_activity from analytics.views.user_activity import get_user_activity
url = reverse(get_user_activity, kwargs=dict(user_profile_id=user_profile_id)) url = reverse(get_user_activity, kwargs=dict(user_profile_id=user_profile_id))
email_link = f'<a href="{escape(url)}">{escape(email)}</a>' return Markup('<a href="{url}">{email}</a>').format(url=url, email=email)
return Markup(email_link)
def realm_activity_link(realm_str: str) -> Markup: def realm_activity_link(realm_str: str) -> Markup:
from analytics.views.realm_activity import get_realm_activity from analytics.views.realm_activity import get_realm_activity
url = reverse(get_realm_activity, kwargs=dict(realm_str=realm_str)) url = reverse(get_realm_activity, kwargs=dict(realm_str=realm_str))
realm_link = f'<a href="{escape(url)}">{escape(realm_str)}</a>' return Markup('<a href="{url}">{realm_str}</a>').format(url=url, realm_str=realm_str)
return Markup(realm_link)
def realm_stats_link(realm_str: str) -> Markup: def realm_stats_link(realm_str: str) -> Markup:
from analytics.views.stats import stats_for_realm from analytics.views.stats import stats_for_realm
url = reverse(stats_for_realm, kwargs=dict(realm_str=realm_str)) url = reverse(stats_for_realm, kwargs=dict(realm_str=realm_str))
stats_link = f'<a href="{escape(url)}"><i class="fa fa-pie-chart"></i></a>' return Markup('<a href="{url}"><i class="fa fa-pie-chart"></i></a>').format(url=url)
return Markup(stats_link)
def realm_support_link(realm_str: str) -> Markup: def realm_support_link(realm_str: str) -> Markup:
support_url = reverse("support") support_url = reverse("support")
query = urlencode({"q": realm_str}) query = urlencode({"q": realm_str})
url = append_url_query_string(support_url, query) url = append_url_query_string(support_url, query)
support_link = f'<a href="{escape(url)}">{escape(realm_str)}</a>' return Markup('<a href="{url}">{realm_str}</a>').format(url=url, realm_str=realm_str)
return Markup(support_link)
def realm_url_link(realm_str: str) -> Markup: def realm_url_link(realm_str: str) -> Markup:
url = get_realm(realm_str).uri url = get_realm(realm_str).uri
realm_link = f'<a href="{escape(url)}"><i class="fa fa-home"></i></a>' return Markup('<a href="{url}"><i class="fa fa-home"></i></a>').format(url=url)
return Markup(realm_link)
def remote_installation_stats_link(server_id: int, hostname: str) -> Markup: def remote_installation_stats_link(server_id: int, hostname: str) -> Markup:
from analytics.views.stats import stats_for_remote_installation from analytics.views.stats import stats_for_remote_installation
url = reverse(stats_for_remote_installation, kwargs=dict(remote_server_id=server_id)) url = reverse(stats_for_remote_installation, kwargs=dict(remote_server_id=server_id))
stats_link = f'<a href="{escape(url)}"><i class="fa fa-pie-chart"></i>{escape(hostname)}</a>' return Markup('<a href="{url}"><i class="fa fa-pie-chart"></i>{hostname}</a>').format(
return Markup(stats_link) url=url, hostname=hostname
)
def get_user_activity_summary(records: Collection[UserActivity]) -> Dict[str, Any]: def get_user_activity_summary(records: Collection[UserActivity]) -> Dict[str, Any]:

View File

@ -38,7 +38,7 @@ if settings.BILLING_ENABLED:
) )
def get_realm_day_counts() -> Dict[str, Dict[str, str]]: def get_realm_day_counts() -> Dict[str, Dict[str, Markup]]:
query = SQL( query = SQL(
""" """
select select
@ -78,7 +78,7 @@ def get_realm_day_counts() -> Dict[str, Dict[str, str]]:
min_cnt = min(raw_cnts[1:]) min_cnt = min(raw_cnts[1:])
max_cnt = max(raw_cnts[1:]) max_cnt = max(raw_cnts[1:])
def format_count(cnt: int, style: Optional[str] = None) -> str: def format_count(cnt: int, style: Optional[str] = None) -> Markup:
if style is not None: if style is not None:
good_bad = style good_bad = style
elif cnt == min_cnt: elif cnt == min_cnt:
@ -88,9 +88,11 @@ def get_realm_day_counts() -> Dict[str, Dict[str, str]]:
else: else:
good_bad = "neutral" good_bad = "neutral"
return f'<td class="number {good_bad}">{cnt}</td>' return Markup('<td class="number {good_bad}">{cnt}</td>').format(
good_bad=good_bad, cnt=cnt
)
cnts = format_count(raw_cnts[0], "neutral") + "".join(map(format_count, raw_cnts[1:])) cnts = format_count(raw_cnts[0], "neutral") + Markup().join(map(format_count, raw_cnts[1:]))
result[string_id] = dict(cnts=cnts) result[string_id] = dict(cnts=cnts)
return result return result
@ -304,7 +306,8 @@ def user_activity_intervals() -> Tuple[Markup, Dict[str, float]]:
day_end = timestamp_to_datetime(time.time()) day_end = timestamp_to_datetime(time.time())
day_start = day_end - timedelta(hours=24) day_start = day_end - timedelta(hours=24)
output = "Per-user online duration for the last 24 hours:\n" output = Markup()
output += "Per-user online duration for the last 24 hours:\n"
total_duration = timedelta(0) total_duration = timedelta(0)
all_intervals = ( all_intervals = (
@ -335,7 +338,7 @@ def user_activity_intervals() -> Tuple[Markup, Dict[str, float]]:
for string_id, realm_intervals in itertools.groupby(all_intervals, by_string_id): for string_id, realm_intervals in itertools.groupby(all_intervals, by_string_id):
realm_duration = timedelta(0) realm_duration = timedelta(0)
output += f"<hr>{string_id}\n" output += Markup("<hr>") + f"{string_id}\n"
for email, intervals in itertools.groupby(realm_intervals, by_email): for email, intervals in itertools.groupby(realm_intervals, by_email):
duration = timedelta(0) duration = timedelta(0)
for interval in intervals: for interval in intervals:
@ -352,7 +355,7 @@ def user_activity_intervals() -> Tuple[Markup, Dict[str, float]]:
output += f"\nTotal duration: {total_duration}\n" output += f"\nTotal duration: {total_duration}\n"
output += f"\nTotal duration in minutes: {total_duration.total_seconds() / 60.}\n" output += f"\nTotal duration in minutes: {total_duration.total_seconds() / 60.}\n"
output += f"Total duration amortized to a month: {total_duration.total_seconds() * 30. / 60.}" output += f"Total duration amortized to a month: {total_duration.total_seconds() * 30. / 60.}"
content = Markup("<pre>" + output + "</pre>") content = Markup("<pre>{}</pre>").format(output)
return content, realm_minutes return content, realm_minutes

View File

@ -42,17 +42,30 @@ rules:
- zerver/migrations/0387_reupload_realmemoji_again.py - zerver/migrations/0387_reupload_realmemoji_again.py
- pgroonga/migrations/0002_html_escape_subject.py - pgroonga/migrations/0002_html_escape_subject.py
- id: html-format
languages: [python]
pattern-either:
- pattern: markupsafe.Markup(... .format(...))
- pattern: markupsafe.Markup(f"...")
- pattern: markupsafe.Markup(... + ...)
severity: ERROR
message: "Do not write an HTML injection vulnerability please"
- id: sql-format - id: sql-format
languages: [python] languages: [python]
pattern-either: pattern-either:
- pattern: ... .execute("...".format(...)) - pattern: ... .execute("...".format(...))
- pattern: ... .execute(f"...") - pattern: ... .execute(f"...")
- pattern: ... .execute(... + ...)
- pattern: psycopg2.sql.SQL(... .format(...)) - pattern: psycopg2.sql.SQL(... .format(...))
- pattern: psycopg2.sql.SQL(f"...") - pattern: psycopg2.sql.SQL(f"...")
- pattern: psycopg2.sql.SQL(... + ...)
- pattern: django.db.migrations.RunSQL(..., "..." .format(...), ...) - pattern: django.db.migrations.RunSQL(..., "..." .format(...), ...)
- pattern: django.db.migrations.RunSQL(..., f"...", ...) - pattern: django.db.migrations.RunSQL(..., f"...", ...)
- pattern: django.db.migrations.RunSQL(..., ... + ..., ...)
- pattern: django.db.migrations.RunSQL(..., [..., "..." .format(...), ...], ...) - pattern: django.db.migrations.RunSQL(..., [..., "..." .format(...), ...], ...)
- pattern: django.db.migrations.RunSQL(..., [..., f"...", ...], ...) - pattern: django.db.migrations.RunSQL(..., [..., f"...", ...], ...)
- pattern: django.db.migrations.RunSQL(..., [..., ... + ..., ...], ...)
severity: ERROR severity: ERROR
message: "Do not write a SQL injection vulnerability please" message: "Do not write a SQL injection vulnerability please"

View File

@ -51,7 +51,7 @@ if settings.BILLING_ENABLED:
# We don't mark this error for translation, because it's displayed # We don't mark this error for translation, because it's displayed
# only to MIT users. # only to MIT users.
MIT_VALIDATION_ERROR = ( MIT_VALIDATION_ERROR = Markup(
"That user does not exist at MIT or is a" "That user does not exist at MIT or is a"
' <a href="https://ist.mit.edu/email-lists">mailing list</a>.' ' <a href="https://ist.mit.edu/email-lists">mailing list</a>.'
" If you want to sign up an alias for Zulip," " If you want to sign up an alias for Zulip,"
@ -76,7 +76,7 @@ def email_is_not_mit_mailing_list(email: str) -> None:
if e.rcode == DNS.Status.NXDOMAIN: if e.rcode == DNS.Status.NXDOMAIN:
# This error is Markup only because 1. it needs to render HTML # This error is Markup only because 1. it needs to render HTML
# 2. It's not formatted with any user input. # 2. It's not formatted with any user input.
raise ValidationError(Markup(MIT_VALIDATION_ERROR)) raise ValidationError(MIT_VALIDATION_ERROR)
else: else:
raise AssertionError("Unexpected DNS error") raise AssertionError("Unexpected DNS error")

View File

@ -6,7 +6,6 @@ from datetime import datetime
from typing import IO, Any, Callable, Iterator, List, Optional, Tuple from typing import IO, Any, Callable, Iterator, List, Optional, Tuple
from django.utils.translation import gettext as _ from django.utils.translation import gettext as _
from markupsafe import Markup
from PIL import GifImagePlugin, Image, ImageOps, PngImagePlugin from PIL import GifImagePlugin, Image, ImageOps, PngImagePlugin
from PIL.Image import DecompressionBombError from PIL.Image import DecompressionBombError
@ -52,7 +51,7 @@ def sanitize_name(value: str) -> str:
value = re.sub(r"[^\w\s.-]", "", value).strip() value = re.sub(r"[^\w\s.-]", "", value).strip()
value = re.sub(r"[-\s]+", "-", value) value = re.sub(r"[-\s]+", "-", value)
assert value not in {"", ".", ".."} assert value not in {"", ".", ".."}
return Markup(value) return value
class BadImageError(JsonableError): class BadImageError(JsonableError):

View File

@ -19,12 +19,10 @@ from django.http import HttpRequest, HttpResponse, HttpResponseRedirect, HttpRes
from django.shortcuts import redirect, render from django.shortcuts import redirect, render
from django.template.response import SimpleTemplateResponse, TemplateResponse from django.template.response import SimpleTemplateResponse, TemplateResponse
from django.urls import reverse from django.urls import reverse
from django.utils.html import escape
from django.utils.http import url_has_allowed_host_and_scheme from django.utils.http import url_has_allowed_host_and_scheme
from django.utils.translation import gettext as _ from django.utils.translation import gettext as _
from django.views.decorators.csrf import csrf_exempt from django.views.decorators.csrf import csrf_exempt
from django.views.decorators.http import require_safe from django.views.decorators.http import require_safe
from markupsafe import Markup
from social_django.utils import load_backend, load_strategy from social_django.utils import load_backend, load_strategy
from two_factor.forms import BackupTokenForm from two_factor.forms import BackupTokenForm
from two_factor.views import LoginView as BaseTwoFactorLoginView from two_factor.views import LoginView as BaseTwoFactorLoginView
@ -719,8 +717,8 @@ def update_login_page_context(request: HttpRequest, context: Dict[str, Any]) ->
return return
try: try:
validate_email(deactivated_email) validate_email(deactivated_email)
context["deactivated_account_error"] = Markup( context["deactivated_account_error"] = DEACTIVATED_ACCOUNT_ERROR.format(
DEACTIVATED_ACCOUNT_ERROR.format(username=escape(deactivated_email)) username=deactivated_email
) )
except ValidationError: except ValidationError:
logging.info("Invalid email in is_deactivated param to login page: %s", deactivated_email) logging.info("Invalid email in is_deactivated param to login page: %s", deactivated_email)