如何实现包扫描
- spring中的包扫描
- 如何实现呢?
- 自己实现
- 源码
- 验证
- 反思与总结
- spring的实现
- spring的源码
- 反思与总结
spring中的包扫描
在spring中有两种方式可以实现包扫描
- 传统的xml配置方式
<!--配置扫描com.example.spring.beans下的所有bean-->
<context:component-scan base-package="com.example.spring.beans"/>
- 基于注解的方式
@Configuration
@ComponentScan("com.example.spring.beans4")
public class ComponentScanConfig {
}
如何实现呢?
今天换个思路,如果这个需求要我们来实现,那我们如何做呢?我们都知道,一个bean如果想被spring管理起来,那么一定得把BeanDefinition
交给BeanDefinitionRegistry
。那么,BeanDefinition从哪儿来呢?BeanDefinition应该从我们的扫描结果中来。那么应该如何扫描呢?我觉得应该分为以下2步:
- 使用classLoader.getResources(resourceName);查找到所有的resource。
- 遍历resource,根据protocol的不同进行不同的查找。
自己实现
根据上面的想法,我自己实现了一个简单的ClassScanner。实现需求的同时,基于“开闭原则”对接口进行封装。
源码
/**
* 一个简单的class查找工具,目前仅支持jar包查找和本地查找
*/
public class SimpleClassScan {
private final Set<Class<?>> classSet;
private final Map<String, ProtocolHandler> handlerMap;
public SimpleClassScan() {
classSet = new HashSet<>();
handlerMap = new HashMap<>();
//注册一个文件扫描器
FileProtocolHandler fileProtocolHandler = new FileProtocolHandler();
//注册一个jar包扫描器
JarProtocolHandler jarProtocolHandler = new JarProtocolHandler();
handlerMap.put(fileProtocolHandler.handleProtocol(), fileProtocolHandler);
handlerMap.put(jarProtocolHandler.handleProtocol(), jarProtocolHandler);
}
public Set<Class<?>> scan(String... basePackages) {
ClassLoader classLoader = this.getClass().getClassLoader();
for (String basePackage : basePackages) {
//将com.aa.bb 替换成 com/aa/bb
String resourceName = basePackage.replace('.', '/') + "/";
Enumeration<URL> resources = null;
try {
//通过classLoader获取所有的resources
resources = classLoader.getResources(resourceName);
} catch (IOException e) {
e.printStackTrace();
}
if (resources == null) {
continue;
}
while (resources.hasMoreElements()) {
URL url = resources.nextElement();
String protocol = url.getProtocol();
//根据url中protocol类型查找适用的解析器
ProtocolHandler protocolHandler = handlerMap.get(protocol);
if (protocolHandler == null) {
throw new RuntimeException("need support protocol [" + protocol + "]");
}
protocolHandler.handle(basePackage, url);
}
}
return classSet;
}
/**
* 将class添加到结果中
* @param classFullName 形如com.aa.bb.cc.Test.class的字符串
*/
private void addResult(String classFullName) {
Class<?> aClass = null;
try {
aClass = Class.forName(classFullName.substring(0, classFullName.length() - 6));
} catch (ClassNotFoundException e) {
e.printStackTrace();
}
if (aClass != null) {
classSet.add(aClass);
}
}
/**
* 检查一个文件名是否是class文件名
* @param fileName 文件名
* @return
*/
private boolean checkIsNotClass(String fileName) {
//只要class类型的文件
boolean isClass = fileName.endsWith(".class");
if (!isClass) {
return true;
}
//排除内部类
return fileName.indexOf('$') != -1;
}
public Set<Class<?>> getClassSet() {
return classSet;
}
/**
* 协议处理器
*/
private interface ProtocolHandler {
/**
* 适配的协议
*
* @return
*/
String handleProtocol();
/**
* 处理url,最后需要调用{@link #addResult(String)}将结果存储到result中
*
* @param url
*/
void handle(String basePackage, URL url);
}
/**
* jar包解析器
*/
private class JarProtocolHandler implements ProtocolHandler {
@Override
public String handleProtocol() {
return "jar";
}
@Override
public void handle(String basePackage, URL url) {
try {
String resourceName = basePackage.replace('.', '/') + "/";
JarURLConnection conn = (JarURLConnection) url.openConnection();
JarFile jarFile = conn.getJarFile();
Enumeration<JarEntry> entries = jarFile.entries();
while (entries.hasMoreElements()) {
//遍历jar包中的所有项
JarEntry jarEntry = entries.nextElement();
String entryName = jarEntry.getName();
if (!entryName.startsWith(resourceName)) {
continue;
}
if (checkIsNotClass(entryName)) {
continue;
}
String classNameFullName = entryName.replace('/', '.');
addResult(classNameFullName);
}
} catch (IOException e) {
e.printStackTrace();
}
}
}
/**
* 文件解析器
*/
private class FileProtocolHandler implements ProtocolHandler {
@Override
public String handleProtocol() {
return "file";
}
@Override
public void handle(String basePackage, URL url) {
File rootFile = new File(url.getFile());
findClass(rootFile, File.separator + basePackage.replace('.', File.separatorChar) + File.separator);
}
/**
* 递归的方式查找class文件
* @param rootFile 当前文件
* @param subFilePath 子路径
*/
private void findClass(File rootFile, String subFilePath) {
if (rootFile == null) {
return;
}
//如果是文件夹
if (rootFile.isDirectory()) {
File[] files = rootFile.listFiles();
if (files == null) {
return;
}
for (File file : files) {
findClass(file, subFilePath);
}
}
String fileName = rootFile.getName();
if (checkIsNotClass(fileName)) {
return;
}
String path = rootFile.getPath();
int i = path.indexOf(subFilePath);
String subPath = path.substring(i + 1);
String fullClassPath = subPath.replace(File.separatorChar, '.');
addResult(fullClassPath);
}
}
}
验证
我们看下实际效果:
- 扫描本地包中的class
- 扫描依赖包中的class
反思与总结
- 通过
classLoader.getResources("resourceName")
可以获取到resources,其中resourceName必须是com/xxx/rrr的形式 - 文件分隔符在Windows和linux上是不同的,在Windows上是
\
,在linux上是/
。我们可以通过File.separatorChar
来获取当前操作系统中的文件分隔符 - url是通过
protocol
来区分的 - 扫描文件系统,可以使用File对象+递归的方式实现
- 扫描jar时,需要
openConnection
JarURLConnection conn = (JarURLConnection) url.openConnection();
JarFile jarFile = conn.getJarFile();
Enumeration<JarEntry> entries = jarFile.entries();
while (entries.hasMoreElements()) {
//遍历jar包中的所有项
JarEntry jarEntry = entries.nextElement();
String entryName = jarEntry.getName();
//TODO xxx
}
spring的实现
那么spring是如何实现的呢?
spring的源码
在spring中,spring使用ClassPathBeanDefinitionScanner
来实现包扫描,这个东西用起来非常方便,如果我想将com.example.spring.beans3下所有带有 @SkylineComponent注解的类注册为bean,那么可以这么写:
ClassPathBeanDefinitionScanner scanner = new ClassPathBeanDefinitionScanner(registry, false);
scanner.addIncludeFilter(new AnnotationTypeFilter(SkylineComponent.class));
scanner.scan("com.example.spring.beans3");
当然,这段代码一定要写在实现了BeanDefinitionRegistryPostProcessor
的bean中。
接下来我们看下ClassPathBeanDefinitionScanner的构造器
public ClassPathBeanDefinitionScanner(BeanDefinitionRegistry registry, boolean useDefaultFilters,Environment environment, @Nullable ResourceLoader resourceLoader) {
//最最重要的,BeanDefinitionRegistry不能空
Assert.notNull(registry, "BeanDefinitionRegistry must not be null");
this.registry = registry;
//是否使用默认的过滤器,如果使用默认的过滤器,那么仅扫描@Component注解
if (useDefaultFilters) {
registerDefaultFilters();
}
//设置环境参数
setEnvironment(environment);
//设置资源加载器
setResourceLoader(resourceLoader);
}
org.springframework.context.annotation.ClassPathBeanDefinitionScanner#scan
就是扫描方法的入口,实际的扫描逻辑是写在doScan方法中的,我们看下这个方法:
protected Set<BeanDefinitionHolder> doScan(String... basePackages) {
//首先,basePackages不能是空的
Assert.notEmpty(basePackages, "At least one base package must be specified");
Set<BeanDefinitionHolder> beanDefinitions = new LinkedHashSet<>();
//遍历所有要扫描的包
for (String basePackage : basePackages) {
//获取到待选的BeanDefinition
Set<BeanDefinition> candidates = findCandidateComponents(basePackage);
//遍历待选的BeanDefinition
for (BeanDefinition candidate : candidates) {
//设置Scope
ScopeMetadata scopeMetadata = this.scopeMetadataResolver.resolveScopeMetadata(candidate);
candidate.setScope(scopeMetadata.getScopeName());
//生成beanName
String beanName = this.beanNameGenerator.generateBeanName(candidate, this.registry);
if (candidate instanceof AbstractBeanDefinition) {
//默认值处理
postProcessBeanDefinition((AbstractBeanDefinition) candidate, beanName);
}
if (candidate instanceof AnnotatedBeanDefinition) {
//@Lazy @Primary @DependsOn @Role @Description这些注解支持
AnnotationConfigUtils.processCommonDefinitionAnnotations((AnnotatedBeanDefinition) candidate);
}
//bean冲突校验
if (checkCandidate(beanName, candidate)) {
BeanDefinitionHolder definitionHolder = new BeanDefinitionHolder(candidate, beanName);
definitionHolder =
AnnotationConfigUtils.applyScopedProxyMode(scopeMetadata, definitionHolder, this.registry);
beanDefinitions.add(definitionHolder);
//注册bean
registerBeanDefinition(definitionHolder, this.registry);
}
}
}
return beanDefinitions;
}
关键在于findCandidateComponents是如何找到这些候选的BeanDefinition的呢?接下来就会走到scanCandidateComponents中,接下来我们debug看下:
从代码中可以看到,spring将传入的"com.example.spring.beans3"解析成了"classpath*:com/example/spring/beans3/**/*.class",并通过getResourcePatternResolver().getResources(packageSearchPath)
来获取所有的resource,那么,我们看下getResources是如何处理classpath*:com/example/spring/beans3/**/*.class的。接下来,程序执行到findPathMatchingResources中,在findPathMatchingResources中通过getResource方法来返回Resource[]。
那么getResource里面是什么呢?在getResource中,最后会调用到doFindAllClassPathResources,如下图:
这段代码好眼熟…Enumeration<URL> resourceUrls = (cl != null ? cl.getResources(path) : ClassLoader.getSystemResources(path));
跟我的实现方式是一样的,也是先把package转换为资源路径,然后通过classLoader.getResources的方式来获取resource。接下来遍历这些resources,不同类型的resource走不同的逻辑,就像下面这样。
当所有的resources都获取到了之后,就开始遍历所有的resource。如下图:
这里spring的手法就比较高端了,spring通过读取resource中的class文件的字节码,生成了一个叫MetadataReader
的对象。这个MetadaReader并不是class对象,但是可以读取到class上所有的元数据信息。这是因为spring使用了ASM技术,用流的方式读取了class文件。然后就是创建ScannedGenericBeanDefinition并返回了。
反思与总结
spring中的包扫描虽然整体逻辑并不复杂,但是细节还是很多的。比如它处理了通配符**/*.class
、处理了不同协议的url、在最终读取class信息时使用了ASM技术、还支持自定义的过滤器等。spring在能扩展的地方都给我们留出了扩展点,但是在使用起来却是很方便,这一点还是很厉害的。