| # Copyright 2022 Google LLC |
| # |
| # Licensed under the Apache License, Version 2.0 (the "License"); |
| # you may not use this file except in compliance with the License. |
| # You may obtain a copy of the License at |
| # |
| # http://www.apache.org/licenses/LICENSE-2.0 |
| # |
| # Unless required by applicable law or agreed to in writing, software |
| # distributed under the License is distributed on an "AS IS" BASIS, |
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| # See the License for the specific language governing permissions and |
| # limitations under the License. |
| |
| """Unit tests for local_agent.ams_client module.""" |
| import http |
| import json |
| import requests |
| import time |
| import unittest |
| from unittest import mock |
| |
| from local_agent import ams_client |
| from local_agent import errors |
| |
| _FAKE_HOST = 'fake-host' |
| _FAKE_PORT = 8000 |
| |
| |
| class AmsClientTest(unittest.TestCase): |
| |
| def setUp(self): |
| super().setUp() |
| self.sut = ams_client.AmsClient(host=_FAKE_HOST, port=_FAKE_PORT) |
| sleep_patcher = mock.patch.object(time, 'sleep') |
| sleep_patcher.start() |
| self.addCleanup(sleep_patcher.stop) |
| |
| @mock.patch.object(ams_client.AmsClient, 'set_local_agent_credentials') |
| @mock.patch.object(requests.sessions.Session, 'request') |
| def test_register_success(self, mock_request, mock_set_credentials): |
| """Verifies register successful.""" |
| mock_response = mock_request.return_value |
| mock_response.status_code = 200 |
| mock_response.json.return_value = { |
| 'result': { |
| 'agentId': 'the-id', |
| 'agentSecret': 'the-secret'}} |
| |
| self.sut.register(linking_code='the-linking-code') |
| |
| mock_set_credentials.assert_called_once_with( |
| local_agent_id='the-id', |
| local_agent_secret='the-secret') |
| |
| @mock.patch.object(requests.sessions.Session, 'request') |
| def test_register_api_timeout(self, mock_request): |
| """Verifies register raises ApiTimeoutError when request timed out.""" |
| mock_request.side_effect = requests.exceptions.Timeout |
| with self.assertRaises(errors.ApiTimeoutError): |
| self.sut.register(linking_code='the-linking-code') |
| |
| @mock.patch.object(requests.sessions.Session, 'request') |
| def test_register_api_return_status_bad_request(self, mock_request): |
| """Verifies register raises CredentialError when API response 400.""" |
| mock_request.return_value.status_code = http.HTTPStatus.BAD_REQUEST |
| with self.assertRaises(errors.CredentialsError): |
| self.sut.register(linking_code='the-linking-code') |
| |
| @mock.patch.object(requests.sessions.Session, 'request') |
| def test_register_api_error_message_included(self, mock_request): |
| """Verifies register includes AMS error message in its exception.""" |
| mock_response = mock_request.return_value |
| mock_response.status_code = http.HTTPStatus.BAD_REQUEST |
| mock_response.json.return_value = {'errorMessage': 'the-message'} |
| |
| with self.assertRaisesRegex(errors.CredentialsError, 'the-message'): |
| self.sut.register(linking_code='the-linking-code') |
| |
| @mock.patch.object(requests.sessions.Session, 'request') |
| def test_register_api_return_status_not_ok(self, mock_request): |
| """Verifies register raises ApiError when API fails.""" |
| mock_request.return_value.status_code = http.HTTPStatus.NOT_FOUND |
| with self.assertRaises(errors.ApiError): |
| self.sut.register(linking_code='the-linking-code') |
| |
| @mock.patch.object(requests.sessions.Session, 'request') |
| def test_set_local_agent_credentials_get_auth_token_success(self, mock_request): |
| """Verifies set_local_agent_credentials successfully get auth token.""" |
| mock_response = mock_request.return_value |
| mock_response.status_code = http.HTTPStatus.CREATED |
| mock_response.json.return_value = { |
| 'result': { |
| 'authToken': 'the-auth-token', |
| } |
| } |
| self.sut.set_local_agent_credentials('agent-id', 'agent-secret') |
| |
| @mock.patch.object(ams_client.AmsClient, '_request_wrapper') |
| def test_get_auth_token_not_refreshing_auth(self, mock_request_wrapper): |
| """Verifies _get_auth_token sets refresh_auth to False.""" |
| mock_response = mock_request_wrapper.return_value |
| mock_response.status_code = http.HTTPStatus.CREATED |
| mock_response.json.return_value = {'result': {'authToken': 'abc'}} |
| |
| self.sut._get_auth_token() |
| |
| self.assertFalse( |
| mock_request_wrapper.call_args.kwargs['refresh_auth'], |
| msg='_get_auth_token needs to pass refresh_auth=False to ' |
| '_request_wrapper to avoid infinite recursion') |
| |
| @mock.patch.object(ams_client, 'extract_error_message_from_api_response') |
| @mock.patch.object(ams_client.AmsClient, '_request_wrapper') |
| def test_get_auth_token_unlinked_error( |
| self, mock_request_wrapper, mock_extract_err_msg): |
| """Verifies _get_auth_token raises UnlinkedError.""" |
| mock_response = mock_request_wrapper.return_value |
| mock_response.status_code = http.HTTPStatus.BAD_REQUEST |
| mock_extract_err_msg.return_value = 'Invalid agent id' |
| |
| error_regex = 'Local agent is unlinked' |
| with self.assertRaisesRegex(errors.UnlinkedError, error_regex): |
| self.sut._get_auth_token() |
| |
| @mock.patch.object(requests.sessions.Session, 'request') |
| def test_set_local_agent_credentials_raises_when_api_timeout(self, |
| mock_request): |
| """Verifies set_local_agent_credentials raises when API timed out.""" |
| mock_request.side_effect = requests.exceptions.Timeout |
| with self.assertRaises(errors.ApiTimeoutError): |
| self.sut.set_local_agent_credentials('agent-id', 'agent-secret') |
| |
| @mock.patch.object(requests.sessions.Session, 'request') |
| def test_set_local_agent_credentials_raises_when_api_error(self, |
| mock_request): |
| """Verifies set_local_agent_credentials raises when API has error.""" |
| mock_request.return_value.status_code = http.HTTPStatus.NOT_FOUND |
| with self.assertRaises(errors.CredentialsError): |
| self.sut.set_local_agent_credentials('agent-id', 'agent-secret') |
| |
| @mock.patch.object(requests.sessions.Session, 'request') |
| def test_set_local_agent_credentials_include_ams_error_message( |
| self, |
| mock_request): |
| """Verifies set_local_agent_credentials includes AMS error message.""" |
| mock_response = mock_request.return_value |
| mock_response.status_code = http.HTTPStatus.NOT_FOUND |
| mock_response.json.return_value = {'errorMessage': 'the-message'} |
| |
| with self.assertRaisesRegex(errors.CredentialsError, 'the-message'): |
| self.sut.set_local_agent_credentials('agent-id', 'agent-secret') |
| |
| @mock.patch.object(requests.sessions.Session, |
| 'request', |
| side_effect=requests.exceptions.Timeout) |
| def test_report_info_api_timeout(self, mock_request): |
| """Verifies report_info API timed out.""" |
| with self.assertRaises(errors.ApiTimeoutError): |
| self.sut.report_info({}) |
| |
| @mock.patch.object(requests.sessions.Session, 'request') |
| def test_report_info_api_response_error(self, mock_request): |
| """Verifies report_info raise exception when API response has error.""" |
| mock_request.return_value.status_code = http.HTTPStatus.BAD_REQUEST |
| with self.assertRaisesRegex(errors.ApiError, |
| 'Report info API failed: status 400'): |
| self.sut.report_info({}) |
| |
| @mock.patch.object(requests.sessions.Session, 'request') |
| def test_report_info_api_response_error_include_ams_error_message( |
| self, |
| mock_request): |
| """Verifies report_info raise exception when API response has error.""" |
| mock_response = mock_request.return_value |
| mock_response.status_code = http.HTTPStatus.BAD_REQUEST |
| mock_response.json.return_value = {'errorMessage': 'the-message'} |
| |
| with self.assertRaisesRegex(errors.ApiError, 'the-message'): |
| self.sut.report_info({}) |
| |
| @mock.patch.object(requests.sessions.Session, 'request') |
| def test_report_info_successful(self, mock_request): |
| """Verifies report_info succeeds and sends info dict to AMS.""" |
| mock_request.return_value.status_code = http.HTTPStatus.OK |
| local_agent_info = {'hi': 'hello'} |
| |
| self.sut.report_info(local_agent_info) |
| |
| self.assertIn( |
| 'json', |
| mock_request.call_args.kwargs, |
| 'Local agent info should be passed as json arg to requests call.') |
| self.assertEqual(local_agent_info, |
| mock_request.call_args.kwargs['json']) |
| |
| @mock.patch.object(requests.sessions.Session, 'request') |
| def test_get_rpc_request_from_ams_successful_with_request(self, mock_request): |
| """Verifies get_rpc_request_from_ams gets a request from AMS.""" |
| mock_response = mock_request.return_value |
| mock_response.status_code = http.HTTPStatus.OK |
| mock_response.json.return_value = {'result': {'hi': 'hello'}} |
| |
| self.assertEqual(self.sut.get_rpc_request_from_ams(), |
| {'hi': 'hello'}) |
| |
| @mock.patch.object(requests.sessions.Session, 'request') |
| def test_get_rpc_request_from_ams_successful_no_request(self, |
| mock_request): |
| """Verifies get_rpc_request_from_ams gets no request from AMS.""" |
| mock_response = mock_request.return_value |
| mock_response.status_code = http.HTTPStatus.NO_CONTENT |
| |
| self.assertIsNone(self.sut.get_rpc_request_from_ams()) |
| |
| @mock.patch.object(requests.sessions.Session, 'request') |
| def test_get_rpc_request_from_ams_raises_when_api_timeout(self, |
| mock_request): |
| """Verifies get_rpc_request_from_ams raises when API timed out.""" |
| mock_request.side_effect = requests.exceptions.Timeout |
| |
| with self.assertRaises(errors.ApiTimeoutError): |
| self.sut.get_rpc_request_from_ams() |
| |
| @mock.patch.object(requests.sessions.Session, 'request') |
| def test_get_rpc_request_from_ams_raise_when_api_response_error( |
| self, mock_request): |
| """Verifies get_rpc_request_from_ams raises when response has error.""" |
| mock_response = mock_request.return_value |
| mock_response.status_code = http.HTTPStatus.INTERNAL_SERVER_ERROR |
| |
| with self.assertRaises(errors.ApiError): |
| self.sut.get_rpc_request_from_ams() |
| |
| @mock.patch.object(requests.sessions.Session, 'request') |
| def test_get_rpc_request_from_ams_when_api_error_has_exception_message( |
| self, mock_request): |
| """Verifies get_rpc_request_from_ams error message when ApiError.""" |
| mock_response = mock_request.return_value |
| mock_response.status_code = http.HTTPStatus.INTERNAL_SERVER_ERROR |
| mock_response.json.return_value = {'errorMessage': 'message-from-ams'} |
| |
| with self.assertRaisesRegex(errors.ApiError, r'500.*message-from-ams'): |
| self.sut.get_rpc_request_from_ams() |
| |
| @mock.patch.object(requests.sessions.Session, 'request') |
| def test_remove_rpc_request_from_ams_successful(self, mock_request): |
| """Verifies remove_rpc_request_from_ams succeeds.""" |
| mock_response = mock_request.return_value |
| mock_response.status_code = http.HTTPStatus.OK |
| |
| self.sut.remove_rpc_request_from_ams({'the': 'request'}) |
| |
| @mock.patch.object(requests.sessions.Session, 'request') |
| def test_remove_rpc_request_from_ams_raises_when_api_timeout( |
| self, mock_request): |
| """Verifies remove_rpc_request_from_ams raises if API timed out.""" |
| mock_request.side_effect = requests.exceptions.Timeout |
| |
| with self.assertRaises(errors.ApiTimeoutError): |
| self.sut.remove_rpc_request_from_ams({'the': 'request'}) |
| |
| @mock.patch.object(requests.sessions.Session, 'request') |
| def test_remove_rpc_request_from_ams_raises_when_api_fails( |
| self, mock_request): |
| """Verifies remove_rpc_request_from_ams raises if API has error.""" |
| mock_response = mock_request.return_value |
| mock_response.status_code = http.HTTPStatus.INTERNAL_SERVER_ERROR |
| mock_response.json.return_value = {'errorMessage': 'msg-from-ams'} |
| |
| with self.assertRaisesRegex(errors.ApiError, r'500.*msg-from-ams'): |
| self.sut.remove_rpc_request_from_ams({'the': 'request'}) |
| |
| @mock.patch.object(requests.sessions.Session, 'request') |
| def test_send_rpc_response_successful(self, mock_request): |
| """Verifies send_rpc_response sends RPC response to AMS successfully. |
| """ |
| mock_request.return_value.status_code = http.HTTPStatus.OK |
| |
| self.sut.send_rpc_response({'the': 'response'}) |
| |
| @mock.patch.object(requests.sessions.Session, 'request') |
| def test_send_rpc_response_raise_api_error(self, mock_request): |
| """Verifies send_rpc_response raises ApiError if incorrect status code. |
| """ |
| mock_response = mock_request.return_value |
| mock_response.status_code = ( |
| http.HTTPStatus.INTERNAL_SERVER_ERROR) |
| mock_response.json.return_value = {'errorMessage': 'the-ams-msg'} |
| |
| with self.assertRaisesRegex(errors.ApiError, '500.*the-ams-msg'): |
| self.sut.send_rpc_response({'the': 'response'}) |
| # Verify we have retried. |
| self.assertEqual(4, mock_request.call_count) |
| |
| @mock.patch.object(requests.sessions.Session, 'request') |
| def test_send_rpc_response_raise_when_api_timed_out(self, mock_request): |
| """Verifies send_rpc_response raises ApiTimeoutError if API timed out. |
| """ |
| mock_request.side_effect = requests.exceptions.Timeout |
| |
| with self.assertRaises(errors.ApiTimeoutError): |
| self.sut.send_rpc_response({'the': 'response'}) |
| # Verify we have retried. |
| self.assertEqual(4, mock_request.call_count) |
| |
| @mock.patch.object(ams_client.AmsClient, '_get_auth_token') |
| @mock.patch.object(requests.sessions.Session, 'request') |
| def test_send_rpc_response_refresh_auth_if_first_attempt_has_401( |
| self, mock_request, mock_get_auth_token): |
| """Verifies send_rpc_response will refresh auth token.""" |
| mock_401_response = mock.Mock() |
| mock_401_response.status_code = http.HTTPStatus.UNAUTHORIZED |
| mock_401_response.json.return_value = {} |
| mock_500_response = mock.Mock() |
| mock_500_response.status_code = http.HTTPStatus.INTERNAL_SERVER_ERROR |
| mock_500_response.json.return_value = {} |
| mock_request.side_effect = ( |
| mock_401_response, |
| requests.exceptions.Timeout, |
| requests.exceptions.Timeout, |
| requests.exceptions.Timeout, |
| mock_500_response, |
| ) |
| |
| with self.assertRaisesRegex(errors.ApiError, '500'): |
| self.sut.send_rpc_response({'the': 'response'}) |
| # Verify we have refreshed the auth token, and that does not count as |
| # a retry. |
| self.assertEqual(1, mock_get_auth_token.call_count) |
| self.assertEqual(5, mock_request.call_count) |
| |
| @mock.patch.object(requests.sessions.Session, 'request') |
| @mock.patch.object(ams_client, 'open', new_callable=mock.mock_open) |
| def test_upload_artifact_successful(self, mock_open, mock_request): |
| """Verifies upload_artifact succeeds.""" |
| mock_request.return_value.status_code = http.HTTPStatus.OK |
| |
| self.sut.upload_artifact('the/file/path', 'the-test-result-id') |
| |
| @mock.patch.object(requests.sessions.Session, 'request') |
| @mock.patch.object(ams_client, 'open', new_callable=mock.mock_open) |
| def test_upload_artifact_api_timed_out(self, mock_open, mock_request): |
| """Verifies upload_artifact raises ApiTimeoutError if API timed out.""" |
| mock_request.return_value.status_code = http.HTTPStatus.OK |
| mock_request.side_effect = requests.exceptions.Timeout |
| |
| with self.assertRaises(errors.ApiTimeoutError): |
| self.sut.upload_artifact('the/file/path', 'the-test-result-id') |
| |
| @mock.patch.object(requests.sessions.Session, 'request') |
| @mock.patch.object(ams_client, 'open', new_callable=mock.mock_open) |
| def test_upload_artifact_api_response_error(self, mock_open, mock_request): |
| """Verifies upload_artifact raises when API response is error.""" |
| mock_response = mock_request.return_value |
| mock_response.status_code = http.HTTPStatus.INTERNAL_SERVER_ERROR |
| mock_response.json.return_value = {'errorMessage': 'the-ams-err-msg'} |
| |
| with self.assertRaisesRegex(errors.ApiError, '500.*the-ams-err-msg'): |
| self.sut.upload_artifact('the/file/path', 'the-test-result-id') |
| |
| @mock.patch.object(requests.sessions.Session, 'request') |
| def test_request_wrapper_no_timeout_and_invalid_num_retries( |
| self, mock_request): |
| """Verifies request_wrapper no timeout field and invalid num_retries.""" |
| mock_response = mock.Mock(status_code=http.HTTPStatus.OK) |
| mock_request.return_value = mock_response |
| |
| response = self.sut._request_wrapper(num_retries=-1) |
| |
| self.assertEqual(mock_response, response) |
| |
| def test_extract_error_message_from_api_response_success(self): |
| """Verifies extract_error_message_from_api_response on success.""" |
| fake_response = mock.Mock() |
| fake_response.json.return_value = {'errorMessage': 'error'} |
| |
| self.assertEqual( |
| 'error', |
| ams_client.extract_error_message_from_api_response(fake_response)) |
| |
| @mock.patch.object(ams_client, 'logger') |
| def test_extract_error_message_from_api_response_decode_error( |
| self, mock_logger): |
| """Verifies extract_error_message_from_api_response decode error.""" |
| fake_response = mock.Mock() |
| fake_doc = mock.Mock() |
| fake_doc.count.return_value = 0 |
| fake_doc.rfind.return_value = 0 |
| decode_error = json.decoder.JSONDecodeError('', fake_doc, 0) |
| fake_response.json.side_effect = decode_error |
| |
| self.assertIsNone( |
| ams_client.extract_error_message_from_api_response(fake_response)) |
| |
| mock_logger.warning.assert_called_once_with( |
| 'API response cannot be parsed as JSON.') |
| |
| @mock.patch.object(ams_client, 'logger') |
| def test_extract_error_message_from_api_response_key_error( |
| self, mock_logger): |
| """Verifies extract_error_message_from_api_response key error.""" |
| fake_response = mock.Mock() |
| fake_response.json.side_effect = KeyError() |
| |
| self.assertIsNone( |
| ams_client.extract_error_message_from_api_response(fake_response)) |
| |
| mock_logger.warning.assert_called_once_with( |
| 'API response does not have errorMessage field') |
| |
| |
| if __name__ == '__main__': |
| unittest.main() |