blob: 373230c6d4808c728a0cfec84fa1a0ba05554aca [file] [log] [blame]
# Copyright 2022 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Module for AmsClient class."""
import http
import json
import immutabledict
import requests
import time
from typing import Any, Dict, Optional, Tuple
from local_agent import errors
from local_agent import logger as logger_module
logger = logger_module.get_logger()
# _ENDPOINTS stores the endpoint definitions.
# _TIMEOUTS stores the timeout config for the endpoints.
_ENDPOINTS = immutabledict.immutabledict({
'REGISTER': '/tsb/api/local-agent/register',
'AUTH': '/tsb/api/local-agent/auth',
'REPORT_INFO': '/tsb/api/local-agent/report-info',
'GET_RPC_REQUEST': '/tsb/api/local-agent/rpc/request',
'DELETE_RPC_REQUEST': '/tsb/api/local-agent/rpc/request/{rpc_id}',
'SEND_RPC_RESPONSE': '/tsb/api/local-agent/rpc/response',
'UPLOAD_ARTIFACT': '/tsb/api/local-agent/artifacts',
})
_TIMEOUTS = immutabledict.immutabledict({})
_DEFAULT_TIMEOUT = 5.0
_DEFAULT_RETRY = 3
_DEFAULT_RETRY_INTERVAL = 1.0
_DEFAULT_SCHEME = 'https'
_UNLINKED_ERR_MSG = 'Invalid agent id'
def extract_error_message_from_api_response(
response: requests.models.Response) -> Optional[str]:
"""Extracts error message from an AMS API response.
Args:
response: The AMS API response.
Returns:
The error message as string, or None if no AMS error message found.
"""
try:
return response.json()['errorMessage']
except json.decoder.JSONDecodeError:
logger.warning('API response cannot be parsed as JSON.')
except KeyError:
logger.warning('API response does not have errorMessage field')
return None
def _local_agent_is_unlinked(
response_status_code: int, error_msg: Optional[str]) -> bool:
"""Determines whether local agent is unlinked.
Args:
response_status_code: The AMS API response status code.
error_msg: The AMS API response error message.
Returns:
True if the local agent is unlinked, false otherwise.
"""
return (response_status_code == http.HTTPStatus.BAD_REQUEST
and error_msg is not None
and _UNLINKED_ERR_MSG in error_msg)
class AmsClient:
"""Client class to communicate with Agent Management Service (AMS).
Local Agent uses this client class to communicate with AMS, i.e., to send
APIs to AMS and receive responses.
"""
def __init__(
self,
host: str,
port: Optional[int] = None,
scheme: str = _DEFAULT_SCHEME):
"""Initializes the AMS client.
The port and scheme parameters are optional. If not
provided, we will use the defaults.
Args:
host: The host address of AMS server.
port: The port of AMS server.
scheme: Either 'http' or 'https'.
"""
if port is not None:
self._base_url = f'{scheme}://{host}:{port}'
else:
self._base_url = f'{scheme}://{host}'
# Credentials
self._local_agent_id = None
self._local_agent_secret = None
# HTTP request session.
self._session = requests.Session()
def set_local_agent_credentials(self,
local_agent_id: int,
local_agent_secret: str) -> None:
"""Sets the local agent ID and local agent secret.
If the local agent has been registered before, use this method to set
the credentials instead of calling register() again.
Args:
local_agent_id: The agent ID.
local_agent_secret: The agent secret.
"""
self._local_agent_id = local_agent_id
self._local_agent_secret = local_agent_secret
self._get_auth_token()
def register(self, linking_code: str) -> Tuple[str, str]:
"""Registers a local agent with linking code.
This method registers a new local agent and will retrieve local agent
credentials. If you already have credentials for your local agent. Use
set_local_agent_credentials() instead.
Args:
linking_code: The linking code obtained from frontend UI.
Returns:
The local agent ID and local agent secret.
Raises:
ApiTimeoutError: If the register API timed out.
CredentialsError: If the linking code is not valid.
ApiError: If the registration API failed.
"""
endpoint = self._base_url + _ENDPOINTS['REGISTER']
timeout = _TIMEOUTS.get('REGISTER') or _DEFAULT_TIMEOUT
post_data = {'linkingCode': linking_code}
try:
response = self._request_wrapper(
method='POST',
url=endpoint,
json=post_data,
timeout=timeout)
except errors.ApiTimeoutError:
logger.error('Register API timed out.')
raise
if response.status_code == http.HTTPStatus.BAD_REQUEST:
ams_err_msg = extract_error_message_from_api_response(response)
raise errors.CredentialsError(
f'Not a valid linking code: {linking_code}, '
f'AMS error message: {ams_err_msg}')
if response.status_code != http.HTTPStatus.OK:
ams_err_msg = extract_error_message_from_api_response(response)
raise errors.ApiError(
f'Registration failed with status {response.status_code}, '
f'AMS error message: {ams_err_msg}')
result = response.json()['result']
self.set_local_agent_credentials(
local_agent_id=result['agentId'],
local_agent_secret=result['agentSecret'])
return result['agentId'], result['agentSecret']
def report_info(self, info: Dict[str, Any]) -> None:
"""Reports local agent info.
Calling REPORT_INFO API to send the info of local agent to AMS.
Args:
info: The information of local agent.
Raises:
ApiTimeoutError: If the report info API timed out.
ApiError: If API response status is not 200 OK.
"""
endpoint = self._base_url + _ENDPOINTS['REPORT_INFO']
timeout = _TIMEOUTS.get('REGISTER') or _DEFAULT_TIMEOUT
try:
response = self._request_wrapper(
method='POST',
url=endpoint,
json=info,
timeout=timeout)
except errors.ApiTimeoutError:
logger.error('Report info API timed out.')
raise
if response.status_code != http.HTTPStatus.OK:
ams_err_msg = extract_error_message_from_api_response(response)
raise errors.ApiError(
f'Report info API failed: status {response.status_code}, '
f'AMS error message: {ams_err_msg}')
def get_rpc_request_from_ams(self) -> Optional[Dict[str, Any]]:
"""Gets an RPC request from AMS, returns None if there isn't any.
Returns:
The JSON-RPC request object, a dictionary. Or None if there's no
request from AMS.
Raises:
ApiTimeoutError: If the get JSON-RPC request API timed out.
ApiError: If API response has an error status code.
"""
endpoint = self._base_url + _ENDPOINTS['GET_RPC_REQUEST']
timeout = _TIMEOUTS.get('GET_RPC_REQUEST') or _DEFAULT_TIMEOUT
try:
response = self._request_wrapper(
method='GET',
url=endpoint,
timeout=timeout)
except errors.ApiTimeoutError:
logger.error('Get RPC request API timed out.')
raise
if response.status_code == http.HTTPStatus.OK:
return response.json().get('result')
elif response.status_code == http.HTTPStatus.NO_CONTENT:
return None
else:
ams_err_msg = extract_error_message_from_api_response(response)
raise errors.ApiError(
f'Get RPC request API failed: status {response.status_code}, '
f'AMS error message: {ams_err_msg}')
def remove_rpc_request_from_ams(self,
rpc_request: Dict[str, Any]) -> None:
"""Removes the RPC request from AMS.
Local agent should call this method to remove a JSON-RPC request from
AMS after retrieving the request during polling. Otherwise, the next
iteration of the polling loop will get the same JSON-RPC request.
Args:
rpc_request: The JSON-RPC request to remove from AMS.
Raises:
ApiTimeoutError: If the delete JSON-RPC request API timed out.
ApiError: If API response has an unexpected status code.
"""
rpc_id = rpc_request.get('id')
endpoint = (self._base_url +
_ENDPOINTS['DELETE_RPC_REQUEST'].format(rpc_id=rpc_id))
timeout = _TIMEOUTS.get('DELETE_RPC_REQUEST') or _DEFAULT_TIMEOUT
try:
response = self._request_wrapper(
method='DELETE',
url=endpoint,
timeout=timeout)
except errors.ApiTimeoutError:
logger.error('Remove RPC request API timed out.')
raise
if response.status_code != http.HTTPStatus.OK:
ams_err_msg = extract_error_message_from_api_response(response)
raise errors.ApiError(
f'Remove RPC API failed: status {response.status_code}, '
f'AMS error message: {ams_err_msg}')
def send_rpc_response(self, rpc_response: Dict[str, Any]) -> None:
"""Sends a JSON-RPC response to AMS.
We will try at most 3 times if the API did not succeed.
Args:
rpc_response: The JSON-RPC response to send.
Raises:
ApiTimeoutError: If API timed out.
ApiError: If API response has unexpected status code.
"""
endpoint = self._base_url + _ENDPOINTS['SEND_RPC_RESPONSE']
timeout = _TIMEOUTS.get('SEND_RPC_RESPONSE') or _DEFAULT_TIMEOUT
try:
response = self._request_wrapper(
method='POST',
url=endpoint,
json=rpc_response,
timeout=timeout)
except errors.ApiTimeoutError:
logger.error('Send RPC response API timed out.')
raise
if response.status_code != http.HTTPStatus.OK:
ams_err_msg = extract_error_message_from_api_response(response)
raise errors.ApiError(
f'Send RPC response API failed. Status: '
f'{response.status_code}, AMS error message: {ams_err_msg}')
def upload_artifact(self, artifact_path: str, test_result_id: str) -> None:
"""Uploads a test artifact (file) to AMS.
Args:
artifact_path: Path to the artifact.
test_result_id: ID of the test result this artifact belongs to.
Raises:
ApiTimeoutError: If API timed out.
ApiError: If API response has unexpected status code.
"""
endpoint = self._base_url + _ENDPOINTS['UPLOAD_ARTIFACT']
timeout = _TIMEOUTS.get('UPLOAD_ARTIFACT') or _DEFAULT_TIMEOUT
try:
with open(artifact_path, 'rb') as artifact_stream:
response = self._request_wrapper(
method='POST',
url=endpoint,
files={'file': artifact_stream},
data={'testResultId': test_result_id},
timeout=timeout)
except errors.ApiTimeoutError:
logger.error('Upload artifact API timed out.')
raise
if response.status_code != http.HTTPStatus.OK:
ams_err_msg = extract_error_message_from_api_response(response)
raise errors.ApiError(
f'Upload artifact API failed. Status: {response.status_code}, '
f'AMS error message: {ams_err_msg}')
def _get_auth_token(self) -> None:
"""Gets auth token, and sets the HTTP session header.
Raises:
ApiTimeoutError: If the auth API timed out.
CredentialsError: If cannot get auth token with the local agent
credentials.
"""
endpoint = self._base_url + _ENDPOINTS['AUTH']
timeout = _TIMEOUTS.get('AUTH') or _DEFAULT_TIMEOUT
post_data = {'agentId': self._local_agent_id,
'agentSecret': self._local_agent_secret}
try:
response = self._request_wrapper(
refresh_auth=False,
method='POST',
url=endpoint,
json=post_data,
timeout=timeout)
except errors.ApiTimeoutError:
logger.error('Auth API timed out.')
raise
if response.status_code != http.HTTPStatus.CREATED:
ams_err_msg = extract_error_message_from_api_response(response)
if _local_agent_is_unlinked(response.status_code, ams_err_msg):
raise errors.UnlinkedError(
f'Local agent is unlinked: {ams_err_msg}.')
else:
raise errors.CredentialsError(
f'Cannot get auth token. Status {response.status_code}, '
f'AMS error message: {ams_err_msg}')
auth_token = response.json()['result']['authToken']
self._session.headers.update({'Authorization': auth_token})
def _request_wrapper(self,
*,
num_retries: int = _DEFAULT_RETRY,
refresh_auth: bool = True,
**kwargs) -> requests.models.Response:
"""Sends a request and provides retry functionality.
We will retry when:
1) the API response has a 4xx or 5xx status code, or
2) the request has timed out.
If the retries exhausted, we return the last response, or raise
ApiTimeoutError if the last request timed out.
Args:
num_retries: The number of retries if API did not succeed. If a
negative integer or zero is provided, we don't retry.
refresh_auth: If True, we will refresh the auth token when the
first response has a 401 UNAUTHORIZED status. Refreshing auth
token is not considered a retry.
**kwargs: Other keyword arguments that will be passed to the
requests.sessions.Session.request method.
Returns:
The API response.
Raises:
ApiTimeoutError: If the last request timed out.
"""
# Ensure we set a timeout.
if 'timeout' not in kwargs:
kwargs['timeout'] = _DEFAULT_TIMEOUT
if num_retries <= 0:
num_retries = 0
current_try = 0
last_exception = None
while current_try <= num_retries:
current_try += 1
try:
response = self._session.request(**kwargs)
except requests.exceptions.Timeout:
last_exception = errors.ApiTimeoutError()
else:
last_exception = None
status_code = response.status_code
if not (400 <= status_code <= 599):
return response
if (status_code == http.HTTPStatus.UNAUTHORIZED and
refresh_auth and
current_try == 1):
# Only refresh token when first attempt is UNAUTHORIZED.
# And refresh token will not be considered a retry.
self._get_auth_token()
current_try -= 1
time.sleep(_DEFAULT_RETRY_INTERVAL)
if last_exception is not None:
raise last_exception
return response