简单的写一下吧,用python实现的,因为代码量少,可以更清晰的看原理
redis是NIO+IO多路复用无锁模式,利用这一特性,我们可以保证setnx命令不会出现多个线程同时成功,保证我们确实获得了该锁.
代码中有一些幂等性的细节,并且为不可重入锁,如果递归你们自己改一改,用个标记记录ID之类的。
锁的实现
import redis
import time
import uuid
import threading
conn = redis.StrictRedis(host='127.0.0.1', port=6379, db=0, password=123456)
def acquire_lock_with_timeout(lock_name, timeout=10, lock_timeout=20):
# 这里演示了 给锁加上过期时间,自动删除
# 这里的lock_timeout ,并不是申请锁的超时时间,而是自动删除锁的时间
# 一定要合理设置,确保这个timeout会比你所申请锁后处理业务的时间要长
identifier = str(uuid.uuid4())
timeout = time.time() + timeout
lock_name = "lock:" + lock_name
while time.time() < timeout:
if conn.setnx(lock_name, identifier):
conn.expire(lock_name, lock_timeout)
return identifier
elif not conn.ttl(lock_name): # 如果这个锁存在,但是没有被设置过期时间,那么我们设置过期时间
conn.expire(lock_name, lock_timeout)
time.sleep(0.001)
return False
# 用watch的原因是为了适用于有过期时间的锁
def release_lock(lock_name, identifier):
pipe = conn.pipeline(True)
lock_name = "lock:" + lock_name
while True:
try:
pipe.watch(lock_name) # 监视lock_name 没有被篡改
if pipe.get(lock_name) == identifier: # 如果这个键的值还是我们申请锁时返回的标识符
pipe.multi()
pipe.delete(lock_name) # 删除这个锁
pipe.execute()
return True
pipe.unwatch() # 如果没有达成上面的条件,说明锁已经被移除了
break # 停止监视,退出循环 最后会返回False
except redis.exceptions.WatchError: # 如果锁被篡改过,pass 继续执行while 重试
pass
return False
def test():
lock_res = acquire_lock_with_timeout("test")
if lock_res:
print(uuid.uuid4())
release_lock("test", lock_res)
if __name__ == '__main__':
t1 = threading.Thread(target=test)
t2 = threading.Thread(target=test)
t1.start()
t2.start()
time.sleep(10)
信号量实现
关于信号量的解释,总有人和多线程/进程的max数搞混
- 信号量:如果你有30个线程启动,那么同时就有30个线程运行了,只是这30个线程,被阻塞住,确保只有X个信号量个线程处理业务
- 递归锁:如果你有30个线程启动,那么同时就有30个线程运行了,只是每个线程中针对上锁的地方,都会有引用计数,当前线程
可以申请很多次这个锁,但是如果引用计数不为0,其他线程就永远无法进入- 线程/进程池的最大数:不管你注册了多少个worker,但是同时运行的线程/进程数,不能超过设置的最大数,它和信号量总是被人搞混
1.实现信号量可以用string实现,递增递减就完事了,但是这样无法设置过期时间,因为计数是在这个string里完成的
2.还是用string,但是一个信号量占用一个string,这样可以设置过期时间,例如:setnx semcount:1 identifier,但是时间复杂度跟随信号量数量递增
3.使用有序集合,将多个信号量持有者的信息存储到同一个结构里,并且可以利用分值进行移除过期信号量.
4.使用无序集合,使用scard完成计数,但是无法设置过期时间。
import redis
import time
import uuid
import threading
conn = redis.StrictRedis(host='192.168.0.6', port=6379, db=0, password=123456)
# 下面这个方法,思路是OK的,但是分布式信号量每台服务器的时间戳不可能完全相同一丝不差。
# 所造成的的结果就是,下面信号量的数量限制没有问题,但是公平性有问题
# 例如:服务器A和服务器B 同时请求5个信号量,但是服务器A的系统时间比服务器B的系统时间快10毫秒
# 那么最后,B服务器会获得更多的信号量,因为A服务器的排名永远会比B服务器的大(所以靠后)
def acquire_semaphore(semaphore_name, limit, timeout=10):
identifier = str(uuid.uuid4())
now = time.time()
pipe = conn.pipeline(True)
pipe.zremrangebyscore(semaphore_name, '-inf', now - timeout) # 清理过期的信号量持有者
pipe.zadd(semaphore_name, {identifier: now}) # 申请信号量
# 获取rank值,因为有序集合zrank是取排名,那么根据identifier对应的score(时间戳)取排名,就可以确定有多少个信号量了
pipe.zrank(semaphore_name, identifier)
if pipe.execute()[-1] < limit: # 执行 取最后一个元素(rank值),如果rank值<限制数量 那么返回标识符
return identifier
conn.zrem(semaphore_name, identifier) # 否则 移除刚刚添加的标识符
return None
# 有序集合移除对应的是集合内的key,所以我们不需要watch,直接删除即可
def release_semaphore(semaphore_name, identifier):
return conn.zrem(semaphore_name, identifier)
# 公平的信号量
# 如果没有过期时间的设置,可以不要记录时间戳的有序集合
# 这个虽然公平,但是需要加分布式锁,因为我们分了2次执行(为了省事,例子中的代码我没加锁,注意了)
# 如果A服务器先自增,但是后面的还没执行
# B服务器在A服务器自增之后才开始执行,但是速度更快,已经执行完了zrank,那么B就会偷走A服务器申请的信号量
# 然后A服务器,继续往下走,rank排名又拿到一个信号量,这时候信号量可能已经超过的limit限制
# 解决了竞争关系后 还是有一个缺陷,那就是在32位平台下,每过2小时自增数值就会溢出
# 解决办法: 单开一条小线程,轮询检查(每1分钟一次),当semaphore_name集合为空的时候,清空计数器.
def acquire_fair_semaphore(semaphore_name, limit, timeout=10):
identifier = str(uuid.uuid4())
czset = semaphore_name + ':owner'
ctr = semaphore_name + ':counter'
now = time.time()
pipe = conn.pipeline(True)
pipe.zremrangebyscore(semaphore_name, '-inf', now - timeout) # 清理过期的信号量持有者
# 移除semaphore_name有序集合后 和czset按权重求交集,不需要rem czset
pipe.zinterstore(czset, {czset: 1, semaphore_name: 0}) # ZINTERSTORE czset 2 czset semaphore_name WEIGHTS 1 0
pipe.incr(ctr) # 计数器自增
counter = pipe.execute()[-1] # 获取计数器数量
pipe.zadd(semaphore_name, {identifier: now}) # 申请信号量
pipe.zadd(czset, {identifier: counter})
pipe.zrank(czset, identifier) # 因为权重求交集后 对应的标识符在czset里 存放的是counter
if pipe.execute()[-1] < limit: # 按照czset的rank限制信号数量
return identifier
pipe.zrem(semaphore_name, identifier)
pipe.zrem(czset, identifier)
pipe.execute()
return None
# 有序集合移除对应的是集合内的key,所以我们不需要watch,直接删除即可
def release_fair_semaphore(semaphore_name, identifier):
pipe = conn.pipeline(True)
pipe.zrem(semaphore_name, identifier)
pipe.zrem(semaphore_name + ':owner', identifier)
return pipe.execute()[0]
# 刷新信号量
def refresh_fair_semaphore(semaphore_name, identifier):
if conn.zadd(semaphore_name, {identifier: time.time()}): # 如果添加成功,说明这个键在集合里已经不存在了
release_fair_semaphore(semaphore_name, identifier) # 释放
return False
return True
# 由于上面的有序集合方法 counter可能会造成溢出,我们使用无序集合同样可以实现信号量
# 缺点: 我们无法给信号量设置持有时间,如果再用另外一个类型记录过期时间,你需要不停的遍历,性能非常差。
# 如果某一个服务器在申请信号量后宕机,那么不管过多久,真实的信号量永远会少一个
# 这是布置集群的大忌,如果不考虑服务器出问题,这么做是没有问题的
def acquire_test_semaphore(semaphore_name, limit):
identifier = str(uuid.uuid4())
pipe = conn.pipeline(True)
while True:
try:
pipe.watch(semaphore_name)
pipe.multi()
pipe.sadd(semaphore_name, identifier)
pipe.scard(semaphore_name)
count = pipe.execute()[-1]
# print(count)
if count <= limit:
return identifier
else:
release_test_semaphore(semaphore_name, identifier)
break
except redis.exceptions.WatchError:
pass
return None
def release_test_semaphore(semaphore_name, identifier):
return conn.srem(semaphore_name, identifier)
class ThreadSemaphore(threading.Thread):
def run(self):
identifier = acquire_test_semaphore("semtest", 3)
if identifier:
print("申请到了")
time.sleep(3)
release_test_semaphore("semtest", identifier)
if __name__ == '__main__':
for i in range(10):
ThreadSemaphore().start()