前言

在GitHub中学习druid连接池的时候,发现其提供了一个数据库密码加密的功能,如下所示的配置:

<bean id="dataSource" class="com.alibaba.druid.pool.DruidDataSource"
     init-method="init" destroy-method="close">
     <property name="url" value="jdbc:derby:memory:spring-test;create=true" />
     <property name="username" value="sa" />
     <property name="password" value="${password}" />
     <property name="filters" value="config" />
     <property name="connectionProperties" value="config.decrypt=true;config.decrypt.key=${publickey}" />
</bean>

除了正常配置加密后的password之外,还需要配置属性connectionProperties,并且指定指定如下两个参数:

config.decrypt=true;
config.decrypt.key=${publickey}

那么druid是怎么实现的呢?首先我们跟着文档中提供的ConfigTools类来一探究竟。

com.alibaba.druid.filter.config.ConfigTools

ConfigTools

/*
 * Copyright 1999-2101 Alibaba Group Holding Ltd.
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *      http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */
package com.alibaba.druid.filter.config;

import java.io.ByteArrayOutputStream;
import java.io.FileInputStream;
import java.security.InvalidKeyException;
import java.security.Key;
import java.security.KeyFactory;
import java.security.KeyPair;
import java.security.KeyPairGenerator;
import java.security.NoSuchAlgorithmException;
import java.security.PrivateKey;
import java.security.PublicKey;
import java.security.SecureRandom;
import java.security.cert.Certificate;
import java.security.cert.CertificateFactory;
import java.security.interfaces.RSAPrivateKey;
import java.security.interfaces.RSAPublicKey;
import java.security.spec.PKCS8EncodedKeySpec;
import java.security.spec.RSAPrivateKeySpec;
import java.security.spec.RSAPublicKeySpec;
import java.security.spec.X509EncodedKeySpec;

import javax.crypto.Cipher;

import com.alibaba.druid.util.Base64;
import com.alibaba.druid.util.JdbcUtils;

public class ConfigTools {

	private static final String DEFAULT_PRIVATE_KEY_STRING = "MIIBVAIBADANBgkqhkiG9w0BAQEFAASCAT4wggE6AgEAAkEAocbCrurZGbC5GArEHKlAfDSZi7gFBnd4yxOt0rwTqKBFzGyhtQLu5PRKjEiOXVa95aeIIBJ6OhC2f8FjqFUpawIDAQABAkAPejKaBYHrwUqUEEOe8lpnB6lBAsQIUFnQI/vXU4MV+MhIzW0BLVZCiarIQqUXeOhThVWXKFt8GxCykrrUsQ6BAiEA4vMVxEHBovz1di3aozzFvSMdsjTcYRRo82hS5Ru2/OECIQC2fAPoXixVTVY7bNMeuxCP4954ZkXp7fEPDINCjcQDywIgcc8XLkkPcs3Jxk7uYofaXaPbg39wuJpEmzPIxi3k0OECIGubmdpOnin3HuCP/bbjbJLNNoUdGiEmFL5hDI4UdwAdAiEAtcAwbm08bKN7pwwvyqaCBC//VnEWaq39DCzxr+Z2EIk=";
	public static final String DEFAULT_PUBLIC_KEY_STRING = "MFwwDQYJKoZIhvcNAQEBBQADSwAwSAJBAKHGwq7q2RmwuRgKxBypQHw0mYu4BQZ3eMsTrdK8E6igRcxsobUC7uT0SoxIjl1WveWniCASejoQtn/BY6hVKWsCAwEAAQ==";

	public static void main(String[] args) throws Exception {
        String password = args[0];
        String[] arr = genKeyPair(512);
        System.out.println("privateKey:" + arr[0]);
        System.out.println("publicKey:" + arr[1]);
        System.out.println("password:" + encrypt(arr[0], password));
	}

	public static String decrypt(String cipherText) throws Exception {
		return decrypt((String) null, cipherText);
	}

	public static String decrypt(String publicKeyText, String cipherText)
			throws Exception {
		PublicKey publicKey = getPublicKey(publicKeyText);

		return decrypt(publicKey, cipherText);
	}

	public static PublicKey getPublicKeyByX509(String x509File) {
		if (x509File == null || x509File.length() == 0) {
			return ConfigTools.getPublicKey(null);
		}

		FileInputStream in = null;
		try {
			in = new FileInputStream(x509File);

			CertificateFactory factory = CertificateFactory
					.getInstance("X.509");
			Certificate cer = factory.generateCertificate(in);
			return cer.getPublicKey();
		} catch (Exception e) {
			throw new IllegalArgumentException("Failed to get public key", e);
		} finally {
			JdbcUtils.close(in);
		}
	}

	public static PublicKey getPublicKey(String publicKeyText) {
		if (publicKeyText == null || publicKeyText.length() == 0) {
			publicKeyText = ConfigTools.DEFAULT_PUBLIC_KEY_STRING;
		}

		try {
			byte[] publicKeyBytes = Base64.base64ToByteArray(publicKeyText);
			X509EncodedKeySpec x509KeySpec = new X509EncodedKeySpec(
					publicKeyBytes);

			KeyFactory keyFactory = KeyFactory.getInstance("RSA");
			return keyFactory.generatePublic(x509KeySpec);
		} catch (Exception e) {
			throw new IllegalArgumentException("Failed to get public key", e);
		}
	}

	public static PublicKey getPublicKeyByPublicKeyFile(String publicKeyFile) {
		if (publicKeyFile == null || publicKeyFile.length() == 0) {
			return ConfigTools.getPublicKey(null);
		}

		FileInputStream in = null;
		try {
			in = new FileInputStream(publicKeyFile);
			ByteArrayOutputStream out = new ByteArrayOutputStream();
			int len = 0;
			byte[] b = new byte[512 / 8];
			while ((len = in.read(b)) != -1) {
				out.write(b, 0, len);
			}

			byte[] publicKeyBytes = out.toByteArray();
			X509EncodedKeySpec spec = new X509EncodedKeySpec(publicKeyBytes);
			KeyFactory factory = KeyFactory.getInstance("RSA");
			return factory.generatePublic(spec);
		} catch (Exception e) {
			throw new IllegalArgumentException("Failed to get public key", e);
		} finally {
			JdbcUtils.close(in);
		}
	}

	public static String decrypt(PublicKey publicKey, String cipherText)
			throws Exception {
		Cipher cipher = Cipher.getInstance("RSA");
		try {
			cipher.init(Cipher.DECRYPT_MODE, publicKey);
		} catch (InvalidKeyException e) {
            // 因为 IBM JDK 不支持私钥加密, 公钥解密, 所以要反转公私钥
            // 也就是说对于解密, 可以通过公钥的参数伪造一个私钥对象欺骗 IBM JDK
            RSAPublicKey rsaPublicKey = (RSAPublicKey) publicKey;
            RSAPrivateKeySpec spec = new RSAPrivateKeySpec(rsaPublicKey.getModulus(), rsaPublicKey.getPublicExponent());
            Key fakePrivateKey = KeyFactory.getInstance("RSA").generatePrivate(spec);
            cipher = Cipher.getInstance("RSA"); //It is a stateful object. so we need to get new one.
            cipher.init(Cipher.DECRYPT_MODE, fakePrivateKey);
		}
		
		if (cipherText == null || cipherText.length() == 0) {
			return cipherText;
		}

		byte[] cipherBytes = Base64.base64ToByteArray(cipherText);
		byte[] plainBytes = cipher.doFinal(cipherBytes);

		return new String(plainBytes);
	}

	public static String encrypt(String plainText) throws Exception {
		return encrypt((String) null, plainText);
	}

	public static String encrypt(String key, String plainText) throws Exception {
		if (key == null) {
			key = DEFAULT_PRIVATE_KEY_STRING;
		}

		byte[] keyBytes = Base64.base64ToByteArray(key);
		return encrypt(keyBytes, plainText);
	}

	public static String encrypt(byte[] keyBytes, String plainText)
			throws Exception {
		PKCS8EncodedKeySpec spec = new PKCS8EncodedKeySpec(keyBytes);
		KeyFactory factory = KeyFactory.getInstance("RSA");
		PrivateKey privateKey = factory.generatePrivate(spec);
		Cipher cipher = Cipher.getInstance("RSA");
        try {
		    cipher.init(Cipher.ENCRYPT_MODE, privateKey);
        } catch (InvalidKeyException e) {
            //For IBM JDK, 原因请看解密方法中的说明
            RSAPrivateKey rsaPrivateKey = (RSAPrivateKey) privateKey;
            RSAPublicKeySpec publicKeySpec = new RSAPublicKeySpec(rsaPrivateKey.getModulus(), rsaPrivateKey.getPrivateExponent());
            Key fakePublicKey = KeyFactory.getInstance("RSA").generatePublic(publicKeySpec);
            cipher = Cipher.getInstance("RSA");
            cipher.init(Cipher.ENCRYPT_MODE, fakePublicKey);
        }

		byte[] encryptedBytes = cipher.doFinal(plainText.getBytes("UTF-8"));
		String encryptedString = Base64.byteArrayToBase64(encryptedBytes);

		return encryptedString;
	}

	public static byte[][] genKeyPairBytes(int keySize)
			throws NoSuchAlgorithmException {
		byte[][] keyPairBytes = new byte[2][];

		KeyPairGenerator gen = KeyPairGenerator.getInstance("RSA");
		gen.initialize(keySize, new SecureRandom());
		KeyPair pair = gen.generateKeyPair();

		keyPairBytes[0] = pair.getPrivate().getEncoded();
		keyPairBytes[1] = pair.getPublic().getEncoded();

		return keyPairBytes;
	}

	public static String[] genKeyPair(int keySize)
			throws NoSuchAlgorithmException {
		byte[][] keyPairBytes = genKeyPairBytes(keySize);
		String[] keyPairs = new String[2];

		keyPairs[0] = Base64.byteArrayToBase64(keyPairBytes[0]);
		keyPairs[1] = Base64.byteArrayToBase64(keyPairBytes[1]);

		return keyPairs;
	}

}

我们从其提供的main方法入口,可以看到包括genKeyPair和encrypt两个方法,其中genKeyPair主要是生成RSA算法的KeyPair,包含一个公钥和私钥。而encrypt主要是利用私钥对password进行加密。

整体逻辑大概就是这样。那么解密是在哪里做的呢?我们来看下ConfigFilter。

ConfigFilter

/*
 * Copyright 1999-2101 Alibaba Group Holding Ltd.
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *      http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */
package com.alibaba.druid.filter.config;

import java.io.File;
import java.io.FileInputStream;
import java.io.FileNotFoundException;
import java.io.InputStream;
import java.net.URL;
import java.security.PublicKey;
import java.sql.SQLException;
import java.util.Properties;

import com.alibaba.druid.filter.FilterAdapter;
import com.alibaba.druid.pool.DruidDataSource;
import com.alibaba.druid.pool.DruidDataSourceFactory;
import com.alibaba.druid.proxy.jdbc.DataSourceProxy;
import com.alibaba.druid.support.logging.Log;
import com.alibaba.druid.support.logging.LogFactory;
import com.alibaba.druid.util.JdbcUtils;
import com.alibaba.druid.util.StringUtils;

/**
 * <pre>
 * 这个类主要是负责两个事情, 解密, 和下载远程的配置文件
 * [解密]
 * 
 * DruidDataSource dataSource = new DruidDataSource();
 * //dataSource.setXXX 其他设置
 * //下面两步很重要
 * //启用config filter
 * dataSource.setFilters("config");
 * //使用RSA解密(使用默认密钥)
 * dataSource.setConnectionPropertise("config.decrypt=true");
 * dataSource.setPassword("加密的密文");
 * 
 * [远程配置文件]
 * DruidDataSource dataSource = new DruidDataSource();
 * //下面两步很重要
 * //启用config filter
 * dataSource.setFilters("config");
 * //使用RSA解密(使用默认密钥)
 * dataSource.setConnectionPropertise("config.file=http://localhost:8080/remote.propreties;");
 * 
 * [Spring的配置解密]
 * 
 * <bean id="dataSource" class="com.alibaba.druid.pool.DruidDataSource" init-method="init" destroy-method="close">
 *     <property name="password" value="加密的密文" />
 *     <!-- 其他的属性设置 -->
 *     <property name="filters" value="config" />
 *     <property name="connectionProperties" value="config.decrypt=true" />
 * </bean>
 * 
 * [Spring的配置远程配置文件]
 * 
 * <bean id="dataSource" class="com.alibaba.druid.pool.DruidDataSource" init-method="init" destroy-method="close">
 *     <property name="filters" value="config" />
 *     <property name="connectionProperties" value="config.file=http://localhost:8080/remote.propreties; />
 * </bean>
 * 
 * [使用系统属性配置远程文件]
 * java -Ddruid.config.file=file:///home/test/my.properties ...
 * 
 * 远程配置文件格式:
 * 1. 其他的属性KEY请查看 @see com.alibaba.druid.pool.DruidDataSourceFactory
 * 2. config filter 相关设置:
 * #远程文件路径
 * config.file=http://xxxxx(http://开头或者file:开头)
 * 
 * #RSA解密, Key不指定, 使用默认的
 * config.decrypt=true
 * config.decrypt.key=密钥字符串
 * config.decrypt.keyFile=密钥文件路径
 * config.decrypt.x509File=证书路径
 * 
 * </pre>
 * 
 * @author Jonas Yang
 */
public class ConfigFilter extends FilterAdapter {

    private static Log         LOG                     = LogFactory.getLog(ConfigFilter.class);

    public static final String CONFIG_FILE             = "config.file";
    public static final String CONFIG_DECRYPT          = "config.decrypt";
    public static final String CONFIG_KEY              = "config.decrypt.key";

    public static final String SYS_PROP_CONFIG_FILE    = "druid.config.file";
    public static final String SYS_PROP_CONFIG_DECRYPT = "druid.config.decrypt";
    public static final String SYS_PROP_CONFIG_KEY     = "druid.config.decrypt.key";

    public ConfigFilter(){
    }

    public void init(DataSourceProxy dataSourceProxy) {
        if (!(dataSourceProxy instanceof DruidDataSource)) {
            LOG.error("ConfigLoader only support DruidDataSource");
        }

        DruidDataSource dataSource = (DruidDataSource) dataSourceProxy;
        Properties connectinProperties = dataSource.getConnectProperties();

        Properties configFileProperties = loadPropertyFromConfigFile(connectinProperties);

        // 判断是否需要解密,如果需要就进行解密行动
        boolean decrypt = isDecrypt(connectinProperties, configFileProperties);

        if (configFileProperties == null) {
            if (decrypt) {
                decrypt(dataSource, null);
            }
            return;
        }

        if (decrypt) {
            decrypt(dataSource, configFileProperties);
        }

        try {
            DruidDataSourceFactory.config(dataSource, configFileProperties);
        } catch (SQLException e) {
            throw new IllegalArgumentException("Config DataSource error.", e);
        }
    }

    public boolean isDecrypt(Properties connectinProperties, Properties configFileProperties) {
        boolean decrypt = false;

        String decrypterId = connectinProperties.getProperty(CONFIG_DECRYPT);
        if (decrypterId == null || decrypterId.length() == 0) {
            if (configFileProperties != null) {
                decrypterId = configFileProperties.getProperty(CONFIG_DECRYPT);
            }
        }

        if (decrypterId == null || decrypterId.length() == 0) {
            decrypterId = System.getProperty(SYS_PROP_CONFIG_DECRYPT);
        }

        if ("true".equals(decrypterId)) {
            decrypt = true;
        }

        return decrypt;
    }

    Properties loadPropertyFromConfigFile(Properties connectinProperties) {
        String configFile = connectinProperties.getProperty(CONFIG_FILE);

        if (configFile == null) {
            configFile = System.getProperty(SYS_PROP_CONFIG_FILE);
        }

        if (configFile != null && configFile.length() > 0) {
            if (LOG.isInfoEnabled()) {
                LOG.info("DruidDataSource Config File load from : " + configFile);
            }

            Properties info = loadConfig(configFile);

            if (info == null) {
                throw new IllegalArgumentException("Cannot load remote config file from the [config.file=" + configFile
                                                   + "].");
            }

            return info;
        }

        return null;
    }

    public void decrypt(DruidDataSource dataSource, Properties info) {

        try {
            String encryptedPassword = null;
            if (info != null) {
                encryptedPassword = info.getProperty(DruidDataSourceFactory.PROP_PASSWORD);
            }

            if (encryptedPassword == null || encryptedPassword.length() == 0) {
                encryptedPassword = dataSource.getConnectProperties().getProperty(DruidDataSourceFactory.PROP_PASSWORD);
            }

            if (encryptedPassword == null || encryptedPassword.length() == 0) {
                encryptedPassword = dataSource.getPassword();
            }

            PublicKey publicKey = getPublicKey(dataSource.getConnectProperties(), info);

            String passwordPlainText = ConfigTools.decrypt(publicKey, encryptedPassword);

            if (info != null) {
                info.setProperty(DruidDataSourceFactory.PROP_PASSWORD, passwordPlainText);
            } else {
                dataSource.setPassword(passwordPlainText);
            }
        } catch (Exception e) {
            throw new IllegalArgumentException("Failed to decrypt.", e);
        }
    }

    public PublicKey getPublicKey(Properties connectinProperties, Properties configFileProperties) {
        String key = null;
        if (configFileProperties != null) {
            key = configFileProperties.getProperty(CONFIG_KEY);
        }

        if (StringUtils.isEmpty(key) && connectinProperties != null) {
            key = connectinProperties.getProperty(CONFIG_KEY);
        }

        if (StringUtils.isEmpty(key)) {
            key = System.getProperty(SYS_PROP_CONFIG_KEY);
        }

        return ConfigTools.getPublicKey(key);
    }

    public Properties loadConfig(String filePath) {
        Properties properties = new Properties();

        InputStream inStream = null;
        try {
            boolean xml = false;
            if (filePath.startsWith("file://")) {
                filePath = filePath.substring("file://".length());
                inStream = getFileAsStream(filePath);
                xml = filePath.endsWith(".xml");
            } else if (filePath.startsWith("http://") || filePath.startsWith("https://")) {
                URL url = new URL(filePath);
                inStream = url.openStream();
                xml = url.getPath().endsWith(".xml");
            } else if (filePath.startsWith("classpath:")) {
                String resourcePath = filePath.substring("classpath:".length());
                inStream = Thread.currentThread().getContextClassLoader().getResourceAsStream(resourcePath);
                // 在classpath下应该也可以配置xml文件吧?
                xml = resourcePath.endsWith(".xml");
            } else {
                inStream = getFileAsStream(filePath);
                xml = filePath.endsWith(".xml");
            }

            if (inStream == null) {
                LOG.error("load config file error, file : " + filePath);
                return null;
            }

            if (xml) {
                properties.loadFromXML(inStream);
            } else {
                properties.load(inStream);
            }

            return properties;
        } catch (Exception ex) {
            LOG.error("load config file error, file : " + filePath, ex);
            return null;
        } finally {
            JdbcUtils.close(inStream);
        }
    }

    private InputStream getFileAsStream(String filePath) throws FileNotFoundException {
        InputStream inStream = null;
        File file = new File(filePath);
        if (file.exists()) {
            inStream = new FileInputStream(file);
        } else {
            inStream = Thread.currentThread().getContextClassLoader().getResourceAsStream(filePath);
        }
        return inStream;
    }
}

我们前面提到,在实际使用中,除了加密后的密码password,还有额外的connectionProperties属性要配置,包括config.decrypt和config.decrypt.key两个参数。

从代码中可以看到config.decrypt主要是用来判断是否开启了密码加密。

public boolean isDecrypt(Properties connectinProperties, Properties configFileProperties) {
        boolean decrypt = false;

        String decrypterId = connectinProperties.getProperty(CONFIG_DECRYPT);
        if (decrypterId == null || decrypterId.length() == 0) {
            if (configFileProperties != null) {
                decrypterId = configFileProperties.getProperty(CONFIG_DECRYPT);
            }
        }

        if (decrypterId == null || decrypterId.length() == 0) {
            decrypterId = System.getProperty(SYS_PROP_CONFIG_DECRYPT);
        }

        if ("true".equals(decrypterId)) {
            decrypt = true;
        }

        return decrypt;
    }

而config.decrypt.key配置的是上节中介绍到的公钥,利用公钥对加密后的password进行解密,然后把解密后的密码设置到DataSource的属性中。

public void decrypt(DruidDataSource dataSource, Properties info) {

        try {
            String encryptedPassword = null;
            if (info != null) {
                encryptedPassword = info.getProperty(DruidDataSourceFactory.PROP_PASSWORD);
            }

            if (encryptedPassword == null || encryptedPassword.length() == 0) {
                encryptedPassword = dataSource.getConnectProperties().getProperty(DruidDataSourceFactory.PROP_PASSWORD);
            }

            if (encryptedPassword == null || encryptedPassword.length() == 0) {
                encryptedPassword = dataSource.getPassword();
            }

            PublicKey publicKey = getPublicKey(dataSource.getConnectProperties(), info);

            String passwordPlainText = ConfigTools.decrypt(publicKey, encryptedPassword);

            if (info != null) {
                info.setProperty(DruidDataSourceFactory.PROP_PASSWORD, passwordPlainText);
            } else {
                dataSource.setPassword(passwordPlainText);
            }
        } catch (Exception e) {
            throw new IllegalArgumentException("Failed to decrypt.", e);
        }
    }

深入源码之后发现整体逻辑并不复杂,后续我们有类似的加解密的需求也可以参照其设计。