问题引出

当我们使用多线程计数操作的时候,我们使用如下代码

package com.mmall.example.count;

import javax.xml.bind.SchemaOutputResolver;
import java.util.concurrent.CountDownLatch;
import java.util.concurrent.TimeUnit;

/**
 * 测试多线程计数类
 */
public class Test0913 {
    //总访问量
    static int count = 0;

    //模拟访问的方法
    public static void request() throws InterruptedException {
        //模拟耗时5毫秒
        TimeUnit.MILLISECONDS.sleep(5);
        count++;
    }

    //并发程序
    public static void main(String[] args) throws InterruptedException {
        //计算耗时需要获取当前时间、运行后时间,计算他们之间的差值
        long startTime = System.currentTimeMillis();
        //设置线程个数
        int threadsize = 100;
        //创建一个栅栏
        CountDownLatch countDownLatch = new CountDownLatch(threadsize);


        for (int i = 0; i < threadsize; i++) {
            //创建线程
            Thread thread = new Thread(new Runnable() {
                @Override
                public void run() {
                    //模拟用户行为,每个用户访问十次网站

                    try {
                        for (int j = 0; j < 10; j++) {
                            request();
                        }
                    } catch (InterruptedException e) {
                        e.printStackTrace();
                    } finally {
                        countDownLatch.countDown();
                    }
                }
            });
            thread.start();
        }
        //只有await方法的值为0整个程序才会结束
        countDownLatch.await();
        //拿到程序结束的时间
        long endTime = System.currentTimeMillis();

        System.out.println(Thread.currentThread().getName() + ",耗时:" + (endTime - startTime) + ", count:" + count);
    }
}

但是从该类的输出结果我们发现了输出的计数值总是不确定的,并且计数值不是我们想要的。这是因为存在线程不同步问题
并发编程(1)_CAS_java

并发编程(1)_CAS_多线程_02

问题分析

这里问题出自request方法中的count++操作,因为count++语句实际上是需要三步来完成(和jvm有关)

//count操作分解
1. 获取count的值,记作A A = count
2. 将A值+1,得到B B = A+1
3. 将B值赋值给count

如果由a b 两个线程,同时执行count++,那么假如a和b同时执行上面的第一步部操作,最后执行完后a+1、b+1,虽然线程进行了两次加法,但是count相当于只加了一次。

在这里为了避免这种错误,我们希望在执行count++操作的时候,同一时刻只允许一个线程去进行这个操作,因此引出了的概念,在java中,我们利用synchronized关键字来实现对资源的加锁。
于是我们将代码改进如下

package com.mmall.example.count;

import java.util.concurrent.CountDownLatch;
import java.util.concurrent.TimeUnit;

/**
 * 测试多线程计数类
 * 对request方法添加了锁
 */
public class Test09131 {
    //总访问量
    static int count = 0;

    //模拟访问的方法
    public static synchronized void request() throws InterruptedException {
        //模拟耗时5毫秒
        TimeUnit.MILLISECONDS.sleep(5);
        count++;
    }

    //并发程序
    public static void main(String[] args) throws InterruptedException {
        //计算耗时需要获取当前时间、运行后时间,计算他们之间的差值
        long startTime = System.currentTimeMillis();
        //设置线程个数
        int threadsize = 100;
        //创建一个栅栏
        CountDownLatch countDownLatch = new CountDownLatch(threadsize);


        for (int i = 0; i < threadsize; i++) {
            //创建线程
            Thread thread = new Thread(new Runnable() {
                @Override
                public void run() {
                    //模拟用户行为,每个用户访问十次网站

                    try {
                        for (int j = 0; j < 10; j++) {
                            request();
                        }
                    } catch (InterruptedException e) {
                        e.printStackTrace();
                    } finally {
                        countDownLatch.countDown();
                    }
                }
            });
            thread.start();
        }
        //只有await方法的值为0整个程序才会结束
        countDownLatch.await();
        //拿到程序结束的时间
        long endTime = System.currentTimeMillis();

        System.out.println(Thread.currentThread().getName() + ",耗时:" + (endTime - startTime) + ", count:" + count);
    }
}

从这个程序的输出结果看
并发编程(1)_CAS_多线程_03

计数准确性问题确实解决了,但是出现了新问题,运行时间变得十分缓慢,这是因为我们的锁加在了整个方法上,导致执行整个request方法在同一时刻只能有一个线程在执行,实际上我们只需要对count++中的第三步操作加锁就可以,这里也引出了我们加锁的原则

在保证线程安全性的情况下,加锁的范围越小越好

对锁加以改进

我们按照如下方式来对第3步进行升级

  1. 获取锁
  2. 获取以下count最新的值,记作LV
  3. 判断LV是否等于A,如果相等,则将B的值赋值给count,并返回true,否则返回false
  4. 释放锁
    接下来,我们再次对代码进行修改
package com.mmall.example.count;

import java.util.concurrent.CountDownLatch;
import java.util.concurrent.TimeUnit;

/**
 * 测试多线程计数类
 * 对request方法添加了锁
 */
public class Test091311 {
    //总访问量
    //这里的volatile关键字能够保证变量是可见的,在线程获取该变量的时候,会从主内存中获取,而不是共享内存中获取
    volatile static int count = 0;

    //模拟访问的方法
    public static void request() throws InterruptedException {
        //模拟耗时5毫秒
        TimeUnit.MILLISECONDS.sleep(5);
        //   count++;

        int exceptCount;//表示期望值
        while (!compareAndSwap(exceptCount = getCount(), exceptCount + 1)) ;
    }
    //新建一个新方法来加锁

    /**
     * @param expectCount 期望值count
     * @param newCount    需要给count赋值的新值
     * @return 成功返回 true 失败返回false
     */
    public static synchronized boolean compareAndSwap(int expectCount, int newCount) {
        //判断count当前值是否和期望值expectCount一致,如果一致,将newCount赋值给count
        if (getCount() == expectCount) {
            count = newCount;
            return true;
        }
        return false;
    }

    public static int getCount() {
        return count;
    }


    //并发程序
    public static void main(String[] args) throws InterruptedException {
        //计算耗时需要获取当前时间、运行后时间,计算他们之间的差值
        long startTime = System.currentTimeMillis();
        //设置线程个数
        int threadsize = 100;
        //创建一个栅栏
        CountDownLatch countDownLatch = new CountDownLatch(threadsize);


        for (int i = 0; i < threadsize; i++) {
            //创建线程
            Thread thread = new Thread(new Runnable() {
                @Override
                public void run() {
                    //模拟用户行为,每个用户访问十次网站

                    try {
                        for (int j = 0; j < 10; j++) {
                            request();
                        }
                    } catch (InterruptedException e) {
                        e.printStackTrace();
                    } finally {
                        countDownLatch.countDown();
                    }
                }
            });
            thread.start();
        }
        //只有await方法的值为0整个程序才会结束
        countDownLatch.await();
        //拿到程序结束的时间
        long endTime = System.currentTimeMillis();

        System.out.println(Thread.currentThread().getName() + ",耗时:" + (endTime - startTime) + ", count:" + count);
    }
}

执行结果
并发编程(1)_CAS_赋值_04

这里的更新相当于把count++的第三步进行手动拆分,可以发现程序现在既保证了性能又保证了速度。