/*
 * Decompiled with CFR 0.152.
 */
package org.apache.iotdb.confignode.manager.load.balancer.router.leader;

import java.util.ArrayList;
import java.util.Arrays;
import java.util.Iterator;
import java.util.LinkedList;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.Set;
import java.util.TreeMap;
import java.util.concurrent.ConcurrentHashMap;
import org.apache.iotdb.common.rpc.thrift.TConsensusGroupId;
import org.apache.iotdb.commons.utils.TestOnly;
import org.apache.iotdb.confignode.manager.load.balancer.router.leader.AbstractLeaderBalancer;
import org.apache.iotdb.confignode.manager.load.cache.node.NodeStatistics;
import org.apache.iotdb.confignode.manager.load.cache.region.RegionStatistics;

public class CostFlowSelectionLeaderBalancer
extends AbstractLeaderBalancer {
    private static final int INFINITY = Integer.MAX_VALUE;
    private static final int S_VERTEX = 0;
    private static final int T_VERTEX = 1;
    private int maxVertex = 2;
    private final Map<TConsensusGroupId, Integer> rVertexMap = new TreeMap<TConsensusGroupId, Integer>();
    private final Map<String, Map<Integer, Integer>> sDVertexMap = new TreeMap<String, Map<Integer, Integer>>();
    private final Map<String, Map<Integer, Integer>> sDVertexReflect = new TreeMap<String, Map<Integer, Integer>>();
    private final Map<Integer, Integer> tDVertexMap = new TreeMap<Integer, Integer>();
    private int maxEdge = 0;
    private final List<CostFlowEdge> costFlowEdges = new ArrayList<CostFlowEdge>();
    private int[] vertexHeadEdge;
    private int[] vertexCurrentEdge;
    private boolean[] isVertexVisited;
    private int[] vertexMinimumCost;
    private int maximumFlow = 0;
    private int minimumCost = 0;

    @Override
    public Map<TConsensusGroupId, Integer> generateOptimalLeaderDistribution(Map<String, List<TConsensusGroupId>> databaseRegionGroupMap, Map<TConsensusGroupId, Set<Integer>> regionLocationMap, Map<TConsensusGroupId, Integer> regionLeaderMap, Map<Integer, NodeStatistics> dataNodeStatisticsMap, Map<TConsensusGroupId, Map<Integer, RegionStatistics>> regionStatisticsMap) {
        this.initialize(databaseRegionGroupMap, regionLocationMap, regionLeaderMap, dataNodeStatisticsMap, regionStatisticsMap);
        this.constructFlowNetwork();
        this.dinicAlgorithm();
        Map<TConsensusGroupId, Integer> result = this.collectLeaderDistribution();
        this.clear();
        return result;
    }

    @Override
    protected void clear() {
        super.clear();
        this.rVertexMap.clear();
        this.sDVertexMap.clear();
        this.sDVertexReflect.clear();
        this.tDVertexMap.clear();
        this.costFlowEdges.clear();
        this.vertexHeadEdge = null;
        this.vertexCurrentEdge = null;
        this.isVertexVisited = null;
        this.vertexMinimumCost = null;
        this.maxVertex = 2;
        this.maxEdge = 0;
    }

    private void constructFlowNetwork() {
        String database;
        this.maximumFlow = 0;
        this.minimumCost = 0;
        for (Map.Entry entry : this.databaseRegionGroupMap.entrySet()) {
            database = (String)entry.getKey();
            this.sDVertexMap.put(database, new TreeMap());
            this.sDVertexReflect.put(database, new TreeMap());
            for (TConsensusGroupId regionGroupId2 : (List)entry.getValue()) {
                if (!this.regionGroupIntersection.contains(regionGroupId2)) continue;
                this.rVertexMap.put(regionGroupId2, this.maxVertex++);
                ((Set)this.regionLocationMap.get(regionGroupId2)).forEach(dataNodeId -> {
                    if (this.isDataNodeAvailable((int)dataNodeId)) {
                        if (!this.sDVertexMap.get(database).containsKey(dataNodeId)) {
                            this.sDVertexMap.get(database).put((Integer)dataNodeId, this.maxVertex);
                            this.sDVertexReflect.get(database).put(this.maxVertex, (Integer)dataNodeId);
                            ++this.maxVertex;
                        }
                        if (!this.tDVertexMap.containsKey(dataNodeId)) {
                            this.tDVertexMap.put((Integer)dataNodeId, this.maxVertex);
                            ++this.maxVertex;
                        }
                    }
                });
            }
        }
        this.isVertexVisited = new boolean[this.maxVertex];
        this.vertexMinimumCost = new int[this.maxVertex];
        this.vertexCurrentEdge = new int[this.maxVertex];
        this.vertexHeadEdge = new int[this.maxVertex];
        Arrays.fill(this.vertexHeadEdge, -1);
        Iterator<Object> iterator = this.rVertexMap.values().iterator();
        while (iterator.hasNext()) {
            int n = (Integer)iterator.next();
            this.addAdjacentEdges(0, n, 1, 0);
        }
        for (Map.Entry entry : this.databaseRegionGroupMap.entrySet()) {
            database = (String)entry.getKey();
            for (TConsensusGroupId regionGroupId2 : (List)entry.getValue()) {
                if (!this.regionGroupIntersection.contains(regionGroupId2)) continue;
                int rVertex = this.rVertexMap.get(regionGroupId2);
                ((Set)this.regionLocationMap.get(regionGroupId2)).forEach(dataNodeId -> {
                    if (this.isDataNodeAvailable((int)dataNodeId) && this.isRegionAvailable(regionGroupId2, (int)dataNodeId)) {
                        int sDVertex = this.sDVertexMap.get(database).get(dataNodeId);
                        int cost = Objects.equals(this.regionLeaderMap.getOrDefault(regionGroupId2, -1), dataNodeId) ? 0 : 1;
                        this.addAdjacentEdges(rVertex, sDVertex, 1, cost);
                    }
                });
            }
        }
        for (Map.Entry entry : this.databaseRegionGroupMap.entrySet()) {
            database = (String)entry.getKey();
            TreeMap leaderCounter = new TreeMap();
            for (TConsensusGroupId regionGroupId3 : (List)entry.getValue()) {
                if (!this.regionGroupIntersection.contains(regionGroupId3)) continue;
                ((Set)this.regionLocationMap.get(regionGroupId3)).forEach(dataNodeId -> {
                    if (this.isDataNodeAvailable((int)dataNodeId)) {
                        int sDVertex = this.sDVertexMap.get(database).get(dataNodeId);
                        int tDVertex = this.tDVertexMap.get(dataNodeId);
                        int leaderCount = leaderCounter.merge(dataNodeId, 1, Integer::sum);
                        this.addAdjacentEdges(sDVertex, tDVertex, 1, 2 * leaderCount - 1);
                    }
                });
            }
        }
        TreeMap maxLeaderCounter = new TreeMap();
        this.regionLocationMap.forEach((regionGroupId, dataNodeIds) -> dataNodeIds.forEach(dataNodeId -> {
            if (this.isDataNodeAvailable((int)dataNodeId) && this.tDVertexMap.containsKey(dataNodeId)) {
                int tDVertex = this.tDVertexMap.get(dataNodeId);
                int leaderCount = maxLeaderCounter.merge(dataNodeId, 1, Integer::sum);
                this.addAdjacentEdges(tDVertex, 1, 1, 2 * leaderCount - 1);
            }
        }));
    }

    private void addAdjacentEdges(int fromVertex, int destVertex, int capacity, int cost) {
        this.addEdge(fromVertex, destVertex, capacity, cost);
        this.addEdge(destVertex, fromVertex, 0, -cost);
    }

    private void addEdge(int fromVertex, int destVertex, int capacity, int cost) {
        CostFlowEdge edge = new CostFlowEdge(destVertex, capacity, cost, this.vertexHeadEdge[fromVertex]);
        this.costFlowEdges.add(edge);
        ++this.maxEdge;
    }

    private boolean SPFACheck() {
        Arrays.fill(this.isVertexVisited, false);
        Arrays.fill(this.vertexMinimumCost, Integer.MAX_VALUE);
        LinkedList<Integer> queue = new LinkedList<Integer>();
        this.vertexMinimumCost[0] = 0;
        this.isVertexVisited[0] = true;
        queue.offer(0);
        while (!queue.isEmpty()) {
            int currentVertex = (Integer)queue.poll();
            this.isVertexVisited[currentVertex] = false;
            int currentEdge = this.vertexHeadEdge[currentVertex];
            while (currentEdge >= 0) {
                CostFlowEdge edge = this.costFlowEdges.get(currentEdge);
                if (edge.capacity > 0 && this.vertexMinimumCost[currentVertex] + edge.cost < this.vertexMinimumCost[edge.destVertex]) {
                    this.vertexMinimumCost[((CostFlowEdge)edge).destVertex] = this.vertexMinimumCost[currentVertex] + edge.cost;
                    if (!this.isVertexVisited[edge.destVertex]) {
                        this.isVertexVisited[((CostFlowEdge)edge).destVertex] = true;
                        queue.offer(edge.destVertex);
                    }
                }
                currentEdge = this.costFlowEdges.get(currentEdge).nextEdge;
            }
        }
        return this.vertexMinimumCost[1] < Integer.MAX_VALUE;
    }

    private int dfsAugmentation(int currentVertex, int inputFlow) {
        if (currentVertex == 1 || inputFlow == 0) {
            return inputFlow;
        }
        int outputFlow = 0;
        this.isVertexVisited[currentVertex] = true;
        int currentEdge = this.vertexCurrentEdge[currentVertex];
        while (currentEdge >= 0) {
            CostFlowEdge edge = this.costFlowEdges.get(currentEdge);
            if (this.vertexMinimumCost[currentVertex] + edge.cost == this.vertexMinimumCost[edge.destVertex] && edge.capacity > 0 && !this.isVertexVisited[edge.destVertex]) {
                int subOutputFlow = this.dfsAugmentation(edge.destVertex, Math.min(inputFlow, edge.capacity));
                this.minimumCost += subOutputFlow * edge.cost;
                edge.capacity -= subOutputFlow;
                this.costFlowEdges.get(currentEdge ^ 1).capacity += subOutputFlow;
                outputFlow += subOutputFlow;
                if ((inputFlow -= subOutputFlow) == 0) break;
            }
            currentEdge = this.costFlowEdges.get(currentEdge).nextEdge;
        }
        this.vertexCurrentEdge[currentVertex] = currentEdge;
        if (outputFlow > 0) {
            this.isVertexVisited[currentVertex] = false;
        }
        return outputFlow;
    }

    private void dinicAlgorithm() {
        while (this.SPFACheck()) {
            int currentFlow;
            System.arraycopy(this.vertexHeadEdge, 0, this.vertexCurrentEdge, 0, this.maxVertex);
            while ((currentFlow = this.dfsAugmentation(0, Integer.MAX_VALUE)) > 0) {
                this.maximumFlow += currentFlow;
            }
        }
    }

    private Map<TConsensusGroupId, Integer> collectLeaderDistribution() {
        ConcurrentHashMap<TConsensusGroupId, Integer> result = new ConcurrentHashMap<TConsensusGroupId, Integer>();
        this.databaseRegionGroupMap.forEach((database, regionGroupIds) -> regionGroupIds.forEach(regionGroupId -> {
            int originalLeader = this.regionLeaderMap.getOrDefault(regionGroupId, -1);
            if (!this.regionGroupIntersection.contains(regionGroupId)) {
                result.put((TConsensusGroupId)regionGroupId, originalLeader);
                return;
            }
            boolean matchLeader = false;
            int currentEdge = this.vertexHeadEdge[this.rVertexMap.get(regionGroupId)];
            while (currentEdge >= 0) {
                CostFlowEdge edge = this.costFlowEdges.get(currentEdge);
                if (edge.destVertex != 0 && edge.capacity == 0) {
                    matchLeader = true;
                    result.put((TConsensusGroupId)regionGroupId, this.sDVertexReflect.get(database).get(edge.destVertex));
                }
                currentEdge = this.costFlowEdges.get(currentEdge).nextEdge;
            }
            if (!matchLeader) {
                result.put((TConsensusGroupId)regionGroupId, originalLeader);
            }
        }));
        return result;
    }

    @TestOnly
    public int getMaximumFlow() {
        return this.maximumFlow;
    }

    @TestOnly
    public int getMinimumCost() {
        return this.minimumCost;
    }

    private static class CostFlowEdge {
        private final int destVertex;
        private int capacity;
        private final int cost;
        private final int nextEdge;

        private CostFlowEdge(int destVertex, int capacity, int cost, int nextEdge) {
            this.destVertex = destVertex;
            this.capacity = capacity;
            this.cost = cost;
            this.nextEdge = nextEdge;
        }
    }
}

