/*
 * Decompiled with CFR 0.152.
 */
package org.apache.sysds.hops.ipa;

import java.util.ArrayList;
import org.apache.sysds.common.Builtins;
import org.apache.sysds.common.Types;
import org.apache.sysds.hops.DataOp;
import org.apache.sysds.hops.FunctionOp;
import org.apache.sysds.hops.Hop;
import org.apache.sysds.hops.LiteralOp;
import org.apache.sysds.hops.OptimizerUtils;
import org.apache.sysds.hops.ipa.FunctionCallGraph;
import org.apache.sysds.hops.ipa.FunctionCallSizeInfo;
import org.apache.sysds.hops.ipa.IPAPass;
import org.apache.sysds.hops.rewrite.HopRewriteUtils;
import org.apache.sysds.parser.DMLProgram;
import org.apache.sysds.parser.ForStatement;
import org.apache.sysds.parser.ForStatementBlock;
import org.apache.sysds.parser.FunctionStatement;
import org.apache.sysds.parser.FunctionStatementBlock;
import org.apache.sysds.parser.IfStatement;
import org.apache.sysds.parser.IfStatementBlock;
import org.apache.sysds.parser.StatementBlock;
import org.apache.sysds.parser.WhileStatement;
import org.apache.sysds.parser.WhileStatementBlock;

public class IPAPassReplaceEvalFunctionCalls
extends IPAPass {
    @Override
    public boolean isApplicable(FunctionCallGraph fgraph) {
        return fgraph.containsSecondOrderCall() && OptimizerUtils.ALLOW_EVAL_FCALL_REPLACEMENT;
    }

    @Override
    public boolean rewriteProgram(DMLProgram prog, FunctionCallGraph fgraph, FunctionCallSizeInfo fcallSizes) {
        boolean ret = false;
        for (String namespaceKey : prog.getNamespaces().keySet()) {
            for (String fname : prog.getFunctionStatementBlocks(namespaceKey).keySet()) {
                FunctionStatementBlock fsblock = prog.getFunctionStatementBlock(namespaceKey, fname);
                ret |= IPAPassReplaceEvalFunctionCalls.rewriteStatementBlock(prog, fsblock, fgraph);
            }
        }
        for (StatementBlock sb : prog.getStatementBlocks()) {
            ret |= IPAPassReplaceEvalFunctionCalls.rewriteStatementBlock(prog, sb, fgraph);
        }
        return ret;
    }

    private static boolean rewriteStatementBlock(DMLProgram prog, StatementBlock sb, FunctionCallGraph fgraph) {
        boolean ret = false;
        if (sb instanceof FunctionStatementBlock) {
            FunctionStatementBlock fsb = (FunctionStatementBlock)sb;
            FunctionStatement fstmt = (FunctionStatement)fsb.getStatement(0);
            for (StatementBlock csb : fstmt.getBody()) {
                ret |= IPAPassReplaceEvalFunctionCalls.rewriteStatementBlock(prog, csb, fgraph);
            }
        } else if (sb instanceof WhileStatementBlock) {
            WhileStatementBlock wsb = (WhileStatementBlock)sb;
            WhileStatement wstmt = (WhileStatement)wsb.getStatement(0);
            for (StatementBlock csb : wstmt.getBody()) {
                ret |= IPAPassReplaceEvalFunctionCalls.rewriteStatementBlock(prog, csb, fgraph);
            }
        } else if (sb instanceof IfStatementBlock) {
            IfStatementBlock isb = (IfStatementBlock)sb;
            IfStatement istmt = (IfStatement)isb.getStatement(0);
            for (StatementBlock csb : istmt.getIfBody()) {
                ret |= IPAPassReplaceEvalFunctionCalls.rewriteStatementBlock(prog, csb, fgraph);
            }
            for (StatementBlock csb : istmt.getElseBody()) {
                ret |= IPAPassReplaceEvalFunctionCalls.rewriteStatementBlock(prog, csb, fgraph);
            }
        } else if (sb instanceof ForStatementBlock) {
            ForStatementBlock fsb = (ForStatementBlock)sb;
            ForStatement fstmt = (ForStatement)fsb.getStatement(0);
            for (StatementBlock csb : fstmt.getBody()) {
                ret |= IPAPassReplaceEvalFunctionCalls.rewriteStatementBlock(prog, csb, fgraph);
            }
        } else {
            ret |= IPAPassReplaceEvalFunctionCalls.checkAndReplaceEvalFunctionCall(prog, sb, fgraph);
        }
        return ret;
    }

    private static boolean checkAndReplaceEvalFunctionCall(DMLProgram prog, StatementBlock sb, FunctionCallGraph fgraph) {
        if (sb.getHops() == null) {
            return false;
        }
        ArrayList<Hop> roots = sb.getHops();
        boolean ret = false;
        for (int i = 0; i < roots.size(); ++i) {
            FunctionStatement fstmt;
            String fnamespace;
            Hop root = (Hop)roots.get(i);
            if (!HopRewriteUtils.isData(root, Types.OpOpData.TRANSIENTWRITE, Types.OpOpData.PERSISTENTWRITE) || !HopRewriteUtils.isNary(root.getInput(0), Types.OpOpN.EVAL) || !(root.getInput(0).getInput(0) instanceof LiteralOp) || root.getInput(0).getParent().size() != 1) continue;
            Hop eval = root.getInput(0);
            String outvar = ((DataOp)root).getName();
            String fname = ((LiteralOp)eval.getInput(0)).getStringValue();
            String string = fnamespace = prog.getDefaultFunctionDictionary().containsFunction(fname) ? ".defaultNS" : ".builtinNS";
            if (fname.contains("::")) {
                String[] fparts = DMLProgram.splitFunctionKey(fname);
                fnamespace = fparts[0];
                fname = fparts[1];
            }
            fname = fnamespace.equals(".builtinNS") ? Builtins.getInternalFName(fname, eval.getInput(1).getDataType()) : fname;
            FunctionStatementBlock fsb = prog.getFunctionStatementBlock(fnamespace, fname);
            FunctionStatement functionStatement = fstmt = fsb != null ? (FunctionStatement)fsb.getStatement(0) : null;
            if (eval.getInput().size() > 1 && eval.getInput(1).getDataType().isList() && (fstmt == null || !fstmt.getInputParams().get(0).getDataType().isList())) {
                LOG.warn((Object)("IPA: eval(" + fnamespace + "::" + fname + ") applicable for replacement, but list inputs not yet supported."));
                continue;
            }
            if (eval.getDataType().isList()) {
                LOG.warn((Object)("IPA: eval(" + fnamespace + "::" + fname + ") applicable for replacement, but list output not yet supported."));
                continue;
            }
            if (fstmt.getOutputParams().size() != 1 || !fstmt.getOutputParams().get(0).getDataType().isMatrix()) {
                LOG.warn((Object)("IPA: eval(" + fnamespace + "::" + fname + ") applicable for replacement, but function output is not a matrix."));
                continue;
            }
            FunctionOp fop = new FunctionOp(FunctionOp.FunctionType.DML, fnamespace, fname, fstmt.getInputParamNames(), eval.getInput().subList(1, eval.getInput().size()), new String[]{outvar}, true);
            HopRewriteUtils.copyLineNumbers(eval, (Hop)fop);
            HopRewriteUtils.removeAllChildReferences(eval);
            roots.set(i, fop);
            ret = true;
        }
        return ret;
    }
}

