/*
 * Decompiled with CFR 0.152.
 */
package org.apache.sysds.runtime.compress.colgroup;

import java.util.Arrays;
import java.util.Iterator;
import org.apache.commons.lang.NotImplementedException;
import org.apache.sysds.runtime.compress.CompressionSettings;
import org.apache.sysds.runtime.compress.colgroup.ADictionary;
import org.apache.sysds.runtime.compress.colgroup.ColGroup;
import org.apache.sysds.runtime.compress.colgroup.ColGroupValue;
import org.apache.sysds.runtime.compress.utils.ABitmap;
import org.apache.sysds.runtime.functionobjects.Builtin;
import org.apache.sysds.runtime.functionobjects.KahanFunction;
import org.apache.sysds.runtime.functionobjects.KahanPlus;
import org.apache.sysds.runtime.instructions.cp.KahanObject;
import org.apache.sysds.runtime.matrix.data.IJV;
import org.apache.sysds.runtime.matrix.data.MatrixBlock;

public abstract class ColGroupDDC
extends ColGroupValue {
    private static final long serialVersionUID = -3204391646123465004L;

    protected ColGroupDDC() {
    }

    protected ColGroupDDC(int[] colIndices, int numRows, ABitmap ubm, CompressionSettings cs) {
        super(colIndices, numRows, ubm, cs);
    }

    protected ColGroupDDC(int[] colIndices, int numRows, ADictionary dict) {
        super(colIndices, numRows, dict);
    }

    @Override
    public ColGroup.CompressionType getCompType() {
        return ColGroup.CompressionType.DDC;
    }

    @Override
    public void decompressToBlock(MatrixBlock target, int rl, int ru) {
        int ncol = this.getNumCols();
        double[] values = this.getValues();
        for (int i = rl; i < ru; ++i) {
            for (int j = 0; j < ncol; ++j) {
                target.appendValue(i, this._colIndexes[j], this.getData(i, j, values));
            }
        }
    }

    @Override
    public void decompressToBlock(MatrixBlock target, int[] colIndexTargets) {
        int ncol = this.getNumCols();
        double[] dictionary = this.getValues();
        for (int i = 0; i < this._numRows; ++i) {
            for (int colIx = 0; colIx < ncol; ++colIx) {
                int origMatrixColIx = this.getColIndex(colIx);
                int col = colIndexTargets[origMatrixColIx];
                double cellVal = this.getData(i, colIx, dictionary);
                target.quickSetValue(i, col, cellVal);
            }
        }
    }

    @Override
    public void decompressToBlock(MatrixBlock target, int colpos) {
        int ncol = this.getNumCols();
        double[] c = target.getDenseBlockValues();
        double[] values = this.getValues();
        int nnz = 0;
        for (int i = 0; i < this._numRows; ++i) {
            int index = this.getIndex(i);
            if (index != values.length) {
                c[i] = values[index * ncol + colpos];
                nnz += c[i] != 0.0 ? 1 : 0;
                continue;
            }
            c[i] = 0.0;
        }
        target.setNonZeros(nnz);
    }

    @Override
    public double get(int r, int c) {
        int ix = Arrays.binarySearch(this._colIndexes, c);
        if (ix < 0) {
            throw new RuntimeException("Column index " + c + " not in DDC group.");
        }
        int index = this.getIndex(r, ix);
        if (index != this.getNumValues()) {
            return this._dict.getValue(index);
        }
        return 0.0;
    }

    @Override
    public void countNonZerosPerRow(int[] rnnz, int rl, int ru) {
        int ncol = this.getNumCols();
        int numVals = this.getNumValues();
        for (int i = rl; i < ru; ++i) {
            int lnnz = 0;
            for (int colIx = 0; colIx < ncol; ++colIx) {
                int index = this.getIndex(i, colIx);
                if (index >= numVals) continue;
                lnnz += this._dict.getValue(this.getIndex(i, colIx)) != 0.0 ? 1 : 0;
            }
            int n = i - rl;
            rnnz[n] = rnnz[n] + lnnz;
        }
    }

    @Override
    protected void computeSum(double[] c, KahanFunction kplus) {
        c[0] = c[0] + this._dict.sum(this.getCounts(), this._colIndexes.length, kplus);
    }

    @Override
    protected void computeColSums(double[] c, KahanFunction kplus) {
        this._dict.colSum(c, this.getCounts(), this._colIndexes, kplus);
    }

    @Override
    protected void computeRowSums(double[] c, KahanFunction kplus, int rl, int ru, boolean mean) {
        int numVals = this.getNumValues();
        KahanObject kbuff = new KahanObject(0.0, 0.0);
        KahanPlus kplus2 = KahanPlus.getKahanPlusFnObject();
        double[] vals = this._dict.sumAllRowsToDouble(kplus, kbuff, this._colIndexes.length);
        for (int rix = rl; rix < ru; ++rix) {
            int index = this.getIndex(rix);
            if (index == numVals) continue;
            this.setandExecute(c, kbuff, kplus2, vals[index], rix * (2 + (mean ? 1 : 0)));
        }
    }

    @Override
    protected void computeRowMxx(double[] c, Builtin builtin, int rl, int ru) {
        int numVals = this.getNumValues();
        int ncol = this.getNumCols();
        double[] dictionary = this.getValues();
        for (int i = rl; i < ru; ++i) {
            int rowIndex = this.getIndex(i);
            if (rowIndex != numVals) {
                for (int j = 0; j < ncol; ++j) {
                    c[i] = builtin.execute(c[i], dictionary[rowIndex + j]);
                }
                continue;
            }
            c[i] = builtin.execute(c[i], 0.0);
        }
    }

    public void postScaling(double[] values, double[] vals, double[] c, int numVals) {
        this.postScaling(values, vals, c, numVals, 0, 0);
    }

    public void postScaling(double[] values, double[] vals, double[] c, int numVals, int i, int totalCols) {
        int ncol = this.getNumCols();
        for (int j = 0; j < ncol; ++j) {
            int colIx = this._colIndexes[j] + i * totalCols;
            int k = 0;
            int valOff = 0;
            while (k < numVals) {
                double aval = vals[k];
                if (valOff != numVals) {
                    int n = colIx;
                    c[n] = c[n] + aval * values[valOff + j];
                }
                ++k;
                valOff += ncol;
            }
        }
    }

    @Override
    public int[] getCounts(int[] counts) {
        return this.getCounts(0, this._numRows, counts);
    }

    @Override
    public int[] getCounts(int rl, int ru, int[] counts) {
        for (int i = rl; i < ru; ++i) {
            int index;
            int n = index = this.getIndex(i);
            counts[n] = counts[n] + 1;
        }
        return counts;
    }

    @Override
    public void rightMultByMatrix(double[] b, double[] c, int numVals, double[] values, int rl, int ru, int vOff) {
        throw new NotImplementedException("Not Implemented");
    }

    @Override
    public void leftMultByMatrix(double[] a, double[] c, int numVals, double[] values, int numRows, int numCols, int rl, int ru, int voff) {
        numVals = numVals == -1 ? this.getNumValues() : numVals;
        int i = rl;
        int j = voff;
        while (i < ru) {
            if (8 * numVals < this._numRows) {
                double[] vals = this.preAggregate(a, numVals, j);
                this.postScaling(values, vals, c, numVals, i, numCols);
            } else {
                int k = 0;
                int aOff = j * this._numRows;
                while (k < this._numRows) {
                    int valOff;
                    double aval = a[aOff];
                    if (aval != 0.0 && (valOff = this.getIndex(k) * this._colIndexes.length) != numVals) {
                        for (int h = 0; h < this._colIndexes.length; ++h) {
                            int colIx;
                            int n = colIx = this._colIndexes[h] + i * numCols;
                            c[n] = c[n] + aval * values[valOff + h];
                        }
                    }
                    ++k;
                    ++aOff;
                }
            }
            ++i;
            ++j;
        }
    }

    @Override
    public void leftMultByRowVector(double[] a, double[] result, int numVals) {
        numVals = numVals == -1 ? this.getNumValues() : numVals;
        double[] values = this.getValues();
        this.leftMultByRowVector(a, result, numVals, values);
    }

    public double[] preAggregate(double[] a, int numVals) {
        return this.preAggregate(a, numVals, 0);
    }

    public double[] preAggregate(double[] a, int numVals, int aRows) {
        double[] vals;
        if (aRows > 0) {
            vals = ColGroupDDC.allocDVector(numVals, true);
            int i = 0;
            int off = this._numRows * aRows;
            while (i < this._numRows) {
                int index = this.getIndex(i);
                if (index != numVals) {
                    int n = index;
                    vals[n] = vals[n] + a[off];
                }
                ++i;
                ++off;
            }
        } else {
            vals = ColGroupDDC.allocDVector(numVals, true);
            for (int i = 0; i < this._numRows; ++i) {
                int index = this.getIndex(i);
                if (index == numVals) continue;
                int n = index;
                vals[n] = vals[n] + a[i];
            }
        }
        return vals;
    }

    @Override
    public void leftMultByRowVector(double[] a, double[] c, int numVals, double[] values) {
        int n = numVals = numVals == -1 ? this.getNumValues() : numVals;
        if (8 * numVals < this._numRows) {
            double[] vals = this.preAggregate(a, numVals);
            this.postScaling(values, vals, c, numVals);
        } else {
            for (int i = 0; i < this._numRows; ++i) {
                double aval = a[i];
                if (aval == 0.0) continue;
                int valOff = this.getIndex(i) * this._colIndexes.length;
                for (int j = 0; j < this._colIndexes.length; ++j) {
                    if (valOff == numVals) continue;
                    int n2 = this._colIndexes[j];
                    c[n2] = c[n2] + aval * values[valOff + j];
                }
            }
        }
    }

    @Override
    public Iterator<IJV> getIterator(int rl, int ru, boolean inclZeros, boolean rowMajor) {
        return new DDCIterator(rl, ru, inclZeros);
    }

    @Override
    public ColGroup.ColGroupRowIterator getRowIterator(int rl, int ru) {
        return new DDCRowIterator(rl, ru);
    }

    @Override
    public String toString() {
        StringBuilder sb = new StringBuilder();
        sb.append(super.toString());
        return sb.toString();
    }

    protected abstract int getIndex(int var1);

    protected abstract int getIndex(int var1, int var2);

    protected abstract double getData(int var1, double[] var2);

    protected abstract double getData(int var1, int var2, double[] var3);

    protected abstract void setData(int var1, int var2);

    private class DDCRowIterator
    extends ColGroup.ColGroupRowIterator {
        public DDCRowIterator(int rl, int ru) {
            super(ColGroupDDC.this);
        }

        @Override
        public void next(double[] buff, int rowIx, int segIx, boolean last) {
            int clen = ColGroupDDC.this.getNumCols();
            int off = ColGroupDDC.this.getIndex(rowIx) * clen;
            double[] values = ColGroupDDC.this.getValues();
            for (int j = 0; j < clen; ++j) {
                buff[ColGroupDDC.this._colIndexes[j]] = values[off + j];
            }
        }
    }

    private class DDCIterator
    implements Iterator<IJV> {
        private final int _ru;
        private final boolean _inclZeros;
        private final IJV _buff = new IJV();
        private int _rpos = -1;
        private int _cpos = -1;
        private double _value = 0.0;

        public DDCIterator(int rl, int ru, boolean inclZeros) {
            this._ru = ru;
            this._inclZeros = inclZeros;
            this._rpos = rl;
            this._cpos = -1;
            this.getNextValue();
        }

        @Override
        public boolean hasNext() {
            return this._rpos < this._ru;
        }

        @Override
        public IJV next() {
            this._buff.set(this._rpos, ColGroupDDC.this._colIndexes[this._cpos], this._value);
            this.getNextValue();
            return this._buff;
        }

        private void getNextValue() {
            do {
                boolean nextRow = this._cpos + 1 >= ColGroupDDC.this.getNumCols();
                this._rpos += nextRow ? 1 : 0;
                int n = this._cpos = nextRow ? 0 : this._cpos + 1;
                if (this._rpos >= this._ru) {
                    return;
                }
                this._value = ColGroupDDC.this._dict.getValue(ColGroupDDC.this.getIndex(this._rpos, this._cpos));
            } while (!this._inclZeros && this._value == 0.0);
        }
    }
}

