/*
 * Decompiled with CFR 0.152.
 */
package org.apache.sysds.hops.codegen.cplan.cuda;

import org.apache.sysds.hops.codegen.cplan.CNodeBinary;
import org.apache.sysds.hops.codegen.cplan.CodeTemplate;
import org.apache.sysds.runtime.matrix.data.LibMatrixNative;

public class Binary
extends CodeTemplate {
    @Override
    public String getTemplate(CNodeBinary.BinType type, boolean sparseLhs, boolean sparseRhs, boolean scalarVector, boolean scalarInput, boolean vectorVector) {
        if (type == CNodeBinary.BinType.VECT_CBIND) {
            if (scalarInput) {
                return "\t\tVector<T>& %TMP% = vectCbindWrite(%IN1%, %IN2%, this);\n";
            }
            if (!vectorVector) {
                return sparseLhs ? "\t\tVector<T>& %TMP% = vectCbindWrite(%IN1v%, %IN2%, %IN1i%, %POS1%, alen, %LEN%, this);\n" : "\t\tVector<T>& %TMP% = vectCbindWrite(%IN1%, %IN2%, %POS1%, %LEN%, this);\n";
            }
            return sparseLhs ? "\t\tVector<T>& %TMP% = vectCbindWrite(%IN1v%, %IN2%, %IN1i%, %POS1%, %POS2%, alen, %LEN1%, %LEN2%, this);\n" : "\t\tVector<T>& %TMP% = vectCbindWrite(%IN1%, %IN2%, %POS1%, %POS2%, %LEN1%, %LEN2%, this);\n";
        }
        if (LibMatrixNative.isSinglePrecision()) {
            switch (type) {
                case DOT_PRODUCT: {
                    return sparseLhs ? "\tT %TMP% = LibSpoofPrimitives.dotProduct(%IN1v%, %IN2%, %IN1i%, %POS1%, %POS2%, alen);\n" : "\tT %TMP% = LibSpoofPrimitives.dotProduct(%IN1%, %IN2%, %POS1%, %POS2%, %LEN%);\n";
                }
                case VECT_MATRIXMULT: {
                    return sparseLhs ? "\tT[] %TMP% = LibSpoofPrimitives.vectMatrixMult(%IN1v%, %IN2%, %IN1i%, %POS1%, %POS2%, alen, len);\n" : "\tT[] %TMP% = LibSpoofPrimitives.vectMatrixMult(%IN1%, %IN2%, %POS1%, %POS2%, %LEN%);\n";
                }
                case VECT_OUTERMULT_ADD: {
                    return sparseLhs ? "\tLibSpoofPrimitives.vectOuterMultAdd(%IN1v%, %IN2%, %OUT%, %IN1i%, %POS1%, %POS2%, %POSOUT%, alen, %LEN1%, %LEN2%);\n" : (sparseRhs ? "\tLibSpoofPrimitives.vectOuterMultAdd(%IN1%, %IN2v%, %OUT%, %POS1%, %IN2i%, %POS2%, %POSOUT%, alen, %LEN1%, %LEN2%);\n" : "\tLibSpoofPrimitives.vectOuterMultAdd(%IN1%, %IN2%, %OUT%, %POS1%, %POS2%, %POSOUT%, %LEN1%, %LEN2%);\n");
                }
                case VECT_MULT_ADD: 
                case VECT_DIV_ADD: 
                case VECT_MINUS_ADD: 
                case VECT_PLUS_ADD: 
                case VECT_POW_ADD: 
                case VECT_XOR_ADD: 
                case VECT_MIN_ADD: 
                case VECT_MAX_ADD: 
                case VECT_EQUAL_ADD: 
                case VECT_NOTEQUAL_ADD: 
                case VECT_LESS_ADD: 
                case VECT_LESSEQUAL_ADD: 
                case VECT_GREATER_ADD: 
                case VECT_GREATEREQUAL_ADD: 
                case VECT_CBIND_ADD: {
                    String vectName = type.getVectorPrimitiveName();
                    if (scalarVector) {
                        return sparseLhs ? "\tLibSpoofPrimitives.vect" + vectName + "Add(%IN1%, %IN2v%, %OUT%, %IN2i%, %POS2%, %POSOUT%, alen, %LEN%);\n" : "\tLibSpoofPrimitives.vect" + vectName + "Add(%IN1%, %IN2%, %OUT%, %POS2%, %POSOUT%, %LEN%);\n";
                    }
                    return sparseLhs ? "\tLibSpoofPrimitives.vect" + vectName + "Add(%IN1v%, %IN2%, %OUT%, %IN1i%, %POS1%, %POSOUT%, alen, %LEN%);\n" : "\tLibSpoofPrimitives.vect" + vectName + "Add(%IN1%, %IN2%, %OUT%, %POS1%, %POSOUT%, %LEN%);\n";
                }
                case VECT_MULT_SCALAR: 
                case VECT_DIV_SCALAR: 
                case VECT_MINUS_SCALAR: 
                case VECT_PLUS_SCALAR: 
                case VECT_POW_SCALAR: 
                case VECT_XOR_SCALAR: 
                case VECT_BITWAND_SCALAR: 
                case VECT_MIN_SCALAR: 
                case VECT_MAX_SCALAR: 
                case VECT_EQUAL_SCALAR: 
                case VECT_NOTEQUAL_SCALAR: 
                case VECT_LESS_SCALAR: 
                case VECT_LESSEQUAL_SCALAR: 
                case VECT_GREATER_SCALAR: 
                case VECT_GREATEREQUAL_SCALAR: {
                    String vectName = type.getVectorPrimitiveName();
                    if (scalarVector) {
                        return sparseRhs ? "\tT[] %TMP% = LibSpoofPrimitives.vect" + vectName + "Write(%IN1%, %IN2v%, %IN2i%, %POS2%, alen, %LEN%);\n" : "\tT[] %TMP% = LibSpoofPrimitives.vect" + vectName + "Write(%IN1%, %IN2%, %POS2%, %LEN%);\n";
                    }
                    return sparseLhs ? "\tT[] %TMP% = LibSpoofPrimitives.vect" + vectName + "Write(%IN1v%, %IN2%, %IN1i%, %POS1%, alen, %LEN%);\n" : "\tT[] %TMP% = LibSpoofPrimitives.vect" + vectName + "Write(%IN1%, %IN2%, %POS1%, %LEN%);\n";
                }
                case VECT_MULT: 
                case VECT_DIV: 
                case VECT_MINUS: 
                case VECT_PLUS: 
                case VECT_XOR: 
                case VECT_BITWAND: 
                case VECT_BIASADD: 
                case VECT_BIASMULT: 
                case VECT_MIN: 
                case VECT_MAX: 
                case VECT_EQUAL: 
                case VECT_NOTEQUAL: 
                case VECT_LESS: 
                case VECT_LESSEQUAL: 
                case VECT_GREATER: 
                case VECT_GREATEREQUAL: {
                    String vectName = type.getVectorPrimitiveName();
                    return sparseLhs ? "\tT[] %TMP% = LibSpoofPrimitives.vect" + vectName + "Write(%IN1v%, %IN2%, %IN1i%, %POS1%, %POS2%, alen, %LEN%);\n" : (sparseRhs ? "\tT[] %TMP% = LibSpoofPrimitives.vect" + vectName + "Write(%IN1%, %IN2v%, %POS1%, %IN2i%, %POS2%, alen, %LEN%);\n" : "\tT[] %TMP% = LibSpoofPrimitives.vect" + vectName + "Write(%IN1%, %IN2%, %POS1%, %POS2%, %LEN%);\n");
                }
                case MULT: {
                    return "\tT %TMP% = %IN1% * %IN2%;\n";
                }
                case DIV: {
                    return "\tT %TMP% = %IN1% / %IN2%;\n";
                }
                case PLUS: {
                    return "\tT %TMP% = %IN1% + %IN2%;\n";
                }
                case MINUS: {
                    return "\tT %TMP% = %IN1% - %IN2%;\n";
                }
                case MODULUS: {
                    return "\tT %TMP% = modulus(%IN1%, %IN2%);\n";
                }
                case INTDIV: {
                    return "\tT %TMP% = intDiv(%IN1%, %IN2%);\n";
                }
                case LESS: {
                    return "\tT %TMP% = (%IN1% < %IN2%) ? 1 : 0;\n";
                }
                case LESSEQUAL: {
                    return "\tT %TMP% = (%IN1% <= %IN2%) ? 1 : 0;\n";
                }
                case GREATER: {
                    return "\tT %TMP% = (%IN1% > %IN2%) ? 1 : 0;\n";
                }
                case GREATEREQUAL: {
                    return "\tT %TMP% = (%IN1% >= %IN2%) ? 1 : 0;\n";
                }
                case EQUAL: {
                    return "\tT %TMP% = (%IN1% == %IN2%) ? 1 : 0;\n";
                }
                case NOTEQUAL: {
                    return "\tT %TMP% = (%IN1% != %IN2%) ? 1 : 0;\n";
                }
                case MIN: {
                    return "\tT %TMP% = fminf(%IN1%, %IN2%);\n";
                }
                case MAX: {
                    return "\tT %TMP% = fmaxf(%IN1%, %IN2%);\n";
                }
                case LOG: {
                    return "\tT %TMP% = logf(%IN1%)/Math.log(%IN2%);\n";
                }
                case LOG_NZ: {
                    return "\tT %TMP% = (%IN1% == 0) ? 0 : logf(%IN1%) / logf(%IN2%);\n";
                }
                case POW: {
                    return "\tT %TMP% = powf(%IN1%, %IN2%);\n";
                }
                case MINUS1_MULT: {
                    return "\tT %TMP% = 1 - %IN1% * %IN2%;\n";
                }
                case MINUS_NZ: {
                    return "\tT %TMP% = (%IN1% != 0) ? %IN1% - %IN2% : 0;\n";
                }
                case XOR: {
                    return "\tT %TMP% = ( (%IN1% != 0) != (%IN2% != 0) ) ? 1.0f : 0.0f;\n";
                }
                case BITWAND: {
                    return "\tT %TMP% = bwAnd(%IN1%, %IN2%);\n";
                }
                case SEQ_RIX: {
                    return "\tT %TMP% = %IN1% + grix * %IN2%;\n";
                }
            }
            throw new RuntimeException("Invalid binary type: " + this.toString());
        }
        switch (type) {
            case DOT_PRODUCT: {
                return sparseLhs ? "\t\tT %TMP% = dotProduct(%IN1v%, %IN2%, %IN1i%, %POS1%, %POS2%, alen);\n" : "\t\tT %TMP% = dotProduct(%IN1%, %IN2%, %POS1%, static_cast<uint32_t>(%POS2%), %LEN%);\n";
            }
            case VECT_MATRIXMULT: {
                return sparseLhs ? "\tT[] %TMP% = vectMatrixMult(%IN1v%, %IN2%, %IN1i%, %POS1%, %POS2%, alen, len);\n" : "\t\tVector<T>& %TMP% = vectMatrixMult(%IN1%, %IN2%, %POS1%, static_cast<uint32_t>(%POS2%), %LEN%, this);\n";
            }
            case VECT_OUTERMULT_ADD: {
                return sparseLhs ? "\tLibSpoofPrimitives.vectOuterMultAdd(%IN1v%, %IN2%, %OUT%, %IN1i%, %POS1%, %POS2%, %POSOUT%, alen, %LEN1%, %LEN2%);\n" : (sparseRhs ? "\tLibSpoofPrimitives.vectOuterMultAdd(%IN1%, %IN2v%, %OUT%, %POS1%, %IN2i%, %POS2%, %POSOUT%, alen, %LEN1%, %LEN2%);\n" : "\t\tvectOuterMultAdd(%IN1%, %IN2%, %OUT%, %POS1%, %POS2%, %POSOUT%, %LEN1%, %LEN2%);\n");
            }
            case VECT_MULT_ADD: 
            case VECT_DIV_ADD: 
            case VECT_MINUS_ADD: 
            case VECT_PLUS_ADD: 
            case VECT_POW_ADD: 
            case VECT_XOR_ADD: 
            case VECT_MIN_ADD: 
            case VECT_MAX_ADD: 
            case VECT_EQUAL_ADD: 
            case VECT_NOTEQUAL_ADD: 
            case VECT_LESS_ADD: 
            case VECT_LESSEQUAL_ADD: 
            case VECT_GREATER_ADD: 
            case VECT_GREATEREQUAL_ADD: 
            case VECT_CBIND_ADD: {
                String vectName = type.getVectorPrimitiveName();
                if (scalarVector) {
                    return sparseLhs ? "\t\tvect" + vectName + "Add(%IN1%, %IN2v%, %OUT%, %IN2i%, %POS2%, %POSOUT%, alen, %LEN%);\n" : "\t\tvect" + vectName + "Add(%IN1%, %IN2%, %OUT%, %POS2%, %POSOUT%, %LEN%);\n";
                }
                return sparseLhs ? "\t\tvect" + vectName + "Add(%IN1v%, %IN2%, %OUT%, %IN1i%, %POS1%, %POSOUT%, alen, %LEN%);\n" : "\t\tvect" + vectName + "Add(%IN1%, %IN2%, %OUT%, %POS1%, static_cast<uint32_t>(%POSOUT%), %LEN%);\n";
            }
            case VECT_MULT_SCALAR: 
            case VECT_DIV_SCALAR: 
            case VECT_MINUS_SCALAR: 
            case VECT_PLUS_SCALAR: 
            case VECT_POW_SCALAR: 
            case VECT_XOR_SCALAR: 
            case VECT_BITWAND_SCALAR: 
            case VECT_MIN_SCALAR: 
            case VECT_MAX_SCALAR: 
            case VECT_EQUAL_SCALAR: 
            case VECT_NOTEQUAL_SCALAR: 
            case VECT_LESS_SCALAR: 
            case VECT_LESSEQUAL_SCALAR: 
            case VECT_GREATER_SCALAR: 
            case VECT_GREATEREQUAL_SCALAR: {
                String vectName = type.getVectorPrimitiveName();
                if (scalarVector) {
                    return sparseRhs ? "\tT[] %TMP% = LibSpoofPrimitives.vect" + vectName + "Write(%IN1%, %IN2v%, %IN2i%, %POS2%, alen, %LEN%);\n" : "\t\tVector<T>& %TMP% = vect" + vectName + "Write(%IN1%, %IN2%, %POS2%, %LEN%, this);\n";
                }
                return sparseLhs ? "\t\tVector<T>& %TMP% = vect" + vectName + "Write(%IN1v%, %IN2%, %IN1i%, %POS1%, alen, %LEN%, this);\n" : "\t\tVector<T>& %TMP% = vect" + vectName + "Write(%IN1%, %IN2%, static_cast<uint32_t>(%POS1%), %LEN%, this);\n";
            }
            case VECT_MULT: 
            case VECT_DIV: 
            case VECT_MINUS: 
            case VECT_PLUS: 
            case VECT_XOR: 
            case VECT_BITWAND: 
            case VECT_BIASADD: 
            case VECT_BIASMULT: 
            case VECT_MIN: 
            case VECT_MAX: 
            case VECT_EQUAL: 
            case VECT_NOTEQUAL: 
            case VECT_LESS: 
            case VECT_LESSEQUAL: 
            case VECT_GREATER: 
            case VECT_GREATEREQUAL: {
                String vectName = type.getVectorPrimitiveName();
                return sparseLhs ? "\t\tVector<T>& %TMP% = vect" + vectName + "Write(%IN1v%, %IN2%, %IN1i%, %POS1%, %POS2%, alen, %LEN%);\n" : (sparseRhs ? "\t\tVector<T>& %TMP% = vect" + vectName + "Write(%IN1%, %IN2v%, %POS1%, %IN2i%, %POS2%, alen, %LEN%);\n" : "\t\tVector<T>& %TMP% = vect" + vectName + "Write(%IN1%, %IN2%, static_cast<uint32_t>(%POS1%), static_cast<uint32_t>(%POS2%), %LEN%, this);\n");
            }
            case MULT: {
                return "\t\tT %TMP% = %IN1% * %IN2%;\n";
            }
            case DIV: {
                return "\tT %TMP% = %IN1% / %IN2%;\n";
            }
            case PLUS: {
                return "\t\tT %TMP% = %IN1% + %IN2%;\n";
            }
            case MINUS: {
                return "\tT %TMP% = %IN1% - %IN2%;\n";
            }
            case MODULUS: {
                return "\tT %TMP% = modulus(%IN1%, %IN2%);\n";
            }
            case INTDIV: {
                return "\tT %TMP% = intDiv(%IN1%, %IN2%);\n";
            }
            case LESS: {
                return "\tT %TMP% = (%IN1% < %IN2%) ? 1.0 : 0.0;\n";
            }
            case LESSEQUAL: {
                return "\tT %TMP% = (%IN1% <= %IN2%) ? 1.0 : 0.0;\n";
            }
            case GREATER: {
                return "\tT %TMP% = (%IN1% > (%IN2% + EPSILON)) ? 1.0 : 0.0;\n";
            }
            case GREATEREQUAL: {
                return "\tT %TMP% = (%IN1% >= %IN2%) ? 1.0 : 0.0;\n";
            }
            case EQUAL: {
                return "\tT %TMP% = (%IN1% == %IN2%) ? 1.0 : 0.0;\n";
            }
            case NOTEQUAL: {
                return "\tT %TMP% = (%IN1% != %IN2%) ? 1.0 : 0.0;\n";
            }
            case MIN: {
                return "\tT %TMP% = min(%IN1%, %IN2%);\n";
            }
            case MAX: {
                return "\tT %TMP% = max(%IN1%, %IN2%);\n";
            }
            case LOG: {
                return "\tT %TMP% = log(%IN1%)/Math.log(%IN2%);\n";
            }
            case LOG_NZ: {
                return "\tT %TMP% = (%IN1% == 0) ? 0 : log(%IN1%) / log(%IN2%);\n";
            }
            case POW: {
                return "\tT %TMP% = pow(%IN1%, %IN2%);\n";
            }
            case MINUS1_MULT: {
                return "\tT %TMP% = 1 - %IN1% * %IN2%;\n";
            }
            case MINUS_NZ: {
                return "\tT %TMP% = (%IN1% != 0) ? %IN1% - %IN2% : 0;\n";
            }
            case XOR: {
                return "\tT %TMP% = ( (%IN1% < EPSILON) != (%IN2% < EPSILON) ) ? 1.0 : 0.0;\n";
            }
            case BITWAND: {
                return "\tT %TMP% = bwAnd(%IN1%, %IN2%);\n";
            }
            case SEQ_RIX: {
                return "\t\tT %TMP% = %IN1% + grix * %IN2%;\n";
            }
        }
        throw new RuntimeException("Invalid binary type: " + this.toString());
    }
}

