# Copyright 2021 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#      http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Fake front end for local agent and real AMS integration test."""
import argparse
import http
import immutabledict
import logging
import requests
import signal
import sys
import time
import threading
from typing import Any, Dict, List, Optional, Set, Tuple

import fake_test_suite


_LOCAL_TSB_HOST = '127.0.0.1'
_LOCAL_TSB_PORT = 8080
_TEST_PROJECT = 'test-project'
_POLLING_PERIOD = 1  # seconds
_STATUS_POLLING_PERIOD = 30  # seconds
_COOL_DOWN_SEC = 1  # seconds
_DETECTION_COOL_DOWN_SEC = 3  # To wait for GDM detection to complete.


logger = logging.getLogger(__name__)

# ======================== TSB endpoints ========================== #
TEST_SUITE_AUTH = '/tsb/api/test-suite/auth'
TEST_SUITE_PROJECTS = '/tsb/api/test-suite/projects'
LINKING_CODE = '/tsb/api/test-suite/local-agent/linking-code'
AGENT_STATUS = '/tsb/api/test-suite/local-agent/info'
AGENT_RPC = '/tsb/api/test-suite/local-agent/rpc'
UNLINK_AGENT = '/tsb/api/test-suite/local-agent/unlink'
RPC_METADATA = f'{AGENT_RPC}/metadata'
# ================================================================= #


# ======================== Constants ========================== #
ALL_SUITE_CLASSES = (
    fake_test_suite.BrightnessSuite,
    fake_test_suite.ColorSuite,
    fake_test_suite.DeviceCommonSuite,
    fake_test_suite.LightOnOffSuite,
    fake_test_suite.LockUnlockSuite,
)
GDM_CAPABILITY_TO_HG_TRAIT = immutabledict.immutabledict({
    'pw_rpc_common': 'Common',
    'pw_rpc_light': 'OnOff',
    'pw_rpc_lock': 'LockUnlock',
})
# ============================================================= #


# ======================== Module level functions ========================== #
def setup_logger() -> None:
    """Sets up the logger for logging."""
    logger.setLevel(logging.DEBUG)
    handler = logging.StreamHandler()
    handler.setLevel(logging.DEBUG)
    handler.setFormatter(
        logging.Formatter('[%(asctime)s %(levelname)s] %(message)s'))
    logger.addHandler(handler)


def parse_args() -> Tuple[str, Optional[int]]:
    """Sets up the parser for argument parsing.

    Returns:
        Tuple: TSB host, TSB port
    """
    parser = argparse.ArgumentParser()
    parser.add_argument(
        '-host', '--tsb_host', type=str, required=False,
        default=_LOCAL_TSB_HOST, help='TSB host')
    parser.add_argument(
        '-port', '--tsb_port', type=int, required=False,
        default=_LOCAL_TSB_PORT, help='TSB port')
    args, leftover = parser.parse_known_args(sys.argv[1:])
    sys.argv[1:] = leftover

    tsb_host = args.tsb_host
    tsb_port = args.tsb_port if tsb_host == _LOCAL_TSB_HOST else None

    return tsb_host, tsb_port


def raise_exception(response: requests.models.Response, err_msg: str) -> None:
    """Raises exception for HTTP response status code != 200.

    Args:
        response: HTTP response.
        err_msg: Error message.
    """
    err_msg = f'{err_msg}. Status: {response.status_code}'
    try:
        ams_err_msg = response.json()['errorMessage']
    except:
        ams_err_msg = ''
    if ams_err_msg:
        err_msg += f', AMS error message: {ams_err_msg}'
    raise RuntimeError(err_msg)
# ========================================================================== #


class TSBService:
    """TSB endpoint service."""

    def __init__(self, tsb_host: str, tsb_port: Optional[int]):
        if tsb_port is None:
            self._base_url = f'http://{tsb_host}'
        else:
            self._base_url = f'http://{tsb_host}:{tsb_port}'
        auth_token = self.get_auth_token()
        self._auth = {'Authorization': auth_token}

    def get_auth_token(self) -> str:
        """Service to obtain the test suite user's auth token.

        Returns:
            Test suite user auth token.

        Raises:
            RuntimeError: HTTP response status code is not 200.
        """
        url = self._base_url + TEST_SUITE_AUTH
        resp = requests.post(url, json={'idToken': 'debug-token'})
        if resp.status_code != http.HTTPStatus.OK:
            raise_exception(resp, 'Failed to get auth token')
        auth_token = resp.json().get('authToken')
        return auth_token

    def get_or_create_test_project(self) -> None:
        """Service to create test project if needed.

        Raises:
            RuntimeError: HTTP response status code is not 200.
        """
        url = self._base_url + TEST_SUITE_PROJECTS
        resp = requests.get(url, headers=self._auth)
        if resp.status_code != http.HTTPStatus.OK:
            raise_exception(resp, 'Failed to get test project')

        for project in resp.json().get('result'):
            if project['id'] == _TEST_PROJECT:
                return
        resp = requests.post(
            url, headers=self._auth, json={'projectIds': [_TEST_PROJECT]})
        if resp.status_code != http.HTTPStatus.CREATED:
            raise_exception(resp, 'Failed to create test project')

    def get_linking_code(self) -> str:
        """Service to obtain the linking code.

        Returns:
            Linking code.

        Raises:
            RuntimeError: HTTP response status code is not 200.
        """
        url = self._base_url + LINKING_CODE
        resp = requests.post(
            url, headers=self._auth, json={'projectId': _TEST_PROJECT})
        if resp.status_code != http.HTTPStatus.OK:
            raise_exception(resp, 'Failed to get linking code')
        linking_code = resp.json()['result'].get('code')
        return linking_code

    def get_agent_status(self) -> Optional[Dict[str, Any]]:
        """Service to retrieve the local agent status.

        Returns:
            Local Agent status dict or None if agent is not linked.

        Raises:
            RuntimeError: HTTP response status code is not 200 nor 404.
        """
        url = self._base_url + AGENT_STATUS + f'?projectId={_TEST_PROJECT}'
        resp = requests.get(url, headers=self._auth)
        if resp.status_code == http.HTTPStatus.NOT_FOUND:
            return None
        elif resp.status_code == http.HTTPStatus.OK:
            return resp.json()['result']
        else:
            raise_exception(resp, 'Failed to get agent status')

    def send_rpc_request(self, rpc_request: Dict[str, Any]) -> str:
        """Service to send RPC request to the AMS.

        Args:
            rpc_request: JSON RPC request.

        Returns:
            JSON-RPC id.

        Raises:
            RuntimeError: HTTP response status code is not 200.
        """
        url = self._base_url + AGENT_RPC
        resp = requests.post(url,
                             headers=self._auth, json=rpc_request)
        if resp.status_code != http.HTTPStatus.OK:
            raise_exception(resp, 'Failed to send RPC request')
        rpc_id = resp.json()['result']['id']
        return rpc_id

    def get_rpc_metadata(self, rpc_id: str) -> Dict[str, Optional[int]]:
        """Service to get RPC metadata with given rpc_id.

        Args:
            rpc_id: JSON-RPC id.

        Returns:
            RPC metadata.

        Raises:
            RuntimeError: HTTP response status code is not 200.
        """
        url = (self._base_url + RPC_METADATA +
            f'?projectId={_TEST_PROJECT}&rpcId={rpc_id}')
        resp = requests.get(url, headers=self._auth)
        if resp.status_code != http.HTTPStatus.OK:
            raise_exception(resp, 'Failed to get RPC metadata')
        metadata = resp.json()['result']
        return metadata

    def get_rpc_response(self, rpc_id: str) -> Dict[str, Any]:
        """Service to get RPC response with given rpc_id.

        Args:
            rpc_id: JSON-RPC id.

        Returns:
            RPC response.

        Raises:
            RuntimeError: HTTP response status code is not 200.
        """
        url = (self._base_url + AGENT_RPC +
            f'?projectId={_TEST_PROJECT}&rpcId={rpc_id}')
        resp = requests.get(url, headers=self._auth)
        if resp.status_code != http.HTTPStatus.OK:
            raise_exception(resp, 'Failed to get RPC response')
        rpc_response = resp.json()['result']
        return rpc_response

    def unlink_agent(self) -> None:
        """API to unlink local agent. Raises RuntimeError if failed."""
        url = self._base_url + UNLINK_AGENT
        post_data = {'projectId': _TEST_PROJECT}
        resp = requests.post(url, json=post_data, headers=self._auth)
        if resp.status_code != http.HTTPStatus.OK:
            raise_exception(resp, 'Failed to unlink local agent')


class FakeFrontEnd:
    """Fake front end module."""

    def __init__(self, host: str, port: int):
        # Registers termination signal handler
        signal.signal(signal.SIGINT, self._terminate)

        # Retrieves auth token and creates test project
        self._tsb_service = TSBService(host, port)
        self._tsb_service.get_or_create_test_project()

        # Local Agent status
        self._local_agent_status = None

        # Worker threads: executing test plan and retrieving agent status.
        self._termination_event = threading.Event()
        self._test_plan_worker = threading.Thread(
            target=self._create_and_execute_test_plan, daemon=True)
        self._agent_status_worker = threading.Thread(
            target=self._retrieve_agent_status, daemon=True)

    def _terminate(self, sig_num: int, frame: 'frame') -> None:
        """Signal handler upon receiving a SIGINT.

        Args:
            sig_num: Signal number passed to the handler.
            frame: Current stack frame passed to the handler.
        """
        del sig_num, frame  # Unused
        logger.warning('Terminates fake front end process.')
        self._termination_event.set()

    def run(self) -> None:
        """Runs fake front end.

        Simulate the front end behaviors:
        1. Links the local agent.
        2. Sends RPC requests to TSB, polls the metadata and gets response.
        3. Retrieves the local agent status simultaneously.
        4. Unlinks the agent after the test is completed.
        """
        if self._link_agent():

            time.sleep(_DETECTION_COOL_DOWN_SEC)
            self._test_plan_worker.start()
            self._agent_status_worker.start()

            while not self._termination_event.is_set():
                time.sleep(_POLLING_PERIOD)

            self._unlink_agent()

    def _checks_if_agent_is_linked(self) -> bool:
        """Returns if local agent is linked.

        Returns:
            True if agent is linked, false otherwise.
        """
        status = self._tsb_service.get_agent_status()
        return status is not None and status.get('status') != 'OFFLINE'

    def _checks_if_response_is_stored(self, rpc_id: str) -> bool:
        """Returns if the rpc response of rpc_id is stored.

        Returns:
            True if the rpc response is stored, false otherwise.
        """
        metadata = self._tsb_service.get_rpc_metadata(rpc_id=rpc_id)
        resp_timestamp = metadata.get('responseStoredTimestamp')
        return resp_timestamp is not None

    def _link_agent(self) -> bool:
        """Links local agent.

        Retrieves linking code and checks if agent is linked.

        Returns:
            True if agent is linked, false otherwise.
        """
        linking_code = self._tsb_service.get_linking_code()
        print(f'\033[1m******************** Linking Code: {linking_code} ****'
             '****************\033[0m')
        while (not self._termination_event.is_set() and
            not self._checks_if_agent_is_linked()):
            logger.info(f'No agent is linked, sleep {_POLLING_PERIOD} sec...')
            time.sleep(_POLLING_PERIOD)
        if not self._termination_event.is_set():
            logger.info('The local agent is linked.')
            return True
        return False

    def _run_rpc_requests(
        self, rpc_request: Dict[str, Any]) -> Optional[Dict[str, Any]]:
        """Runs RPC request.

        Simulates the FE behavior: sends the rpc request to BE, polls for the
        rpc metadata, once it gets updated, retrieves the rpc response from BE.

        Args:
            rpc_request: JSON-RPC request.

        Returns:
            JSON-RPC response or None if fake front end is interrupted by SIGINT.
        """
        rpc_id = self._tsb_service.send_rpc_request(rpc_request=rpc_request)
        logger.info(f'Sent RPC request: {rpc_request}.')

        while (not self._termination_event.is_set() and
            not self._checks_if_response_is_stored(rpc_id)):
            logger.info(
                f'RPC response not available, sleep {_POLLING_PERIOD} sec...')
            time.sleep(_POLLING_PERIOD)

        # Interrupted by SIGINT
        if self._termination_event.is_set():
            return None

        logger.info('Metadata is updated.')
        rpc_response = self._tsb_service.get_rpc_response(rpc_id=rpc_id)

        return rpc_response

    def _unlink_agent(self) -> None:
        """Unlinks local agent."""
        logger.info('Unlinking local agent.')
        self._tsb_service.unlink_agent()
        logger.info('Local agent unlinked.')

    def _create_and_execute_test_plan(self) -> None:
        """Creates and executes test plan."""
        while self._local_agent_status is None:
            logger.info(f'Local Agent status not available yet.')
            time.sleep(_POLLING_PERIOD)
        connected_devices = []
        for device_info in self._local_agent_status['devices']:
            device_id = device_info['deviceId']
            hg_traits = self._get_hg_traits(device_info['capabilities'])
            connected_devices.append((device_id, hg_traits))

        for device_id, hg_traits in connected_devices:
            suites = self._generate_suites(device_id, hg_traits)
            for suite in suites:
                logger.info(
                    f'Executes suite {suite} for device {device_id} ...')
                self._run_suite(suite)

    def _get_hg_traits(self, capabilities: List[str]) -> List[str]:
        """Maps the GDM capability to the corresponding HG trait.

        Args:
            capabilities: List of GDM capability.

        Returns:
            List of HG traits.
        """
        hg_traits = []
        for capability in capabilities:
            hg_trait = GDM_CAPABILITY_TO_HG_TRAIT.get(capability)
            if hg_trait is not None:
                hg_traits.append(hg_trait)
        return hg_traits

    def _retrieve_agent_status(self) -> None:
        """Retrieves local agent status."""
        while not self._termination_event.is_set():
            status = self._tsb_service.get_agent_status()
            if status is not None:
                logger.info(f'Retrieves local agent status: {status}')
            self._local_agent_status = status
            time.sleep(_STATUS_POLLING_PERIOD)

    def _generate_suites(
        self, device_id: str, device_traits: Set[str]) -> List[Any]:
        """Generates test suites based on device trait.

        In reality, FE sends device info to BE for test plan/suite generation.
        To reduce the maintenance effort, the fake FE here is to simply map the
        corresponding fake-suite directly.

        Args:
            device_id: GDM device id.
            device_traits: Set of device traits on HG.

        Returns:
            The list of suites which are applicable to the device traits.
        """
        suites = []
        for suite_class in ALL_SUITE_CLASSES:
            if suite_class.is_applicable_to(device_traits):
                suites.append(suite_class(device_id))
        return suites

    def _run_suite(self, suite: Any) -> None:
        """Runs suite procedures.

        Args:
            suite: Suite instance.
        """
        device_ids = [suite.device_id]
        start_suite_rpc, end_suite_rpc = (
            fake_test_suite.generate_start_end_suite_rpc(device_ids))

        all_procedures = [start_suite_rpc] + suite.procedures + [end_suite_rpc]

        for rpc_request in all_procedures:
            logger.info(f'Runs RPC request {rpc_request}')
            rpc_response = self._run_rpc_requests(rpc_request=rpc_request)
            logger.info(f'Retrieves RPC response: {rpc_response}')
            time.sleep(_COOL_DOWN_SEC)


def main() -> None:
    """Main entry of fake front end."""
    setup_logger()
    tsb_host, tsb_port = parse_args()
    fake_front_end = FakeFrontEnd(host=tsb_host, port=tsb_port)
    fake_front_end.run()


if __name__ == '__main__':
    main()
