1. 基本概念

1.1. CAS

CAS全称 Compare And Swap(比较与交换),是一种无锁算法。在不使用锁(没有线程被阻塞)的情况下实现多线程之间的变量同步。java.util.concurrent包中的原子类就是通过CAS来实现了乐观锁。

CAS算法涉及到三个操作数:

  • 需要读写的内存值 V。
  • 进行比较的值 A。
  • 要写入的新值 B。

当且仅当 V 的值等于 A 时,CAS通过原子方式用新值B来更新V的值(“比较+更新”整体是一个原子操作),否则不会执行任何操作。一般情况下,“更新”是一个不断重试的操作。详见Unsafe类中的getAndAddInt方法

1.2. ABA问题

CAS虽然很高效,但存在ABA问题

ABA问题。CAS需要在操作值的时候检查内存值是否发生变化,没有发生变化才会更新内存值。但是如果内存值原来是A,后来变成了B,然后又变成了A,那么CAS进行检查时会发现值没有发生变化,但是实际上是有变化的。ABA问题的解决思路就是在变量前面添加版本号,每次变量更新的时候都把版本号加一,这样变化过程就从“A-B-A”变成了“1A-2B-3A”。

2. ABA问题解决之道

JDK从1.5开始提供了AtomicStampedReference类来解决ABA问题,具体操作封装在compareAndSet()中。compareAndSet()首先检查当前引用和当前标志与预期引用和预期标志是否相等,如果都相等,则以原子方式将引用值和标志的值设置为给定的更新值

2.1 重现ABA问题

对ABA问题进行场景重现,示例代码如下:两个线程启动时,获取AtomicInteger变量的当前值,然后一个线程将值修改为B再改回A,但线程B感知不到这种变化,仍然将值修改为12

线程A 启动,当前值是:10

线程B 启动,当前值是:10

线程A:10->11

线程A:10->11->10

线程B: index是预期的10:true,新值是:12

public class ABATest {
    public static void main(String[] args) throws InterruptedException {
        abaRepeat();
    }

    private static void abaRepeat() throws InterruptedException {
        AtomicInteger index = new AtomicInteger(10);
        Thread threadA = new Thread(() -> {
            try {
                System.out.println(String.format("%s 启动,当前值是:%s", Thread.currentThread().getName(), index.get()));
                TimeUnit.MILLISECONDS.sleep(1000);
                index.compareAndSet(10, 11);
                System.out.println(String.format("%s:%s", Thread.currentThread().getName(), "10->11"));
                index.compareAndSet(11, 10);
                System.out.println(String.format("%s:%s", Thread.currentThread().getName(), "10->11->10"));
            } catch (InterruptedException e) {
            }
        }, "线程A");

        Thread threadB = new Thread(() -> {
            try {
                System.out.println(String.format("%s 启动,当前值是:%s", Thread.currentThread().getName(), index.get()));
                TimeUnit.MILLISECONDS.sleep(2000);
                boolean isSucc = index.compareAndSet(10, 12);
                System.out.println(String.format("%s: index是预期的10:%s,新值是:%s", Thread.currentThread().getName(), isSucc, index.get()));
            } catch (InterruptedException e) {
                e.printStackTrace();
            }
        }, "线程B");

        threadA.start();
        threadB.start();

        threadA.join();
        threadB.join();
    }
}

2.1 ABA问题修正

AtomicStampedReference是一个带有时间戳的对象引用,能很好的解决CAS机制中的ABA问题

示例使用AtomicStampedReference解决这个问题的,注意:代码通过 ref.get(stamp); 同时获取时间戳和数据,防止获取到数据和版本不是一致的,不能先取值再取时间戳,否则仍将有ABA问题

示例代码输出如下:

线程B 启动,当前值是:10,版本:0
       线程A 启动,当前值是:10,版本:0
       线程A:10->11,1
       线程A:10->11->10,2
       线程B: index是预期的10:false,新值是:10,版本:0

public class ABATest {
    public static void main(String[] args) throws InterruptedException {
        abaCorrect();
    }

    private static void abaCorrect() throws InterruptedException {
        AtomicStampedReference<Integer> ref = new AtomicStampedReference<Integer>(10, 0);
        Thread threadA = new Thread(() -> {
            try {
                int[] stamp = new int[1];
                Integer value = ref.get(stamp); //同时获取时间戳和数据,防止获取到数据和版本不是一致的

                System.out.println(String.format("%s 启动,当前值是:%s,版本:%s", Thread.currentThread().getName(), ref.getReference(), stamp[0]));
                TimeUnit.MILLISECONDS.sleep(1000);

                int newValue = value + 1;
                boolean writeOk = ref.compareAndSet(value, newValue, stamp[0], stamp[0] + 1);

                System.out.println(String.format("%s:%s,%s", Thread.currentThread().getName(), "10->11", writeOk ? stamp[0] + 1 : stamp[0]));
                stamp = new int[1];
                value = ref.get(stamp); //同时获取时间戳和数据,防止获取到数据和版本不是一致的
                newValue = value - 1;
                writeOk = ref.compareAndSet(value, newValue, stamp[0], stamp[0] + 1);
                System.out.println(String.format("%s:%s,%s", Thread.currentThread().getName(), "10->11->10", writeOk ? stamp[0] + 1 : stamp[0]));
            } catch (InterruptedException e) {
            }
        }, "线程A");

        Thread threadB = new Thread(() -> {
            try {
                int[] stamp = new int[1];
                Integer value = ref.get(stamp); //同时获取时间戳和数据,防止获取到数据和版本不是一致的

                System.out.println(String.format("%s 启动,当前值是:%s,版本:%s", Thread.currentThread().getName(), ref.getReference(), stamp[0]));
                TimeUnit.MILLISECONDS.sleep(2000);

                int newValue = value + 2;
                boolean writeOk = ref.compareAndSet(value, newValue, stamp[0], stamp[0] + 1);

                System.out.println(String.format("%s: index是预期的10:%s,新值是:%s,版本:%s", Thread.currentThread().getName(), writeOk, ref.getReference(), writeOk ? stamp[0] + 1 : stamp[0]));
            } catch (InterruptedException e) {
                e.printStackTrace();
            }
        }, "线程B");

        threadA.start();
        threadB.start();

        threadA.join();
        threadB.join();
    }

}

2.3. 源码解析

AtomicStampedReference通过Pair对象维护了一个对象的引用 及一个时间戳,在更新时,通过UNSAFE.compareAndSwapObject 实现对象修改的原子操作。

private static class Pair<T> {
        final T reference;
        final int stamp;
        private Pair(T reference, int stamp) {
            this.reference = reference;
            this.stamp = stamp;
        }
        static <T> Pair<T> of(T reference, int stamp) {
            return new Pair<T>(reference, stamp);
        }
    }
 

/**
     * Atomically sets the value of both the reference and stamp
     * to the given update values if the
     * current reference is {@code ==} to the expected reference
     * and the current stamp is equal to the expected stamp.
     *
     * @param expectedReference the expected value of the reference
     * @param newReference the new value for the reference
     * @param expectedStamp the expected value of the stamp
     * @param newStamp the new value for the stamp
     * @return {@code true} if successful
     */
    public boolean compareAndSet(V   expectedReference,
                                 V   newReference,
                                 int expectedStamp,
                                 int newStamp) {
        Pair<V> current = pair;
        return
            expectedReference == current.reference &&
            expectedStamp == current.stamp &&
            ((newReference == current.reference &&
              newStamp == current.stamp) ||
             casPair(current, Pair.of(newReference, newStamp)));
    }

private boolean casPair(Pair<V> cmp, Pair<V> val) {
        return UNSAFE.compareAndSwapObject(this, pairOffset, cmp, val);
    }