文章目录

  • application.properties
  • ZonecodeFilterProperties
  • ZonecodeLineHandler
  • YqZonecodeLineHandler
  • MybatisPlusConfig
  • 测试
  • 表数据
  • 测试单元
  • 测试结果

最近再做项目的时候遇到一个需求,系统中的数据按照行政区代码进行分隔,简单来说每个表中都有zonecode字段,查询数据的时候几乎所有的sql都要过滤这个条件。

# 正常sql
SELECT A FROM TAB WHERE B = #{B}
# 过滤后的sql
SELECT A FROM TAB WHERE B = #{B} AND C IN ('C1', 'C2')

application.properties

# 表中过滤的字段名称
zonehandle.idColumn=zonecode
# 排除的表
zonehandle.tableFilter[0]=user
zonehandle.tableFilter[1]=role

ZonecodeFilterProperties

@Data
@Component
@ConfigurationProperties(prefix = "zonehandle")
public class ZonecodeFilterProperties {
    // 表中过滤的字段名称
    private String idColumn;
    // 
    private List<String> tableFilter;

}

ZonecodeLineHandler

public interface ZonecodeLineHandler {
    /**
     * 获取行政区代码 ID 值表达式,支持N个 ID 值
     * <p>
     *
     * @return 行政区代码 ID 值表达式
     */
    List<Expression> getZonecodeId();

    /**
     * 获取行政区代码字段名
     * <p>
     * 默认字段名叫: zonecode
     *
     * @return 行政区代码字段名
     */
    default String getZonecodeIdColumn() {
        return "zonecode";
    }

    /**
     * 根据表名判断是否处理拼接多行政区代码条件
     * <p>
     * 默认都不需要进行解析
     *
     * @param tableName 表名
     * @return 是否处理, true:表示处理并拼接行政区代码条件,false:跳过
     */
    default boolean handleTable(String tableName) {
        return false;
    }
}

YqZonecodeLineHandler

@Slf4j
@Data
@NoArgsConstructor
@AllArgsConstructor
@ToString(callSuper = true)
@EqualsAndHashCode(callSuper = true)
public class ZonecodeLineInnerInterceptor extends JsqlParserSupport implements InnerInterceptor {

    private ZonecodeLineHandler zonecodeLineHandler;

    @Override
    public void beforeQuery(Executor executor, MappedStatement ms, Object parameter, RowBounds rowBounds, ResultHandler resultHandler, BoundSql boundSql) throws SQLException {
//        if (InterceptorIgnoreHelper.willIgnoreTenantLine(ms.getId())) return;
//        if (SqlParserHelper.getSqlParserInfo(ms)) return;
        PluginUtils.MPBoundSql mpBs = PluginUtils.mpBoundSql(boundSql);
        mpBs.sql(parserSingle(mpBs.sql(), null));
    }

    @Override
    public void beforePrepare(StatementHandler sh, Connection connection, Integer transactionTimeout) {
        PluginUtils.MPStatementHandler mpSh = PluginUtils.mpStatementHandler(sh);
        MappedStatement ms = mpSh.mappedStatement();
        SqlCommandType sct = ms.getSqlCommandType();
        if (sct == SqlCommandType.INSERT || sct == SqlCommandType.UPDATE || sct == SqlCommandType.DELETE) {
//            if (InterceptorIgnoreHelper.willIgnoreTenantLine(ms.getId())) return;
//            if (SqlParserHelper.getSqlParserInfo(ms)) return;
            PluginUtils.MPBoundSql mpBs = mpSh.mPBoundSql();
            mpBs.sql(parserMulti(mpBs.sql(), null));
        }
    }

    @Override
    protected void processSelect(Select select, int index, String sql, Object obj) {
        processSelectBody(select.getSelectBody());
        List<WithItem> withItemsList = select.getWithItemsList();
        if (!CollectionUtils.isEmpty(withItemsList)) {
            withItemsList.forEach(this::processSelectBody);
        }
    }

    protected void processSelectBody(SelectBody selectBody) {
        if (selectBody == null) {
            return;
        }
        if (selectBody instanceof PlainSelect) {
            processPlainSelect((PlainSelect) selectBody);
        } else if (selectBody instanceof WithItem) {
            WithItem withItem = (WithItem) selectBody;
            processSelectBody(withItem.getSelectBody());
        } else {
            SetOperationList operationList = (SetOperationList) selectBody;
            if (operationList.getSelects() != null && operationList.getSelects().size() > 0) {
                operationList.getSelects().forEach(this::processSelectBody);
            }
        }
    }

    @Override
    protected void processInsert(Insert insert, int index, String sql, Object obj) {
        if (!zonecodeLineHandler.handleTable(insert.getTable().getName())) {
            // 过滤退出执行
            return;
        }
        List<Column> columns = insert.getColumns();
        if (CollectionUtils.isEmpty(columns)) {
            // 针对不给列名的insert 不处理
            return;
        }
        String cityIdColumn = zonecodeLineHandler.getZonecodeIdColumn();
        if (columns.stream().map(Column::getColumnName).anyMatch(i -> i.equals(cityIdColumn))) {
            // 针对已给出租户列的insert 不处理
            return;
        }
        columns.add(new Column(zonecodeLineHandler.getZonecodeIdColumn()));
        Select select = insert.getSelect();
        if (select != null) {
            this.processInsertSelect(select.getSelectBody());
        } else if (insert.getItemsList() != null) {
            // fixed github pull/295
            ItemsList itemsList = insert.getItemsList();
            if (itemsList instanceof MultiExpressionList) {
                ((MultiExpressionList) itemsList).getExprList().forEach(el -> el.getExpressions().addAll(zonecodeLineHandler.getZonecodeId()));
            } else {
                ((ExpressionList) itemsList).getExpressions().addAll(zonecodeLineHandler.getZonecodeId());
            }
        } else {
            throw ExceptionUtils.mpe("Failed to process multiple-table update, please exclude the tableName or statementId");
        }
    }

    /**
     * update 语句处理
     */
    @Override
    protected void processUpdate(Update update, int index, String sql, Object obj) {
        final Table table = update.getTable();
        if (!zonecodeLineHandler.handleTable(table.getName())) {
            // 过滤退出执行
            return;
        }
        update.setWhere(this.andInExpression(table, update.getWhere()));
    }

    /**
     * delete 语句处理
     */
    @Override
    protected void processDelete(Delete delete, int index, String sql, Object obj) {
        if (!zonecodeLineHandler.handleTable(delete.getTable().getName())) {
            // 过滤退出执行
            return;
        }
        delete.setWhere(this.andInExpression(delete.getTable(), delete.getWhere()));
    }

    /**
     * delete update 语句 where 处理
     */
    protected Expression andInExpression(Table table, Expression where) {
        //获得where条件表达式
        InExpression inExpression = new InExpression();
        inExpression.setLeftExpression(this.getAliasColumn(table));
        inExpression.setRightItemsList(new ExpressionList(zonecodeLineHandler.getZonecodeId()));

        if (null != where) {
            if (where instanceof OrExpression) {
                return new AndExpression(inExpression, new Parenthesis(where));
            } else {
                return new AndExpression(inExpression, where);
            }
        }
        return inExpression;

    }


    /**
     * 处理 insert into select
     * 
     *
     * @param selectBody SelectBody
     */
    protected void processInsertSelect(SelectBody selectBody) {
        PlainSelect plainSelect = (PlainSelect) selectBody;
        FromItem fromItem = plainSelect.getFromItem();
        if (fromItem instanceof Table) {
            Table fromTable = (Table) fromItem;
            plainSelect.setWhere(builderInExpression(plainSelect.getWhere(), fromTable));
            appendSelectItem(plainSelect.getSelectItems());
        } else if (fromItem instanceof SubSelect) {
            SubSelect subSelect = (SubSelect) fromItem;
            appendSelectItem(plainSelect.getSelectItems());
            processInsertSelect(subSelect.getSelectBody());
        }
    }

    /**
     * 追加 SelectItem
     *
     * @param selectItems SelectItem
     */
    protected void appendSelectItem(List<SelectItem> selectItems) {
        if (CollectionUtils.isEmpty(selectItems)) return;
        if (selectItems.size() == 1) {
            SelectItem item = selectItems.get(0);
            if (item instanceof AllColumns || item instanceof AllTableColumns) return;
        }
        selectItems.add(new SelectExpressionItem(new Column(zonecodeLineHandler.getZonecodeIdColumn())));
    }

    /**
     * 处理 PlainSelect
     */
    protected void processPlainSelect(PlainSelect plainSelect) {
        FromItem fromItem = plainSelect.getFromItem();
        Expression where = plainSelect.getWhere();
        processWhereSubSelect(where);
        if (fromItem instanceof Table) {
            Table fromTable = (Table) fromItem;
            if (zonecodeLineHandler.handleTable(fromTable.getName())) {
                //#1186 github
                plainSelect.setWhere(builderInExpression(where, fromTable));
            }
        } else {
            processFromItem(fromItem);
        }
        List<Join> joins = plainSelect.getJoins();
        if (joins != null && joins.size() > 0) {
            joins.forEach(j -> {
                processJoin(j);
                processFromItem(j.getRightItem());
            });
        }
    }

    /**
     * 处理where条件内的子查询
     * <p>
     * 支持如下:
     * 1. in
     * 2. =
     * 3. >
     * 4. <
     * 5. >=
     * 6. <=
     * 7. <>
     * 8. EXISTS
     * 9. NOT EXISTS
     * <p>
     * 前提条件:
     * 1. 子查询必须放在小括号中
     * 2. 子查询一般放在比较操作符的右边
     *
     * @param where where 条件
     */
    protected void processWhereSubSelect(Expression where) {
        if (where == null) {
            return;
        }
        if (where instanceof FromItem) {
            processFromItem((FromItem) where);
            return;
        }
        if (where.toString().indexOf("SELECT") > 0) {
            // 有子查询
            if (where instanceof BinaryExpression) {
                // 比较符号 , and , or , 等等
                BinaryExpression expression = (BinaryExpression) where;
                processWhereSubSelect(expression.getLeftExpression());
                processWhereSubSelect(expression.getRightExpression());
            } else if (where instanceof InExpression) {
                // in
                InExpression expression = (InExpression) where;
                ItemsList itemsList = expression.getRightItemsList();
                if (itemsList instanceof SubSelect) {
                    processSelectBody(((SubSelect) itemsList).getSelectBody());
                }
            } else if (where instanceof ExistsExpression) {
                // exists
                ExistsExpression expression = (ExistsExpression) where;
                processWhereSubSelect(expression.getRightExpression());
            } else if (where instanceof NotExpression) {
                // not exists
                NotExpression expression = (NotExpression) where;
                processWhereSubSelect(expression.getExpression());
            } else if (where instanceof Parenthesis) {
                Parenthesis expression = (Parenthesis) where;
                processWhereSubSelect(expression.getExpression());
            }
        }
    }

    /**
     * 处理子查询等
     */
    protected void processFromItem(FromItem fromItem) {
        if (fromItem instanceof SubJoin) {
            SubJoin subJoin = (SubJoin) fromItem;
            if (subJoin.getJoinList() != null) {
                subJoin.getJoinList().forEach(this::processJoin);
            }
            if (subJoin.getLeft() != null) {
                processFromItem(subJoin.getLeft());
            }
        } else if (fromItem instanceof SubSelect) {
            SubSelect subSelect = (SubSelect) fromItem;
            if (subSelect.getSelectBody() != null) {
                processSelectBody(subSelect.getSelectBody());
            }
        } else if (fromItem instanceof ValuesList) {
            logger.debug("Perform a subquery, if you do not give us feedback");
        } else if (fromItem instanceof LateralSubSelect) {
            LateralSubSelect lateralSubSelect = (LateralSubSelect) fromItem;
            if (lateralSubSelect.getSubSelect() != null) {
                SubSelect subSelect = lateralSubSelect.getSubSelect();
                if (subSelect.getSelectBody() != null) {
                    processSelectBody(subSelect.getSelectBody());
                }
            }
        }
    }

    /**
     * 处理联接语句
     */
    protected void processJoin(Join join) {
        if (join.getRightItem() instanceof Table) {
            Table fromTable = (Table) join.getRightItem();
            if (!zonecodeLineHandler.handleTable(fromTable.getName())) {
                // 过滤退出执行
                return;
            }
            join.setOnExpression(builderInExpression(join.getOnExpression(), fromTable));
        }
    }

    /**
     * 处理条件
     */
    protected Expression builderInExpression(Expression currentExpression, Table table) {
        InExpression inExpression = new InExpression();
        inExpression.setLeftExpression(this.getAliasColumn(table));
        inExpression.setRightItemsList(new ExpressionList(zonecodeLineHandler.getZonecodeId()));

        if (currentExpression == null) {
            return inExpression;
        }
        if (currentExpression instanceof OrExpression) {
            return new AndExpression(new Parenthesis(currentExpression), inExpression);
        } else {
            return new AndExpression(currentExpression, inExpression);
        }
    }

    /**
     * 行政区代码字段别名设置
     * <p>zonecode 或 tableAlias.zonecode</p>
     *
     * @param table 表对象
     * @return 字段
     */
    protected Column getAliasColumn(Table table) {
        StringBuilder column = new StringBuilder();
        if (table.getAlias() != null) {
            column.append(table.getAlias().getName()).append(StringPool.DOT);
        }
        column.append(zonecodeLineHandler.getZonecodeIdColumn());
        return new Column(column.toString());
    }

    @Override
    public void setProperties(Properties properties) {
        PropertyMapper.newInstance(properties)
                .whenNotBlack("zonecodeLineHandler", ClassUtils::newInstance, this::setZonecodeLineHandler);
    }

}

MybatisPlusConfig

@Configuration
@MapperScan("com.example.mapper")
public class MybatisPlusConfig {

    @Autowired
    private YqZonecodeLineHandler yqZonecodeLineHandler;

    @Bean
    public MybatisPlusInterceptor mybatisPlusInterceptor() {
        MybatisPlusInterceptor interceptor = new MybatisPlusInterceptor();
        //zonecode拦截器
        ZonecodeLineInnerInterceptor zonecodeSqlParser = new ZonecodeLineInnerInterceptor();
        // 设置zonecode处理器
        citySqlParser.setZonecodeLineHandler(yqZonecodeLineHandler);
        interceptor.addInnerInterceptor(zonecodeSqlParser);
        return interceptor;
    }

}

测试

表数据

Mybatis-plus插件的一次完美实践_sql

测试单元

@SpringBootTest
class TestMybatisPlusApplicationTests {

	@Autowired
	private UserMapper userMapper;

	@Test
	public void testSelect() {
		User user = new User();
		user.setName("Jack");
		List<User> userList = userMapper.selectList(new QueryWrapper<User>(user));
		userList.forEach(System.out::println);
	}

}

测试结果

Mybatis-plus插件的一次完美实践_mybatis_02

1、原始sql

2、包装后的sql

3、最终执行的sql

4、执行结果