/*
 * Decompiled with CFR 0.152.
 */
package org.jkiss.dbeaver.model.ai.impl;

import java.util.ArrayList;
import java.util.Collections;
import java.util.List;
import java.util.Objects;
import java.util.stream.Collectors;
import org.jkiss.code.NotNull;
import org.jkiss.dbeaver.model.ai.AIMessage;
import org.jkiss.dbeaver.model.ai.AIMessageType;
import org.jkiss.dbeaver.model.ai.impl.TokenCounter;

public final class ChatTruncator {
    private final int maxTokens;
    private final int reserveForSystem;
    private final int reserveForReply;
    private final int reserveForOverhead;
    private final TokenCounter counter;

    private ChatTruncator(Builder b) {
        this.maxTokens = b.maxTokens;
        this.reserveForReply = b.reserveForReply;
        this.reserveForOverhead = b.reserveForOverhead;
        this.counter = Objects.requireNonNull(b.counter, "TokenCounter is required");
        this.reserveForSystem = b.reserveForSystem;
        if (this.maxTokens <= this.reserveForReply + this.reserveForOverhead) {
            throw new IllegalArgumentException("maxDbSnapshotTokens too small for the reserves");
        }
    }

    @NotNull
    public List<AIMessage> truncate(@NotNull List<AIMessage> input) {
        List<AIMessage> messages = ChatTruncator.filterNonEmpty(input);
        if (messages.isEmpty()) {
            return List.of();
        }
        ArrayList<AIMessage> systems = new ArrayList<AIMessage>();
        ArrayList<AIMessage> rest = new ArrayList<AIMessage>(messages.size());
        for (AIMessage m : messages) {
            if (m.getRole() == AIMessageType.SYSTEM) {
                systems.add(m);
                continue;
            }
            rest.add(m);
        }
        AIMessage mergedSystem = systems.isEmpty() ? null : ChatTruncator.mergeSystems(systems);
        int systemTokens = mergedSystem != null ? this.counter.count(mergedSystem.getContent()) : 0;
        int systemCap = Math.min(systemTokens, this.reserveForSystem);
        int headroom = this.maxTokens - this.reserveForReply - this.reserveForOverhead;
        int budget = Math.max(0, headroom - systemCap);
        ArrayList<AIMessage> pickedReverse = new ArrayList<AIMessage>(rest.size());
        int used = 0;
        for (int i = rest.size() - 1; i >= 0; --i) {
            AIMessage cut;
            int cutTokens;
            AIMessage m = (AIMessage)rest.get(i);
            int t = this.counter.count(m.getContent());
            if (used + t <= budget) {
                pickedReverse.add(m);
                used += t;
                continue;
            }
            int remaining = budget - used;
            if (remaining <= 0 || (cutTokens = this.counter.count((cut = this.truncateToTokens(m, remaining)).getContent())) <= 0) break;
            pickedReverse.add(cut);
            used += cutTokens;
            break;
        }
        Collections.reverse(pickedReverse);
        ArrayList<AIMessage> result = new ArrayList<AIMessage>(pickedReverse.size() + 1);
        if (mergedSystem != null) {
            int remainingForSystem = Math.max(0, headroom - used);
            result.add(this.truncateToTokens(mergedSystem, remainingForSystem));
        }
        result.addAll(pickedReverse);
        return result;
    }

    private static AIMessage mergeSystems(List<AIMessage> systems) {
        assert (!systems.isEmpty()) : "At least one SYSTEM message is required";
        String mergedMessage = systems.stream().map(AIMessage::getContent).filter(s -> !s.isBlank()).collect(Collectors.joining("\n\n---\n\n"));
        return systems.getFirst().withContent(mergedMessage);
    }

    private static List<AIMessage> filterNonEmpty(List<AIMessage> in) {
        if (in == null || in.isEmpty()) {
            return List.of();
        }
        ArrayList<AIMessage> out = new ArrayList<AIMessage>(in.size());
        for (AIMessage m : in) {
            if (m == null || m.getContent().isBlank()) continue;
            out.add(m);
        }
        return out;
    }

    private AIMessage truncateToTokens(AIMessage message, int maxTokens) {
        if (maxTokens <= 0) {
            return message.withContent("");
        }
        if (this.counter.count(message.getContent()) <= maxTokens) {
            return message;
        }
        String content = message.getContent();
        int lo = 0;
        int hi = content.length();
        AIMessage best = message.withContent("");
        while (lo <= hi) {
            int mid = lo + hi >>> 1;
            String slice = ChatTruncator.safeHead(content, mid);
            AIMessage candidate = message.withContent(slice);
            int t = this.counter.count(candidate.getContent());
            if (t <= maxTokens) {
                best = candidate;
                lo = mid + 1;
                continue;
            }
            hi = mid - 1;
        }
        return best;
    }

    private static String safeHead(String s, int headLen) {
        if (headLen <= 0) {
            return "";
        }
        if (headLen >= s.length()) {
            return s;
        }
        return s.substring(0, headLen);
    }

    public static Builder builder() {
        return new Builder();
    }

    public static final class Builder {
        private int maxTokens;
        private int reserveForSystem;
        private int reserveForReply;
        private int reserveForOverhead;
        private TokenCounter counter;

        public Builder maxTokens(int v) {
            this.maxTokens = v;
            return this;
        }

        public Builder reserveForSystem(int v) {
            this.reserveForSystem = v;
            return this;
        }

        public Builder reserveForReply(int v) {
            this.reserveForReply = v;
            return this;
        }

        public Builder reserveForOverhead(int v) {
            this.reserveForOverhead = v;
            return this;
        }

        public Builder tokenCounter(TokenCounter c) {
            this.counter = c;
            return this;
        }

        public ChatTruncator build() {
            assert (this.maxTokens > this.reserveForReply + this.reserveForOverhead + this.reserveForSystem) : "maxDbSnapshotTokens must be greater than the sum of reserves";
            return new ChatTruncator(this);
        }
    }
}

