/*
 * Decompiled with CFR 0.152.
 */
package org.apache.sysml.runtime.instructions.cp;

import java.util.ArrayList;
import java.util.HashSet;
import java.util.List;
import org.apache.sysml.api.DMLScript;
import org.apache.sysml.parser.DMLProgram;
import org.apache.sysml.parser.DataIdentifier;
import org.apache.sysml.parser.Expression;
import org.apache.sysml.runtime.DMLRuntimeException;
import org.apache.sysml.runtime.DMLScriptException;
import org.apache.sysml.runtime.controlprogram.FunctionProgramBlock;
import org.apache.sysml.runtime.controlprogram.LocalVariableMap;
import org.apache.sysml.runtime.controlprogram.context.ExecutionContext;
import org.apache.sysml.runtime.controlprogram.context.ExecutionContextFactory;
import org.apache.sysml.runtime.instructions.Instruction;
import org.apache.sysml.runtime.instructions.InstructionUtils;
import org.apache.sysml.runtime.instructions.cp.CPInstruction;
import org.apache.sysml.runtime.instructions.cp.CPOperand;
import org.apache.sysml.runtime.instructions.cp.Data;
import org.apache.sysml.runtime.instructions.cp.ScalarObject;
import org.apache.sysml.runtime.instructions.cp.ScalarObjectFactory;
import org.apache.sysml.runtime.io.IOUtilFunctions;

public class FunctionCallCPInstruction
extends CPInstruction {
    private final String _functionName;
    private final String _namespace;
    private final CPOperand[] _boundInputs;
    private final List<String> _boundInputNames;
    private final List<String> _funArgNames;
    private final List<String> _boundOutputNames;

    public FunctionCallCPInstruction(String namespace, String functName, CPOperand[] boundInputs, List<String> boundInputNames, List<String> funArgNames, List<String> boundOutputNames, String istr) {
        super(CPInstruction.CPType.External, null, functName, istr);
        this._functionName = functName;
        this._namespace = namespace;
        this._boundInputs = boundInputs;
        this._boundInputNames = boundInputNames;
        this._funArgNames = funArgNames;
        this._boundOutputNames = boundOutputNames;
    }

    public String getFunctionName() {
        return this._functionName;
    }

    public String getNamespace() {
        return this._namespace;
    }

    public static FunctionCallCPInstruction parseInstruction(String str) {
        int i;
        String[] parts = InstructionUtils.getInstructionPartsWithValueType(str);
        String namespace = parts[1];
        String functionName = parts[2];
        int numInputs = Integer.valueOf(parts[3]);
        int numOutputs = Integer.valueOf(parts[4]);
        CPOperand[] boundInputs = new CPOperand[numInputs];
        ArrayList<String> boundInputNames = new ArrayList<String>();
        ArrayList<String> funArgNames = new ArrayList<String>();
        ArrayList<String> boundOutputNames = new ArrayList<String>();
        for (i = 0; i < numInputs; ++i) {
            String[] nameValue = IOUtilFunctions.splitByFirst(parts[5 + i], "=");
            boundInputs[i] = new CPOperand(nameValue[1]);
            funArgNames.add(nameValue[0]);
            boundInputNames.add(boundInputs[i].getName());
        }
        for (i = 0; i < numOutputs; ++i) {
            boundOutputNames.add(parts[5 + numInputs + i]);
        }
        return new FunctionCallCPInstruction(namespace, functionName, boundInputs, boundInputNames, funArgNames, boundOutputNames, str);
    }

    @Override
    public Instruction preprocessInstruction(ExecutionContext ec) {
        Instruction tmp = super.preprocessInstruction(ec);
        if (DMLScript.ENABLE_DEBUG_MODE) {
            ec.handleDebugFunctionEntry((FunctionCallCPInstruction)tmp);
        }
        return tmp;
    }

    /*
     * WARNING - void declaration
     */
    @Override
    public void processInstruction(ExecutionContext ec) {
        FunctionProgramBlock fpb;
        if (LOG.isTraceEnabled()) {
            LOG.trace((Object)("Executing instruction : " + this.toString()));
        }
        if (this._boundInputs.length < (fpb = ec.getProgram().getFunctionProgramBlock(this._namespace, this._functionName)).getInputParams().size()) {
            throw new DMLRuntimeException("Number of bound input parameters does not match the function signature (" + this._boundInputs.length + ", but " + fpb.getInputParams().size() + " expected)");
        }
        LocalVariableMap functionVariables = new LocalVariableMap();
        for (int i = 0; i < this._boundInputs.length; ++i) {
            void var8_11;
            CPOperand input = this._boundInputs[i];
            if (!input.isLiteral() && !ec.containsVariable(input.getName())) {
                throw new DMLRuntimeException("Input variable '" + input.getName() + "' not existing on call of " + DMLProgram.constructFunctionKey(this._namespace, this._functionName) + " (line " + this.getLineNum() + ").");
            }
            String argName = this._funArgNames.get(i);
            DataIdentifier currFormalParam = fpb.getInputParam(argName);
            if (currFormalParam == null) {
                throw new DMLRuntimeException("Non-existing named function argument: '" + argName + "' (line " + this.getLineNum() + ").");
            }
            Data data = ec.getVariable(input);
            if (data.getDataType() == Expression.DataType.SCALAR && data.getValueType() != currFormalParam.getValueType()) {
                ScalarObject scalarObject = ScalarObjectFactory.createScalarObject(currFormalParam.getValueType(), (ScalarObject)data);
            }
            functionVariables.put(currFormalParam.getName(), (Data)var8_11);
        }
        boolean[] pinStatus = ec.pinVariables(this._boundInputNames);
        ExecutionContext fn_ec = ExecutionContextFactory.createContext(false, ec.getProgram());
        if (DMLScript.USE_ACCELERATOR) {
            fn_ec.setGPUContexts(ec.getGPUContexts());
            fn_ec.getGPUContext(0).initializeThread();
        }
        fn_ec.setVariables(functionVariables);
        try {
            fpb._functionName = this._functionName;
            fpb._namespace = this._namespace;
            fpb.execute(fn_ec);
        }
        catch (DMLScriptException e) {
            throw e;
        }
        catch (Exception e) {
            String fname = DMLProgram.constructFunctionKey(this._namespace, this._functionName);
            throw new DMLRuntimeException("error executing function " + (String)fname, e);
        }
        HashSet<String> expectRetVars = new HashSet<String>();
        for (DataIdentifier dataIdentifier : fpb.getOutputParams()) {
            expectRetVars.add(dataIdentifier.getName());
        }
        LocalVariableMap retVars = fn_ec.getVariables();
        for (String varName : new ArrayList<String>(retVars.keySet())) {
            if (expectRetVars.contains(varName)) continue;
            fn_ec.cleanupDataObject(fn_ec.removeVariable(varName));
        }
        ec.unpinVariables(this._boundInputNames, pinStatus);
        int n = Math.min(this._boundOutputNames.size(), fpb.getOutputParams().size());
        for (int i = 0; i < n; ++i) {
            String boundVarName = this._boundOutputNames.get(i);
            Data boundValue = retVars.get(fpb.getOutputParams().get(i).getName());
            if (boundValue == null) {
                throw new DMLRuntimeException(boundVarName + " was not assigned a return value");
            }
            Data exdata = ec.removeVariable(boundVarName);
            if (exdata != boundValue) {
                ec.cleanupDataObject(exdata);
            }
            ec.setVariable(boundVarName, boundValue);
        }
    }

    @Override
    public void postprocessInstruction(ExecutionContext ec) {
        if (DMLScript.ENABLE_DEBUG_MODE) {
            ec.handleDebugFunctionExit(this);
        }
        super.postprocessInstruction(ec);
    }

    @Override
    public void printMe() {
        LOG.debug((Object)("ExternalBuiltInFunction: " + this.toString()));
    }

    public List<String> getBoundOutputParamNames() {
        return this._boundOutputNames;
    }

    public String updateInstStringFunctionName(String pattern, String replace) {
        String[] parts = this.instString.split("\u00b0");
        if (parts[3].equals(pattern)) {
            parts[3] = replace;
        }
        StringBuilder sb = new StringBuilder();
        for (String part : parts) {
            sb.append(part);
            sb.append("\u00b0");
        }
        return sb.substring(0, sb.length() - "\u00b0".length());
    }
}

