[Local Agent] Release local agent packages.
Change-Id: I7ccb58bb837a0684924021a6e8680fad513c1215
diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md
new file mode 100644
index 0000000..63cd555
--- /dev/null
+++ b/CONTRIBUTING.md
@@ -0,0 +1,27 @@
+# How to Contribute
+
+We'd love to accept your patches and contributions to Local Agent. There are just a few small guidelines you need to follow.
+The contributor workflow has not been fully ironed out yet, so make sure to start by sending an email to smarthome-testsuite-contrib@google.com with an outline of your planned change. We will review your proposal and advise on the next steps you will need to take.
+
+## Contributor License Agreement
+
+Contributions to this project must be accompanied by a Contributor License
+Agreement (CLA). You (or your employer) retain the copyright to your
+contribution; this simply gives us permission to use and redistribute your
+contributions as part of the project. Head over to
+<https://cla.developers.google.com/> to see your current agreements on file or
+to sign a new one.
+
+You generally only need to submit a CLA once, so if you've already submitted one
+(even if it was for a different project), you probably don't need to do it
+again.
+
+## Code Reviews
+
+All submissions, including submissions by project members, require review. We
+use Git at Google pull requests for this purpose.
+
+## Community Guidelines
+
+This project follows
+[Google's Open Source Community Guidelines](https://opensource.google/conduct/)
\ No newline at end of file
diff --git a/LICENSE b/LICENSE
new file mode 100644
index 0000000..d645695
--- /dev/null
+++ b/LICENSE
@@ -0,0 +1,202 @@
+
+ Apache License
+ Version 2.0, January 2004
+ http://www.apache.org/licenses/
+
+ TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
+
+ 1. Definitions.
+
+ "License" shall mean the terms and conditions for use, reproduction,
+ and distribution as defined by Sections 1 through 9 of this document.
+
+ "Licensor" shall mean the copyright owner or entity authorized by
+ the copyright owner that is granting the License.
+
+ "Legal Entity" shall mean the union of the acting entity and all
+ other entities that control, are controlled by, or are under common
+ control with that entity. For the purposes of this definition,
+ "control" means (i) the power, direct or indirect, to cause the
+ direction or management of such entity, whether by contract or
+ otherwise, or (ii) ownership of fifty percent (50%) or more of the
+ outstanding shares, or (iii) beneficial ownership of such entity.
+
+ "You" (or "Your") shall mean an individual or Legal Entity
+ exercising permissions granted by this License.
+
+ "Source" form shall mean the preferred form for making modifications,
+ including but not limited to software source code, documentation
+ source, and configuration files.
+
+ "Object" form shall mean any form resulting from mechanical
+ transformation or translation of a Source form, including but
+ not limited to compiled object code, generated documentation,
+ and conversions to other media types.
+
+ "Work" shall mean the work of authorship, whether in Source or
+ Object form, made available under the License, as indicated by a
+ copyright notice that is included in or attached to the work
+ (an example is provided in the Appendix below).
+
+ "Derivative Works" shall mean any work, whether in Source or Object
+ form, that is based on (or derived from) the Work and for which the
+ editorial revisions, annotations, elaborations, or other modifications
+ represent, as a whole, an original work of authorship. For the purposes
+ of this License, Derivative Works shall not include works that remain
+ separable from, or merely link (or bind by name) to the interfaces of,
+ the Work and Derivative Works thereof.
+
+ "Contribution" shall mean any work of authorship, including
+ the original version of the Work and any modifications or additions
+ to that Work or Derivative Works thereof, that is intentionally
+ submitted to Licensor for inclusion in the Work by the copyright owner
+ or by an individual or Legal Entity authorized to submit on behalf of
+ the copyright owner. For the purposes of this definition, "submitted"
+ means any form of electronic, verbal, or written communication sent
+ to the Licensor or its representatives, including but not limited to
+ communication on electronic mailing lists, source code control systems,
+ and issue tracking systems that are managed by, or on behalf of, the
+ Licensor for the purpose of discussing and improving the Work, but
+ excluding communication that is conspicuously marked or otherwise
+ designated in writing by the copyright owner as "Not a Contribution."
+
+ "Contributor" shall mean Licensor and any individual or Legal Entity
+ on behalf of whom a Contribution has been received by Licensor and
+ subsequently incorporated within the Work.
+
+ 2. Grant of Copyright License. Subject to the terms and conditions of
+ this License, each Contributor hereby grants to You a perpetual,
+ worldwide, non-exclusive, no-charge, royalty-free, irrevocable
+ copyright license to reproduce, prepare Derivative Works of,
+ publicly display, publicly perform, sublicense, and distribute the
+ Work and such Derivative Works in Source or Object form.
+
+ 3. Grant of Patent License. Subject to the terms and conditions of
+ this License, each Contributor hereby grants to You a perpetual,
+ worldwide, non-exclusive, no-charge, royalty-free, irrevocable
+ (except as stated in this section) patent license to make, have made,
+ use, offer to sell, sell, import, and otherwise transfer the Work,
+ where such license applies only to those patent claims licensable
+ by such Contributor that are necessarily infringed by their
+ Contribution(s) alone or by combination of their Contribution(s)
+ with the Work to which such Contribution(s) was submitted. If You
+ institute patent litigation against any entity (including a
+ cross-claim or counterclaim in a lawsuit) alleging that the Work
+ or a Contribution incorporated within the Work constitutes direct
+ or contributory patent infringement, then any patent licenses
+ granted to You under this License for that Work shall terminate
+ as of the date such litigation is filed.
+
+ 4. Redistribution. You may reproduce and distribute copies of the
+ Work or Derivative Works thereof in any medium, with or without
+ modifications, and in Source or Object form, provided that You
+ meet the following conditions:
+
+ (a) You must give any other recipients of the Work or
+ Derivative Works a copy of this License; and
+
+ (b) You must cause any modified files to carry prominent notices
+ stating that You changed the files; and
+
+ (c) You must retain, in the Source form of any Derivative Works
+ that You distribute, all copyright, patent, trademark, and
+ attribution notices from the Source form of the Work,
+ excluding those notices that do not pertain to any part of
+ the Derivative Works; and
+
+ (d) If the Work includes a "NOTICE" text file as part of its
+ distribution, then any Derivative Works that You distribute must
+ include a readable copy of the attribution notices contained
+ within such NOTICE file, excluding those notices that do not
+ pertain to any part of the Derivative Works, in at least one
+ of the following places: within a NOTICE text file distributed
+ as part of the Derivative Works; within the Source form or
+ documentation, if provided along with the Derivative Works; or,
+ within a display generated by the Derivative Works, if and
+ wherever such third-party notices normally appear. The contents
+ of the NOTICE file are for informational purposes only and
+ do not modify the License. You may add Your own attribution
+ notices within Derivative Works that You distribute, alongside
+ or as an addendum to the NOTICE text from the Work, provided
+ that such additional attribution notices cannot be construed
+ as modifying the License.
+
+ You may add Your own copyright statement to Your modifications and
+ may provide additional or different license terms and conditions
+ for use, reproduction, or distribution of Your modifications, or
+ for any such Derivative Works as a whole, provided Your use,
+ reproduction, and distribution of the Work otherwise complies with
+ the conditions stated in this License.
+
+ 5. Submission of Contributions. Unless You explicitly state otherwise,
+ any Contribution intentionally submitted for inclusion in the Work
+ by You to the Licensor shall be under the terms and conditions of
+ this License, without any additional terms or conditions.
+ Notwithstanding the above, nothing herein shall supersede or modify
+ the terms of any separate license agreement you may have executed
+ with Licensor regarding such Contributions.
+
+ 6. Trademarks. This License does not grant permission to use the trade
+ names, trademarks, service marks, or product names of the Licensor,
+ except as required for reasonable and customary use in describing the
+ origin of the Work and reproducing the content of the NOTICE file.
+
+ 7. Disclaimer of Warranty. Unless required by applicable law or
+ agreed to in writing, Licensor provides the Work (and each
+ Contributor provides its Contributions) on an "AS IS" BASIS,
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
+ implied, including, without limitation, any warranties or conditions
+ of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
+ PARTICULAR PURPOSE. You are solely responsible for determining the
+ appropriateness of using or redistributing the Work and assume any
+ risks associated with Your exercise of permissions under this License.
+
+ 8. Limitation of Liability. In no event and under no legal theory,
+ whether in tort (including negligence), contract, or otherwise,
+ unless required by applicable law (such as deliberate and grossly
+ negligent acts) or agreed to in writing, shall any Contributor be
+ liable to You for damages, including any direct, indirect, special,
+ incidental, or consequential damages of any character arising as a
+ result of this License or out of the use or inability to use the
+ Work (including but not limited to damages for loss of goodwill,
+ work stoppage, computer failure or malfunction, or any and all
+ other commercial damages or losses), even if such Contributor
+ has been advised of the possibility of such damages.
+
+ 9. Accepting Warranty or Additional Liability. While redistributing
+ the Work or Derivative Works thereof, You may choose to offer,
+ and charge a fee for, acceptance of support, warranty, indemnity,
+ or other liability obligations and/or rights consistent with this
+ License. However, in accepting such obligations, You may act only
+ on Your own behalf and on Your sole responsibility, not on behalf
+ of any other Contributor, and only if You agree to indemnify,
+ defend, and hold each Contributor harmless for any liability
+ incurred by, or claims asserted against, such Contributor by reason
+ of your accepting any such warranty or additional liability.
+
+ END OF TERMS AND CONDITIONS
+
+ APPENDIX: How to apply the Apache License to your work.
+
+ To apply the Apache License to your work, attach the following
+ boilerplate notice, with the fields enclosed by brackets "[]"
+ replaced with your own identifying information. (Don't include
+ the brackets!) The text should be enclosed in the appropriate
+ comment syntax for the file format. We also recommend that a
+ file or class name and description of purpose be included on the
+ same "printed page" as the copyright notice for easier
+ identification within third-party archives.
+
+ Copyright [yyyy] [name of copyright owner]
+
+ Licensed under the Apache License, Version 2.0 (the "License");
+ you may not use this file except in compliance with the License.
+ You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+ Unless required by applicable law or agreed to in writing, software
+ distributed under the License is distributed on an "AS IS" BASIS,
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ See the License for the specific language governing permissions and
+ limitations under the License.
diff --git a/README.md b/README.md
new file mode 100644
index 0000000..5a09ec5
--- /dev/null
+++ b/README.md
@@ -0,0 +1,63 @@
+# Local Agent
+
+## Overview
+**Test Suite for Smart Home with Matter** is a developer platform to support 3rd-party partners build products using the [Matter](https://buildwithmatter.com/) technology, which integrates with the Google Nest ecosystem.
+
+To support partners in testing and certifying their products, **Test Suite for Smart Home with Matter** provides the tools and services to allow partners to comprehensively test their Matter products against the Google ecosystem in a scalable way.
+
+Local Agent, one of the major components of Smart Home Partner Test Suite, is a software package which can be installed in the host machines at partners' laboratory. A local agent is required for running Auto test suite to command and control DUTs and perform Matter protocol operations at the testing site. Note that partners won't need to use local agent if they only aim to run Manual test suites.
+
+In order to use local agent, partners will have to enable the Matter testability of their DUTs, please contact your **Google Developer Support**.
+
+Note that this is not an officially supported Google product.
+
+
+## Features
+
+1. Registration: The linking process to link user's local agent to the backend.
+2. Service Sync: Reporting status, polling RPC requests and sending RPC responses to the backend.
+3. Device Control: Interacting with DUT for command and control, including collecting any available logs
+
+
+## Set up
+
+1. Clone the repo (at a desired place):
+ ```
+ $ git clone https://testsuite-smarthome-matter.googlesource.com/{repo-base}
+ ```
+
+2. Create a dedicated virtual environment for developing local agent if you don't have one:
+ ```
+ $ python3 -m venv <venv-name>
+ ```
+
+ ```Important```: Make sure you're running a python version ```>= 3.7```, otherwise some packages are not supported.
+
+3. Activate the virtual environment:
+ ```
+ $ source <venv-name>/bin/activate
+ ```
+
+4. Install the local agent package in your virtual environment:
+ ```
+ $ cd /path/to/{repo-base}/local-agent
+ $ pip install .
+ ```
+
+5. Install Pigweed python packages in your virtual environment:
+Important: one of [Gazoo Device Manager](https://github.com/google/gazoo-device)'s dependencies, [Pigweed](https://pigweed.dev/), is not available on PyPI yet. You will have to manually download and install Pigweed wheels in your virtual environment. You can use the [snippet scripts on Github](https://github.com/google/gazoo-device#install) for downloading and installation.
+
+
+> For development, we might want to use `pip install -e .` for the last step.
+> (See documentation of the "-e" option [here](https://pip.pypa.io/en/stable/cli/pip_install/#cmdoption-e))
+
+
+## Usage
+
+You should be running the Smart Home testsuite from the frontend web app to link and use the local agent.
+
+2. On the web app's Test Plan page, you should see a ```Local Agent Setup``` button. Click on it, and if you haven’t linked your local agent before, you should see a linking code. Copy the linking code.
+
+3. On your host machine, activate the virtual environment where you installed the local agent, and start the local agent by simply running ```$ local-agent```.
+
+4. Once you’ve started your local agent, you should see a prompt that asks you for the linking code. Paste the linking code you copied from the web app. The local agent should start the linking process, and it will also start the device detection process, which should detect your DUTs that are connected to your host machine.
diff --git a/example_config.ini b/example_config.ini
new file mode 100644
index 0000000..c369c81
--- /dev/null
+++ b/example_config.ini
@@ -0,0 +1,18 @@
+# Copyright 2021 Google LLC
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+[LocalAgentConfig]
+AMS_HOST = localhost
+AMS_PORT = 8080
+AMS_SCHEME = http
diff --git a/local_agent/__init__.py b/local_agent/__init__.py
new file mode 100644
index 0000000..d46dbae
--- /dev/null
+++ b/local_agent/__init__.py
@@ -0,0 +1,13 @@
+# Copyright 2021 Google LLC
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
diff --git a/local_agent/ams_client.py b/local_agent/ams_client.py
new file mode 100644
index 0000000..3382a43
--- /dev/null
+++ b/local_agent/ams_client.py
@@ -0,0 +1,444 @@
+# Copyright 2021 Google LLC
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+"""Module for AmsClient class."""
+import http
+import json
+import immutabledict
+import requests
+import time
+from typing import Any, Dict, Optional, Tuple
+
+from local_agent import errors
+from local_agent import logger as logger_module
+
+
+logger = logger_module.get_logger()
+
+# TODO(b/194648316): Use production host when ready to launch.
+# The default AMS server configuration.
+_AMS_SCHEME = 'https'
+_AMS_HOST = 'chip-testsuite-experimental.appspot.com'
+_AMS_PORT = None
+
+# _ENDPOINTS stores the endpoint definitions.
+# _TIMEOUTS stores the timeout config for the endpoints.
+_ENDPOINTS = immutabledict.immutabledict({
+ 'REGISTER': '/tsb/api/local-agent/register',
+ 'AUTH': '/tsb/api/local-agent/auth',
+ 'REPORT_INFO': '/tsb/api/local-agent/report-info',
+ 'GET_RPC_REQUEST': '/tsb/api/local-agent/rpc/request',
+ 'DELETE_RPC_REQUEST': '/tsb/api/local-agent/rpc/request/{rpc_id}',
+ 'SEND_RPC_RESPONSE': '/tsb/api/local-agent/rpc/response',
+ 'UPLOAD_ARTIFACT': '/tsb/api/local-agent/artifacts',
+})
+_TIMEOUTS = immutabledict.immutabledict({})
+_DEFAULT_TIMEOUT = 5.0
+_DEFAULT_RETRY = 3
+_DEFAULT_RETRY_INTERVAL = 1.0
+_UNLINKED_ERR_MSG = 'Invalid agent id'
+
+
+def extract_error_message_from_api_response(
+ response: requests.models.Response) -> Optional[str]:
+ """Extracts error message from an AMS API response.
+
+ Args:
+ response: The AMS API response.
+
+ Returns:
+ The error message as string, or None if no AMS error message found.
+ """
+ try:
+ return response.json()['errorMessage']
+ except json.decoder.JSONDecodeError:
+ logger.warning('API response cannot be parsed as JSON.')
+ except KeyError:
+ logger.warning('API response does not have errorMessage field')
+ return None
+
+
+def _local_agent_is_unlinked(
+ response_status_code: int, error_msg: Optional[str]) -> bool:
+ """Determines whether local agent is unlinked.
+
+ Args:
+ response_status_code: The AMS API response status code.
+ error_msg: The AMS API response error message.
+
+ Returns:
+ True if the local agent is unlinked, false otherwise.
+ """
+ return (response_status_code == http.HTTPStatus.BAD_REQUEST
+ and error_msg is not None
+ and _UNLINKED_ERR_MSG in error_msg)
+
+
+class AmsClient:
+ """Client class to communicate with Agent Management Service (AMS).
+
+ Local Agent uses this client class to communicate with AMS, i.e., to send
+ APIs to AMS and receive responses.
+ """
+
+ def __init__(self,
+ host: Optional[str] = None,
+ port: Optional[int] = None,
+ scheme: Optional[str] = None):
+ """Initializes the AMS client.
+
+ The host, port, and scheme parameters are all optional. If not
+ provided, we will use the defaults.
+
+ Args:
+ host: The host address of AMS server.
+ port: The port of AMS server.
+ scheme: Either 'http' or 'https'.
+ """
+ host = host or _AMS_HOST
+ port = port or _AMS_PORT
+ scheme = scheme or _AMS_SCHEME
+ if port is not None:
+ self._base_url = f'{scheme}://{host}:{port}'
+ else:
+ self._base_url = f'{scheme}://{host}'
+
+ # Credentials
+ self._local_agent_id = None
+ self._local_agent_secret = None
+
+ # HTTP request session.
+ self._session = requests.Session()
+
+ def set_local_agent_credentials(self,
+ local_agent_id: int,
+ local_agent_secret: str) -> None:
+ """Sets the local agent ID and local agent secret.
+
+ If the local agent has been registered before, use this method to set
+ the credentials instead of calling register() again.
+
+ Args:
+ local_agent_id: The agent ID.
+ local_agent_secret: The agent secret.
+ """
+ self._local_agent_id = local_agent_id
+ self._local_agent_secret = local_agent_secret
+ self._get_auth_token()
+
+ def register(self, linking_code: str) -> Tuple[str, str]:
+ """Registers a local agent with linking code.
+
+ This method registers a new local agent and will retrieve local agent
+ credentials. If you already have credentials for your local agent. Use
+ set_local_agent_credentials() instead.
+
+ Args:
+ linking_code: The linking code obtained from frontend UI.
+
+ Returns:
+ The local agent ID and local agent secret.
+
+ Raises:
+ ApiTimeoutError: If the register API timed out.
+ CredentialsError: If the linking code is not valid.
+ ApiError: If the registration API failed.
+ """
+ endpoint = self._base_url + _ENDPOINTS['REGISTER']
+ timeout = _TIMEOUTS.get('REGISTER') or _DEFAULT_TIMEOUT
+ post_data = {'linkingCode': linking_code}
+
+ try:
+ response = self._request_wrapper(
+ method='POST',
+ url=endpoint,
+ json=post_data,
+ timeout=timeout)
+ except errors.ApiTimeoutError:
+ logger.error('Register API timed out.')
+ raise
+
+ if response.status_code == http.HTTPStatus.BAD_REQUEST:
+ ams_err_msg = extract_error_message_from_api_response(response)
+ raise errors.CredentialsError(
+ f'Not a valid linking code: {linking_code}, '
+ f'AMS error message: {ams_err_msg}')
+ if response.status_code != http.HTTPStatus.OK:
+ ams_err_msg = extract_error_message_from_api_response(response)
+ raise errors.ApiError(
+ f'Registration failed with status {response.status_code}, '
+ f'AMS error message: {ams_err_msg}')
+
+ result = response.json()['result']
+ self.set_local_agent_credentials(
+ local_agent_id=result['agentId'],
+ local_agent_secret=result['agentSecret'])
+ return result['agentId'], result['agentSecret']
+
+ def report_info(self, info: Dict[str, Any]) -> None:
+ """Reports local agent info.
+
+ Calling REPORT_INFO API to send the info of local agent to AMS.
+
+ Args:
+ info: The information of local agent.
+
+ Raises:
+ ApiTimeoutError: If the report info API timed out.
+ ApiError: If API response status is not 200 OK.
+ """
+ endpoint = self._base_url + _ENDPOINTS['REPORT_INFO']
+ timeout = _TIMEOUTS.get('REGISTER') or _DEFAULT_TIMEOUT
+ try:
+ response = self._request_wrapper(
+ method='POST',
+ url=endpoint,
+ json=info,
+ timeout=timeout)
+ except errors.ApiTimeoutError:
+ logger.error('Report info API timed out.')
+ raise
+ if response.status_code != http.HTTPStatus.OK:
+ ams_err_msg = extract_error_message_from_api_response(response)
+ raise errors.ApiError(
+ f'Report info API failed: status {response.status_code}, '
+ f'AMS error message: {ams_err_msg}')
+
+ def get_rpc_request_from_ams(self) -> Optional[Dict[str, Any]]:
+ """Gets an RPC request from AMS, returns None if there isn't any.
+
+ Returns:
+ The JSON-RPC request object, a dictionary. Or None if there's no
+ request from AMS.
+
+ Raises:
+ ApiTimeoutError: If the get JSON-RPC request API timed out.
+ ApiError: If API response has an error status code.
+ """
+ endpoint = self._base_url + _ENDPOINTS['GET_RPC_REQUEST']
+ timeout = _TIMEOUTS.get('GET_RPC_REQUEST') or _DEFAULT_TIMEOUT
+
+ try:
+ response = self._request_wrapper(
+ method='GET',
+ url=endpoint,
+ timeout=timeout)
+ except errors.ApiTimeoutError:
+ logger.error('Get RPC request API timed out.')
+ raise
+ if response.status_code == http.HTTPStatus.OK:
+ return response.json().get('result')
+ elif response.status_code == http.HTTPStatus.NO_CONTENT:
+ return None
+ else:
+ ams_err_msg = extract_error_message_from_api_response(response)
+ raise errors.ApiError(
+ f'Get RPC request API failed: status {response.status_code}, '
+ f'AMS error message: {ams_err_msg}')
+
+ def remove_rpc_request_from_ams(self,
+ rpc_request: Dict[str, Any]) -> None:
+ """Removes the RPC request from AMS.
+
+ Local agent should call this method to remove a JSON-RPC request from
+ AMS after retrieving the request during polling. Otherwise, the next
+ iteration of the polling loop will get the same JSON-RPC request.
+
+ Args:
+ rpc_request: The JSON-RPC request to remove from AMS.
+
+ Raises:
+ ApiTimeoutError: If the delete JSON-RPC request API timed out.
+ ApiError: If API response has an unexpected status code.
+ """
+ rpc_id = rpc_request.get('id')
+ endpoint = (self._base_url +
+ _ENDPOINTS['DELETE_RPC_REQUEST'].format(rpc_id=rpc_id))
+ timeout = _TIMEOUTS.get('DELETE_RPC_REQUEST') or _DEFAULT_TIMEOUT
+
+ try:
+ response = self._request_wrapper(
+ method='DELETE',
+ url=endpoint,
+ timeout=timeout)
+ except errors.ApiTimeoutError:
+ logger.error('Remove RPC request API timed out.')
+ raise
+ if response.status_code != http.HTTPStatus.OK:
+ ams_err_msg = extract_error_message_from_api_response(response)
+ raise errors.ApiError(
+ f'Remove RPC API failed: status {response.status_code}, '
+ f'AMS error message: {ams_err_msg}')
+
+ def send_rpc_response(self, rpc_response: Dict[str, Any]) -> None:
+ """Sends a JSON-RPC response to AMS.
+
+ We will try at most 3 times if the API did not succeed.
+
+ Args:
+ rpc_response: The JSON-RPC response to send.
+
+ Raises:
+ ApiTimeoutError: If API timed out.
+ ApiError: If API response has unexpected status code.
+ """
+ endpoint = self._base_url + _ENDPOINTS['SEND_RPC_RESPONSE']
+ timeout = _TIMEOUTS.get('SEND_RPC_RESPONSE') or _DEFAULT_TIMEOUT
+
+ try:
+ response = self._request_wrapper(
+ method='POST',
+ url=endpoint,
+ json=rpc_response,
+ timeout=timeout)
+ except errors.ApiTimeoutError:
+ logger.error('Send RPC response API timed out.')
+ raise
+
+ if response.status_code != http.HTTPStatus.OK:
+ ams_err_msg = extract_error_message_from_api_response(response)
+ raise errors.ApiError(
+ f'Send RPC response API failed. Status: '
+ f'{response.status_code}, AMS error message: {ams_err_msg}')
+
+ def upload_artifact(self, artifact_path: str, test_result_id: str) -> None:
+ """Uploads a test artifact (file) to AMS.
+
+ Args:
+ artifact_path: Path to the artifact.
+ test_result_id: ID of the test result this artifact belongs to.
+
+ Raises:
+ ApiTimeoutError: If API timed out.
+ ApiError: If API response has unexpected status code.
+ """
+ endpoint = self._base_url + _ENDPOINTS['UPLOAD_ARTIFACT']
+ timeout = _TIMEOUTS.get('UPLOAD_ARTIFACT') or _DEFAULT_TIMEOUT
+
+ try:
+ with open(artifact_path, 'rb') as artifact_stream:
+ response = self._request_wrapper(
+ method='POST',
+ url=endpoint,
+ files={'file': artifact_stream},
+ data={'testResultId': test_result_id},
+ timeout=timeout)
+ except errors.ApiTimeoutError:
+ logger.error('Upload artifact API timed out.')
+ raise
+ if response.status_code != http.HTTPStatus.OK:
+ ams_err_msg = extract_error_message_from_api_response(response)
+ raise errors.ApiError(
+ f'Upload artifact API failed. Status: {response.status_code}, '
+ f'AMS error message: {ams_err_msg}')
+
+ def _get_auth_token(self) -> None:
+ """Gets auth token, and sets the HTTP session header.
+
+ Raises:
+ ApiTimeoutError: If the auth API timed out.
+ CredentialsError: If cannot get auth token with the local agent
+ credentials.
+ """
+ endpoint = self._base_url + _ENDPOINTS['AUTH']
+ timeout = _TIMEOUTS.get('AUTH') or _DEFAULT_TIMEOUT
+ post_data = {'agentId': self._local_agent_id,
+ 'agentSecret': self._local_agent_secret}
+
+ try:
+ response = self._request_wrapper(
+ refresh_auth=False,
+ method='POST',
+ url=endpoint,
+ json=post_data,
+ timeout=timeout)
+ except errors.ApiTimeoutError:
+ logger.error('Auth API timed out.')
+ raise
+
+ if response.status_code != http.HTTPStatus.CREATED:
+ ams_err_msg = extract_error_message_from_api_response(response)
+ if _local_agent_is_unlinked(response.status_code, ams_err_msg):
+ raise errors.UnlinkedError(
+ f'Local agent is unlinked: {ams_err_msg}.')
+ else:
+ raise errors.CredentialsError(
+ f'Cannot get auth token. Status {response.status_code}, '
+ f'AMS error message: {ams_err_msg}')
+
+ auth_token = response.json()['result']['authToken']
+ self._session.headers.update({'Authorization': auth_token})
+
+ def _request_wrapper(self,
+ *,
+ num_retries: int = _DEFAULT_RETRY,
+ refresh_auth: bool = True,
+ **kwargs) -> requests.models.Response:
+ """Sends a request and provides retry functionality.
+
+ We will retry when:
+ 1) the API response has a 4xx or 5xx status code, or
+ 2) the request has timed out.
+
+ If the retries exhausted, we return the last response, or raise
+ ApiTimeoutError if the last request timed out.
+
+ Args:
+ num_retries: The number of retries if API did not succeed. If a
+ negative integer or zero is provided, we don't retry.
+ refresh_auth: If True, we will refresh the auth token when the
+ first response has a 401 UNAUTHORIZED status. Refreshing auth
+ token is not considered a retry.
+ **kwargs: Other keyword arguments that will be passed to the
+ requests.sessions.Session.request method.
+
+ Returns:
+ The API response.
+
+ Raises:
+ ApiTimeoutError: If the last request timed out.
+ """
+ # Ensure we set a timeout.
+ if 'timeout' not in kwargs:
+ kwargs['timeout'] = _DEFAULT_TIMEOUT
+
+ if num_retries <= 0:
+ num_retries = 0
+
+ current_try = 0
+ last_exception = None
+
+ while current_try <= num_retries:
+ current_try += 1
+ try:
+ response = self._session.request(**kwargs)
+ except requests.exceptions.Timeout:
+ last_exception = errors.ApiTimeoutError()
+ else:
+ last_exception = None
+ status_code = response.status_code
+ if not (400 <= status_code <= 599):
+ return response
+ if (status_code == http.HTTPStatus.UNAUTHORIZED and
+ refresh_auth and
+ current_try == 1):
+ # Only refresh token when first attempt is UNAUTHORIZED.
+ # And refresh token will not be considered a retry.
+ self._get_auth_token()
+ current_try -= 1
+ time.sleep(_DEFAULT_RETRY_INTERVAL)
+
+ if last_exception is not None:
+ raise last_exception
+ return response
diff --git a/local_agent/errors.py b/local_agent/errors.py
new file mode 100644
index 0000000..317533f
--- /dev/null
+++ b/local_agent/errors.py
@@ -0,0 +1,95 @@
+# Copyright 2021 Google LLC
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+"""Module for Local Agent related errors.
+
+The error subclasses are intended to make it easier to distinguish between and
+handle different types of error exceptions.
+"""
+DEFAULT_ERROR_CODE = 0
+
+
+class CredentialsExpiredError(Exception):
+ """Raised when the user's credentials (auth token) are expired"""
+ err_code = 1
+
+
+class ServiceError(Exception):
+ """Raised when a service error occurs"""
+ err_code = 2
+
+
+class InvalidRPCError(Exception):
+ """Raised when an invalid rpc command is given"""
+ err_code = 3
+
+
+class DeviceNotConnectedError(Exception):
+ """Raised when querying an unconnected device"""
+ err_code = 4
+
+
+class HandlerNotFoundError(Exception):
+ """Raised when no correct handler is found"""
+ err_code = 5
+
+
+class HandlersCollisionError(Exception):
+ """Raised when 2 registered handlers share the same method when they are instantiated"""
+ err_code = 6
+
+
+class HandlerInvalidError(Exception):
+ """Raised when the registered handler is not a subclass of BaseCommandHandler"""
+ err_code = 7
+
+
+class DeviceNotOpenError(Exception):
+ """Raised when accessing a device which is not open yet."""
+ err_code = 8
+
+
+class InvalidTestSuiteSessionError(Exception):
+ """Raised when encountering an invalid testsuite session."""
+ err_code = 9
+
+
+class AmsClientError(Exception):
+ """Base exception for any AmsClient exceptions."""
+ err_code = 10
+
+
+class CredentialsError(AmsClientError):
+ """Exception for invalid local agent credentials."""
+ err_code = 11
+
+
+class ApiTimeoutError(AmsClientError):
+ """Exception for AMS API timed out."""
+ err_code = 12
+
+
+class ApiError(AmsClientError):
+ """Exception for unexpected status code in API response."""
+ err_code = 13
+
+
+class UnlinkedError(CredentialsExpiredError):
+ """Raised when the local agent is unlinked from AMS."""
+ err_code = 14
+
+
+class RpcTimeOutError(InvalidRPCError):
+ """Raised when the RPC request handling times out."""
+ err_code = 15
diff --git a/local_agent/local_agent.py b/local_agent/local_agent.py
new file mode 100644
index 0000000..cc97bf6
--- /dev/null
+++ b/local_agent/local_agent.py
@@ -0,0 +1,577 @@
+# Copyright 2021 Google LLC
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+"""Module for LocalAgentProcess."""
+import argparse
+from concurrent import futures
+import configparser
+import enum
+import json
+import os
+import shutil
+import signal
+import sys
+import threading
+import time
+import traceback
+from typing import Any, Dict, Optional, Tuple
+
+import gazoo_device
+
+from local_agent import errors
+from local_agent import ams_client
+from local_agent import suite_session_manager
+from local_agent import version
+from local_agent import logger as logger_module
+from local_agent.translation_layer import translation_layer
+
+
+logger = logger_module.get_logger()
+
+# ========================= Constants / Configs ========================= #
+APP_ENGINE_DATA_SIZE_LIMIT = 31 * 1048576 # 31 MB
+APP_ENGINE_DATA_SIZE_LIMIT_HUMAN_READABLE = '31 MB'
+DEFAULT_ARTIFACTS_DIR = '/tmp/local_agent_artifacts'
+AUTH_FILE = os.path.expanduser('~/.config/google/matter_local_agent_auth.json')
+DEFAULT_USER_CONFIG = os.path.expanduser(
+ '~/.config/google/local_agent_config.ini')
+_START_TEST_SUITE_METHOD = 'startTestSuite'
+_END_TEST_SUITE_METHOD = 'endTestSuite'
+_USER_CONFIG_ROOT_KEY = 'LocalAgentConfig'
+_USER_CONFIG_AMS_HOST = 'AMS_HOST'
+_USER_CONFIG_AMS_PORT = 'AMS_PORT'
+_USER_CONFIG_AMS_SCHEME = 'AMS_SCHEME'
+_USER_CONFIG_ARTIFACTS_DIR = 'ARTIFACTS_DIR'
+# ======================================================================= #
+
+
+class RpcRequestType(enum.Enum):
+ """RPC request type enum."""
+ START_TEST_SUITE = enum.auto()
+ END_TEST_SUITE = enum.auto()
+ DEVICE_QUERY_CONTROL = enum.auto()
+
+# ======================== Module level functions ========================== #
+def rpc_request_type(method: str) -> RpcRequestType:
+ """RPC request type selector.
+
+ Args:
+ method: JSON-RPC request method.
+
+ Returns:
+ Request type.
+ """
+ if method == _START_TEST_SUITE_METHOD:
+ return RpcRequestType.START_TEST_SUITE
+ elif method == _END_TEST_SUITE_METHOD:
+ return RpcRequestType.END_TEST_SUITE
+ else:
+ return RpcRequestType.DEVICE_QUERY_CONTROL
+# ========================================================================== #
+
+
+class LocalAgentProcess:
+ """Local Agent Process.
+
+ A continuously running process which constantly sends GET
+ requests to the Agent Management Service for incoming RPC
+ requests.
+ """
+
+ _REPORT_INFO_INTERVAL_SECONDS = 30
+ _REPORT_INFO_THREAD_TERMINATION_WAIT_SECONDS = 10
+ _POLL_RPC_COOL_DOWN_SECONDS = 1
+ _POLL_RPC_THREAD_TERMINATION_WAIT_SECONDS = 10
+ _MAIN_THREAD_KEEP_ALIVE_COOLDOWN_SECONDS = 1
+
+ _MAX_WORKERS_FOR_RPC_EXECUTION_THREAD_POOL = 5
+
+ # Status for local agent, defined in the AMS.
+ # Note that we don't need 'OFFLINE' status because that is determined by
+ # the AMS.
+ _STATUS_RUNNING = 'RUNNING'
+ _STATUS_IDLE = 'IDLE'
+
+ def __init__(self, client: ams_client.AmsClient, artifacts_dir: str):
+ """Initializes local agent.
+
+ Args:
+ client: The AmsClient instance.
+ artifacts_dir: Artifacts directory.
+ """
+ self._ams_client = client
+ self._artifacts_dir = artifacts_dir
+
+ # Threads
+ self._termination_event = threading.Event()
+ self._rpc_polling_thread = None
+ self._info_reporting_thread = None
+ self._rpc_execution_thread_pool = futures.ThreadPoolExecutor(
+ max_workers=self._MAX_WORKERS_FOR_RPC_EXECUTION_THREAD_POOL)
+
+ # Store IDs of running Futures.
+ self._rpc_execution_future_ids = set()
+
+ # Translation Layer
+ self._translator = translation_layer.TranslationLayer(client)
+
+ # Suite session manager
+ self._suite_mgr = suite_session_manager.SuiteSessionManager(
+ artifacts_fn=self._compress_artifacts_and_upload,
+ artifact_root_dir=artifacts_dir,
+ create_devices_fn=self._translator.create_devices,
+ close_devices_fn=self._translator.close_devices)
+
+ def run(self) -> None:
+ """Runs the local agent, starting polling JSON-RPC and reporting info.
+
+ We start two threads, one of polling JSON-RPC from AMS and the other
+ for reporting info to AMS.
+ """
+ if not self._setup_credentials():
+ logger.warning('Local Agent linking failed, exit the process.')
+ return
+ logger.info('Local Agent is linked successfully.')
+
+ # Register termination signal handler
+ signal.signal(signal.SIGINT, self._terminate)
+
+ self._translator.start(termination_event=self._termination_event)
+ self._suite_mgr.start(termination_event=self._termination_event)
+ self._start_info_reporting_thread()
+ self._start_rpc_polling_thread()
+ while not self._termination_event.is_set():
+ # If any top level thread is dead, terminate the local agent.
+ if self._terminate_if_thread_not_running(
+ self._info_reporting_thread):
+ break
+ if self._terminate_if_thread_not_running(
+ self._rpc_polling_thread):
+ break
+
+ time.sleep(self._MAIN_THREAD_KEEP_ALIVE_COOLDOWN_SECONDS)
+
+ def _terminate_if_thread_not_running(
+ self, target_thread: Optional[threading.Thread]) -> bool:
+ """Terminates local agent if the target thread is not running.
+
+ Args:
+ target_thread: The thread to check running.
+
+ Returns:
+ True if termination procedure initiated, i.e., the thread is not
+ running.
+ """
+ if target_thread is None or not target_thread.is_alive():
+ logger.error('Thread is dead or not even started.')
+ self._terminate(None, None)
+ return True
+ return False
+
+ def _start_info_reporting_thread(self) -> None:
+ """Starts the _report_info job in a thread."""
+ if self._info_reporting_thread is None:
+ self._info_reporting_thread = threading.Thread(
+ target=self._report_info, name='Info-reporting-thread')
+ if not self._info_reporting_thread.is_alive():
+ self._info_reporting_thread.start()
+
+ def _start_rpc_polling_thread(self) -> None:
+ """Starts the _poll_rpc job in a thread."""
+ if self._rpc_polling_thread is None:
+ self._rpc_polling_thread = threading.Thread(
+ target=self._poll_rpc, name='RPC-polling-thread')
+ if not self._rpc_polling_thread.is_alive():
+ self._rpc_polling_thread.start()
+
+ def _start_rpc_execution_thread(self, rpc_request: Dict[str, Any]) -> None:
+ """Submits a _execute_rpc job to RPC execution thread pool."""
+ future = self._rpc_execution_thread_pool.submit(self._execute_rpc,
+ rpc_request)
+ self._rpc_execution_future_ids.add(id(future))
+ future.add_done_callback(self._callback_for_rpc_execution_complete)
+
+ def _report_info(self) -> None:
+ """Reports local agent information back to AMS periodically.
+
+ Note that this method contains an infinite loop, and is designed to be
+ run in a separate thread, instead of the main thread.
+ Typically we should not invoke this method directly, and instead we
+ use _start_info_reporting_thread method.
+
+ The information being reported includes:
+ - The devices connected to this local agent.
+ - The version of this local agent.
+ - GDM version in use.
+ - The status of this local agent.
+ """
+ while True:
+ logger.info('Reporting status to AMS.')
+
+ devices = self._translator.detect_devices()
+ status = (self._STATUS_RUNNING if self._rpc_execution_future_ids
+ else self._STATUS_IDLE)
+ local_agent_info = {
+ 'devices': devices,
+ 'gdmVersion': gazoo_device.__version__,
+ 'status': status,
+ 'version': version.__version__,
+ }
+ try:
+ self._ams_client.report_info(local_agent_info)
+ except errors.ApiTimeoutError:
+ logger.warning('Report info API timed out.')
+ except errors.ApiError as e:
+ logger.warning('Report status failed. %s', e)
+ except errors.UnlinkedError:
+ logger.warning('The local agent is unlinked.')
+ self._clean_up_and_terminate_agent(remove_auth_file=True)
+ break
+
+ if self._termination_event.wait(self._REPORT_INFO_INTERVAL_SECONDS):
+ break
+
+ logger.info('Stopped reporting info because stop event is set.')
+
+ def _poll_rpc(self) -> None:
+ """Polls AMS for JSON-RPC requests.
+
+ Note that this method contains an infinite loop, and is designed to be
+ run in a separate thread, instead of the main thread.
+ """
+ while not self._termination_event.is_set():
+ logger.info('Polling JSON-RPC requests from AMS.')
+
+ try:
+ rpc_request = self._ams_client.get_rpc_request_from_ams()
+ except errors.ApiTimeoutError:
+ logger.warning('Get RPC request API timed out.')
+ rpc_request = None
+ except errors.ApiError as e:
+ logger.warning(f'Failed to get RPC request from AMS: {e}')
+ rpc_request = None
+ except errors.UnlinkedError:
+ logger.warning('The local agent is unlinked.')
+ self._clean_up_and_terminate_agent(remove_auth_file=True)
+ break
+
+ if rpc_request is not None:
+ try:
+ self._ams_client.remove_rpc_request_from_ams(rpc_request)
+ except (errors.ApiTimeoutError, errors.ApiError):
+ logger.exception(
+ 'Failed to remove JSON-RPC request from AMS.'
+ 'Terminating the local agent process.')
+ self._clean_up_and_terminate_agent()
+ break
+ self._start_rpc_execution_thread(rpc_request)
+
+ time.sleep(self._POLL_RPC_COOL_DOWN_SECONDS)
+ logger.info('Stopped polling RPC because stop event is set.')
+
+ def _execute_rpc(self, rpc_request: Dict[str, Any]) -> None:
+ """Executes the JSON-RPC request, and sends result back to AMS."""
+ rpc_id = rpc_request.get('id')
+ logger.info(f'Executing JSON-RPC: {rpc_id}')
+
+ rpc_response = self._handle_rpc_request(rpc_request)
+
+ if self._translator.is_rpc_timeout(rpc_id):
+ logger.warning(f'RPC {rpc_id} has timed out, ignoring response.')
+ else:
+ try:
+ self._ams_client.send_rpc_response(rpc_response)
+ except (errors.ApiTimeoutError, errors.ApiError):
+ logger.exception('Failed to send JSON-RPC response to AMS.')
+
+ def _handle_rpc_request(
+ self, rpc_request: Dict[str, Any]) -> Dict[str, Any]:
+ """Handles the JSON-RPC request and returns the response.
+
+ Args:
+ rpc_request: JSON-RPC request.
+
+ Returns:
+ JSON-RPC response.
+ """
+ req_type = rpc_request_type(rpc_request['method'])
+
+ try:
+ if req_type == RpcRequestType.START_TEST_SUITE:
+ resp = self._suite_mgr.start_test_suite(rpc_request)
+
+ elif req_type == RpcRequestType.END_TEST_SUITE:
+ resp = self._suite_mgr.end_test_suite(rpc_request)
+
+ elif req_type == RpcRequestType.DEVICE_QUERY_CONTROL:
+ resp = self._translator.dispatch_to_cmd_handler(rpc_request)
+
+ else:
+ raise errors.InvalidRPCError(
+ f'Invalid RPC request type {req_type}.')
+
+ except Exception as e:
+ logger.exception('Error when handling JSON-RPC.')
+ err_resp = {'id': rpc_request['id'], 'jsonrpc': '2.0'}
+ err_code = getattr(e, 'err_code', errors.DEFAULT_ERROR_CODE)
+ stack_trace = traceback.format_exc()
+ err_msg = stack_trace if err_code == errors.DEFAULT_ERROR_CODE else str(e)
+ err_resp['error'] = {'code': err_code, 'message': err_msg}
+ return err_resp
+
+ return resp
+
+ def _callback_for_rpc_execution_complete(self,
+ future: futures.Future) -> None:
+ """Callback function when an RPC execution is complete.
+
+ What this callback does:
+ 1) We remove the id(future) from the self._rpc_execution_future_ids.
+ 2) We log the exception if there is one.
+
+ This callback should be registered to an RPC execution future object
+ using Future.add_done_callback() instead of being called directly.
+
+ Args:
+ future: The Future object when we submit an RPC execution job to
+ the ThreadPoolExecutor.
+ """
+ future_id = id(future)
+ self._rpc_execution_future_ids.remove(future_id)
+ exc = future.exception()
+ if exc is not None:
+ logger.error('RPC execution encounters exception: %s', exc)
+
+ def _terminate(self, sig_num, frame) -> None:
+ """Termination procedure for local agent. A signal handler.
+
+ We set the termination event and stop the 2 top-level threads:
+ info-reporting and rpc-polling. Also shutdown the ThreadPoolExecutor
+ for RPC execution.
+
+ Args:
+ sig_num: Signal number passed to a signal handler. See:
+ https://docs.python.org/3/library/signal.html#signal.signal
+ frame: Current stack frame passed to a signal hanlder. See:
+ https://docs.python.org/3/library/signal.html#signal.signal
+ """
+ del sig_num, frame # Unused.
+
+ logger.warning('Terminating local agent process.')
+
+ self._clean_up_and_terminate_agent()
+
+ thread_and_wait_time = (
+ (self._rpc_polling_thread,
+ self._POLL_RPC_THREAD_TERMINATION_WAIT_SECONDS),
+ (self._info_reporting_thread,
+ self._REPORT_INFO_THREAD_TERMINATION_WAIT_SECONDS))
+
+ for thread, wait_time in thread_and_wait_time:
+ if thread is None:
+ # The thread wasn't even created. Skipping.
+ continue
+ logger.info(f'Waiting for thead {thread.name} to stop. '
+ f'(Timeout = {wait_time} seconds)')
+ thread.join(timeout=wait_time)
+ if thread.is_alive():
+ logger.error('Thread %s still alive after waiting %s seconds.',
+ thread.name, wait_time)
+
+ self._rpc_execution_thread_pool.shutdown(wait=False)
+
+ logger.warning('Local agent process terminated.')
+
+ def _read_auths(self) -> Tuple[str, str]:
+ """Read agent auths from local config.
+
+ Reads the stored agent_id and agent_secret locally.
+
+ Returns:
+ Tuple of agent_id and agent_secret.
+ """
+ with open(AUTH_FILE, 'r') as fstream:
+ auths = json.load(fstream)
+ return auths['agent_id'], auths['agent_secret']
+
+ def _write_auths(self, agent_id: str, agent_secret: str) -> None:
+ """Writes agent auths into local config.
+
+ Stores the agent_id and agent_secret into the credential auth file
+ locally.
+
+ Args:
+ agent_id: local agent id.
+ agent_secret: local agent secret.
+ """
+ with open(AUTH_FILE, 'w') as fstream:
+ auths = {'agent_id': agent_id, 'agent_secret': agent_secret}
+ json.dump(auths, fstream)
+
+ def _setup_credentials(self) -> bool:
+ """Sets up credentials.
+
+ Read credentials from a local auth file. If credentials not available
+ or expired, start the register process.
+
+ Returns:
+ True if credentials are set up successfully. False otherwise.
+ """
+ credentials_set_up = False
+ try:
+ agent_id, agent_secret = self._read_auths()
+ self._ams_client.set_local_agent_credentials(
+ local_agent_id=agent_id,
+ local_agent_secret=agent_secret)
+ credentials_set_up = True
+ except (
+ FileNotFoundError, errors.CredentialsError, errors.UnlinkedError):
+ logger.info('Start the linking process')
+ while True:
+ try:
+ linking_code = input('Linking Code:')
+ agent_id, agent_secret = self._ams_client.register(
+ linking_code=linking_code)
+ self._write_auths(agent_id, agent_secret)
+ credentials_set_up = True
+ break
+ except errors.CredentialsError as e:
+ # We don't use logger.exception here in order not to
+ # overwhelm the user interface.
+ logger.warning(
+ 'Invalid linking code. Please retry. (%s)', e)
+ except errors.ApiTimeoutError as e:
+ logger.warning(
+ 'Register API timed out. Please retry. (%s)', e)
+ except errors.ApiError as e:
+ logger.warning(
+ 'Agent registration failed. Please retry. (%s)', e)
+ except KeyboardInterrupt:
+ break
+ return credentials_set_up
+
+ def _compress_artifacts_and_upload(
+ self,
+ test_suite_id: str,
+ test_result_id: Optional[str] = None) -> None:
+ """Compresses the artifacts and uploads if needed.
+
+ Compresses the artifact directory, and uploads the artifacts
+ to AMS if the test result ID is provided.
+
+ Args:
+ test_suite_id: Test suite ID.
+ test_result_id: Test result ID.
+
+ Raises:
+ RuntimeError: Uploading fails.
+ """
+ logger.info(f'Compressing artifacts for {test_suite_id}')
+
+ test_suite_dir = os.path.join(self._artifacts_dir, test_suite_id)
+
+ # remove logging handler
+ local_agent_log = os.path.join(
+ self._artifacts_dir, test_suite_id, 'local_agent.log')
+ logger_module.remove_file_handler(local_agent_log)
+
+ # compress the artifacts directory and remove it after the compression
+ shutil.make_archive(test_suite_dir, 'gztar', test_suite_dir)
+ shutil.rmtree(test_suite_dir)
+
+ if test_result_id is not None:
+ logger.info(f'Uploading artifacts for {test_suite_id}')
+
+ # check against file size limit.
+ artifacts_name = test_suite_dir + '.tar.gz'
+ if os.stat(artifacts_name).st_size >= APP_ENGINE_DATA_SIZE_LIMIT:
+ raise RuntimeError(
+ f'The file size of {artifacts_name} is larger than '
+ f'{APP_ENGINE_DATA_SIZE_LIMIT_HUMAN_READABLE}.')
+
+ # upload the artifact
+ try:
+ self._ams_client.upload_artifact(artifacts_name,
+ test_result_id=test_result_id)
+ except (errors.ApiTimeoutError, errors.ApiError):
+ logger.exception('Failed to upload artifact.')
+
+ def _clean_up_and_terminate_agent(
+ self, remove_auth_file: bool = False) -> None:
+ """Cleanup method for local agent.
+
+ Cleans up suite session and sets the terminate event.
+ Removes the auth file if remove_auth_file is true.
+
+ Args:
+ remove_auth_file: To remove the auth file or not.
+ """
+ if remove_auth_file:
+ if os.path.exists(AUTH_FILE):
+ os.remove(AUTH_FILE)
+ self._suite_mgr.clean_up()
+ self._termination_event.set()
+
+
+def read_config() -> Dict[str, Any]:
+ """Reads user data from configuration file.
+
+ The config file should be in YAML format. User can specify the path to
+ their config file using command line argument. If not provided, we use the
+ default path, DEFAULT_USER_CONFIG.
+
+ The configuration file is not required. An empty dict is returned when
+ there's no such config file.
+
+ Raises:
+ RuntimeError: If unable to parse the config file as YAML.
+
+ Returns:
+ User configuration data.
+ """
+ parser = argparse.ArgumentParser()
+ parser.add_argument('-u', '--user_config', type=str, required=False,
+ default=DEFAULT_USER_CONFIG,
+ help='Local Agent user config file.')
+ args, leftover = parser.parse_known_args(sys.argv[1:])
+ sys.argv[1:] = leftover
+
+ if not os.path.exists(args.user_config):
+ return {}
+
+ config = configparser.ConfigParser()
+ config.read(args.user_config)
+
+ if _USER_CONFIG_ROOT_KEY not in config:
+ raise ValueError(
+ f'Invalid config file, no section {_USER_CONFIG_ROOT_KEY}.'
+ 'Please refer to example_config.ini for reference.')
+
+ return config[_USER_CONFIG_ROOT_KEY]
+
+
+def main() -> None:
+ """Main entry of Local Agent."""
+ user_config = read_config()
+ ams_host = user_config.get(_USER_CONFIG_AMS_HOST)
+ ams_port = user_config.get(_USER_CONFIG_AMS_PORT)
+ ams_scheme = user_config.get(_USER_CONFIG_AMS_SCHEME)
+ artifacts_dir = (
+ user_config.get(_USER_CONFIG_ARTIFACTS_DIR, DEFAULT_ARTIFACTS_DIR))
+
+ client = ams_client.AmsClient(host=ams_host,
+ port=ams_port,
+ scheme=ams_scheme)
+ proc = LocalAgentProcess(client=client, artifacts_dir=artifacts_dir)
+ proc.run()
diff --git a/local_agent/logger.py b/local_agent/logger.py
new file mode 100644
index 0000000..155d388
--- /dev/null
+++ b/local_agent/logger.py
@@ -0,0 +1,113 @@
+# Copyright 2021 Google LLC
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+"""Module for Local Agent logger."""
+import logging
+from logging import handlers
+from typing import Optional
+
+
+log_handler_map = {}
+logger = logging.getLogger(name='local_agent')
+
+_DEFAULT_LOG_FILE = 'local_agent.log'
+
+# Example:
+# 20210725 18:23:09.509 W tmp.py:23 This is the log message.
+_LOG_FORMAT_PARTS = ('%(asctime)s.%(msecs)03d', # Time.
+ '%(levelname).1s', # Level.
+ '%(filename)s:%(lineno)d', # Where.
+ '%(message)s')
+_LOG_FORMAT = ' '.join(_LOG_FORMAT_PARTS)
+_DATE_FORMAT = '%Y%m%d %X'
+
+
+def create_handler(log_file: Optional[str] = None) -> logging.Handler:
+ """Creates the logging handler for a specific log file.
+
+ If log_file arugment is provided, we create RotatingFileHandler; otherwise
+ we create a StreamHandler.
+
+ Args:
+ log_file (str): Path to the log file.
+
+ Returns:
+ The logging handler created.
+ """
+ if log_file is not None:
+ handler = handlers.RotatingFileHandler(log_file,
+ maxBytes=100000,
+ backupCount=10)
+ else:
+ handler = logging.StreamHandler()
+ handler.setLevel(logging.DEBUG)
+ handler.setFormatter(logging.Formatter(_LOG_FORMAT, datefmt=_DATE_FORMAT))
+ return handler
+
+
+def add_file_handler(log_file: str) -> None:
+ """Adds a handler to the global logger for logging to the given file.
+
+ We use a global dict to prevent adding multiple file handlers for one
+ single file.
+
+ Args:
+ log_file: Path of the file to add handler for.
+ """
+ if log_file in log_handler_map:
+ return
+ handler = create_handler(log_file)
+ log_handler_map[log_file] = handler
+ logger.addHandler(handler)
+
+
+def remove_file_handler(log_file: str) -> None:
+ """Removes the logging handler of a specific file.
+
+ If there isn't a handler for the file, no-op.
+
+ Args:
+ log_file: Path of the log file.
+ """
+ if log_file not in log_handler_map:
+ return
+ handler = log_handler_map[log_file]
+ logger.removeHandler(handler)
+ del log_handler_map[log_file]
+
+
+def setup_logger() -> None:
+ """Sets up the logger for logging in file and console."""
+ logger.setLevel(logging.DEBUG)
+
+ file_handler = create_handler(_DEFAULT_LOG_FILE)
+ stream_handler = create_handler()
+
+ logger.addHandler(file_handler)
+ logger.addHandler(stream_handler)
+
+
+def get_logger() -> logging.Logger:
+ """Gets the local agent logger for logging.
+
+ This get function helps modules in the local agent package get the local
+ agent logger without having to directly importing the "logger" member.
+
+ Returns:
+ The local agent logger.
+ """
+ return logger
+
+
+setup_logger()
diff --git a/local_agent/suite_session_manager.py b/local_agent/suite_session_manager.py
new file mode 100644
index 0000000..a720803
--- /dev/null
+++ b/local_agent/suite_session_manager.py
@@ -0,0 +1,218 @@
+# Copyright 2021 Google LLC
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+"""Module for SuiteSessionManager class."""
+import datetime
+import os
+import time
+import threading
+from typing import Any, Callable, Dict, List, Optional
+
+from local_agent import errors as agent_errors
+from local_agent import logger as logger_module
+
+
+logger = logger_module.get_logger()
+
+# ============= Constants ============= #
+_ARTIFACT_EXPIRE_DAYS = 30 # days
+_SUITE_SESSION_TIME_OUT = 5400 # 90 mins in seconds
+_SUITE_SESSION_TIME_OUT_HUMAN_READABLE = '90 mins'
+_SUITE_SESSION_TIME_OUT_INTERVAL_SECONDS = 30
+# ===================================== #
+
+
+class SuiteSessionManager:
+ """Suite Session Manager for test suite session management."""
+
+ def __init__(self,
+ artifacts_fn: Callable[[str, str], None],
+ artifact_root_dir: str,
+ create_devices_fn: Callable[[List[str], str], None],
+ close_devices_fn: Callable[[], None]):
+ # Artifacts related
+ self._compress_artifacts_and_upload = artifacts_fn
+ self._artifact_root_dir = artifact_root_dir
+
+ # Tracks test suite related
+ self._termination_event = None
+ self._suite_timeout_checker = threading.Thread(
+ target=self._check_suite_timeout, daemon=True)
+ self._ongoing_test_suite_id = None
+ self._ongoing_test_suite_start_time = None
+
+ # Device control related
+ self._create_devices = create_devices_fn
+ self._close_devices = close_devices_fn
+
+ def start(self, termination_event: threading.Event) -> None:
+ """Starts the suite session manager by enabling the background threads.
+
+ Args:
+ termination_event: The termination threading event for the thread.
+ """
+ self._termination_event = termination_event
+ self._suite_timeout_checker.start()
+
+ def start_test_suite(self, rpc_request: Dict[str, str]) -> Dict[str, Any]:
+ """Subroutine for handling START_TEST_SUITE command.
+
+ Args:
+ rpc_request: JSON-RPC request.
+
+ Raises:
+ InvalidRPCError: Invalid RPC.
+ InvalidTestSuiteSessionError: Previous test suite has not ended.
+
+ Returns:
+ RPC response.
+ """
+ dut_ids = rpc_request['params'].get('dutDeviceIds')
+ if dut_ids is None:
+ raise agent_errors.InvalidRPCError(
+ 'Invalid rpc command, no dutDeviceIds')
+ suite_id = rpc_request['params'].get('id')
+ force_start = rpc_request['params'].get('forceStart')
+
+ # Support single suite execution only at a time.
+ if self._ongoing_test_suite_id is not None:
+ if force_start:
+ self.clean_up()
+ else:
+ raise agent_errors.InvalidTestSuiteSessionError(
+ f'The previous test suite {self._ongoing_test_suite_id}'
+ ' has not ended yet.')
+
+ self._ongoing_test_suite_id = suite_id
+ self._ongoing_test_suite_start_time = time.time()
+ test_suite_dir = self._initialize_artifact_directory(
+ self._ongoing_test_suite_id)
+
+ self._create_devices(dut_ids, test_suite_dir)
+
+ logger.info(f'Start a new test suite: {self._ongoing_test_suite_id}')
+ return {'id': rpc_request['id'], 'jsonrpc': '2.0', 'result': {}}
+
+ def end_test_suite(self, rpc_request: Dict[str, str]) -> Dict[str, Any]:
+ """Subroutine for handling END_TEST_SUITE request.
+
+ Args:
+ rpc_request: JSON-RPC request.
+
+ Raises:
+ InvalidTestSuiteSessionError: Invalid testsuite name.
+
+ Returns:
+ RPC response.
+ """
+ suite_id = rpc_request['params'].get('id')
+ test_result_id = rpc_request['params'].get('testResultId')
+
+ if suite_id == self._ongoing_test_suite_id:
+ self.clean_up(test_result_id)
+ else:
+ # Invalid testsuite session
+ raise agent_errors.InvalidTestSuiteSessionError(
+ f'Session {suite_id} has never started before.')
+
+ return {'id': rpc_request['id'], 'jsonrpc': '2.0', 'result': {}}
+
+ def clean_up(self, test_result_id: Optional[str] = None) -> None:
+ """Cleans up the ongoing test suite and removes outdated artifacts.
+
+ Args:
+ test_result_id: Test result ID.
+ """
+ if self._ongoing_test_suite_id is not None:
+ logger.info(
+ f'Cleaning up the test suite: {self._ongoing_test_suite_id}')
+ self._close_devices()
+ self._compress_artifacts_and_upload(
+ test_suite_id=self._ongoing_test_suite_id,
+ test_result_id=test_result_id)
+
+ elapsed_time = time.time() - self._ongoing_test_suite_start_time
+ logger.info(
+ f'Cleaned up test suite {self._ongoing_test_suite_id}.'
+ f'Suite session elapsed time = {elapsed_time} seconds.')
+ self._ongoing_test_suite_id = None
+ self._ongoing_test_suite_start_time = None
+
+ self._remove_outdated_artifacts()
+
+ def _initialize_artifact_directory(self, ongoing_test_suite_id: str) -> str:
+ """Creates a directory for artifacts of a test suite session.
+
+ Args:
+ ongoing_test_suite_id: ID of the ongoing test suite.
+
+ Returns:
+ Artifact directory path.
+ """
+ test_suite_dir = os.path.join(
+ self._artifact_root_dir, ongoing_test_suite_id)
+ if not os.path.exists(test_suite_dir):
+ os.makedirs(test_suite_dir)
+
+ local_agent_log = os.path.join(test_suite_dir, 'local_agent.log')
+ logger_module.add_file_handler(local_agent_log)
+
+ logger.info(f'Start collecting log for {ongoing_test_suite_id}')
+
+ return test_suite_dir
+
+ def _remove_outdated_artifacts(self) -> None:
+ """Deletes the outdated local artifacts (older than 1 month)."""
+ if not os.path.exists(self._artifact_root_dir):
+ return
+ today = datetime.datetime.today()
+ for name in os.listdir(self._artifact_root_dir):
+ artifact = os.path.join(self._artifact_root_dir, name)
+ if os.path.isfile(artifact):
+ created_time = os.stat(artifact).st_ctime
+ delta = today - datetime.datetime.fromtimestamp(created_time)
+ if delta.days >= _ARTIFACT_EXPIRE_DAYS:
+ try:
+ os.remove(artifact)
+ logger.info(
+ f'Deleted outdated artifact: {artifact}')
+ except OSError as e:
+ logger.warning(
+ 'Failed to remove the outdated artifact'
+ f' {artifact}: {str(e)}')
+
+ def _check_suite_timeout(self) -> None:
+ """Checks if the current test suite has timed out.
+
+ Practically no suite in Rainier will be longer than 1 hour,
+ if the current suite has been running over 90 mins, we'll
+ force clean up the current suite session.
+ """
+ while (self._termination_event is not None and
+ not self._termination_event.wait(
+ _SUITE_SESSION_TIME_OUT_INTERVAL_SECONDS)):
+
+ if self._ongoing_test_suite_id is None:
+ logger.info('No ongoing test suite.')
+ continue
+
+ logger.info('Checking if suite session has timed out.')
+ suite_elapsed_time = (
+ time.time() - self._ongoing_test_suite_start_time)
+ if suite_elapsed_time >= _SUITE_SESSION_TIME_OUT:
+ logger.warning(
+ f'Suite {self._ongoing_test_suite_id} has timed out '
+ f'(over {_SUITE_SESSION_TIME_OUT_HUMAN_READABLE}), '
+ 'forcing suite session clean up.')
+ self.clean_up()
diff --git a/local_agent/tests/README.md b/local_agent/tests/README.md
new file mode 100644
index 0000000..72f506f
--- /dev/null
+++ b/local_agent/tests/README.md
@@ -0,0 +1,59 @@
+# Local Agent Tests
+
+This folder contains unit and functional (integration with backend) tests for Local Agent.
+
+You'll have to install the extra packages to run the tests:
+```
+$ source <venv-name>/bin/activate
+$ cd /path/to/{repo-base}/local-agent
+$ pip install .[test]
+```
+
+## Unit Test
+
+Run all the unit tests and show the coverage report.
+
+1. Activate the virtual environment.
+2. Execute the coverage tool:
+ ```
+ cd /path/to/{repo-base}/local-agent/local_agent/tests/unit_tests/
+ coverage run --source={LOCAL_AGENT_PACKAGE_PATH}/local_agent/ -m unittest discover && coverage report
+ ```
+ where {LOCAL_AGENT_PACKAGE_PATH} can be obtained by the following command:
+ ```
+ $ pip show local-agent | grep -e "^Location" | awk -F " " '{print $2}'
+ ```
+
+## Real AMS Test
+
+```Important```: If you're NOT an internal Googler or tester, you won't be able to run this test.
+
+Run the local agent with a real Test Suite Backend (AMS).
+
+
+1. Backend setup: If you're an internal Googler or tester, please contact the local agent author for the backend deployment; Otherwise, you won't be able to deploy the backend.
+
+2. Edit the local agent config file: filling the ```AMS_HOST``` and ```AMS_PORT``` fields with the AMS configurations. Contact the local agent author if you're not sure about the backend configurations. See ```example_config.yaml``` for references.
+
+3. Activate the virtual environment.
+
+4. Connect your Matter devices with the above mentioned testabilities to the host machine. (You can still run the AMS test without connecting any devices).
+
+5. Start the fake front end process \
+Use ```-host [tsb_host]``` and ```-port [tsb_port]``` to indicate the host and port of deployed backend, the fake front end will use ```host=localhost``` and ```port=8000``` if the arguments are not provided. \
+For example: \
+ Run with real devices
+ ```
+ python fake_front_end.py
+ ```
+
+6. Start the local agent process:
+ ```
+ local-agent -u [config file location]
+ ```
+ or just put the config under ```~/.config/google/local_agent_config.yaml```, then start the process parameterlessly:
+ ```
+ local-agent
+ ```
+
+7. The Local Agent should prompt the user for entering the linking code, which is displayed in the console of ```fake_front_end``` process.
diff --git a/local_agent/tests/ams_tests/fake_front_end.py b/local_agent/tests/ams_tests/fake_front_end.py
new file mode 100644
index 0000000..a5e036f
--- /dev/null
+++ b/local_agent/tests/ams_tests/fake_front_end.py
@@ -0,0 +1,489 @@
+# Copyright 2021 Google LLC
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+"""Fake front end for local agent and real AMS integration test."""
+import argparse
+import http
+import immutabledict
+import logging
+import requests
+import signal
+import sys
+import time
+import threading
+from typing import Any, Dict, List, Optional, Set, Tuple
+
+import fake_test_suite
+
+
+_LOCAL_TSB_HOST = '127.0.0.1'
+_LOCAL_TSB_PORT = 8080
+_TEST_PROJECT = 'test-project'
+_POLLING_PERIOD = 1 # seconds
+_STATUS_POLLING_PERIOD = 30 # seconds
+_COOL_DOWN_SEC = 1 # seconds
+_DETECTION_COOL_DOWN_SEC = 3 # To wait for GDM detection to complete.
+
+
+logger = logging.getLogger(__name__)
+
+# ======================== TSB endpoints ========================== #
+TEST_SUITE_AUTH = '/tsb/api/test-suite/auth'
+TEST_SUITE_PROJECTS = '/tsb/api/test-suite/projects'
+LINKING_CODE = '/tsb/api/test-suite/local-agent/linking-code'
+AGENT_STATUS = '/tsb/api/test-suite/local-agent/info'
+AGENT_RPC = '/tsb/api/test-suite/local-agent/rpc'
+UNLINK_AGENT = '/tsb/api/test-suite/local-agent/unlink'
+RPC_METADATA = f'{AGENT_RPC}/metadata'
+# ================================================================= #
+
+
+# ======================== Constants ========================== #
+ALL_SUITE_CLASSES = (
+ fake_test_suite.BrightnessSuite,
+ fake_test_suite.ColorSuite,
+ fake_test_suite.DeviceCommonSuite,
+ fake_test_suite.LightOnOffSuite,
+ fake_test_suite.LockUnlockSuite,
+)
+GDM_CAPABILITY_TO_HG_TRAIT = immutabledict.immutabledict({
+ 'pw_rpc_common': 'Common',
+ 'pw_rpc_light': 'OnOff',
+ 'pw_rpc_lock': 'LockUnlock',
+})
+# ============================================================= #
+
+
+# ======================== Module level functions ========================== #
+def setup_logger() -> None:
+ """Sets up the logger for logging."""
+ logger.setLevel(logging.DEBUG)
+ handler = logging.StreamHandler()
+ handler.setLevel(logging.DEBUG)
+ handler.setFormatter(
+ logging.Formatter('[%(asctime)s %(levelname)s] %(message)s'))
+ logger.addHandler(handler)
+
+
+def parse_args() -> Tuple[str, Optional[int]]:
+ """Sets up the parser for argument parsing.
+
+ Returns:
+ Tuple: TSB host, TSB port
+ """
+ parser = argparse.ArgumentParser()
+ parser.add_argument(
+ '-host', '--tsb_host', type=str, required=False,
+ default=_LOCAL_TSB_HOST, help='TSB host')
+ parser.add_argument(
+ '-port', '--tsb_port', type=int, required=False,
+ default=_LOCAL_TSB_PORT, help='TSB port')
+ args, leftover = parser.parse_known_args(sys.argv[1:])
+ sys.argv[1:] = leftover
+
+ tsb_host = args.tsb_host
+ tsb_port = args.tsb_port if tsb_host == _LOCAL_TSB_HOST else None
+
+ return tsb_host, tsb_port
+
+
+def raise_exception(response: requests.models.Response, err_msg: str) -> None:
+ """Raises exception for HTTP response status code != 200.
+
+ Args:
+ response: HTTP response.
+ err_msg: Error message.
+ """
+ err_msg = f'{err_msg}. Status: {response.status_code}'
+ try:
+ ams_err_msg = response.json()['errorMessage']
+ except:
+ ams_err_msg = ''
+ if ams_err_msg:
+ err_msg += f', AMS error message: {ams_err_msg}'
+ raise RuntimeError(err_msg)
+# ========================================================================== #
+
+
+class TSBService:
+ """TSB endpoint service."""
+
+ def __init__(self, tsb_host: str, tsb_port: Optional[int]):
+ if tsb_port is None:
+ self._base_url = f'http://{tsb_host}'
+ else:
+ self._base_url = f'http://{tsb_host}:{tsb_port}'
+ auth_token = self.get_auth_token()
+ self._auth = {'Authorization': auth_token}
+
+ def get_auth_token(self) -> str:
+ """Service to obtain the test suite user's auth token.
+
+ Returns:
+ Test suite user auth token.
+
+ Raises:
+ RuntimeError: HTTP response status code is not 200.
+ """
+ url = self._base_url + TEST_SUITE_AUTH
+ resp = requests.post(url, json={'idToken': 'debug-token'})
+ if resp.status_code != http.HTTPStatus.OK:
+ raise_exception(resp, 'Failed to get auth token')
+ auth_token = resp.json().get('authToken')
+ return auth_token
+
+ def get_or_create_test_project(self) -> None:
+ """Service to create test project if needed.
+
+ Raises:
+ RuntimeError: HTTP response status code is not 200.
+ """
+ url = self._base_url + TEST_SUITE_PROJECTS
+ resp = requests.get(url, headers=self._auth)
+ if resp.status_code != http.HTTPStatus.OK:
+ raise_exception(resp, 'Failed to get test project')
+
+ for project in resp.json().get('result'):
+ if project['id'] == _TEST_PROJECT:
+ return
+ resp = requests.post(
+ url, headers=self._auth, json={'projectIds': [_TEST_PROJECT]})
+ if resp.status_code != http.HTTPStatus.CREATED:
+ raise_exception(resp, 'Failed to create test project')
+
+ def get_linking_code(self) -> str:
+ """Service to obtain the linking code.
+
+ Returns:
+ Linking code.
+
+ Raises:
+ RuntimeError: HTTP response status code is not 200.
+ """
+ url = self._base_url + LINKING_CODE
+ resp = requests.post(
+ url, headers=self._auth, json={'projectId': _TEST_PROJECT})
+ if resp.status_code != http.HTTPStatus.OK:
+ raise_exception(resp, 'Failed to get linking code')
+ linking_code = resp.json()['result'].get('code')
+ return linking_code
+
+ def get_agent_status(self) -> Optional[Dict[str, Any]]:
+ """Service to retrieve the local agent status.
+
+ Returns:
+ Local Agent status dict or None if agent is not linked.
+
+ Raises:
+ RuntimeError: HTTP response status code is not 200 nor 404.
+ """
+ url = self._base_url + AGENT_STATUS + f'?projectId={_TEST_PROJECT}'
+ resp = requests.get(url, headers=self._auth)
+ if resp.status_code == http.HTTPStatus.NOT_FOUND:
+ return None
+ elif resp.status_code == http.HTTPStatus.OK:
+ return resp.json()['result']
+ else:
+ raise_exception(resp, 'Failed to get agent status')
+
+ def send_rpc_request(self, rpc_request: Dict[str, Any]) -> str:
+ """Service to send RPC request to the AMS.
+
+ Args:
+ rpc_request: JSON RPC request.
+
+ Returns:
+ JSON-RPC id.
+
+ Raises:
+ RuntimeError: HTTP response status code is not 200.
+ """
+ url = self._base_url + AGENT_RPC
+ resp = requests.post(url,
+ headers=self._auth, json=rpc_request)
+ if resp.status_code != http.HTTPStatus.OK:
+ raise_exception(resp, 'Failed to send RPC request')
+ rpc_id = resp.json()['result']['id']
+ return rpc_id
+
+ def get_rpc_metadata(self, rpc_id: str) -> Dict[str, Optional[int]]:
+ """Service to get RPC metadata with given rpc_id.
+
+ Args:
+ rpc_id: JSON-RPC id.
+
+ Returns:
+ RPC metadata.
+
+ Raises:
+ RuntimeError: HTTP response status code is not 200.
+ """
+ url = (self._base_url + RPC_METADATA +
+ f'?projectId={_TEST_PROJECT}&rpcId={rpc_id}')
+ resp = requests.get(url, headers=self._auth)
+ if resp.status_code != http.HTTPStatus.OK:
+ raise_exception(resp, 'Failed to get RPC metadata')
+ metadata = resp.json()['result']
+ return metadata
+
+ def get_rpc_response(self, rpc_id: str) -> Dict[str, Any]:
+ """Service to get RPC response with given rpc_id.
+
+ Args:
+ rpc_id: JSON-RPC id.
+
+ Returns:
+ RPC response.
+
+ Raises:
+ RuntimeError: HTTP response status code is not 200.
+ """
+ url = (self._base_url + AGENT_RPC +
+ f'?projectId={_TEST_PROJECT}&rpcId={rpc_id}')
+ resp = requests.get(url, headers=self._auth)
+ if resp.status_code != http.HTTPStatus.OK:
+ raise_exception(resp, 'Failed to get RPC response')
+ rpc_response = resp.json()['result']
+ return rpc_response
+
+ def unlink_agent(self) -> None:
+ """API to unlink local agent. Raises RuntimeError if failed."""
+ url = self._base_url + UNLINK_AGENT
+ post_data = {'projectId': _TEST_PROJECT}
+ resp = requests.post(url, json=post_data, headers=self._auth)
+ if resp.status_code != http.HTTPStatus.OK:
+ raise_exception(resp, 'Failed to unlink local agent')
+
+
+class FakeFrontEnd:
+ """Fake front end module."""
+
+ def __init__(self, host: str, port: int):
+ # Registers termination signal handler
+ signal.signal(signal.SIGINT, self._terminate)
+
+ # Retrieves auth token and creates test project
+ self._tsb_service = TSBService(host, port)
+ self._tsb_service.get_or_create_test_project()
+
+ # Local Agent status
+ self._local_agent_status = None
+
+ # Worker threads: executing test plan and retrieving agent status.
+ self._termination_event = threading.Event()
+ self._test_plan_worker = threading.Thread(
+ target=self._create_and_execute_test_plan, daemon=True)
+ self._agent_status_worker = threading.Thread(
+ target=self._retrieve_agent_status, daemon=True)
+
+ def _terminate(self, sig_num: int, frame: 'frame') -> None:
+ """Signal handler upon receiving a SIGINT.
+
+ Args:
+ sig_num: Signal number passed to the handler.
+ frame: Current stack frame passed to the handler.
+ """
+ del sig_num, frame # Unused
+ logger.warning('Terminates fake front end process.')
+ self._termination_event.set()
+
+ def run(self) -> None:
+ """Runs fake front end.
+
+ Simulate the front end behaviors:
+ 1. Links the local agent.
+ 2. Sends RPC requests to TSB, polls the metadata and gets response.
+ 3. Retrieves the local agent status simultaneously.
+ 4. Unlinks the agent after the test is completed.
+ """
+ if self._link_agent():
+
+ time.sleep(_DETECTION_COOL_DOWN_SEC)
+ self._test_plan_worker.start()
+ self._agent_status_worker.start()
+
+ while not self._termination_event.is_set():
+ time.sleep(_POLLING_PERIOD)
+
+ self._unlink_agent()
+
+ def _checks_if_agent_is_linked(self) -> bool:
+ """Returns if local agent is linked.
+
+ Returns:
+ True if agent is linked, false otherwise.
+ """
+ status = self._tsb_service.get_agent_status()
+ return status is not None and status.get('status') != 'OFFLINE'
+
+ def _checks_if_response_is_stored(self, rpc_id: str) -> bool:
+ """Returns if the rpc response of rpc_id is stored.
+
+ Returns:
+ True if the rpc response is stored, false otherwise.
+ """
+ metadata = self._tsb_service.get_rpc_metadata(rpc_id=rpc_id)
+ resp_timestamp = metadata.get('responseStoredTimestamp')
+ return resp_timestamp is not None
+
+ def _link_agent(self) -> bool:
+ """Links local agent.
+
+ Retrieves linking code and checks if agent is linked.
+
+ Returns:
+ True if agent is linked, false otherwise.
+ """
+ linking_code = self._tsb_service.get_linking_code()
+ print(f'\033[1m******************** Linking Code: {linking_code} ****'
+ '****************\033[0m')
+ while (not self._termination_event.is_set() and
+ not self._checks_if_agent_is_linked()):
+ logger.info(f'No agent is linked, sleep {_POLLING_PERIOD} sec...')
+ time.sleep(_POLLING_PERIOD)
+ if not self._termination_event.is_set():
+ logger.info('The local agent is linked.')
+ return True
+ return False
+
+ def _run_rpc_requests(
+ self, rpc_request: Dict[str, Any]) -> Optional[Dict[str, Any]]:
+ """Runs RPC request.
+
+ Simulates the FE behavior: sends the rpc request to BE, polls for the
+ rpc metadata, once it gets updated, retrieves the rpc response from BE.
+
+ Args:
+ rpc_request: JSON-RPC request.
+
+ Returns:
+ JSON-RPC response or None if fake front end is interrupted by SIGINT.
+ """
+ rpc_id = self._tsb_service.send_rpc_request(rpc_request=rpc_request)
+ logger.info(f'Sent RPC request: {rpc_request}.')
+
+ while (not self._termination_event.is_set() and
+ not self._checks_if_response_is_stored(rpc_id)):
+ logger.info(
+ f'RPC response not available, sleep {_POLLING_PERIOD} sec...')
+ time.sleep(_POLLING_PERIOD)
+
+ # Interrupted by SIGINT
+ if self._termination_event.is_set():
+ return None
+
+ logger.info('Metadata is updated.')
+ rpc_response = self._tsb_service.get_rpc_response(rpc_id=rpc_id)
+
+ return rpc_response
+
+ def _unlink_agent(self) -> None:
+ """Unlinks local agent."""
+ logger.info('Unlinking local agent.')
+ self._tsb_service.unlink_agent()
+ logger.info('Local agent unlinked.')
+
+ def _create_and_execute_test_plan(self) -> None:
+ """Creates and executes test plan."""
+ while self._local_agent_status is None:
+ logger.info(f'Local Agent status not available yet.')
+ time.sleep(_POLLING_PERIOD)
+ connected_devices = []
+ for device_info in self._local_agent_status['devices']:
+ device_id = device_info['deviceId']
+ hg_traits = self._get_hg_traits(device_info['capabilities'])
+ connected_devices.append((device_id, hg_traits))
+
+ for device_id, hg_traits in connected_devices:
+ suites = self._generate_suites(device_id, hg_traits)
+ for suite in suites:
+ logger.info(
+ f'Executes suite {suite} for device {device_id} ...')
+ self._run_suite(suite)
+
+ def _get_hg_traits(self, capabilities: List[str]) -> List[str]:
+ """Maps the GDM capability to the corresponding HG trait.
+
+ Args:
+ capabilities: List of GDM capability.
+
+ Returns:
+ List of HG traits.
+ """
+ hg_traits = []
+ for capability in capabilities:
+ hg_trait = GDM_CAPABILITY_TO_HG_TRAIT.get(capability)
+ if hg_trait is not None:
+ hg_traits.append(hg_trait)
+ return hg_traits
+
+ def _retrieve_agent_status(self) -> None:
+ """Retrieves local agent status."""
+ while not self._termination_event.is_set():
+ status = self._tsb_service.get_agent_status()
+ if status is not None:
+ logger.info(f'Retrieves local agent status: {status}')
+ self._local_agent_status = status
+ time.sleep(_STATUS_POLLING_PERIOD)
+
+ def _generate_suites(
+ self, device_id: str, device_traits: Set[str]) -> List[Any]:
+ """Generates test suites based on device trait.
+
+ In reality, FE sends device info to BE for test plan/suite generation.
+ To reduce the maintenance effort, the fake FE here is to simply map the
+ corresponding fake-suite directly.
+
+ Args:
+ device_id: GDM device id.
+ device_traits: Set of device traits on HG.
+
+ Returns:
+ The list of suites which are applicable to the device traits.
+ """
+ suites = []
+ for suite_class in ALL_SUITE_CLASSES:
+ if suite_class.is_applicable_to(device_traits):
+ suites.append(suite_class(device_id))
+ return suites
+
+ def _run_suite(self, suite: Any) -> None:
+ """Runs suite procedures.
+
+ Args:
+ suite: Suite instance.
+ """
+ device_ids = [suite.device_id]
+ start_suite_rpc, end_suite_rpc = (
+ fake_test_suite.generate_start_end_suite_rpc(device_ids))
+
+ all_procedures = [start_suite_rpc] + suite.procedures + [end_suite_rpc]
+
+ for rpc_request in all_procedures:
+ logger.info(f'Runs RPC request {rpc_request}')
+ rpc_response = self._run_rpc_requests(rpc_request=rpc_request)
+ logger.info(f'Retrieves RPC response: {rpc_response}')
+ time.sleep(_COOL_DOWN_SEC)
+
+
+def main() -> None:
+ """Main entry of fake front end."""
+ setup_logger()
+ tsb_host, tsb_port = parse_args()
+ fake_front_end = FakeFrontEnd(host=tsb_host, port=tsb_port)
+ fake_front_end.run()
+
+
+if __name__ == '__main__':
+ main()
diff --git a/local_agent/tests/ams_tests/fake_test_suite.py b/local_agent/tests/ams_tests/fake_test_suite.py
new file mode 100644
index 0000000..fda9f4f
--- /dev/null
+++ b/local_agent/tests/ams_tests/fake_test_suite.py
@@ -0,0 +1,272 @@
+# Copyright 2021 Google LLC
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+"""Fake test suite for local agent and real AMS integration test."""
+import secrets
+from typing import Any, Dict, List, Optional, Set, Tuple
+
+
+_TEST_PROJECT = 'test-project'
+_TEST_BRIGHTNESS_LEVEL = 150
+_TEST_COLOR_HUE = 50
+_TEST_COLOR_SATURATION = 30
+
+# ======================== Supported Methods ========================== #
+# See go/rainier-amsprocedure-api for more details
+
+# Basic
+_REBOOT = 'setReboot'
+_FACTORY_RESET = 'setFactoryReset'
+
+# On / Off
+_TURN_ON_LIGHT = 'setOn'
+_TURN_OFF_LIGHT = 'setOff'
+_GET_LIGHT_STATE = 'getOnOff'
+
+# Lock / Unlock
+_LOCK_DEVICE = 'setLock'
+_UNLOCK_DEVICE = 'setUnlock'
+_GET_IS_LOCKED = 'getIsLocked'
+
+# Brightness
+_GET_LIGHT_BRIGHTNESS = 'getBrightness'
+_SET_LIGHT_BRIGHTNESS = 'setBrightness'
+
+# Color
+_GET_LIGHT_COLOR = 'getColor'
+_SET_LIGHT_COLOR = 'setColor'
+
+# Test suite management
+_START_SUITE = 'startTestSuite'
+_END_SUITE = 'endTestSuite'
+# ============================================================= #
+
+# ====================== Helper methods ======================= #
+# TODO(b/194029399) Retrieves correct test result id.
+def generate_start_end_suite_rpc(
+ device_ids: List[str]) -> Tuple[Dict[str, Any], Dict[str, Any]]:
+ """Generates startTestSuite and endTestSuite RPC request.
+
+ Args:
+ device_ids: List of GDM device ids.
+
+ Returns:
+ Tuple: startTestSuite request and endTestSuite request.
+ """
+ fake_suite_id = secrets.token_hex(4)
+ start_suite_rpc = {
+ 'projectId': _TEST_PROJECT,
+ 'method': _START_SUITE,
+ 'params': {'id': fake_suite_id,
+ 'dutDeviceIds': device_ids}}
+ end_suite_rpc = {
+ 'projectId': _TEST_PROJECT,
+ 'method': _END_SUITE,
+ 'params': {'id': fake_suite_id}}
+ return start_suite_rpc, end_suite_rpc
+# ============================================================= #
+
+
+class BaseSuite:
+ """Base class for all suite templates."""
+
+ NAME = ''
+ REQUIRED_TRAITS = set()
+
+ def __init__(self, device_id: str):
+ """Suite constructor.
+
+ Args:
+ device_id: GDM device id.
+ """
+ self._device_id = device_id
+ self._procedures = []
+
+ @classmethod
+ def is_applicable_to(cls, device_traits: Set[str]) -> bool:
+ """Whether the suite is applicable to the device traits.
+
+ Args:
+ device_traits: The device traits on HG.
+
+ Returns:
+ True if it's applicable, false otherwise.
+ """
+ return cls.REQUIRED_TRAITS.issubset(device_traits)
+
+ def __str__(self) -> str:
+ """Returns the suite name."""
+ return self.NAME
+
+ @property
+ def device_id(self) -> str:
+ """Returns the device id."""
+ return self._device_id
+
+ @property
+ def procedures(self) -> List[Dict[str, Any]]:
+ """Returns the procedures."""
+ return self._procedures
+
+
+class LightOnOffSuite(BaseSuite):
+ """Light on off test suite."""
+
+ NAME = 'Light OnOff Suite'
+ REQUIRED_TRAITS = {'OnOff',}
+
+ def __init__(self, device_id: str):
+ """Light on off test suite constructor."""
+
+ super().__init__(device_id=device_id)
+
+ self._set_on_light_rpc = {
+ 'projectId': _TEST_PROJECT,
+ 'method': _TURN_ON_LIGHT,
+ 'params': {'dutDeviceId': self._device_id}}
+
+ self._set_off_light_rpc = {
+ 'projectId': _TEST_PROJECT,
+ 'method': _TURN_OFF_LIGHT,
+ 'params': {'dutDeviceId': self._device_id}}
+
+ self._get_state_light_rpc = {
+ 'projectId': _TEST_PROJECT,
+ 'method': _GET_LIGHT_STATE,
+ 'params': {'dutDeviceId': self._device_id}}
+
+ self._procedures = [
+ self._set_on_light_rpc,
+ self._get_state_light_rpc,
+ self._set_off_light_rpc,
+ self._get_state_light_rpc,
+ ]
+
+
+class LockUnlockSuite(BaseSuite):
+ """Lock unlock test suite."""
+
+ NAME = 'Lock Unlock Suite'
+ REQUIRED_TRAITS = {'LockUnlock',}
+
+ def __init__(self, device_id: str):
+ """Lock unlock test suite constructor."""
+
+ super().__init__(device_id=device_id)
+
+ self._set_lock_rpc = {
+ 'projectId': _TEST_PROJECT,
+ 'method': _LOCK_DEVICE,
+ 'params': {'dutDeviceId': self._device_id}}
+
+ self._set_unlock_rpc = {
+ 'projectId': _TEST_PROJECT,
+ 'method': _UNLOCK_DEVICE,
+ 'params': {'dutDeviceId': self._device_id}}
+
+ self._get_locked_state_rpc = {
+ 'projectId': _TEST_PROJECT,
+ 'method': _GET_IS_LOCKED,
+ 'params': {'dutDeviceId': self._device_id}}
+
+ self._procedures = [
+ self._set_lock_rpc,
+ self._get_locked_state_rpc,
+ self._set_unlock_rpc,
+ self._get_locked_state_rpc,
+ ]
+
+
+class DeviceCommonSuite(BaseSuite):
+ """Device common suite."""
+
+ NAME = 'Device Common Suite'
+ REQUIRED_TRAITS = {'Common',}
+
+ def __init__(self, device_id: str):
+ """Device common test suite constructor."""
+
+ super().__init__(device_id=device_id)
+
+ self._set_reboot_rpc = {
+ 'projectId': _TEST_PROJECT,
+ 'method': _REBOOT,
+ 'params': {'dutDeviceId': self._device_id}}
+
+ self._set_factory_reset_rpc = {
+ 'projectId': _TEST_PROJECT,
+ 'method': _FACTORY_RESET,
+ 'params': {'dutDeviceId': self._device_id}}
+
+ self._procedures = [
+ self._set_reboot_rpc,
+ self._set_factory_reset_rpc,
+ ]
+
+
+class BrightnessSuite(BaseSuite):
+ """Lighting brightness suite."""
+
+ NAME = 'Brightness Suite'
+ REQUIRED_TRAITS = {'OnOff',}
+
+ def __init__(self, device_id: str):
+ """Brightness test suite constructor."""
+
+ super().__init__(device_id=device_id)
+
+ self._set_brightness = {
+ 'projectId': _TEST_PROJECT,
+ 'method': _SET_LIGHT_BRIGHTNESS,
+ 'params': {'dutDeviceId': self._device_id,
+ 'level':_TEST_BRIGHTNESS_LEVEL}}
+
+ self._get_brightness = {
+ 'projectId': _TEST_PROJECT,
+ 'method': _GET_LIGHT_BRIGHTNESS,
+ 'params': {'dutDeviceId': self._device_id}}
+
+ self._procedures = [
+ self._set_brightness,
+ self._get_brightness,
+ ]
+
+
+class ColorSuite(BaseSuite):
+ """Lighting color suite."""
+
+ NAME = 'Color Suite'
+ REQUIRED_TRAITS = {'OnOff',}
+
+ def __init__(self, device_id: str):
+ """Color test suite constructor."""
+
+ super().__init__(device_id=device_id)
+
+ self._set_color = {
+ 'projectId': _TEST_PROJECT,
+ 'method': _SET_LIGHT_COLOR,
+ 'params': {'dutDeviceId': self._device_id,
+ 'hue': _TEST_COLOR_HUE,
+ 'saturation': _TEST_COLOR_SATURATION}}
+
+ self._get_color = {
+ 'projectId': _TEST_PROJECT,
+ 'method': _GET_LIGHT_COLOR,
+ 'params': {'dutDeviceId': self._device_id}}
+
+ self._procedures = [
+ self._set_color,
+ self._get_color,
+ ]
diff --git a/local_agent/tests/unit_tests/test_ams_client.py b/local_agent/tests/unit_tests/test_ams_client.py
new file mode 100644
index 0000000..f99fc1c
--- /dev/null
+++ b/local_agent/tests/unit_tests/test_ams_client.py
@@ -0,0 +1,436 @@
+# Copyright 2021 Google LLC
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+"""Unit tests for local_agent.ams_client module."""
+import http
+import json
+import requests
+import time
+import unittest
+from unittest import mock
+
+from local_agent import ams_client
+from local_agent import errors
+
+
+class AmsClientTest(unittest.TestCase):
+
+ def setUp(self):
+ super().setUp()
+ sleep_patcher = mock.patch.object(time, 'sleep')
+ sleep_patcher.start()
+ self.addCleanup(sleep_patcher.stop)
+
+ @mock.patch.object(ams_client.AmsClient, 'set_local_agent_credentials')
+ @mock.patch.object(requests.sessions.Session, 'request')
+ def test_register_success(self, mock_request, mock_set_credentials):
+ """Verifies register successful."""
+ sut = ams_client.AmsClient()
+ mock_response = mock_request.return_value
+ mock_response.status_code = 200
+ mock_response.json.return_value = {
+ 'result': {
+ 'agentId': 'the-id',
+ 'agentSecret': 'the-secret'}}
+
+ sut.register(linking_code='the-linking-code')
+
+ mock_set_credentials.assert_called_once_with(
+ local_agent_id='the-id',
+ local_agent_secret='the-secret')
+
+ @mock.patch.object(requests.sessions.Session, 'request')
+ def test_register_api_timeout(self, mock_request):
+ """Verifies register raises ApiTimeoutError when request timed out."""
+ sut = ams_client.AmsClient()
+ mock_request.side_effect = requests.exceptions.Timeout
+ with self.assertRaises(errors.ApiTimeoutError):
+ sut.register(linking_code='the-linking-code')
+
+ @mock.patch.object(requests.sessions.Session, 'request')
+ def test_register_api_return_status_bad_request(self, mock_request):
+ """Verifies register raises CredentialError when API response 400."""
+ sut = ams_client.AmsClient()
+ mock_request.return_value.status_code = http.HTTPStatus.BAD_REQUEST
+ with self.assertRaises(errors.CredentialsError):
+ sut.register(linking_code='the-linking-code')
+
+ @mock.patch.object(requests.sessions.Session, 'request')
+ def test_register_api_error_message_included(self, mock_request):
+ """Verifies register includes AMS error message in its exception."""
+ mock_response = mock_request.return_value
+ mock_response.status_code = http.HTTPStatus.BAD_REQUEST
+ mock_response.json.return_value = {'errorMessage': 'the-message'}
+ sut = ams_client.AmsClient()
+
+ with self.assertRaisesRegex(errors.CredentialsError, 'the-message'):
+ sut.register(linking_code='the-linking-code')
+
+ @mock.patch.object(requests.sessions.Session, 'request')
+ def test_register_api_return_status_not_ok(self, mock_request):
+ """Verifies register raises ApiError when API fails."""
+ sut = ams_client.AmsClient()
+ mock_request.return_value.status_code = http.HTTPStatus.NOT_FOUND
+ with self.assertRaises(errors.ApiError):
+ sut.register(linking_code='the-linking-code')
+
+ @mock.patch.object(requests.sessions.Session, 'request')
+ def test_set_local_agent_credentials_get_auth_token_success(self, mock_request):
+ """Verifies set_local_agent_credentials successfully get auth token."""
+ mock_response = mock_request.return_value
+ mock_response.status_code = http.HTTPStatus.CREATED
+ mock_response.json.return_value = {
+ 'result': {
+ 'authToken': 'the-auth-token',
+ }
+ }
+ sut = ams_client.AmsClient()
+ sut.set_local_agent_credentials('agent-id', 'agent-secret')
+
+ @mock.patch.object(ams_client.AmsClient, '_request_wrapper')
+ def test_get_auth_token_not_refreshing_auth(self, mock_request_wrapper):
+ """Verifies _get_auth_token sets refresh_auth to False."""
+ mock_response = mock_request_wrapper.return_value
+ mock_response.status_code = http.HTTPStatus.CREATED
+ mock_response.json.return_value = {'result': {'authToken': 'abc'}}
+ sut = ams_client.AmsClient()
+
+ sut._get_auth_token()
+
+ self.assertFalse(
+ mock_request_wrapper.call_args.kwargs['refresh_auth'],
+ msg='_get_auth_token needs to pass refresh_auth=False to '
+ '_request_wrapper to avoid infinite recursion')
+
+ @mock.patch.object(ams_client, 'extract_error_message_from_api_response')
+ @mock.patch.object(ams_client.AmsClient, '_request_wrapper')
+ def test_get_auth_token_unlinked_error(
+ self, mock_request_wrapper, mock_extract_err_msg):
+ """Verifies _get_auth_token raises UnlinkedError."""
+ mock_response = mock_request_wrapper.return_value
+ mock_response.status_code = http.HTTPStatus.BAD_REQUEST
+ mock_extract_err_msg.return_value = 'Invalid agent id'
+ sut = ams_client.AmsClient()
+
+ error_regex = 'Local agent is unlinked'
+ with self.assertRaisesRegex(errors.UnlinkedError, error_regex):
+ sut._get_auth_token()
+
+ @mock.patch.object(requests.sessions.Session, 'request')
+ def test_set_local_agent_credentials_raises_when_api_timeout(self,
+ mock_request):
+ """Verifies set_local_agent_credentials raises when API timed out."""
+ mock_request.side_effect = requests.exceptions.Timeout
+ sut = ams_client.AmsClient()
+ with self.assertRaises(errors.ApiTimeoutError):
+ sut.set_local_agent_credentials('agent-id', 'agent-secret')
+
+ @mock.patch.object(requests.sessions.Session, 'request')
+ def test_set_local_agent_credentials_raises_when_api_error(self,
+ mock_request):
+ """Verifies set_local_agent_credentials raises when API has error."""
+ mock_request.return_value.status_code = http.HTTPStatus.NOT_FOUND
+ sut = ams_client.AmsClient()
+ with self.assertRaises(errors.CredentialsError):
+ sut.set_local_agent_credentials('agent-id', 'agent-secret')
+
+ @mock.patch.object(requests.sessions.Session, 'request')
+ def test_set_local_agent_credentials_include_ams_error_message(
+ self,
+ mock_request):
+ """Verifies set_local_agent_credentials includes AMS error message."""
+ mock_response = mock_request.return_value
+ mock_response.status_code = http.HTTPStatus.NOT_FOUND
+ mock_response.json.return_value = {'errorMessage': 'the-message'}
+ sut = ams_client.AmsClient()
+
+ with self.assertRaisesRegex(errors.CredentialsError, 'the-message'):
+ sut.set_local_agent_credentials('agent-id', 'agent-secret')
+
+ @mock.patch.object(requests.sessions.Session,
+ 'request',
+ side_effect=requests.exceptions.Timeout)
+ def test_report_info_api_timeout(self, mock_request):
+ """Verifies report_info API timed out."""
+ sut = ams_client.AmsClient()
+ with self.assertRaises(errors.ApiTimeoutError):
+ sut.report_info({})
+
+ @mock.patch.object(requests.sessions.Session, 'request')
+ def test_report_info_api_response_error(self, mock_request):
+ """Verifies report_info raise exception when API response has error."""
+ mock_request.return_value.status_code = http.HTTPStatus.BAD_REQUEST
+ sut = ams_client.AmsClient()
+ with self.assertRaisesRegex(errors.ApiError,
+ 'Report info API failed: status 400'):
+ sut.report_info({})
+
+ @mock.patch.object(requests.sessions.Session, 'request')
+ def test_report_info_api_response_error_include_ams_error_message(
+ self,
+ mock_request):
+ """Verifies report_info raise exception when API response has error."""
+ mock_response = mock_request.return_value
+ mock_response.status_code = http.HTTPStatus.BAD_REQUEST
+ mock_response.json.return_value = {'errorMessage': 'the-message'}
+ sut = ams_client.AmsClient()
+
+ with self.assertRaisesRegex(errors.ApiError, 'the-message'):
+ sut.report_info({})
+
+ @mock.patch.object(requests.sessions.Session, 'request')
+ def test_report_info_successful(self, mock_request):
+ """Verifies report_info succeeds and sends info dict to AMS."""
+ mock_request.return_value.status_code = http.HTTPStatus.OK
+ local_agent_info = {'hi': 'hello'}
+ sut = ams_client.AmsClient()
+
+ sut.report_info(local_agent_info)
+
+ self.assertIn(
+ 'json',
+ mock_request.call_args.kwargs,
+ 'Local agent info should be passed as json arg to requests call.')
+ self.assertEqual(local_agent_info,
+ mock_request.call_args.kwargs['json'])
+
+ @mock.patch.object(requests.sessions.Session, 'request')
+ def test_get_rpc_request_from_ams_successful_with_request(self, mock_request):
+ """Verifies get_rpc_request_from_ams gets a request from AMS."""
+ mock_response = mock_request.return_value
+ mock_response.status_code = http.HTTPStatus.OK
+ mock_response.json.return_value = {'result': {'hi': 'hello'}}
+ sut = ams_client.AmsClient()
+
+ self.assertEqual(sut.get_rpc_request_from_ams(),
+ {'hi': 'hello'})
+
+ @mock.patch.object(requests.sessions.Session, 'request')
+ def test_get_rpc_request_from_ams_successful_no_request(self,
+ mock_request):
+ """Verifies get_rpc_request_from_ams gets no request from AMS."""
+ mock_response = mock_request.return_value
+ mock_response.status_code = http.HTTPStatus.NO_CONTENT
+ sut = ams_client.AmsClient()
+
+ self.assertIsNone(sut.get_rpc_request_from_ams())
+
+ @mock.patch.object(requests.sessions.Session, 'request')
+ def test_get_rpc_request_from_ams_raises_when_api_timeout(self,
+ mock_request):
+ """Verifies get_rpc_request_from_ams raises when API timed out."""
+ mock_request.side_effect = requests.exceptions.Timeout
+ sut = ams_client.AmsClient()
+
+ with self.assertRaises(errors.ApiTimeoutError):
+ sut.get_rpc_request_from_ams()
+
+ @mock.patch.object(requests.sessions.Session, 'request')
+ def test_get_rpc_request_from_ams_raise_when_api_response_error(
+ self, mock_request):
+ """Verifies get_rpc_request_from_ams raises when response has error."""
+ mock_response = mock_request.return_value
+ mock_response.status_code = http.HTTPStatus.INTERNAL_SERVER_ERROR
+ sut = ams_client.AmsClient()
+
+ with self.assertRaises(errors.ApiError):
+ sut.get_rpc_request_from_ams()
+
+ @mock.patch.object(requests.sessions.Session, 'request')
+ def test_get_rpc_request_from_ams_when_api_error_has_exception_message(
+ self, mock_request):
+ """Verifies get_rpc_request_from_ams error message when ApiError."""
+ mock_response = mock_request.return_value
+ mock_response.status_code = http.HTTPStatus.INTERNAL_SERVER_ERROR
+ mock_response.json.return_value = {'errorMessage': 'message-from-ams'}
+ sut = ams_client.AmsClient()
+
+ with self.assertRaisesRegex(errors.ApiError, r'500.*message-from-ams'):
+ sut.get_rpc_request_from_ams()
+
+ @mock.patch.object(requests.sessions.Session, 'request')
+ def test_remove_rpc_request_from_ams_successful(self, mock_request):
+ """Verifies remove_rpc_request_from_ams succeeds."""
+ mock_response = mock_request.return_value
+ mock_response.status_code = http.HTTPStatus.OK
+ sut = ams_client.AmsClient()
+ sut.remove_rpc_request_from_ams({'the': 'request'})
+
+ @mock.patch.object(requests.sessions.Session, 'request')
+ def test_remove_rpc_request_from_ams_raises_when_api_timeout(
+ self, mock_request):
+ """Verifies remove_rpc_request_from_ams raises if API timed out."""
+ mock_request.side_effect = requests.exceptions.Timeout
+ sut = ams_client.AmsClient()
+ with self.assertRaises(errors.ApiTimeoutError):
+ sut.remove_rpc_request_from_ams({'the': 'request'})
+
+ @mock.patch.object(requests.sessions.Session, 'request')
+ def test_remove_rpc_request_from_ams_raises_when_api_fails(
+ self, mock_request):
+ """Verifies remove_rpc_request_from_ams raises if API has error."""
+ mock_response = mock_request.return_value
+ mock_response.status_code = http.HTTPStatus.INTERNAL_SERVER_ERROR
+ mock_response.json.return_value = {'errorMessage': 'msg-from-ams'}
+ sut = ams_client.AmsClient()
+
+ with self.assertRaisesRegex(errors.ApiError, r'500.*msg-from-ams'):
+ sut.remove_rpc_request_from_ams({'the': 'request'})
+
+ @mock.patch.object(requests.sessions.Session, 'request')
+ def test_send_rpc_response_successful(self, mock_request):
+ """Verifies send_rpc_response sends RPC response to AMS successfully.
+ """
+ mock_request.return_value.status_code = http.HTTPStatus.OK
+ sut = ams_client.AmsClient()
+ sut.send_rpc_response({'the': 'response'})
+
+ @mock.patch.object(requests.sessions.Session, 'request')
+ def test_send_rpc_response_raise_api_error(self, mock_request):
+ """Verifies send_rpc_response raises ApiError if incorrect status code.
+ """
+ mock_response = mock_request.return_value
+ mock_response.status_code = (
+ http.HTTPStatus.INTERNAL_SERVER_ERROR)
+ mock_response.json.return_value = {'errorMessage': 'the-ams-msg'}
+ sut = ams_client.AmsClient()
+
+ with self.assertRaisesRegex(errors.ApiError, '500.*the-ams-msg'):
+ sut.send_rpc_response({'the': 'response'})
+ # Verify we have retried.
+ self.assertEqual(4, mock_request.call_count)
+
+ @mock.patch.object(requests.sessions.Session, 'request')
+ def test_send_rpc_response_raise_when_api_timed_out(self, mock_request):
+ """Verifies send_rpc_response raises ApiTimeoutError if API timed out.
+ """
+ mock_request.side_effect = requests.exceptions.Timeout
+ sut = ams_client.AmsClient()
+
+ with self.assertRaises(errors.ApiTimeoutError):
+ sut.send_rpc_response({'the': 'response'})
+ # Verify we have retried.
+ self.assertEqual(4, mock_request.call_count)
+
+ @mock.patch.object(ams_client.AmsClient, '_get_auth_token')
+ @mock.patch.object(requests.sessions.Session, 'request')
+ def test_send_rpc_response_refresh_auth_if_first_attempt_has_401(
+ self, mock_request, mock_get_auth_token):
+ """Verifies send_rpc_response will refresh auth token."""
+ mock_401_response = mock.Mock()
+ mock_401_response.status_code = http.HTTPStatus.UNAUTHORIZED
+ mock_401_response.json.return_value = {}
+ mock_500_response = mock.Mock()
+ mock_500_response.status_code = http.HTTPStatus.INTERNAL_SERVER_ERROR
+ mock_500_response.json.return_value = {}
+ mock_request.side_effect = (
+ mock_401_response,
+ requests.exceptions.Timeout,
+ requests.exceptions.Timeout,
+ requests.exceptions.Timeout,
+ mock_500_response,
+ )
+ sut = ams_client.AmsClient()
+
+ with self.assertRaisesRegex(errors.ApiError, '500'):
+ sut.send_rpc_response({'the': 'response'})
+ # Verify we have refreshed the auth token, and that does not count as
+ # a retry.
+ self.assertEqual(1, mock_get_auth_token.call_count)
+ self.assertEqual(5, mock_request.call_count)
+
+ @mock.patch.object(requests.sessions.Session, 'request')
+ @mock.patch.object(ams_client, 'open', new_callable=mock.mock_open)
+ def test_upload_artifact_successful(self, mock_open, mock_request):
+ """Verifies upload_artifact succeeds."""
+ mock_request.return_value.status_code = http.HTTPStatus.OK
+ sut = ams_client.AmsClient()
+ sut.upload_artifact('the/file/path', 'the-test-result-id')
+
+ @mock.patch.object(requests.sessions.Session, 'request')
+ @mock.patch.object(ams_client, 'open', new_callable=mock.mock_open)
+ def test_upload_artifact_api_timed_out(self, mock_open, mock_request):
+ """Verifies upload_artifact raises ApiTimeoutError if API timed out."""
+ mock_request.return_value.status_code = http.HTTPStatus.OK
+ mock_request.side_effect = requests.exceptions.Timeout
+ sut = ams_client.AmsClient()
+ with self.assertRaises(errors.ApiTimeoutError):
+ sut.upload_artifact('the/file/path', 'the-test-result-id')
+
+ @mock.patch.object(requests.sessions.Session, 'request')
+ @mock.patch.object(ams_client, 'open', new_callable=mock.mock_open)
+ def test_upload_artifact_api_response_error(self, mock_open, mock_request):
+ """Verifies upload_artifact raises when API response is error."""
+ mock_response = mock_request.return_value
+ mock_response.status_code = http.HTTPStatus.INTERNAL_SERVER_ERROR
+ mock_response.json.return_value = {'errorMessage': 'the-ams-err-msg'}
+ sut = ams_client.AmsClient()
+
+ with self.assertRaisesRegex(errors.ApiError, '500.*the-ams-err-msg'):
+ sut.upload_artifact('the/file/path', 'the-test-result-id')
+
+ @mock.patch.object(requests.sessions.Session, 'request')
+ def test_request_wrapper_no_timeout_and_invalid_num_retries(
+ self, mock_request):
+ """Verifies request_wrapper no timeout field and invalid num_retries."""
+ mock_response = mock.Mock(status_code=http.HTTPStatus.OK)
+ mock_request.return_value = mock_response
+ sut = ams_client.AmsClient()
+
+ response = sut._request_wrapper(num_retries=-1)
+
+ self.assertEqual(mock_response, response)
+
+ def test_extract_error_message_from_api_response_success(self):
+ """Verifies extract_error_message_from_api_response on success."""
+ fake_response = mock.Mock()
+ fake_response.json.return_value = {'errorMessage': 'error'}
+
+ self.assertEqual(
+ 'error',
+ ams_client.extract_error_message_from_api_response(fake_response))
+
+ @mock.patch.object(ams_client, 'logger')
+ def test_extract_error_message_from_api_response_decode_error(
+ self, mock_logger):
+ """Verifies extract_error_message_from_api_response decode error."""
+ fake_response = mock.Mock()
+ fake_doc = mock.Mock()
+ fake_doc.count.return_value = 0
+ fake_doc.rfind.return_value = 0
+ decode_error = json.decoder.JSONDecodeError('', fake_doc, 0)
+ fake_response.json.side_effect = decode_error
+
+ self.assertIsNone(
+ ams_client.extract_error_message_from_api_response(fake_response))
+
+ mock_logger.warning.assert_called_once_with(
+ 'API response cannot be parsed as JSON.')
+
+ @mock.patch.object(ams_client, 'logger')
+ def test_extract_error_message_from_api_response_key_error(
+ self, mock_logger):
+ """Verifies extract_error_message_from_api_response key error."""
+ fake_response = mock.Mock()
+ fake_response.json.side_effect = KeyError()
+
+ self.assertIsNone(
+ ams_client.extract_error_message_from_api_response(fake_response))
+
+ mock_logger.warning.assert_called_once_with(
+ 'API response does not have errorMessage field')
+
+
+if __name__ == '__main__':
+ unittest.main()
diff --git a/local_agent/tests/unit_tests/test_command_handlers/__init__.py b/local_agent/tests/unit_tests/test_command_handlers/__init__.py
new file mode 100644
index 0000000..d46dbae
--- /dev/null
+++ b/local_agent/tests/unit_tests/test_command_handlers/__init__.py
@@ -0,0 +1,13 @@
+# Copyright 2021 Google LLC
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
diff --git a/local_agent/tests/unit_tests/test_command_handlers/test_base.py b/local_agent/tests/unit_tests/test_command_handlers/test_base.py
new file mode 100644
index 0000000..ccb9ad1
--- /dev/null
+++ b/local_agent/tests/unit_tests/test_command_handlers/test_base.py
@@ -0,0 +1,104 @@
+# Copyright 2021 Google LLC
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+"""Unit tests for base command handlers."""
+import inflection
+from parameterized import parameterized
+import unittest
+from unittest import mock
+
+from local_agent import errors as agent_errors
+from local_agent.translation_layer.command_handlers import base
+
+
+_FAKE_SET_OP = 'setFake'
+_FAKE_GET_OP = 'getFake'
+_FAKE_SET_RESULT = None
+_FAKE_GET_RESULT = 'fake-result'
+_FAKE_PARAM_KEY = 'fake-param-key'
+_FAKE_PARAM_VAL = 10
+_UNKNOWN_METHOD = 'unknownMethod'
+
+
+def rpc_request(method, params):
+ """Simple wrapper for JSON-RPC request."""
+ return {'jsonrpc': '2.0', 'id': 0, 'method': method, 'params': params}
+
+def rpc_response(func_result):
+ """Simple wrapper for JSON-RPC response."""
+ result = {} if func_result is None else {'value': func_result}
+ return {'id': 0, 'jsonrpc': '2.0', 'result': result}
+
+
+class BaseCommandHandlerTest(unittest.TestCase):
+ """Unit tests for BaseCommandHandler."""
+
+ def setUp(self):
+ super().setUp()
+ self.handler = base.BaseCommandHandler(None)
+
+ @parameterized.expand(
+ [(_FAKE_GET_OP, _FAKE_GET_RESULT), (_FAKE_SET_OP, _FAKE_SET_RESULT)])
+ def test_01_handle_request_on_success(self, fake_method, fake_result):
+ """Verifies handle_request on success."""
+ fake_func = mock.Mock()
+ fake_func.return_value = fake_result
+ method_name = f'_{inflection.underscore(fake_method)}'
+ setattr(self.handler, method_name, fake_func)
+ fake_rpc = rpc_request(method=fake_method, params={})
+ expected_response = rpc_response(fake_result)
+
+ response = self.handler.handle_request(fake_rpc)
+
+ self.assertEqual(expected_response, response)
+ fake_func.assert_called_once()
+
+ def test_01_handle_request_on_failure_unknown_method(self):
+ """Verifies handle_request on failure with unknown method."""
+ fake_rpc = rpc_request(method=_UNKNOWN_METHOD, params={})
+
+ error_regex = 'Unknown method'
+ with self.assertRaisesRegex(agent_errors.InvalidRPCError, error_regex):
+ self.handler.handle_request(fake_rpc)
+
+ def test_02_validate_key_in_params_on_success(self):
+ """Verifies validate_key_in_params method on success."""
+ params = {_FAKE_PARAM_KEY: _FAKE_PARAM_VAL}
+ self.handler.validate_key_in_params(
+ params=params,
+ param_key=_FAKE_PARAM_KEY,
+ expected_type=int)
+
+ def test_02_validate_key_in_params_on_failure_missing_key(self):
+ """Verifies _validate_key_in_params on failure with missing key."""
+ error_regex = f'Missing field {_FAKE_PARAM_KEY} from RPC request.'
+ with self.assertRaisesRegex(ValueError, error_regex):
+ self.handler.validate_key_in_params(
+ params={},
+ param_key=_FAKE_PARAM_KEY,
+ expected_type=int)
+
+ def test_02_validate_key_in_params_on_failure_invalid_type(self):
+ """Verifies validate_key_in_params on failure with invalid type."""
+ params = {_FAKE_PARAM_KEY: _FAKE_PARAM_VAL}
+ error_regex = f'Invalid type for {_FAKE_PARAM_KEY}.'
+ with self.assertRaisesRegex(ValueError, error_regex):
+ self.handler.validate_key_in_params(
+ params=params,
+ param_key=_FAKE_PARAM_KEY,
+ expected_type=str)
+
+
+if __name__ == '__main__':
+ unittest.main(failfast=True)
diff --git a/local_agent/tests/unit_tests/test_command_handlers/test_common.py b/local_agent/tests/unit_tests/test_command_handlers/test_common.py
new file mode 100644
index 0000000..c60724c
--- /dev/null
+++ b/local_agent/tests/unit_tests/test_command_handlers/test_common.py
@@ -0,0 +1,56 @@
+# Copyright 2021 Google LLC
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+"""Unit tests for common command handlers."""
+import unittest
+from unittest import mock
+
+from gazoo_device import errors
+
+from local_agent.translation_layer.command_handlers import common
+
+
+class CommonCommandHandlerTest(unittest.TestCase):
+ """Unit tests for CommonCommandHandler."""
+
+ def setUp(self):
+ super().setUp()
+ self.dut = mock.Mock()
+ self.handler = common.CommonCommandHandler(self.dut)
+
+ def test_01_set_reboot_success(self):
+ """Verifies _set_reboot method on success."""
+ self.handler._set_reboot({})
+ self.dut.reboot.assert_called_once()
+
+ def test_01_set_reboot_failure(self):
+ """Verifies _set_reboot method on failure."""
+ self.dut.reboot.side_effect = errors.DeviceError('')
+ with self.assertRaises(errors.DeviceError):
+ self.handler._set_reboot({})
+
+ def test_02_set_factory_reset_success(self):
+ """Verifies _set_factory_reset method on success."""
+ self.handler._set_factory_reset({})
+ self.dut.factory_reset.assert_called_once()
+
+ def test_02_set_factory_reset_failure(self):
+ """Verifies _set_factory_reset method on failure."""
+ self.dut.factory_reset.side_effect = errors.DeviceError('')
+ with self.assertRaises(errors.DeviceError):
+ self.handler._set_factory_reset({})
+
+
+if __name__ == '__main__':
+ unittest.main(failfast=True)
diff --git a/local_agent/tests/unit_tests/test_command_handlers/test_light.py b/local_agent/tests/unit_tests/test_command_handlers/test_light.py
new file mode 100644
index 0000000..ee02fe8
--- /dev/null
+++ b/local_agent/tests/unit_tests/test_command_handlers/test_light.py
@@ -0,0 +1,148 @@
+# Copyright 2021 Google LLC
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+"""Unit tests for light command handlers."""
+from parameterized import parameterized
+import unittest
+from unittest import mock
+
+from gazoo_device import errors
+
+from local_agent.translation_layer.command_handlers import light
+
+
+_FAKE_BRIGHTNESS = 100
+_FAKE_HUE = 10
+_FAKE_SATURATION = 10
+_FAKE_ERROR_MSG = 'fake-error-msg'
+
+
+class LightCommandHandlerTest(unittest.TestCase):
+ """Unit tests for LightCommandHandler."""
+
+ def setUp(self):
+ super().setUp()
+ self.dut = mock.Mock()
+ self.handler = light.LightCommandHandler(self.dut)
+
+ @parameterized.expand([(True, 'on'), (False, 'off')])
+ def test_01_get_on_off_on_success(self, light_state, expected_state):
+ """Verifies get_on_off on success."""
+ self.dut.pw_rpc_light.state = light_state
+
+ state = self.handler._get_on_off({})
+
+ self.assertEqual(expected_state, state)
+
+ def test_01_get_on_off_on_failure(self):
+ """Verifies get_on_off on failure."""
+ mock_state = mock.PropertyMock(
+ side_effect=errors.DeviceError(_FAKE_ERROR_MSG))
+ type(self.dut.pw_rpc_light).state = mock_state
+ with self.assertRaisesRegex(errors.DeviceError, _FAKE_ERROR_MSG):
+ self.handler._get_on_off({})
+
+ def test_02_set_on_on_success(self):
+ """Verifies set_on on success."""
+ self.handler._set_on({})
+ self.dut.pw_rpc_light.on.assert_called_once()
+
+ def test_02_set_on_on_failure(self):
+ """Verifies set_on on failure."""
+ self.dut.pw_rpc_light.on.side_effect = errors.DeviceError('')
+ with self.assertRaises(errors.DeviceError):
+ self.handler._set_on({})
+
+ def test_03_set_off_on_success(self):
+ """Verifies set_off on success."""
+ self.handler._set_off({})
+ self.dut.pw_rpc_light.off.assert_called_once()
+
+ def test_03_set_off_on_failure(self):
+ """Verifies set_off on failure."""
+ self.dut.pw_rpc_light.off.side_effect = errors.DeviceError('')
+ with self.assertRaises(errors.DeviceError):
+ self.handler._set_off({})
+
+ def test_04_get_brightness_on_success(self):
+ """Verifies get_brightness on success."""
+ self.dut.pw_rpc_light.brightness = _FAKE_BRIGHTNESS
+
+ brightness = self.handler._get_brightness({})
+
+ self.assertEqual(_FAKE_BRIGHTNESS, brightness)
+
+ def test_04_get_brightness_on_failure(self):
+ """Verifies get_brightness on failure."""
+ mock_state = mock.PropertyMock(
+ side_effect=errors.DeviceError(_FAKE_ERROR_MSG))
+ type(self.dut.pw_rpc_light).brightness = mock_state
+ with self.assertRaisesRegex(errors.DeviceError, _FAKE_ERROR_MSG):
+ self.handler._get_brightness({})
+
+ @mock.patch.object(light.LightCommandHandler, 'validate_key_in_params')
+ def test_05_set_brightness_on_success(self, mock_validate_key_in_params):
+ """Verifies set_brightness on success."""
+ self.handler._set_brightness({'level': _FAKE_BRIGHTNESS})
+
+ mock_validate_key_in_params.assert_called_once()
+ self.dut.pw_rpc_light.on.assert_called_once_with(level=_FAKE_BRIGHTNESS)
+
+ @mock.patch.object(light.LightCommandHandler, 'validate_key_in_params')
+ def test_05_set_brightness_on_failure(self, mock_validate_key_in_params):
+ """Verifies set_brightness on failure."""
+ self.dut.pw_rpc_light.on.side_effect = errors.DeviceError('')
+ with self.assertRaises(errors.DeviceError):
+ self.handler._set_brightness({'level': _FAKE_BRIGHTNESS})
+
+ def test_06_get_color_on_success(self):
+ """Verifies get_color on success."""
+ self.dut.pw_rpc_light.color.hue = _FAKE_HUE
+ self.dut.pw_rpc_light.color.saturation = _FAKE_SATURATION
+ expected_response = {'hue': _FAKE_HUE, 'saturation': _FAKE_SATURATION}
+
+ color = self.handler._get_color({})
+
+ self.assertEqual(expected_response, color)
+
+ def test_06_get_color_on_failure(self):
+ """Verifies get_color on failure."""
+ mock_state = mock.PropertyMock(
+ side_effect=errors.DeviceError(_FAKE_ERROR_MSG))
+ type(self.dut.pw_rpc_light.color).hue = mock_state
+ with self.assertRaisesRegex(errors.DeviceError, _FAKE_ERROR_MSG):
+ self.handler._get_color({})
+
+ @mock.patch.object(light.LightCommandHandler, 'validate_key_in_params')
+ def test_07_set_color_on_success(self, mock_validate_key_in_params):
+ """Verifies set_color on success."""
+ params = {'hue': _FAKE_HUE, 'saturation': _FAKE_SATURATION}
+
+ self.handler._set_color(params)
+
+ self.assertEqual(2, mock_validate_key_in_params.call_count)
+ self.dut.pw_rpc_light.on.assert_called_once_with(
+ hue=_FAKE_HUE, saturation=_FAKE_SATURATION)
+
+ @mock.patch.object(light.LightCommandHandler, 'validate_key_in_params')
+ def test_07_set_color_on_failure(self, mock_validate_key_in_params):
+ """Verifies set_color on failure."""
+ params = {'hue': _FAKE_HUE, 'saturation': _FAKE_SATURATION}
+ self.dut.pw_rpc_light.on.side_effect = errors.DeviceError('')
+ with self.assertRaises(errors.DeviceError):
+ self.handler._set_color(params)
+
+
+if __name__ == '__main__':
+ unittest.main(failfast=True)
diff --git a/local_agent/tests/unit_tests/test_command_handlers/test_lock.py b/local_agent/tests/unit_tests/test_command_handlers/test_lock.py
new file mode 100644
index 0000000..393e7a3
--- /dev/null
+++ b/local_agent/tests/unit_tests/test_command_handlers/test_lock.py
@@ -0,0 +1,77 @@
+# Copyright 2021 Google LLC
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+"""Unit tests for lock command handlers."""
+from parameterized import parameterized
+import unittest
+from unittest import mock
+
+from gazoo_device import errors
+
+from local_agent.translation_layer.command_handlers import lock
+
+
+_FAKE_ERROR_MSG = 'fake-error-msg'
+
+
+class LockCommandHandlerTest(unittest.TestCase):
+ """Unit tests for LockCommandHandler."""
+
+ def setUp(self):
+ super().setUp()
+ self.dut = mock.Mock()
+ self.handler = lock.LockCommandHandler(self.dut)
+
+ @parameterized.expand([(True,), (False,)])
+ def test_01_get_is_locked_on_success(self, locked_state):
+ """Verifies get_is_locked on success."""
+ self.dut.pw_rpc_lock.state = locked_state
+
+ is_locked = self.handler._get_is_locked({})
+
+ self.assertEqual(locked_state, is_locked)
+
+ def test_01_get_is_locked_on_failure(self):
+ """Verifies get_is_locked on failure."""
+ mock_state = mock.PropertyMock(
+ side_effect=errors.DeviceError(_FAKE_ERROR_MSG))
+ type(self.dut.pw_rpc_lock).state = mock_state
+ with self.assertRaisesRegex(errors.DeviceError, _FAKE_ERROR_MSG):
+ self.handler._get_is_locked({})
+
+ def test_02_set_lock_on_success(self):
+ """Verifies set_lock on success."""
+ self.handler._set_lock({})
+ self.dut.pw_rpc_lock.lock.assert_called_once()
+
+ def test_02_set_lock_on_failure(self):
+ """Verifies set_lock on failure."""
+ self.dut.pw_rpc_lock.lock.side_effect = errors.DeviceError('')
+ with self.assertRaises(errors.DeviceError):
+ self.handler._set_lock({})
+
+ def test_03_set_unlock_on_success(self):
+ """Verifies set_unlock on success."""
+ self.handler._set_unlock({})
+ self.dut.pw_rpc_lock.unlock.assert_called_once()
+
+ def test_03_set_unlock_on_failure(self):
+ """Verifies set_unlock on failure."""
+ self.dut.pw_rpc_lock.unlock.side_effect = errors.DeviceError('')
+ with self.assertRaises(errors.DeviceError):
+ self.handler._set_unlock({})
+
+
+if __name__ == '__main__':
+ unittest.main(failfast=True)
diff --git a/local_agent/tests/unit_tests/test_gdm_manager.py b/local_agent/tests/unit_tests/test_gdm_manager.py
new file mode 100644
index 0000000..77ae32f
--- /dev/null
+++ b/local_agent/tests/unit_tests/test_gdm_manager.py
@@ -0,0 +1,160 @@
+# Copyright 2021 Google LLC
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+"""Unit tests for gdm_manager."""
+import unittest
+from unittest import mock
+
+import gazoo_device
+from gazoo_device import errors as gdm_errors
+
+from local_agent import errors as agent_errors
+from local_agent.translation_layer import gdm_manager
+
+
+####################### Fake data for unit test #############################
+_FAKE_DEVICE_TYPE = 'fake-device-type'
+_FAKE_DEVICE_ID = 'fake-device-id'
+_FAKE_DEVICE_ID2 = 'fake-device-id2'
+_FAKE_DETECTION_RESULT = {
+ 'efr32-3453': {
+ 'persistent': {'device_type': 'efr32', 'serial_number': '000440173453'},
+ },
+ 'nrf52840-1234': {
+ 'persistent': {'device_type': 'nrf52849', 'serial_number': '1234'},
+ },
+}
+_FAKE_DEVICE_CAPABILITIES = ['flash_build', 'switchboard']
+_FAKE_ERROR_MSG = 'error'
+##############################################################################
+
+
+class GdmManagerTest(unittest.TestCase):
+ """Unit tests for gdm_manager module."""
+
+ def setUp(self):
+ super().setUp()
+ self.mgr = gdm_manager.GdmManager(mock.Mock())
+
+ @mock.patch.object(gazoo_device.Manager, 'get_supported_device_capabilities')
+ @mock.patch.object(gazoo_device.Manager, 'is_device_connected')
+ @mock.patch.object(gazoo_device.Manager, 'get_devices')
+ @mock.patch.object(gazoo_device.Manager, 'detect')
+ def test_01_detect_devices(
+ self,
+ mock_detect,
+ mock_get_devices,
+ mock_is_connect,
+ mock_get_capabilities):
+ """Verifies detect_devices on success."""
+ mock_get_devices.return_value = _FAKE_DETECTION_RESULT
+ mock_is_connect.side_effect = [True, False]
+ mock_get_capabilities.return_value = _FAKE_DEVICE_CAPABILITIES
+ device_statuses = self.mgr.detect_devices()
+ expected_statuses = [
+ {
+ 'deviceId': 'efr32-3453',
+ 'serialNumber': '000440173453',
+ 'deviceType': 'efr32',
+ 'capabilities': ['flash_build', 'switchboard'],
+ },
+ ]
+ self.assertEqual(expected_statuses, device_statuses)
+ mock_detect.assert_called_once_with(
+ force_overwrite=True, log_directory='/tmp')
+ self.assertEqual(2, mock_is_connect.call_count)
+ self.assertEqual(1, mock_get_capabilities.call_count)
+
+ @mock.patch.object(gazoo_device.Manager, 'create_device')
+ def test_02_create_devices_on_success(self, mock_create):
+ """Verifies create_devices on success."""
+ identifiers = [_FAKE_DEVICE_ID, _FAKE_DEVICE_ID2]
+ self.mgr.create_devices(identifiers)
+ mock_create.assert_called_with(
+ identifier=_FAKE_DEVICE_ID2, log_directory=None)
+ self.assertEqual(2, mock_create.call_count)
+
+ @mock.patch.object(gazoo_device.Manager, 'is_device_connected')
+ def test_03_check_device_connected_on_success(
+ self, mock_is_device_connected):
+ """Verifies check_device_connected succeeds when device connected."""
+ mock_is_device_connected.return_value = True
+ self.mgr.check_device_connected(_FAKE_DEVICE_ID)
+ self.assertEqual(1, mock_is_device_connected.call_count)
+
+ @mock.patch.object(gazoo_device.Manager, 'is_device_connected')
+ def test_03_check_device_connected_raises_when_disconnected(
+ self, mock_is_device_connected):
+ """Verifies check_device_connected raises when device disconnected."""
+ mock_is_device_connected.return_value = False
+ with self.assertRaises(agent_errors.DeviceNotConnectedError):
+ self.mgr.check_device_connected(_FAKE_DEVICE_ID)
+ self.assertEqual(1, mock_is_device_connected.call_count)
+
+ @mock.patch.object(gazoo_device.Manager, 'is_device_connected')
+ def test_03_check_device_connected_raises_unexpected_error(
+ self, mock_is_device_connected):
+ """
+ Verifies check_device_connected raises when device operation has error.
+ """
+ mock_is_device_connected.side_effect = (
+ gdm_errors.DeviceError(_FAKE_ERROR_MSG))
+ with self.assertRaisesRegex(gdm_errors.DeviceError, _FAKE_ERROR_MSG):
+ self.mgr.check_device_connected(_FAKE_DEVICE_ID)
+ self.assertEqual(1, mock_is_device_connected.call_count)
+
+ @mock.patch.object(gazoo_device.Manager, 'get_open_device')
+ @mock.patch.object(gazoo_device.Manager, 'get_open_device_names')
+ def test_04_get_device_instance_on_success(
+ self, mock_get_open_device_names, mock_get_open_device):
+ """Verifies get_device_instance on success."""
+ mock_get_open_device_names.return_value = [_FAKE_DEVICE_ID]
+ self.mgr.get_device_instance(_FAKE_DEVICE_ID)
+ self.assertEqual(1, mock_get_open_device_names.call_count)
+ mock_get_open_device.assert_called_once_with(_FAKE_DEVICE_ID)
+
+ @mock.patch.object(gazoo_device.Manager, 'get_open_device_names')
+ def test_04_get_device_instance_failure_because_device_not_open(
+ self, mock_get_open_device_names):
+ """Verifies get_device_instance on failure."""
+ mock_get_open_device_names.return_value = []
+ error_msg = f'{_FAKE_DEVICE_ID} is not open'
+ with self.assertRaisesRegex(agent_errors.DeviceNotOpenError, error_msg):
+ self.mgr.get_device_instance(_FAKE_DEVICE_ID)
+
+ @mock.patch.object(gazoo_device.Manager, 'get_open_device')
+ @mock.patch.object(gazoo_device.Manager, 'get_open_device_names')
+ def test_05_get_device_type_on_success(
+ self, mock_get_names, mock_get_device):
+ """Verifies _get_device_type on success."""
+ mock_get_names.return_value = {_FAKE_DEVICE_ID,}
+ mock_get_device.return_value.device_type = _FAKE_DEVICE_TYPE
+ dut_type = self.mgr.get_device_type(_FAKE_DEVICE_ID)
+ self.assertEqual(_FAKE_DEVICE_TYPE, dut_type)
+ self.assertEqual(1, mock_get_names.call_count)
+ self.assertEqual(1, mock_get_device.call_count)
+
+ @mock.patch.object(gazoo_device.Manager, 'get_open_device_names')
+ def test_05_get_device_type_on_failure_not_open(self, mock_get_names):
+ """Verifies _get_device_type on failure not open."""
+ mock_get_names.return_value = []
+ error_msg = f'{_FAKE_DEVICE_ID} is not open'
+ with self.assertRaisesRegex(agent_errors.DeviceNotOpenError, error_msg):
+ self.mgr.get_device_type(_FAKE_DEVICE_ID)
+
+ @mock.patch.object(gazoo_device.Manager, 'close_open_devices')
+ def test_06_close_open_devices_on_success(self, mock_close_open_devices):
+ """Verifies close_open_devices on success."""
+ self.mgr.close_open_devices()
+ self.assertEqual(1, mock_close_open_devices.call_count)
diff --git a/local_agent/tests/unit_tests/test_local_agent.py b/local_agent/tests/unit_tests/test_local_agent.py
new file mode 100644
index 0000000..eba1454
--- /dev/null
+++ b/local_agent/tests/unit_tests/test_local_agent.py
@@ -0,0 +1,649 @@
+# Copyright 2021 Google LLC
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+"""Unit tests for local agent process."""
+import builtins
+from concurrent import futures
+import configparser
+import os
+import shutil
+import tempfile
+import threading
+import unittest
+from unittest import mock
+
+from local_agent import ams_client
+from local_agent import errors
+from local_agent import local_agent
+from local_agent import suite_session_manager
+from local_agent.translation_layer import translation_layer
+
+
+####################### Fake data for unit test #############################
+_FAKE_AMS_HOST = 'localhost'
+_FAKE_AMS_PORT = 8000
+_FAKE_AGENT_ID = 'fake-agent-id'
+_FAKE_AGENT_SECRET = 'fake-agent-secret'
+_FAKE_AUTH_TOKEN = 'fake-auth-token'
+_FAKE_RPC_RESPONSE = {'fake': 'response'}
+_FAKE_RPC_ID = 'fake-rpc-id'
+_FAKE_ERROR_MSG = 'fake-error-msg'
+_FAKE_ARTIFACTS_DIR = 'fake-artifacts-dir'
+_START_TEST_SUITE = 'startTestSuite'
+_END_TEST_SUITE = 'endTestSuite'
+_LOCK_DEVICE = 'setLock'
+#############################################################################
+
+
+class LocalAgentTest(unittest.TestCase):
+ """Unit tests for local agent process."""
+
+ def setUp(self):
+ super().setUp()
+ self.proc = local_agent.LocalAgentProcess(
+ client=ams_client.AmsClient(), artifacts_dir=_FAKE_ARTIFACTS_DIR)
+
+ _, local_agent.AUTH_FILE = self._create_temp_file_with_clean_up()
+ _, local_agent.DEFAULT_USER_CONFIG = (
+ self._create_temp_file_with_clean_up())
+
+ def _create_temp_file_with_clean_up(self):
+ """Creates a temp file and registers clean-up procedure.
+
+ We use tempfile.mkstemp to create a temporary file, and clean it up
+ using self.addCleanup provided by unittest package.
+
+ Returns:
+ Tuple of (file_descriptor, file_path). Exactly what is returned by a
+ tempfile.mkstemp call.
+ """
+ fd, path = tempfile.mkstemp()
+ self.addCleanup(os.close, fd)
+ self.addCleanup(os.remove, path)
+ return fd, path
+
+ @mock.patch.object(ams_client.AmsClient, 'set_local_agent_credentials')
+ @mock.patch.object(local_agent.LocalAgentProcess,
+ '_read_auths',
+ return_value=('agent-id', 'agent-secret'))
+ def test_01_setup_credentials_existing_credential_success(
+ self,
+ mock_read_auths,
+ mock_set_local_agent_credentials):
+ """Verifies _setup_credentials succeeds with existing credentials."""
+ self.assertTrue(self.proc._setup_credentials())
+
+ @mock.patch.object(builtins, 'input', side_effect=KeyboardInterrupt)
+ @mock.patch.object(local_agent.LocalAgentProcess, '_read_auths')
+ def test_01_setup_credentials_no_existing_credential_will_start_linking(
+ self, mock_read_auths, mock_input):
+ """Verifies _setup_credentials starts linking if no credentials."""
+ mock_read_auths.side_effect = FileNotFoundError
+ self.assertFalse(self.proc._setup_credentials())
+ self.assertEqual(1, mock_input.call_count)
+
+ @mock.patch.object(builtins, 'input', side_effect=KeyboardInterrupt)
+ @mock.patch.object(ams_client.AmsClient,
+ 'set_local_agent_credentials',
+ side_effect=errors.CredentialsError)
+ @mock.patch.object(local_agent.LocalAgentProcess,
+ '_read_auths',
+ return_value=('agent-id', 'agent-secret'))
+ def test_01_setup_credentials_bad_existing_credential_will_start_linking(
+ self,
+ mock_read_auths,
+ mock_set_local_agent_credentials,
+ mock_input):
+ """Verifies _setup_credentials starts linking if bad credentials."""
+ self.assertFalse(self.proc._setup_credentials())
+ self.assertEqual(1, mock_input.call_count)
+
+ @mock.patch.object(ams_client.AmsClient, 'register')
+ @mock.patch.object(builtins, 'input', return_value='the-code')
+ @mock.patch.object(local_agent.LocalAgentProcess,
+ '_read_auths',
+ side_effect=FileNotFoundError)
+ def test_01_setup_credentials_start_linking_and_succeed(
+ self, mock_read_auths, mock_inpnut, mock_register):
+ """Verifies _setup_credentials starts linking and succeeds."""
+ mock_register.return_value = ('the-agent-id', 'the-agent-secret')
+ self.assertTrue(self.proc._setup_credentials())
+ mock_register.assert_called_once_with(linking_code='the-code')
+
+ @mock.patch.object(ams_client.AmsClient, 'register')
+ @mock.patch.object(builtins, 'input', return_value='the-code')
+ @mock.patch.object(local_agent.LocalAgentProcess,
+ '_read_auths',
+ side_effect=FileNotFoundError)
+ def test_01_setup_credentials_start_linking_and_keeps_retry(
+ self,
+ mock_read_auths,
+ mock_input,
+ mock_register):
+ """Verifies _setup_credentials keeps retry when linking fails."""
+ mock_register.side_effect = (
+ errors.ApiTimeoutError,
+ errors.ApiTimeoutError,
+ errors.CredentialsError,
+ errors.CredentialsError,
+ errors.ApiError,
+ errors.CredentialsError,
+ ('the-agent-id', 'the-agent-secret'),
+ )
+ self.assertTrue(self.proc._setup_credentials())
+ self.assertEqual(7, mock_input.call_count)
+
+ def test_03_read_write_auth(self):
+ """Verifies reading/writing auths."""
+ self.proc._write_auths(_FAKE_AGENT_ID, _FAKE_AGENT_SECRET)
+ self.assertEqual(
+ (_FAKE_AGENT_ID, _FAKE_AGENT_SECRET), self.proc._read_auths())
+
+ def test_05_read_config_with_inexistent_file(self):
+ """Verifies read_config returns {} when config file doesn't exist."""
+ local_agent.DEFAULT_USER_CONFIG = ''
+ self.assertEqual({}, local_agent.read_config())
+
+ @mock.patch.object(configparser, 'ConfigParser')
+ def test_05_read_config_missing_root_key(self, mock_parser):
+ """Verifies read_config raise ValueError when root key not present."""
+ mock_config = mock.MagicMock()
+ mock_parser.return_value = mock_config
+ mock_config.__contains__.return_value = False
+ with self.assertRaisesRegex(ValueError, 'Invalid config file'):
+ local_agent.read_config()
+
+ @mock.patch.object(configparser, 'ConfigParser')
+ def test_05_read_config_success(self, mock_parser):
+ """Verifies read_config run successfully."""
+ mock_config = mock.MagicMock()
+ mock_parser.return_value = mock_config
+ mock_config.__contains__.return_value = True
+
+ local_agent.read_config()
+
+ mock_config.read.assert_called_once()
+
+ @mock.patch.object(
+ local_agent.LocalAgentProcess,'_start_info_reporting_thread')
+ @mock.patch.object(
+ local_agent.LocalAgentProcess, '_start_rpc_polling_thread')
+ @mock.patch.object(suite_session_manager.SuiteSessionManager, 'start')
+ @mock.patch.object(
+ local_agent.LocalAgentProcess, '_setup_credentials', return_value=True)
+ @mock.patch.object(threading.Event, 'is_set', side_effect=(False, True))
+ def test_06_run_starts_two_top_level_threads(
+ self,
+ mock_event_is_set,
+ mock_setup_credentials,
+ mock_start,
+ mock_start_rpc_polling_thread,
+ mock_start_info_reporting_thread):
+ """Verifies run() starts RPC polling and info reporting threads."""
+ self.proc.run()
+ self.assertEqual(1, mock_start_rpc_polling_thread.call_count)
+ self.assertEqual(1, mock_start_info_reporting_thread.call_count)
+
+ @mock.patch.object(
+ local_agent.LocalAgentProcess,'_start_info_reporting_thread')
+ @mock.patch.object(
+ local_agent.LocalAgentProcess, '_start_rpc_polling_thread')
+ @mock.patch.object(local_agent.LocalAgentProcess,
+ '_setup_credentials',
+ return_value=False)
+ def test_06_run_will_exit_if_cannot_setup_credentials(
+ self,
+ mock_setup_credentials,
+ mock_start_rpc_polling_thread,
+ mock_start_info_reporting_thread):
+ """Verifies run() aborts if _setup_credentials failed."""
+ self.proc.run()
+ self.assertFalse(mock_start_rpc_polling_thread.called)
+ self.assertFalse(mock_start_info_reporting_thread.called)
+
+ @mock.patch.object(
+ translation_layer.TranslationLayer, 'detect_devices', return_value=[])
+ @mock.patch.object(
+ ams_client.AmsClient, 'get_rpc_request_from_ams', return_value=None)
+ @mock.patch.object(ams_client.AmsClient, 'report_info')
+ @mock.patch.object(suite_session_manager.SuiteSessionManager, 'start')
+ @mock.patch.object(
+ local_agent.LocalAgentProcess, '_setup_credentials', return_value=True)
+ def test_06_run_exits_main_thread_if_report_info_thread_is_dead(
+ self,
+ _,
+ mock_start,
+ mock_report_info,
+ mock_get_rpc_request,
+ mock_detect_devices):
+ """Verifies run() terminates local agent if report info thread is dead.
+ """
+ mock_report_info.side_effect = RuntimeError()
+ self.proc.run()
+
+ @mock.patch.object(
+ translation_layer.TranslationLayer, 'detect_devices', return_value=[])
+ @mock.patch.object(ams_client.AmsClient, 'get_rpc_request_from_ams')
+ @mock.patch.object(ams_client.AmsClient, 'report_info')
+ @mock.patch.object(suite_session_manager.SuiteSessionManager, 'start')
+ @mock.patch.object(
+ local_agent.LocalAgentProcess, '_setup_credentials', return_value=True)
+ def test_06_run_exits_main_thread_if_poll_rpc_thread_is_dead(
+ self,
+ _,
+ mock_start,
+ mock_report_info,
+ mock_get_rpc_request,
+ mock_detect_devices):
+ """Verifies run() terminates local agent if poll RPC thread is dead."""
+ mock_get_rpc_request.side_effect = RuntimeError()
+ self.proc.run()
+
+ @mock.patch.object(translation_layer.TranslationLayer, 'detect_devices')
+ @mock.patch.object(ams_client.AmsClient, 'report_info')
+ @mock.patch.object(threading.Event, 'wait', return_value=True)
+ def test_07_report_info_sends_request_to_ams(self,
+ mock_event_wait,
+ mock_report_info,
+ mock_detect_devices):
+ """Verifies _report_info uses AmsClient to report info."""
+ mock_detect_devices.return_value = []
+ self.proc._report_info()
+ self.assertEqual(1, mock_report_info.call_count)
+ self.assertEqual(1, mock_detect_devices.call_count)
+
+ @mock.patch.object(translation_layer.TranslationLayer, 'detect_devices')
+ @mock.patch.object(ams_client.AmsClient, 'report_info')
+ @mock.patch.object(threading.Event, 'wait', return_value=True)
+ def test_07_report_info_wont_break_when_api_error(self,
+ mock_event_wait,
+ mock_report_info,
+ mock_detect_devices):
+ """Verifies _report_info continues when an API error happens."""
+ mock_report_info.side_effect = errors.ApiError
+ self.proc._report_info()
+
+ @mock.patch.object(translation_layer.TranslationLayer, 'detect_devices')
+ @mock.patch.object(ams_client.AmsClient, 'report_info')
+ @mock.patch.object(threading.Event, 'wait', return_value=True)
+ def test_07_report_info_wont_break_when_api_timeout(self,
+ mock_event_wait,
+ mock_report_info,
+ mock_detect_devices):
+ """Verifies _report_info continues when an API error happens."""
+ mock_report_info.side_effect = errors.ApiTimeoutError
+ self.proc._report_info()
+
+ @mock.patch.object(
+ local_agent.LocalAgentProcess, '_clean_up_and_terminate_agent')
+ @mock.patch.object(translation_layer.TranslationLayer, 'detect_devices')
+ @mock.patch.object(ams_client.AmsClient, 'report_info')
+ def test_07_report_info_break_when_agent_unlinked(self,
+ mock_report_info,
+ mock_detect_devices,
+ mock_clean_up):
+ """Verifies _report_info breaks when the local agent is unlinked."""
+ mock_report_info.side_effect = errors.UnlinkedError
+
+ self.proc._report_info()
+
+ mock_clean_up.assert_called_once()
+
+
+ @mock.patch.object(ams_client.AmsClient,
+ 'get_rpc_request_from_ams',
+ return_value=None)
+ @mock.patch.object(threading.Event, 'is_set')
+ def test_08_poll_rpc_gets_request_from_ams(
+ self, mock_event_is_set, mock_get_rpc_request):
+ """Verifies _poll_rpc will get request from AMS in each iteration."""
+ # We let there be 2 iterations.
+ mock_event_is_set.side_effect = (False, False, True)
+ self.proc._poll_rpc()
+ self.assertEqual(2, mock_get_rpc_request.call_count)
+
+ @mock.patch.object(ams_client.AmsClient, 'remove_rpc_request_from_ams')
+ @mock.patch.object(ams_client.AmsClient, 'get_rpc_request_from_ams')
+ @mock.patch.object(threading.Event, 'is_set', side_effect=(False, True))
+ def test_08_poll_rpc_continues_if_api_error_when_getting_request(
+ self,
+ mock_event_is_set,
+ mock_get_rpc_request,
+ mock_remove_rpc):
+ """Verifies _poll_rpc continues when getting RPC has ApiError."""
+ mock_get_rpc_request.side_effect = errors.ApiError
+ self.proc._poll_rpc()
+ self.assertFalse(mock_remove_rpc.called)
+
+ @mock.patch.object(ams_client.AmsClient, 'remove_rpc_request_from_ams')
+ @mock.patch.object(ams_client.AmsClient, 'get_rpc_request_from_ams')
+ @mock.patch.object(threading.Event, 'is_set', side_effect=(False, True))
+ def test_08_poll_rpc_continues_if_api_timeout_when_getting_request(
+ self,
+ mock_event_is_set,
+ mock_get_rpc_request,
+ mock_remove_rpc):
+ """Verifies _poll_rpc continues when get RPC request API timed out."""
+ mock_get_rpc_request.side_effect = errors.ApiTimeoutError
+ self.proc._poll_rpc()
+ self.assertFalse(mock_remove_rpc.called)
+
+ @mock.patch.object(
+ local_agent.LocalAgentProcess, '_clean_up_and_terminate_agent')
+ @mock.patch.object(ams_client.AmsClient, 'get_rpc_request_from_ams')
+ @mock.patch.object(threading.Event, 'is_set', return_value=False)
+ def test_08_poll_rpc_raises_if_agent_unlinked_when_getting_request(
+ self,
+ mock_event_is_set,
+ mock_get_rpc_request,
+ mock_clean_up_and_terminate_agent):
+ """Verifies _poll_rpc raises when get RPC request agent unlinked."""
+ mock_get_rpc_request.side_effect = errors.UnlinkedError
+
+ self.proc._poll_rpc()
+
+ mock_clean_up_and_terminate_agent.assert_called_once()
+
+ @mock.patch.object(local_agent.LocalAgentProcess,
+ '_start_rpc_execution_thread')
+ @mock.patch.object(ams_client.AmsClient, 'remove_rpc_request_from_ams')
+ @mock.patch.object(ams_client.AmsClient, 'get_rpc_request_from_ams')
+ @mock.patch.object(threading.Event, 'is_set', side_effect=(False, True))
+ def test_08_poll_rpc_removes_rpc_request_from_ams_and_executes(
+ self,
+ mock_event_is_set,
+ mock_get_rpc_request,
+ mock_remove_rpc,
+ mock_start_rpc_execution):
+ """Verifies _poll_rpc removes RPC request from AMS and executes it."""
+ fake_rpc_request = {'hi': 'hello'}
+ mock_get_rpc_request.return_value = fake_rpc_request
+
+ self.proc._poll_rpc()
+
+ self.assertEqual(1, mock_get_rpc_request.call_count)
+ mock_remove_rpc.assert_called_once_with(fake_rpc_request)
+ mock_start_rpc_execution.assert_called_once_with(fake_rpc_request)
+
+ @mock.patch.object(ams_client.AmsClient, 'remove_rpc_request_from_ams')
+ @mock.patch.object(ams_client.AmsClient, 'get_rpc_request_from_ams')
+ def test_08_poll_rpc_terminate_local_agent_when_remove_rpc_fails(
+ self,
+ mock_get_request,
+ mock_remove_request):
+ """Verifies _poll_rpc terminates local agent when remove RPC fails."""
+ mock_get_request.return_value = {'hi': 'rpc-request-here'}
+ mock_remove_request.side_effect = errors.ApiError
+
+ self.proc._poll_rpc()
+
+ @mock.patch.object(futures.ThreadPoolExecutor, 'submit')
+ def test_09_start_rpc_execution_thread_submits_to_thread_pool_executor(
+ self, mock_submit):
+ """Verifies _start_rpc_execution_thread submits to ThreadPoolExecutor.
+ """
+ fake_rpc_request = {'hi': 'hello'}
+ fake_future = mock.Mock()
+ mock_submit.return_value = fake_future
+
+ self.proc._start_rpc_execution_thread(fake_rpc_request)
+
+ self.assertEqual(
+ 1,
+ mock_submit.call_count,
+ 'Should submit a task to ThreadPoolExecutor but did not.')
+ self.assertIn(
+ id(fake_future),
+ self.proc._rpc_execution_future_ids,
+ 'Should keep track of the future but did not.')
+ # Should register the callback to the future.
+ fake_future.add_done_callback.assert_called_once_with(
+ self.proc._callback_for_rpc_execution_complete)
+
+ @mock.patch.object(futures.ThreadPoolExecutor, 'shutdown')
+ def test_10_terminate_shutdown_pool_executor(self, mock_shutdown):
+ """Verifies _terminate shuts down the ThreadPoolExecutor."""
+ self.proc._terminate(None, None)
+ self.assertEqual(1, mock_shutdown.call_count)
+
+ @mock.patch.object(threading.Event, 'set')
+ def test_10_terminate_sets_threading_event(self, mock_set):
+ """Verifies _terminate sets the threading event."""
+ self.proc._terminate(None, None)
+ mock_set.assert_called_once()
+
+ @mock.patch.object(local_agent, 'logger')
+ def test_10_terminate_thread_still_alive(self, mock_logger):
+ """Verifies _terminates on failure with still alive threads."""
+ mock_rpc_thread = mock.Mock()
+ mock_rpc_thread.is_alive.return_value = True
+ self.proc._rpc_polling_thread = mock_rpc_thread
+
+ self.proc._terminate(None, None)
+
+ mock_rpc_thread.join.assert_called_once()
+ mock_logger.error.assert_called_once()
+
+ @mock.patch.object(ams_client.AmsClient, 'upload_artifact')
+ @mock.patch.object(os, 'stat')
+ @mock.patch.object(shutil, 'rmtree')
+ @mock.patch.object(shutil, 'make_archive')
+ def test_11_compress_artifacts_and_upload_on_success(
+ self, mock_make, mock_rm, mock_stat, mock_ams_upload):
+ """Verifies _compress_artifacts_and_upload on success."""
+ mock_stat.return_value.st_size = 1
+ with mock.patch('builtins.open',
+ new_callable=mock.mock_open):
+ self.proc._compress_artifacts_and_upload('', '')
+ self.assertEqual(1, mock_make.call_count)
+ self.assertEqual(1, mock_rm.call_count)
+ self.assertEqual(1, mock_stat.call_count)
+ self.assertEqual(1, mock_ams_upload.call_count)
+
+ @mock.patch.object(os, 'stat')
+ @mock.patch.object(shutil, 'rmtree')
+ @mock.patch.object(shutil, 'make_archive')
+ def test_11_compress_artifacts_and_upload_too_large_file(
+ self, mock_make, mock_rm, mock_stat):
+ """
+ Verifies _compress_artifacts_and_upload on failure with too large file.
+ """
+ mock_stat.return_value.st_size = (
+ local_agent.APP_ENGINE_DATA_SIZE_LIMIT + 1)
+ error_mesg = (
+ f'larger than '
+ f'{local_agent.APP_ENGINE_DATA_SIZE_LIMIT_HUMAN_READABLE}')
+ with self.assertRaisesRegex(RuntimeError, error_mesg):
+ self.proc._compress_artifacts_and_upload('', '')
+ self.assertEqual(1, mock_make.call_count)
+ self.assertEqual(1, mock_rm.call_count)
+
+ @mock.patch.object(ams_client.AmsClient, 'upload_artifact')
+ @mock.patch.object(os, 'stat')
+ @mock.patch.object(shutil, 'rmtree')
+ @mock.patch.object(shutil, 'make_archive')
+ def test_11_compress_artifacts_and_upload_uploading_timed_out(
+ self, mock_make, mock_rm, mock_stat, mock_ams_client_upload):
+ """
+ Verifies _compress_artifacts_and_upload on failure due to API timed out.
+ """
+ mock_stat.return_value.st_size = 1
+ mock_ams_client_upload.side_effect = errors.ApiTimeoutError
+ with mock.patch('builtins.open',
+ new_callable=mock.mock_open):
+ self.proc._compress_artifacts_and_upload('', '')
+ self.assertEqual(1, mock_ams_client_upload.call_count)
+
+ @mock.patch.object(ams_client.AmsClient, 'upload_artifact')
+ @mock.patch.object(os, 'stat')
+ @mock.patch.object(shutil, 'rmtree')
+ @mock.patch.object(shutil, 'make_archive')
+ def test_11_compress_artifacts_and_upload_uploading_api_error(
+ self, mock_make, mock_rm, mock_stat, mock_ams_client_upload):
+ """
+ Verifies _compress_artifacts_and_upload on failure due to API error.
+ """
+ mock_stat.return_value.st_size = 1
+ mock_ams_client_upload.side_effect = errors.ApiError
+ with mock.patch('builtins.open',
+ new_callable=mock.mock_open):
+ self.proc._compress_artifacts_and_upload('', '')
+ self.assertEqual(1, mock_ams_client_upload.call_count)
+
+ @mock.patch.object(local_agent, 'LocalAgentProcess')
+ @mock.patch.object(local_agent, 'read_config')
+ def test_12_main_entry(self, mock_read, mock_proc):
+ """Verifies local agent main entry on success."""
+ local_agent.main()
+ self.assertEqual(1, mock_read.call_count)
+ self.assertEqual(1, mock_proc.call_count)
+ self.assertEqual(1, mock_proc.return_value.run.call_count)
+
+ @mock.patch.object(translation_layer.TranslationLayer,'is_rpc_timeout')
+ @mock.patch.object(ams_client.AmsClient, 'send_rpc_response')
+ @mock.patch.object(local_agent.LocalAgentProcess, '_handle_rpc_request')
+ def test_13_execute_rpc_executes_and_sends_result_to_ams(
+ self,
+ mock_handle_request,
+ mock_send_rpc_response,
+ mock_is_rpc_timeout):
+ """Verifies _execute_rpc executes the RPC and sends the result."""
+ fake_rpc_request = {'the': 'rpc-request'}
+ fake_rpc_response = {'the': 'rpc-response'}
+ mock_handle_request.return_value = fake_rpc_response
+ mock_is_rpc_timeout.return_value = False
+
+ self.proc._execute_rpc(fake_rpc_request)
+
+ mock_handle_request.assert_called_once_with(fake_rpc_request)
+ mock_send_rpc_response.assert_called_once_with(fake_rpc_response)
+
+ @mock.patch.object(translation_layer.TranslationLayer,'is_rpc_timeout')
+ @mock.patch.object(local_agent, 'logger')
+ @mock.patch.object(ams_client.AmsClient, 'send_rpc_response')
+ @mock.patch.object(local_agent.LocalAgentProcess, '_handle_rpc_request')
+ def test_13_execute_rpc_fail_to_send_rpc_response(
+ self,
+ mock_handle_request,
+ mock_send_rpc_response,
+ mock_logger,
+ mock_is_rpc_timeout):
+ """Verifies _execute_rpc continues when fail to send RPC response."""
+ fake_rpc_request = {'the': 'rpc-request'}
+ fake_rpc_response = {'the': 'rpc-response'}
+ mock_handle_request.return_value = fake_rpc_response
+ mock_send_rpc_response.side_effect = errors.ApiError
+ mock_is_rpc_timeout.return_value = False
+
+ self.proc._execute_rpc(fake_rpc_request)
+
+ mock_send_rpc_response.assert_called_once()
+ mock_logger.exception.assert_called_once()
+
+ @mock.patch.object(translation_layer.TranslationLayer,'is_rpc_timeout')
+ @mock.patch.object(local_agent, 'logger')
+ @mock.patch.object(local_agent.LocalAgentProcess, '_handle_rpc_request')
+ def test_13_execute_rpc_not_sending_timeout_rpc_response(
+ self,
+ mock_handle_request,
+ mock_logger,
+ mock_is_rpc_timeout):
+ """Verifies _execute_rpc not sending timeout RPC response."""
+ fake_rpc_request = {'id': _FAKE_RPC_ID}
+ mock_is_rpc_timeout.return_value = True
+
+ self.proc._execute_rpc(fake_rpc_request)
+
+ mock_logger.warning.assert_called_once()
+
+ @mock.patch.object(
+ suite_session_manager.SuiteSessionManager, 'start_test_suite')
+ def test_14_handle_rpc_request_start_suite(self, mock_start):
+ """Verifies handle_rpc_request to start suite on success."""
+ mock_start.return_value = _FAKE_RPC_RESPONSE
+ fake_rpc_request = {'method': _START_TEST_SUITE}
+
+ rpc_response = self.proc._handle_rpc_request(fake_rpc_request)
+
+ self.assertEqual(_FAKE_RPC_RESPONSE, rpc_response)
+ mock_start.assert_called_once_with(fake_rpc_request)
+
+ @mock.patch.object(
+ suite_session_manager.SuiteSessionManager, 'end_test_suite')
+ def test_14_handle_rpc_request_end_suite(self, mock_end):
+ """Verifies handle_rpc_request to end suite on success."""
+ mock_end.return_value = _FAKE_RPC_RESPONSE
+ fake_rpc_request = {'method': _END_TEST_SUITE}
+
+ rpc_response = self.proc._handle_rpc_request(fake_rpc_request)
+
+ self.assertEqual(_FAKE_RPC_RESPONSE, rpc_response)
+ mock_end.assert_called_once_with(fake_rpc_request)
+
+ @mock.patch.object(
+ translation_layer.TranslationLayer, 'dispatch_to_cmd_handler')
+ def test_14_handle_rpc_request_device_control(self, mock_dispatch):
+ """Verifies handle_rpc_request to control device on success."""
+ mock_dispatch.return_value = _FAKE_RPC_RESPONSE
+ fake_rpc_request = {'method': _LOCK_DEVICE}
+
+ rpc_response = self.proc._handle_rpc_request(fake_rpc_request)
+
+ self.assertEqual(_FAKE_RPC_RESPONSE, rpc_response)
+ mock_dispatch.assert_called_once_with(fake_rpc_request)
+
+ @mock.patch.object(local_agent, 'rpc_request_type', return_value='')
+ def test_14_handle_rpc_request_invalid_rpc_type(self, mock_req_type):
+ """Verifies handle_rpc_request on failure with invalid RPC type."""
+ error_msg = 'Invalid RPC request type'
+
+ rpc_response = self.proc._handle_rpc_request({'id': '', 'method': ''})
+
+ self.assertIn(error_msg, rpc_response['error']['message'])
+
+ @mock.patch.object(
+ translation_layer.TranslationLayer, 'dispatch_to_cmd_handler')
+ def test_14_handle_rpc_unexpected_errors(self, mock_dispatch):
+ """Verifies handle_rpc_request on failure with unexpected errors."""
+ mock_dispatch.side_effect = RuntimeError(_FAKE_ERROR_MSG)
+ fake_rpc_request = {'id': _FAKE_RPC_ID, 'method': _LOCK_DEVICE}
+
+ rpc_response = self.proc._handle_rpc_request(fake_rpc_request)
+
+ self.assertIn(_FAKE_ERROR_MSG, rpc_response['error']['message'])
+
+ @mock.patch.object(os, 'remove')
+ @mock.patch.object(os.path, 'exists', return_value=True)
+ @mock.patch.object(suite_session_manager.SuiteSessionManager, 'clean_up')
+ def test_15_clean_up_and_terminate_agent_on_success(
+ self, mock_clean_up, mock_exists, mock_rm):
+ """Verifies _clean_up_and_terminate_agent on success."""
+ self.proc._clean_up_and_terminate_agent(remove_auth_file=True)
+ mock_clean_up.assert_called_once()
+ mock_rm.assert_called_once()
+
+ @mock.patch.object(local_agent, 'logger')
+ def test_16_callback_for_rpc_execution_complete(self, mock_logger):
+ """Verifies _callback_for_rpc_execution_complete on success."""
+ mock_future = mock.Mock()
+ self.proc._rpc_execution_future_ids.add(id(mock_future))
+
+ self.proc._callback_for_rpc_execution_complete(mock_future)
+
+ mock_future.exception.assert_called_once()
+ mock_logger.error.assert_called_once()
+
+
+if __name__ == '__main__':
+ unittest.main(failfast=True)
diff --git a/local_agent/tests/unit_tests/test_logger.py b/local_agent/tests/unit_tests/test_logger.py
new file mode 100644
index 0000000..5b49ece
--- /dev/null
+++ b/local_agent/tests/unit_tests/test_logger.py
@@ -0,0 +1,113 @@
+# Copyright 2021 Google LLC
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+"""Unit tests for logger."""
+import unittest
+from unittest import mock
+import logging
+
+from local_agent import logger
+
+_FAKE_LOG_FILE = 'fake-log-file'
+
+
+class LoggerTest(unittest.TestCase):
+ """Unit test for logger module."""
+
+ @mock.patch.object(logging.handlers, 'RotatingFileHandler')
+ def test_01_create_handler_log_file_not_none(
+ self, mock_rotating_file_handler):
+ """Verifies create_handler creates logging handler for log file."""
+ mock_handler = mock.Mock()
+ mock_rotating_file_handler.return_value = mock_handler
+
+ returned_handler = logger.create_handler(log_file=_FAKE_LOG_FILE)
+
+ self.assertEqual(returned_handler, mock_handler)
+ mock_handler.setLevel.assert_called_once()
+ mock_handler.setFormatter.assert_called_once()
+
+ @mock.patch.object(logging, 'StreamHandler')
+ def test_01_create_handler_no_log_file(self, mock_stream_handler):
+ """Verifies create_handler creates logging stream handler."""
+ mock_handler = mock.Mock()
+ mock_stream_handler.return_value = mock_handler
+
+ returned_handler = logger.create_handler()
+
+ self.assertEqual(returned_handler, mock_handler)
+ mock_handler.setLevel.assert_called_once()
+ mock_handler.setFormatter.assert_called_once()
+
+ @mock.patch.object(logger, 'logger')
+ @mock.patch.object(logger, 'create_handler')
+ def test_02_add_file_handler(self, mock_create_handler, mock_logger):
+ """Verifies add_file_handler adds handler for log file."""
+ mock_handler = mock.Mock()
+ mock_create_handler.return_value = mock_handler
+
+ logger.add_file_handler(log_file=_FAKE_LOG_FILE)
+
+ mock_create_handler.assert_called_once()
+ self.assertEqual(logger.log_handler_map[_FAKE_LOG_FILE], mock_handler)
+ mock_logger.addHandler.assert_called_once_with(mock_handler)
+
+ @mock.patch.object(logger, 'create_handler')
+ def test_02_add_file_handler_already_exists(
+ self, mock_create_handler):
+ """Verifies add_file_handler does nothing for already existing file."""
+ logger.log_handler_map[_FAKE_LOG_FILE] = mock.Mock()
+
+ logger.add_file_handler(log_file=_FAKE_LOG_FILE)
+
+ self.assertEqual(0, mock_create_handler.call_count)
+
+ @mock.patch.object(logger, 'logger')
+ def test_03_remove_file_handler(self, mock_logger):
+ """Verifies remove_file_handler removes handler for log file."""
+ mock_handler = mock.Mock()
+ logger.log_handler_map[_FAKE_LOG_FILE] = mock_handler
+
+ logger.remove_file_handler(log_file=_FAKE_LOG_FILE)
+
+ self.assertNotIn(_FAKE_LOG_FILE, logger.log_handler_map)
+ mock_logger.removeHandler.assert_called_once_with(mock_handler)
+
+ @mock.patch.object(logger, 'logger')
+ def test_03_remove_file_handler_not_exists(self, mock_logger):
+ """Verifies remove_file_handler does nothing for non-existed handler."""
+ logger.log_handler_map.clear()
+
+ logger.remove_file_handler(log_file=_FAKE_LOG_FILE)
+
+ self.assertEqual(0, mock_logger.removeHandler.call_count)
+
+ @mock.patch.object(logger, 'create_handler')
+ @mock.patch.object(logger, 'logger')
+ def test_04_setup_logger(self, mock_logger, mock_create_handler):
+ """Verifies setup_logger on success."""
+ logger.setup_logger()
+
+ mock_logger.setLevel.assert_called_once()
+ self.assertEqual(2, mock_create_handler.call_count)
+ self.assertEqual(2, mock_logger.addHandler.call_count)
+
+ def test_05_get_logger(self):
+ """Verifies get_logger gets local agent logger."""
+ local_agent_logger = logger.get_logger()
+ self.assertEqual(logger.logger, local_agent_logger)
+
+
+if __name__ == '__main__':
+ unittest.main(failfast=True)
diff --git a/local_agent/tests/unit_tests/test_suite_session_manager.py b/local_agent/tests/unit_tests/test_suite_session_manager.py
new file mode 100644
index 0000000..901e816
--- /dev/null
+++ b/local_agent/tests/unit_tests/test_suite_session_manager.py
@@ -0,0 +1,265 @@
+# Copyright 2021 Google LLC
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+"""Unit tests for suite session manager."""
+import os
+import threading
+import time
+import unittest
+from unittest import mock
+
+from local_agent import errors as agent_errors
+from local_agent import logger as logger_module
+from local_agent import suite_session_manager
+
+
+####################### Fake data for unit test #############################
+_FAKE_DEVICE_ID = 'fake-device-id'
+_FAKE_DEVICE_ID2 = 'fake-device-id2'
+_FAKE_TEST_SUITE_ID = 'fake-test-suite-id'
+_FAKE_TEST_SUITE_ID2 = 'fake-test-suite-id2'
+_FAKE_TEST_RESULT_ID = 'fake-test-result-id'
+_START_TEST_SUITE = 'startTestSuite'
+_END_TEST_SUITE = 'endTestSuite'
+_THREADING_MODULE_PATH = 'threading.Thread'
+##############################################################################
+
+
+def rpc_request(method, params):
+ """Simple wrapper for json rpc request."""
+ return {'jsonrpc': '2.0', 'id': 0, 'method': method, 'params': params}
+
+
+class SuiteSessionManagerTest(unittest.TestCase):
+ """Unit tests for local agent suite session manager."""
+
+ def setUp(self):
+ super().setUp()
+ fake_thread = mock.patch(_THREADING_MODULE_PATH)
+ fake_thread.start()
+ self.addCleanup(fake_thread.stop)
+ self.suite_mgr = suite_session_manager.SuiteSessionManager(
+ artifacts_fn=mock.Mock(),
+ artifact_root_dir='',
+ create_devices_fn=mock.Mock(),
+ close_devices_fn=mock.Mock())
+
+ def test_01_start_on_success(self):
+ """Verifies start method on success."""
+ self.suite_mgr.start(termination_event=mock.Mock())
+ self.suite_mgr._suite_timeout_checker.start.assert_called_once()
+
+ @mock.patch.object(
+ suite_session_manager.SuiteSessionManager,
+ '_initialize_artifact_directory')
+ def test_02_start_test_suite_on_success(self, mock_init_dir):
+ """Verifies start_test_suite on success."""
+ self.suite_mgr._ongoing_test_suite_id = None
+ start_suite = rpc_request(
+ _START_TEST_SUITE,
+ {'id': _FAKE_TEST_SUITE_ID,
+ 'dutDeviceIds': [_FAKE_DEVICE_ID,_FAKE_DEVICE_ID2]})
+ expected_response = {
+ 'id': start_suite['id'], 'jsonrpc': '2.0', 'result': {}}
+
+ rpc_response = self.suite_mgr.start_test_suite(start_suite)
+
+ self.assertEqual(expected_response, rpc_response)
+ mock_init_dir.assert_called_once_with(_FAKE_TEST_SUITE_ID)
+ self.suite_mgr._create_devices.assert_called_once()
+
+ @mock.patch.object(
+ suite_session_manager.SuiteSessionManager,
+ '_initialize_artifact_directory')
+ @mock.patch.object(suite_session_manager.SuiteSessionManager, 'clean_up')
+ def test_02_start_test_suite_force_on_success(
+ self, mock_force_cleanup, mock_init_dir):
+ """Verifies start_test_suite with forceStart=True on success."""
+ self.suite_mgr._ongoing_test_suite_id = _FAKE_TEST_SUITE_ID
+ start_suite = rpc_request(
+ _START_TEST_SUITE,
+ {'id': _FAKE_TEST_SUITE_ID2,
+ 'forceStart': True,
+ 'dutDeviceIds': [_FAKE_DEVICE_ID]})
+ expected_response = {
+ 'id': start_suite['id'], 'jsonrpc': '2.0', 'result': {}}
+
+ rpc_response = self.suite_mgr.start_test_suite(start_suite)
+
+ self.assertEqual(expected_response, rpc_response)
+ mock_force_cleanup.assert_called_once()
+ mock_init_dir.assert_called_once_with(_FAKE_TEST_SUITE_ID2)
+
+ def test_02_start_test_suite_on_failure_no_device_ids(self):
+ """Verifies start_test_suite on failure with no device ids."""
+ no_device_ids = rpc_request(
+ _START_TEST_SUITE, {'id': _FAKE_TEST_SUITE_ID})
+ error_msg = 'Invalid rpc command, no dutDeviceIds'
+ with self.assertRaisesRegex(agent_errors.InvalidRPCError, error_msg):
+ self.suite_mgr.start_test_suite(no_device_ids)
+
+ def test_02_start_test_suite_on_failure_invalid_session(self):
+ """Verifies start_test_suite on failure with incomplete session."""
+ self.suite_mgr._ongoing_test_suite_id = _FAKE_TEST_SUITE_ID
+ invalid_start = rpc_request(
+ _START_TEST_SUITE,
+ {'id': _FAKE_TEST_SUITE_ID, 'dutDeviceIds': [_FAKE_DEVICE_ID]})
+ error_msg = f'{_FAKE_TEST_SUITE_ID} has not ended yet.'
+ with self.assertRaisesRegex(
+ agent_errors.InvalidTestSuiteSessionError, error_msg):
+ self.suite_mgr.start_test_suite(invalid_start)
+
+ @mock.patch.object(suite_session_manager.SuiteSessionManager, 'clean_up')
+ def test_03_end_test_suite_on_success(self, mock_cleanup):
+ """Verifies end test suite on success."""
+ self.suite_mgr._ongoing_test_suite_id = _FAKE_TEST_SUITE_ID
+ self.suite_mgr._ongoing_test_suite_start_time = 0
+ end_suite = rpc_request(_END_TEST_SUITE, {'id': _FAKE_TEST_SUITE_ID})
+ expected_response = {
+ 'id': end_suite['id'], 'jsonrpc': '2.0', 'result': {}}
+
+ rpc_response = self.suite_mgr.end_test_suite(end_suite)
+
+ self.assertEqual(expected_response, rpc_response)
+ mock_cleanup.assert_called_once()
+
+ def test_03_end_test_suite_on_failure_invalid_session(self):
+ """Verifies end_test_suite on failure with invalid session."""
+ self.suite_mgr._ongoing_test_suite_id = None
+ invalid_end = rpc_request(_END_TEST_SUITE, {'id': _FAKE_TEST_SUITE_ID})
+ error_msg = f'Session {_FAKE_TEST_SUITE_ID} has never started before.'
+ with self.assertRaisesRegex(
+ agent_errors.InvalidTestSuiteSessionError, error_msg):
+ self.suite_mgr.end_test_suite(invalid_end)
+
+ @mock.patch.object(
+ suite_session_manager.SuiteSessionManager, '_remove_outdated_artifacts')
+ def test_04_clean_up_on_success(self, mock_remove):
+ """Verifies clean_up method on success."""
+ self.suite_mgr._ongoing_test_suite_id = _FAKE_TEST_SUITE_ID
+ self.suite_mgr._ongoing_test_suite_start_time = 0
+
+ self.suite_mgr.clean_up(test_result_id=_FAKE_TEST_RESULT_ID)
+
+ self.suite_mgr._close_devices.assert_called_once()
+ self.suite_mgr._compress_artifacts_and_upload.assert_called_once_with(
+ test_suite_id=_FAKE_TEST_SUITE_ID,
+ test_result_id=_FAKE_TEST_RESULT_ID)
+ mock_remove.assert_called_once()
+
+ @mock.patch.object(logger_module, 'add_file_handler')
+ @mock.patch.object(os, 'makedirs')
+ @mock.patch.object(os.path, 'exists', return_value=False)
+ def test_05_initialize_artifact_directory_on_success(
+ self, mock_exists, mock_mk, mock_add):
+ """Verifies initialize_artifact_directory on success."""
+ self.suite_mgr._initialize_artifact_directory(_FAKE_TEST_SUITE_ID)
+ mock_mk.assert_called_once()
+ mock_add.assert_called_once()
+
+ @mock.patch.object(logger_module, 'logger')
+ @mock.patch.object(os, 'listdir')
+ @mock.patch.object(os.path, 'isfile', return_value=True)
+ @mock.patch.object(os, 'stat')
+ @mock.patch.object(os, 'remove')
+ @mock.patch.object(os.path, 'exists', return_value=True)
+ def test_06_remove_outdated_artifacts_on_success(
+ self,
+ mock_exists,
+ mock_rm,
+ mock_stat,
+ mock_isfile,
+ mock_listdir,
+ mock_logger):
+ """Verifies _remove_outdated_artifacts removes outdated files only."""
+ mock_listdir.return_value = [
+ 'outdated_artifacts.zip',
+ 'not_outdated_artifacts.zip',
+ ]
+ mock_stat.side_effect = [
+ mock.Mock(st_ctime=0),
+ mock.Mock(st_ctime=time.time()),
+ ]
+
+ self.suite_mgr._remove_outdated_artifacts()
+
+ mock_exists.assert_called_once()
+ mock_rm.assert_called_once()
+ self.assertEqual(2, mock_isfile.call_count)
+ self.assertEqual('outdated_artifacts.zip',
+ os.path.basename(mock_rm.call_args.args[0]))
+
+ @mock.patch.object(os, 'listdir')
+ @mock.patch.object(os.path, 'isfile', return_value=True)
+ @mock.patch.object(os, 'stat')
+ @mock.patch.object(os, 'remove')
+ @mock.patch.object(os.path, 'exists', return_value=True)
+ def test_06_remove_outdated_artifacts_will_suppress_oserror(
+ self, mock_exists, mock_rm, mock_stat, mock_isfile, mock_listdir):
+ """
+ Verifies _remove_outdated_artifacts suppresses OSError
+ when cannot remove.
+ """
+ mock_listdir.return_value = ['artifacts.zip']
+ mock_stat.return_value.st_ctime = 0
+ mock_rm.side_effect = OSError
+
+ self.suite_mgr._remove_outdated_artifacts()
+
+ mock_exists.assert_called_once()
+ mock_rm.assert_called_once()
+
+ @mock.patch.object(os, 'remove')
+ @mock.patch.object(os.path, 'exists', return_value=False)
+ def test_06_remove_outdated_artifacts_not_exists(
+ self, mock_exists, mock_rm):
+ """Verifies _remove_outdated_artifacts without existing artifacts."""
+ self.suite_mgr._remove_outdated_artifacts()
+ mock_exists.assert_called_once()
+ self.assertEqual(0, mock_rm.call_count)
+
+ @mock.patch.object(suite_session_manager.SuiteSessionManager, 'clean_up')
+ @mock.patch.object(suite_session_manager.time, 'sleep')
+ def test_07_check_suite_timeout_no_test_suite(self, mock_sleep, mock_cleanup):
+ """Verifies _initialize_suite_tracker no ongoing suite."""
+ fake_terminiation_event = mock.Mock()
+ fake_terminiation_event.wait.side_effect = [False, True]
+ self.suite_mgr._termination_event = fake_terminiation_event
+ self.suite_mgr._ongoing_test_suite_id = None
+
+ self.suite_mgr._check_suite_timeout()
+
+ self.assertEqual(0, mock_cleanup.call_count)
+
+ @mock.patch.object(suite_session_manager.SuiteSessionManager, 'clean_up')
+ @mock.patch.object(suite_session_manager.time, 'time')
+ @mock.patch.object(suite_session_manager.time, 'sleep')
+ def test_07_check_suite_timeout_on_success(
+ self, mock_sleep, mock_time, mock_cleanup):
+ """Verifies _initialize_suite_tracker on success."""
+ fake_terminiation_event = mock.Mock()
+ fake_terminiation_event.wait.side_effect = [False, True]
+ self.suite_mgr._termination_event = fake_terminiation_event
+ self.suite_mgr._ongoing_test_suite_id = _FAKE_TEST_SUITE_ID
+ self.suite_mgr._ongoing_test_suite_start_time = 0
+ self.suite_mgr._busy_devices = set()
+ mock_time.return_value = suite_session_manager._SUITE_SESSION_TIME_OUT
+
+ self.suite_mgr._check_suite_timeout()
+
+ mock_cleanup.assert_called_once()
+
+
+if __name__ == '__main__':
+ unittest.main(failfast=True)
diff --git a/local_agent/tests/unit_tests/test_translation_layer.py b/local_agent/tests/unit_tests/test_translation_layer.py
new file mode 100644
index 0000000..3113780
--- /dev/null
+++ b/local_agent/tests/unit_tests/test_translation_layer.py
@@ -0,0 +1,224 @@
+# Copyright 2021 Google LLC
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+"""Unit tests for translation layer."""
+import collections
+import unittest
+from unittest import mock
+
+from local_agent import ams_client
+from local_agent import errors as agent_errors
+from local_agent.translation_layer import gdm_manager
+from local_agent.translation_layer import translation_layer
+
+
+####################### Fake data for unit test #############################
+_FAKE_RPC_ID = 'fake-rpc-id'
+_FAKE_DEVICE_TYPE = 'fake-device-type'
+_FAKE_DEVICE_ID = 'fake-device-id'
+_FAKE_CAPABILITY = 'fake-device-capability'
+_SET_ON = 'setOn'
+_SET_LOCK = 'setLock'
+##############################################################################
+
+
+def rpc_request(method, params):
+ """Simple wrapper for json rpc request."""
+ return {'jsonrpc': '2.0', 'id': 0, 'method': method, 'params': params}
+
+
+class TranslationLayerTest(unittest.TestCase):
+ """Unit tests for local agent translation layer."""
+
+ def setUp(self):
+ super().setUp()
+ self.mock_client = mock.Mock()
+ self.translator = translation_layer.TranslationLayer(
+ client=self.mock_client)
+
+ @mock.patch.object(gdm_manager.GdmManager, 'create_devices')
+ def test_01_create_devices(self, mock_create):
+ """Verifies create_devices method on success."""
+ fake_dut_ids = [_FAKE_DEVICE_ID]
+ fake_suite_dir = ''
+ self.translator.create_devices(fake_dut_ids, fake_suite_dir)
+ mock_create.assert_called_once_with(fake_dut_ids, fake_suite_dir)
+
+ @mock.patch.object(gdm_manager.GdmManager, 'close_open_devices')
+ def test_02_close_devices(self, mock_close):
+ """Verifies close_devices method on success."""
+ self.translator.close_devices()
+ mock_close.assert_called_once()
+
+ @mock.patch.object(gdm_manager.GdmManager, 'detect_devices')
+ def test_03_detect_devices(self, mock_detect):
+ """Verifies detect_devices on success."""
+ self.translator.detect_devices()
+ self.assertEqual(1, mock_detect.call_count)
+
+ @mock.patch.object(translation_layer.TranslationLayer, '_get_cmd_handler')
+ @mock.patch.object(gdm_manager.GdmManager, 'check_device_connected')
+ def test_04_dispatch_to_cmd_handler_on_success(self, mock_check, mock_get):
+ """Verifies dispatch_to_cmd_handler on success."""
+ self.translator._busy_devices.clear()
+ set_on = rpc_request(_SET_ON, {'dutDeviceId': _FAKE_DEVICE_ID})
+ self.translator.dispatch_to_cmd_handler(set_on)
+ self.assertEqual(1, mock_check.call_count)
+ self.assertEqual(1, mock_get.call_count)
+
+ def test_04_dispatch_to_cmd_handler_on_failure_no_device(self):
+ """Verifies dispatch_to_cmd_handler on failure with no device params."""
+ invalid_rpc_request = rpc_request(_SET_ON, {})
+ error_msg = 'no dutDeviceId in params'
+ with self.assertRaisesRegex(agent_errors.InvalidRPCError, error_msg):
+ self.translator.dispatch_to_cmd_handler(invalid_rpc_request)
+
+ @mock.patch.object(translation_layer, 'issubclass', return_value=True)
+ def test_05_validate_handlers_cls_map_on_success(self, mock_isusbclass):
+ """Verifies validate_handlers_cls_map on success."""
+ fake_handlers = [mock.Mock(SUPPORTED_METHODS={_SET_ON,}),
+ mock.Mock(SUPPORTED_METHODS={_SET_LOCK,})]
+ translation_layer.validate_handlers_cls_map(fake_handlers)
+ self.assertEqual(2, mock_isusbclass.call_count)
+
+ @mock.patch.object(translation_layer, 'issubclass', return_value=False)
+ def test_05_validate_handlers_cls_map_on_failure_not_subclass(
+ self, mock_issubclass):
+ """Verifies validate_handlers_cls_map invalid handler."""
+ error_msg = 'is not a subclass of BaseCommandHandler.'
+ with self.assertRaisesRegex(agent_errors.HandlerInvalidError, error_msg):
+ translation_layer.validate_handlers_cls_map([mock.Mock(__name__='')])
+
+ @mock.patch.object(translation_layer, 'issubclass', return_value=True)
+ def test_05_validate_handlers_cls_map_on_failure_collision(
+ self, mock_issubclass):
+ """Verifies validate_handlers_cls_map collision."""
+ fake_handlers = [mock.Mock(SUPPORTED_METHODS={_SET_ON,}, __name__='1'),
+ mock.Mock(SUPPORTED_METHODS={_SET_ON,}, __name__='2')]
+ error_msg = f'Handlers 1 and 2 have duplicate methods: {_SET_ON}.'
+ with self.assertRaisesRegex(agent_errors.HandlersCollisionError, error_msg):
+ translation_layer.validate_handlers_cls_map(fake_handlers)
+ self.assertEqual(2, mock_issubclass.call_count)
+
+ def test_06_get_cmd_handler_on_success_already_exists(self):
+ """Verifies get_cmd_handler on success with existing handler."""
+ method = _SET_ON
+ self.translator._cmd_handlers[_FAKE_DEVICE_ID] = (
+ mock.Mock(SUPPORTED_METHODS={method,}))
+ handler = self.translator._get_cmd_handler(_FAKE_DEVICE_ID, method)
+ self.assertIn(method, handler.SUPPORTED_METHODS)
+
+ def test_06_get_cmd_handler_on_failure_method_not_found(self):
+ """Verifies get_cmd_handler on failure method not found."""
+ self.translator._cmd_handlers[_FAKE_DEVICE_ID] = (
+ mock.Mock(SUPPORTED_METHODS={_SET_LOCK,}))
+ error_msg = 'No matching command handler'
+ with self.assertRaisesRegex(agent_errors.HandlerNotFoundError, error_msg):
+ self.translator._get_cmd_handler(_FAKE_DEVICE_ID, _SET_ON)
+
+ @mock.patch.object(gdm_manager.GdmManager, 'get_device_instance')
+ @mock.patch.object(gdm_manager.GdmManager, 'get_device_type')
+ def test_06_get_cmd_handler_on_success_not_exists(
+ self, mock_get_type, mock_get_inst):
+ """Verifies get_cmd_handler on success without existing handler."""
+ method = _SET_ON
+ self.translator._cmd_handlers.clear()
+ self.translator._handlers_cls_map[_FAKE_DEVICE_TYPE] = (
+ {method: mock.Mock()})
+ mock_get_type.return_value = _FAKE_DEVICE_TYPE
+ handler = self.translator._get_cmd_handler(_FAKE_DEVICE_ID, method)
+ self.assertEqual(
+ self.translator._cmd_handlers[_FAKE_DEVICE_ID], handler)
+ self.assertEqual(1, mock_get_type.call_count)
+ self.assertEqual(1, mock_get_inst.call_count)
+
+ @mock.patch.object(gdm_manager.GdmManager, 'get_device_type')
+ def test_06_get_cmd_handler_on_failure_invalid_device_type(self, mock_get_type):
+ """Verifies get_cmd_handler on failure with invalid device type."""
+ fake_not_exist_type = 'not-exists-type'
+ mock_get_type.return_value = fake_not_exist_type
+ error_msg = f'device type: {fake_not_exist_type} is not implemented'
+ with self.assertRaisesRegex(agent_errors.HandlerNotFoundError, error_msg):
+ self.translator._get_cmd_handler(_FAKE_DEVICE_TYPE, '')
+ self.assertEqual(1, mock_get_type.call_count)
+
+ @mock.patch.object(gdm_manager.GdmManager, 'get_device_type')
+ def test_06_get_cmd_handler_on_failure_invalid_method(self, mock_get_type):
+ """Verifies get_cmd_handler on failure with invalid method."""
+ self.translator._cmd_handlers.clear()
+ self.translator._handlers_cls_map[_FAKE_DEVICE_TYPE] = (
+ {_SET_ON: mock.Mock()})
+ mock_get_type.return_value = _FAKE_DEVICE_TYPE
+ error_msg = 'method: setLock is not implemented'
+ with self.assertRaisesRegex(agent_errors.HandlerNotFoundError, error_msg):
+ self.translator._get_cmd_handler(_FAKE_DEVICE_ID, 'setLock')
+ self.assertEqual(1, mock_get_type.call_count)
+
+ def test_07_device_operation_handler_on_failure_is_busy(self):
+ """Verifies device_operation_handler raises when device is still busy"""
+ self.translator._busy_devices = {_FAKE_DEVICE_ID,}
+ error_msg = f'{_FAKE_DEVICE_ID} is still busy'
+ with self.assertRaisesRegex(agent_errors.InvalidRPCError, error_msg):
+ with self.translator.device_operation_handler(
+ _FAKE_DEVICE_ID, _FAKE_RPC_ID):
+ pass
+
+ def test_07_device_operation_handler_on_failure_duplicate_rpc(self):
+ """Verifies device_operation_handler raises when duplicate RPC"""
+ self.translator._rpc_execution_start_time[_FAKE_RPC_ID] = 0
+ error_msg = f'RPC {_FAKE_RPC_ID} is already executing.'
+ with self.assertRaisesRegex(agent_errors.InvalidRPCError, error_msg):
+ with self.translator.device_operation_handler(
+ _FAKE_DEVICE_ID, _FAKE_RPC_ID):
+ pass
+
+ @mock.patch.object(translation_layer, 'validate_handlers_cls_map')
+ def test_08_update_handlers_cls_map_on_success_creation(
+ self, mock_validate):
+ """Verifies _update_handlers_cls_map creation on success."""
+ fake_handler = mock.Mock(SUPPORTED_METHODS={_SET_ON,})
+ fake_capabilities_map = {
+ _FAKE_CAPABILITY: [fake_handler]}
+ translation_layer.GDM_CAPABILITIES_TO_COMMAND_HANDLERS = (
+ fake_capabilities_map)
+ self.translator.update_handlers_cls_map(
+ _FAKE_DEVICE_TYPE, [_FAKE_CAPABILITY])
+ handler = self.translator._handlers_cls_map[_FAKE_DEVICE_TYPE][_SET_ON]
+ self.assertEqual(fake_handler, handler)
+ self.assertEqual(1, mock_validate.call_count)
+
+ @mock.patch.object(translation_layer, 'validate_handlers_cls_map')
+ def test_08_update_handlers_cls_map_does_nothing(self, mock_validate):
+ """Verifies the _update_handlers_cls_map does nothing."""
+ self.translator._handlers_cls_map[_FAKE_DEVICE_TYPE][_SET_ON] = None
+ self.translator.update_handlers_cls_map(
+ _FAKE_DEVICE_TYPE, [_FAKE_CAPABILITY])
+ self.assertEqual(0, mock_validate.call_count)
+
+ def test_09_check_rpc_timeout(self):
+ """Verifies _check_rpc_timeout catches timeout RPCs."""
+ fake_terminiation_event = mock.Mock()
+ fake_terminiation_event.wait.side_effect = [False, True]
+ self.translator._termination_event = fake_terminiation_event
+ self.translator._timeout_rpc.clear()
+ self.translator._rpc_execution_start_time = {_FAKE_DEVICE_ID: 0}
+
+ self.translator._check_rpc_timeout()
+
+ self.mock_client.send_rpc_response.assert_called_once()
+ self.assertTrue(self.translator.is_rpc_timeout(_FAKE_DEVICE_ID))
+
+
+if __name__ == '__main__':
+ unittest.main(failfast=True)
diff --git a/local_agent/translation_layer/__init__.py b/local_agent/translation_layer/__init__.py
new file mode 100644
index 0000000..d46dbae
--- /dev/null
+++ b/local_agent/translation_layer/__init__.py
@@ -0,0 +1,13 @@
+# Copyright 2021 Google LLC
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
diff --git a/local_agent/translation_layer/command_handlers/__init__.py b/local_agent/translation_layer/command_handlers/__init__.py
new file mode 100644
index 0000000..d46dbae
--- /dev/null
+++ b/local_agent/translation_layer/command_handlers/__init__.py
@@ -0,0 +1,13 @@
+# Copyright 2021 Google LLC
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
diff --git a/local_agent/translation_layer/command_handlers/base.py b/local_agent/translation_layer/command_handlers/base.py
new file mode 100644
index 0000000..a472e82
--- /dev/null
+++ b/local_agent/translation_layer/command_handlers/base.py
@@ -0,0 +1,85 @@
+# Copyright 2021 Google LLC
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+"""Module for base command handler."""
+import abc
+import inflection
+from typing import Any, Dict
+
+from local_agent import errors as agent_errors
+
+
+class BaseCommandHandler:
+ """Base handler class for all command handlers."""
+
+ SUPPORTED_METHODS = set()
+
+ def __init__(self, dut: Any) -> None:
+ """
+ Base command handler constructor.
+
+ Args:
+ dut: Gazoo device instance.
+ """
+ self.dut = dut
+
+ def handle_request(self, rpc_request: Dict[str, Any]) -> Dict[str, Any]:
+ """
+ The required method for handling the RPC request,
+ the corresponding GDM device class operation should be called
+ accordingly. Note that we don't return the failure JSON-RPC response
+ explicitly since it's already handled by the exception handler in
+ the local agent process.
+
+ Args:
+ rpc_request: JSON-RPC request.
+
+ Returns:
+ RPC response, which is the result of executing the command.
+ """
+ rpc_id = rpc_request['id']
+ method = rpc_request['method']
+ params = rpc_request['params']
+ method_snake_case = f'_{inflection.underscore(method)}'
+ func = getattr(self, method_snake_case, None)
+ if func is None:
+ raise agent_errors.InvalidRPCError(f'Unknown method {method}.')
+ func_result = func(params)
+ if func_result is not None:
+ result = {'value': func_result}
+ else:
+ result = {}
+ return {'id': rpc_id, 'jsonrpc': '2.0', 'result': result}
+
+ @staticmethod
+ def validate_key_in_params(
+ params: Dict[str, Any], param_key: str, expected_type: Any) -> None:
+ """Verifies params contains the key with the expected type.
+
+ Args:
+ params: RPC request parameters.
+ param_key: Key of the parameter dict.
+ expected_type: Expected type of the given key.
+
+ Raises:
+ ValueError: param_key does not exist in params or the expected_type
+ does not match.
+ """
+ if param_key not in params:
+ raise ValueError(f'Missing field {param_key} from RPC request.')
+ param_value = params[param_key]
+ if not isinstance(param_value, expected_type):
+ raise ValueError(
+ f'Invalid type for {param_key}. '
+ f'Expecting {expected_type} while receiving {param_value}.')
diff --git a/local_agent/translation_layer/command_handlers/common.py b/local_agent/translation_layer/command_handlers/common.py
new file mode 100644
index 0000000..52be68c
--- /dev/null
+++ b/local_agent/translation_layer/command_handlers/common.py
@@ -0,0 +1,68 @@
+# Copyright 2021 Google LLC
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+"""Module for common command handler."""
+from typing import Any, Dict
+
+from gazoo_device import errors
+
+from local_agent import logger as logger_module
+from local_agent.translation_layer.command_handlers import base
+
+
+logger = logger_module.get_logger()
+
+PWRPC_COMMON_CAPABILITY = 'pw_rpc_common'
+
+
+class CommonCommandHandler(base.BaseCommandHandler):
+ """Common command handler.
+
+ The handler dealing with the device common operations which includes:
+ reboot, factory-reset, OTA. Note that OTA is currently not implemented
+ in the Pigweed RPC endpoint.
+
+ Smart Home Reboot Trait Schema:
+ https://developers.google.com/assistant/smarthome/traits/reboot
+ """
+
+ _REBOOT = 'setReboot'
+ _FACTORY_RESET = 'setFactoryReset'
+ SUPPORTED_METHODS = {_REBOOT, _FACTORY_RESET}
+
+ def _set_reboot(self, params: Dict[str, Any]) -> None:
+ """Reboots the device.
+
+ Raises:
+ DeviceError: rebooting fails.
+ """
+ del params
+ try:
+ self.dut.reboot()
+ except errors.DeviceError as exc:
+ logger.exception(f'Rebooting {self.dut.name} failed.')
+ raise exc
+
+ def _set_factory_reset(self, params: Dict[str, Any]) -> None:
+ """Factory resets the device.
+
+ Raises:
+ DeviceError: factory resetting fails.
+ """
+ del params
+ try:
+ self.dut.factory_reset()
+ except errors.DeviceError as exc:
+ logger.exception(f'Factory resetting {self.dut.name} failed.')
+ raise exc
diff --git a/local_agent/translation_layer/command_handlers/handler_registry.py b/local_agent/translation_layer/command_handlers/handler_registry.py
new file mode 100644
index 0000000..ef0cad0
--- /dev/null
+++ b/local_agent/translation_layer/command_handlers/handler_registry.py
@@ -0,0 +1,32 @@
+# Copyright 2021 Google LLC
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+"""Mapping GDM capabilities to their handlers.
+
+This module provides the GDM_CAPABILITIES_TO_COMMAND_HANDLERS dict,
+mapping each GDM capability to the corresponding command handler.
+"""
+import immutabledict
+
+from local_agent.translation_layer.command_handlers import common
+from local_agent.translation_layer.command_handlers import light
+from local_agent.translation_layer.command_handlers import lock
+
+
+# GDM capability -> set of acceptable command handlers
+GDM_CAPABILITIES_TO_COMMAND_HANDLERS = immutabledict.immutabledict({
+ common.PWRPC_COMMON_CAPABILITY: {common.CommonCommandHandler,},
+ light.PWRPC_LIGHT_CAPABILITY: {light.LightCommandHandler,},
+ lock.PWRPC_LOCK_CAPABILITY: {lock.LockCommandHandler,}
+})
diff --git a/local_agent/translation_layer/command_handlers/light.py b/local_agent/translation_layer/command_handlers/light.py
new file mode 100644
index 0000000..bc4a1a6
--- /dev/null
+++ b/local_agent/translation_layer/command_handlers/light.py
@@ -0,0 +1,169 @@
+# Copyright 2021 Google LLC
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+"""Module for light command handler."""
+from typing import Any, Dict
+
+from gazoo_device import errors
+
+from local_agent import logger as logger_module
+from local_agent.translation_layer.command_handlers import base
+
+
+logger = logger_module.get_logger()
+
+PWRPC_LIGHT_CAPABILITY = 'pw_rpc_light'
+
+# OnOff Response States:
+LIGHT_ON = 'on'
+LIGHT_OFF = 'off'
+
+
+class LightCommandHandler(base.BaseCommandHandler):
+ """Lighting device command handler
+
+ Smart Home OnOff Trait Schema:
+ https://developers.google.com/assistant/smarthome/traits/onoff
+ """
+
+ _GET_STATE = 'getOnOff'
+ _SET_STATE_ON = 'setOn'
+ _SET_STATE_OFF = 'setOff'
+ _GET_BRIGHTNESS = 'getBrightness'
+ _SET_BRIGHTNESS = 'setBrightness'
+ _GET_COLOR = 'getColor'
+ _SET_COLOR = 'setColor'
+
+ SUPPORTED_METHODS = {
+ _GET_STATE,
+ _SET_STATE_ON,
+ _SET_STATE_OFF,
+ _GET_BRIGHTNESS,
+ _SET_BRIGHTNESS,
+ _GET_COLOR,
+ _SET_COLOR}
+
+ def _get_on_off(self, params: Dict[str, Any]) -> str:
+ """Queries the light state of the device.
+
+ Returns:
+ The light state.
+
+ Raises:
+ DeviceError: getting light state fails.
+ """
+ del params # not used
+ try:
+ return LIGHT_ON if self.dut.pw_rpc_light.state else LIGHT_OFF
+ except errors.DeviceError as exc:
+ logger.exception(f'Getting light state of {self.dut.name} failed.')
+ raise exc
+
+ def _set_on(self, params: Dict[str, Any]) -> None:
+ """Turns on the light of the device.
+
+ Raises:
+ DeviceError: turning light on fails.
+ """
+ del params # not used
+ try:
+ self.dut.pw_rpc_light.on()
+ except errors.DeviceError as exc:
+ logger.exception(f'Turning {self.dut.name} on failed.')
+ raise exc
+
+ def _set_off(self, params: Dict[str, Any]) -> None:
+ """Turns off the light of the device.
+
+ Raises:
+ DeviceError: turning light off fails.
+ """
+ del params # not used
+ try:
+ self.dut.pw_rpc_light.off()
+ except errors.DeviceError as exc:
+ logger.exception(f'Turning {self.dut.name} off failed.')
+ raise exc
+
+ def _get_brightness(self, params: Dict[str, Any]) -> int:
+ """Queries the current brightness level of the device.
+
+ Returns:
+ The current brightness level.
+
+ Raises:
+ DeviceError: getting light brightness fails.
+ """
+ del params # not used
+ try:
+ return self.dut.pw_rpc_light.brightness
+ except errors.DeviceError as exc:
+ logger.exception(
+ f'Getting light brightness of {self.dut.name} failed.')
+ raise exc
+
+ def _set_brightness(self, params: Dict[str, Any]) -> None:
+ """Sets the current brightness level of the device.
+
+ Raises:
+ DeviceError: setting brightness level fails.
+ """
+ self.validate_key_in_params(
+ params=params, param_key='level', expected_type=int)
+
+ try:
+ self.dut.pw_rpc_light.on(level=params['level'])
+ except errors.DeviceError as exc:
+ logger.exception(
+ f'Setting light brightness of {self.dut.name} failed.')
+ raise exc
+
+ def _get_color(self, params: Dict[str, Any]) -> Dict[str, int]:
+ """Gets the current lighting color of the device.
+
+ Returns:
+ The current hue and saturation values in dict.
+
+ Raises:
+ DeviceError: getting color fails.
+ """
+ del params
+ try:
+ hue = self.dut.pw_rpc_light.color.hue
+ saturation = self.dut.pw_rpc_light.color.saturation
+ except errors.DeviceError as exc:
+ logger.exception(
+ f'Getting light color of {self.dut.name} failed.')
+ raise exc
+ return {'hue': hue, 'saturation': saturation}
+
+ def _set_color(self, params: Dict[str, Any]) -> None:
+ """Sets the lighting color to specific hue and saturation.
+
+ Raises:
+ DeviceError: setting color fails.
+ """
+ self.validate_key_in_params(
+ params=params, param_key='hue', expected_type=int)
+ self.validate_key_in_params(
+ params=params, param_key='saturation', expected_type=int)
+
+ try:
+ hue = params['hue']
+ saturation = params['saturation']
+ self.dut.pw_rpc_light.on(hue=hue, saturation=saturation)
+ except errors.DeviceError as exc:
+ logger.exception(
+ f'Setting light color of {self.dut.name} failed.')
+ raise exc
diff --git a/local_agent/translation_layer/command_handlers/lock.py b/local_agent/translation_layer/command_handlers/lock.py
new file mode 100644
index 0000000..af1c707
--- /dev/null
+++ b/local_agent/translation_layer/command_handlers/lock.py
@@ -0,0 +1,81 @@
+# Copyright 2021 Google LLC
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+"""Module for lock command handler."""
+from typing import Any, Dict
+
+from gazoo_device import errors
+
+from local_agent import logger as logger_module
+from local_agent.translation_layer.command_handlers import base
+
+
+logger = logger_module.get_logger()
+
+PWRPC_LOCK_CAPABILITY = 'pw_rpc_lock'
+
+
+class LockCommandHandler(base.BaseCommandHandler):
+ """Lock device command handler
+
+ Smart Home LockUnlock Trait Schema:
+ https://developers.google.com/assistant/smarthome/traits/lockunlock
+ """
+
+ _GET_IS_LOCKED = 'getIsLocked'
+ _SET_LOCK = 'setLock'
+ _SET_UNLOCK = 'setUnlock'
+ SUPPORTED_METHODS = {_GET_IS_LOCKED, _SET_LOCK, _SET_UNLOCK}
+
+ def _get_is_locked(self, params: Dict[str, Any]) -> bool:
+ """Returns if the device is locked or not.
+
+ Returns:
+ True if the device is locked, false otherwise.
+
+ Raises:
+ DeviceError: getting locked state fails.
+ """
+ del params # not used
+ try:
+ return self.dut.pw_rpc_lock.state
+ except errors.DeviceError as exc:
+ logger.exception(f'Getting locked state of {self.dut.name} failed.')
+ raise exc
+
+ def _set_lock(self, params: Dict[str, Any]) -> None:
+ """Locks the device.
+
+ Raises:
+ DeviceError: locking device fails.
+ """
+ del params # not used
+ try:
+ self.dut.pw_rpc_lock.lock()
+ except errors.DeviceError as exc:
+ logger.exception(f'Locking {self.dut.name} failed.')
+ raise exc
+
+ def _set_unlock(self, params: Dict[str, Any]) -> None:
+ """Unlocks the device.
+
+ Raises:
+ DeviceError: unlocking device on fails.
+ """
+ del params # not used
+ try:
+ self.dut.pw_rpc_lock.unlock()
+ except errors.DeviceError as exc:
+ logger.exception(f'Unlocking {self.dut.name} failed.')
+ raise exc
diff --git a/local_agent/translation_layer/gdm_manager.py b/local_agent/translation_layer/gdm_manager.py
new file mode 100644
index 0000000..c047ecc
--- /dev/null
+++ b/local_agent/translation_layer/gdm_manager.py
@@ -0,0 +1,138 @@
+# Copyright 2021 Google LLC
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+"""Module for GDM manager."""
+import threading
+from typing import Any, Callable, Dict, List, Optional
+
+import gazoo_device
+from gazoo_device import errors as gdm_errors
+
+from local_agent import errors as agent_errors
+from local_agent import logger as logger_module
+from local_agent.translation_layer.command_handlers import light
+
+
+logger = logger_module.get_logger()
+
+# ============= Constants ============= #
+_DEVICE_DETECTION_LOG_DIR = '/tmp'
+# ===================================== #
+
+
+class GdmManager:
+ """GDM manager for device management."""
+
+ def __init__(
+ self, update_handlers_cls_map_fn: Callable[[str, Any], None]):
+ self._mgr_lock = threading.RLock()
+ self._mgr = gazoo_device.Manager()
+ self._first_detection = True
+ self._update_handlers_cls_map = update_handlers_cls_map_fn
+
+ def detect_devices(self) -> List[Dict[str, Any]]:
+ """Detects connected devices.
+
+ Returns:
+ The list of device dict, each dict includes the following fields:
+ deviceId, serialNumber, capabilities and deviceType in GDM.
+ """
+ devices = [] # List[Dict[str, Any]]
+ with self._mgr_lock:
+ self._mgr.detect(force_overwrite=self._first_detection,
+ log_directory=_DEVICE_DETECTION_LOG_DIR)
+ connected_devices = self._mgr.get_devices('all')
+ for device_id, info in connected_devices.items():
+ if not self._mgr.is_device_connected(device_id):
+ continue
+ serial_number = info['persistent']['serial_number']
+ device_type = info['persistent']['device_type']
+ capabilities = (
+ self._mgr.get_supported_device_capabilities(device_type))
+ device_dict = {
+ 'deviceId': device_id,
+ 'serialNumber': serial_number,
+ 'deviceType': device_type,
+ 'capabilities': capabilities,
+ }
+ devices.append(device_dict)
+ self._update_handlers_cls_map(device_type, capabilities)
+ self._first_detection = False
+ return devices
+
+ def create_devices(
+ self,
+ identifiers: List[str],
+ log_directory: Optional[str] = None) -> None:
+ """Creates GDM device instances.
+
+ Args:
+ identifiers: List of GDM device id.
+ log_directory: GDM log directory.
+ """
+ with self._mgr_lock:
+ for device_id in identifiers:
+ self._mgr.create_device(
+ identifier=device_id, log_directory=log_directory)
+
+ def check_device_connected(self, identifier: str) -> None:
+ """Checks if the device is connected.
+
+ Args:
+ identifier: GDM device id.
+
+ Raises:
+ DeviceNotConnectedError: When device is not connected.
+ DeviceError: When unexpected error occurs.
+ """
+ with self._mgr_lock:
+ try:
+ if not self._mgr.is_device_connected(identifier):
+ raise agent_errors.DeviceNotConnectedError(
+ f'Device {identifier} is not connected.')
+ except gdm_errors.DeviceError as e:
+ logger.warning(f'check_device_connected failed: {e}')
+ raise e
+
+ def get_device_instance(self, identifier: str) -> Any:
+ """Gets the device instance in GDM.
+
+ Args:
+ identifier: GDM device id.
+
+ Returns:
+ Device instance in GDM.
+ """
+ with self._mgr_lock:
+ if identifier not in self._mgr.get_open_device_names():
+ raise agent_errors.DeviceNotOpenError(f'{identifier} is not open.')
+ return self._mgr.get_open_device(identifier)
+
+ def get_device_type(self, identifier: str) -> str:
+ """Gets the device type in GDM.
+
+ Args:
+ identifier: GDM device id.
+
+ Returns:
+ Device type in GDM.
+ """
+ with self._mgr_lock:
+ dut = self.get_device_instance(identifier)
+ return dut.device_type
+
+ def close_open_devices(self) -> None:
+ """Closes all open devices in GDM."""
+ with self._mgr_lock:
+ self._mgr.close_open_devices()
diff --git a/local_agent/translation_layer/translation_layer.py b/local_agent/translation_layer/translation_layer.py
new file mode 100644
index 0000000..04576fe
--- /dev/null
+++ b/local_agent/translation_layer/translation_layer.py
@@ -0,0 +1,282 @@
+# Copyright 2021 Google LLC
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+"""Module for Translation Layer."""
+import collections
+import contextlib
+import threading
+import time
+from typing import Any, Callable, Dict, List
+
+from local_agent import ams_client
+from local_agent import errors as agent_errors
+from local_agent import logger as logger_module
+from local_agent.translation_layer import gdm_manager
+from local_agent.translation_layer.command_handlers import base
+from local_agent.translation_layer.command_handlers.handler_registry import GDM_CAPABILITIES_TO_COMMAND_HANDLERS
+
+
+logger = logger_module.get_logger()
+
+_RPC_TIME_OUT = 900 # 15 mins in seconds
+_RPC_TIME_OUT_HUMAN_READABLE = '15 mins'
+_RPC_TIME_OUT_INTERVAL_SECONDS = 30
+
+
+# ======================== Module level functions ========================== #
+def validate_handlers_cls_map(handlers: List[Any]) -> None:
+ """Validates the handler classes for a device type.
+
+ Validates if there's a collision between 2 handlers in the
+ given handler list or if any handler is not a child of
+ BaseCommandHandler.
+
+ Args:
+ handlers: The list of command handler classes.
+
+ Raises:
+ HandlerInvalidError: Handler is not a child of BaseCommandHandler.
+ HandlersCollisionError: Two or more command handlers have the
+ same SUPPORTED_METHODS.
+ """
+ method_to_handler = {} # method -> handler
+ for handler in handlers:
+ # check if the handler is a subclass of BaseCommandHandler
+ if not issubclass(handler, base.BaseCommandHandler):
+ raise agent_errors.HandlerInvalidError(
+ f'{handler.__name__} is not a subclass of BaseCommandHandler.')
+ for method in handler.SUPPORTED_METHODS:
+ if method in method_to_handler:
+ pre_handler = method_to_handler[method]
+ raise agent_errors.HandlersCollisionError(
+ f'Handlers {pre_handler.__name__} and {handler.__name__} '
+ f'have duplicate methods: {method}.')
+ method_to_handler[method] = handler
+# ========================================================================== #
+
+
+class TranslationLayer:
+ """Translation Layer for JSON-RPC and Device Control Libraries mapping."""
+
+ def __init__(self, client: ams_client.AmsClient):
+ # Command handlers
+ self._handlers_cls_map = collections.defaultdict(dict)
+ self._cmd_handlers = {}
+
+ # GDM manager
+ self._mgr = gdm_manager.GdmManager(self.update_handlers_cls_map)
+
+ # Checks busy devices
+ self._rpc_execution_lock = threading.RLock()
+ self._busy_devices = set()
+
+ # Tracks RPC timeout
+ self._ams_client = client
+ self._rpc_execution_start_time = {}
+ self._termination_event = None
+ self._rpc_timeout_checker = threading.Thread(
+ target=self._check_rpc_timeout, daemon=True)
+ self._timeout_rpc = set()
+
+ def start(self, termination_event: threading.Event) -> None:
+ """Starts the suite session manager by enabling the background threads.
+
+ Args:
+ termination_event: The termination threading event for the thread.
+ """
+ self._termination_event = termination_event
+ self._rpc_timeout_checker.start()
+
+ def create_devices(self, dut_ids: List[str], test_suite_dir: str) -> None:
+ """Creates GDM device instances.
+
+ Args:
+ dut_ids: List of GDM device ids.
+ test_suite_dir: Test suite directory.
+ """
+ dut_ids = list(set(dut_ids))
+ self._mgr.create_devices(dut_ids, test_suite_dir)
+
+ def close_devices(self) -> None:
+ """Closes all GDM devices and clears handler maps."""
+ self._mgr.close_open_devices()
+ self._cmd_handlers.clear()
+
+ @contextlib.contextmanager
+ def device_operation_handler(self, dut_device_id: str, rpc_id: str) -> None:
+ """Context manager for device operation.
+
+ Marks the device as busy when entering the context, unmarks
+ the device when exiting the context.
+ Also records the RPC execution start time when entering the
+ context, clears the record when exiting.
+
+ Args:
+ dut_device_id: DUT device id in GDM.
+
+ Raises:
+ InvalidRPCError: When the requested device is still busy.
+ """
+ try:
+ with self._rpc_execution_lock:
+ if rpc_id in self._rpc_execution_start_time:
+ raise agent_errors.InvalidRPCError(
+ f'RPC {rpc_id} is already executing.')
+
+ if dut_device_id in self._busy_devices:
+ raise agent_errors.InvalidRPCError(
+ f'Invalid RPC request: {dut_device_id} is still busy.')
+
+ self._busy_devices.add(dut_device_id)
+ self._rpc_execution_start_time[rpc_id] = time.time()
+
+ yield None
+
+ finally:
+ with self._rpc_execution_lock:
+ if dut_device_id in self._busy_devices:
+ self._busy_devices.remove(dut_device_id)
+ if rpc_id in self._rpc_execution_start_time:
+ del self._rpc_execution_start_time[rpc_id]
+
+ def update_handlers_cls_map(
+ self, device_type: str, capabilities: List[str]) -> None:
+ """Updates the handlers_cls_map for the given device_type.
+
+ Args:
+ device_type: GDM device type.
+ capabilities: List of GDM capabilities.
+ """
+ if device_type in self._handlers_cls_map:
+ return
+
+ matched_handlers = set()
+ for capability in capabilities:
+ for handler in GDM_CAPABILITIES_TO_COMMAND_HANDLERS.get(
+ capability, []):
+ matched_handlers.add(handler)
+ matched_handlers = list(matched_handlers)
+
+ validate_handlers_cls_map(matched_handlers)
+
+ for handler in matched_handlers:
+ for method in handler.SUPPORTED_METHODS:
+ self._handlers_cls_map[device_type][method] = handler
+
+ def detect_devices(self) -> List[Dict[str, Any]]:
+ """Detects connected devices.
+
+ Returns:
+ The list of device dict, each dict includes the following fields:
+ deviceId, serialNumber, capabilities and deviceType in GDM.
+ """
+ return self._mgr.detect_devices()
+
+ def dispatch_to_cmd_handler(
+ self, rpc_request: Dict[str, str]) -> Dict[str, Any]:
+ """Subroutine for handling regular device related rpc request.
+
+ Args:
+ rpc_request: JSON-RPC request.
+
+ Raises:
+ InvalidRPCError: Invalid rpc.
+
+ Returns:
+ RPC response.
+ """
+ rpc_id = rpc_request['id']
+ dut_device_id = rpc_request['params'].get('dutDeviceId')
+ if dut_device_id is None:
+ raise agent_errors.InvalidRPCError(
+ 'Invalid rpc request: no dutDeviceId in params.')
+
+ with self.device_operation_handler(dut_device_id, rpc_id):
+ self._mgr.check_device_connected(dut_device_id)
+ cmd_handler = self._get_cmd_handler(dut_device_id,
+ rpc_request['method'])
+ resp = cmd_handler.handle_request(rpc_request)
+
+ logger.info(f'Completed request for {dut_device_id}: {rpc_request}')
+ return resp
+
+ def is_rpc_timeout(self, rpc_id: str) -> bool:
+ """Returns if the RPC request has timed out."""
+ return rpc_id in self._timeout_rpc
+
+ def _get_cmd_handler(self,
+ dut_device_id: str,
+ method: str) -> Callable[..., Any]:
+ """Gets the corresponding command handler via device id and rpc command.
+
+ Args:
+ dut_device_id: DUT device id in GDM.
+ method: device operation in RPC command.
+
+ Returns:
+ The command handler which matches the given device ID and rpc
+ command.
+
+ Raises:
+ HandlerNotFoundError: When no matching request handlers.
+ """
+ if dut_device_id in self._cmd_handlers:
+ handler = self._cmd_handlers[dut_device_id]
+ if method not in handler.SUPPORTED_METHODS:
+ raise agent_errors.HandlerNotFoundError(
+ 'No matching command handler, '
+ f'method: {method} is not implemented')
+ return handler
+
+ device_type = self._mgr.get_device_type(dut_device_id)
+
+ if device_type not in self._handlers_cls_map:
+ raise agent_errors.HandlerNotFoundError(
+ 'No matching command handler, '
+ f'device type: {device_type} is not implemented')
+
+ target_handler_cls = self._handlers_cls_map[device_type].get(method)
+ if target_handler_cls is None:
+ raise agent_errors.HandlerNotFoundError(
+ 'No matching command handler, '
+ f'method: {method} is not implemented')
+
+ dut = self._mgr.get_device_instance(dut_device_id)
+ handler = target_handler_cls(dut)
+ self._cmd_handlers[dut_device_id] = handler
+
+ return handler
+
+ def _check_rpc_timeout(self) -> None:
+ """Checks if RPC request handling times out.
+
+ Checks through the current progressing RPCs, sends the timeout
+ failure response if times out and marks the RPC.
+ """
+ while (self._termination_event is not None and
+ not self._termination_event.wait(_RPC_TIME_OUT_INTERVAL_SECONDS)):
+ with self._rpc_execution_lock:
+ now = time.time()
+ for rpc_id, start_time in self._rpc_execution_start_time.items():
+ if (not self.is_rpc_timeout(rpc_id) and
+ now - start_time >= _RPC_TIME_OUT):
+ err_mesg = (f'Handling RPC request {rpc_id} has timed out.'
+ f'(over {_RPC_TIME_OUT_HUMAN_READABLE}, DUT may'
+ ' be unresponsive)')
+ err_resp = {'id': rpc_id, 'jsonrpc': '2.0'}
+ err_resp['error'] = {
+ 'code': agent_errors.RpcTimeOutError.err_code,
+ 'message': err_mesg}
+ self._ams_client.send_rpc_response(err_resp)
+ self._timeout_rpc.add(rpc_id)
diff --git a/local_agent/version.py b/local_agent/version.py
new file mode 100644
index 0000000..a47123c
--- /dev/null
+++ b/local_agent/version.py
@@ -0,0 +1,21 @@
+# Copyright 2021 Google LLC
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+"""Dedicated module to store package version info.
+
+This is one of the techniques to have single source of truth for the package
+version.
+"""
+
+__version__ = '0.0.1'
diff --git a/setup.py b/setup.py
new file mode 100644
index 0000000..d383f3b
--- /dev/null
+++ b/setup.py
@@ -0,0 +1,57 @@
+# Copyright 2021 Google LLC
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+"""Build the local-agent Python package."""
+
+import os
+from setuptools import find_packages, setup
+import sys
+
+
+def get_version() -> str:
+ """Gets the package version from single source of truth."""
+ version_module = os.path.join(
+ os.path.dirname(os.path.abspath(__file__)),
+ 'local_agent/version.py')
+ globals_in_version_module = {}
+ exec(open(version_module, 'r').read(), globals_in_version_module)
+ return globals_in_version_module['__version__']
+
+
+# Package meta-data
+NAME = 'Local-Agent'
+DESCRIPTION = 'Local Agent Process in Rainier Test Infrastructure'
+REQUIRES_PYTHON = '>=3.7'
+VERSION = get_version()
+LICENSE = 'Copyright 2021 Google LLC'
+
+REQUIRED_MODULES = [
+ 'gazoo-device', 'inflection', 'requests >= 2.25.0']
+TEST_REQUIRED_MODULES = {
+ 'test': ['coverage', 'immutabledict', 'parameterized']}
+
+
+setup(
+ name = NAME,
+ version = VERSION,
+ description = DESCRIPTION,
+ license = LICENSE,
+ packages = find_packages(exclude=('tests')),
+ python_requires = REQUIRES_PYTHON,
+ install_requires = REQUIRED_MODULES,
+ extras_require = TEST_REQUIRED_MODULES,
+ entry_points = {
+ 'console_scripts': ['local-agent = local_agent.local_agent:main']
+ }
+)