# Copyright 2022 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#      http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Unit tests for local agent process."""
import builtins
from concurrent import futures
import configparser
import os
import shutil
import sys
import tempfile
import threading
import unittest
from unittest import mock

from local_agent import ams_client
from local_agent import errors
from local_agent import local_agent
from local_agent import suite_session_manager
from local_agent.translation_layer import translation_layer


####################### Fake data for unit test #############################
_FAKE_AMS_HOST = 'localhost'
_FAKE_AGENT_ID = 'fake-agent-id'
_FAKE_AGENT_SECRET = 'fake-agent-secret'
_FAKE_AUTH_TOKEN = 'fake-auth-token'
_FAKE_RPC_RESPONSE = {'fake': 'response'}
_FAKE_RPC_ID = 'fake-rpc-id'
_FAKE_ERROR_MSG = 'fake-error-msg'
_FAKE_COMMISSION_ERROR_MSG = 'Unable to commission the device.'
_FAKE_ARTIFACTS_DIR = 'fake-artifacts-dir'
_START_TEST_SUITE = 'startTestSuite'
_END_TEST_SUITE = 'endTestSuite'
_COMMISSION_TO_GOOGLE_FABRIC = 'commissionToGoogleFabric'
_LOCK_DEVICE = 'setLock'
_FAKE_CONTROLLER_PACKAGE = 'FAKE_CONTROLLER_PACKAGE'
#############################################################################


class LocalAgentTest(unittest.TestCase):
    """Unit tests for local agent process."""

    def setUp(self):
        super().setUp()
        self.proc = local_agent.LocalAgentProcess(
            client=ams_client.AmsClient(host=_FAKE_AMS_HOST), 
            artifacts_dir=_FAKE_ARTIFACTS_DIR)

        _, local_agent.AUTH_FILE = self._create_temp_file_with_clean_up()
        _, local_agent.DEFAULT_USER_CONFIG = (
                self._create_temp_file_with_clean_up())

    def _create_temp_file_with_clean_up(self):
        """Creates a temp file and registers clean-up procedure.

        We use tempfile.mkstemp to create a temporary file, and clean it up
        using self.addCleanup provided by unittest package.

        Returns:
          Tuple of (file_descriptor, file_path). Exactly what is returned by a
          tempfile.mkstemp call.
        """
        fd, path = tempfile.mkstemp()
        self.addCleanup(os.close, fd)
        self.addCleanup(os.remove, path)
        return fd, path

    @mock.patch.object(ams_client.AmsClient, 'set_local_agent_credentials')
    @mock.patch.object(local_agent.LocalAgentProcess,
                       '_read_auths',
                       return_value=('agent-id', 'agent-secret'))
    def test_setup_credentials_existing_credential_success(
        self,
        mock_read_auths,
        mock_set_local_agent_credentials):
        """Verifies _setup_credentials succeeds with existing credentials."""
        self.assertTrue(self.proc._setup_credentials())

    @mock.patch.object(builtins, 'input', side_effect=KeyboardInterrupt)
    @mock.patch.object(local_agent.LocalAgentProcess, '_read_auths')
    def test_setup_credentials_no_existing_credential_will_start_linking(
        self, mock_read_auths, mock_input):
        """Verifies _setup_credentials starts linking if no credentials."""
        mock_read_auths.side_effect = FileNotFoundError
        self.assertFalse(self.proc._setup_credentials())
        self.assertEqual(1, mock_input.call_count)

    @mock.patch.object(builtins, 'input', side_effect=KeyboardInterrupt)
    @mock.patch.object(ams_client.AmsClient,
                       'set_local_agent_credentials',
                       side_effect=errors.CredentialsError)
    @mock.patch.object(local_agent.LocalAgentProcess,
                       '_read_auths',
                       return_value=('agent-id', 'agent-secret'))
    def test_setup_credentials_bad_existing_credential_will_start_linking(
        self,
        mock_read_auths,
        mock_set_local_agent_credentials,
        mock_input):
        """Verifies _setup_credentials starts linking if bad credentials."""
        self.assertFalse(self.proc._setup_credentials())
        self.assertEqual(1, mock_input.call_count)

    @mock.patch.object(ams_client.AmsClient, 'register')
    @mock.patch.object(builtins, 'input', return_value='the-code')
    @mock.patch.object(local_agent.LocalAgentProcess,
                       '_read_auths',
                       side_effect=FileNotFoundError)
    def test_setup_credentials_start_linking_and_succeed(
        self, mock_read_auths, mock_inpnut, mock_register):
        """Verifies _setup_credentials starts linking and succeeds."""
        mock_register.return_value = ('the-agent-id', 'the-agent-secret')
        self.assertTrue(self.proc._setup_credentials())
        mock_register.assert_called_once_with(linking_code='the-code')

    @mock.patch.object(ams_client.AmsClient, 'register')
    @mock.patch.object(builtins, 'input', return_value='the-code')
    @mock.patch.object(local_agent.LocalAgentProcess,
                       '_read_auths',
                       side_effect=FileNotFoundError)
    def test_setup_credentials_start_linking_and_keeps_retry(
        self,
        mock_read_auths,
        mock_input,
        mock_register):
        """Verifies _setup_credentials keeps retry when linking fails."""
        mock_register.side_effect = (
            errors.ApiTimeoutError,
            errors.ApiTimeoutError,
            errors.CredentialsError,
            errors.CredentialsError,
            errors.ApiError,
            errors.CredentialsError,
            ('the-agent-id', 'the-agent-secret'),
        )
        self.assertTrue(self.proc._setup_credentials())
        self.assertEqual(7, mock_input.call_count)

    def test_read_write_auth(self):
        """Verifies reading/writing auths."""
        self.proc._write_auths(_FAKE_AGENT_ID, _FAKE_AGENT_SECRET)
        self.assertEqual(
            (_FAKE_AGENT_ID, _FAKE_AGENT_SECRET), self.proc._read_auths())

    def test_read_config_with_inexistent_file(self):
        """Verifies read_config returns {} when config file doesn't exist."""
        local_agent.DEFAULT_USER_CONFIG = ''
        self.assertEqual(({}, {}), local_agent.read_config())

    @mock.patch.object(configparser, 'ConfigParser')
    def test_read_config_missing_root_key(self, mock_parser):
        """Verifies read_config raise ValueError when root key not present."""
        mock_config = mock.MagicMock()
        mock_parser.return_value = mock_config
        mock_config.__contains__.return_value = False
        with self.assertRaisesRegex(ValueError, 'Invalid config file'):
            local_agent.read_config()

    @mock.patch.object(configparser, 'ConfigParser')
    def test_read_config_success(self, mock_parser):
        """Verifies read_config run successfully."""
        mock_config = mock.MagicMock()
        mock_parser.return_value = mock_config
        mock_config.__contains__.return_value = True

        local_agent.read_config()

        mock_config.read.assert_called_once()

    @mock.patch.object(
        local_agent.LocalAgentProcess,'_start_info_reporting_thread')
    @mock.patch.object(
        local_agent.LocalAgentProcess, '_start_rpc_polling_thread')
    @mock.patch.object(suite_session_manager.SuiteSessionManager, 'start')
    @mock.patch.object(
        local_agent.LocalAgentProcess, '_setup_credentials', return_value=True)
    @mock.patch.object(threading.Event, 'is_set', side_effect=(False, True))
    def test_run_starts_two_top_level_threads(
        self,
        mock_event_is_set,
        mock_setup_credentials,
        mock_start,
        mock_start_rpc_polling_thread,
        mock_start_info_reporting_thread):
        """Verifies run() starts RPC polling and info reporting threads."""
        self.proc.run()
        self.assertEqual(1, mock_start_rpc_polling_thread.call_count)
        self.assertEqual(1, mock_start_info_reporting_thread.call_count)

    @mock.patch.object(
        local_agent.LocalAgentProcess,'_start_info_reporting_thread')
    @mock.patch.object(
        local_agent.LocalAgentProcess, '_start_rpc_polling_thread')
    @mock.patch.object(local_agent.LocalAgentProcess,
                       '_setup_credentials',
                       return_value=False)
    def test_run_will_exit_if_cannot_setup_credentials(
        self,
        mock_setup_credentials,
        mock_start_rpc_polling_thread,
        mock_start_info_reporting_thread):
        """Verifies run() aborts if _setup_credentials failed."""
        self.proc.run()
        self.assertFalse(mock_start_rpc_polling_thread.called)
        self.assertFalse(mock_start_info_reporting_thread.called)

    @mock.patch.object(
        translation_layer.TranslationLayer, 'detect_devices', return_value=[])
    @mock.patch.object(
        ams_client.AmsClient, 'get_rpc_request_from_ams', return_value=None)
    @mock.patch.object(ams_client.AmsClient, 'report_info')
    @mock.patch.object(suite_session_manager.SuiteSessionManager, 'start')
    @mock.patch.object(
        local_agent.LocalAgentProcess, '_setup_credentials', return_value=True)
    def test_run_exits_main_thread_if_report_info_thread_is_dead(
        self,
        _,
        mock_start,
        mock_report_info,
        mock_get_rpc_request,
        mock_detect_devices):
        """Verifies run() terminates local agent if report info thread is dead.
        """
        mock_report_info.side_effect = RuntimeError()
        self.proc.run()

    @mock.patch.object(
        translation_layer.TranslationLayer, 'detect_devices', return_value=[])
    @mock.patch.object(ams_client.AmsClient, 'get_rpc_request_from_ams')
    @mock.patch.object(ams_client.AmsClient, 'report_info')
    @mock.patch.object(suite_session_manager.SuiteSessionManager, 'start')
    @mock.patch.object(
        local_agent.LocalAgentProcess, '_setup_credentials', return_value=True)
    def test_run_exits_main_thread_if_poll_rpc_thread_is_dead(
        self,
        _,
        mock_start,
        mock_report_info,
        mock_get_rpc_request,
        mock_detect_devices):
        """Verifies run() terminates local agent if poll RPC thread is dead."""
        mock_get_rpc_request.side_effect = RuntimeError()
        self.proc.run()

    @mock.patch.object(translation_layer.TranslationLayer, 'detect_devices')
    @mock.patch.object(ams_client.AmsClient, 'report_info')
    @mock.patch.object(threading.Event, 'wait', return_value=True)
    def test_report_info_sends_request_to_ams(self,
                                              mock_event_wait,
                                              mock_report_info,
                                              mock_detect_devices):
        """Verifies _report_info uses AmsClient to report info."""
        mock_detect_devices.return_value = []
        self.proc._report_info()
        self.assertEqual(1, mock_report_info.call_count)
        self.assertEqual(1, mock_detect_devices.call_count)

    @mock.patch.object(translation_layer.TranslationLayer, 'detect_devices')
    @mock.patch.object(ams_client.AmsClient, 'report_info')
    @mock.patch.object(threading.Event, 'wait', return_value=True)
    def test_report_info_wont_break_when_api_error(self,
                                                   mock_event_wait,
                                                   mock_report_info,
                                                   mock_detect_devices):
        """Verifies _report_info continues when an API error happens."""
        mock_report_info.side_effect = errors.ApiError
        self.proc._report_info()

    @mock.patch.object(translation_layer.TranslationLayer, 'detect_devices')
    @mock.patch.object(ams_client.AmsClient, 'report_info')
    @mock.patch.object(threading.Event, 'wait', return_value=True)
    def test_report_info_wont_break_when_api_timeout(self,
                                                     mock_event_wait,
                                                     mock_report_info,
                                                     mock_detect_devices):
        """Verifies _report_info continues when an API error happens."""
        mock_report_info.side_effect = errors.ApiTimeoutError
        self.proc._report_info()

    @mock.patch.object(
        local_agent.LocalAgentProcess, '_clean_up_and_terminate_agent')
    @mock.patch.object(translation_layer.TranslationLayer, 'detect_devices')
    @mock.patch.object(ams_client.AmsClient, 'report_info')
    def test_report_info_break_when_agent_unlinked(self,
                                                   mock_report_info,
                                                   mock_detect_devices,
                                                   mock_clean_up):
        """Verifies _report_info breaks when the local agent is unlinked."""
        mock_report_info.side_effect = errors.UnlinkedError

        self.proc._report_info()

        mock_clean_up.assert_called_once()


    @mock.patch.object(ams_client.AmsClient,
                       'get_rpc_request_from_ams',
                       return_value=None)
    @mock.patch.object(threading.Event, 'is_set')
    def test_poll_rpc_gets_request_from_ams(
        self, mock_event_is_set, mock_get_rpc_request):
        """Verifies _poll_rpc will get request from AMS in each iteration."""
        # We let there be 2 iterations.
        mock_event_is_set.side_effect = (False, False, True)
        self.proc._poll_rpc()
        self.assertEqual(2, mock_get_rpc_request.call_count)

    @mock.patch.object(ams_client.AmsClient, 'remove_rpc_request_from_ams')
    @mock.patch.object(ams_client.AmsClient, 'get_rpc_request_from_ams')
    @mock.patch.object(threading.Event, 'is_set', side_effect=(False, True))
    def test_poll_rpc_continues_if_api_error_when_getting_request(
        self,
        mock_event_is_set,
        mock_get_rpc_request,
        mock_remove_rpc):
        """Verifies _poll_rpc continues when getting RPC has ApiError."""
        mock_get_rpc_request.side_effect = errors.ApiError
        self.proc._poll_rpc()
        self.assertFalse(mock_remove_rpc.called)

    @mock.patch.object(ams_client.AmsClient, 'remove_rpc_request_from_ams')
    @mock.patch.object(ams_client.AmsClient, 'get_rpc_request_from_ams')
    @mock.patch.object(threading.Event, 'is_set', side_effect=(False, True))
    def test_poll_rpc_continues_if_api_timeout_when_getting_request(
        self,
        mock_event_is_set,
        mock_get_rpc_request,
        mock_remove_rpc):
        """Verifies _poll_rpc continues when get RPC request API timed out."""
        mock_get_rpc_request.side_effect = errors.ApiTimeoutError
        self.proc._poll_rpc()
        self.assertFalse(mock_remove_rpc.called)

    @mock.patch.object(
        local_agent.LocalAgentProcess, '_clean_up_and_terminate_agent')
    @mock.patch.object(ams_client.AmsClient, 'get_rpc_request_from_ams')
    @mock.patch.object(threading.Event, 'is_set', return_value=False)
    def test_poll_rpc_raises_if_agent_unlinked_when_getting_request(
        self,
        mock_event_is_set,
        mock_get_rpc_request,
        mock_clean_up_and_terminate_agent):
        """Verifies _poll_rpc raises when get RPC request agent unlinked."""
        mock_get_rpc_request.side_effect = errors.UnlinkedError

        self.proc._poll_rpc()

        mock_clean_up_and_terminate_agent.assert_called_once()

    @mock.patch.object(local_agent.LocalAgentProcess,
                       '_start_rpc_execution_thread')
    @mock.patch.object(ams_client.AmsClient, 'remove_rpc_request_from_ams')
    @mock.patch.object(ams_client.AmsClient, 'get_rpc_request_from_ams')
    @mock.patch.object(threading.Event, 'is_set', side_effect=(False, True))
    def test_poll_rpc_removes_rpc_request_from_ams_and_executes(
        self,
        mock_event_is_set,
        mock_get_rpc_request,
        mock_remove_rpc,
        mock_start_rpc_execution):
        """Verifies _poll_rpc removes RPC request from AMS and executes it."""
        fake_rpc_request = {'hi': 'hello'}
        mock_get_rpc_request.return_value = fake_rpc_request

        self.proc._poll_rpc()

        self.assertEqual(1, mock_get_rpc_request.call_count)
        mock_remove_rpc.assert_called_once_with(fake_rpc_request)
        mock_start_rpc_execution.assert_called_once_with(fake_rpc_request)

    @mock.patch.object(ams_client.AmsClient, 'remove_rpc_request_from_ams')
    @mock.patch.object(ams_client.AmsClient, 'get_rpc_request_from_ams')
    def test_poll_rpc_terminate_local_agent_when_remove_rpc_fails(
        self,
        mock_get_request,
        mock_remove_request):
        """Verifies _poll_rpc terminates local agent when remove RPC fails."""
        mock_get_request.return_value = {'hi': 'rpc-request-here'}
        mock_remove_request.side_effect = errors.ApiError

        self.proc._poll_rpc()

    @mock.patch.object(futures.ThreadPoolExecutor, 'submit')
    def test_start_rpc_execution_thread_submits_to_thread_pool_executor(
        self, mock_submit):
        """Verifies _start_rpc_execution_thread submits to ThreadPoolExecutor.
        """
        fake_rpc_request = {'hi': 'hello'}
        fake_future = mock.Mock()
        mock_submit.return_value = fake_future

        self.proc._start_rpc_execution_thread(fake_rpc_request)

        self.assertEqual(
            1,
            mock_submit.call_count,
            'Should submit a task to ThreadPoolExecutor but did not.')
        self.assertIn(
            id(fake_future),
            self.proc._rpc_execution_future_ids,
            'Should keep track of the future but did not.')
        # Should register the callback to the future.
        fake_future.add_done_callback.assert_called_once_with(
            self.proc._callback_for_rpc_execution_complete)

    @mock.patch.object(futures.ThreadPoolExecutor, 'shutdown')
    def test_terminate_shutdown_pool_executor(self, mock_shutdown):
        """Verifies _terminate shuts down the ThreadPoolExecutor."""
        self.proc._terminate(None, None)
        self.assertEqual(1, mock_shutdown.call_count)

    @mock.patch.object(threading.Event, 'set')
    def test_terminate_sets_threading_event(self, mock_set):
        """Verifies _terminate sets the threading event."""
        self.proc._terminate(None, None)
        mock_set.assert_called_once()

    @mock.patch.object(local_agent, 'logger')
    def test_terminate_thread_still_alive(self, mock_logger):
        """Verifies _terminates on failure with still alive threads."""
        mock_rpc_thread = mock.Mock()
        mock_rpc_thread.is_alive.return_value = True
        self.proc._rpc_polling_thread = mock_rpc_thread

        self.proc._terminate(None, None)

        mock_rpc_thread.join.assert_called_once()
        mock_logger.error.assert_called_once()

    @mock.patch.object(ams_client.AmsClient, 'upload_artifact')
    @mock.patch.object(local_agent, 'os')
    @mock.patch.object(shutil, 'rmtree')
    @mock.patch.object(shutil, 'make_archive')
    def test_compress_artifacts_and_upload_on_success(
        self, mock_make, mock_rm, mock_os, mock_ams_upload):
        """Verifies _compress_artifacts_and_upload on success."""
        mock_os.stat.return_value.st_size = 1
        with mock.patch('builtins.open',
                        new_callable=mock.mock_open):
            self.proc._compress_artifacts_and_upload('', '')
        mock_make.assert_called_once()
        mock_rm.assert_called_once()
        mock_os.stat.assert_called_once()
        mock_ams_upload.assert_called_once()

    @mock.patch.object(os, 'stat')
    @mock.patch.object(shutil, 'rmtree')
    @mock.patch.object(shutil, 'make_archive')
    def test_compress_artifacts_and_upload_too_large_file(
        self, mock_make, mock_rm, mock_stat):
        """
        Verifies _compress_artifacts_and_upload on failure with too large file.
        """
        mock_stat.return_value.st_size = (
            local_agent.APP_ENGINE_DATA_SIZE_LIMIT + 1)
        error_mesg = (
            f'larger than '
            f'{local_agent.APP_ENGINE_DATA_SIZE_LIMIT_HUMAN_READABLE}')
        with self.assertRaisesRegex(RuntimeError, error_mesg):
            self.proc._compress_artifacts_and_upload('', '')
        self.assertEqual(1, mock_make.call_count)
        self.assertEqual(1, mock_rm.call_count)

    @mock.patch.object(ams_client.AmsClient, 'upload_artifact')
    @mock.patch.object(os, 'stat')
    @mock.patch.object(shutil, 'rmtree')
    @mock.patch.object(shutil, 'make_archive')
    def test_compress_artifacts_and_upload_uploading_timed_out(
        self, mock_make, mock_rm, mock_stat, mock_ams_client_upload):
        """
        Verifies _compress_artifacts_and_upload on failure due to API timed out.
        """
        mock_stat.return_value.st_size = 1
        mock_ams_client_upload.side_effect = errors.ApiTimeoutError
        with mock.patch('builtins.open',
                        new_callable=mock.mock_open):
            self.proc._compress_artifacts_and_upload('', '')
        self.assertEqual(1, mock_ams_client_upload.call_count)

    @mock.patch.object(ams_client.AmsClient, 'upload_artifact')
    @mock.patch.object(os, 'stat')
    @mock.patch.object(shutil, 'rmtree')
    @mock.patch.object(shutil, 'make_archive')
    def test_compress_artifacts_and_upload_uploading_api_error(
        self, mock_make, mock_rm, mock_stat, mock_ams_client_upload):
        """
        Verifies _compress_artifacts_and_upload on failure due to API error.
        """
        mock_stat.return_value.st_size = 1
        mock_ams_client_upload.side_effect = errors.ApiError
        with mock.patch('builtins.open',
                        new_callable=mock.mock_open):
            self.proc._compress_artifacts_and_upload('', '')
        self.assertEqual(1, mock_ams_client_upload.call_count)

    @mock.patch.object(local_agent, 'LocalAgentProcess')
    @mock.patch.object(local_agent, 'register_extension_controllers')
    @mock.patch.object(local_agent, 'read_config')
    @mock.patch.object(local_agent, 'exit_if_query_module_versions')
    def test_main_entry(self, mock_query, mock_read, mock_register, mock_proc):
        """Verifies local agent main entry on success."""
        mock_read.return_value = {}, []

        local_agent.main()

        mock_query.assert_called_once()
        mock_read.assert_called_once()
        mock_register.assert_called_once()
        mock_proc.assert_called_once()
        mock_proc.return_value.run.assert_called_once()

    @mock.patch.object(translation_layer.TranslationLayer,'is_rpc_timeout')
    @mock.patch.object(ams_client.AmsClient, 'send_rpc_response')
    @mock.patch.object(local_agent.LocalAgentProcess, '_handle_rpc_request')
    def test_execute_rpc_executes_and_sends_result_to_ams(
        self,
        mock_handle_request,
        mock_send_rpc_response,
        mock_is_rpc_timeout):
        """Verifies _execute_rpc executes the RPC and sends the result."""
        fake_rpc_request = {'the': 'rpc-request'}
        fake_rpc_response = {'the': 'rpc-response'}
        mock_handle_request.return_value = fake_rpc_response
        mock_is_rpc_timeout.return_value = False

        self.proc._execute_rpc(fake_rpc_request)

        mock_handle_request.assert_called_once_with(fake_rpc_request)
        mock_send_rpc_response.assert_called_once_with(fake_rpc_response)

    @mock.patch.object(translation_layer.TranslationLayer,'is_rpc_timeout')
    @mock.patch.object(local_agent, 'logger')
    @mock.patch.object(ams_client.AmsClient, 'send_rpc_response')
    @mock.patch.object(local_agent.LocalAgentProcess, '_handle_rpc_request')
    def test_execute_rpc_fail_to_send_rpc_response(
        self,
        mock_handle_request,
        mock_send_rpc_response,
        mock_logger,
        mock_is_rpc_timeout):
        """Verifies _execute_rpc continues when fail to send RPC response."""
        fake_rpc_request = {'the': 'rpc-request'}
        fake_rpc_response = {'the': 'rpc-response'}
        mock_handle_request.return_value = fake_rpc_response
        mock_send_rpc_response.side_effect = errors.ApiError
        mock_is_rpc_timeout.return_value = False

        self.proc._execute_rpc(fake_rpc_request)

        mock_send_rpc_response.assert_called_once()
        mock_logger.exception.assert_called_once()

    @mock.patch.object(translation_layer.TranslationLayer,'is_rpc_timeout')
    @mock.patch.object(local_agent, 'logger')
    @mock.patch.object(local_agent.LocalAgentProcess, '_handle_rpc_request')
    def test_execute_rpc_not_sending_timeout_rpc_response(
        self,
        mock_handle_request,
        mock_logger,
        mock_is_rpc_timeout):
        """Verifies _execute_rpc not sending timeout RPC response."""
        fake_rpc_request = {'id': _FAKE_RPC_ID}
        mock_is_rpc_timeout.return_value = True

        self.proc._execute_rpc(fake_rpc_request)

        mock_logger.warning.assert_called_once()

    @mock.patch.object(
        suite_session_manager.SuiteSessionManager, 'start_test_suite')
    def test_handle_rpc_request_start_suite(self, mock_start):
        """Verifies handle_rpc_request to start suite on success."""
        mock_start.return_value = _FAKE_RPC_RESPONSE
        fake_rpc_request = {'method': _START_TEST_SUITE}

        rpc_response = self.proc._handle_rpc_request(fake_rpc_request)

        self.assertEqual(_FAKE_RPC_RESPONSE, rpc_response)
        mock_start.assert_called_once_with(fake_rpc_request)

    @mock.patch.object(
        suite_session_manager.SuiteSessionManager, 'end_test_suite')
    def test_handle_rpc_request_end_suite(self, mock_end):
        """Verifies handle_rpc_request to end suite on success."""
        mock_end.return_value = _FAKE_RPC_RESPONSE
        fake_rpc_request = {'method': _END_TEST_SUITE}

        rpc_response = self.proc._handle_rpc_request(fake_rpc_request)

        self.assertEqual(_FAKE_RPC_RESPONSE, rpc_response)
        mock_end.assert_called_once_with(fake_rpc_request)

    @mock.patch.object(
        translation_layer.TranslationLayer, 'dispatch_to_cmd_handler')
    def test_handle_rpc_request_device_control(self, mock_dispatch):
        """Verifies handle_rpc_request to control device on success."""
        mock_dispatch.return_value = _FAKE_RPC_RESPONSE
        fake_rpc_request = {'method': _LOCK_DEVICE}

        rpc_response = self.proc._handle_rpc_request(fake_rpc_request)

        self.assertEqual(_FAKE_RPC_RESPONSE, rpc_response)
        mock_dispatch.assert_called_once_with(fake_rpc_request)

    @mock.patch.object(
        translation_layer.TranslationLayer, 'commission_to_google_fabric')
    def test_handle_rpc_request_commission_to_google_fabric_on_success(self, mock_commission):
        """Verifies handle_rpc_request to commission a device to google fabric on success."""
        mock_commission.return_value = _FAKE_RPC_RESPONSE
        fake_rpc_request = {'method': _COMMISSION_TO_GOOGLE_FABRIC}

        rpc_response = self.proc._handle_rpc_request(fake_rpc_request)

        self.assertEqual(_FAKE_RPC_RESPONSE, rpc_response)
        mock_commission.assert_called_once_with(fake_rpc_request)

    @mock.patch.object(translation_layer.TranslationLayer, 'commission_to_google_fabric')
    def test_handle_rpc_request_returns_a_commissioning_error(
        self, mock_commission
    ):
        """Verifies handle_rpc_request returns a commissioning error."""
        mock_commission.side_effect = errors.InvalidRPCError(
            _FAKE_COMMISSION_ERROR_MSG
        )
        fake_rpc_request = {
            'id': _FAKE_RPC_ID,
            'method': _COMMISSION_TO_GOOGLE_FABRIC,
        }

        rpc_response = self.proc._handle_rpc_request(fake_rpc_request)

        self.assertIn(_FAKE_COMMISSION_ERROR_MSG, rpc_response['error']['message'])

    @mock.patch.object(local_agent, 'rpc_request_type', return_value='')
    def test_handle_rpc_request_invalid_rpc_type(self, mock_req_type):
        """Verifies handle_rpc_request on failure with invalid RPC type."""
        error_msg = 'Invalid RPC request type'

        rpc_response = self.proc._handle_rpc_request({'id': '', 'method': ''})

        self.assertIn(error_msg, rpc_response['error']['message'])

    @mock.patch.object(
        translation_layer.TranslationLayer, 'dispatch_to_cmd_handler')
    def test_handle_rpc_unexpected_errors(self, mock_dispatch):
        """Verifies handle_rpc_request on failure with unexpected errors."""
        mock_dispatch.side_effect = RuntimeError(_FAKE_ERROR_MSG)
        fake_rpc_request = {'id': _FAKE_RPC_ID, 'method': _LOCK_DEVICE}

        rpc_response = self.proc._handle_rpc_request(fake_rpc_request)

        self.assertIn(_FAKE_ERROR_MSG, rpc_response['error']['message'])

    @mock.patch.object(os, 'remove')
    @mock.patch.object(os.path, 'exists', return_value=True)
    @mock.patch.object(suite_session_manager.SuiteSessionManager, 'clean_up')
    def test_clean_up_and_terminate_agent_on_success(
        self, mock_clean_up, mock_exists, mock_rm):
        """Verifies _clean_up_and_terminate_agent on success."""
        self.proc._clean_up_and_terminate_agent(remove_auth_file=True)
        mock_clean_up.assert_called_once()
        mock_rm.assert_called_once()

    @mock.patch.object(local_agent, 'logger')
    def test_callback_for_rpc_execution_complete(self, mock_logger):
        """Verifies _callback_for_rpc_execution_complete on success."""
        mock_future = mock.Mock()
        self.proc._rpc_execution_future_ids.add(id(mock_future))

        self.proc._callback_for_rpc_execution_complete(mock_future)

        mock_future.exception.assert_called_once()
        mock_logger.error.assert_called_once()

    @mock.patch.object(local_agent, 'gazoo_device')
    @mock.patch.object(local_agent, 'importlib')
    def test_register_extension_controllers_on_success(
        self, mock_importlib, mock_gazoo_device):
        """Verifies register_extension_controllers on success."""
        matter_controllers_config = [(_FAKE_CONTROLLER_PACKAGE, '')]

        local_agent.register_extension_controllers(matter_controllers_config)

        mock_importlib.import_module.assert_called_once()
        mock_gazoo_device.register.assert_called_once()

    @mock.patch.object(sys, 'exit')
    def test_exit_if_query_module_versions(self, mock_exit):
        """Verifies exit_if_query_module_versions on success."""
        with mock.patch.object(sys, 'argv', ['-v']):
            local_agent.exit_if_query_module_versions()
            mock_exit.assert_called_once()


if __name__ == '__main__':
    unittest.main(failfast=True)
