blob: ae13009c07c585a48de189554bb03e5b5af1d174 [file] [log] [blame]
# Copyright 2022 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Unit tests for local_agent.ams_client module."""
import http
import json
import requests
import time
import unittest
from unittest import mock
from local_agent import ams_client
from local_agent import errors
_FAKE_HOST = 'fake-host'
_FAKE_PORT = 8000
class AmsClientTest(unittest.TestCase):
def setUp(self):
super().setUp()
self.sut = ams_client.AmsClient(host=_FAKE_HOST, port=_FAKE_PORT)
sleep_patcher = mock.patch.object(time, 'sleep')
sleep_patcher.start()
self.addCleanup(sleep_patcher.stop)
@mock.patch.object(ams_client.AmsClient, 'set_local_agent_credentials')
@mock.patch.object(requests.sessions.Session, 'request')
def test_register_success(self, mock_request, mock_set_credentials):
"""Verifies register successful."""
mock_response = mock_request.return_value
mock_response.status_code = 200
mock_response.json.return_value = {
'result': {
'agentId': 'the-id',
'agentSecret': 'the-secret'}}
self.sut.register(linking_code='the-linking-code')
mock_set_credentials.assert_called_once_with(
local_agent_id='the-id',
local_agent_secret='the-secret')
@mock.patch.object(requests.sessions.Session, 'request')
def test_register_api_timeout(self, mock_request):
"""Verifies register raises ApiTimeoutError when request timed out."""
mock_request.side_effect = requests.exceptions.Timeout
with self.assertRaises(errors.ApiTimeoutError):
self.sut.register(linking_code='the-linking-code')
@mock.patch.object(requests.sessions.Session, 'request')
def test_register_api_return_status_bad_request(self, mock_request):
"""Verifies register raises CredentialError when API response 400."""
mock_request.return_value.status_code = http.HTTPStatus.BAD_REQUEST
with self.assertRaises(errors.CredentialsError):
self.sut.register(linking_code='the-linking-code')
@mock.patch.object(requests.sessions.Session, 'request')
def test_register_api_error_message_included(self, mock_request):
"""Verifies register includes AMS error message in its exception."""
mock_response = mock_request.return_value
mock_response.status_code = http.HTTPStatus.BAD_REQUEST
mock_response.json.return_value = {'errorMessage': 'the-message'}
with self.assertRaisesRegex(errors.CredentialsError, 'the-message'):
self.sut.register(linking_code='the-linking-code')
@mock.patch.object(requests.sessions.Session, 'request')
def test_register_api_return_status_not_ok(self, mock_request):
"""Verifies register raises ApiError when API fails."""
mock_request.return_value.status_code = http.HTTPStatus.NOT_FOUND
with self.assertRaises(errors.ApiError):
self.sut.register(linking_code='the-linking-code')
@mock.patch.object(requests.sessions.Session, 'request')
def test_set_local_agent_credentials_get_auth_token_success(self, mock_request):
"""Verifies set_local_agent_credentials successfully get auth token."""
mock_response = mock_request.return_value
mock_response.status_code = http.HTTPStatus.CREATED
mock_response.json.return_value = {
'result': {
'authToken': 'the-auth-token',
}
}
self.sut.set_local_agent_credentials('agent-id', 'agent-secret')
@mock.patch.object(ams_client.AmsClient, '_request_wrapper')
def test_get_auth_token_not_refreshing_auth(self, mock_request_wrapper):
"""Verifies _get_auth_token sets refresh_auth to False."""
mock_response = mock_request_wrapper.return_value
mock_response.status_code = http.HTTPStatus.CREATED
mock_response.json.return_value = {'result': {'authToken': 'abc'}}
self.sut._get_auth_token()
self.assertFalse(
mock_request_wrapper.call_args.kwargs['refresh_auth'],
msg='_get_auth_token needs to pass refresh_auth=False to '
'_request_wrapper to avoid infinite recursion')
@mock.patch.object(ams_client, 'extract_error_message_from_api_response')
@mock.patch.object(ams_client.AmsClient, '_request_wrapper')
def test_get_auth_token_unlinked_error(
self, mock_request_wrapper, mock_extract_err_msg):
"""Verifies _get_auth_token raises UnlinkedError."""
mock_response = mock_request_wrapper.return_value
mock_response.status_code = http.HTTPStatus.BAD_REQUEST
mock_extract_err_msg.return_value = 'Invalid agent id'
error_regex = 'Local agent is unlinked'
with self.assertRaisesRegex(errors.UnlinkedError, error_regex):
self.sut._get_auth_token()
@mock.patch.object(requests.sessions.Session, 'request')
def test_set_local_agent_credentials_raises_when_api_timeout(self,
mock_request):
"""Verifies set_local_agent_credentials raises when API timed out."""
mock_request.side_effect = requests.exceptions.Timeout
with self.assertRaises(errors.ApiTimeoutError):
self.sut.set_local_agent_credentials('agent-id', 'agent-secret')
@mock.patch.object(requests.sessions.Session, 'request')
def test_set_local_agent_credentials_raises_when_api_error(self,
mock_request):
"""Verifies set_local_agent_credentials raises when API has error."""
mock_request.return_value.status_code = http.HTTPStatus.NOT_FOUND
with self.assertRaises(errors.CredentialsError):
self.sut.set_local_agent_credentials('agent-id', 'agent-secret')
@mock.patch.object(requests.sessions.Session, 'request')
def test_set_local_agent_credentials_include_ams_error_message(
self,
mock_request):
"""Verifies set_local_agent_credentials includes AMS error message."""
mock_response = mock_request.return_value
mock_response.status_code = http.HTTPStatus.NOT_FOUND
mock_response.json.return_value = {'errorMessage': 'the-message'}
with self.assertRaisesRegex(errors.CredentialsError, 'the-message'):
self.sut.set_local_agent_credentials('agent-id', 'agent-secret')
@mock.patch.object(requests.sessions.Session,
'request',
side_effect=requests.exceptions.Timeout)
def test_report_info_api_timeout(self, mock_request):
"""Verifies report_info API timed out."""
with self.assertRaises(errors.ApiTimeoutError):
self.sut.report_info({})
@mock.patch.object(requests.sessions.Session, 'request')
def test_report_info_api_response_error(self, mock_request):
"""Verifies report_info raise exception when API response has error."""
mock_request.return_value.status_code = http.HTTPStatus.BAD_REQUEST
with self.assertRaisesRegex(errors.ApiError,
'Report info API failed: status 400'):
self.sut.report_info({})
@mock.patch.object(requests.sessions.Session, 'request')
def test_report_info_api_response_error_include_ams_error_message(
self,
mock_request):
"""Verifies report_info raise exception when API response has error."""
mock_response = mock_request.return_value
mock_response.status_code = http.HTTPStatus.BAD_REQUEST
mock_response.json.return_value = {'errorMessage': 'the-message'}
with self.assertRaisesRegex(errors.ApiError, 'the-message'):
self.sut.report_info({})
@mock.patch.object(requests.sessions.Session, 'request')
def test_report_info_successful(self, mock_request):
"""Verifies report_info succeeds and sends info dict to AMS."""
mock_request.return_value.status_code = http.HTTPStatus.OK
local_agent_info = {'hi': 'hello'}
self.sut.report_info(local_agent_info)
self.assertIn(
'json',
mock_request.call_args.kwargs,
'Local agent info should be passed as json arg to requests call.')
self.assertEqual(local_agent_info,
mock_request.call_args.kwargs['json'])
@mock.patch.object(requests.sessions.Session, 'request')
def test_get_rpc_request_from_ams_successful_with_request(self, mock_request):
"""Verifies get_rpc_request_from_ams gets a request from AMS."""
mock_response = mock_request.return_value
mock_response.status_code = http.HTTPStatus.OK
mock_response.json.return_value = {'result': {'hi': 'hello'}}
self.assertEqual(self.sut.get_rpc_request_from_ams(),
{'hi': 'hello'})
@mock.patch.object(requests.sessions.Session, 'request')
def test_get_rpc_request_from_ams_successful_no_request(self,
mock_request):
"""Verifies get_rpc_request_from_ams gets no request from AMS."""
mock_response = mock_request.return_value
mock_response.status_code = http.HTTPStatus.NO_CONTENT
self.assertIsNone(self.sut.get_rpc_request_from_ams())
@mock.patch.object(requests.sessions.Session, 'request')
def test_get_rpc_request_from_ams_raises_when_api_timeout(self,
mock_request):
"""Verifies get_rpc_request_from_ams raises when API timed out."""
mock_request.side_effect = requests.exceptions.Timeout
with self.assertRaises(errors.ApiTimeoutError):
self.sut.get_rpc_request_from_ams()
@mock.patch.object(requests.sessions.Session, 'request')
def test_get_rpc_request_from_ams_raise_when_api_response_error(
self, mock_request):
"""Verifies get_rpc_request_from_ams raises when response has error."""
mock_response = mock_request.return_value
mock_response.status_code = http.HTTPStatus.INTERNAL_SERVER_ERROR
with self.assertRaises(errors.ApiError):
self.sut.get_rpc_request_from_ams()
@mock.patch.object(requests.sessions.Session, 'request')
def test_get_rpc_request_from_ams_when_api_error_has_exception_message(
self, mock_request):
"""Verifies get_rpc_request_from_ams error message when ApiError."""
mock_response = mock_request.return_value
mock_response.status_code = http.HTTPStatus.INTERNAL_SERVER_ERROR
mock_response.json.return_value = {'errorMessage': 'message-from-ams'}
with self.assertRaisesRegex(errors.ApiError, r'500.*message-from-ams'):
self.sut.get_rpc_request_from_ams()
@mock.patch.object(requests.sessions.Session, 'request')
def test_remove_rpc_request_from_ams_successful(self, mock_request):
"""Verifies remove_rpc_request_from_ams succeeds."""
mock_response = mock_request.return_value
mock_response.status_code = http.HTTPStatus.OK
self.sut.remove_rpc_request_from_ams({'the': 'request'})
@mock.patch.object(requests.sessions.Session, 'request')
def test_remove_rpc_request_from_ams_raises_when_api_timeout(
self, mock_request):
"""Verifies remove_rpc_request_from_ams raises if API timed out."""
mock_request.side_effect = requests.exceptions.Timeout
with self.assertRaises(errors.ApiTimeoutError):
self.sut.remove_rpc_request_from_ams({'the': 'request'})
@mock.patch.object(requests.sessions.Session, 'request')
def test_remove_rpc_request_from_ams_raises_when_api_fails(
self, mock_request):
"""Verifies remove_rpc_request_from_ams raises if API has error."""
mock_response = mock_request.return_value
mock_response.status_code = http.HTTPStatus.INTERNAL_SERVER_ERROR
mock_response.json.return_value = {'errorMessage': 'msg-from-ams'}
with self.assertRaisesRegex(errors.ApiError, r'500.*msg-from-ams'):
self.sut.remove_rpc_request_from_ams({'the': 'request'})
@mock.patch.object(requests.sessions.Session, 'request')
def test_send_rpc_response_successful(self, mock_request):
"""Verifies send_rpc_response sends RPC response to AMS successfully.
"""
mock_request.return_value.status_code = http.HTTPStatus.OK
self.sut.send_rpc_response({'the': 'response'})
@mock.patch.object(requests.sessions.Session, 'request')
def test_send_rpc_response_raise_api_error(self, mock_request):
"""Verifies send_rpc_response raises ApiError if incorrect status code.
"""
mock_response = mock_request.return_value
mock_response.status_code = (
http.HTTPStatus.INTERNAL_SERVER_ERROR)
mock_response.json.return_value = {'errorMessage': 'the-ams-msg'}
with self.assertRaisesRegex(errors.ApiError, '500.*the-ams-msg'):
self.sut.send_rpc_response({'the': 'response'})
# Verify we have retried.
self.assertEqual(4, mock_request.call_count)
@mock.patch.object(requests.sessions.Session, 'request')
def test_send_rpc_response_raise_when_api_timed_out(self, mock_request):
"""Verifies send_rpc_response raises ApiTimeoutError if API timed out.
"""
mock_request.side_effect = requests.exceptions.Timeout
with self.assertRaises(errors.ApiTimeoutError):
self.sut.send_rpc_response({'the': 'response'})
# Verify we have retried.
self.assertEqual(4, mock_request.call_count)
@mock.patch.object(ams_client.AmsClient, '_get_auth_token')
@mock.patch.object(requests.sessions.Session, 'request')
def test_send_rpc_response_refresh_auth_if_first_attempt_has_401(
self, mock_request, mock_get_auth_token):
"""Verifies send_rpc_response will refresh auth token."""
mock_401_response = mock.Mock()
mock_401_response.status_code = http.HTTPStatus.UNAUTHORIZED
mock_401_response.json.return_value = {}
mock_500_response = mock.Mock()
mock_500_response.status_code = http.HTTPStatus.INTERNAL_SERVER_ERROR
mock_500_response.json.return_value = {}
mock_request.side_effect = (
mock_401_response,
requests.exceptions.Timeout,
requests.exceptions.Timeout,
requests.exceptions.Timeout,
mock_500_response,
)
with self.assertRaisesRegex(errors.ApiError, '500'):
self.sut.send_rpc_response({'the': 'response'})
# Verify we have refreshed the auth token, and that does not count as
# a retry.
self.assertEqual(1, mock_get_auth_token.call_count)
self.assertEqual(5, mock_request.call_count)
@mock.patch.object(requests.sessions.Session, 'request')
@mock.patch.object(ams_client, 'open', new_callable=mock.mock_open)
def test_upload_artifact_successful(self, mock_open, mock_request):
"""Verifies upload_artifact succeeds."""
mock_request.return_value.status_code = http.HTTPStatus.OK
self.sut.upload_artifact('the/file/path', 'the-test-result-id')
@mock.patch.object(requests.sessions.Session, 'request')
@mock.patch.object(ams_client, 'open', new_callable=mock.mock_open)
def test_upload_artifact_api_timed_out(self, mock_open, mock_request):
"""Verifies upload_artifact raises ApiTimeoutError if API timed out."""
mock_request.return_value.status_code = http.HTTPStatus.OK
mock_request.side_effect = requests.exceptions.Timeout
with self.assertRaises(errors.ApiTimeoutError):
self.sut.upload_artifact('the/file/path', 'the-test-result-id')
@mock.patch.object(requests.sessions.Session, 'request')
@mock.patch.object(ams_client, 'open', new_callable=mock.mock_open)
def test_upload_artifact_api_response_error(self, mock_open, mock_request):
"""Verifies upload_artifact raises when API response is error."""
mock_response = mock_request.return_value
mock_response.status_code = http.HTTPStatus.INTERNAL_SERVER_ERROR
mock_response.json.return_value = {'errorMessage': 'the-ams-err-msg'}
with self.assertRaisesRegex(errors.ApiError, '500.*the-ams-err-msg'):
self.sut.upload_artifact('the/file/path', 'the-test-result-id')
@mock.patch.object(requests.sessions.Session, 'request')
def test_request_wrapper_no_timeout_and_invalid_num_retries(
self, mock_request):
"""Verifies request_wrapper no timeout field and invalid num_retries."""
mock_response = mock.Mock(status_code=http.HTTPStatus.OK)
mock_request.return_value = mock_response
response = self.sut._request_wrapper(num_retries=-1)
self.assertEqual(mock_response, response)
def test_extract_error_message_from_api_response_success(self):
"""Verifies extract_error_message_from_api_response on success."""
fake_response = mock.Mock()
fake_response.json.return_value = {'errorMessage': 'error'}
self.assertEqual(
'error',
ams_client.extract_error_message_from_api_response(fake_response))
@mock.patch.object(ams_client, 'logger')
def test_extract_error_message_from_api_response_decode_error(
self, mock_logger):
"""Verifies extract_error_message_from_api_response decode error."""
fake_response = mock.Mock()
fake_doc = mock.Mock()
fake_doc.count.return_value = 0
fake_doc.rfind.return_value = 0
decode_error = json.decoder.JSONDecodeError('', fake_doc, 0)
fake_response.json.side_effect = decode_error
self.assertIsNone(
ams_client.extract_error_message_from_api_response(fake_response))
mock_logger.warning.assert_called_once_with(
'API response cannot be parsed as JSON.')
@mock.patch.object(ams_client, 'logger')
def test_extract_error_message_from_api_response_key_error(
self, mock_logger):
"""Verifies extract_error_message_from_api_response key error."""
fake_response = mock.Mock()
fake_response.json.side_effect = KeyError()
self.assertIsNone(
ams_client.extract_error_message_from_api_response(fake_response))
mock_logger.warning.assert_called_once_with(
'API response does not have errorMessage field')
if __name__ == '__main__':
unittest.main()