from __future__ import annotations

import os
import random
import string
import subprocess
import sys
import tempfile

from pt.machine import Machine
from pt.pt import PageTableDump
from pt.pt_aarch64_parse import PT_Aarch64_Backend
from pt.pt_riscv64_parse import PT_RiscV64_Backend
from pt.pt_x86_64_parse import PT_x86_64_Backend

import pwndbg
import pwndbg.aglib
import pwndbg.aglib.kernel
import pwndbg.aglib.kernel.paging
import pwndbg.aglib.memory
import pwndbg.aglib.qemu
import pwndbg.color.message as message
import pwndbg.dbg_mod
import pwndbg.lib.cache
import pwndbg.lib.memory
from pwndbg.lib.memory import Page


class KernelVmmap:
    def __init__(self, pages: tuple[Page, ...]):
        self.pages = pages
        self.sections: tuple[tuple[str, int], ...] = None
        self.pi = pwndbg.aglib.kernel.arch_paginginfo()
        if self.pi:
            self.sections = self.pi.markers()
        self.adjust()

    def get_name(self, addr: int) -> str | None:
        if addr is None or self.sections is None:
            return None
        for i in range(len(self.sections) - 1):
            name, cur = self.sections[i]
            _, next = self.sections[i + 1]
            if cur is None or next is None or name is None:
                continue
            if cur <= addr < next:
                return name
        return None

    def adjust(self) -> None:
        if self.pi is None or self.pages is None or len(self.pages) == 0:
            return
        for i, page in enumerate(self.pages):
            name = self.get_name(page.start)
            if name is not None:
                page.objfile = name
        self.handle_user_pages()
        self.pi.handle_kernel_pages(self.pages)
        self.handle_offsets()

    def handle_user_pages(self) -> None:
        base_offset = self.pages[0].start
        for i in range(len(self.pages)):
            page = self.pages[i]
            if page.objfile != self.pi.USERLAND:
                break
            diff = page.start - base_offset
            if diff > 0x100000:
                if diff > 0x100000000000:
                    if page.execute:
                        page.objfile = "userland [library]"
                    elif page.rw:
                        page.objfile = "userland [stack]"
                else:
                    page.objfile = "userland [heap]"
            else:
                # page.objfile += f"_{hex(i)[2:]}"
                base_offset = page.start

    def handle_offsets(self) -> None:
        prev_objfile, base = "", 0
        for page in self.pages:
            # the check on KERNELRO is to make getting offsets for symbols such as `init_creds` more convinient
            if page.objfile != self.pi.KERNELRO and prev_objfile != page.objfile:
                prev_objfile = page.objfile
                base = page.start
            page.offset = page.start - base
            if len(hex(page.offset)) > 9:
                page.offset = 0


# Most of QemuMachine code was inherited from gdb-pt-dump thanks to Martin Radev (@martinradev)
# on the MIT license, see:
# https://github.com/martinradev/gdb-pt-dump/blob/21158ac3f9b36d0e5e0c86193e0ef018fc628e74/pt_gdb/pt_gdb.py#L11-L80
class QemuMachine(Machine):
    def __init__(self):
        super().__init__()
        self.file = None
        self.pid = QemuMachine.get_qemu_pid()
        self.file = os.open(f"/proc/{self.pid}/mem", os.O_RDONLY)

    def __del__(self):
        if self.file:
            os.close(self.file)

    @staticmethod
    def search_pids_for_file(pids: list[str], filename: str) -> str | None:
        for pid in pids:
            fd_dir = f"/proc/{pid}/fd"
            try:
                for fd in os.listdir(fd_dir):
                    if os.readlink(f"{fd_dir}/{fd}") == filename:
                        return pid
            except FileNotFoundError:
                # Either the process has gone or fds are changing, not our pid
                pass
            except PermissionError:
                # Evade processes owned by other users
                pass

        return None

    @staticmethod
    def get_qemu_pid():
        try:
            out = subprocess.check_output(["pgrep", "qemu-system"], encoding="utf8")
            pids = out.strip().split("\n")

            if len(pids) == 1:
                return int(pids[0], 10)
        except subprocess.CalledProcessError:
            # If no process with the name `qemu-system` is found, fallback to alternative methods,
            # as the binary name may vary (e.g., `qemu_system`).
            pass

        # We add a chardev file backend (we dont add a fronted, so it doesn't affect
        # the guest). We can then look through proc to find which process has the file
        # open. This approach is agnostic to namespaces (pid, network and mount).
        chardev_id = "gdb-pt-dump" + "-" + "".join(random.choices(string.ascii_letters, k=16))
        with tempfile.NamedTemporaryFile() as tmpf:
            pwndbg.dbg.selected_inferior().send_monitor(
                f"chardev-add file,id={chardev_id},path={tmpf.name}"
            )
            pid_found = QemuMachine.search_pids_for_file(pids, tmpf.name)
            pwndbg.dbg.selected_inferior().send_monitor(f"chardev-remove {chardev_id}")

        if not pid_found:
            raise ProcessLookupError("Could not find qemu-system pid")

        return int(pid_found, 10)

    def read_physical_memory(self, physical_address: int, length: int) -> bytes:
        res = pwndbg.dbg.selected_inferior().send_monitor(f"gpa2hva {hex(physical_address)}")

        # It's not possible to pread large sizes, so let's break the request
        # into a few smaller ones.
        max_block_size = 1024 * 1024 * 256
        try:
            hva = int(res.split(" ")[-1], 16)
            data = b""
            for offset in range(0, length, max_block_size):
                length_to_read = min(length - offset, max_block_size)
                block = os.pread(self.file, length_to_read, hva + offset)
                data += block
            return data
        except Exception as e:
            msg = f"Physical address ({hex(physical_address)}, +{hex(length)}) is not accessible. Reason: {e}. gpa2hva result: {res}"
            raise OSError(msg)

    def read_register(self, register_name: str) -> int:
        if register_name.startswith("$"):
            register_name = register_name[1:]

        return int(pwndbg.aglib.regs.read_reg(register_name))


@pwndbg.lib.cache.cache_until("stop")
def kernel_vmmap_via_page_tables() -> tuple[Page, ...]:
    if not pwndbg.aglib.qemu.is_qemu_kernel():
        return ()

    if sys.platform != "linux":
        # QemuMachine requires access to /proc/{qemu-pid}/mem, which is only available on Linux
        return ()

    try:
        machine_backend = QemuMachine()
    except PermissionError:
        print(
            message.error(
                "Permission error when attempting to parse page tables with gdb-pt-dump.\n"
                "Either change the kernel-vmmap setting, re-run GDB as root, or disable "
                "`ptrace_scope` (`echo 0 | sudo tee /proc/sys/kernel/yama/ptrace_scope`)"
            )
        )
        return ()
    except ProcessLookupError:
        print(
            message.error(
                "Could not find the PID for process named `qemu-system`.\n"
                "This might happen if pwndbg is running on a different machine than `qemu-system`,\n"
                "or if the `qemu-system` binary has a different name."
            )
        )
        return ()

    arch: str = pwndbg.aglib.arch.name
    ptrsize: int = pwndbg.aglib.arch.ptrsize
    if arch == "aarch64":
        arch_backend = PT_Aarch64_Backend(machine_backend)
    elif arch == "i386":
        arch_backend = PT_x86_64_Backend(machine_backend)
    elif arch == "x86-64":
        arch_backend = PT_x86_64_Backend(machine_backend)
    elif arch == "rv64":
        arch_backend = PT_RiscV64_Backend(machine_backend)
    else:
        print(
            message.error(
                f"The {pwndbg.aglib.arch.name} architecture does"
                " not support the `vmmap_via_page_tables`.\n"
                "Run `help show kernel-vmmap` for other options."
            )
        )
        return ()

    # If paging is not enabled, we shouldn't attempt to parse page tables
    if not pwndbg.aglib.kernel.paging_enabled():
        return ()

    p = PageTableDump(machine_backend, arch_backend)
    pages = p.arch_backend.parse_tables(p.cache, p.parser.parse_args(""))

    retpages: list[Page] = []
    for page in pages:
        start = page.va
        size = page.page_size
        flags = 4  # IMPLY ALWAYS READ
        if page.pwndbg_is_writeable():
            flags |= 2
        if page.pwndbg_is_executable():
            flags |= 1
        objfile = f"[pt_{hex(start)[2:-3]}]"
        retpages.append(Page(start, size, flags, 0, ptrsize, objfile))
    return tuple(retpages)


monitor_info_mem_not_warned = True


def _parser_mem_info_line_x86(line: str) -> Page | None:
    """
    Example response from `info mem`:
    ```
    ffff903580000000-ffff903580099000 0000000000099000 -rw
    ffff903580099000-ffff90358009b000 0000000000002000 -r-
    ffff90358009b000-ffff903582200000 0000000002165000 -rw
    ffff903582200000-ffff903582803000 0000000000603000 -r-
    ```
    """

    dash_idx = line.index("-")
    space_idx = line.index(" ")
    rspace_idx = line.rindex(" ")

    start = int(line[:dash_idx], 16)
    end = int(line[dash_idx + 1 : space_idx], 16)
    size = int(line[space_idx + 1 : rspace_idx], 16)
    perm = line[rspace_idx + 1 :]

    flags = 0
    if "r" in perm:
        flags |= Page.R_OK
    if "w" in perm:
        flags |= Page.W_OK
    if "x" in perm:
        flags |= Page.X_OK

    global monitor_info_mem_not_warned
    if end - start != size and monitor_info_mem_not_warned:
        print(
            message.warn(
                "The vmmap output may be incorrect as `monitor info mem` output assertion/assumption\n"
                "that end-start==size failed. The values are:\n"
                f"end={end:#x}; start={start:#x}; size={size:#x}; end-start={end - start:#x}\n"
                "Note that this warning will not show up again in this Pwndbg/GDB session."
            )
        )
        monitor_info_mem_not_warned = False

    return Page(start, size, flags, 0, pwndbg.aglib.arch.ptrsize, "<qemu>")


def _parser_mem_info_line_riscv64(line: str) -> Page | None:
    """
    Example response from `info mem`:
    ```
    vaddr            paddr            size             attr
    ---------------- ---------------- ---------------- -------
    0000000000010000 00000000feece000 0000000000001000 r-xu-a-
    0000000000011000 00000000fefeb000 0000000000002000 r-xu-a-
    0000000000013000 00000000a0a7a000 0000000000002000 r-xu-a-
    0000000000015000 00000000bfe02000 0000000000002000 r-xu-a-
    ```
    """

    arr = line.split(" ", 3)
    if len(arr) != 4:
        raise ValueError("invalid line format")

    start, _, size, perm = arr
    start = int(start, 16)
    size = int(size, 16)

    flags = 0
    if "r" in perm:
        flags |= Page.R_OK
    if "w" in perm:
        flags |= Page.W_OK
    if "x" in perm:
        flags |= Page.X_OK

    return Page(start, size, flags, 0, pwndbg.aglib.arch.ptrsize, "<qemu>")


@pwndbg.lib.cache.cache_until("stop")
def kernel_vmmap_via_monitor_info_mem() -> tuple[Page, ...]:
    """
    Returns Linux memory maps information by parsing `monitor info mem` output
    from QEMU kernel GDB stub.
    Works only on X86/X64/RISC-V as this is what QEMU supports.

    Consider using the `kernel_vmmap_via_page_tables` method
    as it is probably more reliable/better.

    See also: https://github.com/pwndbg/pwndbg/pull/685
    (TODO: revisit with future QEMU versions)
    """
    if not pwndbg.aglib.qemu.is_qemu_kernel():
        return ()

    try:
        monitor_info_mem = pwndbg.dbg.selected_inferior().send_monitor("info mem")
    except pwndbg.dbg_mod.Error:
        # Exception should not happen in new qemu, can we clean up it?
        # Older versions of QEMU/GDB may throw `gdb.error: "monitor" command
        # not supported by this target`. Newer versions will not throw, but will
        # return a string starting with 'unknown command:'.
        monitor_info_mem = "unknown command"

    parser_func = None
    if pwndbg.aglib.arch.name in ("i386", "x86-64"):
        parser_func = _parser_mem_info_line_x86
    elif pwndbg.aglib.arch.name == "rv64":
        parser_func = _parser_mem_info_line_riscv64

    if parser_func is None or "unknown command" in monitor_info_mem:
        print(
            message.error(
                f"The {pwndbg.aglib.arch.name} architecture does"
                " not support the `monitor info mem` command.\n"
                "Run `help show kernel-vmmap` for other options."
            )
        )
        return ()

    pages: list[Page] = []
    for line in monitor_info_mem.splitlines():
        try:
            page = parser_func(line)
        except Exception:
            # invalid format
            continue
        pages.append(page)

    return tuple(pages)


kernel_vmmap_mode = pwndbg.config.add_param(
    "kernel-vmmap",
    "page-tables",
    "the method to get vmmap information when debugging via QEMU kernel",
    help_docstring="""\
Values explained:

+ `page-tables` - walk page tables to render vmmap
+ `pt-dump` - read /proc/$qemu-pid/mem to parse kernel page tables to render vmmap
+ `monitor` - use QEMU's `monitor info mem` to render vmmap
+ `none` - disable vmmap rendering; useful if rendering is particularly slow

Note that the page-tables method will require the QEMU kernel process to be on the same machine and within the same PID namespace. Running QEMU kernel and GDB in different Docker containers will not work. Consider running both containers with --pid=host (meaning they will see and so be able to interact with all processes on the machine).
""",
    param_class=pwndbg.lib.config.PARAM_ENUM,
    enum_sequence=["page-tables", "pt-dump", "monitor", "none"],
)


@pwndbg.lib.cache.cache_until("stop")
def kernel_vmmap_pages() -> tuple[Page, ...]:
    mode = str(kernel_vmmap_mode)
    arch_name = pwndbg.aglib.arch.name
    if mode == "page-tables" and arch_name not in ("x86-64", "aarch64"):
        # TODO: remove this by implementing `RiscvPagingInfo`, `RiscvOps`, etc
        print(
            message.warn(
                f"`kernel-vmmap = {mode}` unsupported for {arch_name}, defaulting to `monitor`"
            )
        )
        mode = "monitor"
    match mode:
        case "page-tables":
            # has the user set the pgd with kcurrent?
            # None if not which gets properly handled
            entry = pwndbg.commands.kcurrent.KCURRENT_PGD
            if pwndbg.aglib.memory.is_kernel(entry):
                entry = pwndbg.aglib.kernel.pagewalk(entry, virt=False).phys
            return pwndbg.aglib.kernel.pagescan(entry)
        case "pt-dump":
            return kernel_vmmap_via_page_tables()
        case "monitor":
            return kernel_vmmap_via_monitor_info_mem()
    return ()


def kernel_vmmap() -> tuple[pwndbg.lib.memory.Page, ...]:
    if not pwndbg.aglib.qemu.is_qemu_kernel():
        return ()

    if pwndbg.aglib.arch.name not in (
        "i386",
        "x86-64",
        "aarch64",
        "rv32",
        "rv64",
    ):
        return ()

    pages = kernel_vmmap_pages()
    kv = KernelVmmap(pages)
    if kernel_vmmap_mode == "monitor" and pwndbg.aglib.arch.name == "x86-64":
        # TODO: check version here when QEMU displays the x bit for x64
        # see: https://github.com/pwndbg/pwndbg/pull/3020#issuecomment-2914573242
        for page in pages:
            if page.objfile == kv.pi.ESPSTACK:
                continue
            entry = pwndbg.aglib.kernel.pagewalk(page.start).entry
            if entry and entry >> 63 == 0:
                page.flags |= 1

    return tuple(pages)
