| # Copyright 2022 Google LLC |
| # |
| # Licensed under the Apache License, Version 2.0 (the "License"); |
| # you may not use this file except in compliance with the License. |
| # You may obtain a copy of the License at |
| # |
| # http://www.apache.org/licenses/LICENSE-2.0 |
| # |
| # Unless required by applicable law or agreed to in writing, software |
| # distributed under the License is distributed on an "AS IS" BASIS, |
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| # See the License for the specific language governing permissions and |
| # limitations under the License. |
| |
| """Module for AmsClient class.""" |
| import http |
| import json |
| import immutabledict |
| import requests |
| import time |
| from typing import Any, Dict, Optional, Tuple |
| |
| from local_agent import errors |
| from local_agent import logger as logger_module |
| |
| |
| logger = logger_module.get_logger() |
| |
| # _ENDPOINTS stores the endpoint definitions. |
| # _TIMEOUTS stores the timeout config for the endpoints. |
| _ENDPOINTS = immutabledict.immutabledict({ |
| 'REGISTER': '/tsb/api/local-agent/register', |
| 'AUTH': '/tsb/api/local-agent/auth', |
| 'REPORT_INFO': '/tsb/api/local-agent/report-info', |
| 'GET_RPC_REQUEST': '/tsb/api/local-agent/rpc/request', |
| 'DELETE_RPC_REQUEST': '/tsb/api/local-agent/rpc/request/{rpc_id}', |
| 'SEND_RPC_RESPONSE': '/tsb/api/local-agent/rpc/response', |
| 'UPLOAD_ARTIFACT': '/tsb/api/local-agent/artifacts', |
| }) |
| _TIMEOUTS = immutabledict.immutabledict({}) |
| _DEFAULT_TIMEOUT = 5.0 |
| _DEFAULT_RETRY = 3 |
| _DEFAULT_RETRY_INTERVAL = 1.0 |
| _DEFAULT_SCHEME = 'https' |
| _UNLINKED_ERR_MSG = 'Invalid agent id' |
| |
| |
| def extract_error_message_from_api_response( |
| response: requests.models.Response) -> Optional[str]: |
| """Extracts error message from an AMS API response. |
| |
| Args: |
| response: The AMS API response. |
| |
| Returns: |
| The error message as string, or None if no AMS error message found. |
| """ |
| try: |
| return response.json()['errorMessage'] |
| except json.decoder.JSONDecodeError: |
| logger.warning('API response cannot be parsed as JSON.') |
| except KeyError: |
| logger.warning('API response does not have errorMessage field') |
| return None |
| |
| |
| def _local_agent_is_unlinked( |
| response_status_code: int, error_msg: Optional[str]) -> bool: |
| """Determines whether local agent is unlinked. |
| |
| Args: |
| response_status_code: The AMS API response status code. |
| error_msg: The AMS API response error message. |
| |
| Returns: |
| True if the local agent is unlinked, false otherwise. |
| """ |
| return (response_status_code == http.HTTPStatus.BAD_REQUEST |
| and error_msg is not None |
| and _UNLINKED_ERR_MSG in error_msg) |
| |
| |
| class AmsClient: |
| """Client class to communicate with Agent Management Service (AMS). |
| |
| Local Agent uses this client class to communicate with AMS, i.e., to send |
| APIs to AMS and receive responses. |
| """ |
| |
| def __init__( |
| self, |
| host: str, |
| port: Optional[int] = None, |
| scheme: str = _DEFAULT_SCHEME): |
| """Initializes the AMS client. |
| |
| The port and scheme parameters are optional. If not |
| provided, we will use the defaults. |
| |
| Args: |
| host: The host address of AMS server. |
| port: The port of AMS server. |
| scheme: Either 'http' or 'https'. |
| """ |
| if port is not None: |
| self._base_url = f'{scheme}://{host}:{port}' |
| else: |
| self._base_url = f'{scheme}://{host}' |
| |
| # Credentials |
| self._local_agent_id = None |
| self._local_agent_secret = None |
| |
| # HTTP request session. |
| self._session = requests.Session() |
| |
| def set_local_agent_credentials(self, |
| local_agent_id: int, |
| local_agent_secret: str) -> None: |
| """Sets the local agent ID and local agent secret. |
| |
| If the local agent has been registered before, use this method to set |
| the credentials instead of calling register() again. |
| |
| Args: |
| local_agent_id: The agent ID. |
| local_agent_secret: The agent secret. |
| """ |
| self._local_agent_id = local_agent_id |
| self._local_agent_secret = local_agent_secret |
| self._get_auth_token() |
| |
| def register(self, linking_code: str) -> Tuple[str, str]: |
| """Registers a local agent with linking code. |
| |
| This method registers a new local agent and will retrieve local agent |
| credentials. If you already have credentials for your local agent. Use |
| set_local_agent_credentials() instead. |
| |
| Args: |
| linking_code: The linking code obtained from frontend UI. |
| |
| Returns: |
| The local agent ID and local agent secret. |
| |
| Raises: |
| ApiTimeoutError: If the register API timed out. |
| CredentialsError: If the linking code is not valid. |
| ApiError: If the registration API failed. |
| """ |
| endpoint = self._base_url + _ENDPOINTS['REGISTER'] |
| timeout = _TIMEOUTS.get('REGISTER') or _DEFAULT_TIMEOUT |
| post_data = {'linkingCode': linking_code} |
| |
| try: |
| response = self._request_wrapper( |
| method='POST', |
| url=endpoint, |
| json=post_data, |
| timeout=timeout) |
| except errors.ApiTimeoutError: |
| logger.error('Register API timed out.') |
| raise |
| |
| if response.status_code == http.HTTPStatus.BAD_REQUEST: |
| ams_err_msg = extract_error_message_from_api_response(response) |
| raise errors.CredentialsError( |
| f'Not a valid linking code: {linking_code}, ' |
| f'AMS error message: {ams_err_msg}') |
| if response.status_code != http.HTTPStatus.OK: |
| ams_err_msg = extract_error_message_from_api_response(response) |
| raise errors.ApiError( |
| f'Registration failed with status {response.status_code}, ' |
| f'AMS error message: {ams_err_msg}') |
| |
| result = response.json()['result'] |
| self.set_local_agent_credentials( |
| local_agent_id=result['agentId'], |
| local_agent_secret=result['agentSecret']) |
| return result['agentId'], result['agentSecret'] |
| |
| def report_info(self, info: Dict[str, Any]) -> None: |
| """Reports local agent info. |
| |
| Calling REPORT_INFO API to send the info of local agent to AMS. |
| |
| Args: |
| info: The information of local agent. |
| |
| Raises: |
| ApiTimeoutError: If the report info API timed out. |
| ApiError: If API response status is not 200 OK. |
| """ |
| endpoint = self._base_url + _ENDPOINTS['REPORT_INFO'] |
| timeout = _TIMEOUTS.get('REGISTER') or _DEFAULT_TIMEOUT |
| try: |
| response = self._request_wrapper( |
| method='POST', |
| url=endpoint, |
| json=info, |
| timeout=timeout) |
| except errors.ApiTimeoutError: |
| logger.error('Report info API timed out.') |
| raise |
| if response.status_code != http.HTTPStatus.OK: |
| ams_err_msg = extract_error_message_from_api_response(response) |
| raise errors.ApiError( |
| f'Report info API failed: status {response.status_code}, ' |
| f'AMS error message: {ams_err_msg}') |
| |
| def get_rpc_request_from_ams(self) -> Optional[Dict[str, Any]]: |
| """Gets an RPC request from AMS, returns None if there isn't any. |
| |
| Returns: |
| The JSON-RPC request object, a dictionary. Or None if there's no |
| request from AMS. |
| |
| Raises: |
| ApiTimeoutError: If the get JSON-RPC request API timed out. |
| ApiError: If API response has an error status code. |
| """ |
| endpoint = self._base_url + _ENDPOINTS['GET_RPC_REQUEST'] |
| timeout = _TIMEOUTS.get('GET_RPC_REQUEST') or _DEFAULT_TIMEOUT |
| |
| try: |
| response = self._request_wrapper( |
| method='GET', |
| url=endpoint, |
| timeout=timeout) |
| except errors.ApiTimeoutError: |
| logger.error('Get RPC request API timed out.') |
| raise |
| if response.status_code == http.HTTPStatus.OK: |
| return response.json().get('result') |
| elif response.status_code == http.HTTPStatus.NO_CONTENT: |
| return None |
| else: |
| ams_err_msg = extract_error_message_from_api_response(response) |
| raise errors.ApiError( |
| f'Get RPC request API failed: status {response.status_code}, ' |
| f'AMS error message: {ams_err_msg}') |
| |
| def remove_rpc_request_from_ams(self, |
| rpc_request: Dict[str, Any]) -> None: |
| """Removes the RPC request from AMS. |
| |
| Local agent should call this method to remove a JSON-RPC request from |
| AMS after retrieving the request during polling. Otherwise, the next |
| iteration of the polling loop will get the same JSON-RPC request. |
| |
| Args: |
| rpc_request: The JSON-RPC request to remove from AMS. |
| |
| Raises: |
| ApiTimeoutError: If the delete JSON-RPC request API timed out. |
| ApiError: If API response has an unexpected status code. |
| """ |
| rpc_id = rpc_request.get('id') |
| endpoint = (self._base_url + |
| _ENDPOINTS['DELETE_RPC_REQUEST'].format(rpc_id=rpc_id)) |
| timeout = _TIMEOUTS.get('DELETE_RPC_REQUEST') or _DEFAULT_TIMEOUT |
| |
| try: |
| response = self._request_wrapper( |
| method='DELETE', |
| url=endpoint, |
| timeout=timeout) |
| except errors.ApiTimeoutError: |
| logger.error('Remove RPC request API timed out.') |
| raise |
| if response.status_code != http.HTTPStatus.OK: |
| ams_err_msg = extract_error_message_from_api_response(response) |
| raise errors.ApiError( |
| f'Remove RPC API failed: status {response.status_code}, ' |
| f'AMS error message: {ams_err_msg}') |
| |
| def send_rpc_response(self, rpc_response: Dict[str, Any]) -> None: |
| """Sends a JSON-RPC response to AMS. |
| |
| We will try at most 3 times if the API did not succeed. |
| |
| Args: |
| rpc_response: The JSON-RPC response to send. |
| |
| Raises: |
| ApiTimeoutError: If API timed out. |
| ApiError: If API response has unexpected status code. |
| """ |
| endpoint = self._base_url + _ENDPOINTS['SEND_RPC_RESPONSE'] |
| timeout = _TIMEOUTS.get('SEND_RPC_RESPONSE') or _DEFAULT_TIMEOUT |
| |
| try: |
| response = self._request_wrapper( |
| method='POST', |
| url=endpoint, |
| json=rpc_response, |
| timeout=timeout) |
| except errors.ApiTimeoutError: |
| logger.error('Send RPC response API timed out.') |
| raise |
| |
| if response.status_code != http.HTTPStatus.OK: |
| ams_err_msg = extract_error_message_from_api_response(response) |
| raise errors.ApiError( |
| f'Send RPC response API failed. Status: ' |
| f'{response.status_code}, AMS error message: {ams_err_msg}') |
| |
| def upload_artifact(self, artifact_path: str, test_result_id: str) -> None: |
| """Uploads a test artifact (file) to AMS. |
| |
| Args: |
| artifact_path: Path to the artifact. |
| test_result_id: ID of the test result this artifact belongs to. |
| |
| Raises: |
| ApiTimeoutError: If API timed out. |
| ApiError: If API response has unexpected status code. |
| """ |
| endpoint = self._base_url + _ENDPOINTS['UPLOAD_ARTIFACT'] |
| timeout = _TIMEOUTS.get('UPLOAD_ARTIFACT') or _DEFAULT_TIMEOUT |
| |
| try: |
| with open(artifact_path, 'rb') as artifact_stream: |
| response = self._request_wrapper( |
| method='POST', |
| url=endpoint, |
| files={'file': artifact_stream}, |
| data={'testResultId': test_result_id}, |
| timeout=timeout) |
| except errors.ApiTimeoutError: |
| logger.error('Upload artifact API timed out.') |
| raise |
| if response.status_code != http.HTTPStatus.OK: |
| ams_err_msg = extract_error_message_from_api_response(response) |
| raise errors.ApiError( |
| f'Upload artifact API failed. Status: {response.status_code}, ' |
| f'AMS error message: {ams_err_msg}') |
| |
| def _get_auth_token(self) -> None: |
| """Gets auth token, and sets the HTTP session header. |
| |
| Raises: |
| ApiTimeoutError: If the auth API timed out. |
| CredentialsError: If cannot get auth token with the local agent |
| credentials. |
| """ |
| endpoint = self._base_url + _ENDPOINTS['AUTH'] |
| timeout = _TIMEOUTS.get('AUTH') or _DEFAULT_TIMEOUT |
| post_data = {'agentId': self._local_agent_id, |
| 'agentSecret': self._local_agent_secret} |
| |
| try: |
| response = self._request_wrapper( |
| refresh_auth=False, |
| method='POST', |
| url=endpoint, |
| json=post_data, |
| timeout=timeout) |
| except errors.ApiTimeoutError: |
| logger.error('Auth API timed out.') |
| raise |
| |
| if response.status_code != http.HTTPStatus.CREATED: |
| ams_err_msg = extract_error_message_from_api_response(response) |
| if _local_agent_is_unlinked(response.status_code, ams_err_msg): |
| raise errors.UnlinkedError( |
| f'Local agent is unlinked: {ams_err_msg}.') |
| else: |
| raise errors.CredentialsError( |
| f'Cannot get auth token. Status {response.status_code}, ' |
| f'AMS error message: {ams_err_msg}') |
| |
| auth_token = response.json()['result']['authToken'] |
| self._session.headers.update({'Authorization': auth_token}) |
| |
| def _request_wrapper(self, |
| *, |
| num_retries: int = _DEFAULT_RETRY, |
| refresh_auth: bool = True, |
| **kwargs) -> requests.models.Response: |
| """Sends a request and provides retry functionality. |
| |
| We will retry when: |
| 1) the API response has a 4xx or 5xx status code, or |
| 2) the request has timed out. |
| |
| If the retries exhausted, we return the last response, or raise |
| ApiTimeoutError if the last request timed out. |
| |
| Args: |
| num_retries: The number of retries if API did not succeed. If a |
| negative integer or zero is provided, we don't retry. |
| refresh_auth: If True, we will refresh the auth token when the |
| first response has a 401 UNAUTHORIZED status. Refreshing auth |
| token is not considered a retry. |
| **kwargs: Other keyword arguments that will be passed to the |
| requests.sessions.Session.request method. |
| |
| Returns: |
| The API response. |
| |
| Raises: |
| ApiTimeoutError: If the last request timed out. |
| """ |
| # Ensure we set a timeout. |
| if 'timeout' not in kwargs: |
| kwargs['timeout'] = _DEFAULT_TIMEOUT |
| |
| if num_retries <= 0: |
| num_retries = 0 |
| |
| current_try = 0 |
| last_exception = None |
| |
| while current_try <= num_retries: |
| current_try += 1 |
| try: |
| response = self._session.request(**kwargs) |
| except requests.exceptions.Timeout: |
| last_exception = errors.ApiTimeoutError() |
| else: |
| last_exception = None |
| status_code = response.status_code |
| if not (400 <= status_code <= 599): |
| return response |
| if (status_code == http.HTTPStatus.UNAUTHORIZED and |
| refresh_auth and |
| current_try == 1): |
| # Only refresh token when first attempt is UNAUTHORIZED. |
| # And refresh token will not be considered a retry. |
| self._get_auth_token() |
| current_try -= 1 |
| time.sleep(_DEFAULT_RETRY_INTERVAL) |
| |
| if last_exception is not None: |
| raise last_exception |
| return response |