# Copyright 2022 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#      http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Unit tests for translation layer."""
import collections
import unittest
from unittest import mock

from local_agent import errors as agent_errors
from local_agent.translation_layer import gdm_manager
from local_agent.translation_layer import translation_layer
from mobly.controllers import android_device
from ui_automator import errors as ua_errors
from ui_automator import ui_automator


####################### Fake data for unit test #############################
_FAKE_RPC_ID = 'fake-rpc-id'
_FAKE_DEVICE_TYPE = 'fake-device-type'
_FAKE_DEVICE_ID = 'fake-device-id'
_FAKE_CAPABILITY = 'fake-device-capability'
_FAKE_PAIRING_CODE = '34970112332'
_FAKE_MATTER_DEVICE_NAME = 'fake-matter-device-name'
_FAKE_GHA_ROOM = 'Office'
_FAKE_SERIAL = 'fake-serial'
_SET_ON = 'setOn'
_SET_LOCK = 'setLock'
_COMMISSION_TO_GOOGLE_FABRIC = 'commissionToGoogleFabric'
##############################################################################


def rpc_request(method, params):
    """Simple wrapper for json rpc request."""
    return {'jsonrpc': '2.0', 'id': 0, 'method': method, 'params': params}


class TranslationLayerTest(unittest.TestCase):
    """Unit tests for local agent translation layer."""

    def setUp(self):
        super().setUp()
        self.mock_client = mock.Mock()
        self.translator = translation_layer.TranslationLayer(
            client=self.mock_client)

    @mock.patch.object(gdm_manager.GdmManager, 'create_devices')
    def test_01_create_devices(self, mock_create):
        """Verifies create_devices method on success."""
        fake_dut_ids = [_FAKE_DEVICE_ID]
        fake_suite_dir = ''
        self.translator.create_devices(fake_dut_ids, fake_suite_dir)
        mock_create.assert_called_once_with(fake_dut_ids, fake_suite_dir)

    @mock.patch.object(gdm_manager.GdmManager, 'close_open_devices')
    def test_02_close_devices(self, mock_close):
        """Verifies close_devices method on success."""
        self.translator.close_devices()
        mock_close.assert_called_once()

    @mock.patch.object(gdm_manager.GdmManager, 'detect_devices')
    def test_03_detect_devices(self, mock_detect):
        """Verifies detect_devices on success."""
        self.translator.detect_devices()
        self.assertEqual(1, mock_detect.call_count)

    @mock.patch.object(translation_layer.TranslationLayer, '_get_cmd_handler')
    @mock.patch.object(gdm_manager.GdmManager, 'check_device_connected')
    def test_04_dispatch_to_cmd_handler_on_success(self, mock_check, mock_get):
        """Verifies dispatch_to_cmd_handler on success."""
        self.translator._busy_devices.clear()
        set_on = rpc_request(_SET_ON, {'dutDeviceId': _FAKE_DEVICE_ID})
        self.translator.dispatch_to_cmd_handler(set_on)
        self.assertEqual(1, mock_check.call_count)
        self.assertEqual(1, mock_get.call_count)

    def test_04_dispatch_to_cmd_handler_on_failure_no_device(self):
        """Verifies dispatch_to_cmd_handler on failure with no device params."""
        invalid_rpc_request = rpc_request(_SET_ON, {})
        error_msg = 'no dutDeviceId in params'
        with self.assertRaisesRegex(agent_errors.InvalidRPCError, error_msg):
            self.translator.dispatch_to_cmd_handler(invalid_rpc_request)

    @mock.patch.object(translation_layer, 'issubclass', return_value=True)
    def test_05_validate_handlers_cls_map_on_success(self, mock_isusbclass):
        """Verifies validate_handlers_cls_map on success."""
        fake_handlers = [mock.Mock(SUPPORTED_METHODS={_SET_ON,}),
                         mock.Mock(SUPPORTED_METHODS={_SET_LOCK,})]
        translation_layer.validate_handlers_cls_map(fake_handlers)
        self.assertEqual(2, mock_isusbclass.call_count)

    @mock.patch.object(translation_layer, 'issubclass', return_value=False)
    def test_05_validate_handlers_cls_map_on_failure_not_subclass(
        self, mock_issubclass):
        """Verifies validate_handlers_cls_map invalid handler."""
        error_msg = 'is not a subclass of BaseCommandHandler.'
        with self.assertRaisesRegex(agent_errors.HandlerInvalidError, error_msg):
            translation_layer.validate_handlers_cls_map([mock.Mock(__name__='')])

    @mock.patch.object(translation_layer, 'issubclass', return_value=True)
    def test_05_validate_handlers_cls_map_on_failure_collision(
        self, mock_issubclass):
        """Verifies validate_handlers_cls_map collision."""
        fake_handlers = [mock.Mock(SUPPORTED_METHODS={_SET_ON,}, __name__='1'),
                         mock.Mock(SUPPORTED_METHODS={_SET_ON,}, __name__='2')]
        error_msg = f'Handlers 1 and 2 have duplicate methods: {_SET_ON}.'
        with self.assertRaisesRegex(agent_errors.HandlersCollisionError, error_msg):
            translation_layer.validate_handlers_cls_map(fake_handlers)
        self.assertEqual(2, mock_issubclass.call_count)

    def test_06_get_cmd_handler_on_success_already_exists(self):
        """Verifies get_cmd_handler on success with existing handler."""
        method = _SET_ON
        self.translator._cmd_handlers[_FAKE_DEVICE_ID] = (
            mock.Mock(SUPPORTED_METHODS={method,}))
        handler = self.translator._get_cmd_handler(_FAKE_DEVICE_ID, method)
        self.assertIn(method, handler.SUPPORTED_METHODS)

    def test_06_get_cmd_handler_on_failure_method_not_found(self):
        """Verifies get_cmd_handler on failure method not found."""
        self.translator._cmd_handlers[_FAKE_DEVICE_ID] = (
            mock.Mock(SUPPORTED_METHODS={_SET_LOCK,}))
        error_msg = 'No matching command handler'
        with self.assertRaisesRegex(agent_errors.HandlerNotFoundError, error_msg):
            self.translator._get_cmd_handler(_FAKE_DEVICE_ID, _SET_ON)

    @mock.patch.object(gdm_manager.GdmManager, 'get_device_instance')
    @mock.patch.object(gdm_manager.GdmManager, 'get_device_type')
    def test_06_get_cmd_handler_on_success_not_exists(
        self, mock_get_type, mock_get_inst):
        """Verifies get_cmd_handler on success without existing handler."""
        method = _SET_ON
        self.translator._cmd_handlers.clear()
        self.translator._handlers_cls_map[_FAKE_DEVICE_TYPE] = (
            {method: mock.Mock()})
        mock_get_type.return_value = _FAKE_DEVICE_TYPE
        handler = self.translator._get_cmd_handler(_FAKE_DEVICE_ID, method)
        self.assertEqual(
            self.translator._cmd_handlers[_FAKE_DEVICE_ID], handler)
        self.assertEqual(1, mock_get_type.call_count)
        self.assertEqual(1, mock_get_inst.call_count)

    @mock.patch.object(gdm_manager.GdmManager, 'get_device_type')
    def test_06_get_cmd_handler_on_failure_invalid_device_type(self, mock_get_type):
        """Verifies get_cmd_handler on failure with invalid device type."""
        fake_not_exist_type = 'not-exists-type'
        mock_get_type.return_value = fake_not_exist_type
        error_msg = f'device type: {fake_not_exist_type} is not implemented'
        with self.assertRaisesRegex(agent_errors.HandlerNotFoundError, error_msg):
            self.translator._get_cmd_handler(_FAKE_DEVICE_TYPE, '')
        self.assertEqual(1, mock_get_type.call_count)

    @mock.patch.object(gdm_manager.GdmManager, 'get_device_type')
    def test_06_get_cmd_handler_on_failure_invalid_method(self, mock_get_type):
        """Verifies get_cmd_handler on failure with invalid method."""
        self.translator._cmd_handlers.clear()
        self.translator._handlers_cls_map[_FAKE_DEVICE_TYPE] = (
            {_SET_ON: mock.Mock()})
        mock_get_type.return_value = _FAKE_DEVICE_TYPE
        error_msg = 'method: setLock is not implemented'
        with self.assertRaisesRegex(agent_errors.HandlerNotFoundError, error_msg):
            self.translator._get_cmd_handler(_FAKE_DEVICE_ID, 'setLock')
        self.assertEqual(1, mock_get_type.call_count)

    def test_07_device_operation_handler_on_failure_is_busy(self):
        """Verifies device_operation_handler raises when device is still busy"""
        self.translator._busy_devices = {_FAKE_DEVICE_ID,}
        error_msg = f'{_FAKE_DEVICE_ID} is still busy'
        with self.assertRaisesRegex(agent_errors.InvalidRPCError, error_msg):
            with self.translator.device_operation_handler(
                _FAKE_DEVICE_ID, _FAKE_RPC_ID):
                pass

    def test_07_device_operation_handler_on_failure_duplicate_rpc(self):
        """Verifies device_operation_handler raises when duplicate RPC"""
        self.translator._rpc_execution_start_time[_FAKE_RPC_ID] = 0
        error_msg = f'RPC {_FAKE_RPC_ID} is already executing.'
        with self.assertRaisesRegex(agent_errors.InvalidRPCError, error_msg):
            with self.translator.device_operation_handler(
                _FAKE_DEVICE_ID, _FAKE_RPC_ID):
                pass

    @mock.patch.object(translation_layer, 'validate_handlers_cls_map')
    def test_08_update_handlers_cls_map_on_success_creation(
        self, mock_validate):
        """Verifies _update_handlers_cls_map creation on success."""
        fake_handler = mock.Mock(SUPPORTED_METHODS={_SET_ON,})
        fake_capabilities_map = {_FAKE_CAPABILITY: fake_handler}
        translation_layer.GDM_CAPABILITIES_TO_COMMAND_HANDLERS = (
            fake_capabilities_map)
        self.translator.update_handlers_cls_map(
            _FAKE_DEVICE_TYPE, [_FAKE_CAPABILITY])
        handler = self.translator._handlers_cls_map[_FAKE_DEVICE_TYPE][_SET_ON]
        self.assertEqual(fake_handler, handler)
        self.assertEqual(1, mock_validate.call_count)

    @mock.patch.object(translation_layer, 'validate_handlers_cls_map')
    def test_08_update_handlers_cls_map_does_nothing(self, mock_validate):
        """Verifies the _update_handlers_cls_map does nothing."""
        self.translator._handlers_cls_map[_FAKE_DEVICE_TYPE][_SET_ON] = None
        self.translator.update_handlers_cls_map(
            _FAKE_DEVICE_TYPE, [_FAKE_CAPABILITY])
        self.assertEqual(0, mock_validate.call_count)

    def test_09_check_rpc_timeout(self):
        """Verifies _check_rpc_timeout catches timeout RPCs."""
        fake_terminiation_event = mock.Mock()
        fake_terminiation_event.wait.side_effect = [False, True]
        self.translator._termination_event = fake_terminiation_event
        self.translator._timeout_rpc.clear()
        self.translator._rpc_execution_start_time = {_FAKE_DEVICE_ID: 0}

        self.translator._check_rpc_timeout()

        self.mock_client.send_rpc_response.assert_called_once()
        self.assertTrue(self.translator.is_rpc_timeout(_FAKE_DEVICE_ID))

    def test_commission_to_google_fabric_raises_an_error_without_pairing_code(
      self,
    ):
        """Verifies commission_to_google_fabric method raises an error when no pairingCode was provided."""
        invalid_rpc_request = rpc_request(_COMMISSION_TO_GOOGLE_FABRIC, {})
        error_msg = 'Invalid rpc command, no pairingCode in params was found.'

        with self.assertRaisesRegex(agent_errors.InvalidRPCError, error_msg):
            self.translator.commission_to_google_fabric(invalid_rpc_request)

    def test_commission_to_google_fabric_raises_an_error_without_device_name(
        self,
    ):
        """Verifies commission_to_google_fabric method raises an error when no deviceName was provided."""
        invalid_rpc_request = rpc_request(
            _COMMISSION_TO_GOOGLE_FABRIC,
            {
                'pairingCode': _FAKE_PAIRING_CODE,
            },
        )
        error_msg = 'Invalid rpc command, no deviceName in params was found.'

        with self.assertRaisesRegex(agent_errors.InvalidRPCError, error_msg):
            self.translator.commission_to_google_fabric(invalid_rpc_request)

    def test_commission_to_google_fabric_raises_an_error_without_gha_room(self):
        """Verifies commission_to_google_fabric method raises an error when no ghaRoom was provided."""
        invalid_rpc_request = rpc_request(
            _COMMISSION_TO_GOOGLE_FABRIC,
            {
                'pairingCode': _FAKE_PAIRING_CODE,
                'deviceName': _FAKE_MATTER_DEVICE_NAME,
            },
        )
        error_msg = 'Invalid rpc command, no ghaRoom in params was found.'

        with self.assertRaisesRegex(agent_errors.InvalidRPCError, error_msg):
            self.translator.commission_to_google_fabric(invalid_rpc_request)

    @mock.patch.object(android_device, 'AndroidDevice')
    @mock.patch.object(android_device, 'get_all_instances')
    @mock.patch.object(translation_layer, 'logger')
    @mock.patch.object(ui_automator.UIAutomator, 'commission_device')
    def test_commission_to_google_fabric_raises_an_error_when_commissioning_fails(
        self,
        mock_commission_device,
        mock_logger,
        mock_get_all_instances,
        mock_android_device,
    ):
        """Verifies commission_to_google_fabric method raises an error when running commissioning method in snippet apk fails."""
        mock_get_all_instances.return_value = [
            android_device.AndroidDevice(_FAKE_SERIAL)
        ]
        expected_error_message = (
            'Unable to continue automated commissioning process on'
            ' device(fake-serial).'
        )
        mock_commission_device.side_effect = ua_errors.MoblySnippetError(
            expected_error_message
        )

        response = self.translator.commission_to_google_fabric(
            rpc_request(
                _COMMISSION_TO_GOOGLE_FABRIC,
                {
                    'pairingCode': _FAKE_PAIRING_CODE,
                    'deviceName': _FAKE_MATTER_DEVICE_NAME,
                    'ghaRoom': _FAKE_GHA_ROOM,
                },
            )
        )

        self.assertEqual(
            response,
            {
                'jsonrpc': '2.0',
                'id': 0,
                'result': {
                    'isCommissioned': False,
                    'errorLog': expected_error_message,
                },
            },
        )
        mock_commission_device.assert_called_once_with(
            pairing_code=_FAKE_PAIRING_CODE,
            device_name=_FAKE_MATTER_DEVICE_NAME,
            gha_room=_FAKE_GHA_ROOM,
        )
        mock_logger.info.assert_called_once_with(
            f'Completed request for 0, pairingCode: {_FAKE_PAIRING_CODE},'
            f' deviceName: {_FAKE_MATTER_DEVICE_NAME}, ghaRoom: {_FAKE_GHA_ROOM}'
        )

    @mock.patch.object(android_device, 'AndroidDevice')
    @mock.patch.object(android_device, 'get_all_instances')
    @mock.patch.object(translation_layer, 'logger')
    @mock.patch.object(ui_automator.UIAutomator, 'commission_device')
    def test_commission_to_google_fabric_on_success(
        self,
        mock_commission_device,
        mock_logger,
        mock_get_all_instances,
        mock_android_device,
    ):
        """Verifies commission_to_google_fabric on success when required params are provided."""
        mock_get_all_instances.return_value = [
            android_device.AndroidDevice(_FAKE_SERIAL)
        ]

        response = self.translator.commission_to_google_fabric(
            rpc_request(
                _COMMISSION_TO_GOOGLE_FABRIC,
                {
                    'pairingCode': _FAKE_PAIRING_CODE,
                    'deviceName': _FAKE_MATTER_DEVICE_NAME,
                    'ghaRoom': _FAKE_GHA_ROOM,
                },
            )
        )

        mock_commission_device.assert_called_once_with(
            pairing_code=_FAKE_PAIRING_CODE,
            device_name=_FAKE_MATTER_DEVICE_NAME,
            gha_room=_FAKE_GHA_ROOM,
        )
        mock_logger.info.assert_called_once_with(
            f'Completed request for 0, pairingCode: {_FAKE_PAIRING_CODE},'
            f' deviceName: {_FAKE_MATTER_DEVICE_NAME}, ghaRoom: {_FAKE_GHA_ROOM}'
        )
        self.assertEqual(
            response,
            {'jsonrpc': '2.0', 'id': 0, 'result': {'isCommissioned': True}},
        )

if __name__ == '__main__':
    unittest.main(failfast=True)
