blob: ba0f9c2910f7705c429c15340182387cf5a87d07 [file] [log] [blame]
# Copyright 2022 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Unit tests for base command handlers."""
import inflection
from parameterized import parameterized
import unittest
from unittest import mock
from local_agent import errors as agent_errors
from local_agent.translation_layer.command_handlers import base
_FAKE_SET_OP = 'setFake'
_FAKE_GET_OP = 'getFake'
_FAKE_SET_RESULT = None
_FAKE_GET_RESULT = 'fake-result'
_FAKE_PARAM_KEY = 'fake-param-key'
_FAKE_PARAM_VAL = 10
_UNKNOWN_METHOD = 'unknownMethod'
def rpc_request(method, params):
"""Simple wrapper for JSON-RPC request."""
return {'jsonrpc': '2.0', 'id': 0, 'method': method, 'params': params}
def rpc_response(func_result):
"""Simple wrapper for JSON-RPC response."""
result = {} if func_result is None else {'value': func_result}
return {'id': 0, 'jsonrpc': '2.0', 'result': result}
class BaseCommandHandlerTest(unittest.TestCase):
"""Unit tests for BaseCommandHandler."""
def setUp(self):
super().setUp()
self.handler = base.BaseCommandHandler(None)
@parameterized.expand(
[(_FAKE_GET_OP, _FAKE_GET_RESULT), (_FAKE_SET_OP, _FAKE_SET_RESULT)])
def test_01_handle_request_on_success(self, fake_method, fake_result):
"""Verifies handle_request on success."""
fake_func = mock.Mock()
fake_func.return_value = fake_result
method_name = f'_{inflection.underscore(fake_method)}'
setattr(self.handler, method_name, fake_func)
fake_rpc = rpc_request(method=fake_method, params={})
expected_response = rpc_response(fake_result)
response = self.handler.handle_request(fake_rpc)
self.assertEqual(expected_response, response)
fake_func.assert_called_once()
def test_01_handle_request_on_failure_unknown_method(self):
"""Verifies handle_request on failure with unknown method."""
fake_rpc = rpc_request(method=_UNKNOWN_METHOD, params={})
error_regex = 'Unknown method'
with self.assertRaisesRegex(agent_errors.InvalidRPCError, error_regex):
self.handler.handle_request(fake_rpc)
def test_02_validate_key_in_params_on_success(self):
"""Verifies validate_key_in_params method on success."""
params = {_FAKE_PARAM_KEY: _FAKE_PARAM_VAL}
self.handler.validate_key_in_params(
params=params,
param_key=_FAKE_PARAM_KEY,
expected_type=int)
def test_02_validate_key_in_params_on_failure_missing_key(self):
"""Verifies _validate_key_in_params on failure with missing key."""
error_regex = f'Missing field {_FAKE_PARAM_KEY} from RPC request.'
with self.assertRaisesRegex(ValueError, error_regex):
self.handler.validate_key_in_params(
params={},
param_key=_FAKE_PARAM_KEY,
expected_type=int)
def test_02_validate_key_in_params_on_failure_invalid_type(self):
"""Verifies validate_key_in_params on failure with invalid type."""
params = {_FAKE_PARAM_KEY: _FAKE_PARAM_VAL}
error_regex = f'Invalid type for {_FAKE_PARAM_KEY}.'
with self.assertRaisesRegex(ValueError, error_regex):
self.handler.validate_key_in_params(
params=params,
param_key=_FAKE_PARAM_KEY,
expected_type=str)
if __name__ == '__main__':
unittest.main(failfast=True)