文章目录


在上一篇中已经对sqlsession进行了构造,本章内容主要是针对Executor进行自定义。

1.自定义Executor执行器接口

package org.apache.ibatis.executor;

import org.apache.ibatis.configration.Configuration;
import org.apache.ibatis.configration.MappedStatement;

import java.util.List;

public interface Executor {

<E> List<E> query(Configuration configuration, MappedStatement mappedStatement, Object... params) throws Exception;
}

3.自定义BoundSql类

package org.apache.ibatis.configration;

import com.google.common.collect.Lists;
import lombok.AllArgsConstructor;
import lombok.Data;

import java.util.List;

@Data
@AllArgsConstructor
public class BoundSql {

private String sqlText; // 解析过后的sql

private List<String> parameterMappingList = Lists.newArrayList();

}

3.自定义Executor执行器接口实现类SimpleExecutor

SimpleExecutor类主要负责sql语句的执行动作

package org.apache.ibatis.executor;

import com.google.common.collect.Lists;
import org.apache.ibatis.configration.BoundSql;
import org.apache.ibatis.configration.Configuration;
import org.apache.ibatis.configration.MappedStatement;
import org.apache.ibatis.session.CommandType;

import java.beans.PropertyDescriptor;
import java.lang.reflect.Field;
import java.lang.reflect.Method;
import java.sql.Connection;
import java.sql.PreparedStatement;
import java.sql.ResultSet;
import java.sql.ResultSetMetaData;
import java.util.*;

public class SimpleExecutor implements Executor {

@Override
public <E> List<E> query(Configuration configuration, MappedStatement mappedStatement, Object... params) throws Exception {
// 1.注册驱动,获取连接
Connection connection = configuration.getDataSource().getConnection();

// 2.获取sql语句 : select * from user where id = #{id} and username = #{username}
// 转换sql语句 : select * from user where id = ? and username = ?
// 转换的过程中 : 还要对#{}里面的值进行解析存储
String sql = mappedStatement.getSql();
BoundSql boundSql = getBoundSql(sql);

// 3.获取预处理对象:preparedStatement
PreparedStatement preparedStatement = connection.prepareStatement(boundSql.getSqlText());

// 4.设置参数
// 获取到了参数的全路径
String parameterType = mappedStatement.getParameterType();
Class<?> parameterTypeClass = getClassType(parameterType);
List<String> parameterMappingList = boundSql.getParameterMappingList();
for (int i = 0; i < parameterMappingList.size(); i++) {
String content = parameterMappingList.get(i);

// 反射
Field declaredField = parameterTypeClass.getDeclaredField(content);
// 暴力访问,防止访问的字段是private修饰
declaredField.setAccessible(true);
Object obj = declaredField.get(params[0]);
preparedStatement.setObject(i + 1, obj);
}

// 5.执行sql
String id = mappedStatement.getId();
ResultSet resultSet = null;
if (!Arrays.asList(CommandType.sqlCommand).contains(id)) {
resultSet = preparedStatement.executeQuery();
} else {
Integer result = preparedStatement.executeUpdate();
List<Integer> resultList = Lists.newArrayList();
resultList.add(result);
return (List<E>) resultList;
}

String resultType = mappedStatement.getResultType();
Class<?> resultTypeClass = getClassType(resultType);
List<Object> objects = Lists.newArrayList();

// 6.封装返回结果集
while (resultSet.next()) {
Object o = resultTypeClass.newInstance();
// 元数据
ResultSetMetaData metaData = resultSet.getMetaData();
for (int i = 1; i <= metaData.getColumnCount(); i++) {
// 字段名
String columnName = metaData.getColumnName(i);
// 字段的值
Object value = resultSet.getObject(columnName);
// 使用反射或者内省,根据数据库表和实体的对应关系,完成封装
PropertyDescriptor propertyDescriptor = new PropertyDescriptor(columnName, resultTypeClass);
Method writeMethod = propertyDescriptor.getWriteMethod();
writeMethod.invoke(o, value);
}
objects.add(o);
}
return (List<E>) objects;
}

private Class<?> getClassType(String parameterType) throws ClassNotFoundException {
if (parameterType != null) {
Class<?> clazz = Class.forName(parameterType);
return clazz;
}
return null;
}


Map<Integer,Integer> map=new TreeMap<Integer, Integer>();
int findPosition=0;
List<String> parameterMappings=new ArrayList<String>();

/**
* 完成对#{}的解析工作:
* 1.将#{}使用?进行代替
* 2.解析出#{}里面的值进行存储
* @param sql 原生sql
* @return 解析后的sql
*/
private BoundSql getBoundSql(String sql) {
parserSql(sql);
Set<Map.Entry<Integer, Integer>> entries = map.entrySet();
for (Map.Entry<Integer, Integer> entry : entries) {
Integer key = entry.getKey()+2;
Integer value = entry.getValue();
parameterMappings.add(sql.substring(key,value));
}
for (String s : parameterMappings) {
sql=sql.replace("#{"+s+"}","?");
}
BoundSql boundSql = new BoundSql(sql, parameterMappings);
return boundSql;
}

private void parserSql(String sql) {
int openIndex=sql.indexOf("#{",findPosition);
if(openIndex!=-1){
int endIndex = sql.indexOf("}",findPosition+1);
if(endIndex!=-1){
map.put(openIndex,endIndex);
findPosition=endIndex+1;
parserSql(sql);
}else{
System.out.println("SQL语句中参数错误..");
}
}
}

}

4.补全DefaultSqlSession类中的接口实现方法

@Override
public <E> List<E> selectList(String statementId, Object... params) throws Exception {
// 将要去完成对SimpleExecutor里的query方法的调用
SimpleExecutor simpleExecutor = new SimpleExecutor();
MappedStatement mappedStatement = this.configuration.getMappedStatementMap().get(statementId);
List<Object> list = simpleExecutor.query(this.configuration, mappedStatement, params);
return (List<E>) list;
}

@Override
public <T> T selectOne(String statementId, Object... params) throws Exception {
List<Object> objects = selectList(statementId, params);
if (objects.size() == 1) {
return (T) objects.get(0);
} else {
throw new RuntimeException("查询结果为空或者结果过多!");
}
}

@Override
public <T> T update(String statementId, Object... params) throws Exception {
SimpleExecutor simpleExecutor = new SimpleExecutor();
MappedStatement mappedStatement = this.configuration.getMappedStatementMap().get(statementId);
return (T) simpleExecutor.query(this.configuration, mappedStatement, params);
}

5.测试全部查询

package com.bruce.test;

import com.bruce.dao.UserDao;
import com.bruce.pojo.User;
import org.apache.ibatis.io.Resources;
import org.apache.ibatis.session.SqlSession;
import org.apache.ibatis.session.SqlSessionFactory;
import org.apache.ibatis.session.SqlSessionFactoryBuilder;
import org.junit.Before;
import org.junit.Test;

import java.io.InputStream;
import java.util.List;

public class PersistenceTest {

UserDao userDao;

@Before
public void before() throws Exception {
InputStream resourceAsStream = Resources.getResourceAsStream("mybatis-config.xml");
SqlSessionFactory sqlSessionFactory = new SqlSessionFactoryBuilder().build(resourceAsStream);
SqlSession sqlSession = sqlSessionFactory.openSession();
userDao = sqlSession.getMapper(UserDao.class);
}

@Test
public void testSelect() throws Exception {
// 调用
List<User> all = userDao.findAll();
System.out.println(all);
}
}

运行结果:
《Java手写系列》-手写MyBatis框架(四)_java

6.测试条件查询

@Test
public void testSelect1() throws Exception {
// 调用
User user = new User();
user.setId(1);
user.setUsername("bruce");

User user2 = userDao.selectOne(user);
System.out.println(user2);
}

运行结果:
《Java手写系列》-手写MyBatis框架(四)_java_02

7.测试新增

@Test
public void testInsert() throws Exception {
// 调用
User user = new User();
user.setId(4);
user.setUsername("jack");

userDao.insert(user);
System.out.println("新增成功");
}

8.测试新增

@Test
public void testInsert() throws Exception {
// 调用
User user = new User();
user.setId(4);
user.setUsername("jack");

userDao.insert(user);
System.out.println("新增成功");
}

9.测试更新

@Test
public void testUpdate() throws Exception {
// 调用
User user = new User();
user.setId(1);
user.setUsername("jackUpdate");
userDao.update(user);
System.out.println("更新成功!");
}

10.测试删除

@Test
public void testDelete() throws Exception {
// 调用
User user = new User();
user.setId(4);
userDao.delete(user);
System.out.println("删除成功");
}

至此为止,手写MyBatis大功告成!