# Copyright 2022 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#      http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Module for GDM manager."""
import threading
from typing import Any, Callable, Dict, List, Optional

import gazoo_device
from gazoo_device import errors as gdm_errors

from local_agent import errors as agent_errors
from local_agent import logger as logger_module
from local_agent.translation_layer.command_handlers import common


logger = logger_module.get_logger()

# ============= Constants ============= #
_DEVICE_DETECTION_LOG_DIR = '/tmp'
# ===================================== #


class GdmManager:
    """GDM manager for device management."""

    def __init__(
        self, update_handlers_cls_map_fn: Callable[[str, Any], None]):
        self._mgr_lock = threading.RLock()
        self._mgr = gazoo_device.Manager()
        self._first_detection = True
        self._update_handlers_cls_map = update_handlers_cls_map_fn

        # Caching device information to avoid race condition when doing device
        # communication multiple times.
        self._connected_devices = {}  # device id -> device information dict

    def detect_devices(self) -> List[Dict[str, Any]]:
        """Detects connected devices.

        Returns:
            The list of device dict, each dict includes the following fields:
            deviceId, serialNumber, capabilities and deviceType in GDM.
        """
        devices = []  # List[Dict[str, Any]]
        with self._mgr_lock:
            self._mgr.detect(force_overwrite=self._first_detection,
                             log_directory=_DEVICE_DETECTION_LOG_DIR)
            connected_devices = self._mgr.get_devices('all')
            for device_id, info in connected_devices.items():

                # The device has been removed.
                if not self._mgr.is_device_connected(device_id):
                    continue

                # The first time this device is detected.
                if device_id not in self._connected_devices:
                    serial_number = info['persistent']['serial_number']
                    device_type = info['persistent']['device_type']

                    # Retrieve the Matter endpoints and clusters in the first detection.
                    with self._mgr.create_and_close_device(device_id) as device:
                        matter_endpoints = device.matter_endpoints.get_supported_endpoints()
                        endpoint_clusters_mapping = device.matter_endpoints.get_supported_endpoints_and_clusters()
                        cluster_capabilities = list(set().union(*endpoint_clusters_mapping.values()))

                        # Additional non-Matter capabilities.
                        if device.has_capabilities([common.PWRPC_COMMON_CAPABILITY]):
                            matter_endpoints.append(common.PWRPC_COMMON_CAPABILITY)
                            cluster_capabilities.append(common.PWRPC_COMMON_CAPABILITY)

                    # Therefore, we only need to update the command handlers once.
                    self._update_handlers_cls_map(device_type, matter_endpoints)

                    # Store the device information in cache.
                    device_info = {
                        'deviceId': device_id,
                        'serialNumber': serial_number,
                        'deviceType': device_type,
                        'capabilities': cluster_capabilities,
                    }
                    self._connected_devices[device_id] = device_info

                devices.append(self._connected_devices[device_id])
                
        self._first_detection = False

        return devices

    def create_devices(
        self,
        identifiers: List[str],
        log_directory: Optional[str] = None) -> None:
        """Creates GDM device instances.

        Args:
            identifiers: List of GDM device id.
            log_directory: GDM log directory.
        """
        with self._mgr_lock:
            for device_id in identifiers:
                self._mgr.create_device(
                    identifier=device_id, log_directory=log_directory)

    def check_device_connected(self, identifier: str) -> None:
        """Checks if the device is connected.

        Args:
            identifier: GDM device id.

        Raises:
            DeviceNotConnectedError: When device is not connected.
            DeviceError: When unexpected error occurs.
        """
        with self._mgr_lock:
            try:
                if not self._mgr.is_device_connected(identifier):
                    raise agent_errors.DeviceNotConnectedError(
                        f'Device {identifier} is not connected.')
            except gdm_errors.DeviceError as e:
                logger.warning(f'check_device_connected failed: {e}')
                raise e

    def get_device_instance(self, identifier: str) -> Any:
        """Gets the device instance in GDM.

        Args:
            identifier: GDM device id.

        Returns:
            Device instance in GDM.
        """
        with self._mgr_lock:
            if identifier not in self._mgr.get_open_device_names():
                raise agent_errors.DeviceNotOpenError(f'{identifier} is not open.')
            return self._mgr.get_open_device(identifier)

    def get_device_type(self, identifier: str) -> str:
        """Gets the device type in GDM.

        Args:
            identifier: GDM device id.

        Returns:
            Device type in GDM.
        """
        with self._mgr_lock:
            dut = self.get_device_instance(identifier)
            return dut.device_type

    def close_open_devices(self) -> None:
        """Closes all open devices in GDM."""
        with self._mgr_lock:
            self._mgr.close_open_devices()
