/*
 * Decompiled with CFR 0.152.
 */
package org.apache.kylin.dict;

import com.google.common.annotations.VisibleForTesting;
import com.google.common.base.Preconditions;
import com.google.common.collect.Lists;
import java.io.ByteArrayInputStream;
import java.io.ByteArrayOutputStream;
import java.io.DataInput;
import java.io.DataInputStream;
import java.io.DataOutput;
import java.io.IOException;
import java.io.ObjectInputStream;
import java.io.ObjectOutputStream;
import java.io.PrintStream;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
import org.apache.kylin.common.util.Bytes;
import org.apache.kylin.common.util.BytesUtil;
import org.apache.kylin.common.util.ClassUtil;
import org.apache.kylin.common.util.Dictionary;
import org.apache.kylin.dict.BytesConverter;
import org.apache.kylin.dict.CacheDictionary;
import org.apache.kylin.dict.StringBytesConverter;
import org.apache.kylin.dict.TrieDictionaryBuilder;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

public class TrieDictionary<T>
extends CacheDictionary<T> {
    private static final long serialVersionUID = 1L;
    public static final byte[] MAGIC = new byte[]{84, 114, 105, 101, 68, 105, 99, 116};
    public static final int MAGIC_SIZE_I = MAGIC.length;
    public static final int BIT_IS_LAST_CHILD = 128;
    public static final int BIT_IS_END_OF_VALUE = 64;
    private static final Logger logger = LoggerFactory.getLogger(TrieDictionary.class);
    private byte[] trieBytes;
    private transient int headSize;
    private transient int bodyLen;
    private transient int sizeChildOffset;
    private transient int sizeNoValuesBeneath;
    private transient int maxValueLength;
    private transient int nValues;
    private transient int sizeOfId;
    private transient long childOffsetMask;
    private transient int firstByteOffset;

    public TrieDictionary() {
    }

    public TrieDictionary(byte[] trieBytes) {
        this.init(trieBytes);
    }

    private void init(byte[] trieBytes) {
        this.trieBytes = trieBytes;
        if (BytesUtil.compareBytes(MAGIC, 0, trieBytes, 0, MAGIC.length) != 0) {
            throw new IllegalArgumentException("Wrong file type (magic does not match)");
        }
        try {
            DataInputStream headIn = new DataInputStream(new ByteArrayInputStream(trieBytes, MAGIC_SIZE_I, trieBytes.length - MAGIC_SIZE_I));
            this.headSize = headIn.readShort();
            this.bodyLen = headIn.readInt();
            this.sizeChildOffset = headIn.read();
            this.sizeNoValuesBeneath = headIn.read();
            this.baseId = headIn.readShort();
            this.maxValueLength = headIn.readShort();
            if (this.maxValueLength < 0) {
                throw new IllegalStateException("maxValueLength is negative (" + this.maxValueLength + "). Dict value is too long, whose length is larger than " + Short.MAX_VALUE);
            }
            String converterName = headIn.readUTF();
            if (!converterName.isEmpty()) {
                this.setConverterByName(converterName);
            }
            this.nValues = BytesUtil.readUnsigned(trieBytes, this.headSize + this.sizeChildOffset, this.sizeNoValuesBeneath);
            this.sizeOfId = BytesUtil.sizeForValue((long)(this.baseId + this.nValues) + 1L);
            this.childOffsetMask = 192L << (this.sizeChildOffset - 1) * 8 ^ 0xFFFFFFFFFFFFFFFFL;
            this.firstByteOffset = this.sizeChildOffset + this.sizeNoValuesBeneath + 1;
            this.enableCache();
        }
        catch (Exception e) {
            if (e instanceof RuntimeException) {
                throw (RuntimeException)e;
            }
            throw new RuntimeException(e);
        }
    }

    protected void setConverterByName(String converterName) throws Exception {
        this.bytesConvert = ClassUtil.forName(converterName, BytesConverter.class).getDeclaredConstructor(new Class[0]).newInstance(new Object[0]);
    }

    @Override
    public int getMinId() {
        return this.baseId;
    }

    @Override
    public int getMaxId() {
        return this.baseId + this.nValues - 1;
    }

    @Override
    public int getSizeOfId() {
        return this.sizeOfId;
    }

    @Override
    public int getSizeOfValue() {
        return this.maxValueLength;
    }

    public int getStorageSizeInBytes() {
        return this.trieBytes.length;
    }

    @Override
    protected int getIdFromValueBytesWithoutCache(byte[] value, int offset, int len, int roundingFlag) {
        int seq = this.lookupSeqNoFromValue(this.headSize, value, offset, offset + len, roundingFlag);
        int id = this.calcIdFromSeqNo(seq);
        if (id < 0) {
            logger.debug("Not a valid value: " + this.bytesConvert.convertFromBytes(value, offset, len));
        }
        return id;
    }

    private int lookupSeqNoFromValue(int n, byte[] inp, int o, int inpEnd, int roundingFlag) {
        if (o == inpEnd) {
            return this.checkFlag(this.headSize, 64) ? 0 : this.roundSeqNo(roundingFlag, -1, -1, 0);
        }
        int seq = 0;
        block0: while (true) {
            int c;
            int p = n + this.firstByteOffset;
            int end = p + BytesUtil.readUnsigned(this.trieBytes, p - 1, 1);
            ++p;
            while (p < end && o < inpEnd) {
                if (this.trieBytes[p] != inp[o]) {
                    int comp = BytesUtil.compareByteUnsigned(this.trieBytes[p], inp[o]);
                    if (comp < 0) {
                        seq += BytesUtil.readUnsigned(this.trieBytes, n + this.sizeChildOffset, this.sizeNoValuesBeneath);
                    }
                    return this.roundSeqNo(roundingFlag, seq - 1, -1, seq);
                }
                ++p;
                ++o;
            }
            boolean isEndOfValue = this.checkFlag(n, 64);
            if (o == inpEnd) {
                return p == end && isEndOfValue ? seq : this.roundSeqNo(roundingFlag, seq - 1, -1, seq);
            }
            if (isEndOfValue) {
                ++seq;
            }
            if ((c = this.getChildOffset(n)) == this.headSize) {
                return this.roundSeqNo(roundingFlag, seq - 1, -1, seq);
            }
            byte inpByte = inp[o];
            while (true) {
                int comp;
                if ((comp = BytesUtil.compareByteUnsigned(this.trieBytes[p = c + this.firstByteOffset], inpByte)) == 0) {
                    n = c;
                    ++o;
                    continue block0;
                }
                if (comp >= 0) break block0;
                seq += BytesUtil.readUnsigned(this.trieBytes, c + this.sizeChildOffset, this.sizeNoValuesBeneath);
                if (this.checkFlag(c, 128)) {
                    return this.roundSeqNo(roundingFlag, seq - 1, -1, seq);
                }
                c = p + BytesUtil.readUnsigned(this.trieBytes, p - 1, 1);
            }
            break;
        }
        return this.roundSeqNo(roundingFlag, seq - 1, -1, seq);
    }

    private int getChildOffset(int n) {
        long offset = (long)this.headSize + (BytesUtil.readLong(this.trieBytes, n, this.sizeChildOffset) & this.childOffsetMask);
        assert (offset < (long)this.trieBytes.length);
        return (int)offset;
    }

    private int roundSeqNo(int roundingFlag, int i, int j, int k) {
        if (roundingFlag == 0) {
            return j;
        }
        if (roundingFlag < 0) {
            return i;
        }
        return k;
    }

    @Override
    protected byte[] getValueBytesFromIdWithoutCache(int id) {
        byte[] buf = new byte[this.maxValueLength];
        int len = this.getValueBytesFromIdImpl(id, buf, 0);
        if (len == buf.length) {
            return buf;
        }
        byte[] result = new byte[len];
        System.arraycopy(buf, 0, result, 0, len);
        return result;
    }

    protected int getValueBytesFromIdImpl(int id, byte[] returnValue, int offset) {
        int seq = this.calcSeqNoFromId(id);
        return this.lookupValueFromSeqNo(this.headSize, seq, returnValue, offset);
    }

    private int lookupValueFromSeqNo(int n, int seq, byte[] returnValue, int offset) {
        int o = offset;
        block0: while (true) {
            int p = n + this.firstByteOffset;
            int len = BytesUtil.readUnsigned(this.trieBytes, p - 1, 1);
            System.arraycopy(this.trieBytes, p, returnValue, o, len);
            o += len;
            boolean isEndOfValue = this.checkFlag(n, 64);
            if (isEndOfValue && --seq < 0) {
                return o - offset;
            }
            int c = this.getChildOffset(n);
            if (c == this.headSize) {
                return -1;
            }
            while (true) {
                int nValuesBeneath;
                if (seq - (nValuesBeneath = BytesUtil.readUnsigned(this.trieBytes, c + this.sizeChildOffset, this.sizeNoValuesBeneath)) < 0) {
                    n = c;
                    continue block0;
                }
                seq -= nValuesBeneath;
                if (this.checkFlag(c, 128)) {
                    return -1;
                }
                p = c + this.firstByteOffset;
                c = p + BytesUtil.readUnsigned(this.trieBytes, p - 1, 1);
            }
            break;
        }
    }

    private boolean checkFlag(int offset, int bit) {
        return (this.trieBytes[offset] & bit) > 0;
    }

    private int calcIdFromSeqNo(int seq) {
        if (seq < 0 || seq >= this.nValues) {
            return -1;
        }
        return this.baseId + seq;
    }

    @Override
    public void write(DataOutput out) throws IOException {
        out.write(this.trieBytes);
    }

    @Override
    public void readFields(DataInput in) throws IOException {
        byte[] headPartial = new byte[MAGIC.length + 16 + 32];
        in.readFully(headPartial);
        if (BytesUtil.compareBytes(MAGIC, 0, headPartial, 0, MAGIC.length) != 0) {
            throw new IllegalArgumentException("Wrong file type (magic does not match)");
        }
        DataInputStream headIn = new DataInputStream(new ByteArrayInputStream(headPartial, MAGIC_SIZE_I, headPartial.length - MAGIC_SIZE_I));
        short headSize = headIn.readShort();
        int bodyLen = headIn.readInt();
        headIn.close();
        byte[] all = new byte[headSize + bodyLen];
        System.arraycopy(headPartial, 0, all, 0, headPartial.length);
        in.readFully(all, headPartial.length, all.length - headPartial.length);
        this.init(all);
    }

    private void writeObject(ObjectOutputStream stream) throws IOException {
        stream.writeInt(this.trieBytes.length);
        stream.write(this.trieBytes);
    }

    private void readObject(ObjectInputStream stream) throws IOException, ClassNotFoundException {
        int currentCount;
        int length = stream.readInt();
        byte[] trieBytes = new byte[length];
        int idx = 0;
        while ((currentCount = stream.read(trieBytes, idx, length - idx)) > 0) {
            idx += currentCount;
        }
        this.init(trieBytes);
    }

    @Override
    public List<T> enumeratorValues() {
        ArrayList result = Lists.newArrayListWithExpectedSize((int)this.getSize());
        byte[] buf = new byte[this.maxValueLength];
        this.visitNode(this.headSize, buf, 0, result);
        return result;
    }

    @VisibleForTesting
    List<T> enumeratorValuesByParent() {
        return super.enumeratorValues();
    }

    private void visitNode(int n, byte[] returnValue, int offset, List<T> result) {
        int c;
        int o = offset;
        int p = n + this.firstByteOffset;
        int len = BytesUtil.readUnsigned(this.trieBytes, p - 1, 1);
        System.arraycopy(this.trieBytes, p, returnValue, o, len);
        o += len;
        boolean isEndOfValue = this.checkFlag(n, 64);
        if (isEndOfValue) {
            Object curNodeValue = this.bytesConvert.convertFromBytes(returnValue, 0, o);
            result.add(curNodeValue);
        }
        if ((c = this.getChildOffset(n)) == this.headSize) {
            return;
        }
        while (true) {
            this.visitNode(c, returnValue, o, result);
            if (this.checkFlag(c, 128)) {
                return;
            }
            p = c + this.firstByteOffset;
            c = p + BytesUtil.readUnsigned(this.trieBytes, p - 1, 1);
        }
    }

    @Override
    public void dump(PrintStream out) {
        out.println("Total " + this.nValues + " values");
        for (int i = 0; i < this.nValues; ++i) {
            int id = this.calcIdFromSeqNo(i);
            Object value = this.getValueFromId(id);
            out.println(id + " (" + Integer.toHexString(id) + "): " + value);
        }
    }

    public int hashCode() {
        return Arrays.hashCode(this.trieBytes);
    }

    public boolean equals(Object o) {
        if (!(o instanceof TrieDictionary)) {
            logger.info("Equals return false because it's not TrieDictionary");
            return false;
        }
        TrieDictionary that = (TrieDictionary)o;
        return Arrays.equals(this.trieBytes, that.trieBytes);
    }

    @Override
    public boolean contains(Dictionary other) {
        if (other.getSize() > this.getSize()) {
            return false;
        }
        for (int i = other.getMinId(); i <= other.getMaxId(); ++i) {
            Object v = other.getValueFromId(i);
            if (this.containsValue(v)) continue;
            return false;
        }
        return true;
    }

    public static void main(String[] args) throws Exception {
        TrieDictionaryBuilder<String> b = new TrieDictionaryBuilder<String>(new StringBytesConverter());
        b.addValue("part");
        b.print();
        b.addValue("part");
        b.print();
        b.addValue("par");
        b.print();
        b.addValue("partition");
        b.print();
        b.addValue("party");
        b.print();
        b.addValue("parties");
        b.print();
        b.addValue("paint");
        b.print();
        TrieDictionary<String> dict = b.build(0);
        ByteArrayOutputStream baos = new ByteArrayOutputStream();
        new ObjectOutputStream(baos).writeObject(dict);
        TrieDictionary dict2 = (TrieDictionary)new ObjectInputStream(new ByteArrayInputStream(baos.toByteArray())).readObject();
        Preconditions.checkArgument((boolean)dict.contains((Dictionary)dict2));
        Preconditions.checkArgument((boolean)dict2.contains(dict));
        Preconditions.checkArgument((boolean)dict.equals(dict2));
        for (int i = 0; i <= dict.getMaxId(); ++i) {
            System.out.println(Bytes.toString(dict.getValueBytesFromIdWithoutCache(i)));
        }
    }
}

