#include <inttypes.h>
#include <xen/xen.h>

#include "emu.h"
#include "emu-mm.c"

/* --------------------------------------------------------------------- */

static char *print_pgflags(uint32_t flags)
{
    static char buf[80];

    snprintf(buf, sizeof(buf), "%s%s%s%s%s%s%s%s%s\n",
	     flags & _PAGE_GLOBAL   ? " global"   : "",
	     flags & _PAGE_PSE      ? " pse"      : "",
	     flags & _PAGE_DIRTY    ? " dirty"    : "",
	     flags & _PAGE_ACCESSED ? " accessed" : "",
	     flags & _PAGE_PCD      ? " pcd"      : "",
	     flags & _PAGE_PWT      ? " pwt"      : "",
	     flags & _PAGE_USER     ? " user"     : "",
	     flags & _PAGE_RW       ? " write"    : "",
	     flags & _PAGE_PRESENT  ? " present"  : "");
    return buf;
}

void pgtable_walk(int level, uint64_t va, uint64_t root_mfn)
{
    void *physmem = (void*)XEN_RAM_64;
    uint64_t *pgd, *pud, *pmd, *pte;
    uint64_t mfn;
    uint32_t slot, flags;

    if (vmconf.debug_level < level)
	return;

    printk(level, "page table walk for va %" PRIx64 ", root_mfn %" PRIx64 "\n",
	   va, root_mfn);

    pgd   = physmem + frame_to_addr(root_mfn);
    slot  = PGD_INDEX_64(va);
    mfn   = get_pgframe_64(pgd[slot]);
    flags = get_pgflags_64(pgd[slot]);
    printk(level, "pgd   : %p +%3d  |  mfn %4" PRIx64 "  |  %s",
	   pgd, slot, mfn, print_pgflags(flags));
    if (!(flags & _PAGE_PRESENT))
	return;

    pud   = physmem + frame_to_addr(mfn);
    slot  = PUD_INDEX_64(va);
    mfn   = get_pgframe_64(pud[slot]);
    flags = get_pgflags_64(pud[slot]);
    printk(level, " pud  : %p +%3d  |  mfn %4" PRIx64 "  |  %s",
	   pud, slot, mfn, print_pgflags(flags));
    if (!(flags & _PAGE_PRESENT))
	return;

    pmd   = physmem + frame_to_addr(mfn);
    slot  = PMD_INDEX_64(va);
    mfn   = get_pgframe_64(pmd[slot]);
    flags = get_pgflags_64(pmd[slot]);
    printk(level, "  pmd : %p +%3d  |  mfn %4" PRIx64 "  |  %s",
	   pmd, slot, mfn, print_pgflags(flags));
    if (!(flags & _PAGE_PRESENT))
	return;
    if (flags & _PAGE_PSE)
	return;

    pte   = physmem + frame_to_addr(mfn);
    slot  = PTE_INDEX_64(va);
    mfn   = get_pgframe_64(pte[slot]);
    flags = get_pgflags_64(pte[slot]);
    printk(level, "   pte: %p +%3d  |  mfn %4" PRIx64 "  |  %s",
	   pte, slot, mfn, print_pgflags(flags));
}

int pgtable_fixup_flag(struct xen_cpu *cpu, uint64_t va, uint32_t flag)
{
    void *physmem = (void*)XEN_RAM_64;
    uint64_t *pgd, *pud, *pmd, *pte;
    uint32_t slot;
    int fixes = 0;

    /* quick test on the leaf page via linear page table */
    pte = find_pte_64(va);
    if (!test_pgflag_64(*pte, flag)) {
	*pte |= flag;
	fixes++;
	goto done;
    }

    /* do full page table walk */
    pgd   = physmem + frame_to_addr(read_cr3_mfn(cpu));
    slot  = PGD_INDEX_64(va);
    if (!test_pgflag_64(pgd[slot], flag)) {
	pgd[slot] |= flag;
	fixes++;
    }

    pud   = physmem + frame_to_addr(get_pgframe_64(pgd[slot]));
    slot  = PUD_INDEX_64(va);
    if (!test_pgflag_64(pud[slot], flag)) {
	pud[slot] |= flag;
	fixes++;
    }

    pmd   = physmem + frame_to_addr(get_pgframe_64(pud[slot]));
    slot  = PMD_INDEX_64(va);
    if (!test_pgflag_64(pmd[slot], flag)) {
	pmd[slot] |= flag;
	fixes++;
    }

done:
    if (fixes)
	flush_tlb_addr(va);
    return fixes;
}

int pgtable_is_present(uint64_t va, uint64_t root_mfn)
{
    void *physmem = (void*)XEN_RAM_64;
    uint64_t *pgd, *pud, *pmd, *pte;
    uint32_t slot;

    pgd  = physmem + frame_to_addr(root_mfn);
    slot = PGD_INDEX_64(va);
    if (!test_pgflag_64(pgd[slot], _PAGE_PRESENT))
        return 0;

    pud  = physmem + frame_to_addr(get_pgframe_64(pgd[slot]));
    slot = PUD_INDEX_64(va);
    if (!test_pgflag_64(pud[slot], _PAGE_PRESENT))
        return 0;

    pmd  = physmem + frame_to_addr(get_pgframe_64(pud[slot]));
    slot = PMD_INDEX_64(va);
    if (!test_pgflag_64(pmd[slot], _PAGE_PRESENT))
        return 0;
    if (!test_pgflag_64(pmd[slot], _PAGE_PSE))
        return 1;

    pte   = physmem + frame_to_addr(get_pgframe_64(pmd[slot]));
    slot  = PTE_INDEX_64(va);
    if (!test_pgflag_64(pmd[slot], _PAGE_PRESENT))
        return 0;

    return 1;
}

/* --------------------------------------------------------------------- */

void *map_page(uint64_t maddr)
{
    void *ram = (void*)XEN_RAM_64;
    return ram + maddr;
}

uint64_t *find_pte_64(uint64_t va)
{
    uint64_t *lpt_base = (void*)XEN_LPT_64;
    uint64_t offset = (va & 0xffffffffffff) >> PAGE_SHIFT;

    return lpt_base + offset;
}

void update_emu_mappings(uint64_t cr3_mfn)
{
    uint64_t *new_pgd;
    int idx;

    new_pgd  = map_page(frame_to_addr(cr3_mfn));

    idx = PGD_INDEX_64(XEN_M2P_64);
    for (; idx < PGD_INDEX_64(XEN_DOM_64); idx++) {
        if (test_pgflag_64(new_pgd[idx], _PAGE_PRESENT))
            continue;
        if (!test_pgflag_64(emu_pgd[idx], _PAGE_PRESENT))
            continue;
        if (idx == PGD_INDEX_64(XEN_LPT_64))
            continue;
        new_pgd[idx] = emu_pgd[idx];
    }

    /* linear pgtable mapping */
    idx = PGD_INDEX_64(XEN_LPT_64);
    new_pgd[idx] = get_pgentry_64(cr3_mfn, LPT_PGFLAGS);
}

static inline uint64_t *find_pgd(uint64_t va, uint64_t mfn, int alloc)
{
    void *physmem = (void*)XEN_RAM_64;
    uint64_t *pgd, *pud;

    pgd  = physmem + frame_to_addr(mfn);
    pgd += PGD_INDEX_64(va);
    if (!test_pgflag_64(*pgd, _PAGE_PRESENT) && alloc) {
	pud = get_pages(1, "pud");
	*pgd = get_pgentry_64(EMU_MFN(pud), PGT_PGFLAGS_64);
        emu_pgd[PGD_INDEX_64(va)] = *pgd; /* sync emu boot pgt */
    }
    return pgd;
}

static inline uint64_t *find_pud(uint64_t va, uint64_t mfn, int alloc)
{
    void *physmem = (void*)XEN_RAM_64;
    uint64_t *pud, *pmd;

    pud  = physmem + frame_to_addr(mfn);
    pud += PUD_INDEX_64(va);
    if (!test_pgflag_64(*pud, _PAGE_PRESENT) && alloc) {
	pmd = get_pages(1, "pmd");
	*pud = get_pgentry_64(EMU_MFN(pmd), PGT_PGFLAGS_64);
    }
    return pud;
}

static inline uint64_t *find_pmd(uint64_t va, uint64_t mfn, int alloc)
{
    void *physmem = (void*)XEN_RAM_64;
    uint64_t *pmd, *pte;

    pmd  = physmem + frame_to_addr(mfn);
    pmd += PMD_INDEX_64(va);
    if (!test_pgflag_64(*pmd, _PAGE_PRESENT) && alloc) {
	pte = get_pages(1, "pte");
	*pmd = get_pgentry_64(EMU_MFN(pte), PGT_PGFLAGS_64);
    }
    return pmd;
}

static inline uint64_t *find_pte(uint64_t va, uint64_t mfn)
{
    void *physmem = (void*)XEN_RAM_64;
    uint64_t *pte;

    pte  = physmem + frame_to_addr(mfn);
    pte += PTE_INDEX_64(va);
    return pte;
}

static int map_region_pse(struct xen_cpu *cpu, uint64_t va_start, uint32_t flags,
			  uint64_t start, uint64_t count)
{
    uint64_t *pgd;
    uint64_t *pud;
    uint64_t *pmd;
    uint64_t va;
    uint64_t mfn;

    flags |= _PAGE_PSE;
    for (mfn = start; mfn < start + count; mfn += PMD_COUNT_64) {
	va = va_start + frame_to_addr(mfn-start);

	pgd = find_pgd(va, read_cr3_mfn(cpu), 1);
	pud = find_pud(va, get_pgframe_64(*pgd), 1);
	pmd = find_pmd(va, get_pgframe_64(*pud), 0);
	*pmd = get_pgentry_64(mfn, flags);
    }
    return 0;
}

void *fixmap_page(struct xen_cpu *cpu, uint64_t maddr)
{
    static int fixmap_slot = 0;
    uint64_t mfn = addr_to_frame(maddr);
    uint32_t off = addr_offset(maddr);
    uint64_t va;
    uint64_t *pgd;
    uint64_t *pud;
    uint64_t *pmd;
    uint64_t *pte;

    va = XEN_MAP_64 + PAGE_SIZE * fixmap_slot++;
    pgd = find_pgd(va, read_cr3_mfn(cpu), 1);
    pud = find_pud(va, get_pgframe_64(*pgd), 1);
    pmd = find_pmd(va, get_pgframe_64(*pud), 1);
    pte = find_pte(va, get_pgframe_64(*pmd));
    *pte = get_pgentry_64(mfn, EMU_PGFLAGS);
#if 0 /* debug */
    pgtable_walk(1, va, read_cr3_mfn(cpu));
#endif
    return (void*)va + off;
}

void paging_init(struct xen_cpu *cpu)
{
    map_region_pse(cpu, XEN_RAM_64, EMU_PGFLAGS,    0,              vmconf.pg_total);
    map_region_pse(cpu, XEN_M2P_64, M2P_PGFLAGS_64, vmconf.mfn_m2p, vmconf.pg_m2p);
//    find_pgd(XEN_MAP_64, read_cr3_mfn(cpu), 1);
    m2p = (void*)XEN_M2P_64;
}
