import GenerationBackend from "./backend.js";
import * as Util from "../misc/util.js";
import settings from "../misc/settings.js";
import InstructTemplate from "@/ai/instruct-template.js";

// Generate a random-ish ASCII string to use as an identifier.  We should be able to
// use crypto.randomUUID, but browsers don't allow local networks to be secure.
const genkey = Math.round(Math.random() * 1000000000).toString();

export default class GenerationBackendKobold extends GenerationBackend
{
    constructor()
    {
        super();

        this._generationComplete = false;
    }

    get url() { return settings.values.serverUrl; }
    get generationComplete() { return this._generationComplete; }

    async getModelName()
    {
        let { result: model } = await this.koboldGet("api/v1/model");
        return model;
    }

    async getMaxContextLength()
    {
        let { value: tokens } = await this.koboldGet("api/v1/config/max_context_length");
        return tokens;
    }

    async getTokenCount(prompt)
    {
        let { value: tokens } = await this.koboldPost("api/extra/tokencount", {
            prompt
        });

        return tokens;
    }

    async tokenizeString(prompt)
    {
        let { ids } = await this.koboldPost("/api/extra/tokencount", {
            prompt
        });
        return ids;
    }

    async *generate({
        prompt='',

        // If set, specify a generation seed.
        seed=null,

        settingsOverrides={},
        signal=null,
    }={})
    {
        signal ??= (new AbortController()).signal;

        signal.throwIfAborted();

        // Combine 
        let effectiveSettings = {
            ...settings.values,
            ...settingsOverrides,
        };

        // Gather settings.
        let {
            contextSize, maxOutputLength, temperature,
            topK, topP, topA, minP, typicalP, tfs,
            samplerOrder,
            repetitionPenalty, repetitionPenaltyRange,
            grammar,
        } = effectiveSettings;

        if(seed === null)
            seed = -1;

        let instructTemplate = InstructTemplate.getActiveTemplate();

        let stopSequences = [];
        if(instructTemplate.stopSequence.length > 0)
            stopSequences.push(instructTemplate.stopSequence);

        // API arguments:
        let args = {
            max_context_length: contextSize,
            max_length: maxOutputLength,
            temperature,
            top_p: topP,
            top_k: topK,
            top_a: topA,
            min_p: minP,
            typical: typicalP,
            tfs,
            rep_pen: repetitionPenalty,
            rep_pen_range: repetitionPenaltyRange,
            sampler_order: samplerOrder,
            stop_sequence: stopSequences,
            smoothing_factor: 0,
            presence_penalty: 0,
            grammar,
            sampler_seed: seed,
            prompt,
            // logit_bias,
            // presence_penalty
            // dynatemp_range: 1,
            // dynatemp_exponent: 5,
            // dynatemp_exponent,
        };

        try {
            let response = await this.beginGeneration(args, { signal });
            if(!response)
            {
                console.log("Connection error during generation");
                return;
            }

            let reader = response.body.getReader();
            for await(let { data, type, id } of Util.eventStreamReader(reader, { signal }))
            {
                let { token } = data;
                yield token;
            }
        } catch(e) {
            // On error, abort generation.  Don't wait for this to finish.
            console.log("Cancelling generation due to error:", e);
            this.abortGeneration();
            this._generationComplete = false;
            throw e;
        }

        // If the abort signal was fired, make sure generation is stopped server-side.
        // We don't wait for this to complete.
        if(signal.aborted)
        {
            this.abortGeneration();
            this._generationComplete = false;
            return;
        }

        // See if generation finished.  stopReason is 0 if we reached the output limit (the
        // model has more to say), 1 if we hit EOS (we're finished), or 2 if we hit a stop
        // sequence.
        let { stop_reason: stopReason } = await this.koboldGet("/api/extra/perf");
        this._generationComplete = stopReason == 1

        // console.log(`Generation ${this._generationComplete? "complete": "incomplete"}`);
    }

    // This isn't generalized yet.
    async beginGeneration(args, { signal }={})
    {
        args = {
            ...args,
            genkey,
        };

        return await this.koboldPost("api/extra/generate/stream", {
            ...args,
            json: false,
            signal,
        });
    }

    async abortGeneration({ signal }={})
    {
        let { success } = await this.koboldPost("api/extra/abort", {
            genkey,
            signal,
        });

        if(success != 'true') // watch out, this isn't a boolean
            console.log("Error aborting generation");
        else
            console.log("Aborted generation");
    }

    async koboldRequest(path, {
        method="POST",
        signal=null,
        json=true,
        ...args
    })
    {
        signal ??= (new AbortController()).signal;
        let requestUrl = `${this.url}/${path}`;
        let options = {
            method,
            headers: {
                "Content-Type": "application/json",
            },
            signal,
        };

        if(method == "POST")
        {
            options.body = JSON.stringify(args);
        }

        let response;
        try {
            response = await fetch(requestUrl, options);
        } catch(e) {
            console.log("Connection error");
            return json? { }: null;
        }

        if(!json)
            return response;
        try {
            return await response.json();
        } catch(e) {
            console.log(`Error parsing JSON from request ${path}`, e);
            return { };
        }
    }

    async koboldGet(path, { ...args }={})
    {
        return this.koboldRequest(path, {
            method: "GET",
            ...args
        });
    }

    async koboldPost(path, { ...args }={})
    {
        return this.koboldRequest(path, {
            method: "POST",
            ...args
        });
    }
}
