前言
我所在项目组刚好接到一个领取优惠券需求,具体需求是用户领取的各种类型的优惠券(比如:代金券、折扣券)数量不能超过某个自定义数量,因考虑到领取限制数量是动态的,另外考虑到扩展性,满足将来业务规则的增长,不只是限制领取数需要新加其他条件,为了满足不断变化的业务场景,经分析后选用规则引擎 Drools 来实现。
本场 Chat 通过一个电商行业的领取优惠券场景,详细介绍了 Drools 的常用语法、使用步骤、具体开发过程及代码分析、测试步骤、注意事项等等,并贴出完整的代码。
Drools 常用语法
package
规则的路径是必须定义并且放在文件的第一行,类似于 Java 语言的 package,但 Java 的 package 是物理路径,Drools 的 package 是逻辑路径不需要真实存在。
import
导入外部变量,类似于 Java 语言的 import,但 Java 只能导入类,Drools 不仅可以导入类还可以导入静态方法。
global
定义全局变量,通常用于返回规则结果信息,用法介绍请看下面代码:
Drl 文件:
package com.example.demo.rule.card.receive
import org.slf4j.Logger;
import com.example.demo.drools.base.CheckResult;
import com.example.demo.rule.fact.CardReceiveFact;
global Logger log;
global CheckResult result;
Java 代码:
KieSession kSession = kContainer.newKieSession();
kSession.setGlobal("log", LOG);
kSession.setGlobal("result", result);
rule
定义一个规则,rule 开头 end 结尾,一个规则会包含三个部分,分别是属性部分、条件部分(即 LHS)和结果部分(即 RHS),请看下面的代码结构:
rule "规则名称,不能重复"
no-loop true //属性部分
when //条件部分
fact : com.example.demo.rule.fact.CardReceiveFact(
getMaxReceiveNumber() !=null && getMaxReceiveNumber() <=2
);
then //结果部分
result.setCode('200');
result.setMessage('匹配成功');
log.info('match success');
update(fact);
end
no-loop true
定义当前规则是否不允许循环执行,默认是允许循环执行,如果 RHS 部分有 update 等触发规则重新执行的操作,那么就要设置为 true 代表只执行一次,避免产生死循环。
lock-on-active true
如果是 true 则设置规则只被执行一次,不受本身规则触发,也不受其他规则触发,是 no-loop 的加强版。
salience 1000
优先级,数字越大优先级越高,用于控制规则的执行顺序,一个 drl 文件可以定义多个 rule 规则。
date-expires
设置规则的失效时间。
date-effective
设置规则的生效时间。
duration
设置规则延时执行,比如:duration 5000
代表 5 秒后执行规则。
fact
fact 对象实际上就是一个 Java Bean,用来把 Java 参数传递到规则引擎。
when & then
when 用于定义条件,条件可以一个,也可以多个,通常都是判断传入的 fact 对象变量是否符合条件,假如符合条件则执行 then 的代码。
Drools 使用步骤
1. 定义规则脚本,通常脚本都是根据业务场景动态的生成并保存数据库,脚本如下:
package com.example.demo.rule.card.receive
import org.slf4j.Logger;
import com.example.demo.drools.base.CheckResult;
import com.example.demo.rule.fact.CardReceiveFact;
global Logger log;
global CheckResult result;
rule "W8V32BE36Q_1082913542028271999"
no-loop true
lock-on-active true
salience 1000
when
fact : com.example.demo.rule.fact.CardReceiveFact(getMaxReceiveNumber() != null && getMaxReceiveNumber() <= 2);
then
result.setCode('200');
result.setMessage('匹配成功');
log.info('match success');
update(fact);
end
2. 定义 Fact 对象,定义一个最大可领用数量变量,并提供 get 和 set 方法,代码如下:
package com.example.demo.rule.fact;
import com.example.demo.drools.base.RuleFact;
/**
* 领取优惠券 fact
*/
public class CardReceiveFact extends RuleFact {
/**
* 最大可领用数量
*/
private Integer maxReceiveNumber;
public Integer getMaxReceiveNumber() {
return maxReceiveNumber;
}
public void setMaxReceiveNumber(Integer maxReceiveNumber) {
this.maxReceiveNumber = maxReceiveNumber;
}
@Override
public String toString() {
return "CardReceiveFact{" +
"maxReceiveNumber=" + maxReceiveNumber +
'}';
}
}
3. 执行规则并获取执行结果,把脚本和 Fact 对象传给已封装好的规则执行器,执行器会校验参数、创建规则会话、执行规则、销毁会话和返回执行结果:
//执行规则脚本
CheckResult result = cardReceiveExecutor.excute(new RuleSchema(schemaCode, ruleScript), fact);
具体开发过程及代码分析
代码结构
pom.xml 文件配置
下面配置是工程需要使用的所有 jar 和 maven 打包策略,代码如下:
<?xml version="1.0" encoding="UTF-8"?>
<project xmlns="http://maven.apache.org/POM/4.0.0" xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance"
xsi:schemaLocation="http://maven.apache.org/POM/4.0.0 https://maven.apache.org/xsd/maven-4.0.0.xsd">
<modelVersion>4.0.0</modelVersion>
<parent>
<groupId>org.springframework.boot</groupId>
<artifactId>spring-boot-starter-parent</artifactId>
<version>2.5.6</version>
<relativePath/> <!-- lookup parent from repository -->
</parent>
<groupId>com.example</groupId>
<artifactId>demo-drools </artifactId>
<version>0.0.1-SNAPSHOT</version>
<name>demo-drools</name>
<description>Demo project for Spring Boot</description>
<properties>
<java.version>8</java.version>
</properties>
<dependencies>
<dependency>
<groupId>org.springframework.boot</groupId>
<artifactId>spring-boot-starter</artifactId>
</dependency>
<dependency>
<groupId>org.springframework.boot</groupId>
<artifactId>spring-boot-starter-test</artifactId>
<scope>test</scope>
</dependency>
<dependency>
<groupId>org.springframework.boot</groupId>
<artifactId>spring-boot-starter-web</artifactId>
</dependency>
<dependency>
<groupId>cn.hutool</groupId>
<artifactId>hutool-all</artifactId>
<version>4.0.3</version>
</dependency>
<!--kie api 构建 kie 虚拟文件系统,关联 decisiontable 和 drl 文件 -->
<dependency>
<groupId>org.kie</groupId>
<artifactId>kie-spring</artifactId>
<version>7.0.0.Final</version>
</dependency>
<dependency>
<groupId>org.drools</groupId>
<artifactId>drools-compiler</artifactId>
<version>7.0.0.Final</version>
</dependency>
</dependencies>
<build>
<plugins>
<plugin>
<groupId>org.springframework.boot</groupId>
<artifactId>spring-boot-maven-plugin</artifactId>
</plugin>
</plugins>
</build>
</project>
工程配置文件和启动类
application.properties 仅仅配置了服务端口,启动类也是非常简单,具体代码如下:
application.properties 文件:
server.port=8888
package com.example.demo;
import org.springframework.boot.SpringApplication;
import org.springframework.boot.autoconfigure.SpringBootApplication;
@SpringBootApplication
public class DemoDroolsApplication {
public static void main(String[] args) {
SpringApplication.run(DemoDroolsApplication.class, args);
}
}
Drools 的基础类、规则执行器类和工具类
定义了两个接口,分别是执行器接口和是否支持该执行器接口,使用抽象类封装了一个执行器基类,把执行规则时的参数校验,创建规则会话、执行规则等操作封装好了,子类使用时便极为简单了;使用 KieContainer 管理所有规则,比如发布规则、删除规则等,Kie 工具还有缓存机制,代码如下:
Drools 的基础类代码:
package com.example.demo.drools.base;
/**
* 执行规则结果
*/
public class CheckResult {
/**
* 结果描述
*/
private String message;
/**
* 结果编码
*/
private String code;
/**
* 返回数据
*/
private Object data;
/**
* 成功码
*/
private static final String SUCCESS_CODE = "200";
/**
* 失败码
*/
private static final String FAIL_CODE = "400";
public CheckResult() {
}
/**
* 转换为 result 对象
* @return
*/
public Result<Object> transf(){
Result<Object> result = new Result<>();
result.setCode(this.code);
result.setMessage(this.message);
result.setData(this.data);
return result;
}
/**
* 失败
* @param code
* @param message
* @return
*/
public static CheckResult fail(String code, String message) {
CheckResult result = new CheckResult();
result.setCode(code);
result.setMessage(message);
return result;
}
/**
* 失败, code 为 400
* @param message
* @return
*/
public static CheckResult fail(String message) {
return fail(FAIL_CODE, message);
}
/**
* 失败, code 为 400, message 为匹配不成功
* @return
*/
public static CheckResult fail() {
return fail("匹配不成功");
}
/**
* 匹配是否成功
* @return
*/
public boolean isSuccess() {
return SUCCESS_CODE.equals(code);
}
public String getMessage() {
return message;
}
public void setMessage(String message) {
this.message = message;
}
public String getCode() {
return code;
}
public void setCode(String code) {
this.code = code;
}
public Object getData() {
return data;
}
public void setData(Object data) {
this.data = data;
}
@Override
public String toString() {
return "CheckResult [message=" + message + ", code=" + code + ", data=" + data + "]";
}
}
package com.example.demo.drools.base;
import java.io.Serializable;
/**
* WebApi 接口对象.
*/
public class Result<T> implements Serializable {
private static final long serialVersionUID = 7314670530670292902L;
/**
* 结果描述
*/
private String message;
/**
* 结果编码
*/
private String code;
/**
* 业务数据
*/
private T data;
public Result() {
}
public static <T> Result<T> success() {
return success(null);
}
public static <T> Result<T> success(T data) {
Result<T> result = new Result<>();
result.setData(data);
result.setCode("200");
result.setMessage("操作成功");
return result;
}
public static <T> Result<T> fail(String code, String message) {
Result<T> result = new Result<>();
result.setCode(code);
result.setMessage(message);
return result;
}
/**
* 操作是否成功
* @return
*/
public boolean isSuccess() {
if("200".equals(code)) {
return true;
}else{
return false;
}
}
public static <T> Result<T> fail(String message) {
return fail("400", message);
}
public String getMessage() {
return message;
}
public void setMessage(String message) {
this.message = message;
}
public T getData() {
return data;
}
public void setData(T data) {
this.data = data;
}
public String getCode() {
return code;
}
public void setCode(String code) {
this.code = code;
}
@Override
public String toString() {
return "Result [message=" + message + ", code=" + code + ", data=" + data + "]";
}
}
package com.example.demo.drools.base;
/**
* 规则异常
*/
@SuppressWarnings("serial")
public class RuleException extends RuntimeException {
public RuleException() {
super("规则处理异常");
}
public RuleException(String message) {
super(message);
}
public RuleException(String message, Throwable cause) {
super(message, cause);
}
}
package com.example.demo.drools.base;
import java.io.Serializable;
import java.util.HashMap;
import java.util.Map;
/**
* 规则输入事实
*/
@SuppressWarnings("serial")
public class RuleFact implements Serializable {
/**
* 通用输入事实集
*/
private final Map<String, Object> facts = new HashMap<>();
/**
* 通用输出结果集
*/
private final Map<String, Object> results = new HashMap<>();
/**
* 根据属性获取值
* @param attr
*/
public Object factGet(String attr) {
return facts.get(attr);
}
/**
* 移除事实
*/
public void factRemove(String attr) {
facts.remove(attr);
}
/**
* 设置事实
*/
public final void factSet(String attr, Object value) {
facts.put(attr, value);
}
/**
* 设置事实集
*/
public final void factsSet(Map<String, Object> facts) {
this.facts.putAll(facts);
}
/**
* 清除所有事实
*/
public final void factClear() {
facts.clear();
}
/**
* 设置结果
* @param attr
* @param value
*/
public final void resultSet(String attr, Object value) {
results.put(attr, value);
}
/**
* 获取结果
* @param attr
* @return
*/
public final Object resultGet(String attr) {
return results.get(attr);
}
/**
* 获取结果集
* @return
*/
public final Map<String, Object> resultsGet() {
return results;
}
@Override
public String toString() {
return "RuleFact [facts=" + facts + ", results=" + results + "]";
}
}
package com.example.demo.drools.base;
public class RuleSchema {
/**
* 模式编码
*/
private String schemaCode;
/**
* 规则脚本
*/
private String ruleDrl;
public RuleSchema() {}
public RuleSchema(String schemaCode, String ruleDrl) {
this.schemaCode = schemaCode;
this.ruleDrl = ruleDrl;
}
public String getRuleDrl() {
return ruleDrl;
}
public void setRuleDrl(String ruleDrl) {
this.ruleDrl = ruleDrl;
}
public String getSchemaCode() {
return schemaCode;
}
public void setSchemaCode(String schemaCode) {
this.schemaCode = schemaCode;
}
}
Drools 的规则执行器类代码:
package com.example.demo.drools.executor;
import com.example.demo.drools.base.CheckResult;
import com.example.demo.drools.base.RuleFact;
import com.example.demo.drools.base.RuleSchema;
/**
* 规则执行器
*/
public interface RuleExecutor {
/**
* 执行规则
* @param ruleSchema 规则
* @param fact 事实
* @return
*/
public CheckResult excute(RuleSchema ruleSchema, RuleFact fact);
/**
* 是否支持该执行器
* @param fact
* @return
*/
public Boolean support(RuleFact fact);
}
package com.example.demo.drools.executor;
import com.example.demo.drools.base.RuleException;
import com.example.demo.drools.base.CheckResult;
import com.example.demo.drools.base.RuleFact;
import com.example.demo.drools.base.RuleSchema;
import com.example.demo.drools.util.KieUtil;
import org.drools.core.base.RuleNameStartsWithAgendaFilter;
import org.kie.api.runtime.KieContainer;
import org.kie.api.runtime.KieSession;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.springframework.util.Assert;
/**
* 执行器基类
*/
public abstract class AbstractExecutor implements RuleExecutor {
private static Logger LOG = LoggerFactory.getLogger(AbstractExecutor.class);
private static final String SEPARATOR = "_";
@Override
public final CheckResult excute(RuleSchema ruleSchema, RuleFact fact) {
String ruleDrl = ruleSchema.getRuleDrl();
Assert.hasText(ruleDrl, "规则不能为空");
Assert.notNull(fact, "规则事实不能为空");
if(!support(fact)) {
throw new RuleException("不支持的规则执行器");
}
LOG.info("创建规则会话:fact = {}", fact);
long time1 = System.currentTimeMillis();
KieContainer kContainer = KieUtil.getSchemaRuleContainer(ruleSchema.getSchemaCode(), ruleDrl);
KieSession kSession = kContainer.newKieSession();
long time2 = System.currentTimeMillis();
LOG.info("会话创建成功, 耗时: {}", time2 - time1);
//默认匹配不成功
CheckResult result = CheckResult.fail();
try {
kSession.setGlobal("log", LOG);
kSession.setGlobal("result", result);
insertFact(fact, kSession);
int count = kSession.fireAllRules(new RuleNameStartsWithAgendaFilter(ruleSchema.getSchemaCode() + SEPARATOR));
long time3 = System.currentTimeMillis();
LOG.info("执行成功, 耗时{}毫秒", time3 - time2);
LOG.info("执行了{}条规则", count);
} finally {
if(kSession != null) {
kSession.dispose();
kSession.destroy();
}
}
return result;
}
/**
* 如有需要, 可对该方法进行重写
* @param fact
* @param kSession
*/
protected void insertFact(RuleFact fact, KieSession kSession) {
kSession.insert(fact);
}
}
Drools 的工具类代码:
package com.example.demo.drools.util;
import cn.hutool.core.collection.CollUtil;
import cn.hutool.crypto.digest.DigestUtil;
import org.kie.api.KieServices;
import org.kie.api.builder.KieBuilder;
import org.kie.api.builder.KieFileSystem;
import org.kie.api.builder.KieRepository;
import org.kie.api.builder.Message;
import org.kie.api.io.Resource;
import org.kie.api.io.ResourceType;
import org.kie.api.runtime.KieContainer;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import java.util.List;
import java.util.Map;
import java.util.concurrent.ConcurrentHashMap;
/**
* kie 工具, 管理 kieContainer
*/
public class KieUtil {
private final static Logger LOG = LoggerFactory.getLogger(KieUtil.class);
private final static Map<String, String> RULE_KEY_MAP = new ConcurrentHashMap<String, String>();
private final static KieServices KIE_SERVICES;
private static KieContainer kieContainer;
private final static KieFileSystem KIE_FILE_SYSTEM;
/**
* 规则文件扩展名
*/
private final static String EXTENSION = ".drl";
/**
* 规则路径
*/
private static final String RULES_PATH = "src/main/resources/rules/";
/**
* 规则组前缀
*/
private final static String SCHEMA_PREFIX = "schema_";
static {
KIE_SERVICES = KieServices.Factory.get();
KIE_FILE_SYSTEM = KIE_SERVICES.newKieFileSystem();
KIE_SERVICES.newKieBuilder(KIE_FILE_SYSTEM).buildAll();
kieContainer = KIE_SERVICES.newKieContainer(KIE_SERVICES.getRepository().getDefaultReleaseId());
}
/**
* 获取规则容器,当规则不存在或有变更时更新规则
* @param schemaCode 模式编码
* @param ruleDrl 规则
* @return
*/
public static KieContainer getSchemaRuleContainer(String schemaCode, String ruleDrl) {
return getRuleContainer(SCHEMA_PREFIX + schemaCode, ruleDrl);
}
/**
* 获取规则容器
* @param key
* @param ruleDrl
* @return
*/
private static KieContainer getRuleContainer(String key, String ruleDrl) {
String ruleMd5 = DigestUtil.md5Hex(ruleDrl);
String prevSign = RULE_KEY_MAP.get(key);
if (!ruleMd5.equals(prevSign)) {
if (prevSign != null) {
removeRule(prevSign);
}
deployRule(ruleMd5, ruleDrl);
RULE_KEY_MAP.put(key, ruleMd5);
}
return kieContainer;
}
/**
* 删除规则
* @param ruleMd5
*/
private static void removeRule(String ruleMd5) {
KieRepository kr = KIE_SERVICES.getRepository();
String path = getRulePath(ruleMd5);
KIE_FILE_SYSTEM.delete(path);
LOG.info("删除规则:{}", path);
kieContainer = KIE_SERVICES.newKieContainer(kr.getDefaultReleaseId());
}
/**
* 发布规则
* @param ruleMd5
* @param ruleContent
* @return
*/
private static void deployRule(String ruleMd5, String ruleContent) {
LOG.debug("deploy rule {}", ruleContent);
KieRepository kr = KIE_SERVICES.getRepository();
Resource res = KIE_SERVICES.getResources().newByteArrayResource(ruleContent.getBytes()).setResourceType(ResourceType.DRL);
String path = getRulePath(ruleMd5);
KIE_FILE_SYSTEM.write(path, res);
LOG.info("发布规则:{}", path);
KieBuilder kieBuilder = KIE_SERVICES.newKieBuilder(KIE_FILE_SYSTEM).buildAll();
List<Message> errors = kieBuilder.getResults().getMessages(Message.Level.ERROR);
if(CollUtil.isNotEmpty(errors)) {
StringBuilder sb = new StringBuilder();
for(Message message : errors) {
sb.append("path:"+message.getPath()+", text:"+message.getText()+"||");
}
LOG.error(sb.toString());
}
kieContainer = KIE_SERVICES.newKieContainer(kr.getDefaultReleaseId());
}
/**
* 获取规则文件路径
* @param ruleKey
* @return
*/
private static String getRulePath(String ruleKey) {
return RULES_PATH + ruleKey + EXTENSION;
}
}