blob: 5400fa10daa10c69558020e9606196e7ced4b58f [file] [log] [blame]
# Copyright 2022 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Module for LocalAgentProcess."""
import argparse
from concurrent import futures
import configparser
import enum
import importlib
import json
import os
import shutil
import signal
import sys
import threading
import time
import traceback
from typing import Any, Dict, List, Optional, Tuple
import gazoo_device
from local_agent import errors
from local_agent import ams_client
from local_agent import suite_session_manager
from local_agent import version
from local_agent import logger as logger_module
from local_agent.translation_layer import translation_layer
logger = logger_module.get_logger()
# ========================= Constants / Configs ========================= #
VERSION_FLAGS = {"-v", "-version", "--version"}
APP_ENGINE_DATA_SIZE_LIMIT = 31 * 1048576 # 31 MB
APP_ENGINE_DATA_SIZE_LIMIT_HUMAN_READABLE = '31 MB'
DEFAULT_ARTIFACTS_DIR = '/tmp/local_agent_artifacts'
AUTH_FILE = os.path.expanduser('~/.config/google/matter_local_agent_auth.json')
DEFAULT_USER_CONFIG = os.path.expanduser(
'~/.config/google/local_agent_config.ini')
_START_TEST_SUITE_METHOD = 'startTestSuite'
_END_TEST_SUITE_METHOD = 'endTestSuite'
_COMMISSION_TO_GOOGLE_FABRIC_METHOD = 'commissionToGoogleFabric'
_USER_CONFIG_ROOT_KEY = 'ServerConfig'
_USER_CONFIG_AMS_HOST = 'AMS_HOST'
_USER_CONFIG_AMS_PORT = 'AMS_PORT'
_USER_CONFIG_AMS_SCHEME = 'AMS_SCHEME'
_USER_CONFIG_ARTIFACTS_DIR = 'ARTIFACTS_DIR'
_USER_CONFIG_MATTER_DEVICE_CONTROLLERS = 'MatterDeviceControllerPackages'
_EAP_AMS_HOST = 'matter-test-suite-eap.withgoogle.com'
_EAP_AMS_PORT = None
_EAP_AMS_SCHEME = 'https'
# ======================================================================= #
class RpcRequestType(enum.Enum):
"""RPC request type enum."""
START_TEST_SUITE = enum.auto()
END_TEST_SUITE = enum.auto()
DEVICE_QUERY_CONTROL = enum.auto()
COMMISSION_TO_GOOGLE_FABRIC = enum.auto()
# ======================== Module level functions ========================== #
def rpc_request_type(method: str) -> RpcRequestType:
"""RPC request type selector.
Args:
method: JSON-RPC request method.
Returns:
Request type.
"""
if method == _START_TEST_SUITE_METHOD:
return RpcRequestType.START_TEST_SUITE
elif method == _END_TEST_SUITE_METHOD:
return RpcRequestType.END_TEST_SUITE
elif method == _COMMISSION_TO_GOOGLE_FABRIC_METHOD:
return RpcRequestType.COMMISSION_TO_GOOGLE_FABRIC
else:
return RpcRequestType.DEVICE_QUERY_CONTROL
# ========================================================================== #
class LocalAgentProcess:
"""Local Agent Process.
A continuously running process which constantly sends GET
requests to the Agent Management Service for incoming RPC
requests.
"""
_REPORT_INFO_INTERVAL_SECONDS = 30
_REPORT_INFO_THREAD_TERMINATION_WAIT_SECONDS = 10
_POLL_RPC_COOL_DOWN_SECONDS = 1
_POLL_RPC_THREAD_TERMINATION_WAIT_SECONDS = 10
_MAIN_THREAD_KEEP_ALIVE_COOLDOWN_SECONDS = 1
_MAX_WORKERS_FOR_RPC_EXECUTION_THREAD_POOL = 5
# Status for local agent, defined in the AMS.
# Note that we don't need 'OFFLINE' status because that is determined by
# the AMS.
_STATUS_RUNNING = 'RUNNING'
_STATUS_IDLE = 'IDLE'
def __init__(self, client: ams_client.AmsClient, artifacts_dir: str):
"""Initializes local agent.
Args:
client: The AmsClient instance.
artifacts_dir: Artifacts directory.
"""
self._ams_client = client
self._artifacts_dir = artifacts_dir
# Threads
self._termination_event = threading.Event()
self._rpc_polling_thread = None
self._info_reporting_thread = None
self._rpc_execution_thread_pool = futures.ThreadPoolExecutor(
max_workers=self._MAX_WORKERS_FOR_RPC_EXECUTION_THREAD_POOL)
# Store IDs of running Futures.
self._rpc_execution_future_ids = set()
# Translation Layer
self._translator = translation_layer.TranslationLayer(client)
# Suite session manager
self._suite_mgr = suite_session_manager.SuiteSessionManager(
artifacts_fn=self._compress_artifacts_and_upload,
artifact_root_dir=artifacts_dir,
create_devices_fn=self._translator.create_devices,
close_devices_fn=self._translator.close_devices)
def run(self) -> None:
"""Runs the local agent, starting polling JSON-RPC and reporting info.
We start two threads, one of polling JSON-RPC from AMS and the other
for reporting info to AMS.
"""
if not self._setup_credentials():
logger.warning('Local Agent linking failed, exit the process.')
return
logger.info('Local Agent is linked successfully.')
# Register termination signal handler
signal.signal(signal.SIGINT, self._terminate)
self._translator.start(termination_event=self._termination_event)
self._suite_mgr.start(termination_event=self._termination_event)
self._start_info_reporting_thread()
self._start_rpc_polling_thread()
while not self._termination_event.is_set():
# If any top level thread is dead, terminate the local agent.
if self._terminate_if_thread_not_running(
self._info_reporting_thread):
break
if self._terminate_if_thread_not_running(
self._rpc_polling_thread):
break
time.sleep(self._MAIN_THREAD_KEEP_ALIVE_COOLDOWN_SECONDS)
def _terminate_if_thread_not_running(
self, target_thread: Optional[threading.Thread]) -> bool:
"""Terminates local agent if the target thread is not running.
Args:
target_thread: The thread to check running.
Returns:
True if termination procedure initiated, i.e., the thread is not
running.
"""
if target_thread is None or not target_thread.is_alive():
logger.error('Thread is dead or not even started.')
self._terminate(None, None)
return True
return False
def _start_info_reporting_thread(self) -> None:
"""Starts the _report_info job in a thread."""
if self._info_reporting_thread is None:
self._info_reporting_thread = threading.Thread(
target=self._report_info, name='Info-reporting-thread')
if not self._info_reporting_thread.is_alive():
self._info_reporting_thread.start()
def _start_rpc_polling_thread(self) -> None:
"""Starts the _poll_rpc job in a thread."""
if self._rpc_polling_thread is None:
self._rpc_polling_thread = threading.Thread(
target=self._poll_rpc, name='RPC-polling-thread')
if not self._rpc_polling_thread.is_alive():
self._rpc_polling_thread.start()
def _start_rpc_execution_thread(self, rpc_request: Dict[str, Any]) -> None:
"""Submits a _execute_rpc job to RPC execution thread pool."""
future = self._rpc_execution_thread_pool.submit(self._execute_rpc,
rpc_request)
self._rpc_execution_future_ids.add(id(future))
future.add_done_callback(self._callback_for_rpc_execution_complete)
def _report_info(self) -> None:
"""Reports local agent information back to AMS periodically.
Note that this method contains an infinite loop, and is designed to be
run in a separate thread, instead of the main thread.
Typically we should not invoke this method directly, and instead we
use _start_info_reporting_thread method.
The information being reported includes:
- The devices connected to this local agent.
- The version of this local agent.
- GDM version in use.
- The status of this local agent.
"""
while True:
logger.info('Reporting status to AMS.')
devices = self._translator.detect_devices()
status = (self._STATUS_RUNNING if self._rpc_execution_future_ids
else self._STATUS_IDLE)
local_agent_info = {
'devices': devices,
'gdmVersion': gazoo_device.__version__,
'status': status,
'version': version.__version__,
}
try:
self._ams_client.report_info(local_agent_info)
except errors.ApiTimeoutError:
logger.warning('Report info API timed out.')
except errors.ApiError as e:
logger.warning('Report status failed. %s', e)
except errors.UnlinkedError:
logger.warning('The local agent is unlinked.')
self._clean_up_and_terminate_agent(remove_auth_file=True)
break
if self._termination_event.wait(self._REPORT_INFO_INTERVAL_SECONDS):
break
logger.info('Stopped reporting info because stop event is set.')
def _poll_rpc(self) -> None:
"""Polls AMS for JSON-RPC requests.
Note that this method contains an infinite loop, and is designed to be
run in a separate thread, instead of the main thread.
"""
while not self._termination_event.is_set():
logger.info('Polling JSON-RPC requests from AMS.')
try:
rpc_request = self._ams_client.get_rpc_request_from_ams()
except errors.ApiTimeoutError:
logger.warning('Get RPC request API timed out.')
rpc_request = None
except errors.ApiError as e:
logger.warning(f'Failed to get RPC request from AMS: {e}')
rpc_request = None
except errors.UnlinkedError:
logger.warning('The local agent is unlinked.')
self._clean_up_and_terminate_agent(remove_auth_file=True)
break
if rpc_request is not None:
try:
self._ams_client.remove_rpc_request_from_ams(rpc_request)
except (errors.ApiTimeoutError, errors.ApiError):
logger.exception(
'Failed to remove JSON-RPC request from AMS.'
'Terminating the local agent process.')
self._clean_up_and_terminate_agent()
break
self._start_rpc_execution_thread(rpc_request)
time.sleep(self._POLL_RPC_COOL_DOWN_SECONDS)
logger.info('Stopped polling RPC because stop event is set.')
def _execute_rpc(self, rpc_request: Dict[str, Any]) -> None:
"""Executes the JSON-RPC request, and sends result back to AMS."""
rpc_id = rpc_request.get('id')
logger.info(f'Executing JSON-RPC: {rpc_id}')
rpc_response = self._handle_rpc_request(rpc_request)
if self._translator.is_rpc_timeout(rpc_id):
logger.warning(f'RPC {rpc_id} has timed out, ignoring response.')
else:
try:
self._ams_client.send_rpc_response(rpc_response)
except (errors.ApiTimeoutError, errors.ApiError):
logger.exception('Failed to send JSON-RPC response to AMS.')
def _handle_rpc_request(
self, rpc_request: Dict[str, Any]) -> Dict[str, Any]:
"""Handles the JSON-RPC request and returns the response.
Args:
rpc_request: JSON-RPC request.
Returns:
JSON-RPC response.
"""
req_type = rpc_request_type(rpc_request['method'])
try:
if req_type == RpcRequestType.START_TEST_SUITE:
resp = self._suite_mgr.start_test_suite(rpc_request)
elif req_type == RpcRequestType.END_TEST_SUITE:
resp = self._suite_mgr.end_test_suite(rpc_request)
elif req_type == RpcRequestType.DEVICE_QUERY_CONTROL:
resp = self._translator.dispatch_to_cmd_handler(rpc_request)
elif req_type == RpcRequestType.COMMISSION_TO_GOOGLE_FABRIC:
# TODO(b/282592569): Move commission_to_google_fabric out of translator.
resp = self._translator.commission_to_google_fabric(rpc_request)
else:
raise errors.InvalidRPCError(
f'Invalid RPC request type {req_type}.')
except Exception as e:
logger.exception('Error when handling JSON-RPC.')
err_resp = {'id': rpc_request['id'], 'jsonrpc': '2.0'}
err_code = getattr(e, 'err_code', errors.DEFAULT_ERROR_CODE)
stack_trace = traceback.format_exc()
err_msg = stack_trace if err_code == errors.DEFAULT_ERROR_CODE else str(e)
err_resp['error'] = {'code': err_code, 'message': err_msg}
return err_resp
return resp
def _callback_for_rpc_execution_complete(self,
future: futures.Future) -> None:
"""Callback function when an RPC execution is complete.
What this callback does:
1) We remove the id(future) from the self._rpc_execution_future_ids.
2) We log the exception if there is one.
This callback should be registered to an RPC execution future object
using Future.add_done_callback() instead of being called directly.
Args:
future: The Future object when we submit an RPC execution job to
the ThreadPoolExecutor.
"""
future_id = id(future)
self._rpc_execution_future_ids.remove(future_id)
exc = future.exception()
if exc is not None:
logger.error('RPC execution encounters exception: %s', exc)
def _terminate(self, sig_num, frame) -> None:
"""Termination procedure for local agent. A signal handler.
We set the termination event and stop the 2 top-level threads:
info-reporting and rpc-polling. Also shutdown the ThreadPoolExecutor
for RPC execution.
Args:
sig_num: Signal number passed to a signal handler. See:
https://docs.python.org/3/library/signal.html#signal.signal
frame: Current stack frame passed to a signal hanlder. See:
https://docs.python.org/3/library/signal.html#signal.signal
"""
del sig_num, frame # Unused.
logger.warning('Terminating local agent process.')
self._clean_up_and_terminate_agent()
thread_and_wait_time = (
(self._rpc_polling_thread,
self._POLL_RPC_THREAD_TERMINATION_WAIT_SECONDS),
(self._info_reporting_thread,
self._REPORT_INFO_THREAD_TERMINATION_WAIT_SECONDS))
for thread, wait_time in thread_and_wait_time:
if thread is None:
# The thread wasn't even created. Skipping.
continue
logger.info(f'Waiting for thead {thread.name} to stop. '
f'(Timeout = {wait_time} seconds)')
thread.join(timeout=wait_time)
if thread.is_alive():
logger.error('Thread %s still alive after waiting %s seconds.',
thread.name, wait_time)
self._rpc_execution_thread_pool.shutdown(wait=False)
logger.warning('Local agent process terminated.')
def _read_auths(self) -> Tuple[str, str]:
"""Read agent auths from local config.
Reads the stored agent_id and agent_secret locally.
Returns:
Tuple of agent_id and agent_secret.
"""
with open(AUTH_FILE, 'r') as fstream:
auths = json.load(fstream)
return auths['agent_id'], auths['agent_secret']
def _write_auths(self, agent_id: str, agent_secret: str) -> None:
"""Writes agent auths into local config.
Stores the agent_id and agent_secret into the credential auth file
locally.
Args:
agent_id: local agent id.
agent_secret: local agent secret.
"""
with open(AUTH_FILE, 'w') as fstream:
auths = {'agent_id': agent_id, 'agent_secret': agent_secret}
json.dump(auths, fstream)
def _setup_credentials(self) -> bool:
"""Sets up credentials.
Read credentials from a local auth file. If credentials not available
or expired, start the register process.
Returns:
True if credentials are set up successfully. False otherwise.
"""
credentials_set_up = False
try:
agent_id, agent_secret = self._read_auths()
self._ams_client.set_local_agent_credentials(
local_agent_id=agent_id,
local_agent_secret=agent_secret)
credentials_set_up = True
except (
FileNotFoundError, errors.CredentialsError, errors.UnlinkedError):
logger.info('Start the linking process')
while True:
try:
linking_code = input('Linking Code:')
agent_id, agent_secret = self._ams_client.register(
linking_code=linking_code)
self._write_auths(agent_id, agent_secret)
credentials_set_up = True
break
except errors.CredentialsError as e:
# We don't use logger.exception here in order not to
# overwhelm the user interface.
logger.warning(
'Invalid linking code. Please retry. (%s)', e)
except errors.ApiTimeoutError as e:
logger.warning(
'Register API timed out. Please retry. (%s)', e)
except errors.ApiError as e:
logger.warning(
'Agent registration failed. Please retry. (%s)', e)
except KeyboardInterrupt:
break
return credentials_set_up
def _compress_artifacts_and_upload(
self,
test_suite_id: str,
test_result_id: Optional[str] = None) -> None:
"""Compresses the artifacts and uploads if needed.
Compresses the artifact directory, and uploads the artifacts
to AMS if the test result ID is provided.
Args:
test_suite_id: Test suite ID.
test_result_id: Test result ID.
Raises:
RuntimeError: Uploading fails.
"""
logger.info(f'Compressing artifacts for {test_suite_id}')
test_suite_dir = os.path.join(self._artifacts_dir, test_suite_id)
# remove logging handler
local_agent_log = os.path.join(
self._artifacts_dir, test_suite_id, 'local_agent.log')
logger_module.remove_file_handler(local_agent_log)
# compress the artifacts directory and remove it after the compression
shutil.make_archive(test_suite_dir, 'gztar', test_suite_dir)
shutil.rmtree(test_suite_dir)
if test_result_id is not None:
logger.info(f'Uploading artifacts for {test_suite_id}')
# check against file size limit.
artifacts_name = test_suite_dir + '.tar.gz'
if os.stat(artifacts_name).st_size >= APP_ENGINE_DATA_SIZE_LIMIT:
raise RuntimeError(
f'The file size of {artifacts_name} is larger than '
f'{APP_ENGINE_DATA_SIZE_LIMIT_HUMAN_READABLE}.')
# upload the artifact
try:
self._ams_client.upload_artifact(artifacts_name,
test_result_id=test_result_id)
except (errors.ApiTimeoutError, errors.ApiError):
logger.exception('Failed to upload artifact.')
def _clean_up_and_terminate_agent(
self, remove_auth_file: bool = False) -> None:
"""Cleanup method for local agent.
Cleans up suite session and sets the terminate event.
Removes the auth file if remove_auth_file is true.
Args:
remove_auth_file: To remove the auth file or not.
"""
if remove_auth_file:
if os.path.exists(AUTH_FILE):
os.remove(AUTH_FILE)
self._suite_mgr.clean_up()
self._termination_event.set()
def read_config() -> Tuple[Dict[str, Any], List[Tuple[str, str]]]:
"""Reads user data from configuration file.
The config file should be in YAML format. User can specify the path to
their config file using command line argument. If not provided, we use the
default path, DEFAULT_USER_CONFIG.
The configuration file is not required. An empty dict is returned when
there's no such config file.
Raises:
RuntimeError: If unable to parse the config file as YAML.
Returns:
User configuration data for AMS and extension Matter device controllers.
"""
parser = argparse.ArgumentParser()
parser.add_argument('-u', '--user_config', type=str, required=False,
default=DEFAULT_USER_CONFIG,
help='Local Agent user config file.')
args, leftover = parser.parse_known_args(sys.argv[1:])
sys.argv[1:] = leftover
if not os.path.exists(args.user_config):
return {}, {}
config = configparser.ConfigParser(allow_no_value=True)
config.read(args.user_config)
if _USER_CONFIG_ROOT_KEY not in config:
raise ValueError(
f'Invalid config file, no section {_USER_CONFIG_ROOT_KEY}.'
'Please refer to example_config.ini for reference.')
matter_controllers_config = (
config.items(_USER_CONFIG_MATTER_DEVICE_CONTROLLERS)
if _USER_CONFIG_MATTER_DEVICE_CONTROLLERS in config else [])
return config[_USER_CONFIG_ROOT_KEY], matter_controllers_config
def register_extension_controllers(
matter_controllers_config: List[Tuple[str, str]]) -> None:
"""Registers partner extension Matter device controllers in GDM.
Args:
matter_controllers_config: List of controllers to be registered.
"""
for controller_package_name, _ in matter_controllers_config:
controller = importlib.import_module(controller_package_name)
gazoo_device.register(controller)
def exit_if_query_module_versions() -> None:
"""Exits the process if querying the Local Agent and GDM versions."""
if VERSION_FLAGS & set(sys.argv):
la_version = f'\n******* Local Agent version {version.__version__} *******'
gdm_version = f'\n******* GDM version {gazoo_device.__version__} *******'
logger.info(la_version + gdm_version)
sys.exit(0)
def main() -> None:
"""Main entry of Local Agent."""
exit_if_query_module_versions()
user_config, matter_controllers_config = read_config()
ams_host = user_config.get(_USER_CONFIG_AMS_HOST, _EAP_AMS_HOST)
ams_port = user_config.get(_USER_CONFIG_AMS_PORT, _EAP_AMS_PORT)
ams_scheme = user_config.get(_USER_CONFIG_AMS_SCHEME, _EAP_AMS_SCHEME)
artifacts_dir = (
user_config.get(_USER_CONFIG_ARTIFACTS_DIR, DEFAULT_ARTIFACTS_DIR))
register_extension_controllers(matter_controllers_config)
client = ams_client.AmsClient(host=ams_host,
port=ams_port,
scheme=ams_scheme)
proc = LocalAgentProcess(client=client, artifacts_dir=artifacts_dir)
proc.run()