# Copyright 2022 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#      http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Module for LocalAgentProcess."""
import argparse
from concurrent import futures
import configparser
import enum
import importlib
import json
import os
import shutil
import signal
import sys
import threading
import time
import traceback
from typing import Any, Dict, List, Optional, Tuple

import gazoo_device

from local_agent import errors
from local_agent import ams_client
from local_agent import suite_session_manager
from local_agent import version
from local_agent import logger as logger_module
from local_agent.translation_layer import translation_layer


logger = logger_module.get_logger()

# ========================= Constants / Configs ========================= #
VERSION_FLAGS = {"-v", "-version", "--version"}
APP_ENGINE_DATA_SIZE_LIMIT = 31 * 1048576  # 31 MB
APP_ENGINE_DATA_SIZE_LIMIT_HUMAN_READABLE = '31 MB'
DEFAULT_ARTIFACTS_DIR = '/tmp/local_agent_artifacts'
AUTH_FILE = os.path.expanduser('~/.config/google/matter_local_agent_auth.json')
DEFAULT_USER_CONFIG = os.path.expanduser(
    '~/.config/google/local_agent_config.ini')
_START_TEST_SUITE_METHOD = 'startTestSuite'
_END_TEST_SUITE_METHOD = 'endTestSuite'
_USER_CONFIG_ROOT_KEY = 'ServerConfig'
_USER_CONFIG_AMS_HOST = 'AMS_HOST'
_USER_CONFIG_AMS_PORT = 'AMS_PORT'
_USER_CONFIG_AMS_SCHEME = 'AMS_SCHEME'
_USER_CONFIG_ARTIFACTS_DIR = 'ARTIFACTS_DIR'
_USER_CONFIG_MATTER_DEVICE_CONTROLLERS = 'MatterDeviceControllerPackages'
_EAP_AMS_HOST = 'matter-test-suite-eap.withgoogle.com'
_EAP_AMS_PORT = None
_EAP_AMS_SCHEME = 'https'
# ======================================================================= #


class RpcRequestType(enum.Enum):
    """RPC request type enum."""
    START_TEST_SUITE = enum.auto()
    END_TEST_SUITE = enum.auto()
    DEVICE_QUERY_CONTROL = enum.auto()

# ======================== Module level functions ========================== #
def rpc_request_type(method: str) -> RpcRequestType:
    """RPC request type selector.

    Args:
        method: JSON-RPC request method.

    Returns:
        Request type.
    """
    if method == _START_TEST_SUITE_METHOD:
        return RpcRequestType.START_TEST_SUITE
    elif method == _END_TEST_SUITE_METHOD:
        return RpcRequestType.END_TEST_SUITE
    else:
        return RpcRequestType.DEVICE_QUERY_CONTROL
# ========================================================================== #


class LocalAgentProcess:
    """Local Agent Process.

    A continuously running process which constantly sends GET
    requests to the Agent Management Service for incoming RPC
    requests.
    """

    _REPORT_INFO_INTERVAL_SECONDS = 30
    _REPORT_INFO_THREAD_TERMINATION_WAIT_SECONDS = 10
    _POLL_RPC_COOL_DOWN_SECONDS = 1
    _POLL_RPC_THREAD_TERMINATION_WAIT_SECONDS = 10
    _MAIN_THREAD_KEEP_ALIVE_COOLDOWN_SECONDS = 1

    _MAX_WORKERS_FOR_RPC_EXECUTION_THREAD_POOL = 5

    # Status for local agent, defined in the AMS.
    # Note that we don't need 'OFFLINE' status because that is determined by
    # the AMS.
    _STATUS_RUNNING = 'RUNNING'
    _STATUS_IDLE = 'IDLE'

    def __init__(self, client: ams_client.AmsClient, artifacts_dir: str):
        """Initializes local agent.

        Args:
            client: The AmsClient instance.
            artifacts_dir: Artifacts directory.
        """
        self._ams_client = client
        self._artifacts_dir = artifacts_dir

        # Threads
        self._termination_event = threading.Event()
        self._rpc_polling_thread = None
        self._info_reporting_thread = None
        self._rpc_execution_thread_pool = futures.ThreadPoolExecutor(
            max_workers=self._MAX_WORKERS_FOR_RPC_EXECUTION_THREAD_POOL)

        # Store IDs of running Futures.
        self._rpc_execution_future_ids = set()

        # Translation Layer
        self._translator = translation_layer.TranslationLayer(client)

        # Suite session manager
        self._suite_mgr = suite_session_manager.SuiteSessionManager(
            artifacts_fn=self._compress_artifacts_and_upload,
            artifact_root_dir=artifacts_dir,
            create_devices_fn=self._translator.create_devices,
            close_devices_fn=self._translator.close_devices)

    def run(self) -> None:
        """Runs the local agent, starting polling JSON-RPC and reporting info.

        We start two threads, one of polling JSON-RPC from AMS and the other
        for reporting info to AMS.
        """
        if not self._setup_credentials():
            logger.warning('Local Agent linking failed, exit the process.')
            return
        logger.info('Local Agent is linked successfully.')

        # Register termination signal handler
        signal.signal(signal.SIGINT, self._terminate)

        self._translator.start(termination_event=self._termination_event)
        self._suite_mgr.start(termination_event=self._termination_event)
        self._start_info_reporting_thread()
        self._start_rpc_polling_thread()
        while not self._termination_event.is_set():
            # If any top level thread is dead, terminate the local agent.
            if self._terminate_if_thread_not_running(
                self._info_reporting_thread):
                break
            if self._terminate_if_thread_not_running(
                self._rpc_polling_thread):
                break

            time.sleep(self._MAIN_THREAD_KEEP_ALIVE_COOLDOWN_SECONDS)

    def _terminate_if_thread_not_running(
        self, target_thread: Optional[threading.Thread]) -> bool:
        """Terminates local agent if the target thread is not running.

        Args:
            target_thread: The thread to check running.

        Returns:
            True if termination procedure initiated, i.e., the thread is not
            running.
        """
        if target_thread is None or not target_thread.is_alive():
            logger.error('Thread is dead or not even started.')
            self._terminate(None, None)
            return True
        return False

    def _start_info_reporting_thread(self) -> None:
        """Starts the _report_info job in a thread."""
        if self._info_reporting_thread is None:
            self._info_reporting_thread = threading.Thread(
                target=self._report_info, name='Info-reporting-thread')
        if not self._info_reporting_thread.is_alive():
            self._info_reporting_thread.start()

    def _start_rpc_polling_thread(self) -> None:
        """Starts the _poll_rpc job in a thread."""
        if self._rpc_polling_thread is None:
            self._rpc_polling_thread = threading.Thread(
                target=self._poll_rpc, name='RPC-polling-thread')
        if not self._rpc_polling_thread.is_alive():
            self._rpc_polling_thread.start()

    def _start_rpc_execution_thread(self, rpc_request: Dict[str, Any]) -> None:
        """Submits a _execute_rpc job to RPC execution thread pool."""
        future = self._rpc_execution_thread_pool.submit(self._execute_rpc,
                                                        rpc_request)
        self._rpc_execution_future_ids.add(id(future))
        future.add_done_callback(self._callback_for_rpc_execution_complete)

    def _report_info(self) -> None:
        """Reports local agent information back to AMS periodically.

        Note that this method contains an infinite loop, and is designed to be
        run in a separate thread, instead of the main thread.
        Typically we should not invoke this method directly, and instead we
        use _start_info_reporting_thread method.

        The information being reported includes:
        - The devices connected to this local agent.
        - The version of this local agent.
        - GDM version in use.
        - The status of this local agent.
        """
        while True:
            logger.info('Reporting status to AMS.')

            devices = self._translator.detect_devices()
            status = (self._STATUS_RUNNING if self._rpc_execution_future_ids
                      else self._STATUS_IDLE)
            local_agent_info = {
                'devices': devices,
                'gdmVersion': gazoo_device.__version__,
                'status': status,
                'version': version.__version__,
            }
            try:
                self._ams_client.report_info(local_agent_info)
            except errors.ApiTimeoutError:
                logger.warning('Report info API timed out.')
            except errors.ApiError as e:
                logger.warning('Report status failed. %s', e)
            except errors.UnlinkedError:
                logger.warning('The local agent is unlinked.')
                self._clean_up_and_terminate_agent(remove_auth_file=True)
                break

            if self._termination_event.wait(self._REPORT_INFO_INTERVAL_SECONDS):
                break

        logger.info('Stopped reporting info because stop event is set.')

    def _poll_rpc(self) -> None:
        """Polls AMS for JSON-RPC requests.

        Note that this method contains an infinite loop, and is designed to be
        run in a separate thread, instead of the main thread.
        """
        while not self._termination_event.is_set():
            logger.info('Polling JSON-RPC requests from AMS.')

            try:
                rpc_request = self._ams_client.get_rpc_request_from_ams()
            except errors.ApiTimeoutError:
                logger.warning('Get RPC request API timed out.')
                rpc_request = None
            except errors.ApiError as e:
                logger.warning(f'Failed to get RPC request from AMS: {e}')
                rpc_request = None
            except errors.UnlinkedError:
                logger.warning('The local agent is unlinked.')
                self._clean_up_and_terminate_agent(remove_auth_file=True)
                break

            if rpc_request is not None:
                try:
                    self._ams_client.remove_rpc_request_from_ams(rpc_request)
                except (errors.ApiTimeoutError, errors.ApiError):
                    logger.exception(
                        'Failed to remove JSON-RPC request from AMS.'
                        'Terminating the local agent process.')
                    self._clean_up_and_terminate_agent()
                    break
                self._start_rpc_execution_thread(rpc_request)

            time.sleep(self._POLL_RPC_COOL_DOWN_SECONDS)
        logger.info('Stopped polling RPC because stop event is set.')

    def _execute_rpc(self, rpc_request: Dict[str, Any]) -> None:
        """Executes the JSON-RPC request, and sends result back to AMS."""
        rpc_id = rpc_request.get('id')
        logger.info(f'Executing JSON-RPC: {rpc_id}')

        rpc_response = self._handle_rpc_request(rpc_request)

        if self._translator.is_rpc_timeout(rpc_id):
            logger.warning(f'RPC {rpc_id} has timed out, ignoring response.')
        else:
            try:
                self._ams_client.send_rpc_response(rpc_response)
            except (errors.ApiTimeoutError, errors.ApiError):
                logger.exception('Failed to send JSON-RPC response to AMS.')

    def _handle_rpc_request(
        self, rpc_request: Dict[str, Any]) -> Dict[str, Any]:
        """Handles the JSON-RPC request and returns the response.

        Args:
            rpc_request: JSON-RPC request.

        Returns:
            JSON-RPC response.
        """
        req_type = rpc_request_type(rpc_request['method'])

        try:
            if req_type == RpcRequestType.START_TEST_SUITE:
                resp = self._suite_mgr.start_test_suite(rpc_request)

            elif req_type == RpcRequestType.END_TEST_SUITE:
                resp = self._suite_mgr.end_test_suite(rpc_request)

            elif req_type == RpcRequestType.DEVICE_QUERY_CONTROL:
                resp = self._translator.dispatch_to_cmd_handler(rpc_request)

            else:
                raise errors.InvalidRPCError(
                    f'Invalid RPC request type {req_type}.')

        except Exception as e:
            logger.exception('Error when handling JSON-RPC.')
            err_resp = {'id': rpc_request['id'], 'jsonrpc': '2.0'}
            err_code = getattr(e, 'err_code', errors.DEFAULT_ERROR_CODE)
            stack_trace = traceback.format_exc()
            err_msg = stack_trace if err_code == errors.DEFAULT_ERROR_CODE else str(e)
            err_resp['error'] = {'code': err_code, 'message': err_msg}
            return err_resp

        return resp

    def _callback_for_rpc_execution_complete(self,
                                             future: futures.Future) -> None:
        """Callback function when an RPC execution is complete.

        What this callback does:
        1) We remove the id(future) from the self._rpc_execution_future_ids.
        2) We log the exception if there is one.

        This callback should be registered to an RPC execution future object
        using Future.add_done_callback() instead of being called directly.

        Args:
            future: The Future object when we submit an RPC execution job to
              the ThreadPoolExecutor.
        """
        future_id = id(future)
        self._rpc_execution_future_ids.remove(future_id)
        exc = future.exception()
        if exc is not None:
            logger.error('RPC execution encounters exception: %s', exc)

    def _terminate(self, sig_num, frame) -> None:
        """Termination procedure for local agent. A signal handler.

        We set the termination event and stop the 2 top-level threads:
        info-reporting and rpc-polling. Also shutdown the ThreadPoolExecutor
        for RPC execution.

        Args:
          sig_num: Signal number passed to a signal handler. See:
            https://docs.python.org/3/library/signal.html#signal.signal
          frame: Current stack frame passed to a signal hanlder. See:
            https://docs.python.org/3/library/signal.html#signal.signal
        """
        del sig_num, frame  # Unused.

        logger.warning('Terminating local agent process.')

        self._clean_up_and_terminate_agent()

        thread_and_wait_time = (
            (self._rpc_polling_thread,
             self._POLL_RPC_THREAD_TERMINATION_WAIT_SECONDS),
            (self._info_reporting_thread,
             self._REPORT_INFO_THREAD_TERMINATION_WAIT_SECONDS))

        for thread, wait_time in thread_and_wait_time:
            if thread is None:
                # The thread wasn't even created. Skipping.
                continue
            logger.info(f'Waiting for thead {thread.name} to stop. '
                        f'(Timeout = {wait_time} seconds)')
            thread.join(timeout=wait_time)
            if thread.is_alive():
                logger.error('Thread %s still alive after waiting %s seconds.',
                             thread.name, wait_time)

        self._rpc_execution_thread_pool.shutdown(wait=False)

        logger.warning('Local agent process terminated.')

    def _read_auths(self) -> Tuple[str, str]:
        """Read agent auths from local config.

        Reads the stored agent_id and agent_secret locally.

        Returns:
            Tuple of agent_id and agent_secret.
        """
        with open(AUTH_FILE, 'r') as fstream:
            auths = json.load(fstream)
        return auths['agent_id'], auths['agent_secret']

    def _write_auths(self, agent_id: str, agent_secret: str) -> None:
        """Writes agent auths into local config.

        Stores the agent_id and agent_secret into the credential auth file
        locally.

        Args:
            agent_id: local agent id.
            agent_secret: local agent secret.
        """
        with open(AUTH_FILE, 'w') as fstream:
            auths = {'agent_id': agent_id, 'agent_secret': agent_secret}
            json.dump(auths, fstream)

    def _setup_credentials(self) -> bool:
        """Sets up credentials.

        Read credentials from a local auth file. If credentials not available
        or expired, start the register process.

        Returns:
            True if credentials are set up successfully. False otherwise.
        """
        credentials_set_up = False
        try:
            agent_id, agent_secret = self._read_auths()
            self._ams_client.set_local_agent_credentials(
                local_agent_id=agent_id,
                local_agent_secret=agent_secret)
            credentials_set_up = True
        except (
            FileNotFoundError, errors.CredentialsError, errors.UnlinkedError):
            logger.info('Start the linking process')
            while True:
                try:
                    linking_code = input('Linking Code:')
                    agent_id, agent_secret = self._ams_client.register(
                        linking_code=linking_code)
                    self._write_auths(agent_id, agent_secret)
                    credentials_set_up = True
                    break
                except errors.CredentialsError as e:
                    # We don't use logger.exception here in order not to
                    # overwhelm the user interface.
                    logger.warning(
                        'Invalid linking code. Please retry. (%s)', e)
                except errors.ApiTimeoutError as e:
                    logger.warning(
                        'Register API timed out. Please retry. (%s)', e)
                except errors.ApiError as e:
                    logger.warning(
                        'Agent registration failed. Please retry. (%s)', e)
                except KeyboardInterrupt:
                    break
        return credentials_set_up

    def _compress_artifacts_and_upload(
        self,
        test_suite_id: str,
        test_result_id: Optional[str] = None) -> None:
        """Compresses the artifacts and uploads if needed.

        Compresses the artifact directory, and uploads the artifacts
        to AMS if the test result ID is provided.

        Args:
            test_suite_id: Test suite ID.
            test_result_id: Test result ID.

        Raises:
            RuntimeError: Uploading fails.
        """
        logger.info(f'Compressing artifacts for {test_suite_id}')

        test_suite_dir = os.path.join(self._artifacts_dir, test_suite_id)

        # remove logging handler
        local_agent_log = os.path.join(
            self._artifacts_dir, test_suite_id, 'local_agent.log')
        logger_module.remove_file_handler(local_agent_log)

        # compress the artifacts directory and remove it after the compression
        shutil.make_archive(test_suite_dir, 'gztar', test_suite_dir)
        shutil.rmtree(test_suite_dir)

        if test_result_id is not None:
            logger.info(f'Uploading artifacts for {test_suite_id}')

            # check against file size limit.
            artifacts_name = test_suite_dir + '.tar.gz'
            if os.stat(artifacts_name).st_size >= APP_ENGINE_DATA_SIZE_LIMIT:
                raise RuntimeError(
                    f'The file size of {artifacts_name} is larger than '
                    f'{APP_ENGINE_DATA_SIZE_LIMIT_HUMAN_READABLE}.')

            # upload the artifact
            try:
                self._ams_client.upload_artifact(artifacts_name,
                                                 test_result_id=test_result_id)
            except (errors.ApiTimeoutError, errors.ApiError):
                logger.exception('Failed to upload artifact.')

    def _clean_up_and_terminate_agent(
        self, remove_auth_file: bool = False) -> None:
        """Cleanup method for local agent.

        Cleans up suite session and sets the terminate event.
        Removes the auth file if remove_auth_file is true.

        Args:
            remove_auth_file: To remove the auth file or not.
        """
        if remove_auth_file:
            if os.path.exists(AUTH_FILE):
                os.remove(AUTH_FILE)
        self._suite_mgr.clean_up()
        self._termination_event.set()


def read_config() -> Tuple[Dict[str, Any], List[Tuple[str, str]]]:
    """Reads user data from configuration file.

    The config file should be in YAML format. User can specify the path to
    their config file using command line argument. If not provided, we use the
    default path, DEFAULT_USER_CONFIG.

    The configuration file is not required. An empty dict is returned when
    there's no such config file.

    Raises:
        RuntimeError: If unable to parse the config file as YAML.

    Returns:
        User configuration data for AMS and extension Matter device controllers.
    """
    parser = argparse.ArgumentParser()
    parser.add_argument('-u', '--user_config', type=str, required=False,
                        default=DEFAULT_USER_CONFIG,
                        help='Local Agent user config file.')
    args, leftover = parser.parse_known_args(sys.argv[1:])
    sys.argv[1:] = leftover

    if not os.path.exists(args.user_config):
        return {}, {}

    config = configparser.ConfigParser(allow_no_value=True)
    config.read(args.user_config)

    if _USER_CONFIG_ROOT_KEY not in config:
        raise ValueError(
            f'Invalid config file, no section {_USER_CONFIG_ROOT_KEY}.'
            'Please refer to example_config.ini for reference.')

    matter_controllers_config = (
        config.items(_USER_CONFIG_MATTER_DEVICE_CONTROLLERS)
        if _USER_CONFIG_MATTER_DEVICE_CONTROLLERS in config else [])


    return config[_USER_CONFIG_ROOT_KEY], matter_controllers_config


def register_extension_controllers(
    matter_controllers_config: List[Tuple[str, str]]) -> None:
    """Registers partner extension Matter device controllers in GDM.

    Args:
        matter_controllers_config: List of controllers to be registered.
    """
    for controller_package_name, _ in matter_controllers_config:
        controller = importlib.import_module(controller_package_name)
        gazoo_device.register(controller)


def exit_if_query_module_versions() -> None:
    """Exits the process if querying the Local Agent and GDM versions."""
    if VERSION_FLAGS & set(sys.argv):
        la_version = f'\n******* Local Agent version {version.__version__} *******'
        gdm_version = f'\n******* GDM version {gazoo_device.__version__} *******'
        logger.info(la_version + gdm_version)
        sys.exit(0)


def main() -> None:
    """Main entry of Local Agent."""
    exit_if_query_module_versions()

    user_config, matter_controllers_config = read_config()

    ams_host = user_config.get(_USER_CONFIG_AMS_HOST, _EAP_AMS_HOST)
    ams_port = user_config.get(_USER_CONFIG_AMS_PORT, _EAP_AMS_PORT)
    ams_scheme = user_config.get(_USER_CONFIG_AMS_SCHEME, _EAP_AMS_SCHEME)
    artifacts_dir = (
        user_config.get(_USER_CONFIG_ARTIFACTS_DIR, DEFAULT_ARTIFACTS_DIR))

    register_extension_controllers(matter_controllers_config)

    client = ams_client.AmsClient(host=ams_host,
                                  port=ams_port,
                                  scheme=ams_scheme)
    proc = LocalAgentProcess(client=client, artifacts_dir=artifacts_dir)
    proc.run()
