/*
 * Decompiled with CFR 0.152.
 */
package org.apache.sysds.runtime.controlprogram.federated;

import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
import java.util.Map;
import java.util.TreeMap;
import java.util.concurrent.Callable;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Future;
import java.util.function.BiFunction;
import org.apache.commons.lang3.tuple.ImmutablePair;
import org.apache.commons.lang3.tuple.Pair;
import org.apache.sysds.runtime.DMLRuntimeException;
import org.apache.sysds.runtime.controlprogram.caching.CacheBlock;
import org.apache.sysds.runtime.controlprogram.caching.CacheableData;
import org.apache.sysds.runtime.controlprogram.federated.FederatedData;
import org.apache.sysds.runtime.controlprogram.federated.FederatedRange;
import org.apache.sysds.runtime.controlprogram.federated.FederatedRequest;
import org.apache.sysds.runtime.controlprogram.federated.FederatedResponse;
import org.apache.sysds.runtime.controlprogram.federated.FederationUtils;
import org.apache.sysds.runtime.instructions.cp.ScalarObject;
import org.apache.sysds.runtime.instructions.cp.VariableCPInstruction;
import org.apache.sysds.runtime.matrix.data.MatrixBlock;
import org.apache.sysds.runtime.util.CommonThreadPool;

public class FederationMap {
    private long _ID = -1L;
    private final Map<FederatedRange, FederatedData> _fedMap;
    private FType _type;

    public FederationMap(Map<FederatedRange, FederatedData> fedMap) {
        this(-1L, fedMap);
    }

    public FederationMap(long ID, Map<FederatedRange, FederatedData> fedMap) {
        this(ID, fedMap, FType.OTHER);
    }

    public FederationMap(long ID, Map<FederatedRange, FederatedData> fedMap, FType type) {
        this._ID = ID;
        this._fedMap = fedMap;
        this._type = type;
    }

    public long getID() {
        return this._ID;
    }

    public FType getType() {
        return this._type;
    }

    public boolean isInitialized() {
        return this._ID >= 0L;
    }

    public void setType(FType type) {
        this._type = type;
    }

    public FederatedRange[] getFederatedRanges() {
        return this._fedMap.keySet().toArray(new FederatedRange[0]);
    }

    public FederatedRequest broadcast(CacheableData<?> data) {
        long id = FederationUtils.getNextFedDataID();
        Object cb = data.acquireReadAndRelease();
        return new FederatedRequest(FederatedRequest.RequestType.PUT_VAR, id, cb);
    }

    public FederatedRequest broadcast(ScalarObject scalar) {
        long id = FederationUtils.getNextFedDataID();
        return new FederatedRequest(FederatedRequest.RequestType.PUT_VAR, id, scalar);
    }

    public FederatedRequest[] broadcastSliced(CacheableData<?> data, boolean transposed) {
        long id = FederationUtils.getNextFedDataID();
        Object cb = data.acquireReadAndRelease();
        ArrayList<FederatedRequest> ret = new ArrayList<FederatedRequest>();
        for (Map.Entry<FederatedRange, FederatedData> e : this._fedMap.entrySet()) {
            int rl = transposed ? 0 : e.getKey().getBeginDimsInt()[0];
            int ru = transposed ? cb.getNumRows() - 1 : e.getKey().getEndDimsInt()[0] - 1;
            int cl = transposed ? e.getKey().getBeginDimsInt()[0] : 0;
            int cu = transposed ? e.getKey().getEndDimsInt()[0] - 1 : cb.getNumColumns() - 1;
            CacheBlock tmp = cb.slice(rl, ru, cl, cu, new MatrixBlock());
            ret.add(new FederatedRequest(FederatedRequest.RequestType.PUT_VAR, id, tmp));
        }
        return ret.toArray(new FederatedRequest[0]);
    }

    public boolean isAligned(FederationMap that, boolean transposed) {
        boolean ret = true;
        for (Map.Entry<FederatedRange, FederatedData> e : this._fedMap.entrySet()) {
            FederatedRange range = !transposed ? e.getKey() : new FederatedRange(e.getKey()).transpose();
            FederatedData dat2 = that._fedMap.get(range);
            ret &= e.getValue().equalAddress(dat2);
        }
        return ret;
    }

    public Future<FederatedResponse>[] execute(long tid, FederatedRequest ... fr) {
        return this.execute(tid, false, fr);
    }

    public Future<FederatedResponse>[] execute(long tid, boolean wait, FederatedRequest ... fr) {
        return this.execute(tid, wait, (FederatedRequest[])null, fr);
    }

    public Future<FederatedResponse>[] execute(long tid, FederatedRequest[] frSlices, FederatedRequest ... fr) {
        return this.execute(tid, false, frSlices, fr);
    }

    public Future<FederatedResponse>[] execute(long tid, boolean wait, FederatedRequest[] frSlices, FederatedRequest ... fr) {
        FederationMap.setThreadID(tid, frSlices, fr);
        ArrayList<Future<FederatedResponse>> ret = new ArrayList<Future<FederatedResponse>>();
        int pos = 0;
        for (Map.Entry<FederatedRange, FederatedData> e : this._fedMap.entrySet()) {
            ret.add(e.getValue().executeFederatedOperation(frSlices != null ? FederationMap.addAll(frSlices[pos++], fr) : fr));
        }
        if (wait) {
            FederationUtils.waitFor(ret);
        }
        return ret.toArray(new Future[0]);
    }

    public List<Pair<FederatedRange, Future<FederatedResponse>>> requestFederatedData() {
        if (!this.isInitialized()) {
            throw new DMLRuntimeException("Federated matrix read only supported on initialized FederatedData");
        }
        ArrayList<Pair<FederatedRange, Future<FederatedResponse>>> readResponses = new ArrayList<Pair<FederatedRange, Future<FederatedResponse>>>();
        FederatedRequest request = new FederatedRequest(FederatedRequest.RequestType.GET_VAR, this._ID);
        for (Map.Entry<FederatedRange, FederatedData> e : this._fedMap.entrySet()) {
            readResponses.add((Pair<FederatedRange, Future<FederatedResponse>>)new ImmutablePair((Object)e.getKey(), e.getValue().executeFederatedOperation(request)));
        }
        return readResponses;
    }

    public FederatedRequest cleanup(long tid, long ... id) {
        FederatedRequest request = new FederatedRequest(FederatedRequest.RequestType.EXEC_INST, -1L, VariableCPInstruction.prepareRemoveInstruction(id).toString());
        request.setTID(tid);
        return request;
    }

    public void execCleanup(long tid, long ... id) {
        FederatedRequest request = new FederatedRequest(FederatedRequest.RequestType.EXEC_INST, -1L, VariableCPInstruction.prepareRemoveInstruction(id).toString());
        request.setTID(tid);
        ArrayList<Future<FederatedResponse>> tmp = new ArrayList<Future<FederatedResponse>>();
        for (FederatedData fd : this._fedMap.values()) {
            tmp.add(fd.executeFederatedOperation(request));
        }
        FederationUtils.waitFor(tmp);
    }

    private static FederatedRequest[] addAll(FederatedRequest a, FederatedRequest[] b) {
        FederatedRequest[] ret = new FederatedRequest[b.length + 1];
        ret[0] = a;
        System.arraycopy(b, 0, ret, 1, b.length);
        return ret;
    }

    public FederationMap identCopy(long tid, long id) {
        Future<FederatedResponse>[] copyInstr;
        for (Future<FederatedResponse> future : copyInstr = this.execute(tid, new FederatedRequest(FederatedRequest.RequestType.EXEC_INST, this._ID, VariableCPInstruction.prepareCopyInstruction(Long.toString(this._ID), Long.toString(id)).toString()))) {
            try {
                FederatedResponse response = future.get();
                if (response.isSuccessful()) continue;
                response.throwExceptionFromResponse();
            }
            catch (Exception e) {
                throw new DMLRuntimeException(e);
            }
        }
        FederationMap copyFederationMap = this.copyWithNewID(id);
        copyFederationMap._type = this._type;
        return copyFederationMap;
    }

    public FederationMap copyWithNewID() {
        return this.copyWithNewID(FederationUtils.getNextFedDataID());
    }

    public FederationMap copyWithNewID(long id) {
        TreeMap<FederatedRange, FederatedData> map = new TreeMap<FederatedRange, FederatedData>();
        for (Map.Entry<FederatedRange, FederatedData> e : this._fedMap.entrySet()) {
            map.put(new FederatedRange(e.getKey()), e.getValue().copyWithNewID(id));
        }
        return new FederationMap(id, map, this._type);
    }

    public FederationMap copyWithNewID(long id, long clen) {
        TreeMap<FederatedRange, FederatedData> map = new TreeMap<FederatedRange, FederatedData>();
        for (Map.Entry<FederatedRange, FederatedData> e : this._fedMap.entrySet()) {
            map.put(new FederatedRange(e.getKey(), clen), e.getValue().copyWithNewID(id));
        }
        return new FederationMap(id, map);
    }

    public FederationMap bind(long rOffset, long cOffset, FederationMap that) {
        for (Map.Entry<FederatedRange, FederatedData> e : that._fedMap.entrySet()) {
            this._fedMap.put(new FederatedRange(e.getKey()).shift(rOffset, cOffset), e.getValue().copyWithNewID(this._ID));
        }
        return this;
    }

    public FederationMap transpose() {
        TreeMap<FederatedRange, FederatedData> tmp = new TreeMap<FederatedRange, FederatedData>(this._fedMap);
        this._fedMap.clear();
        for (Map.Entry e : tmp.entrySet()) {
            this._fedMap.put(new FederatedRange((FederatedRange)e.getKey()).transpose(), ((FederatedData)e.getValue()).copyWithNewID(this._ID));
        }
        switch (this._type) {
            case ROW: {
                this._type = FType.COL;
                break;
            }
            case COL: {
                this._type = FType.ROW;
                break;
            }
            default: {
                this._type = FType.OTHER;
            }
        }
        return this;
    }

    public long getMaxIndexInRange(int dim) {
        return this._fedMap.keySet().stream().mapToLong(range -> range.getEndDims()[dim]).max().orElse(-1L);
    }

    public void forEachParallel(BiFunction<FederatedRange, FederatedData, Void> forEachFunction) {
        ExecutorService pool = CommonThreadPool.get(this._fedMap.size());
        ArrayList<MappingTask> mappingTasks = new ArrayList<MappingTask>();
        for (Map.Entry<FederatedRange, FederatedData> fedMap : this._fedMap.entrySet()) {
            mappingTasks.add(new MappingTask(fedMap.getKey(), fedMap.getValue(), forEachFunction, this._ID));
        }
        CommonThreadPool.invokeAndShutdown(pool, mappingTasks);
    }

    public FederationMap mapParallel(long newVarID, BiFunction<FederatedRange, FederatedData, Void> mappingFunction) {
        ExecutorService pool = CommonThreadPool.get(this._fedMap.size());
        FederationMap fedMapCopy = this.copyWithNewID(this._ID);
        ArrayList<MappingTask> mappingTasks = new ArrayList<MappingTask>();
        for (Map.Entry<FederatedRange, FederatedData> fedMap : fedMapCopy._fedMap.entrySet()) {
            mappingTasks.add(new MappingTask(fedMap.getKey(), fedMap.getValue(), mappingFunction, newVarID));
        }
        CommonThreadPool.invokeAndShutdown(pool, mappingTasks);
        fedMapCopy._ID = newVarID;
        return fedMapCopy;
    }

    private static void setThreadID(long tid, FederatedRequest[] ... frsets) {
        for (FederatedRequest[] frset : frsets) {
            if (frset == null) continue;
            Arrays.stream(frset).forEach(fr -> fr.setTID(tid));
        }
    }

    private static class MappingTask
    implements Callable<Void> {
        private final FederatedRange _range;
        private final FederatedData _data;
        private final BiFunction<FederatedRange, FederatedData, Void> _mappingFunction;
        private final long _varID;

        public MappingTask(FederatedRange range, FederatedData data, BiFunction<FederatedRange, FederatedData, Void> mappingFunction, long varID) {
            this._range = range;
            this._data = data;
            this._mappingFunction = mappingFunction;
            this._varID = varID;
        }

        @Override
        public Void call() throws Exception {
            this._mappingFunction.apply(this._range, this._data);
            this._data.setVarID(this._varID);
            return null;
        }
    }

    public static enum FType {
        ROW,
        COL,
        OTHER;

    }
}

