import { createHAMTRecord, hamtEmpty, hamtGet, ImmutableHAMT } from "@/ds/hamt";
import shallowEqual from "@/misc/shallowEqual";
import { Diff } from "@/misc/types";

export type CounterState<T> = Readonly<T[]>;

export interface Counter<T> {
    add(state: CounterState<T>, tabId: string, values: T[]): CounterState<T>;
    remove(state: CounterState<T>, tabId: string, values: T[]): CounterState<T>;
    removeTabId(state: CounterState<T>, tabId: string): CounterState<T>;
    getInitialState: () => CounterState<T>;
}

/** @function Maintain a counter for the added/removed elements.
 * The methods to update the counter take the previous `CounterState` as an argument,
 * and return the new `CounterState` to the caller.
 * We do a `shallowEqual` check between the old/updated state, and only return
 * a new `CounterState` instance when it has been meaningfully updated
 * (i.e. not shallow equal).
 *
 * @returns a `Counter`.
 */
export function createCounter<T>(): Counter<T> {
    const c = new Map<T, Map<string, number>>();

    // Convert the count map into an array of its keys (which are in order from the Map)
    const getState = (): CounterState<T> => [...c.keys()];

    // Return the new state object if keys not shallow equal to last state
    const updateState = (state: CounterState<T>) => {
        const newState = getState();
        if (shallowEqual(state, newState)) {
            return state;
        }
        else {
            return Object.freeze(newState);
        }
    };

    const ensureIdMap = (id: T): Map<string, number> => {
        const m = c.get(id);
        if (m) return m;

        const v = new Map<string, number>();
        c.set(id, v);
        return v;
    };

    // Add value(s) to the counter and return the new state
    const add = (state: CounterState<T>, tabId: string, ids: T[]): CounterState<T> => {
        ids.forEach(id => {
            const currentIdMap = ensureIdMap(id);
            const val = currentIdMap.get(tabId);
            currentIdMap.set(tabId, (val ?? 0) + 1);
        });

        return updateState(state);
    };

    // Remove value(s) from the counter (deleting if necessary) and return the new state.
    // Note this silently returns when removing a value not present in the counter.
    const remove = (state: CounterState<T>, tabId: string, ids: T[]): CounterState<T> => {
        ids.forEach(id => {
            const currentIdMap = c.get(id);
            if (!currentIdMap) return;

            const val = currentIdMap.get(tabId);
            if (val === undefined) return;

            if (val >= 2) {
                currentIdMap.set(tabId, val - 1);
                return;
            }

            currentIdMap.delete(tabId);
            if (currentIdMap.size === 0) {
                c.delete(id);
            }
        });

        return updateState(state);
    };

    const removeTabId = (state: CounterState<T>, tabId: string): CounterState<T> => {
        for (const [id, m] of [...c.entries()]) {
            const removed = m.delete(tabId);
            if (removed && m.size === 0) {
                c.delete(id);
            }
        }

        return updateState(state);
    };

    return {
        getInitialState: () => Object.freeze([]),
        add,
        remove,
        removeTabId,
    };
}

const increment = (n: number | undefined) => n ? n + 1 : 1;
const decrement = (n: number | undefined) => (!n || n === 1) ? undefined : n - 1;

export type HAMTCounterState<T extends string = string> = {
    hamt: ImmutableHAMT<T, ImmutableHAMT<string, number>>;
    latest: Diff<T>;
};

export interface HAMTCounter<T extends string = string> {
    add(state: HAMTCounterState<T>, tabId: string, ids: T[]): HAMTCounterState<T>;
    remove(state: HAMTCounterState<T>, tabId: string, ids: T[]): HAMTCounterState<T>;
    removeTabId(state: HAMTCounterState<T>, tabId: string): HAMTCounterState<T>;
    getInitialState: () => Readonly<HAMTCounterState<T>>;
}

const innerOps = createHAMTRecord<string, number>();
export function createHAMTCounter<T extends string = string>(): HAMTCounter<T> {
    const ops = createHAMTRecord<T, ImmutableHAMT<string, number>>();
    return {
        getInitialState: () =>
            Object.freeze({
                hamt: ops.initial(),
                latest: {},
            }),
        add: (state, tabId, ids) => {
            if (ids.length === 0) return state;

            const latest: Diff<T> = { added: [], removed: [] };

            // Could accumulate ids here first, but I think it's unlikely there
            // are lots of duplicate ids here.

            const hamt = ops.updateMany(
                state.hamt,
                ids,
                (_: T, innerHAMT: ImmutableHAMT<string, number> | undefined) => {
                    return innerHAMT ? innerOps.update(innerHAMT, tabId, increment)
                        : innerOps.initial({ [tabId]: 1 });
                },
                (k, v1, v2) => {
                    if (!v1 && v2) latest.added!.push(k);
                },
            );

            return Object.freeze({
                hamt,
                latest,
            });
        },
        remove: (state, tabId, ids) => {
            if (ids.length === 0) return state;

            const latest: Diff<T> = { added: [], removed: [] };

            const hamt = ops.updateMany(
                state.hamt,
                ids,
                (_: T, innerHAMT: ImmutableHAMT<string, number> | undefined) => {
                    if (innerHAMT === undefined) return;

                    if (!hamtGet(innerHAMT, tabId)) return innerHAMT;

                    const newInnerHAMT = innerOps.update(innerHAMT, tabId, decrement);
                    return hamtEmpty(newInnerHAMT) ? undefined : newInnerHAMT;
                },
                (k, v1, v2) => {
                    if (v1 && !v2) latest.removed!.push(k);
                },
            );

            if (hamt === state.hamt) return state;

            return Object.freeze({
                hamt,
                latest,
            });
        },
        removeTabId: (state, tabId) => {
            const latest: Diff<T> = { added: [], removed: [] };

            const hamt = ops.updateAll(
                state.hamt,
                (_: T, innerHAMT: ImmutableHAMT<string, number> | undefined) => {
                    if (innerHAMT === undefined) return;

                    if (!hamtGet(innerHAMT, tabId)) return innerHAMT;

                    const newInnerHAMT = innerOps.remove(innerHAMT, tabId);
                    return hamtEmpty(newInnerHAMT) ? undefined : newInnerHAMT;
                },
                (k, v1, v2) => {
                    if (v1 && !v2) latest.removed!.push(k);
                },
            );

            if (hamt === state.hamt) return state;

            return Object.freeze({
                hamt,
                latest,
            });
        },
    };
}
