负载均衡-哈希一致性算法

负载均衡

负载均衡通常用于服务器集群当中,就是当有多个服务器可以提供服务,负载均衡能够将请求均衡的分配到不同的服务器进行处理。避免了单个服务器流量过高而导致崩溃,或单个服务器长期空闲导致浪费资源。

哈希一致性算法

在了解哈希一致性算法之前,我们可以先看一下普通的hash算法。hash算法的主要作用就是散列,将一系列在形式上具有相似性质的数据,打散成随机的均匀分布的数据。在JDK中,许多类都有重写hashcode()方法。

普通的hash算法也是可以用来实现负载均衡的。例如我们有一张图片,现在要缓存到服务器中,而在缓存服务器集群中,我们有3台服务器可以作为目标,我们暂且将服务器编号设为0,1,2。

我们首先将图片的名称作为key值,获取其hash得到hash(key)。再用该值与服务器数量取模,得到的值,便可以对应集群当中服务器的编号。获取图片时,我们同样可以通过hash计算,获取图片存储的位置。

这样可以初步实现一个负载均衡,但是这样做也有一个缺点,那就是当服务器数量增加时,我们同样按照hash计算后的值进行取模,那么得到的值是前后不一的,导致所有hash都失效。

这样就会导致无法获取到正确的缓存数据,导致大量请求从数据库中获取数据,数据库压力过大出现异常,导致雪崩。

而哈希一致性算法的好处就是能够尽量减少该问题出现的机率。

想象有这么一个哈希环,其周长为2的32次方。

我们同样将上述例子中的服务器通过哈希计算后对2的32次方取模,得到这么一个值,就是该服务器再哈希环上的位置。

当我们要缓存图片时,就可以将图片通过哈希计算后对2的32次方取模,得到一个值,再沿顺时针方法在环上找到一个邻近的服务器,就是负载均衡的结果。

容器和负载均衡 负载均衡使用场景_哈希算法

使用一致性哈希算法,同样的服务器数量增加时,可能会导致一部分的缓存失效,但大部分的缓存还是能够正常使用。例如新增一个服务器D,原先需要到服务器A中的请求,可能会被分配到服务器D中而失效。但是大部分还是有效。有效避免的大量缓存失效,导致雪崩的局面。

但是这么做还是有一个缺点,那就是数据倾斜,也就是服务器扎堆分布在环上的某一个位置,导致缓存分布不均匀,大部分的请求,都分配了一台服务器上。

容器和负载均衡 负载均衡使用场景_算法_02

为了解决上述问题,就引入了虚拟节点的概念。

也就是将一个服务器复制出大量的虚拟节点,并分布到环上,使分布在环上尽可能的均匀。

容器和负载均衡 负载均衡使用场景_容器和负载均衡_03


在查找缓存时我们就可以先获取到虚拟服务节点,再去获取其真实节点,以获取缓存数据。

代码实现

以下代码是参考Dubbo中,对于负载均衡的哈希一致性算法的实现。可直接用于RPC框架当中。

public class ConsistentHashLoadBalance  implements LoadBalancer{
    // 存储该服务对应的环,环的数据结构为SortedMap
    private final ConcurrentHashMap<String, ConsistentHashSelector> selectors = new ConcurrentHashMap<>();

    // 每个服务对应有1000个虚拟结点,具体可视业务需要而定
    private static final int VIRTUAL_NODE_NUM = 1000;

    @Override
    public Instance select(List<Instance> instances) {
        // 服务名称
        String rpcServiceName = instances.get(0).getServiceName();
        // 服务列表的hashcode,以判断该列表是否有更新
        int identityHashCode = instances.hashCode();
        // 获取该服务对应的环
        ConsistentHashSelector selector = selectors.get(rpcServiceName);
        // 检查更新
        if (selector == null || selector.identityHashCode != identityHashCode) {
            selectors.put(rpcServiceName, new ConsistentHashSelector(instances, identityHashCode));
            selector = selectors.get(rpcServiceName);
        }
        // Arrays.stream(rpcRequest.getParameters()), 方法参数不同,分配到不同的服务
        return selector.selectForKey(HashUtil.getHash(rpcServiceName + Arrays.stream(rpcRequest.getParameters())));
    }

    static class ConsistentHashSelector {
        // 环
        private final TreeMap<Long, String> virtualServices;

        // 该环对应的服务列表的hashcode
        private final int identityHashCode;

        ConsistentHashSelector(List<Instance> instances, int identityHashCode) {
            this.virtualServices = new TreeMap<>();
            this.identityHashCode = identityHashCode;

            for (Instance instance : instances) {
                for (int i = 0; i < VIRTUAL_NODE_NUM; i++) {
                    String realName = instance.getServiceName() + ":" + instance.getIp() + ":" + instance.getPort();
                    String virtualNodeName = getVirtualNodeName(realName, i);
                    long hash = HashUtil.getHash(virtualNodeName);
                    virtualServices.put(hash, virtualNodeName);
                }
            }
        }

        public Instance selectForKey(long hashCode) {
            Map.Entry<Long, String> entry = virtualServices.tailMap(hashCode, true).firstEntry();

            // hash值在最尾部,应该映射到第一个instance上
            if (entry == null) {
                entry = virtualServices.firstEntry();
            }
            String virtualNodeName = entry.getValue();
            String realName = getRealNodeName(virtualNodeName);
            Instance instance = new Instance();
            
            String[] s = realName.split(":");
            instance.setServiceName(s[0]);
            instance.setIp(s[1]);
            instance.setPort(Integer.valueOf(s[2]));
            return instance;
        }
    }

    // 获取虚拟结点名称
    private static String getVirtualNodeName(String realName, int num) {
        return realName + "&&VN" + String.valueOf(num);
    }

    // 获取真实结点名称
    private static String getRealNodeName(String virtualName) {
        return virtualName.split("&&")[0];
    }

}
// FNV1_32_HASH算法
public class HashUtil {
    /**
     * 计算Hash值, 使用FNV1_32_HASH算法
     * @param str
     * @return
     */
    public static long getHash(String str) {
        final int p = 16777619;
        int hash = (int)2166136261L;
        for (int i = 0; i < str.length(); i++) {
            hash =( hash ^ str.charAt(i) ) * p;
        }
        hash += hash << 13;
        hash ^= hash >> 7;
        hash += hash << 3;
        hash ^= hash >> 17;
        hash += hash << 5;

        if (hash < 0) {
            hash = Math.abs(hash);
        }
        return hash;
    }
}

下述代码为Dubbo中的源码

public class ConsistentHashLoadBalance extends AbstractLoadBalance {
    public static final String NAME = "consistenthash";
    public static final String HASH_NODES = "hash.nodes";
    public static final String HASH_ARGUMENTS = "hash.arguments";
    private final ConcurrentMap<String, ConsistentHashSelector<?>> selectors = new ConcurrentHashMap<String, ConsistentHashSelector<?>>();
    @SuppressWarnings("unchecked")
    @Override
    protected <T> Invoker<T> doSelect(List<Invoker<T>> invokers, URL url, Invocation invocation) {
        String methodName = RpcUtils.getMethodName(invocation);
        String key = invokers.get(0).getUrl().getServiceKey() + "." + methodName;
        int invokersHashCode = getCorrespondingHashCode(invokers);
        ConsistentHashSelector<T> selector = (ConsistentHashSelector<T>) selectors.get(key);
        if (selector == null || selector.identityHashCode != invokersHashCode) {
            selectors.put(key, new ConsistentHashSelector<T>(invokers, methodName, invokersHashCode));
            selector = (ConsistentHashSelector<T>) selectors.get(key);
        }
        return selector.select(invocation);
    }

    public <T> int getCorrespondingHashCode(List<Invoker<T>> invokers){
        return invokers.hashCode();
    }

    private static final class ConsistentHashSelector<T> {

        private final TreeMap<Long, Invoker<T>> virtualInvokers;
        private final int replicaNumber;
        private final int identityHashCode;

        private final int[] argumentIndex;

        private Map<String, AtomicLong> serverRequestCountMap = new ConcurrentHashMap<>();

        private AtomicLong totalRequestCount;

        private int serverCount;

        private static final double OVERLOAD_RATIO_THREAD = 1.5F;

        ConsistentHashSelector(List<Invoker<T>> invokers, String methodName, int identityHashCode) {
            this.virtualInvokers = new TreeMap<Long, Invoker<T>>();
            this.identityHashCode = identityHashCode;
            URL url = invokers.get(0).getUrl();
            this.replicaNumber = url.getMethodParameter(methodName, HASH_NODES, 160);
            String[] index = COMMA_SPLIT_PATTERN.split(url.getMethodParameter(methodName, HASH_ARGUMENTS, "0"));
            argumentIndex = new int[index.length];
            for (int i = 0; i < index.length; i++) {
                argumentIndex[i] = Integer.parseInt(index[i]);
            }
            for (Invoker<T> invoker : invokers) {
                String address = invoker.getUrl().getAddress();
                for (int i = 0; i < replicaNumber / 4; i++) {
                    byte[] digest = Bytes.getMD5(address + i);
                    for (int h = 0; h < 4; h++) {
                        long m = hash(digest, h);
                        virtualInvokers.put(m, invoker);
                    }
                }
            }

            totalRequestCount = new AtomicLong(0);
            serverCount = invokers.size();
            serverRequestCountMap.clear();
        }

        public Invoker<T> select(Invocation invocation) {
            String key = toKey(invocation.getArguments());
            byte[] digest = Bytes.getMD5(key);
            return selectForKey(hash(digest, 0));
        }
        private String toKey(Object[] args) {
            StringBuilder buf = new StringBuilder();
            for (int i : argumentIndex) {
                if (i >= 0 && i < args.length) {
                    buf.append(args[i]);
                }
            }
            return buf.toString();
        }
        private Invoker<T> selectForKey(long hash) {
            Map.Entry<Long, Invoker<T>> entry = virtualInvokers.ceilingEntry(hash);
            if (entry == null) {
                entry = virtualInvokers.firstEntry();
            }

            String serverAddress = entry.getValue().getUrl().getAddress();

            double overloadThread = ((double) totalRequestCount.get() / (double) serverCount) * OVERLOAD_RATIO_THREAD;
            
            while (serverRequestCountMap.containsKey(serverAddress)
                && serverRequestCountMap.get(serverAddress).get() >= overloadThread) {
                entry = getNextInvokerNode(virtualInvokers, entry);
                serverAddress = entry.getValue().getUrl().getAddress();
            }
            if (!serverRequestCountMap.containsKey(serverAddress)) {
                serverRequestCountMap.put(serverAddress, new AtomicLong(1));
            } else {
                serverRequestCountMap.get(serverAddress).incrementAndGet();
            }
            totalRequestCount.incrementAndGet();

            return entry.getValue();
        }

        private Map.Entry<Long, Invoker<T>> getNextInvokerNode(TreeMap<Long, Invoker<T>> virtualInvokers, Map.Entry<Long, Invoker<T>> entry){
            Map.Entry<Long, Invoker<T>> nextEntry = virtualInvokers.higherEntry(entry.getKey());
            if(nextEntry == null){
                return virtualInvokers.firstEntry();
            }
            return nextEntry;
        }

        private long hash(byte[] digest, int number) {
            return (((long) (digest[3 + number * 4] & 0xFF) << 24)
                    | ((long) (digest[2 + number * 4] & 0xFF) << 16)
                    | ((long) (digest[1 + number * 4] & 0xFF) << 8)
                    | (digest[number * 4] & 0xFF))
                    & 0xFFFFFFFFL;
        }
    }

}

使用场景

哈希一致性算法,通常被用于需要实现会话粘滞的场景。将来源同一IP的请求分配到同一个服务器上处理,避免处理session共享问题。

当然,当这个服务宕机后,该session就会丢失。