outgoing_webhook: Add a requests session on the webhook.

The session object provides a common place to set headers on all
requests, no matter which implementation.

Because the `headers` attribute of Session is not a true static
attribute, but rather exposed via overriding `__getstate__`, `mock`'s
autospec cannot know about it, and thus throws an error; in tests that
mock the Session, we thus must explicitly set the `session.headers`.
This commit is contained in:
Alex Vandiver 2021-03-26 20:59:27 -07:00 committed by Tim Abbott
parent be100154dd
commit cb3e6df8b9
3 changed files with 72 additions and 75 deletions

View File

@ -5,7 +5,7 @@ from typing import Any, AnyStr, Dict, Optional
import requests import requests
from django.utils.translation import ugettext as _ from django.utils.translation import ugettext as _
from requests import Response from requests import Response, Session
from version import ZULIP_VERSION from version import ZULIP_VERSION
from zerver.decorator import JsonableError from zerver.decorator import JsonableError
@ -30,6 +30,7 @@ class OutgoingWebhookServiceInterface(metaclass=abc.ABCMeta):
self.token: str = token self.token: str = token
self.user_profile: UserProfile = user_profile self.user_profile: UserProfile = user_profile
self.service_name: str = service_name self.service_name: str = service_name
self.session: Session = Session()
@abc.abstractmethod @abc.abstractmethod
def make_request(self, base_url: str, event: Dict[str, Any]) -> Optional[Response]: def make_request(self, base_url: str, event: Dict[str, Any]) -> Optional[Response]:
@ -70,7 +71,7 @@ class GenericOutgoingWebhookService(OutgoingWebhookServiceInterface):
headers = { headers = {
"User-Agent": user_agent, "User-Agent": user_agent,
} }
return requests.request("POST", base_url, json=request_data, headers=headers) return self.session.post(base_url, json=request_data, headers=headers)
def process_success(self, response_json: Dict[str, Any]) -> Optional[Dict[str, Any]]: def process_success(self, response_json: Dict[str, Any]) -> Optional[Dict[str, Any]]:
if "response_not_required" in response_json and response_json["response_not_required"]: if "response_not_required" in response_json and response_json["response_not_required"]:
@ -112,7 +113,7 @@ class SlackOutgoingWebhookService(OutgoingWebhookServiceInterface):
("trigger_word", event["trigger"]), ("trigger_word", event["trigger"]),
("service_id", event["user_profile_id"]), ("service_id", event["user_profile_id"]),
] ]
return requests.request("POST", base_url, data=request_data) return self.session.post(base_url, data=request_data)
def process_success(self, response_json: Dict[str, Any]) -> Optional[Dict[str, Any]]: def process_success(self, response_json: Dict[str, Any]) -> Optional[Dict[str, Any]]:
if "text" in response_json: if "text" in response_json:

View File

@ -103,23 +103,14 @@ class TestGenericOutgoingWebhookService(ZulipTestCase):
} }
test_url = "https://example.com/example" test_url = "https://example.com/example"
response = mock.Mock(spec=requests.Response) with mock.patch.object(self.handler, "session") as session:
response.status_code = 200
expect_200 = mock.patch("requests.request", return_value=response)
with expect_200 as mock_request:
self.handler.make_request( self.handler.make_request(
test_url, test_url,
event, event,
) )
mock_request.assert_called_once() session.post.assert_called_once()
self.assertEqual( self.assertEqual(session.post.call_args[0], (test_url,))
mock_request.call_args[0], request_data = session.post.call_args[1]["json"]
(
"POST",
test_url,
),
)
request_data = mock_request.call_args[1]["json"]
validate_against_openapi_schema(request_data, "/zulip-outgoing-webhook", "post", "200") validate_against_openapi_schema(request_data, "/zulip-outgoing-webhook", "post", "200")
self.assertEqual(request_data["data"], "@**test**") self.assertEqual(request_data["data"], "@**test**")
@ -199,20 +190,14 @@ class TestSlackOutgoingWebhookService(ZulipTestCase):
def test_make_request_stream_message(self) -> None: def test_make_request_stream_message(self) -> None:
test_url = "https://example.com/example" test_url = "https://example.com/example"
with mock.patch("requests.request") as mock_request: with mock.patch.object(self.handler, "session") as session:
self.handler.make_request( self.handler.make_request(
test_url, test_url,
self.stream_message_event, self.stream_message_event,
) )
mock_request.assert_called_once() session.post.assert_called_once()
self.assertEqual( self.assertEqual(session.post.call_args[0], (test_url,))
mock_request.call_args[0], request_data = session.post.call_args[1]["data"]
(
"POST",
test_url,
),
)
request_data = mock_request.call_args[1]["data"]
self.assertEqual(request_data[0][1], "abcdef") # token self.assertEqual(request_data[0][1], "abcdef") # token
self.assertEqual(request_data[1][1], "zulip") # team_id self.assertEqual(request_data[1][1], "zulip") # team_id
@ -229,12 +214,12 @@ class TestSlackOutgoingWebhookService(ZulipTestCase):
@mock.patch("zerver.lib.outgoing_webhook.fail_with_message") @mock.patch("zerver.lib.outgoing_webhook.fail_with_message")
def test_make_request_private_message(self, mock_fail_with_message: mock.Mock) -> None: def test_make_request_private_message(self, mock_fail_with_message: mock.Mock) -> None:
test_url = "https://example.com/example" test_url = "https://example.com/example"
with mock.patch("requests.request") as mock_request: with mock.patch.object(self.handler, "session") as session:
response = self.handler.make_request( response = self.handler.make_request(
test_url, test_url,
self.private_message_event, self.private_message_event,
) )
mock_request.assert_not_called() session.post.assert_not_called()
self.assertIsNone(response) self.assertIsNone(response)
self.assertTrue(mock_fail_with_message.called) self.assertTrue(mock_fail_with_message.called)

View File

@ -25,15 +25,15 @@ class ResponseMock:
self.text = content.decode() self.text = content.decode()
def request_exception_error(http_method: Any, final_url: Any, **request_kwargs: Any) -> Any: def request_exception_error(final_url: Any, **request_kwargs: Any) -> Any:
raise requests.exceptions.RequestException("I'm a generic exception :(") raise requests.exceptions.RequestException("I'm a generic exception :(")
def timeout_error(http_method: Any, final_url: Any, **request_kwargs: Any) -> Any: def timeout_error(final_url: Any, **request_kwargs: Any) -> Any:
raise requests.exceptions.Timeout("Time is up!") raise requests.exceptions.Timeout("Time is up!")
def connection_error(http_method: Any, final_url: Any, **request_kwargs: Any) -> Any: def connection_error(final_url: Any, **request_kwargs: Any) -> Any:
raise requests.exceptions.ConnectionError() raise requests.exceptions.ConnectionError()
@ -74,31 +74,30 @@ class DoRestCallTests(ZulipTestCase):
mock_event = self.mock_event(bot_user) mock_event = self.mock_event(bot_user)
service_handler = GenericOutgoingWebhookService("token", bot_user, "service") service_handler = GenericOutgoingWebhookService("token", bot_user, "service")
response = ResponseMock(200, orjson.dumps(dict(content="whatever")))
expect_200 = mock.patch("requests.request", return_value=response)
expect_send_response = mock.patch("zerver.lib.outgoing_webhook.send_response_message") expect_send_response = mock.patch("zerver.lib.outgoing_webhook.send_response_message")
with expect_200, expect_send_response as mock_send: with mock.patch.object(
service_handler, "session"
) as session, expect_send_response as mock_send:
session.post.return_value = ResponseMock(200, orjson.dumps(dict(content="whatever")))
do_rest_call("", mock_event, service_handler) do_rest_call("", mock_event, service_handler)
self.assertTrue(mock_send.called) self.assertTrue(mock_send.called)
for service_class in [GenericOutgoingWebhookService, SlackOutgoingWebhookService]: for service_class in [GenericOutgoingWebhookService, SlackOutgoingWebhookService]:
handler = service_class("token", bot_user, "service") handler = service_class("token", bot_user, "service")
with mock.patch.object(handler, "session") as session:
with expect_200: session.post.return_value = ResponseMock(200)
do_rest_call("", mock_event, handler) do_rest_call("", mock_event, handler)
session.post.assert_called_once()
# TODO: assert something interesting here?
def test_retry_request(self) -> None: def test_retry_request(self) -> None:
bot_user = self.example_user("outgoing_webhook_bot") bot_user = self.example_user("outgoing_webhook_bot")
mock_event = self.mock_event(bot_user) mock_event = self.mock_event(bot_user)
service_handler = GenericOutgoingWebhookService("token", bot_user, "service") service_handler = GenericOutgoingWebhookService("token", bot_user, "service")
response = ResponseMock(500) with mock.patch.object(service_handler, "session") as session, self.assertLogs(
with mock.patch("requests.request", return_value=response), self.assertLogs(
level="WARNING" level="WARNING"
) as m: ) as m:
session.post.return_value = ResponseMock(500)
final_response = do_rest_call("", mock_event, service_handler) final_response = do_rest_call("", mock_event, service_handler)
assert final_response is not None assert final_response is not None
@ -123,11 +122,12 @@ The webhook got a response with status code *500*.""",
mock_event = self.mock_event(bot_user) mock_event = self.mock_event(bot_user)
service_handler = GenericOutgoingWebhookService("token", bot_user, "service") service_handler = GenericOutgoingWebhookService("token", bot_user, "service")
response = ResponseMock(400)
expect_400 = mock.patch("requests.request", return_value=response)
expect_fail = mock.patch("zerver.lib.outgoing_webhook.fail_with_message") expect_fail = mock.patch("zerver.lib.outgoing_webhook.fail_with_message")
with expect_400, expect_fail as mock_fail, self.assertLogs(level="WARNING") as m: with mock.patch.object(
service_handler, "session"
) as session, expect_fail as mock_fail, self.assertLogs(level="WARNING") as m:
session.post.return_value = ResponseMock(400)
final_response = do_rest_call("", mock_event, service_handler) final_response = do_rest_call("", mock_event, service_handler)
assert final_response is not None assert final_response is not None
@ -155,7 +155,8 @@ The webhook got a response with status code *400*.""",
mock_event = self.mock_event(bot_user) mock_event = self.mock_event(bot_user)
service_handler = GenericOutgoingWebhookService("token", bot_user, "service") service_handler = GenericOutgoingWebhookService("token", bot_user, "service")
with mock.patch("requests.sessions.Session.send") as mock_send: session = service_handler.session
with mock.patch.object(session, "send") as mock_send:
mock_send.return_value = ResponseMock(200) mock_send.return_value = ResponseMock(200)
final_response = do_rest_call("https://example.com/", mock_event, service_handler) final_response = do_rest_call("https://example.com/", mock_event, service_handler)
assert final_response is not None assert final_response is not None
@ -176,8 +177,8 @@ The webhook got a response with status code *400*.""",
bot_user_email = self.example_user_map["outgoing_webhook_bot"] bot_user_email = self.example_user_map["outgoing_webhook_bot"]
def helper(side_effect: Any, error_text: str) -> None: def helper(side_effect: Any, error_text: str) -> None:
with mock.patch.object(service_handler, "session") as session:
with mock.patch("requests.request", side_effect=side_effect): session.post.side_effect = side_effect
do_rest_call("", mock_event, service_handler) do_rest_call("", mock_event, service_handler)
bot_owner_notification = self.get_last_message() bot_owner_notification = self.get_last_message()
@ -204,15 +205,15 @@ The webhook got a response with status code *400*.""",
mock_event = self.mock_event(bot_user) mock_event = self.mock_event(bot_user)
service_handler = GenericOutgoingWebhookService("token", bot_user, "service") service_handler = GenericOutgoingWebhookService("token", bot_user, "service")
expect_request_exception = mock.patch(
"requests.request", side_effect=request_exception_error
)
expect_logging_exception = self.assertLogs(level="ERROR") expect_logging_exception = self.assertLogs(level="ERROR")
expect_fail = mock.patch("zerver.lib.outgoing_webhook.fail_with_message") expect_fail = mock.patch("zerver.lib.outgoing_webhook.fail_with_message")
# Don't think that we should catch and assert whole log output(which is actually a very big error traceback). # Don't think that we should catch and assert whole log output(which is actually a very big error traceback).
# We are already asserting bot_owner_notification.content which verifies exception did occur. # We are already asserting bot_owner_notification.content which verifies exception did occur.
with expect_request_exception, expect_logging_exception, expect_fail as mock_fail: with mock.patch.object(
service_handler, "session"
) as session, expect_logging_exception, expect_fail as mock_fail:
session.post.side_effect = request_exception_error
do_rest_call("", mock_event, service_handler) do_rest_call("", mock_event, service_handler)
self.assertTrue(mock_fail.called) self.assertTrue(mock_fail.called)
@ -271,8 +272,11 @@ class TestOutgoingWebhookMessaging(ZulipTestCase):
sender = self.example_user("hamlet") sender = self.example_user("hamlet")
with mock.patch("requests.request") as mock_request: session = mock.Mock(spec=requests.Session)
mock_request.return_value = ResponseMock(200) session.headers = {}
session.post.return_value = ResponseMock(200)
with mock.patch("zerver.lib.outgoing_webhook.Session") as sessionmaker:
sessionmaker.return_value = session
self.send_personal_message( self.send_personal_message(
sender, sender,
bot, bot,
@ -280,9 +284,10 @@ class TestOutgoingWebhookMessaging(ZulipTestCase):
) )
url_token_tups = set() url_token_tups = set()
for item in mock_request.call_args_list: session.post.assert_called()
for item in session.post.call_args_list:
args = item[0] args = item[0]
base_url = args[1] base_url = args[0]
kwargs = item[1] kwargs = item[1]
request_data = kwargs["json"] request_data = kwargs["json"]
tup = (base_url, request_data["token"]) tup = (base_url, request_data["token"])
@ -299,18 +304,19 @@ class TestOutgoingWebhookMessaging(ZulipTestCase):
}, },
) )
@mock.patch( def test_pm_to_outgoing_webhook_bot(self) -> None:
"requests.request",
return_value=ResponseMock(
200, orjson.dumps({"response_string": "Hidley ho, I'm a webhook responding!"})
),
)
def test_pm_to_outgoing_webhook_bot(self, mock_requests_request: mock.Mock) -> None:
bot_owner = self.example_user("othello") bot_owner = self.example_user("othello")
bot = self.create_outgoing_bot(bot_owner) bot = self.create_outgoing_bot(bot_owner)
sender = self.example_user("hamlet") sender = self.example_user("hamlet")
self.send_personal_message(sender, bot, content="foo") session = mock.Mock(spec=requests.Session)
session.headers = {}
session.post.return_value = ResponseMock(
200, orjson.dumps({"response_string": "Hidley ho, I'm a webhook responding!"})
)
with mock.patch("zerver.lib.outgoing_webhook.Session") as sessionmaker:
sessionmaker.return_value = session
self.send_personal_message(sender, bot, content="foo")
last_message = self.get_last_message() last_message = self.get_last_message()
self.assertEqual(last_message.content, "Hidley ho, I'm a webhook responding!") self.assertEqual(last_message.content, "Hidley ho, I'm a webhook responding!")
self.assertEqual(last_message.sender_id, bot.id) self.assertEqual(last_message.sender_id, bot.id)
@ -329,11 +335,15 @@ class TestOutgoingWebhookMessaging(ZulipTestCase):
sender = self.example_user("hamlet") sender = self.example_user("hamlet")
realm = get_realm("zulip") realm = get_realm("zulip")
response = ResponseMock(407) session = mock.Mock(spec=requests.Session)
expect_407 = mock.patch("requests.request", return_value=response) session.headers = {}
session.post.return_value = ResponseMock(407)
expect_fail = mock.patch("zerver.lib.outgoing_webhook.fail_with_message") expect_fail = mock.patch("zerver.lib.outgoing_webhook.fail_with_message")
with expect_407, expect_fail as mock_fail, self.assertLogs(level="WARNING"): with mock.patch(
"zerver.lib.outgoing_webhook.Session"
) as sessionmaker, expect_fail as mock_fail, self.assertLogs(level="WARNING"):
sessionmaker.return_value = session
message_id = self.send_personal_message(sender, bot, content="foo") message_id = self.send_personal_message(sender, bot, content="foo")
# create message dict to get the message url # create message dict to get the message url
@ -363,19 +373,20 @@ class TestOutgoingWebhookMessaging(ZulipTestCase):
) )
self.assertTrue(mock_fail.called) self.assertTrue(mock_fail.called)
@mock.patch( def test_stream_message_to_outgoing_webhook_bot(self) -> None:
"requests.request",
return_value=ResponseMock(
200, orjson.dumps({"response_string": "Hidley ho, I'm a webhook responding!"})
),
)
def test_stream_message_to_outgoing_webhook_bot(self, mock_requests_request: mock.Mock) -> None:
bot_owner = self.example_user("othello") bot_owner = self.example_user("othello")
bot = self.create_outgoing_bot(bot_owner) bot = self.create_outgoing_bot(bot_owner)
self.send_stream_message( session = mock.Mock(spec=requests.Session)
bot_owner, "Denmark", content=f"@**{bot.full_name}** foo", topic_name="bar" session.headers = {}
session.post.return_value = ResponseMock(
200, orjson.dumps({"response_string": "Hidley ho, I'm a webhook responding!"})
) )
with mock.patch("zerver.lib.outgoing_webhook.Session") as sessionmaker:
sessionmaker.return_value = session
self.send_stream_message(
bot_owner, "Denmark", content=f"@**{bot.full_name}** foo", topic_name="bar"
)
last_message = self.get_last_message() last_message = self.get_last_message()
self.assertEqual(last_message.content, "Hidley ho, I'm a webhook responding!") self.assertEqual(last_message.content, "Hidley ho, I'm a webhook responding!")
self.assertEqual(last_message.sender_id, bot.id) self.assertEqual(last_message.sender_id, bot.id)