基于RedisTemplate实现分布式锁+守护线程

  • 1. 前言
  • 2. 实现的过程
  • 3.示例


1. 前言

最近由于项目需要一个比较轻量化的分布式锁,开始考虑使用Redisson,简单调研了一下发现对于我们的项目而言有点重,所以就想上网找一下比较轻量化的分布式锁,但是一圈下来大多数都是说思路有具体实现的比较少,下面分享一下我回忆之前看过的一个大神的源码,然后自己改造的一个轻量化的分布式锁的源码。

2. 实现的过程

2.1 首先是基于RedisTemplate实现的redis操作工具类RedisService,将需要用到的redis操作进行封装,其中最重要的就是redisTemplate.opsForValue().setIfAbsent()方法,这个方法底层是基于redis的setNX实现的,redis的setNX是现实分布式锁的基础(原理这里就不过多赘述了,想了解的可以自行百度)。

/**
 * Description   :redis工具类
 * Date          :2023-03-17 16:13:55
 * Author        :wandy
 */
@SuppressWarnings(value = {"unchecked", "rawtypes"})
@Component
public class RedisService
{
    @Autowired
    public RedisTemplate redisTemplate;

    /**
     * 缓存基本的对象,Integer、String、实体类等
     *
     * @param key   缓存的键值
     * @param value 缓存的值
     */
    public <T> void setCacheObject(final String key, final T value)
    {
        redisTemplate.opsForValue().set(key, value);
    }

    /**
     * 基于setNx实现的set方法
     * 
     * @param key   缓存的键值
     * @param value 缓存的值
     * @return true=设置成功;false=设置失败
     */
    public boolean setIfAbsent(final String key, final String value)
    {
        return redisTemplate.opsForValue().setIfAbsent(key, value);
    }

    /**
     * 缓存基本的对象,Integer、String、实体类等
     *
     * @param key      缓存的键值
     * @param value    缓存的值
     * @param timeout  时间
     * @param timeUnit 时间颗粒度
     */
    public <T> void setCacheObject(final String key, final T value, final Long timeout, final TimeUnit timeUnit)
    {
        redisTemplate.opsForValue().set(key, value, timeout, timeUnit);
    }

    /**
     * 设置有效时间
     *
     * @param key     Redis键
     * @param timeout 超时时间
     * @return true=设置成功;false=设置失败
     */
    public boolean expire(final String key, final long timeout)
    {
        return expire(key, timeout, TimeUnit.SECONDS);
    }

    /**
     * 设置有效时间
     *
     * @param key     Redis键
     * @param timeout 超时时间
     * @param unit    时间单位
     * @return true=设置成功;false=设置失败
     */
    public boolean expire(final String key, final long timeout, final TimeUnit unit)
    {
        return redisTemplate.expire(key, timeout, unit);
    }

    /**
     * 获取有效时间
     *
     * @param key Redis键
     * @return 有效时间
     */
    public long getExpire(final String key)
    {
        return redisTemplate.getExpire(key);
    }

    /**
     * 判断 key是否存在
     *
     * @param key 键
     * @return true 存在 false不存在
     */
    public Boolean hasKey(String key)
    {
        return redisTemplate.hasKey(key);
    }

    /**
     * 获得缓存的基本对象。
     *
     * @param key 缓存键值
     * @return 缓存键值对应的数据
     */
    public <T> T getCacheObject(final String key)
    {
        ValueOperations<String, T> operation = redisTemplate.opsForValue();
        return operation.get(key);
    }

    /**
     * 删除单个对象
     *
     * @param key
     */
    public boolean deleteObject(final String key)
    {
        return redisTemplate.delete(key);
    }

    /**
     * 删除集合对象
     *
     * @param collection 多个对象
     * @return
     */
    public boolean deleteObject(final Collection collection)
    {
        return redisTemplate.delete(collection) > 0;
    }

    /**
     * 缓存List数据
     *
     * @param key      缓存的键值
     * @param dataList 待缓存的List数据
     * @return 缓存的对象
     */
    public <T> long setCacheList(final String key, final List<T> dataList)
    {
        Long count = redisTemplate.opsForList().rightPushAll(key, dataList);
        return count == null ? 0 : count;
    }

    /**
     * 获得缓存的list对象
     *
     * @param key 缓存的键值
     * @return 缓存键值对应的数据
     */
    public <T> List<T> getCacheList(final String key)
    {
        return redisTemplate.opsForList().range(key, 0, -1);
    }

    /**
     * 缓存Set
     *
     * @param key     缓存键值
     * @param dataSet 缓存的数据
     * @return 缓存数据的对象
     */
    public <T> BoundSetOperations<String, T> setCacheSet(final String key, final Set<T> dataSet)
    {
        BoundSetOperations<String, T> setOperation = redisTemplate.boundSetOps(key);
        Iterator<T> it = dataSet.iterator();
        while (it.hasNext())
        {
            setOperation.add(it.next());
        }
        return setOperation;
    }

    /**
     * 获得缓存的set
     *
     * @param key
     * @return
     */
    public <T> Set<T> getCacheSet(final String key)
    {
        return redisTemplate.opsForSet().members(key);
    }

    /**
     * 缓存Map
     *
     * @param key
     * @param dataMap
     */
    public <T> void setCacheMap(final String key, final Map<String, T> dataMap)
    {
        if (dataMap != null)
        {
            redisTemplate.opsForHash().putAll(key, dataMap);
        }
    }

    /**
     * 获得缓存的Map
     *
     * @param key
     * @return
     */
    public <T> Map<String, T> getCacheMap(final String key)
    {
        return redisTemplate.opsForHash().entries(key);
    }

    /**
     * 往Hash中存入数据
     *
     * @param key   Redis键
     * @param hKey  Hash键
     * @param value 值
     */
    public <T> void setCacheMapValue(final String key, final String hKey, final T value)
    {
        redisTemplate.opsForHash().put(key, hKey, value);
    }

    /**
     * 获取Hash中的数据
     *
     * @param key  Redis键
     * @param hKey Hash键
     * @return Hash中的对象
     */
    public <T> T getCacheMapValue(final String key, final String hKey)
    {
        HashOperations<String, String, T> opsForHash = redisTemplate.opsForHash();
        return opsForHash.get(key, hKey);
    }

    /**
     * 获取多个Hash中的数据
     *
     * @param key   Redis键
     * @param hKeys Hash键集合
     * @return Hash对象集合
     */
    public <T> List<T> getMultiCacheMapValue(final String key, final Collection<Object> hKeys)
    {
        return redisTemplate.opsForHash().multiGet(key, hKeys);
    }

    /**
     * 删除Hash中的某条数据
     *
     * @param key  Redis键
     * @param hKey Hash键
     * @return 是否成功
     */
    public boolean deleteCacheMapValue(final String key, final String hKey)
    {
        return redisTemplate.opsForHash().delete(key, hKey) > 0;
    }

    /**
     * 获得缓存的基本对象列表
     *
     * @param pattern 字符串前缀
     * @return 对象列表
     */
    public Collection<String> keys(final String pattern)
    {
        return redisTemplate.keys(pattern);
    }
}

2.2 基于RedisService对Redis进行操作实现分布式锁RedisLock,其中包括:创建锁对象、获取锁、释放锁、锁延时等方法。

/**
 * Description   :Redis锁
 * Date          :2023/3/16 16:34
 * Author        :wandy
 */
@Slf4j
public class RedisLock {

    private RedisService redisService= SpringUtils.getBean(RedisService.class);

    /** 锁前缀*/
    private static final String LOCK_PREFIX = "RedisLock:";
    /** 加锁标志 */
    private String lockedFlag;
    /** 毫秒与毫微秒的换算单位 1毫秒=1000000毫微秒. */
    private static final long MILLI_NANO_CONVERSION = 1000*1000L;
    /** 默认超时时间(毫秒). */
    private static final long DEFAULT_TIME_OUT = 3*60*1000;
    private static Random RANDOM = new Random();
    /** 锁的超时时间(秒),过期删除. */
    private static final int EXPIRE = 3*60;
    /**锁键**/
    private String key;
    /** 锁状态标志. */
    private boolean locked = false;

    /**
     * 获取锁对象
     * @param key
     * @return
     */
    public static RedisLock getRedisLock(String key){
        return new RedisLock(key, UUID.randomUUID().toString());
    }

    /**
     * 构造函数.
     * @param key key
     */
    private RedisLock(String key,String lockedFlag) {
        this.key = LOCK_PREFIX+key;
        this.lockedFlag=lockedFlag;
    }

    /**
     * 获取锁
     * @return true:成功;false:失败
     */
    public boolean lock() {
        return lock(DEFAULT_TIME_OUT);
    }

    /**
     * 释放锁
     * 无论加锁是否成功,都需要调用该方法进行释放锁
     */
    public void unlock() {
        if (locked) {
            //判断锁lockedFlag是否一致,防止释放同key的其他锁
            String value = redisService.getCacheObject(key);
            if(lockedFlag.equals(value)) redisService.deleteObject(key);
            locked=false;
        }
        log.info("RedisLock解锁:{}", key);
    }

    /**
     * 加锁
     * @param timeout 获取锁自旋等待时间(毫秒)
     * @return true:成功;false:失败
     */
    public boolean lock(long timeout) {
        return lock(timeout, EXPIRE);
    }

    /**
     * 加锁.
     * @param timeout 获取锁自旋等待时间(毫秒)
     * @param expire 锁的超时时间(秒),过期删除
     * @return true:成功;false:失败
     */
    public boolean lock(long timeout, int expire) {
        long nano = System.nanoTime();
        timeout *= MILLI_NANO_CONVERSION;
        try {
            if(timeout>0){
                while ((System.nanoTime()-nano)<timeout) {
                    if (redisService.setIfAbsent(key,lockedFlag)) {
                        if(expire>0){
                            redisService.expire(key, expire, TimeUnit.SECONDS);
                        }
                        locked = true;
                        log.info("RedisLock锁定:{},线程{}", key,Thread.currentThread().getName());
                        //开启守护线程
                        startGuard(expire);
                        return locked;
                    }
                    // 短暂休眠,避免出现死锁
                    Thread.sleep(3, RANDOM.nextInt(500));
                    log.info("RedisLock等待资源:{},线程{}", key,Thread.currentThread().getName());
                }
            }else{
                if (redisService.setIfAbsent(key,lockedFlag)) {
                    if(expire>0){
                        redisService.expire(key, expire, TimeUnit.SECONDS);
                    }
                    locked = true;
                    //开启守护线程
                    startGuard(expire);
                    return locked;
                }
            }

        } catch (Exception e) {
            throw new RuntimeException("Locking error", e);
        }
        return false;
    }

    //开启守护线程
    private void startGuard(int expire){
        RedisLockGuard redisLockGuard = new RedisLockGuard(key, expire, this);
        Thread thread = new Thread(redisLockGuard);
        //设置为守护线程,所服务的用户线程不存在时自动无效
        thread.setDaemon(Boolean.TRUE);
        thread.start();
    }

    /**
     * 锁延时
     * @param expireTime
     * @return
     */
    public Boolean lockDelay(int expireTime) {
        if(locked){
            //判断锁lockedFlag是否一致
            String value = redisService.getCacheObject(key);
            if(lockedFlag.equals(value)) {
                redisService.redisTemplate.expire(key, expireTime, TimeUnit.SECONDS);
                return true;
            }
        }
        return false;
    }

}

2.3 通过实现Runnable接口为RedisLock加一个守护线程RedisLockGuard,防止锁过期被自动释放,但是业务逻辑还没有执行完的情况。

/**
 * Description   :Redis锁延时守护线程
 * Date          :2023/3/16 17:05
 * Author        :wandy
 */
@Slf4j
public class RedisLockGuard implements Runnable{

    private String key;
    private int expireTime;
    private boolean isRunning;
    private RedisLock redisLock;


    public RedisLockGuard(String key,int expireTime,RedisLock redisLock) {
        this.expireTime = expireTime;
        this.isRunning = Boolean.TRUE;
        this.redisLock = redisLock;
        this.key = key;
    }

    @Override
    public void run() {
        long nano = System.nanoTime();
        long timeout = 3 * expireTime * 1000 * 1000 * 1000L;//最多持续锁超时时间的3倍,expireTime时间单位是秒,转化成微秒
        long waitTime = expireTime * 1000 * 2 / 3;// 线程等待多长时间后执行,expireTime时间单位是秒,转化成毫秒
        while (isRunning) {
            if((System.nanoTime()-nano)>=timeout){
                this.stop();
                log.error("RedisLock延时守护已超时:{}",key);
                continue;
            }
            try {
                Thread.sleep(waitTime);
                if (redisLock.lockDelay(expireTime)) {
                    log.info("RedisLock延时守护延时成功:{}",key);
                } else {
                    log.info("RedisLock延时守护已关闭:{}",key);
                    this.stop();
                }
            } catch (Exception e) {
                e.printStackTrace();
            }
        }
    }
    private void stop() {
        this.isRunning = Boolean.FALSE;
    }

}
3.示例
@GetMapping("testRedisLock")
    public AjaxResult testRedisLock (){

        RedisLock redisLock = RedisLock.getRedisLock("test-1");
        if(redisLock.lock()){
            try {
                //锁定逻辑
                Thread.sleep(200*1000);

            } catch (Exception e) {
                e.printStackTrace();
            } finally {
                //释放锁
                redisLock.unlock();
            }
        }else {
            logger.error("获取锁失败,线程:{}",Thread.currentThread().getName());
        }
        return AjaxResult.success();
    }

欢迎各路大神多多指正,大家一起学习