| # Copyright 2022 Google LLC |
| # |
| # Licensed under the Apache License, Version 2.0 (the "License"); |
| # you may not use this file except in compliance with the License. |
| # You may obtain a copy of the License at |
| # |
| # http://www.apache.org/licenses/LICENSE-2.0 |
| # |
| # Unless required by applicable law or agreed to in writing, software |
| # distributed under the License is distributed on an "AS IS" BASIS, |
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| # See the License for the specific language governing permissions and |
| # limitations under the License. |
| |
| """Module for LocalAgentProcess.""" |
| import argparse |
| from concurrent import futures |
| import configparser |
| import enum |
| import importlib |
| import json |
| import os |
| import shutil |
| import signal |
| import sys |
| import threading |
| import time |
| import traceback |
| from typing import Any, Dict, List, Optional, Tuple |
| |
| import gazoo_device |
| |
| from local_agent import errors |
| from local_agent import ams_client |
| from local_agent import suite_session_manager |
| from local_agent import version |
| from local_agent import logger as logger_module |
| from local_agent.translation_layer import translation_layer |
| |
| |
| logger = logger_module.get_logger() |
| |
| # ========================= Constants / Configs ========================= # |
| VERSION_FLAGS = {"-v", "-version", "--version"} |
| APP_ENGINE_DATA_SIZE_LIMIT = 31 * 1048576 # 31 MB |
| APP_ENGINE_DATA_SIZE_LIMIT_HUMAN_READABLE = '31 MB' |
| DEFAULT_ARTIFACTS_DIR = '/tmp/local_agent_artifacts' |
| AUTH_FILE = os.path.expanduser('~/.config/google/matter_local_agent_auth.json') |
| DEFAULT_USER_CONFIG = os.path.expanduser( |
| '~/.config/google/local_agent_config.ini') |
| _START_TEST_SUITE_METHOD = 'startTestSuite' |
| _END_TEST_SUITE_METHOD = 'endTestSuite' |
| _COMMISSION_TO_GOOGLE_FABRIC_METHOD = 'commissionToGoogleFabric' |
| _USER_CONFIG_ROOT_KEY = 'ServerConfig' |
| _USER_CONFIG_AMS_HOST = 'AMS_HOST' |
| _USER_CONFIG_AMS_PORT = 'AMS_PORT' |
| _USER_CONFIG_AMS_SCHEME = 'AMS_SCHEME' |
| _USER_CONFIG_ARTIFACTS_DIR = 'ARTIFACTS_DIR' |
| _USER_CONFIG_MATTER_DEVICE_CONTROLLERS = 'MatterDeviceControllerPackages' |
| _EAP_AMS_HOST = 'matter-test-suite-eap.withgoogle.com' |
| _EAP_AMS_PORT = None |
| _EAP_AMS_SCHEME = 'https' |
| # ======================================================================= # |
| |
| |
| class RpcRequestType(enum.Enum): |
| """RPC request type enum.""" |
| START_TEST_SUITE = enum.auto() |
| END_TEST_SUITE = enum.auto() |
| DEVICE_QUERY_CONTROL = enum.auto() |
| COMMISSION_TO_GOOGLE_FABRIC = enum.auto() |
| |
| |
| # ======================== Module level functions ========================== # |
| def rpc_request_type(method: str) -> RpcRequestType: |
| """RPC request type selector. |
| |
| Args: |
| method: JSON-RPC request method. |
| |
| Returns: |
| Request type. |
| """ |
| if method == _START_TEST_SUITE_METHOD: |
| return RpcRequestType.START_TEST_SUITE |
| elif method == _END_TEST_SUITE_METHOD: |
| return RpcRequestType.END_TEST_SUITE |
| elif method == _COMMISSION_TO_GOOGLE_FABRIC_METHOD: |
| return RpcRequestType.COMMISSION_TO_GOOGLE_FABRIC |
| else: |
| return RpcRequestType.DEVICE_QUERY_CONTROL |
| # ========================================================================== # |
| |
| |
| class LocalAgentProcess: |
| """Local Agent Process. |
| |
| A continuously running process which constantly sends GET |
| requests to the Agent Management Service for incoming RPC |
| requests. |
| """ |
| |
| _REPORT_INFO_INTERVAL_SECONDS = 30 |
| _REPORT_INFO_THREAD_TERMINATION_WAIT_SECONDS = 10 |
| _POLL_RPC_COOL_DOWN_SECONDS = 1 |
| _POLL_RPC_THREAD_TERMINATION_WAIT_SECONDS = 10 |
| _MAIN_THREAD_KEEP_ALIVE_COOLDOWN_SECONDS = 1 |
| |
| _MAX_WORKERS_FOR_RPC_EXECUTION_THREAD_POOL = 5 |
| |
| # Status for local agent, defined in the AMS. |
| # Note that we don't need 'OFFLINE' status because that is determined by |
| # the AMS. |
| _STATUS_RUNNING = 'RUNNING' |
| _STATUS_IDLE = 'IDLE' |
| |
| def __init__(self, client: ams_client.AmsClient, artifacts_dir: str): |
| """Initializes local agent. |
| |
| Args: |
| client: The AmsClient instance. |
| artifacts_dir: Artifacts directory. |
| """ |
| self._ams_client = client |
| self._artifacts_dir = artifacts_dir |
| |
| # Threads |
| self._termination_event = threading.Event() |
| self._rpc_polling_thread = None |
| self._info_reporting_thread = None |
| self._rpc_execution_thread_pool = futures.ThreadPoolExecutor( |
| max_workers=self._MAX_WORKERS_FOR_RPC_EXECUTION_THREAD_POOL) |
| |
| # Store IDs of running Futures. |
| self._rpc_execution_future_ids = set() |
| |
| # Translation Layer |
| self._translator = translation_layer.TranslationLayer(client) |
| |
| # Suite session manager |
| self._suite_mgr = suite_session_manager.SuiteSessionManager( |
| artifacts_fn=self._compress_artifacts_and_upload, |
| artifact_root_dir=artifacts_dir, |
| create_devices_fn=self._translator.create_devices, |
| close_devices_fn=self._translator.close_devices) |
| |
| def run(self) -> None: |
| """Runs the local agent, starting polling JSON-RPC and reporting info. |
| |
| We start two threads, one of polling JSON-RPC from AMS and the other |
| for reporting info to AMS. |
| """ |
| if not self._setup_credentials(): |
| logger.warning('Local Agent linking failed, exit the process.') |
| return |
| logger.info('Local Agent is linked successfully.') |
| |
| # Register termination signal handler |
| signal.signal(signal.SIGINT, self._terminate) |
| |
| self._translator.start(termination_event=self._termination_event) |
| self._suite_mgr.start(termination_event=self._termination_event) |
| self._start_info_reporting_thread() |
| self._start_rpc_polling_thread() |
| while not self._termination_event.is_set(): |
| # If any top level thread is dead, terminate the local agent. |
| if self._terminate_if_thread_not_running( |
| self._info_reporting_thread): |
| break |
| if self._terminate_if_thread_not_running( |
| self._rpc_polling_thread): |
| break |
| |
| time.sleep(self._MAIN_THREAD_KEEP_ALIVE_COOLDOWN_SECONDS) |
| |
| def _terminate_if_thread_not_running( |
| self, target_thread: Optional[threading.Thread]) -> bool: |
| """Terminates local agent if the target thread is not running. |
| |
| Args: |
| target_thread: The thread to check running. |
| |
| Returns: |
| True if termination procedure initiated, i.e., the thread is not |
| running. |
| """ |
| if target_thread is None or not target_thread.is_alive(): |
| logger.error('Thread is dead or not even started.') |
| self._terminate(None, None) |
| return True |
| return False |
| |
| def _start_info_reporting_thread(self) -> None: |
| """Starts the _report_info job in a thread.""" |
| if self._info_reporting_thread is None: |
| self._info_reporting_thread = threading.Thread( |
| target=self._report_info, name='Info-reporting-thread') |
| if not self._info_reporting_thread.is_alive(): |
| self._info_reporting_thread.start() |
| |
| def _start_rpc_polling_thread(self) -> None: |
| """Starts the _poll_rpc job in a thread.""" |
| if self._rpc_polling_thread is None: |
| self._rpc_polling_thread = threading.Thread( |
| target=self._poll_rpc, name='RPC-polling-thread') |
| if not self._rpc_polling_thread.is_alive(): |
| self._rpc_polling_thread.start() |
| |
| def _start_rpc_execution_thread(self, rpc_request: Dict[str, Any]) -> None: |
| """Submits a _execute_rpc job to RPC execution thread pool.""" |
| future = self._rpc_execution_thread_pool.submit(self._execute_rpc, |
| rpc_request) |
| self._rpc_execution_future_ids.add(id(future)) |
| future.add_done_callback(self._callback_for_rpc_execution_complete) |
| |
| def _report_info(self) -> None: |
| """Reports local agent information back to AMS periodically. |
| |
| Note that this method contains an infinite loop, and is designed to be |
| run in a separate thread, instead of the main thread. |
| Typically we should not invoke this method directly, and instead we |
| use _start_info_reporting_thread method. |
| |
| The information being reported includes: |
| - The devices connected to this local agent. |
| - The version of this local agent. |
| - GDM version in use. |
| - The status of this local agent. |
| """ |
| while True: |
| logger.info('Reporting status to AMS.') |
| |
| devices = self._translator.detect_devices() |
| status = (self._STATUS_RUNNING if self._rpc_execution_future_ids |
| else self._STATUS_IDLE) |
| local_agent_info = { |
| 'devices': devices, |
| 'gdmVersion': gazoo_device.__version__, |
| 'status': status, |
| 'version': version.__version__, |
| } |
| try: |
| self._ams_client.report_info(local_agent_info) |
| except errors.ApiTimeoutError: |
| logger.warning('Report info API timed out.') |
| except errors.ApiError as e: |
| logger.warning('Report status failed. %s', e) |
| except errors.UnlinkedError: |
| logger.warning('The local agent is unlinked.') |
| self._clean_up_and_terminate_agent(remove_auth_file=True) |
| break |
| |
| if self._termination_event.wait(self._REPORT_INFO_INTERVAL_SECONDS): |
| break |
| |
| logger.info('Stopped reporting info because stop event is set.') |
| |
| def _poll_rpc(self) -> None: |
| """Polls AMS for JSON-RPC requests. |
| |
| Note that this method contains an infinite loop, and is designed to be |
| run in a separate thread, instead of the main thread. |
| """ |
| while not self._termination_event.is_set(): |
| logger.info('Polling JSON-RPC requests from AMS.') |
| |
| try: |
| rpc_request = self._ams_client.get_rpc_request_from_ams() |
| except errors.ApiTimeoutError: |
| logger.warning('Get RPC request API timed out.') |
| rpc_request = None |
| except errors.ApiError as e: |
| logger.warning(f'Failed to get RPC request from AMS: {e}') |
| rpc_request = None |
| except errors.UnlinkedError: |
| logger.warning('The local agent is unlinked.') |
| self._clean_up_and_terminate_agent(remove_auth_file=True) |
| break |
| |
| if rpc_request is not None: |
| try: |
| self._ams_client.remove_rpc_request_from_ams(rpc_request) |
| except (errors.ApiTimeoutError, errors.ApiError): |
| logger.exception( |
| 'Failed to remove JSON-RPC request from AMS.' |
| 'Terminating the local agent process.') |
| self._clean_up_and_terminate_agent() |
| break |
| self._start_rpc_execution_thread(rpc_request) |
| |
| time.sleep(self._POLL_RPC_COOL_DOWN_SECONDS) |
| logger.info('Stopped polling RPC because stop event is set.') |
| |
| def _execute_rpc(self, rpc_request: Dict[str, Any]) -> None: |
| """Executes the JSON-RPC request, and sends result back to AMS.""" |
| rpc_id = rpc_request.get('id') |
| logger.info(f'Executing JSON-RPC: {rpc_id}') |
| |
| rpc_response = self._handle_rpc_request(rpc_request) |
| |
| if self._translator.is_rpc_timeout(rpc_id): |
| logger.warning(f'RPC {rpc_id} has timed out, ignoring response.') |
| else: |
| try: |
| self._ams_client.send_rpc_response(rpc_response) |
| except (errors.ApiTimeoutError, errors.ApiError): |
| logger.exception('Failed to send JSON-RPC response to AMS.') |
| |
| def _handle_rpc_request( |
| self, rpc_request: Dict[str, Any]) -> Dict[str, Any]: |
| """Handles the JSON-RPC request and returns the response. |
| |
| Args: |
| rpc_request: JSON-RPC request. |
| |
| Returns: |
| JSON-RPC response. |
| """ |
| req_type = rpc_request_type(rpc_request['method']) |
| |
| try: |
| if req_type == RpcRequestType.START_TEST_SUITE: |
| resp = self._suite_mgr.start_test_suite(rpc_request) |
| |
| elif req_type == RpcRequestType.END_TEST_SUITE: |
| resp = self._suite_mgr.end_test_suite(rpc_request) |
| |
| elif req_type == RpcRequestType.DEVICE_QUERY_CONTROL: |
| resp = self._translator.dispatch_to_cmd_handler(rpc_request) |
| |
| elif req_type == RpcRequestType.COMMISSION_TO_GOOGLE_FABRIC: |
| resp = self._translator.commission_to_google_fabric(rpc_request) |
| |
| else: |
| raise errors.InvalidRPCError( |
| f'Invalid RPC request type {req_type}.') |
| |
| except Exception as e: |
| logger.exception('Error when handling JSON-RPC.') |
| err_resp = {'id': rpc_request['id'], 'jsonrpc': '2.0'} |
| err_code = getattr(e, 'err_code', errors.DEFAULT_ERROR_CODE) |
| stack_trace = traceback.format_exc() |
| err_msg = stack_trace if err_code == errors.DEFAULT_ERROR_CODE else str(e) |
| err_resp['error'] = {'code': err_code, 'message': err_msg} |
| return err_resp |
| |
| return resp |
| |
| def _callback_for_rpc_execution_complete(self, |
| future: futures.Future) -> None: |
| """Callback function when an RPC execution is complete. |
| |
| What this callback does: |
| 1) We remove the id(future) from the self._rpc_execution_future_ids. |
| 2) We log the exception if there is one. |
| |
| This callback should be registered to an RPC execution future object |
| using Future.add_done_callback() instead of being called directly. |
| |
| Args: |
| future: The Future object when we submit an RPC execution job to |
| the ThreadPoolExecutor. |
| """ |
| future_id = id(future) |
| self._rpc_execution_future_ids.remove(future_id) |
| exc = future.exception() |
| if exc is not None: |
| logger.error('RPC execution encounters exception: %s', exc) |
| |
| def _terminate(self, sig_num, frame) -> None: |
| """Termination procedure for local agent. A signal handler. |
| |
| We set the termination event and stop the 2 top-level threads: |
| info-reporting and rpc-polling. Also shutdown the ThreadPoolExecutor |
| for RPC execution. |
| |
| Args: |
| sig_num: Signal number passed to a signal handler. See: |
| https://docs.python.org/3/library/signal.html#signal.signal |
| frame: Current stack frame passed to a signal hanlder. See: |
| https://docs.python.org/3/library/signal.html#signal.signal |
| """ |
| del sig_num, frame # Unused. |
| |
| logger.warning('Terminating local agent process.') |
| |
| self._clean_up_and_terminate_agent() |
| |
| thread_and_wait_time = ( |
| (self._rpc_polling_thread, |
| self._POLL_RPC_THREAD_TERMINATION_WAIT_SECONDS), |
| (self._info_reporting_thread, |
| self._REPORT_INFO_THREAD_TERMINATION_WAIT_SECONDS)) |
| |
| for thread, wait_time in thread_and_wait_time: |
| if thread is None: |
| # The thread wasn't even created. Skipping. |
| continue |
| logger.info(f'Waiting for thead {thread.name} to stop. ' |
| f'(Timeout = {wait_time} seconds)') |
| thread.join(timeout=wait_time) |
| if thread.is_alive(): |
| logger.error('Thread %s still alive after waiting %s seconds.', |
| thread.name, wait_time) |
| |
| self._rpc_execution_thread_pool.shutdown(wait=False) |
| |
| logger.warning('Local agent process terminated.') |
| |
| def _read_auths(self) -> Tuple[str, str]: |
| """Read agent auths from local config. |
| |
| Reads the stored agent_id and agent_secret locally. |
| |
| Returns: |
| Tuple of agent_id and agent_secret. |
| """ |
| with open(AUTH_FILE, 'r') as fstream: |
| auths = json.load(fstream) |
| return auths['agent_id'], auths['agent_secret'] |
| |
| def _write_auths(self, agent_id: str, agent_secret: str) -> None: |
| """Writes agent auths into local config. |
| |
| Stores the agent_id and agent_secret into the credential auth file |
| locally. |
| |
| Args: |
| agent_id: local agent id. |
| agent_secret: local agent secret. |
| """ |
| with open(AUTH_FILE, 'w') as fstream: |
| auths = {'agent_id': agent_id, 'agent_secret': agent_secret} |
| json.dump(auths, fstream) |
| |
| def _setup_credentials(self) -> bool: |
| """Sets up credentials. |
| |
| Read credentials from a local auth file. If credentials not available |
| or expired, start the register process. |
| |
| Returns: |
| True if credentials are set up successfully. False otherwise. |
| """ |
| credentials_set_up = False |
| try: |
| agent_id, agent_secret = self._read_auths() |
| self._ams_client.set_local_agent_credentials( |
| local_agent_id=agent_id, |
| local_agent_secret=agent_secret) |
| credentials_set_up = True |
| except ( |
| FileNotFoundError, errors.CredentialsError, errors.UnlinkedError): |
| logger.info('Start the linking process') |
| while True: |
| try: |
| linking_code = input('Linking Code:') |
| agent_id, agent_secret = self._ams_client.register( |
| linking_code=linking_code) |
| self._write_auths(agent_id, agent_secret) |
| credentials_set_up = True |
| break |
| except errors.CredentialsError as e: |
| # We don't use logger.exception here in order not to |
| # overwhelm the user interface. |
| logger.warning( |
| 'Invalid linking code. Please retry. (%s)', e) |
| except errors.ApiTimeoutError as e: |
| logger.warning( |
| 'Register API timed out. Please retry. (%s)', e) |
| except errors.ApiError as e: |
| logger.warning( |
| 'Agent registration failed. Please retry. (%s)', e) |
| except KeyboardInterrupt: |
| break |
| return credentials_set_up |
| |
| def _compress_artifacts_and_upload( |
| self, |
| test_suite_id: str, |
| test_result_id: Optional[str] = None) -> None: |
| """Compresses the artifacts and uploads if needed. |
| |
| Compresses the artifact directory, and uploads the artifacts |
| to AMS if the test result ID is provided. |
| |
| Args: |
| test_suite_id: Test suite ID. |
| test_result_id: Test result ID. |
| |
| Raises: |
| RuntimeError: Uploading fails. |
| """ |
| logger.info(f'Compressing artifacts for {test_suite_id}') |
| |
| test_suite_dir = os.path.join(self._artifacts_dir, test_suite_id) |
| |
| # remove logging handler |
| local_agent_log = os.path.join( |
| self._artifacts_dir, test_suite_id, 'local_agent.log') |
| logger_module.remove_file_handler(local_agent_log) |
| |
| # compress the artifacts directory and remove it after the compression |
| shutil.make_archive(test_suite_dir, 'gztar', test_suite_dir) |
| shutil.rmtree(test_suite_dir) |
| |
| if test_result_id is not None: |
| logger.info(f'Uploading artifacts for {test_suite_id}') |
| |
| # check against file size limit. |
| artifacts_name = test_suite_dir + '.tar.gz' |
| if os.stat(artifacts_name).st_size >= APP_ENGINE_DATA_SIZE_LIMIT: |
| raise RuntimeError( |
| f'The file size of {artifacts_name} is larger than ' |
| f'{APP_ENGINE_DATA_SIZE_LIMIT_HUMAN_READABLE}.') |
| |
| # upload the artifact |
| try: |
| self._ams_client.upload_artifact(artifacts_name, |
| test_result_id=test_result_id) |
| except (errors.ApiTimeoutError, errors.ApiError): |
| logger.exception('Failed to upload artifact.') |
| |
| def _clean_up_and_terminate_agent( |
| self, remove_auth_file: bool = False) -> None: |
| """Cleanup method for local agent. |
| |
| Cleans up suite session and sets the terminate event. |
| Removes the auth file if remove_auth_file is true. |
| |
| Args: |
| remove_auth_file: To remove the auth file or not. |
| """ |
| if remove_auth_file: |
| if os.path.exists(AUTH_FILE): |
| os.remove(AUTH_FILE) |
| self._suite_mgr.clean_up() |
| self._termination_event.set() |
| |
| |
| def read_config() -> Tuple[Dict[str, Any], List[Tuple[str, str]]]: |
| """Reads user data from configuration file. |
| |
| The config file should be in YAML format. User can specify the path to |
| their config file using command line argument. If not provided, we use the |
| default path, DEFAULT_USER_CONFIG. |
| |
| The configuration file is not required. An empty dict is returned when |
| there's no such config file. |
| |
| Raises: |
| RuntimeError: If unable to parse the config file as YAML. |
| |
| Returns: |
| User configuration data for AMS and extension Matter device controllers. |
| """ |
| parser = argparse.ArgumentParser() |
| parser.add_argument('-u', '--user_config', type=str, required=False, |
| default=DEFAULT_USER_CONFIG, |
| help='Local Agent user config file.') |
| args, leftover = parser.parse_known_args(sys.argv[1:]) |
| sys.argv[1:] = leftover |
| |
| if not os.path.exists(args.user_config): |
| return {}, {} |
| |
| config = configparser.ConfigParser(allow_no_value=True) |
| config.read(args.user_config) |
| |
| if _USER_CONFIG_ROOT_KEY not in config: |
| raise ValueError( |
| f'Invalid config file, no section {_USER_CONFIG_ROOT_KEY}.' |
| 'Please refer to example_config.ini for reference.') |
| |
| matter_controllers_config = ( |
| config.items(_USER_CONFIG_MATTER_DEVICE_CONTROLLERS) |
| if _USER_CONFIG_MATTER_DEVICE_CONTROLLERS in config else []) |
| |
| |
| return config[_USER_CONFIG_ROOT_KEY], matter_controllers_config |
| |
| |
| def register_extension_controllers( |
| matter_controllers_config: List[Tuple[str, str]]) -> None: |
| """Registers partner extension Matter device controllers in GDM. |
| |
| Args: |
| matter_controllers_config: List of controllers to be registered. |
| """ |
| for controller_package_name, _ in matter_controllers_config: |
| controller = importlib.import_module(controller_package_name) |
| gazoo_device.register(controller) |
| |
| |
| def exit_if_query_module_versions() -> None: |
| """Exits the process if querying the Local Agent and GDM versions.""" |
| if VERSION_FLAGS & set(sys.argv): |
| la_version = f'\n******* Local Agent version {version.__version__} *******' |
| gdm_version = f'\n******* GDM version {gazoo_device.__version__} *******' |
| logger.info(la_version + gdm_version) |
| sys.exit(0) |
| |
| |
| def main() -> None: |
| """Main entry of Local Agent.""" |
| exit_if_query_module_versions() |
| |
| user_config, matter_controllers_config = read_config() |
| |
| ams_host = user_config.get(_USER_CONFIG_AMS_HOST, _EAP_AMS_HOST) |
| ams_port = user_config.get(_USER_CONFIG_AMS_PORT, _EAP_AMS_PORT) |
| ams_scheme = user_config.get(_USER_CONFIG_AMS_SCHEME, _EAP_AMS_SCHEME) |
| artifacts_dir = ( |
| user_config.get(_USER_CONFIG_ARTIFACTS_DIR, DEFAULT_ARTIFACTS_DIR)) |
| |
| register_extension_controllers(matter_controllers_config) |
| |
| client = ams_client.AmsClient(host=ams_host, |
| port=ams_port, |
| scheme=ams_scheme) |
| proc = LocalAgentProcess(client=client, artifacts_dir=artifacts_dir) |
| proc.run() |