/*
 * Decompiled with CFR 0.152.
 */
package com.enterprisedt.bouncycastle.tls.crypto.impl;

import com.enterprisedt.bouncycastle.tls.ProtocolVersion;
import com.enterprisedt.bouncycastle.tls.SecurityParameters;
import com.enterprisedt.bouncycastle.tls.TlsFatalAlert;
import com.enterprisedt.bouncycastle.tls.TlsUtils;
import com.enterprisedt.bouncycastle.tls.crypto.TlsCipher;
import com.enterprisedt.bouncycastle.tls.crypto.TlsCryptoParameters;
import com.enterprisedt.bouncycastle.tls.crypto.TlsCryptoUtils;
import com.enterprisedt.bouncycastle.tls.crypto.TlsDecodeResult;
import com.enterprisedt.bouncycastle.tls.crypto.TlsEncodeResult;
import com.enterprisedt.bouncycastle.tls.crypto.TlsSecret;
import com.enterprisedt.bouncycastle.tls.crypto.impl.TlsAEADCipherImpl;
import com.enterprisedt.bouncycastle.tls.crypto.impl.TlsImplUtils;
import com.enterprisedt.bouncycastle.util.Arrays;
import java.io.IOException;

public final class TlsAEADCipher
implements TlsCipher {
    public static final int AEAD_CCM = 1;
    public static final int AEAD_CHACHA20_POLY1305 = 2;
    public static final int AEAD_GCM = 3;
    private final TlsCryptoParameters a;
    private final int b;
    private final int c;
    private final int d;
    private final int e;
    private final TlsAEADCipherImpl f;
    private final TlsAEADCipherImpl g;
    private final byte[] h;
    private final byte[] i;
    private final byte[] j;
    private final byte[] k;
    private final boolean l;
    private final boolean m;
    private final boolean n;
    private final int o;

    public TlsAEADCipher(TlsCryptoParameters cryptoParams, TlsAEADCipherImpl encryptCipher, TlsAEADCipherImpl decryptCipher, int keySize, int macSize, int aeadType) throws IOException {
        SecurityParameters securityParameters = cryptoParams.getSecurityParametersHandshake();
        ProtocolVersion protocolVersion = securityParameters.getNegotiatedVersion();
        if (!TlsImplUtils.isTLSv12(protocolVersion)) {
            throw new TlsFatalAlert(80);
        }
        this.n = TlsImplUtils.isTLSv13(protocolVersion);
        this.o = TlsAEADCipher.a(this.n, aeadType);
        this.j = securityParameters.getConnectionIDPeer();
        this.k = securityParameters.getConnectionIDLocal();
        this.l = this.n || !Arrays.isNullOrEmpty(this.j);
        this.m = this.n || !Arrays.isNullOrEmpty(this.k);
        switch (this.o) {
            case 1: {
                this.d = 4;
                this.e = 8;
                break;
            }
            case 2: {
                this.d = 12;
                this.e = 0;
                break;
            }
            default: {
                throw new TlsFatalAlert(80);
            }
        }
        this.a = cryptoParams;
        this.b = keySize;
        this.c = macSize;
        this.f = decryptCipher;
        this.g = encryptCipher;
        this.h = new byte[this.d];
        this.i = new byte[this.d];
        boolean bl = cryptoParams.isServer();
        if (this.n) {
            this.a(securityParameters, decryptCipher, this.h, !bl);
            this.a(securityParameters, encryptCipher, this.i, bl);
            return;
        }
        int n2 = 2 * keySize + 2 * this.d;
        byte[] byArray = TlsImplUtils.calculateKeyBlock(cryptoParams, n2);
        int n3 = 0;
        if (bl) {
            decryptCipher.setKey(byArray, n3, keySize);
            encryptCipher.setKey(byArray, n3 += keySize, keySize);
            System.arraycopy(byArray, n3 += keySize, this.h, 0, this.d);
            System.arraycopy(byArray, n3 += this.d, this.i, 0, this.d);
            n3 += this.d;
        } else {
            encryptCipher.setKey(byArray, n3, keySize);
            decryptCipher.setKey(byArray, n3 += keySize, keySize);
            System.arraycopy(byArray, n3 += keySize, this.i, 0, this.d);
            System.arraycopy(byArray, n3 += this.d, this.h, 0, this.d);
            n3 += this.d;
        }
        if (n2 != n3) {
            throw new TlsFatalAlert(80);
        }
    }

    @Override
    public int getCiphertextDecodeLimit(int plaintextLimit) {
        int n2 = plaintextLimit + (this.l ? 1 : 0);
        return n2 + this.c + this.e;
    }

    @Override
    public int getCiphertextEncodeLimit(int plaintextLimit) {
        int n2 = plaintextLimit + (this.m ? 1 : 0);
        return n2 + this.c + this.e;
    }

    @Override
    public int getPlaintextDecodeLimit(int ciphertextLimit) {
        int n2 = ciphertextLimit - this.c - this.e;
        return n2 - (this.l ? 1 : 0);
    }

    @Override
    public int getPlaintextEncodeLimit(int ciphertextLimit) {
        int n2 = ciphertextLimit - this.c - this.e;
        return n2 - (this.m ? 1 : 0);
    }

    @Override
    public TlsEncodeResult encodePlaintext(long seqNo, short contentType, ProtocolVersion recordVersion, int headerAllocation, byte[] plaintext, int plaintextOffset, int plaintextLength) throws IOException {
        int n2;
        byte[] byArray = new byte[this.i.length + this.e];
        switch (this.o) {
            case 1: {
                System.arraycopy(this.i, 0, byArray, 0, this.i.length);
                TlsUtils.writeUint64(seqNo, byArray, this.i.length);
                break;
            }
            case 2: {
                TlsUtils.writeUint64(seqNo, byArray, byArray.length - 8);
                for (n2 = 0; n2 < this.i.length; ++n2) {
                    int n3 = n2;
                    byArray[n3] = (byte)(byArray[n3] ^ this.i[n2]);
                }
                break;
            }
            default: {
                throw new TlsFatalAlert(80);
            }
        }
        n2 = plaintextLength + (this.m ? 1 : 0);
        this.g.init(byArray, this.c);
        int n4 = this.g.getOutputSize(n2);
        int n5 = this.e + n4;
        byte[] byArray2 = new byte[headerAllocation + n5];
        int n6 = headerAllocation;
        if (this.e != 0) {
            System.arraycopy(byArray, byArray.length - this.e, byArray2, n6, this.e);
            n6 += this.e;
        }
        short s2 = contentType;
        if (this.m) {
            s2 = this.n ? (short)23 : 25;
        }
        byte[] byArray3 = this.a(seqNo, s2, recordVersion, n5, n2, this.k);
        try {
            System.arraycopy(plaintext, plaintextOffset, byArray2, n6, plaintextLength);
            if (this.m) {
                byArray2[n6 + plaintextLength] = (byte)contentType;
            }
            n6 += this.g.doFinal(byArray3, byArray2, n6, n2, byArray2, n6);
        }
        catch (RuntimeException runtimeException) {
            throw new TlsFatalAlert(80, (Throwable)runtimeException);
        }
        if (n6 != byArray2.length) {
            throw new TlsFatalAlert(80);
        }
        return new TlsEncodeResult(byArray2, 0, byArray2.length, s2);
    }

    @Override
    public TlsDecodeResult decodeCiphertext(long seqNo, short recordType, ProtocolVersion recordVersion, byte[] ciphertext, int ciphertextOffset, int ciphertextLength) throws IOException {
        int n2;
        short s2;
        int n3;
        block10: {
            byte by;
            int n4;
            if (this.getPlaintextDecodeLimit(ciphertextLength) < 0) {
                throw new TlsFatalAlert(50);
            }
            byte[] byArray = new byte[this.h.length + this.e];
            switch (this.o) {
                case 1: {
                    System.arraycopy(this.h, 0, byArray, 0, this.h.length);
                    System.arraycopy(ciphertext, ciphertextOffset, byArray, byArray.length - this.e, this.e);
                    break;
                }
                case 2: {
                    TlsUtils.writeUint64(seqNo, byArray, byArray.length - 8);
                    for (n3 = 0; n3 < this.h.length; ++n3) {
                        int n5 = n3;
                        byArray[n5] = (byte)(byArray[n5] ^ this.h[n3]);
                    }
                    break;
                }
                default: {
                    throw new TlsFatalAlert(80);
                }
            }
            this.f.init(byArray, this.c);
            n3 = ciphertextOffset + this.e;
            int n6 = ciphertextLength - this.e;
            int n7 = this.f.getOutputSize(n6);
            byte[] byArray2 = this.a(seqNo, recordType, recordVersion, ciphertextLength, n7, this.j);
            try {
                n4 = this.f.doFinal(byArray2, ciphertext, n3, n6, ciphertext, n3);
            }
            catch (RuntimeException runtimeException) {
                throw new TlsFatalAlert(20, (Throwable)runtimeException);
            }
            if (n4 != n7) {
                throw new TlsFatalAlert(80);
            }
            s2 = recordType;
            n2 = n7;
            if (!this.l) break block10;
            do {
                if (--n2 >= 0) continue;
                throw new TlsFatalAlert(10);
            } while (0 == (by = ciphertext[n3 + n2]));
            s2 = (short)(by & 0xFF);
        }
        return new TlsDecodeResult(ciphertext, n3, n2, s2);
    }

    @Override
    public void rekeyDecoder() throws IOException {
        this.a(this.a.getSecurityParametersConnection(), this.f, this.h, !this.a.isServer());
    }

    @Override
    public void rekeyEncoder() throws IOException {
        this.a(this.a.getSecurityParametersConnection(), this.g, this.i, this.a.isServer());
    }

    @Override
    public boolean usesOpaqueRecordTypeDecode() {
        return this.l;
    }

    @Override
    public boolean usesOpaqueRecordTypeEncode() {
        return this.m;
    }

    private byte[] a(long l2, short s2, ProtocolVersion protocolVersion, int n2, int n3, byte[] byArray) throws IOException {
        if (!Arrays.isNullOrEmpty(byArray)) {
            int n4 = byArray.length;
            byte[] byArray2 = new byte[23 + n4];
            TlsUtils.writeUint64(-1L, byArray2, 0);
            TlsUtils.writeUint8((short)25, byArray2, 8);
            TlsUtils.writeUint8(n4, byArray2, 9);
            TlsUtils.writeUint8((short)25, byArray2, 10);
            TlsUtils.writeVersion(protocolVersion, byArray2, 11);
            TlsUtils.writeUint64(l2, byArray2, 13);
            System.arraycopy(byArray, 0, byArray2, 21, n4);
            TlsUtils.writeUint16(n3, byArray2, 21 + n4);
            return byArray2;
        }
        if (this.n) {
            byte[] byArray3 = new byte[5];
            TlsUtils.writeUint8(s2, byArray3, 0);
            TlsUtils.writeVersion(protocolVersion, byArray3, 1);
            TlsUtils.writeUint16(n2, byArray3, 3);
            return byArray3;
        }
        byte[] byArray4 = new byte[13];
        TlsUtils.writeUint64(l2, byArray4, 0);
        TlsUtils.writeUint8(s2, byArray4, 8);
        TlsUtils.writeVersion(protocolVersion, byArray4, 9);
        TlsUtils.writeUint16(n3, byArray4, 11);
        return byArray4;
    }

    private void a(SecurityParameters securityParameters, TlsAEADCipherImpl tlsAEADCipherImpl, byte[] byArray, boolean bl) throws IOException {
        TlsSecret tlsSecret;
        if (!this.n) {
            throw new TlsFatalAlert(80);
        }
        TlsSecret tlsSecret2 = tlsSecret = bl ? securityParameters.getTrafficSecretServer() : securityParameters.getTrafficSecretClient();
        if (null == tlsSecret) {
            throw new TlsFatalAlert(80);
        }
        this.a(tlsAEADCipherImpl, byArray, tlsSecret, securityParameters.getPRFCryptoHashAlgorithm());
    }

    private void a(TlsAEADCipherImpl tlsAEADCipherImpl, byte[] byArray, TlsSecret tlsSecret, int n2) throws IOException {
        byte[] byArray2 = TlsCryptoUtils.hkdfExpandLabel(tlsSecret, n2, "key", TlsUtils.EMPTY_BYTES, this.b).extract();
        byte[] byArray3 = TlsCryptoUtils.hkdfExpandLabel(tlsSecret, n2, "iv", TlsUtils.EMPTY_BYTES, this.d).extract();
        tlsAEADCipherImpl.setKey(byArray2, 0, this.b);
        System.arraycopy(byArray3, 0, byArray, 0, this.d);
    }

    private static int a(boolean bl, int n2) throws IOException {
        switch (n2) {
            case 1: 
            case 3: {
                return bl ? 2 : 1;
            }
            case 2: {
                return 2;
            }
        }
        throw new TlsFatalAlert(80);
    }
}

