# Copyright 2021 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#      http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Unit tests for local_agent.ams_client module."""
import http
import json
import requests
import time
import unittest
from unittest import mock

from local_agent import ams_client
from local_agent import errors


class AmsClientTest(unittest.TestCase):

    def setUp(self):
        super().setUp()
        sleep_patcher = mock.patch.object(time, 'sleep')
        sleep_patcher.start()
        self.addCleanup(sleep_patcher.stop)

    @mock.patch.object(ams_client.AmsClient, 'set_local_agent_credentials')
    @mock.patch.object(requests.sessions.Session, 'request')
    def test_register_success(self, mock_request, mock_set_credentials):
        """Verifies register successful."""
        sut = ams_client.AmsClient()
        mock_response = mock_request.return_value
        mock_response.status_code = 200
        mock_response.json.return_value = {
            'result': {
                'agentId': 'the-id',
                'agentSecret': 'the-secret'}}

        sut.register(linking_code='the-linking-code')

        mock_set_credentials.assert_called_once_with(
            local_agent_id='the-id',
            local_agent_secret='the-secret')

    @mock.patch.object(requests.sessions.Session, 'request')
    def test_register_api_timeout(self, mock_request):
        """Verifies register raises ApiTimeoutError when request timed out."""
        sut = ams_client.AmsClient()
        mock_request.side_effect = requests.exceptions.Timeout
        with self.assertRaises(errors.ApiTimeoutError):
            sut.register(linking_code='the-linking-code')

    @mock.patch.object(requests.sessions.Session, 'request')
    def test_register_api_return_status_bad_request(self, mock_request):
        """Verifies register raises CredentialError when API response 400."""
        sut = ams_client.AmsClient()
        mock_request.return_value.status_code = http.HTTPStatus.BAD_REQUEST
        with self.assertRaises(errors.CredentialsError):
            sut.register(linking_code='the-linking-code')

    @mock.patch.object(requests.sessions.Session, 'request')
    def test_register_api_error_message_included(self, mock_request):
        """Verifies register includes AMS error message in its exception."""
        mock_response = mock_request.return_value
        mock_response.status_code = http.HTTPStatus.BAD_REQUEST
        mock_response.json.return_value = {'errorMessage': 'the-message'}
        sut = ams_client.AmsClient()

        with self.assertRaisesRegex(errors.CredentialsError, 'the-message'):
            sut.register(linking_code='the-linking-code')

    @mock.patch.object(requests.sessions.Session, 'request')
    def test_register_api_return_status_not_ok(self, mock_request):
        """Verifies register raises ApiError when API fails."""
        sut = ams_client.AmsClient()
        mock_request.return_value.status_code = http.HTTPStatus.NOT_FOUND
        with self.assertRaises(errors.ApiError):
            sut.register(linking_code='the-linking-code')

    @mock.patch.object(requests.sessions.Session, 'request')
    def test_set_local_agent_credentials_get_auth_token_success(self, mock_request):
        """Verifies set_local_agent_credentials successfully get auth token."""
        mock_response = mock_request.return_value
        mock_response.status_code = http.HTTPStatus.CREATED
        mock_response.json.return_value = {
            'result': {
                'authToken': 'the-auth-token',
            }
        }
        sut = ams_client.AmsClient()
        sut.set_local_agent_credentials('agent-id', 'agent-secret')

    @mock.patch.object(ams_client.AmsClient, '_request_wrapper')
    def test_get_auth_token_not_refreshing_auth(self, mock_request_wrapper):
        """Verifies _get_auth_token sets refresh_auth to False."""
        mock_response = mock_request_wrapper.return_value
        mock_response.status_code = http.HTTPStatus.CREATED
        mock_response.json.return_value = {'result': {'authToken': 'abc'}}
        sut = ams_client.AmsClient()

        sut._get_auth_token()

        self.assertFalse(
            mock_request_wrapper.call_args.kwargs['refresh_auth'],
            msg='_get_auth_token needs to pass refresh_auth=False to '
                '_request_wrapper to avoid infinite recursion')

    @mock.patch.object(ams_client, 'extract_error_message_from_api_response')
    @mock.patch.object(ams_client.AmsClient, '_request_wrapper')
    def test_get_auth_token_unlinked_error(
        self, mock_request_wrapper, mock_extract_err_msg):
        """Verifies _get_auth_token raises UnlinkedError."""
        mock_response = mock_request_wrapper.return_value
        mock_response.status_code = http.HTTPStatus.BAD_REQUEST
        mock_extract_err_msg.return_value = 'Invalid agent id'
        sut = ams_client.AmsClient()

        error_regex = 'Local agent is unlinked'
        with self.assertRaisesRegex(errors.UnlinkedError, error_regex):
            sut._get_auth_token()

    @mock.patch.object(requests.sessions.Session, 'request')
    def test_set_local_agent_credentials_raises_when_api_timeout(self,
                                                                 mock_request):
        """Verifies set_local_agent_credentials raises when API timed out."""
        mock_request.side_effect = requests.exceptions.Timeout
        sut = ams_client.AmsClient()
        with self.assertRaises(errors.ApiTimeoutError):
            sut.set_local_agent_credentials('agent-id', 'agent-secret')

    @mock.patch.object(requests.sessions.Session, 'request')
    def test_set_local_agent_credentials_raises_when_api_error(self,
                                                               mock_request):
        """Verifies set_local_agent_credentials raises when API has error."""
        mock_request.return_value.status_code = http.HTTPStatus.NOT_FOUND
        sut = ams_client.AmsClient()
        with self.assertRaises(errors.CredentialsError):
            sut.set_local_agent_credentials('agent-id', 'agent-secret')

    @mock.patch.object(requests.sessions.Session, 'request')
    def test_set_local_agent_credentials_include_ams_error_message(
        self,
        mock_request):
        """Verifies set_local_agent_credentials includes AMS error message."""
        mock_response = mock_request.return_value
        mock_response.status_code = http.HTTPStatus.NOT_FOUND
        mock_response.json.return_value = {'errorMessage': 'the-message'}
        sut = ams_client.AmsClient()

        with self.assertRaisesRegex(errors.CredentialsError, 'the-message'):
            sut.set_local_agent_credentials('agent-id', 'agent-secret')

    @mock.patch.object(requests.sessions.Session,
                       'request',
                       side_effect=requests.exceptions.Timeout)
    def test_report_info_api_timeout(self, mock_request):
        """Verifies report_info API timed out."""
        sut = ams_client.AmsClient()
        with self.assertRaises(errors.ApiTimeoutError):
            sut.report_info({})

    @mock.patch.object(requests.sessions.Session, 'request')
    def test_report_info_api_response_error(self, mock_request):
        """Verifies report_info raise exception when API response has error."""
        mock_request.return_value.status_code = http.HTTPStatus.BAD_REQUEST
        sut = ams_client.AmsClient()
        with self.assertRaisesRegex(errors.ApiError,
                                    'Report info API failed: status 400'):
            sut.report_info({})

    @mock.patch.object(requests.sessions.Session, 'request')
    def test_report_info_api_response_error_include_ams_error_message(
        self,
        mock_request):
        """Verifies report_info raise exception when API response has error."""
        mock_response = mock_request.return_value
        mock_response.status_code = http.HTTPStatus.BAD_REQUEST
        mock_response.json.return_value = {'errorMessage': 'the-message'}
        sut = ams_client.AmsClient()

        with self.assertRaisesRegex(errors.ApiError, 'the-message'):
            sut.report_info({})

    @mock.patch.object(requests.sessions.Session, 'request')
    def test_report_info_successful(self, mock_request):
        """Verifies report_info succeeds and sends info dict to AMS."""
        mock_request.return_value.status_code = http.HTTPStatus.OK
        local_agent_info = {'hi': 'hello'}
        sut = ams_client.AmsClient()

        sut.report_info(local_agent_info)

        self.assertIn(
            'json',
            mock_request.call_args.kwargs,
            'Local agent info should be passed as json arg to requests call.')
        self.assertEqual(local_agent_info,
                         mock_request.call_args.kwargs['json'])

    @mock.patch.object(requests.sessions.Session, 'request')
    def test_get_rpc_request_from_ams_successful_with_request(self, mock_request):
        """Verifies get_rpc_request_from_ams gets a request from AMS."""
        mock_response = mock_request.return_value
        mock_response.status_code = http.HTTPStatus.OK
        mock_response.json.return_value = {'result': {'hi': 'hello'}}
        sut = ams_client.AmsClient()

        self.assertEqual(sut.get_rpc_request_from_ams(),
                         {'hi': 'hello'})

    @mock.patch.object(requests.sessions.Session, 'request')
    def test_get_rpc_request_from_ams_successful_no_request(self,
                                                            mock_request):
        """Verifies get_rpc_request_from_ams gets no request from AMS."""
        mock_response = mock_request.return_value
        mock_response.status_code = http.HTTPStatus.NO_CONTENT
        sut = ams_client.AmsClient()

        self.assertIsNone(sut.get_rpc_request_from_ams())

    @mock.patch.object(requests.sessions.Session, 'request')
    def test_get_rpc_request_from_ams_raises_when_api_timeout(self,
                                                              mock_request):
        """Verifies get_rpc_request_from_ams raises when API timed out."""
        mock_request.side_effect = requests.exceptions.Timeout
        sut = ams_client.AmsClient()

        with self.assertRaises(errors.ApiTimeoutError):
            sut.get_rpc_request_from_ams()

    @mock.patch.object(requests.sessions.Session, 'request')
    def test_get_rpc_request_from_ams_raise_when_api_response_error(
        self, mock_request):
        """Verifies get_rpc_request_from_ams raises when response has error."""
        mock_response = mock_request.return_value
        mock_response.status_code = http.HTTPStatus.INTERNAL_SERVER_ERROR
        sut = ams_client.AmsClient()

        with self.assertRaises(errors.ApiError):
            sut.get_rpc_request_from_ams()

    @mock.patch.object(requests.sessions.Session, 'request')
    def test_get_rpc_request_from_ams_when_api_error_has_exception_message(
        self, mock_request):
        """Verifies get_rpc_request_from_ams error message when ApiError."""
        mock_response = mock_request.return_value
        mock_response.status_code = http.HTTPStatus.INTERNAL_SERVER_ERROR
        mock_response.json.return_value = {'errorMessage': 'message-from-ams'}
        sut = ams_client.AmsClient()

        with self.assertRaisesRegex(errors.ApiError, r'500.*message-from-ams'):
            sut.get_rpc_request_from_ams()

    @mock.patch.object(requests.sessions.Session, 'request')
    def test_remove_rpc_request_from_ams_successful(self, mock_request):
        """Verifies remove_rpc_request_from_ams succeeds."""
        mock_response = mock_request.return_value
        mock_response.status_code = http.HTTPStatus.OK
        sut = ams_client.AmsClient()
        sut.remove_rpc_request_from_ams({'the': 'request'})

    @mock.patch.object(requests.sessions.Session, 'request')
    def test_remove_rpc_request_from_ams_raises_when_api_timeout(
        self, mock_request):
        """Verifies remove_rpc_request_from_ams raises if API timed out."""
        mock_request.side_effect = requests.exceptions.Timeout
        sut = ams_client.AmsClient()
        with self.assertRaises(errors.ApiTimeoutError):
            sut.remove_rpc_request_from_ams({'the': 'request'})

    @mock.patch.object(requests.sessions.Session, 'request')
    def test_remove_rpc_request_from_ams_raises_when_api_fails(
        self, mock_request):
        """Verifies remove_rpc_request_from_ams raises if API has error."""
        mock_response = mock_request.return_value
        mock_response.status_code = http.HTTPStatus.INTERNAL_SERVER_ERROR
        mock_response.json.return_value = {'errorMessage': 'msg-from-ams'}
        sut = ams_client.AmsClient()

        with self.assertRaisesRegex(errors.ApiError, r'500.*msg-from-ams'):
            sut.remove_rpc_request_from_ams({'the': 'request'})

    @mock.patch.object(requests.sessions.Session, 'request')
    def test_send_rpc_response_successful(self, mock_request):
        """Verifies send_rpc_response sends RPC response to AMS successfully.
        """
        mock_request.return_value.status_code = http.HTTPStatus.OK
        sut = ams_client.AmsClient()
        sut.send_rpc_response({'the': 'response'})

    @mock.patch.object(requests.sessions.Session, 'request')
    def test_send_rpc_response_raise_api_error(self, mock_request):
        """Verifies send_rpc_response raises ApiError if incorrect status code.
        """
        mock_response = mock_request.return_value
        mock_response.status_code = (
            http.HTTPStatus.INTERNAL_SERVER_ERROR)
        mock_response.json.return_value = {'errorMessage': 'the-ams-msg'}
        sut = ams_client.AmsClient()

        with self.assertRaisesRegex(errors.ApiError, '500.*the-ams-msg'):
            sut.send_rpc_response({'the': 'response'})
        # Verify we have retried.
        self.assertEqual(4, mock_request.call_count)

    @mock.patch.object(requests.sessions.Session, 'request')
    def test_send_rpc_response_raise_when_api_timed_out(self, mock_request):
        """Verifies send_rpc_response raises ApiTimeoutError if API timed out.
        """
        mock_request.side_effect = requests.exceptions.Timeout
        sut = ams_client.AmsClient()

        with self.assertRaises(errors.ApiTimeoutError):
            sut.send_rpc_response({'the': 'response'})
        # Verify we have retried.
        self.assertEqual(4, mock_request.call_count)

    @mock.patch.object(ams_client.AmsClient, '_get_auth_token')
    @mock.patch.object(requests.sessions.Session, 'request')
    def test_send_rpc_response_refresh_auth_if_first_attempt_has_401(
        self, mock_request, mock_get_auth_token):
        """Verifies send_rpc_response will refresh auth token."""
        mock_401_response = mock.Mock()
        mock_401_response.status_code = http.HTTPStatus.UNAUTHORIZED
        mock_401_response.json.return_value = {}
        mock_500_response = mock.Mock()
        mock_500_response.status_code = http.HTTPStatus.INTERNAL_SERVER_ERROR
        mock_500_response.json.return_value = {}
        mock_request.side_effect = (
            mock_401_response,
            requests.exceptions.Timeout,
            requests.exceptions.Timeout,
            requests.exceptions.Timeout,
            mock_500_response,
        )
        sut = ams_client.AmsClient()

        with self.assertRaisesRegex(errors.ApiError, '500'):
            sut.send_rpc_response({'the': 'response'})
        # Verify we have refreshed the auth token, and that does not count as
        # a retry.
        self.assertEqual(1, mock_get_auth_token.call_count)
        self.assertEqual(5, mock_request.call_count)

    @mock.patch.object(requests.sessions.Session, 'request')
    @mock.patch.object(ams_client, 'open', new_callable=mock.mock_open)
    def test_upload_artifact_successful(self, mock_open, mock_request):
        """Verifies upload_artifact succeeds."""
        mock_request.return_value.status_code = http.HTTPStatus.OK
        sut = ams_client.AmsClient()
        sut.upload_artifact('the/file/path', 'the-test-result-id')

    @mock.patch.object(requests.sessions.Session, 'request')
    @mock.patch.object(ams_client, 'open', new_callable=mock.mock_open)
    def test_upload_artifact_api_timed_out(self, mock_open, mock_request):
        """Verifies upload_artifact raises ApiTimeoutError if API timed out."""
        mock_request.return_value.status_code = http.HTTPStatus.OK
        mock_request.side_effect = requests.exceptions.Timeout
        sut = ams_client.AmsClient()
        with self.assertRaises(errors.ApiTimeoutError):
            sut.upload_artifact('the/file/path', 'the-test-result-id')

    @mock.patch.object(requests.sessions.Session, 'request')
    @mock.patch.object(ams_client, 'open', new_callable=mock.mock_open)
    def test_upload_artifact_api_response_error(self, mock_open, mock_request):
        """Verifies upload_artifact raises when API response is error."""
        mock_response = mock_request.return_value
        mock_response.status_code = http.HTTPStatus.INTERNAL_SERVER_ERROR
        mock_response.json.return_value = {'errorMessage': 'the-ams-err-msg'}
        sut = ams_client.AmsClient()

        with self.assertRaisesRegex(errors.ApiError, '500.*the-ams-err-msg'):
            sut.upload_artifact('the/file/path', 'the-test-result-id')

    @mock.patch.object(requests.sessions.Session, 'request')
    def test_request_wrapper_no_timeout_and_invalid_num_retries(
        self, mock_request):
        """Verifies request_wrapper no timeout field and invalid num_retries."""
        mock_response = mock.Mock(status_code=http.HTTPStatus.OK)
        mock_request.return_value = mock_response
        sut = ams_client.AmsClient()

        response = sut._request_wrapper(num_retries=-1)

        self.assertEqual(mock_response, response)

    def test_extract_error_message_from_api_response_success(self):
        """Verifies extract_error_message_from_api_response on success."""
        fake_response = mock.Mock()
        fake_response.json.return_value = {'errorMessage': 'error'}

        self.assertEqual(
            'error',
            ams_client.extract_error_message_from_api_response(fake_response))

    @mock.patch.object(ams_client, 'logger')
    def test_extract_error_message_from_api_response_decode_error(
        self, mock_logger):
        """Verifies extract_error_message_from_api_response decode error."""
        fake_response = mock.Mock()
        fake_doc = mock.Mock()
        fake_doc.count.return_value = 0
        fake_doc.rfind.return_value = 0
        decode_error = json.decoder.JSONDecodeError('', fake_doc, 0)
        fake_response.json.side_effect = decode_error

        self.assertIsNone(
            ams_client.extract_error_message_from_api_response(fake_response))

        mock_logger.warning.assert_called_once_with(
            'API response cannot be parsed as JSON.')

    @mock.patch.object(ams_client, 'logger')
    def test_extract_error_message_from_api_response_key_error(
        self, mock_logger):
        """Verifies extract_error_message_from_api_response key error."""
        fake_response = mock.Mock()
        fake_response.json.side_effect = KeyError()

        self.assertIsNone(
            ams_client.extract_error_message_from_api_response(fake_response))

        mock_logger.warning.assert_called_once_with(
            'API response does not have errorMessage field')


if __name__ == '__main__':
    unittest.main()
