手写简单的RPC

1.何为RPC

RPC(Remote Procedure Call,远程过程调用)是一种通过网络从远程计算机程序上请求服务,而不需要了解底层网络技术的协议。RPC协议假定某些传输协议的存在,如TCP或UDP,为通信程序之间携带信息数据。在OSI网络通信模型中,RPC跨越了传输层和应用层。RPC使得开发包括网络分布式多程序在内的应用程序更加容易。

2.工作原理

RPC采用客户机/服务器模式。请求程序就是一个客户机,而服务提供程序就是一个服务器。首先,客户机调用进程发送一个有进程参数的调用信息到服务进程,然后等待应答信息。在服务器端,进程保持睡眠状态直到调用信息的到达为止。当一个调用信息到达,服务器获得进程参数,计算结果,发送应答信息,然后等待下一个调用信息,最后,客户端调用进程接收应答信息,获得进程结果,然后调用执行继续进行。

3.架构描述

本例通过服务生产者(provider)服务消费者(comsumer)服务框架(framework)三个模块实现简单的RPC案例。

其中服务生产者者负责提供服务,服务消费者通过http请求去调用服务提供者提供的方法,服务框架负责处理服务消费者调用服务提供者的相关逻辑处理。

手写简单的RPC_java

本例中,消费端需要调用接口ProviderService.class 中的某个方法。生产者模块提供了其具体实现类ProviderServiceImpl.class ;最后,消费者需要通过RPC去调用生成者提供的这个方法。

废话不多说,直接上代码!

4.服务生产者具体实现

💡 主要职责:服务注册、提供实现方法

package com.myrpc;

import com.myrpc.apis.ProviderService;
import com.myrpc.domain.ServiceBean;
import com.myrpc.domain.ServiceMetaInfo;
import com.myrpc.register.ServiceRegister;
import com.myrpc.server.HttpServer;
import com.myrpc.service.impl.ProviderServiceImpl;

import java.io.IOException;
import java.util.ArrayList;
import java.util.List;
import java.util.UUID;

/**
 * @author huliua
 * @version 1.0
 * @date 2024-04-14 15:26
 */
public class ProviderMain {
    public static void main(String[] args) {
        // 构建服务元信息
        ServiceMetaInfo serviceMetaInfo = new ServiceMetaInfo();
        // 服务的主键信息
        serviceMetaInfo.setKey(UUID.randomUUID().toString());
        // 服务名称
        serviceMetaInfo.setServiceName("provider");
        // 版本号
        serviceMetaInfo.setVersion("1.0");
        // ip
        serviceMetaInfo.setHost("localhost");
        // 端口
        serviceMetaInfo.setPort(8080);

        // 构建提供的服务列表
        List<ServiceBean> beanList = new ArrayList<>();
        ServiceBean serviceBean = new ServiceBean();
        serviceBean.setBeanName(ProviderService.class.getName());
        serviceBean.setBeanClass(ProviderServiceImpl.class);
        beanList.add(serviceBean);

        try {
            // 服务注册
            ServiceRegister.register(serviceMetaInfo, beanList);

            // 启动服务
            HttpServer httpServer = new HttpServer();
            httpServer.start(serviceMetaInfo);
        } catch (IOException e) {
            throw new RuntimeException(e);
        }
    }
}

其中ServiceMetaInfo 是框架模块提供的服务信息类,主要用于记录服务的信息

package com.myrpc.domain;

import cn.hutool.core.util.RandomUtil;
import lombok.Data;

import java.io.Serializable;

/**
 * 服务的元数据
 * @author huliua
 * @version 1.0
 * @date 2024-04-15 21:20
 */
@Data
public class ServiceMetaInfo implements Serializable {

    private String key;

    private String serviceName;

    private String host;

    private Integer port;

    private String version;
}

ServiceBean 是框架提供的类,主要用于记录一个服务下会提供哪些服务,包含类名、以及对应实现类的类名。

package com.myrpc.domain;

import lombok.Data;

import java.io.Serializable;

/**
 * @author huliua
 * @version 1.0
 * @date 2024-04-14 17:54
 */
@Data
public class ServiceBean implements Serializable {

    private String beanName;

    private Class<?> beanClass;
}

ServiceRegister 是框架提供的类,主要用于服务注册。本例中只实现了本地的服务注册,后续可以把服务信息注册到redis、nacos、zookeeper中。

package com.myrpc.register;

import cn.hutool.core.collection.CollUtil;
import cn.hutool.core.collection.CollectionUtil;
import com.myrpc.domain.ServiceBean;
import com.myrpc.domain.ServiceMetaInfo;

import java.io.*;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
import java.util.Map;

/**
 * 服务注册中心
 *
 * @author huliua
 * @version 1.0
 * @date 2024-04-14 15:27
 */
public class ServiceRegister {

    private static final String filePath = "/myrpc/register.txt";

    /**
     * 本地服务元信息列表
     */
    private static final Map<String, List<ServiceMetaInfo>> localServiceMetaInfoMap = new HashMap<>();

    /**
     * 本地服务列表
     */
    private static final Map<String, List<ServiceBean>> localServiceBeanMap = new HashMap<>();

    /**
     * 服务注册
     */
    public static void register(ServiceMetaInfo serviceMetaInfo, List<ServiceBean> serviceList) throws IOException {
        // 先实现本地注册
        List<ServiceMetaInfo> services = localServiceMetaInfoMap.get(serviceMetaInfo.getServiceName());
        if (CollectionUtil.isEmpty(services)) {
            services = new ArrayList<>();
        }
        services.add(serviceMetaInfo);
        localServiceMetaInfoMap.put(serviceMetaInfo.getServiceName(), services);

        // 保存该服务名下提供的服务列表
        localServiceBeanMap.put(serviceMetaInfo.getKey(), serviceList);

        // 远程服务注册(暂时使用存入本地文件的方式代替)
        FileOutputStream fileOutputStream = new FileOutputStream(filePath);
        ObjectOutputStream objectOutputStream = new ObjectOutputStream(fileOutputStream);
        objectOutputStream.writeObject(localServiceMetaInfoMap);
    }

    /**
     * 根据服务名获取服务信息
     *
     * @param serviceName 服务名
     * @return 返回注册中心的服务列表
     */
    public static List<ServiceMetaInfo> getService(String serviceName) {
        // 优先从本地缓存中读取
        List<ServiceMetaInfo> serviceList = localServiceMetaInfoMap.get(serviceName);
        if (CollUtil.isNotEmpty(serviceList)) {
            return serviceList;
        }
        // 从远程注册中心中读取(暂时使用读取本地文件的方式代替)
        FileInputStream fileInputStream = null;
        ObjectInputStream objectInputStream = null;
        try {
            fileInputStream = new FileInputStream(filePath);
            objectInputStream = new ObjectInputStream(fileInputStream);
            Map<String, List<ServiceMetaInfo>> resMap = (Map<String, List<ServiceMetaInfo>>) objectInputStream.readObject();
            return resMap.get(serviceName);
        } catch (IOException | ClassNotFoundException e) {
            e.printStackTrace();
        } finally {
            try {
                assert fileInputStream != null;
                fileInputStream.close();
                assert objectInputStream != null;
                objectInputStream.close();
            } catch (IOException e) {
                e.printStackTrace();
            }
        }
        return null;
    }

    /**
     * 根据key获取服务列表
     * @param key
     * @return
     */
    public static List<ServiceBean> getServiceBeanList(String key) {
        return localServiceBeanMap.get(key);
    }
}

HttpServer 是框架中提供的类,主要作用是启动tomcat,监听请求,并配置请求分发器DispatcherServlet

package com.myrpc.server;

import com.myrpc.dispatcher.DispatcherServlet;
import com.myrpc.domain.ServiceMetaInfo;
import org.apache.catalina.*;
import org.apache.catalina.connector.Connector;
import org.apache.catalina.core.StandardContext;
import org.apache.catalina.core.StandardEngine;
import org.apache.catalina.core.StandardHost;
import org.apache.catalina.startup.Tomcat;

/**
 * @author huliua
 * @version 1.0
 * @date 2024-04-14 17:24
 */
public class HttpServer {

    public void start(ServiceMetaInfo service) {
        Tomcat tomcat = new Tomcat();

        Server server = tomcat.getServer();
        org.apache.catalina.Service tomcatService = server.findService("Tomcat");

        Connector connector = new Connector();
        connector.setPort(service.getPort());

        Engine engine = new StandardEngine();
        engine.setDefaultHost(service.getHost());

        Host host = new StandardHost();
        host.setName(service.getHost());

        String contextPath = "";
        Context context = new StandardContext();
        context.setPath(contextPath);
        context.addLifecycleListener(new Tomcat.FixContextListener());

        host.addChild(context);
        engine.addChild(host);

        tomcatService.setContainer(engine);
        tomcatService.addConnector(connector);

        tomcat.addServlet(contextPath, "dispatcher", new DispatcherServlet());
        context.addServletMappingDecoded("/*", "dispatcher");

        try {
            tomcat.start();
            tomcat.getServer().await();
        } catch (LifecycleException e) {
            e.printStackTrace();
        }

    }
}

DispatcherServlet 是框架提供的类,主要作用是处理请求。当有请求到达时,通过HttpServerHandler 处理请求

package com.myrpc.dispatcher;

import com.myrpc.handler.HttpServerHandler;
import jakarta.servlet.ServletException;
import jakarta.servlet.ServletRequest;
import jakarta.servlet.ServletResponse;
import jakarta.servlet.http.HttpServlet;
import lombok.extern.slf4j.Slf4j;

import java.io.IOException;

/**
 * @author huliua
 * @version 1.0
 * @date 2024-04-14 17:34
 */
@Slf4j
public class DispatcherServlet extends HttpServlet {

    @Override
    public void service(ServletRequest req, ServletResponse res) throws ServletException, IOException {
        log.info("有新的请求待处理...");
        new HttpServerHandler().handler(req, res);
    }
}

HttpServerHandler 是框架提供的类,主要作用是处理远程调用请求。根据远程服务调用信息,通过SPI机制找到对应的实现类,完成方法的调用并将返回值写入请求响应中。

💡 使用`SPI机制`需要在生产者模块`resources`目录下新建`META-INF/services/com.myrpc.apis.ProviderService` 文件,其中写明该接口的实现类全路径`com.myrpc.service.impl.ProviderServiceImpl`

package com.myrpc.handler;

import cn.hutool.core.util.ClassUtil;
import cn.hutool.core.util.ServiceLoaderUtil;
import com.alibaba.fastjson2.JSON;
import com.myrpc.domain.Invocation;
import jakarta.servlet.ServletRequest;
import jakarta.servlet.ServletResponse;
import org.apache.commons.io.IOUtils;

import java.io.FileNotFoundException;
import java.io.IOException;
import java.io.ObjectInputStream;
import java.lang.reflect.InvocationTargetException;
import java.lang.reflect.Method;

/**
 * @author huliua
 * @version 1.0
 * @date 2024-04-14 17:35
 */
@SuppressWarnings("all")
public class HttpServerHandler {

    public void handler(ServletRequest req, ServletResponse res) {
        try {
            ObjectInputStream objectInputStream = new ObjectInputStream(req.getInputStream());
            Invocation invocation = (Invocation) objectInputStream.readObject();
            Class<?> serviceClass = ClassUtil.getClassLoader().loadClass(invocation.getClassName());
            Object serviceImpl = ServiceLoaderUtil.loadFirstAvailable(serviceClass);

            // 服务调用
            Method method = serviceClass.getMethod(invocation.getMethodName(), invocation.getParamTypes());
            Object result = method.invoke(serviceImpl, invocation.getArgs());

            // 写入响应
            IOUtils.write(JSON.toJSONString(result), res.getOutputStream());
        } catch (FileNotFoundException | NoSuchMethodException e) {
            throw new RuntimeException(e);
        } catch (InvocationTargetException e) {
            throw new RuntimeException(e);
        } catch (IllegalAccessException e) {
            throw new RuntimeException(e);
        } catch (IOException e) {
            throw new RuntimeException(e);
        } catch (ClassNotFoundException e) {
            throw new RuntimeException(e);
        }
    }
}

LoadBalance 是框架提供的类,主要用于实现负载均衡,这里以随机的方式为例

package com.myrpc.loadbalance.impl;

import cn.hutool.core.collection.CollectionUtil;
import cn.hutool.core.util.RandomUtil;
import com.myrpc.domain.ServiceMetaInfo;
import com.myrpc.loadbalance.LoadBalance;

import java.util.List;
import java.util.concurrent.ThreadLocalRandom;

/**
 * 随机负载均衡
 *
 * @author huliua
 * @version 1.0
 * @date 2024-04-14 16:25
 */
public class RandomLoadBalance implements LoadBalance {

    public ServiceMetaInfo loadBalance(List<ServiceMetaInfo> serviceList) {
        if (CollectionUtil.isEmpty(serviceList)) {
            return null;
        }
        ThreadLocalRandom random = RandomUtil.getRandom();
        int index = random.nextInt(serviceList.size());
        return serviceList.get(index);
    }
}

5.消费者具体实现

💡 主要职责:调用远程方法

通过JDK代理的方式,获取代理对象,然后调用代理对象的方法实现远程调用。

package com.myrpc;

import com.myrpc.apis.ProviderService;
import com.myrpc.bo.ResponseResult;
import com.myrpc.proxy.ProxyFactory;

import java.util.List;
import java.util.Map;

/**
 * @author huliua
 * @version 1.0
 * @date 2024-04-14 15:26
 */
public class ConsumerMain {
    public static void main(String[] args) {
        ProviderService providerService = ProxyFactory.getProxy("provider", ProviderService.class);
        ResponseResult<List<Map<String, Object>>> result = providerService.say();
        System.out.println(result);
    }
}

ProxyFactory 是框架提供的类,主要用于创建代理对象。当调用代理对象的方法时,都会走到这里的invoke 逻辑中:根据调用方法的方法名、方法参数、返回值类型等信息构建远程方法调用参数,然后发起http请求去实现远程方法调用。

package com.myrpc.proxy;

import com.myrpc.client.HttpClient;
import com.myrpc.domain.Invocation;
import com.myrpc.domain.RpcResponse;
import com.myrpc.retry.Retryer;

import java.lang.reflect.Proxy;

/**
 * @author huliua
 * @version 1.0
 * @date 2024-04-14 21:59
 */
public class ProxyFactory {

    public static <T> T getProxy(String serviceName, Class<?> interfaceClass) {
        Object proxyInstance = Proxy.newProxyInstance(interfaceClass.getClassLoader(), new Class[]{interfaceClass}, (proxy, method, args) -> {
            // 构建方法调用信息
            Invocation invocation = new Invocation();
            invocation.setServiceName(serviceName);
            invocation.setClassName(interfaceClass.getName());
            invocation.setMethodName(method.getName());
            invocation.setParamTypes(method.getParameterTypes());
            invocation.setArgs(args);
            invocation.setReturnType(method.getReturnType());

            HttpClient httpClient = new HttpClient();
            // 服务重试
            RpcResponse response = Retryer.doRetry(() -> httpClient.send(invocation));
            if (response.getData() != null) {
                return response.getData();
            } else {
                // TODO: 重试多次后,服务降级
                throw new RuntimeException(response.getException());
            }
        });
        return (T) proxyInstance;
    }
}

Invocation是框架提供的类,主要用于保存方法调用的信息,比如方法名、参数、返回值类型等

package com.myrpc.domain;

import lombok.Data;

import java.io.Serializable;

/**
 * @author huliua
 * @version 1.0
 * @date 2024-04-14 17:38
 */
@Data
public class Invocation implements Serializable {

    private String serviceName;

    private String className;

    private String methodName;

    private Class[] paramTypes;

    private Object[] args;

    private Class returnType;
}

HttpClient是框架提供的类,是客户端的核心类。主要用于根据方法调用参数发现服务,再通过负载均衡获取具体的服务,然后根据服务的元数据(主要为主机、端口信息)发起http请求,实现服务的远程调用

💡 注意:负载均衡、重试、服务降级都是在客户端实现的

package com.myrpc.client;

import cn.hutool.core.collection.CollUtil;
import cn.hutool.core.map.MapUtil;
import com.alibaba.fastjson2.JSON;
import com.myrpc.domain.Invocation;
import com.myrpc.domain.ServiceMetaInfo;
import com.myrpc.loadbalance.LoadBalance;
import com.myrpc.loadbalance.impl.RandomLoadBalance;
import com.myrpc.register.ServiceRegister;
import org.apache.commons.io.IOUtils;

import java.io.IOException;
import java.io.InputStream;
import java.io.ObjectOutputStream;
import java.io.OutputStream;
import java.net.HttpURLConnection;
import java.net.URL;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Objects;

/**
 * 提供给服务调用端使用
 *
 * @author huliua
 * @version 1.0
 * @date 2024-04-14 19:41
 */
public class HttpClient {

    /**
     * 客户端--服务列表缓存
     */
    private static final Map<String, List<ServiceMetaInfo>> serviceCacheMap = new HashMap<>();

    private final LoadBalance loadBalance;

    public HttpClient() {
        loadBalance = new RandomLoadBalance();
    }

    public Object send(Invocation invocation) {
        try {
            // 优先从本地缓存中获取服务
            List<ServiceMetaInfo> serviceList = serviceCacheMap.get(invocation.getServiceName());
            if (CollUtil.isEmpty(serviceList)) {
                // 本地缓存没有,则从注册中心获取
                serviceList = ServiceRegister.getService(invocation.getServiceName());
            }
						
						// 负载均衡
            ServiceMetaInfo service = loadBalance.loadBalance(serviceList);
            if (null == service) {
                throw new RuntimeException("service not found");
            }

						// 发起请求
            URL url = new URL("http", service.getHost(), service.getPort(), "/");
            HttpURLConnection httpURLConnection = (HttpURLConnection) url.openConnection();

            httpURLConnection.setRequestMethod("POST");
            httpURLConnection.setDoOutput(true);

            OutputStream outputStream = httpURLConnection.getOutputStream();
            ObjectOutputStream oos = new ObjectOutputStream(outputStream);

            oos.writeObject(invocation);
            oos.flush();
            oos.close();

            InputStream inputStream = httpURLConnection.getInputStream();
            // 返回响应
            return JSON.parseObject(IOUtils.toString(inputStream), invocation.getReturnType());
        } catch (IOException e) {
            throw new RuntimeException(e);
        }
    }
}

Retryer是框架提供的类,主要用于实现重试。当服务异常时,通过重试机制多次重新请求。保证服务的高可用。本例中默认会进行3次重试,每次重试直接间隔1秒。

package com.myrpc.retry;

import com.myrpc.domain.RpcResponse;
import lombok.extern.slf4j.Slf4j;

import java.util.concurrent.Callable;
import java.util.concurrent.TimeUnit;

/**
 * 服务重试机制
 * @author huliua
 * @version 1.0
 * @date 2024-04-16 15:55
 */
@Slf4j
public class Retryer {

    /**
     * 最大重试次数
     */
    public static final int MAX_RETRY_TIMES = 3;
    /**
     * 重试间隔时间,单位:秒
     */
    public static final int RETRY_SLEEP_SECOND = 1;

    public static RpcResponse doRetry(Callable<?> callable) throws InterruptedException {
        RpcResponse res = new RpcResponse();
        int retryTimes = 0;
        while (retryTimes < MAX_RETRY_TIMES) {
            try {
                Object callResult = callable.call();
                res.setData(callResult);
                break;
            } catch (Exception e) {
                retryTimes++;
                log.info("retrying......retry times: {}", retryTimes);
                TimeUnit.SECONDS.sleep(RETRY_SLEEP_SECOND);
                res.setException(e);
            }
        }
        return res;
    }
}

6.启动,测试!

6.1 先启动服务生产者

手写简单的RPC_List_02

6.2 再启动消费者

手写简单的RPC_java_03

大功告成~~

7.Github仓库

Github-myrpc