| # Copyright 2022 Google LLC |
| # |
| # Licensed under the Apache License, Version 2.0 (the "License"); |
| # you may not use this file except in compliance with the License. |
| # You may obtain a copy of the License at |
| # |
| # http://www.apache.org/licenses/LICENSE-2.0 |
| # |
| # Unless required by applicable law or agreed to in writing, software |
| # distributed under the License is distributed on an "AS IS" BASIS, |
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| # See the License for the specific language governing permissions and |
| # limitations under the License. |
| |
| """Unit tests for local agent process.""" |
| import builtins |
| from concurrent import futures |
| import configparser |
| import os |
| import shutil |
| import sys |
| import tempfile |
| import threading |
| import unittest |
| from unittest import mock |
| |
| from local_agent import ams_client |
| from local_agent import errors |
| from local_agent import local_agent |
| from local_agent import suite_session_manager |
| from local_agent.translation_layer import translation_layer |
| |
| |
| ####################### Fake data for unit test ############################# |
| _FAKE_AMS_HOST = 'localhost' |
| _FAKE_AGENT_ID = 'fake-agent-id' |
| _FAKE_AGENT_SECRET = 'fake-agent-secret' |
| _FAKE_AUTH_TOKEN = 'fake-auth-token' |
| _FAKE_RPC_RESPONSE = {'fake': 'response'} |
| _FAKE_RPC_ID = 'fake-rpc-id' |
| _FAKE_ERROR_MSG = 'fake-error-msg' |
| _FAKE_COMMISSION_ERROR_MSG = 'Unable to commission the device.' |
| _FAKE_ARTIFACTS_DIR = 'fake-artifacts-dir' |
| _START_TEST_SUITE = 'startTestSuite' |
| _END_TEST_SUITE = 'endTestSuite' |
| _COMMISSION_TO_GOOGLE_FABRIC = 'commissionToGoogleFabric' |
| _LOCK_DEVICE = 'setLock' |
| _FAKE_CONTROLLER_PACKAGE = 'FAKE_CONTROLLER_PACKAGE' |
| ############################################################################# |
| |
| |
| class LocalAgentTest(unittest.TestCase): |
| """Unit tests for local agent process.""" |
| |
| def setUp(self): |
| super().setUp() |
| self.proc = local_agent.LocalAgentProcess( |
| client=ams_client.AmsClient(host=_FAKE_AMS_HOST), |
| artifacts_dir=_FAKE_ARTIFACTS_DIR) |
| |
| _, local_agent.AUTH_FILE = self._create_temp_file_with_clean_up() |
| _, local_agent.DEFAULT_USER_CONFIG = ( |
| self._create_temp_file_with_clean_up()) |
| |
| def _create_temp_file_with_clean_up(self): |
| """Creates a temp file and registers clean-up procedure. |
| |
| We use tempfile.mkstemp to create a temporary file, and clean it up |
| using self.addCleanup provided by unittest package. |
| |
| Returns: |
| Tuple of (file_descriptor, file_path). Exactly what is returned by a |
| tempfile.mkstemp call. |
| """ |
| fd, path = tempfile.mkstemp() |
| self.addCleanup(os.close, fd) |
| self.addCleanup(os.remove, path) |
| return fd, path |
| |
| @mock.patch.object(ams_client.AmsClient, 'set_local_agent_credentials') |
| @mock.patch.object(local_agent.LocalAgentProcess, |
| '_read_auths', |
| return_value=('agent-id', 'agent-secret')) |
| def test_setup_credentials_existing_credential_success( |
| self, |
| mock_read_auths, |
| mock_set_local_agent_credentials): |
| """Verifies _setup_credentials succeeds with existing credentials.""" |
| self.assertTrue(self.proc._setup_credentials()) |
| |
| @mock.patch.object(builtins, 'input', side_effect=KeyboardInterrupt) |
| @mock.patch.object(local_agent.LocalAgentProcess, '_read_auths') |
| def test_setup_credentials_no_existing_credential_will_start_linking( |
| self, mock_read_auths, mock_input): |
| """Verifies _setup_credentials starts linking if no credentials.""" |
| mock_read_auths.side_effect = FileNotFoundError |
| self.assertFalse(self.proc._setup_credentials()) |
| self.assertEqual(1, mock_input.call_count) |
| |
| @mock.patch.object(builtins, 'input', side_effect=KeyboardInterrupt) |
| @mock.patch.object(ams_client.AmsClient, |
| 'set_local_agent_credentials', |
| side_effect=errors.CredentialsError) |
| @mock.patch.object(local_agent.LocalAgentProcess, |
| '_read_auths', |
| return_value=('agent-id', 'agent-secret')) |
| def test_setup_credentials_bad_existing_credential_will_start_linking( |
| self, |
| mock_read_auths, |
| mock_set_local_agent_credentials, |
| mock_input): |
| """Verifies _setup_credentials starts linking if bad credentials.""" |
| self.assertFalse(self.proc._setup_credentials()) |
| self.assertEqual(1, mock_input.call_count) |
| |
| @mock.patch.object(ams_client.AmsClient, 'register') |
| @mock.patch.object(builtins, 'input', return_value='the-code') |
| @mock.patch.object(local_agent.LocalAgentProcess, |
| '_read_auths', |
| side_effect=FileNotFoundError) |
| def test_setup_credentials_start_linking_and_succeed( |
| self, mock_read_auths, mock_inpnut, mock_register): |
| """Verifies _setup_credentials starts linking and succeeds.""" |
| mock_register.return_value = ('the-agent-id', 'the-agent-secret') |
| self.assertTrue(self.proc._setup_credentials()) |
| mock_register.assert_called_once_with(linking_code='the-code') |
| |
| @mock.patch.object(ams_client.AmsClient, 'register') |
| @mock.patch.object(builtins, 'input', return_value='the-code') |
| @mock.patch.object(local_agent.LocalAgentProcess, |
| '_read_auths', |
| side_effect=FileNotFoundError) |
| def test_setup_credentials_start_linking_and_keeps_retry( |
| self, |
| mock_read_auths, |
| mock_input, |
| mock_register): |
| """Verifies _setup_credentials keeps retry when linking fails.""" |
| mock_register.side_effect = ( |
| errors.ApiTimeoutError, |
| errors.ApiTimeoutError, |
| errors.CredentialsError, |
| errors.CredentialsError, |
| errors.ApiError, |
| errors.CredentialsError, |
| ('the-agent-id', 'the-agent-secret'), |
| ) |
| self.assertTrue(self.proc._setup_credentials()) |
| self.assertEqual(7, mock_input.call_count) |
| |
| def test_read_write_auth(self): |
| """Verifies reading/writing auths.""" |
| self.proc._write_auths(_FAKE_AGENT_ID, _FAKE_AGENT_SECRET) |
| self.assertEqual( |
| (_FAKE_AGENT_ID, _FAKE_AGENT_SECRET), self.proc._read_auths()) |
| |
| def test_read_config_with_inexistent_file(self): |
| """Verifies read_config returns {} when config file doesn't exist.""" |
| local_agent.DEFAULT_USER_CONFIG = '' |
| self.assertEqual(({}, {}), local_agent.read_config()) |
| |
| @mock.patch.object(configparser, 'ConfigParser') |
| def test_read_config_missing_root_key(self, mock_parser): |
| """Verifies read_config raise ValueError when root key not present.""" |
| mock_config = mock.MagicMock() |
| mock_parser.return_value = mock_config |
| mock_config.__contains__.return_value = False |
| with self.assertRaisesRegex(ValueError, 'Invalid config file'): |
| local_agent.read_config() |
| |
| @mock.patch.object(configparser, 'ConfigParser') |
| def test_read_config_success(self, mock_parser): |
| """Verifies read_config run successfully.""" |
| mock_config = mock.MagicMock() |
| mock_parser.return_value = mock_config |
| mock_config.__contains__.return_value = True |
| |
| local_agent.read_config() |
| |
| mock_config.read.assert_called_once() |
| |
| @mock.patch.object( |
| local_agent.LocalAgentProcess,'_start_info_reporting_thread') |
| @mock.patch.object( |
| local_agent.LocalAgentProcess, '_start_rpc_polling_thread') |
| @mock.patch.object(suite_session_manager.SuiteSessionManager, 'start') |
| @mock.patch.object( |
| local_agent.LocalAgentProcess, '_setup_credentials', return_value=True) |
| @mock.patch.object(threading.Event, 'is_set', side_effect=(False, True)) |
| def test_run_starts_two_top_level_threads( |
| self, |
| mock_event_is_set, |
| mock_setup_credentials, |
| mock_start, |
| mock_start_rpc_polling_thread, |
| mock_start_info_reporting_thread): |
| """Verifies run() starts RPC polling and info reporting threads.""" |
| self.proc.run() |
| self.assertEqual(1, mock_start_rpc_polling_thread.call_count) |
| self.assertEqual(1, mock_start_info_reporting_thread.call_count) |
| |
| @mock.patch.object( |
| local_agent.LocalAgentProcess,'_start_info_reporting_thread') |
| @mock.patch.object( |
| local_agent.LocalAgentProcess, '_start_rpc_polling_thread') |
| @mock.patch.object(local_agent.LocalAgentProcess, |
| '_setup_credentials', |
| return_value=False) |
| def test_run_will_exit_if_cannot_setup_credentials( |
| self, |
| mock_setup_credentials, |
| mock_start_rpc_polling_thread, |
| mock_start_info_reporting_thread): |
| """Verifies run() aborts if _setup_credentials failed.""" |
| self.proc.run() |
| self.assertFalse(mock_start_rpc_polling_thread.called) |
| self.assertFalse(mock_start_info_reporting_thread.called) |
| |
| @mock.patch.object( |
| translation_layer.TranslationLayer, 'detect_devices', return_value=[]) |
| @mock.patch.object( |
| ams_client.AmsClient, 'get_rpc_request_from_ams', return_value=None) |
| @mock.patch.object(ams_client.AmsClient, 'report_info') |
| @mock.patch.object(suite_session_manager.SuiteSessionManager, 'start') |
| @mock.patch.object( |
| local_agent.LocalAgentProcess, '_setup_credentials', return_value=True) |
| def test_run_exits_main_thread_if_report_info_thread_is_dead( |
| self, |
| _, |
| mock_start, |
| mock_report_info, |
| mock_get_rpc_request, |
| mock_detect_devices): |
| """Verifies run() terminates local agent if report info thread is dead. |
| """ |
| mock_report_info.side_effect = RuntimeError() |
| self.proc.run() |
| |
| @mock.patch.object( |
| translation_layer.TranslationLayer, 'detect_devices', return_value=[]) |
| @mock.patch.object(ams_client.AmsClient, 'get_rpc_request_from_ams') |
| @mock.patch.object(ams_client.AmsClient, 'report_info') |
| @mock.patch.object(suite_session_manager.SuiteSessionManager, 'start') |
| @mock.patch.object( |
| local_agent.LocalAgentProcess, '_setup_credentials', return_value=True) |
| def test_run_exits_main_thread_if_poll_rpc_thread_is_dead( |
| self, |
| _, |
| mock_start, |
| mock_report_info, |
| mock_get_rpc_request, |
| mock_detect_devices): |
| """Verifies run() terminates local agent if poll RPC thread is dead.""" |
| mock_get_rpc_request.side_effect = RuntimeError() |
| self.proc.run() |
| |
| @mock.patch.object(translation_layer.TranslationLayer, 'detect_devices') |
| @mock.patch.object(ams_client.AmsClient, 'report_info') |
| @mock.patch.object(threading.Event, 'wait', return_value=True) |
| def test_report_info_sends_request_to_ams(self, |
| mock_event_wait, |
| mock_report_info, |
| mock_detect_devices): |
| """Verifies _report_info uses AmsClient to report info.""" |
| mock_detect_devices.return_value = [] |
| self.proc._report_info() |
| self.assertEqual(1, mock_report_info.call_count) |
| self.assertEqual(1, mock_detect_devices.call_count) |
| |
| @mock.patch.object(translation_layer.TranslationLayer, 'detect_devices') |
| @mock.patch.object(ams_client.AmsClient, 'report_info') |
| @mock.patch.object(threading.Event, 'wait', return_value=True) |
| def test_report_info_wont_break_when_api_error(self, |
| mock_event_wait, |
| mock_report_info, |
| mock_detect_devices): |
| """Verifies _report_info continues when an API error happens.""" |
| mock_report_info.side_effect = errors.ApiError |
| self.proc._report_info() |
| |
| @mock.patch.object(translation_layer.TranslationLayer, 'detect_devices') |
| @mock.patch.object(ams_client.AmsClient, 'report_info') |
| @mock.patch.object(threading.Event, 'wait', return_value=True) |
| def test_report_info_wont_break_when_api_timeout(self, |
| mock_event_wait, |
| mock_report_info, |
| mock_detect_devices): |
| """Verifies _report_info continues when an API error happens.""" |
| mock_report_info.side_effect = errors.ApiTimeoutError |
| self.proc._report_info() |
| |
| @mock.patch.object( |
| local_agent.LocalAgentProcess, '_clean_up_and_terminate_agent') |
| @mock.patch.object(translation_layer.TranslationLayer, 'detect_devices') |
| @mock.patch.object(ams_client.AmsClient, 'report_info') |
| def test_report_info_break_when_agent_unlinked(self, |
| mock_report_info, |
| mock_detect_devices, |
| mock_clean_up): |
| """Verifies _report_info breaks when the local agent is unlinked.""" |
| mock_report_info.side_effect = errors.UnlinkedError |
| |
| self.proc._report_info() |
| |
| mock_clean_up.assert_called_once() |
| |
| |
| @mock.patch.object(ams_client.AmsClient, |
| 'get_rpc_request_from_ams', |
| return_value=None) |
| @mock.patch.object(threading.Event, 'is_set') |
| def test_poll_rpc_gets_request_from_ams( |
| self, mock_event_is_set, mock_get_rpc_request): |
| """Verifies _poll_rpc will get request from AMS in each iteration.""" |
| # We let there be 2 iterations. |
| mock_event_is_set.side_effect = (False, False, True) |
| self.proc._poll_rpc() |
| self.assertEqual(2, mock_get_rpc_request.call_count) |
| |
| @mock.patch.object(ams_client.AmsClient, 'remove_rpc_request_from_ams') |
| @mock.patch.object(ams_client.AmsClient, 'get_rpc_request_from_ams') |
| @mock.patch.object(threading.Event, 'is_set', side_effect=(False, True)) |
| def test_poll_rpc_continues_if_api_error_when_getting_request( |
| self, |
| mock_event_is_set, |
| mock_get_rpc_request, |
| mock_remove_rpc): |
| """Verifies _poll_rpc continues when getting RPC has ApiError.""" |
| mock_get_rpc_request.side_effect = errors.ApiError |
| self.proc._poll_rpc() |
| self.assertFalse(mock_remove_rpc.called) |
| |
| @mock.patch.object(ams_client.AmsClient, 'remove_rpc_request_from_ams') |
| @mock.patch.object(ams_client.AmsClient, 'get_rpc_request_from_ams') |
| @mock.patch.object(threading.Event, 'is_set', side_effect=(False, True)) |
| def test_poll_rpc_continues_if_api_timeout_when_getting_request( |
| self, |
| mock_event_is_set, |
| mock_get_rpc_request, |
| mock_remove_rpc): |
| """Verifies _poll_rpc continues when get RPC request API timed out.""" |
| mock_get_rpc_request.side_effect = errors.ApiTimeoutError |
| self.proc._poll_rpc() |
| self.assertFalse(mock_remove_rpc.called) |
| |
| @mock.patch.object( |
| local_agent.LocalAgentProcess, '_clean_up_and_terminate_agent') |
| @mock.patch.object(ams_client.AmsClient, 'get_rpc_request_from_ams') |
| @mock.patch.object(threading.Event, 'is_set', return_value=False) |
| def test_poll_rpc_raises_if_agent_unlinked_when_getting_request( |
| self, |
| mock_event_is_set, |
| mock_get_rpc_request, |
| mock_clean_up_and_terminate_agent): |
| """Verifies _poll_rpc raises when get RPC request agent unlinked.""" |
| mock_get_rpc_request.side_effect = errors.UnlinkedError |
| |
| self.proc._poll_rpc() |
| |
| mock_clean_up_and_terminate_agent.assert_called_once() |
| |
| @mock.patch.object(local_agent.LocalAgentProcess, |
| '_start_rpc_execution_thread') |
| @mock.patch.object(ams_client.AmsClient, 'remove_rpc_request_from_ams') |
| @mock.patch.object(ams_client.AmsClient, 'get_rpc_request_from_ams') |
| @mock.patch.object(threading.Event, 'is_set', side_effect=(False, True)) |
| def test_poll_rpc_removes_rpc_request_from_ams_and_executes( |
| self, |
| mock_event_is_set, |
| mock_get_rpc_request, |
| mock_remove_rpc, |
| mock_start_rpc_execution): |
| """Verifies _poll_rpc removes RPC request from AMS and executes it.""" |
| fake_rpc_request = {'hi': 'hello'} |
| mock_get_rpc_request.return_value = fake_rpc_request |
| |
| self.proc._poll_rpc() |
| |
| self.assertEqual(1, mock_get_rpc_request.call_count) |
| mock_remove_rpc.assert_called_once_with(fake_rpc_request) |
| mock_start_rpc_execution.assert_called_once_with(fake_rpc_request) |
| |
| @mock.patch.object(ams_client.AmsClient, 'remove_rpc_request_from_ams') |
| @mock.patch.object(ams_client.AmsClient, 'get_rpc_request_from_ams') |
| def test_poll_rpc_terminate_local_agent_when_remove_rpc_fails( |
| self, |
| mock_get_request, |
| mock_remove_request): |
| """Verifies _poll_rpc terminates local agent when remove RPC fails.""" |
| mock_get_request.return_value = {'hi': 'rpc-request-here'} |
| mock_remove_request.side_effect = errors.ApiError |
| |
| self.proc._poll_rpc() |
| |
| @mock.patch.object(futures.ThreadPoolExecutor, 'submit') |
| def test_start_rpc_execution_thread_submits_to_thread_pool_executor( |
| self, mock_submit): |
| """Verifies _start_rpc_execution_thread submits to ThreadPoolExecutor. |
| """ |
| fake_rpc_request = {'hi': 'hello'} |
| fake_future = mock.Mock() |
| mock_submit.return_value = fake_future |
| |
| self.proc._start_rpc_execution_thread(fake_rpc_request) |
| |
| self.assertEqual( |
| 1, |
| mock_submit.call_count, |
| 'Should submit a task to ThreadPoolExecutor but did not.') |
| self.assertIn( |
| id(fake_future), |
| self.proc._rpc_execution_future_ids, |
| 'Should keep track of the future but did not.') |
| # Should register the callback to the future. |
| fake_future.add_done_callback.assert_called_once_with( |
| self.proc._callback_for_rpc_execution_complete) |
| |
| @mock.patch.object(futures.ThreadPoolExecutor, 'shutdown') |
| def test_terminate_shutdown_pool_executor(self, mock_shutdown): |
| """Verifies _terminate shuts down the ThreadPoolExecutor.""" |
| self.proc._terminate(None, None) |
| self.assertEqual(1, mock_shutdown.call_count) |
| |
| @mock.patch.object(threading.Event, 'set') |
| def test_terminate_sets_threading_event(self, mock_set): |
| """Verifies _terminate sets the threading event.""" |
| self.proc._terminate(None, None) |
| mock_set.assert_called_once() |
| |
| @mock.patch.object(local_agent, 'logger') |
| def test_terminate_thread_still_alive(self, mock_logger): |
| """Verifies _terminates on failure with still alive threads.""" |
| mock_rpc_thread = mock.Mock() |
| mock_rpc_thread.is_alive.return_value = True |
| self.proc._rpc_polling_thread = mock_rpc_thread |
| |
| self.proc._terminate(None, None) |
| |
| mock_rpc_thread.join.assert_called_once() |
| mock_logger.error.assert_called_once() |
| |
| @mock.patch.object(ams_client.AmsClient, 'upload_artifact') |
| @mock.patch.object(local_agent, 'os') |
| @mock.patch.object(shutil, 'rmtree') |
| @mock.patch.object(shutil, 'make_archive') |
| def test_compress_artifacts_and_upload_on_success( |
| self, mock_make, mock_rm, mock_os, mock_ams_upload): |
| """Verifies _compress_artifacts_and_upload on success.""" |
| mock_os.stat.return_value.st_size = 1 |
| with mock.patch('builtins.open', |
| new_callable=mock.mock_open): |
| self.proc._compress_artifacts_and_upload('', '') |
| mock_make.assert_called_once() |
| mock_rm.assert_called_once() |
| mock_os.stat.assert_called_once() |
| mock_ams_upload.assert_called_once() |
| |
| @mock.patch.object(os, 'stat') |
| @mock.patch.object(shutil, 'rmtree') |
| @mock.patch.object(shutil, 'make_archive') |
| def test_compress_artifacts_and_upload_too_large_file( |
| self, mock_make, mock_rm, mock_stat): |
| """ |
| Verifies _compress_artifacts_and_upload on failure with too large file. |
| """ |
| mock_stat.return_value.st_size = ( |
| local_agent.APP_ENGINE_DATA_SIZE_LIMIT + 1) |
| error_mesg = ( |
| f'larger than ' |
| f'{local_agent.APP_ENGINE_DATA_SIZE_LIMIT_HUMAN_READABLE}') |
| with self.assertRaisesRegex(RuntimeError, error_mesg): |
| self.proc._compress_artifacts_and_upload('', '') |
| self.assertEqual(1, mock_make.call_count) |
| self.assertEqual(1, mock_rm.call_count) |
| |
| @mock.patch.object(ams_client.AmsClient, 'upload_artifact') |
| @mock.patch.object(os, 'stat') |
| @mock.patch.object(shutil, 'rmtree') |
| @mock.patch.object(shutil, 'make_archive') |
| def test_compress_artifacts_and_upload_uploading_timed_out( |
| self, mock_make, mock_rm, mock_stat, mock_ams_client_upload): |
| """ |
| Verifies _compress_artifacts_and_upload on failure due to API timed out. |
| """ |
| mock_stat.return_value.st_size = 1 |
| mock_ams_client_upload.side_effect = errors.ApiTimeoutError |
| with mock.patch('builtins.open', |
| new_callable=mock.mock_open): |
| self.proc._compress_artifacts_and_upload('', '') |
| self.assertEqual(1, mock_ams_client_upload.call_count) |
| |
| @mock.patch.object(ams_client.AmsClient, 'upload_artifact') |
| @mock.patch.object(os, 'stat') |
| @mock.patch.object(shutil, 'rmtree') |
| @mock.patch.object(shutil, 'make_archive') |
| def test_compress_artifacts_and_upload_uploading_api_error( |
| self, mock_make, mock_rm, mock_stat, mock_ams_client_upload): |
| """ |
| Verifies _compress_artifacts_and_upload on failure due to API error. |
| """ |
| mock_stat.return_value.st_size = 1 |
| mock_ams_client_upload.side_effect = errors.ApiError |
| with mock.patch('builtins.open', |
| new_callable=mock.mock_open): |
| self.proc._compress_artifacts_and_upload('', '') |
| self.assertEqual(1, mock_ams_client_upload.call_count) |
| |
| @mock.patch.object(local_agent, 'LocalAgentProcess') |
| @mock.patch.object(local_agent, 'register_extension_controllers') |
| @mock.patch.object(local_agent, 'read_config') |
| @mock.patch.object(local_agent, 'exit_if_query_module_versions') |
| def test_main_entry(self, mock_query, mock_read, mock_register, mock_proc): |
| """Verifies local agent main entry on success.""" |
| mock_read.return_value = {}, [] |
| |
| local_agent.main() |
| |
| mock_query.assert_called_once() |
| mock_read.assert_called_once() |
| mock_register.assert_called_once() |
| mock_proc.assert_called_once() |
| mock_proc.return_value.run.assert_called_once() |
| |
| @mock.patch.object(translation_layer.TranslationLayer,'is_rpc_timeout') |
| @mock.patch.object(ams_client.AmsClient, 'send_rpc_response') |
| @mock.patch.object(local_agent.LocalAgentProcess, '_handle_rpc_request') |
| def test_execute_rpc_executes_and_sends_result_to_ams( |
| self, |
| mock_handle_request, |
| mock_send_rpc_response, |
| mock_is_rpc_timeout): |
| """Verifies _execute_rpc executes the RPC and sends the result.""" |
| fake_rpc_request = {'the': 'rpc-request'} |
| fake_rpc_response = {'the': 'rpc-response'} |
| mock_handle_request.return_value = fake_rpc_response |
| mock_is_rpc_timeout.return_value = False |
| |
| self.proc._execute_rpc(fake_rpc_request) |
| |
| mock_handle_request.assert_called_once_with(fake_rpc_request) |
| mock_send_rpc_response.assert_called_once_with(fake_rpc_response) |
| |
| @mock.patch.object(translation_layer.TranslationLayer,'is_rpc_timeout') |
| @mock.patch.object(local_agent, 'logger') |
| @mock.patch.object(ams_client.AmsClient, 'send_rpc_response') |
| @mock.patch.object(local_agent.LocalAgentProcess, '_handle_rpc_request') |
| def test_execute_rpc_fail_to_send_rpc_response( |
| self, |
| mock_handle_request, |
| mock_send_rpc_response, |
| mock_logger, |
| mock_is_rpc_timeout): |
| """Verifies _execute_rpc continues when fail to send RPC response.""" |
| fake_rpc_request = {'the': 'rpc-request'} |
| fake_rpc_response = {'the': 'rpc-response'} |
| mock_handle_request.return_value = fake_rpc_response |
| mock_send_rpc_response.side_effect = errors.ApiError |
| mock_is_rpc_timeout.return_value = False |
| |
| self.proc._execute_rpc(fake_rpc_request) |
| |
| mock_send_rpc_response.assert_called_once() |
| mock_logger.exception.assert_called_once() |
| |
| @mock.patch.object(translation_layer.TranslationLayer,'is_rpc_timeout') |
| @mock.patch.object(local_agent, 'logger') |
| @mock.patch.object(local_agent.LocalAgentProcess, '_handle_rpc_request') |
| def test_execute_rpc_not_sending_timeout_rpc_response( |
| self, |
| mock_handle_request, |
| mock_logger, |
| mock_is_rpc_timeout): |
| """Verifies _execute_rpc not sending timeout RPC response.""" |
| fake_rpc_request = {'id': _FAKE_RPC_ID} |
| mock_is_rpc_timeout.return_value = True |
| |
| self.proc._execute_rpc(fake_rpc_request) |
| |
| mock_logger.warning.assert_called_once() |
| |
| @mock.patch.object( |
| suite_session_manager.SuiteSessionManager, 'start_test_suite') |
| def test_handle_rpc_request_start_suite(self, mock_start): |
| """Verifies handle_rpc_request to start suite on success.""" |
| mock_start.return_value = _FAKE_RPC_RESPONSE |
| fake_rpc_request = {'method': _START_TEST_SUITE} |
| |
| rpc_response = self.proc._handle_rpc_request(fake_rpc_request) |
| |
| self.assertEqual(_FAKE_RPC_RESPONSE, rpc_response) |
| mock_start.assert_called_once_with(fake_rpc_request) |
| |
| @mock.patch.object( |
| suite_session_manager.SuiteSessionManager, 'end_test_suite') |
| def test_handle_rpc_request_end_suite(self, mock_end): |
| """Verifies handle_rpc_request to end suite on success.""" |
| mock_end.return_value = _FAKE_RPC_RESPONSE |
| fake_rpc_request = {'method': _END_TEST_SUITE} |
| |
| rpc_response = self.proc._handle_rpc_request(fake_rpc_request) |
| |
| self.assertEqual(_FAKE_RPC_RESPONSE, rpc_response) |
| mock_end.assert_called_once_with(fake_rpc_request) |
| |
| @mock.patch.object( |
| translation_layer.TranslationLayer, 'dispatch_to_cmd_handler') |
| def test_handle_rpc_request_device_control(self, mock_dispatch): |
| """Verifies handle_rpc_request to control device on success.""" |
| mock_dispatch.return_value = _FAKE_RPC_RESPONSE |
| fake_rpc_request = {'method': _LOCK_DEVICE} |
| |
| rpc_response = self.proc._handle_rpc_request(fake_rpc_request) |
| |
| self.assertEqual(_FAKE_RPC_RESPONSE, rpc_response) |
| mock_dispatch.assert_called_once_with(fake_rpc_request) |
| |
| @mock.patch.object( |
| translation_layer.TranslationLayer, 'commission_to_google_fabric') |
| def test_handle_rpc_request_commission_to_google_fabric_on_success(self, mock_commission): |
| """Verifies handle_rpc_request to commission a device to google fabric on success.""" |
| mock_commission.return_value = _FAKE_RPC_RESPONSE |
| fake_rpc_request = {'method': _COMMISSION_TO_GOOGLE_FABRIC} |
| |
| rpc_response = self.proc._handle_rpc_request(fake_rpc_request) |
| |
| self.assertEqual(_FAKE_RPC_RESPONSE, rpc_response) |
| mock_commission.assert_called_once_with(fake_rpc_request) |
| |
| @mock.patch.object(translation_layer.TranslationLayer, 'commission_to_google_fabric') |
| def test_handle_rpc_request_returns_a_commissioning_error( |
| self, mock_commission |
| ): |
| """Verifies handle_rpc_request returns a commissioning error.""" |
| mock_commission.side_effect = errors.InvalidRPCError( |
| _FAKE_COMMISSION_ERROR_MSG |
| ) |
| fake_rpc_request = { |
| 'id': _FAKE_RPC_ID, |
| 'method': _COMMISSION_TO_GOOGLE_FABRIC, |
| } |
| |
| rpc_response = self.proc._handle_rpc_request(fake_rpc_request) |
| |
| self.assertIn(_FAKE_COMMISSION_ERROR_MSG, rpc_response['error']['message']) |
| |
| @mock.patch.object(local_agent, 'rpc_request_type', return_value='') |
| def test_handle_rpc_request_invalid_rpc_type(self, mock_req_type): |
| """Verifies handle_rpc_request on failure with invalid RPC type.""" |
| error_msg = 'Invalid RPC request type' |
| |
| rpc_response = self.proc._handle_rpc_request({'id': '', 'method': ''}) |
| |
| self.assertIn(error_msg, rpc_response['error']['message']) |
| |
| @mock.patch.object( |
| translation_layer.TranslationLayer, 'dispatch_to_cmd_handler') |
| def test_handle_rpc_unexpected_errors(self, mock_dispatch): |
| """Verifies handle_rpc_request on failure with unexpected errors.""" |
| mock_dispatch.side_effect = RuntimeError(_FAKE_ERROR_MSG) |
| fake_rpc_request = {'id': _FAKE_RPC_ID, 'method': _LOCK_DEVICE} |
| |
| rpc_response = self.proc._handle_rpc_request(fake_rpc_request) |
| |
| self.assertIn(_FAKE_ERROR_MSG, rpc_response['error']['message']) |
| |
| @mock.patch.object(os, 'remove') |
| @mock.patch.object(os.path, 'exists', return_value=True) |
| @mock.patch.object(suite_session_manager.SuiteSessionManager, 'clean_up') |
| def test_clean_up_and_terminate_agent_on_success( |
| self, mock_clean_up, mock_exists, mock_rm): |
| """Verifies _clean_up_and_terminate_agent on success.""" |
| self.proc._clean_up_and_terminate_agent(remove_auth_file=True) |
| mock_clean_up.assert_called_once() |
| mock_rm.assert_called_once() |
| |
| @mock.patch.object(local_agent, 'logger') |
| def test_callback_for_rpc_execution_complete(self, mock_logger): |
| """Verifies _callback_for_rpc_execution_complete on success.""" |
| mock_future = mock.Mock() |
| self.proc._rpc_execution_future_ids.add(id(mock_future)) |
| |
| self.proc._callback_for_rpc_execution_complete(mock_future) |
| |
| mock_future.exception.assert_called_once() |
| mock_logger.error.assert_called_once() |
| |
| @mock.patch.object(local_agent, 'gazoo_device') |
| @mock.patch.object(local_agent, 'importlib') |
| def test_register_extension_controllers_on_success( |
| self, mock_importlib, mock_gazoo_device): |
| """Verifies register_extension_controllers on success.""" |
| matter_controllers_config = [(_FAKE_CONTROLLER_PACKAGE, '')] |
| |
| local_agent.register_extension_controllers(matter_controllers_config) |
| |
| mock_importlib.import_module.assert_called_once() |
| mock_gazoo_device.register.assert_called_once() |
| |
| @mock.patch.object(sys, 'exit') |
| def test_exit_if_query_module_versions(self, mock_exit): |
| """Verifies exit_if_query_module_versions on success.""" |
| with mock.patch.object(sys, 'argv', ['-v']): |
| local_agent.exit_if_query_module_versions() |
| mock_exit.assert_called_once() |
| |
| |
| if __name__ == '__main__': |
| unittest.main(failfast=True) |