# Copyright 2022 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#      http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Module for Translation Layer."""
import collections
import contextlib
import threading
import time
from typing import Any, Callable, Dict, List

from local_agent import ams_client
from local_agent import errors as agent_errors
from local_agent import logger as logger_module
from local_agent.translation_layer import gdm_manager
from local_agent.translation_layer.command_handlers import base
from local_agent.translation_layer.command_handlers.handler_registry import GDM_CAPABILITIES_TO_COMMAND_HANDLERS
from ui_automator import errors as ua_errors
from ui_automator import ui_automator

logger = logger_module.get_logger()

_RPC_TIME_OUT = 900  # 15 mins in seconds
_RPC_TIME_OUT_HUMAN_READABLE = '15 mins'
_RPC_TIME_OUT_INTERVAL_SECONDS = 30


# ======================== Module level functions ========================== #
def validate_handlers_cls_map(handlers: List[Any]) -> None:
    """Validates the handler classes for a device type.

    Validates if there's a collision between 2 handlers in the
    given handler list or if any handler is not a child of
    BaseCommandHandler.

    Args:
        handlers: The list of command handler classes.

    Raises:
        HandlerInvalidError: Handler is not a child of BaseCommandHandler.
        HandlersCollisionError: Two or more command handlers have the
        same SUPPORTED_METHODS.
    """
    method_to_handler = {}  # method -> handler
    for handler in handlers:
        # check if the handler is a subclass of BaseCommandHandler
        if not issubclass(handler, base.BaseCommandHandler):
            raise agent_errors.HandlerInvalidError(
                f'{handler.__name__} is not a subclass of BaseCommandHandler.')
        for method in handler.SUPPORTED_METHODS:
            if method in method_to_handler:
                pre_handler = method_to_handler[method]
                raise agent_errors.HandlersCollisionError(
                    f'Handlers {pre_handler.__name__} and {handler.__name__} '
                    f'have duplicate methods: {method}.')
            method_to_handler[method] = handler
# ========================================================================== #


class TranslationLayer:
    """Translation Layer for JSON-RPC and Device Control Libraries mapping."""

    def __init__(self, client: ams_client.AmsClient):
        # Command handlers
        self._handlers_cls_map = collections.defaultdict(dict)
        self._cmd_handlers = {}

        # GDM manager
        self._mgr = gdm_manager.GdmManager(self.update_handlers_cls_map)

        # UI Automator
        self._ui_automator = ui_automator.UIAutomator()

        # Checks busy devices
        self._rpc_execution_lock = threading.RLock()
        self._busy_devices = set()

        # Tracks RPC timeout
        self._ams_client = client
        self._rpc_execution_start_time = {}
        self._termination_event = None
        self._rpc_timeout_checker = threading.Thread(
            target=self._check_rpc_timeout, daemon=True)
        self._timeout_rpc = set()

    def start(self, termination_event: threading.Event) -> None:
        """Starts the suite session manager by enabling the background threads.

        Args:
            termination_event: The termination threading event for the thread.
        """
        self._termination_event = termination_event
        self._rpc_timeout_checker.start()

    def create_devices(self, dut_ids: List[str], test_suite_dir: str) -> None:
        """Creates GDM device instances.

        Args:
            dut_ids: List of GDM device ids.
            test_suite_dir: Test suite directory.
        """
        dut_ids = list(set(dut_ids))
        self._mgr.create_devices(dut_ids, test_suite_dir)

    def close_devices(self) -> None:
        """Closes all GDM devices and clears handler maps."""
        self._mgr.close_open_devices()
        self._cmd_handlers.clear()

    @contextlib.contextmanager
    def device_operation_handler(self, dut_device_id: str, rpc_id: str) -> None:
        """Context manager for device operation.

        Marks the device as busy when entering the context, unmarks
        the device when exiting the context.
        Also records the RPC execution start time when entering the
        context, clears the record when exiting.

        Args:
            dut_device_id: DUT device id in GDM.

        Raises:
            InvalidRPCError: When the requested device is still busy.
        """
        try:
            with self._rpc_execution_lock:
                if rpc_id in self._rpc_execution_start_time:
                    raise agent_errors.InvalidRPCError(
                        f'RPC {rpc_id} is already executing.')

                if dut_device_id in self._busy_devices:
                    raise agent_errors.InvalidRPCError(
                        f'Invalid RPC request: {dut_device_id} is still busy.')

                self._busy_devices.add(dut_device_id)
                self._rpc_execution_start_time[rpc_id] = time.time()

            yield None

        finally:
            with self._rpc_execution_lock:
                if dut_device_id in self._busy_devices:
                    self._busy_devices.remove(dut_device_id)
                if rpc_id in self._rpc_execution_start_time:
                    del self._rpc_execution_start_time[rpc_id]

    def update_handlers_cls_map(
        self, device_type: str, capabilities: List[str]) -> None:
        """Updates the handlers_cls_map for the given device_type.

        Args:
            device_type: GDM device type.
            capabilities: List of GDM capabilities.
        """
        if device_type in self._handlers_cls_map:
            return

        matched_handlers = set()
        for capability in capabilities:
            handler = GDM_CAPABILITIES_TO_COMMAND_HANDLERS.get(capability)
            if handler is not None:
                matched_handlers.add(handler)
        matched_handlers = list(matched_handlers)

        validate_handlers_cls_map(matched_handlers)

        for handler in matched_handlers:
            for method in handler.SUPPORTED_METHODS:
                 self._handlers_cls_map[device_type][method] = handler

    def detect_devices(self) -> List[Dict[str, Any]]:
        """Detects connected devices.

        Returns:
            The list of device dict, each dict includes the following fields:
            deviceId, serialNumber, capabilities and deviceType in GDM.
        """
        return self._mgr.detect_devices()

    def dispatch_to_cmd_handler(
        self, rpc_request: Dict[str, str]) -> Dict[str, Any]:
        """Subroutine for handling regular device related rpc request.

        Args:
            rpc_request: JSON-RPC request.

        Raises:
            InvalidRPCError: Invalid rpc.

        Returns:
            RPC response.
        """
        rpc_id = rpc_request['id']
        dut_device_id = rpc_request['params'].get('dutDeviceId')
        if dut_device_id is None:
            raise agent_errors.InvalidRPCError(
                'Invalid rpc request: no dutDeviceId in params.')

        with self.device_operation_handler(dut_device_id, rpc_id):
            self._mgr.check_device_connected(dut_device_id)
            cmd_handler = self._get_cmd_handler(dut_device_id,
                                                rpc_request['method'])
            resp = cmd_handler.handle_request(rpc_request)

        logger.info(f'Completed request for {dut_device_id}: {rpc_request}')
        return resp

    def is_rpc_timeout(self, rpc_id: str) -> bool:
        """Returns if the RPC request has timed out."""
        return rpc_id in self._timeout_rpc

    def _get_cmd_handler(self,
                         dut_device_id: str,
                         method: str) -> Callable[..., Any]:
        """Gets the corresponding command handler via device id and rpc command.

        Args:
            dut_device_id: DUT device id in GDM.
            method: device operation in RPC command.

        Returns:
            The command handler which matches the given device ID and rpc
            command.

        Raises:
            HandlerNotFoundError: When no matching request handlers.
        """
        if dut_device_id in self._cmd_handlers:
            handler = self._cmd_handlers[dut_device_id]
            if method not in handler.SUPPORTED_METHODS:
                raise agent_errors.HandlerNotFoundError(
                    'No matching command handler, '
                    f'method: {method} is not implemented')
            return handler

        device_type = self._mgr.get_device_type(dut_device_id)

        if device_type not in self._handlers_cls_map:
            raise agent_errors.HandlerNotFoundError(
                'No matching command handler, '
                f'device type: {device_type} is not implemented')

        target_handler_cls = self._handlers_cls_map[device_type].get(method)
        if target_handler_cls is None:
            raise agent_errors.HandlerNotFoundError(
                'No matching command handler, '
                f'method: {method} is not implemented')

        dut = self._mgr.get_device_instance(dut_device_id)
        handler = target_handler_cls(dut)
        self._cmd_handlers[dut_device_id] = handler

        return handler

    def _check_rpc_timeout(self) -> None:
        """Checks if RPC request handling times out.

        Checks through the current progressing RPCs, sends the timeout
        failure response if times out and marks the RPC.
        """
        while (self._termination_event is not None and
            not self._termination_event.wait(_RPC_TIME_OUT_INTERVAL_SECONDS)):
            with self._rpc_execution_lock:
                now = time.time()
                for rpc_id, start_time in self._rpc_execution_start_time.items():
                    if (not self.is_rpc_timeout(rpc_id) and
                        now - start_time >= _RPC_TIME_OUT):
                        err_mesg = (f'Handling RPC request {rpc_id} has timed out.'
                                    f'(over {_RPC_TIME_OUT_HUMAN_READABLE}, DUT may'
                                    ' be unresponsive)')
                        err_resp = {'id': rpc_id, 'jsonrpc': '2.0'}
                        err_resp['error'] = {
                            'code': agent_errors.RpcTimeOutError.err_code,
                            'message': err_mesg}
                        self._ams_client.send_rpc_response(err_resp)
                        self._timeout_rpc.add(rpc_id)

    def commission_to_google_fabric(self, rpc_request: dict[str, dict[str, str]]) -> dict[str, Any]:
        """Subroutine for handling COMMISSION_TO_GOOGLE_FABRIC command.

        Args:
            rpc_request: JSON-RPC request.

        Raises:
            InvalidRPCError: Invalid RPC.

        Returns:
            RPC response.
        """
        pairing_code = rpc_request['params'].get('pairingCode')
        if pairing_code is None:
            raise agent_errors.InvalidRPCError(
                'Invalid rpc command, no pairingCode in params was found.'
            )

        device_name = rpc_request['params'].get('deviceName')
        if device_name is None:
            raise agent_errors.InvalidRPCError(
                'Invalid rpc command, no deviceName in params was found.'
            )

        gha_room = rpc_request['params'].get('ghaRoom')
        if gha_room is None:
            raise agent_errors.InvalidRPCError(
                'Invalid rpc command, no ghaRoom in params was found.'
            )

        request_id = rpc_request['id']
        try:
            self._ui_automator.commission_device(
                device_name=device_name, pairing_code=pairing_code, gha_room=gha_room
            )
        except (ua_errors.AndroidDeviceNotReadyError, ua_errors.MoblySnippetError) as e:
            return {
                'id': request_id,
                'jsonrpc': '2.0',
                'result': {'isCommissioned': False, 'errorLog': str(e)},
            }
        finally:
            logger.info(
                f'Completed request for {request_id}, pairingCode: {pairing_code},'
                f' deviceName: {device_name}, ghaRoom: {gha_room}'
            )

        return {
            'id': request_id,
            'jsonrpc': '2.0',
            'result': {'isCommissioned': True},
        }
