简单的写一下吧,用python实现的,因为代码量少,可以更清晰的看原理

redis是NIO+IO多路复用无锁模式,利用这一特性,我们可以保证setnx命令不会出现多个线程同时成功,保证我们确实获得了该锁.
代码中有一些幂等性的细节,并且为不可重入锁,如果递归你们自己改一改,用个标记记录ID之类的。

锁的实现

import redis
import time
import uuid
import threading
conn = redis.StrictRedis(host='127.0.0.1', port=6379, db=0, password=123456)


def acquire_lock_with_timeout(lock_name, timeout=10, lock_timeout=20):
    # 这里演示了 给锁加上过期时间,自动删除
    # 这里的lock_timeout ,并不是申请锁的超时时间,而是自动删除锁的时间
    # 一定要合理设置,确保这个timeout会比你所申请锁后处理业务的时间要长
    identifier = str(uuid.uuid4())
    timeout = time.time() + timeout
    lock_name = "lock:" + lock_name
    while time.time() < timeout:
        if conn.setnx(lock_name, identifier): 
            conn.expire(lock_name, lock_timeout)
            return identifier
        elif not conn.ttl(lock_name):  # 如果这个锁存在,但是没有被设置过期时间,那么我们设置过期时间
            conn.expire(lock_name, lock_timeout)
        time.sleep(0.001)
    return False


# 用watch的原因是为了适用于有过期时间的锁
def release_lock(lock_name, identifier):
    pipe = conn.pipeline(True)
    lock_name = "lock:" + lock_name
    while True:
        try:
            pipe.watch(lock_name)  # 监视lock_name 没有被篡改
            if pipe.get(lock_name) == identifier:  # 如果这个键的值还是我们申请锁时返回的标识符
                pipe.multi()
                pipe.delete(lock_name)  # 删除这个锁
                pipe.execute()
                return True
            pipe.unwatch()  # 如果没有达成上面的条件,说明锁已经被移除了
            break   # 停止监视,退出循环  最后会返回False
        except redis.exceptions.WatchError:  # 如果锁被篡改过,pass 继续执行while 重试
            pass
        return False


def test():
    lock_res = acquire_lock_with_timeout("test")
    if lock_res:
        print(uuid.uuid4())
        release_lock("test", lock_res)
    


if __name__ == '__main__':
    t1 = threading.Thread(target=test)
    t2 = threading.Thread(target=test)
    t1.start()
    t2.start()
    time.sleep(10)

信号量实现

关于信号量的解释,总有人和多线程/进程的max数搞混

  • 信号量:如果你有30个线程启动,那么同时就有30个线程运行了,只是这30个线程,被阻塞住,确保只有X个信号量个线程处理业务
  • 递归锁:如果你有30个线程启动,那么同时就有30个线程运行了,只是每个线程中针对上锁的地方,都会有引用计数,当前线程
    可以申请很多次这个锁,但是如果引用计数不为0,其他线程就永远无法进入
  • 线程/进程池的最大数:不管你注册了多少个worker,但是同时运行的线程/进程数,不能超过设置的最大数,它和信号量总是被人搞混

1.实现信号量可以用string实现,递增递减就完事了,但是这样无法设置过期时间,因为计数是在这个string里完成的
2.还是用string,但是一个信号量占用一个string,这样可以设置过期时间,例如:setnx semcount:1 identifier,但是时间复杂度跟随信号量数量递增
3.使用有序集合,将多个信号量持有者的信息存储到同一个结构里,并且可以利用分值进行移除过期信号量.
4.使用无序集合,使用scard完成计数,但是无法设置过期时间。

import redis
import time
import uuid
import threading

conn = redis.StrictRedis(host='192.168.0.6', port=6379, db=0, password=123456)

# 下面这个方法,思路是OK的,但是分布式信号量每台服务器的时间戳不可能完全相同一丝不差。
# 所造成的的结果就是,下面信号量的数量限制没有问题,但是公平性有问题
# 例如:服务器A和服务器B 同时请求5个信号量,但是服务器A的系统时间比服务器B的系统时间快10毫秒
# 那么最后,B服务器会获得更多的信号量,因为A服务器的排名永远会比B服务器的大(所以靠后)
def acquire_semaphore(semaphore_name, limit, timeout=10):
    identifier = str(uuid.uuid4())
    now = time.time()
    pipe = conn.pipeline(True)
    pipe.zremrangebyscore(semaphore_name, '-inf', now - timeout)  # 清理过期的信号量持有者
    pipe.zadd(semaphore_name, {identifier: now})  # 申请信号量
    # 获取rank值,因为有序集合zrank是取排名,那么根据identifier对应的score(时间戳)取排名,就可以确定有多少个信号量了
    pipe.zrank(semaphore_name, identifier)
    if pipe.execute()[-1] < limit:  # 执行 取最后一个元素(rank值),如果rank值<限制数量  那么返回标识符
        return identifier
    conn.zrem(semaphore_name, identifier)  # 否则 移除刚刚添加的标识符
    return None


# 有序集合移除对应的是集合内的key,所以我们不需要watch,直接删除即可
def release_semaphore(semaphore_name, identifier):
    return conn.zrem(semaphore_name, identifier)


# 公平的信号量
# 如果没有过期时间的设置,可以不要记录时间戳的有序集合
# 这个虽然公平,但是需要加分布式锁,因为我们分了2次执行(为了省事,例子中的代码我没加锁,注意了)
# 如果A服务器先自增,但是后面的还没执行
# B服务器在A服务器自增之后才开始执行,但是速度更快,已经执行完了zrank,那么B就会偷走A服务器申请的信号量
# 然后A服务器,继续往下走,rank排名又拿到一个信号量,这时候信号量可能已经超过的limit限制
# 解决了竞争关系后 还是有一个缺陷,那就是在32位平台下,每过2小时自增数值就会溢出
# 解决办法: 单开一条小线程,轮询检查(每1分钟一次),当semaphore_name集合为空的时候,清空计数器.
def acquire_fair_semaphore(semaphore_name, limit, timeout=10):
    identifier = str(uuid.uuid4())
    czset = semaphore_name + ':owner'
    ctr = semaphore_name + ':counter'
    now = time.time()
    pipe = conn.pipeline(True)
    pipe.zremrangebyscore(semaphore_name, '-inf', now - timeout)  # 清理过期的信号量持有者
    # 移除semaphore_name有序集合后 和czset按权重求交集,不需要rem czset
    pipe.zinterstore(czset, {czset: 1, semaphore_name: 0})  # ZINTERSTORE czset 2 czset semaphore_name WEIGHTS 1 0

    pipe.incr(ctr)  # 计数器自增
    counter = pipe.execute()[-1]  # 获取计数器数量
    pipe.zadd(semaphore_name, {identifier: now})  # 申请信号量
    pipe.zadd(czset, {identifier: counter})
    pipe.zrank(czset, identifier)  # 因为权重求交集后 对应的标识符在czset里  存放的是counter
    if pipe.execute()[-1] < limit:  # 按照czset的rank限制信号数量
        return identifier

    pipe.zrem(semaphore_name, identifier)
    pipe.zrem(czset, identifier)
    pipe.execute()
    return None


# 有序集合移除对应的是集合内的key,所以我们不需要watch,直接删除即可
def release_fair_semaphore(semaphore_name, identifier):
    pipe = conn.pipeline(True)
    pipe.zrem(semaphore_name, identifier)
    pipe.zrem(semaphore_name + ':owner', identifier)
    return pipe.execute()[0]


# 刷新信号量
def refresh_fair_semaphore(semaphore_name, identifier):
    if conn.zadd(semaphore_name, {identifier: time.time()}):  # 如果添加成功,说明这个键在集合里已经不存在了
        release_fair_semaphore(semaphore_name, identifier)  # 释放
        return False
    return True


# 由于上面的有序集合方法 counter可能会造成溢出,我们使用无序集合同样可以实现信号量
# 缺点: 我们无法给信号量设置持有时间,如果再用另外一个类型记录过期时间,你需要不停的遍历,性能非常差。
# 如果某一个服务器在申请信号量后宕机,那么不管过多久,真实的信号量永远会少一个
# 这是布置集群的大忌,如果不考虑服务器出问题,这么做是没有问题的
def acquire_test_semaphore(semaphore_name, limit):
    identifier = str(uuid.uuid4())
    pipe = conn.pipeline(True)
    while True:
        try:
            pipe.watch(semaphore_name)
            pipe.multi()
            pipe.sadd(semaphore_name, identifier)
            pipe.scard(semaphore_name)
            count = pipe.execute()[-1]
            # print(count)
            if count <= limit:
                return identifier
            else:
                release_test_semaphore(semaphore_name, identifier)
                break
        except redis.exceptions.WatchError:
            pass
    return None


def release_test_semaphore(semaphore_name, identifier):
    return conn.srem(semaphore_name, identifier)


class ThreadSemaphore(threading.Thread):
    def run(self):
        identifier = acquire_test_semaphore("semtest", 3)
        if identifier:
            print("申请到了")
            time.sleep(3)
            release_test_semaphore("semtest", identifier)


if __name__ == '__main__':
    for i in range(10):
        ThreadSemaphore().start()