#   Copyright (c) 2016-2021, Xilinx, Inc.
#   SPDX-License-Identifier: BSD-3-Clause

import os
import collections
import ctypes
import itertools
import re
import struct
import warnings
from copy import deepcopy

import pynqutils
from pynqmetadata.frontends import Metadata

from .bitstream import Bitstream
from .metadata.append_drivers_pass import bind_drivers_to_metadata
from .mmio import MMIO
from .ps import Clocks
from .registers import RegisterMap

if os.environ.get("PYNQ_REMOTE_DEVICES", False):
    from .pl_server.remote_device import RemoteGPIO as GPIO
    from .pl_server.remote_device import RemoteInterrupt as Interrupt
else:
    from .gpio import GPIO
    from .interrupt import Interrupt


DRIVERS_GROUP = "pynq.lib"


class UnsupportedConfiguration(Exception):
    """Thrown by a driver that does not support the requested configuration
    of an IP.

    If a driver's __init__ throws this exception the binding system will
    issue a warning and instead create a DefaultIP instance.

    """

    pass


def _assign_drivers(description, ignore_version, device):
    """Assigns a driver for each IP and hierarchy in the description."""
    for name, details in description["hierarchies"].items():
        _assign_drivers(details, ignore_version, device)
        details["device"] = device
        details["driver"] = DocumentHierarchy
        for hip in _hierarchy_drivers:
            if hip.checkhierarchy(details):
                details["driver"] = hip
                break

    for name, details in description["ip"].items():
        details["device"] = device
        ip_type = details["type"]
        if ip_type in _ip_drivers:
            details["driver"] = _ip_drivers[ip_type]
        else:
            no_version_ip = ip_type.rpartition(":")[0] if ip_type else None
            if no_version_ip in _ip_drivers:
                if ignore_version:
                    details["driver"] = _ip_drivers[no_version_ip]
                else:
                    other_versions = [
                        v
                        for v in _ip_drivers.keys()
                        if v.startswith(no_version_ip + ":")
                    ]
                    message = (
                        "IP {0} is of type {1} and driver found for [{2}]. "
                        + "Use ignore_version=True to use this driver."
                    ).format(
                        details["fullpath"], details["type"], ", ".join(other_versions)
                    )
                    warnings.warn(message, UserWarning)
                    details["driver"] = DefaultIP
            else:
                details["driver"] = DefaultIP


def _complete_description(
    ip_dict, hierarchy_dict, ignore_version, mem_dict, device, overlay
):
    """Returns a complete hierarchical description of an overlay based
    on the three dictionaries parsed from HWH file.

    """
    starting_dict = dict()
    starting_dict["ip"] = {k: v for k, v in ip_dict.items()}
    starting_dict["hierarchies"] = {k: v for k, v in hierarchy_dict.items()}
    starting_dict["interrupts"] = dict()
    starting_dict["gpio"] = dict()
    starting_dict["memories"] = {
        re.sub("[^A-Za-z0-9_]", "", k): v for k, v in mem_dict.items() if v.get("used")
    }
    starting_dict["device"] = device
    for k, v in starting_dict["hierarchies"].items():
        v["overlay"] = overlay
    _assign_drivers(starting_dict, ignore_version, device)
    return starting_dict


_class_aliases = {
    "pynq.overlay.DocumentOverlay": "pynq.overlay.DefaultOverlay",
    "pynq.overlay.DocumentHierarchy": "pynq.overlay.DefaultHierarchy",
}


def _classname(class_):
    """Returns the full name for a class. Has option for overriding
    some class names to hide internal details. The overrides are
    stored in the `_class_aliases` dictionaries.

    """
    rawname = "{}.{}".format(class_.__module__, class_.__name__)

    if rawname in _class_aliases:
        return _class_aliases[rawname]
    else:
        return rawname


def _build_docstring(description, name, type_):
    """Helper function to build a documentation string for
    a hierarchical description.

    Parameters
    ----------
    description : dict
        The description to document.
    name : str
        The name of the object - inserted into the doc string
    type_ : str
        The type of the object - generally 'overlay' or 'hierarchy'

    Returns
    -------
    str : The generated documentation string

    """
    lines = list()
    lines.append("Default documentation for {} {}. The following".format(type_, name))
    lines.append("attributes are available on this {}:".format(type_))
    lines.append("")

    lines.append("IP Blocks")
    lines.append("----------")
    if description["ip"]:
        for ip, details in description["ip"].items():
            lines.append("{0: <20} : {1}".format(ip, _classname(details["driver"])))
    else:
        lines.append("None")
    lines.append("")

    lines.append("Hierarchies")
    lines.append("-----------")
    if description["hierarchies"]:
        for hierarchy, details in description["hierarchies"].items():
            lines.append(
                "{0: <20} : {1}".format(hierarchy, _classname(details["driver"]))
            )
    else:
        lines.append("None")
    lines.append("")

    lines.append("Interrupts")
    lines.append("----------")
    if description["interrupts"]:
        for interrupt in description["interrupts"].keys():
            lines.append("{0: <20} : pynq.interrupt.Interrupt".format(interrupt))
    else:
        lines.append("None")
    lines.append("")

    lines.append("GPIO Outputs")
    lines.append("------------")
    if description["gpio"]:
        for gpio in description["gpio"].keys():
            lines.append("{0: <20} : pynq.gpio.GPIO".format(gpio))
    else:
        lines.append("None")
    lines.append("")

    lines.append("Memories")
    lines.append("------------")
    if description["memories"]:
        for mem, mem_desc in description["memories"].items():
            if "streaming" in mem_desc and mem_desc["streaming"]:
                lines.append("{0: <20} : Stream".format(mem))
            else:
                lines.append("{0: <20} : Memory".format(mem))
    else:
        lines.append("None")
    lines.append("")
    lines.append("")
    return "\n    ".join(lines)


class Overlay(Bitstream):
    """This class keeps track of a single bitstream's state and contents.

    The overlay class holds the state of the bitstream and enables run-time
    protection of bindings.

    Our definition of overlay is: "post-bitstream configurable design".
    Hence, this class must expose configurability through content discovery
    and runtime protection.

    The overlay class exposes the IP and hierarchies as attributes in the
    overlay. If no other drivers are available the `DefaultIP` is constructed
    for IP cores at top level and `DefaultHierarchy` for any hierarchies that
    contain addressable IP. Custom drivers can be bound to IP and hierarchies
    by subclassing `DefaultIP` and `DefaultHierarchy`. See the help entries
    for those class for more details.

    This class stores four dictionaries: IP, GPIO, interrupt controller
    and interrupt pin dictionaries.

    Each entry of the IP dictionary is a mapping:
    'name' -> {phys_addr, addr_range, type, config, state}, where
    name (str) is the key of the entry.
    phys_addr (int) is the physical address of the IP.
    addr_range (int) is the address range of the IP.
    type (str) is the type of the IP.
    config (dict) is a dictionary of the configuration parameters.
    state (str) is the state information about the IP.

    Each entry of the GPIO dictionary is a mapping:
    'name' -> {pin, state}, where
    name (str) is the key of the entry.
    pin (int) is the user index of the GPIO, starting from 0.
    state (str) is the state information about the GPIO.

    Each entry in the interrupt controller dictionary is a mapping:
    'name' -> {parent, index}, where
    name (str) is the name of the interrupt controller.
    parent (str) is the name of the parent controller or '' if attached
    directly to the PS.
    index (int) is the index of the interrupt attached to.

    Each entry in the interrupt pin dictionary is a mapping:
    'name' -> {controller, index}, where
    name (str) is the name of the pin.
    controller (str) is the name of the interrupt controller.
    index (int) is the line index.

    Attributes
    ----------
    bitfile_name : str
        The absolute path of the bitstream.
    dtbo : str
        The absolute path of the dtbo file for the full bitstream.
    ip_dict : dict
        All the addressable IPs from PS. Key is the name of the IP; value is
        a dictionary mapping the physical address, address range, IP type,
        parameters, registers, and the state associated with that IP:
        {str: {'phys_addr' : int, 'addr_range' : int, \
               'type' : str, 'parameters' : dict, 'registers': dict, \
               'state' : str}}.
    gpio_dict : dict
        All the GPIO pins controlled by PS. Key is the name of the GPIO pin;
        value is a dictionary mapping user index (starting from 0),
        and the state associated with that GPIO pin:
        {str: {'index' : int, 'state' : str}}.
    interrupt_controllers : dict
        All AXI interrupt controllers in the system attached to
        a PS interrupt line. Key is the name of the controller;
        value is a dictionary mapping parent interrupt controller and the
        line index of this interrupt:
        {str: {'parent': str, 'index' : int}}.
        The PS is the root of the hierarchy and is unnamed.
    interrupt_pins : dict
        All pins in the design attached to an interrupt controller.
        Key is the name of the pin; value is a dictionary
        mapping the interrupt controller and the line index used:
        {str: {'controller' : str, 'index' : int}}.
    pr_dict : dict
        Dictionary mapping from the name of the partial-reconfigurable
        hierarchical blocks to the loaded partial bitstreams:
        {str: {'loaded': str, 'dtbo': str}}.
    device : pynq.Device
        The device that the overlay is loaded on

    """

    def __init__(
        self, bitfile_name, dtbo=None, download=True, ignore_version=False, device=None, gen_cache=False
    ):
        """Return a new Overlay object.

        An overlay instantiates a bitstream object as a member initially.

        Parameters
        ----------
        bitfile_name : str
            The bitstream name or absolute path as a string.
        dtbo : str
            The dtbo file name or absolute path as a string.
        download : bool
            Whether the overlay should be downloaded.
        ignore_version : bool
            Indicate whether or not to ignore the driver versions.
        device : pynq.Device
            Device on which to load the Overlay. Defaults to
            pynq.Device.active_device
        gen_cache: bool
            if true generates a pickeled cache of the metadata

        Note
        ----
        This class requires a HWH file to be next to bitstream file
        with same name (e.g. `base.bit` and `base.hwh`).

        """
        super().__init__(bitfile_name, dtbo, partial=False, device=device)
        
        self._register_drivers()

        self.device.set_bitfile_name(self.bitfile_name)
        self.parser = self.device.parser

        self.ip_dict = (
            self.gpio_dict
        ) = (
            self.interrupt_controllers
        ) = self.interrupt_pins = self.hierarchy_dict = dict()
        self._deepcopy_dict_from(self.parser)
        self.clock_dict = self.parser.clock_dict
        self.pr_dict = dict()
        self.ignore_version = ignore_version
        description = _complete_description(
            self.ip_dict,
            self.hierarchy_dict,
            self.ignore_version,
            self.mem_dict,
            self.device,
            self,
        )
        self._ip_map = _IPMap(description)

        # Setting device class variables
        Clocks.set_device(self.device)

        # If we have a system graph information, expose it
        if hasattr(self.parser, "systemgraph"):
            self.systemgraph = self.parser.systemgraph
        else:
            self.systemgraph = None

        if download:
            self.download()
        else:
            if gen_cache:
                self.gen_cache()

        self.__doc__ = _build_docstring(
            self._ip_map._description, bitfile_name, "overlay"
        )

    def __getattr__(self, key):
        """Overload of __getattr__ to return a driver for an IP or
        hierarchy. Throws an `RuntimeError` if the overlay is not loaded.

        """
        if self.is_loaded():
            return getattr(self._ip_map, key)
        else:
            raise RuntimeError("Overlay not currently loaded")

    def _deepcopy_dict_from(self, source):
        self.ip_dict = pynqutils.runtime.ReprDict(
            deepcopy(source.ip_dict), rootname="ip_dict"
        )
        self.gpio_dict = pynqutils.runtime.ReprDict(
            deepcopy(source.gpio_dict), rootname="gpio_dict"
        )
        self.interrupt_controllers = pynqutils.runtime.ReprDict(
            deepcopy(source.interrupt_controllers), rootname="interrupt_controllers"
        )
        self.interrupt_pins = pynqutils.runtime.ReprDict(
            deepcopy(source.interrupt_pins), rootname="interrupt_pins"
        )
        self.hierarchy_dict = pynqutils.runtime.ReprDict(
            deepcopy(source.hierarchy_dict), rootname="hierarchy_dict"
        )
        self.mem_dict = pynqutils.runtime.ReprDict(
            deepcopy(source.mem_dict), rootname="mem_dict"
        )

    def free(self):
        if self.dtbo:
            self.remove_dtbo()
        self.device.close()

    def gen_cache(self):
        """ Generate a pickled cache of the metadata even if a download has not occurred """
        super().gen_cache(self.parser)

    def download(self, dtbo=None):
        """The method to download a full bitstream onto PL.

        After the bitstream has been downloaded, the "timestamp" in PL will be
        updated. In addition, all the dictionaries on PL will
        be reset automatically.

        This method will use parameter `dtbo` or `self.dtbo` to configure the
        device tree.

        The download method will also configure some of the PS registers
        based on the metadata file provided, e.g. PL clocks,
        AXI master port width.

        Parameters
        ----------
        dtbo : str
            The path of the dtbo file.

        """
        for i in self.clock_dict:
            if "enable" in self.clock_dict[i]:
                enable = self.clock_dict[i]["enable"]
                div0 = self.clock_dict[i]["divisor0"]
                div1 = self.clock_dict[i]["divisor1"]
                if enable:
                    Clocks.set_pl_clk(i, div0, div1)
                else:
                    Clocks.set_pl_clk(i)

        super().download(self.parser)
        if dtbo:
            super().insert_dtbo(dtbo)
        elif self.dtbo:
            super().insert_dtbo()

    def pr_download(self, partial_region, partial_bit, dtbo=None, program=True):
        """The method to download a partial bitstream onto PL.

        In this method, the corresponding parser will only be
        added once the `download()` method of the hierarchical block is called.

        This method always uses the parameter `dtbo` to configure the device
        tree.

        Note
        ----
        There is no check on whether the partial region specified by users
        is really partial-reconfigurable. So users have to make sure the
        `partial_region` provided is correct.

        Parameters
        ----------
        partial_region : str
            The name of the hierarchical block corresponding to the PR region.
        partial_bit : str
            The name of the partial bitstream.
        dtbo : str
            The path of the dtbo file.
        program : bool
            Whether the overlay should be downloaded.

        """
        pr_block = self.__getattr__(partial_region)
        pr_block.download(bitfile_name=partial_bit, dtbo=dtbo, program=program)

    def is_loaded(self):
        """This method checks whether a bitstream is loaded.

        This method returns true if the loaded PL bitstream is same
        as this Overlay's member bitstream.

        Returns
        -------
        bool
            True if bitstream is loaded.

        """
        if not self.timestamp == "":
            return self.timestamp == self.device.timestamp
        else:
            return self.bitfile_name == self.device.bitfile_name

    def reset(self):
        """This function resets all the dictionaries kept in the overlay.

        This function should be used with caution. In most cases, only those
        dictionaries keeping track of states need to be updated.

        Returns
        -------
        None

        """
        self.ip_dict = self.parser.ip_dict
        self.gpio_dict = self.parser.gpio_dict
        self.interrupt_controllers = self.parser.interrupt_controllers
        self.interrupt_pins = self.parser.interrupt_pins
        if self.is_loaded():
            self.device.reset(self.parser, self.timestamp, self.bitfile_name)

    def load_ip_data(self, ip_name, data):
        """This method loads the data to the addressable IP.

        Calls the method in the super class to load the data. This method can
        be used to program the IP. For example, users can use this method to
        load the program to the Microblaze processors on PL.

        Note
        ----
        The data is assumed to be in binary format (.bin). The data name will
        be stored as a state information in the IP dictionary.

        Parameters
        ----------
        ip_name : str
            The name of the addressable IP.
        data : str
            The absolute path of the data to be loaded.

        Returns
        -------
        None

        """
        self.device.load_ip_data(ip_name, data)
        self.ip_dict[ip_name]["state"] = data

    def __dir__(self):
        return sorted(
            set(super().__dir__() + list(self.__dict__.keys()) + self._ip_map._keys())
        )

    def _register_drivers(self):
        """Imports plugin modules registered against `pynq.lib`, so that IP
        drivers contained in these modules can be registered automatically.
        """
        import importlib

        drivers_ext_man = pynqutils.setup_utils.ExtensionsManager(DRIVERS_GROUP)
        importlib.import_module(DRIVERS_GROUP)
        for ext in drivers_ext_man.list:
            importlib.import_module(ext.module_name)


_ip_drivers = dict()
_hierarchy_drivers = collections.deque()


class RegisterIP(type):
    """Meta class that binds all registers all subclasses as IP drivers

    The `bindto` attribute of subclasses should be an array of strings
    containing the VLNV of the IP the driver should bind to.

    """

    def __init__(cls, name, bases, attrs):
        if "bindto" in attrs:
            for vlnv in cls.bindto:
                _ip_drivers[vlnv] = cls
                _ip_drivers[vlnv.rpartition(":")[0]] = cls
        super().__init__(name, bases, attrs)

    def unregister(cls):
        """Unregister a subclass from the driver registry"""
        if hasattr(cls, "bindto"):
            for vlnv in cls.bindto:
                vln = vlnv.rpartition(":")[0]
                if _ip_drivers.get(vlnv, None) == cls:
                    del _ip_drivers[vlnv]
                if _ip_drivers.get(vln, None) == cls:
                    del _ip_drivers[vln]


_struct_dict = {
    # Base Vitis int types
    "char": "c",
    "signed char": "b",
    "unsigned char": "B",
    "short": "h",
    "short int": "h",
    "signed short": "h",
    "signed short int": "h",
    "unsigned short": "H",
    "unsigned short int": "H",
    "int": "i",
    "signed": "i",
    "signed int": "i",
    "unsigned": "I",
    "unsigned int": "I",
    "long": "l",
    "long int": "l",
    "signed long": "l",
    "signed long int": "l",
    "unsigned long": "L",
    "unsigned long int": "L",
    "long long": "q",
    "long long int": "q",
    "signed long long": "q",
    "signed long long int": "q",
    "unsigned long long": "Q",
    "unsigned long long int": "Q",
    # Base Vitis floating point types
    "float": "f",
    "double": "d",
    # Other types seen in the wild
    "long": "l",
    "uint": "I",
    "ushort": "H",
}


def _ctype_to_struct(ctype):
    ctype = ctype.replace("const", "").strip()
    return _struct_dict[ctype]


XrtArgument = collections.namedtuple("XrtArgument", ["name", "index", "type", "mem"])


def _create_call(regmap):
    from inspect import Parameter, Signature

    sorted_regmap = list(regmap.items())
    sorted_regmap.sort(key=lambda x: x[1]["address_offset"])

    parameters = []
    ptr_list = []
    struct_string = "="
    arg_details = {}

    for k, v in sorted_regmap:
        curr_offset = struct.calcsize(struct_string)
        reg_offset = v["address_offset"]
        if reg_offset < curr_offset:
            raise RuntimeError("Struct string generation failed")
        elif reg_offset > curr_offset:
            struct_string += "{}x".format(reg_offset - curr_offset)
        reg_type = v["type"]
        if "*" in reg_type:
            struct_string += "Q"
            ptr_type = True
        else:
            struct_string += _struct_dict[v["type"]]
            ptr_type = False
        if k != "CTRL":
            ptr_list.append(ptr_type)
            parameters.append(
                Parameter(k, Parameter.POSITIONAL_OR_KEYWORD, annotation=v["type"])
            )
            arg_details[k] = XrtArgument(
                k, len(parameters), v["type"], v["memory"] if "memory" in v else None
            )
    signature = Signature(parameters)
    return signature, struct_string, ptr_list, arg_details


class WaitHandle:
    def __init__(self, target):
        self.target = target

    def wait(self):
        while self.target.mmio.read(0) & 0x4 != 0x4:
            pass


class DefaultIP(metaclass=RegisterIP):
    """Driver for an IP without a more specific driver

    This driver wraps an MMIO device and provides a base class
    for more specific drivers written later. It also provides
    access to GPIO outputs and interrupts inputs via attributes. More specific
    drivers should inherit from `DefaultIP` and include a
    `bindto` entry containing all of the IP that the driver
    should bind to. Subclasses meeting these requirements will
    automatically be registered.

    Attributes
    ----------
    mmio : pynq.MMIO
        Underlying MMIO driver for the device
    _interrupts : dict
        Subset of the PL.interrupt_pins related to this IP
    _gpio : dict
        Subset of the PL.gpio_dict related to this IP

    """

    def __init__(self, description):
        if "device" in description:
            self.device = description["device"]
        else:
            from .pl_server.device import Device

            self.device = Device.active_device
        self.mmio = MMIO(
            description["phys_addr"], description["addr_range"], device=self.device
        )
        if "interrupts" in description:
            self._interrupts = description["interrupts"]
        else:
            self._interrupts = {}
        if "gpio" in description:
            self._gpio = description["gpio"]
        else:
            self._gpio = {}
        for interrupt, details in self._interrupts.items():
            try:
                setattr(self, interrupt, Interrupt(details["fullpath"]))
            except ValueError as e:
                warnings.warn("Interrupt {} not created: {}".format(interrupt, str(e)))
                setattr(self, interrupt, None)
        for gpio, entry in self._gpio.items():
            gpio_number = GPIO.get_gpio_pin(entry["index"])
            setattr(self, gpio, GPIO(gpio_number, "out"))
        if "registers" in description:
            self._registers = description["registers"]
            self._fullpath = description["fullpath"]
            self._register_name = description["fullpath"].rpartition("/")[2]
        else:
            self._registers = None
        if "index" in description:
            cu_index = self.device.open_context(description)
            self.cu_mask = 1 << cu_index
        if "streams" in description:
            self.streams = {}
            for k, v in description["streams"].items():
                stream = self.device.get_memory_by_name(v["stream_id"])
                self.streams[k] = stream
                if v["direction"] == "output":
                    stream.source_ip = self
                elif v["direction"] == "input":
                    stream.sink_ip = self

    @property
    def register_map(self):
        if not hasattr(self, "_register_map"):
            if self._registers:
                self._register_map = RegisterMap.create_subclass(
                    self._register_name, self._registers
                )(self.mmio.array)
            else:
                raise AttributeError(
                    "register_map only available if the .hwh is provided"
                )
        return self._register_map

    @property
    def signature(self):
        """The signature of the `call` method"""
        if hasattr(self, "_signature"):
            return self._signature
        else:
            return None

    def _call(self, *args, **kwargs):
        self.start(*args, **kwargs).wait()

    def read(self, offset=0):
        """Read from the MMIO device

        Parameters
        ----------
        offset : int
            Address to read

        """
        return self.mmio.read(offset)

    def write(self, offset, value):
        """Write to the MMIO device

        Parameters
        ----------
        offset : int
            Address to write to
        value : int or bytes
            Data to write

        """
        self.mmio.write(offset, value)


class _IPMap:
    """Class that stores drivers to IP, hierarchies, interrupts and
    gpio as attributes.

    """

    def __init__(self, desc):
        """Create a new _IPMap based on a hierarchical description."""
        self._description = desc

    def __getattr__(self, key):
        if key in self._description["hierarchies"]:
            hierdescription = self._description["hierarchies"][key]
            hierarchy = hierdescription["driver"](hierdescription)
            setattr(self, key, hierarchy)
            return hierarchy
        elif key in self._description["ip"]:
            ipdescription = self._description["ip"][key]
            try:
                driver = ipdescription["driver"](ipdescription)
            except UnsupportedConfiguration as e:
                warnings.warn(
                    "Configuration if IP {} not supported: {}".format(key, str(e.args)),
                    UserWarning,
                )
                print(f"Creating Driver for {key}  from description {ipdescription}")
                driver = DefaultIP(ipdescription)
            setattr(self, key, driver)
            return driver
        elif key in self._description["interrupts"]:
            interrupt = Interrupt(self._description["interrupts"][key]["fullpath"])
            setattr(self, key, interrupt)
            return interrupt
        elif key in self._description["gpio"]:
            gpio_index = self._description["gpio"][key]["index"]
            gpio_number = GPIO.get_gpio_pin(gpio_index)
            gpio = GPIO(gpio_number, "out")
            setattr(self, key, gpio)
            return gpio
        elif key in self._description["memories"]:
            mem = self._description["device"].get_memory(
                self._description["memories"][key]
            )
            setattr(self, key, mem)
            return mem
        else:
            raise AttributeError(
                "Could not find IP or hierarchy {} in overlay".format(key)
            )

    def _keys(self):
        """The set of keys that can be accessed through the IP map"""
        return (
            list(self._description["hierarchies"].keys())
            + list(i for i in self._description["ip"].keys())
            + list(i for i in self._description["interrupts"].keys())
            + list(i for i in self._description["gpio"].keys())
            + list(g for g in self._description["memories"].keys())
        )

    def __dir__(self):
        return sorted(
            set(super().__dir__() + list(self.__dict__.keys()) + self._keys())
        )


def DocumentOverlay(bitfile, download):
    """Function to build a custom overlay class with a custom docstring
    based on the supplied bitstream. Mimics a class constructor.

    """

    class DocumentedOverlay(Overlay):
        def __init__(self):
            super().__init__(bitfile, download=download)

    overlay = DocumentedOverlay()
    DocumentedOverlay.__doc__ = _build_docstring(
        overlay._ip_map._description, bitfile, "overlay"
    )
    return overlay


def DocumentHierarchy(description):
    """Helper function to build a custom hierarchy class with a docstring
    based on the description. Mimics a class constructor

    """

    class DocumentedHierarchy(DefaultHierarchy):
        def __init__(self):
            super().__init__(description)

    hierarchy = DocumentedHierarchy()
    DocumentedHierarchy.__doc__ = _build_docstring(
        description, description["fullpath"], "hierarchy"
    )
    return hierarchy


class RegisterHierarchy(type):
    """Metaclass to register classes as hierarchy drivers

    Any class with this metaclass an the `checkhierarchy` function
    will be registered in the global driver database

    """

    def __init__(cls, name, bases, attrs):
        if "checkhierarchy" in attrs:
            _hierarchy_drivers.appendleft(cls)
        super().__init__(name, bases, attrs)

    def unregister(cls):
        if cls in _hierarchy_drivers:
            _hierarchy_drivers.remove(cls)


class DefaultHierarchy(_IPMap, metaclass=RegisterHierarchy):
    """Hierarchy exposing all IP and hierarchies as attributes

    This Hierarchy is instantiated if no more specific hierarchy class
    registered with register_hierarchy_driver is specified. More specific
    drivers should inherit from `DefaultHierarachy` and call it's constructor
    in __init__ prior to any other initialisation. `checkhierarchy` should
    also be redefined to return True if the driver matches a hierarchy.
    Any derived class that meets these requirements will automatically be
    registered in the driver database.

    Attributes
    ----------
    description : dict
        Dictionary storing relevant information about the hierarchy.
    parsers : dict
        Parser objects for partial block design metadata.
    bitstreams : dict
        Bitstream objects for partial designs.
    pr_loaded : str
        The absolute path of the partial bitstream loaded.

    """

    def __init__(self, description):
        self.description = description
        self.parsers = dict()
        self.bitstreams = dict()
        self.pr_loaded = ""
        self.device = description["device"]
        self._overlay = description["overlay"]
        super().__init__(description)

    @staticmethod
    def checkhierarchy(description):
        """Function to check if the driver matches a particular hierarchy

        This function should be redefined in derived classes to return True
        if the description matches what is expected by the driver. The default
        implementation always returns False so that drivers that forget don't
        get loaded for hierarchies they don't expect.

        """
        return False

    def download(self, bitfile_name, dtbo=None, program=True):
        """Function to download a partial bitstream for the hierarchy block.

        Since it is hard to know which hierarchy is to be reconfigured by only
        looking at the metadata, we assume users will tell this information.
        Thus, this function should be called only when users are sure about
        the hierarchy name of the partial region.

        Parameters
        ----------
        bitfile_name : str
            The name of the partial bitstream.
        dtbo : str
            The relative or absolute path of the partial dtbo file.
        program : bool
            Whether the overlay should be downloaded.

        """
        if self.pr_loaded:
            self._find_bitstream_by_abs(self.pr_loaded).remove_dtbo()
        self._locate_metadata(bitfile_name, dtbo)
        self._parse(bitfile_name)
        self._load_bitstream(bitfile_name, program)
        if dtbo:
            self.bitstreams[bitfile_name].insert_dtbo()

        self.device.update_partial_region(
            self.description["fullpath"], self.parsers[self.pr_loaded]
        )

        self._overlay._deepcopy_dict_from(self.device)
        self._overlay.pr_dict[self.description["fullpath"]] = {
            "loaded": self.pr_loaded,
            "dtbo": dtbo,
        }
        description = _complete_description(
            self._overlay.ip_dict,
            self._overlay.hierarchy_dict,
            self._overlay.ignore_version,
            self._overlay.mem_dict,
            self._overlay.device,
            self._overlay,
        )
        self._overlay._ip_map = _IPMap(description)

    def _find_bitstream_by_abs(self, absolute_path):
        for i in self.bitstreams.keys():
            if self.bitstreams[i].bitfile_name == absolute_path:
                return self.bitstreams[i]
        return None

    def _locate_metadata(self, bitfile_name, dtbo):
        self.bitstreams[bitfile_name] = Bitstream(bitfile_name, dtbo, partial=True, device=self.device)
        bitfile_name = self.bitstreams[bitfile_name].bitfile_name
        self.parsers[bitfile_name] = self.device.get_bitfile_metadata(bitfile_name, partial=True)

    def _parse(self, bitfile_name):
        bitfile_name = self.bitstreams[bitfile_name].bitfile_name
        fullpath = self.description["fullpath"]
        ip_dict = dict()
        for k, v in self.parsers[bitfile_name].ip_dict.items():
            ip_dict_id = fullpath + "/" + v["fullpath"]
            ip_dict[ip_dict_id] = v
            ip_dict[ip_dict_id]["fullpath"] = fullpath + "/" + v["fullpath"]
        self.parsers[bitfile_name].ip_dict = ip_dict

        self.parsers[bitfile_name].nets = {
            fullpath + "_" + s: {fullpath + "/" + i for i in p}
            for s, p in self.parsers[bitfile_name].nets.items()
            if s is not None and p is not None
        }

        self.parsers[bitfile_name].pins = {
            fullpath + "/" + p: fullpath + "_" + s
            for p, s in self.parsers[bitfile_name].pins.items()
            if s is not None and p is not None
        }

    def _load_bitstream(self, bitfile_name, program):
        if program:
            self.bitstreams[bitfile_name].download()
        self.pr_loaded = self.bitstreams[bitfile_name].bitfile_name


