基于Java的Socket API,我们能够实现一个简单的RPC调用,在这个例子中,包括了服务的接口及接口的远端实现,服务的消费者与远端的提供方。基于TCP协议所实现的RPC的类图,如下:

基于TCP协议实现RPC_服务端


项目的目录结构如下:

基于TCP协议实现RPC_java_02


 

1.首先编码服务端代码:

①定义接口

package com.bjsxt.tcp;

public interface SayHelloService {

	/*
	 * 问好的接口
	 * 
	 */
	public String sayHello(String helloArg);
}

②接口实现

package com.bjsxt.tcp;

public class SayHelloServiceImpl implements SayHelloService{

	public String sayHello(String helloArg) {
		if(helloArg.equals("hello")){
			return "hello";
		}else{
			return "bye bye";
		}
	}

}

③定义消息体

package com.bjsxt.tcp;

import java.io.Serializable;

/**
 * 远程调用信息封装(包含.1.调用接口名称(包名+接口名) 2.调用方法名 3.调用参数Class类型数组)
 * @author 316311
 *
 */
public class TransportMessage implements Serializable {
	
	//包名+接口名
	private String interfaceName;
	
	//方法名
	private String methodName;
	
	//参数类型 按照接口参数顺序
	private Class[] parameterTypes;
	
	//参数 按照接口参数顺序
	private Object[] parameters;
	
	
	public TransportMessage(){
		super();
	}
	
	public TransportMessage(String interfaceName, String methodName, Class[] parameterTypes,
			Object[] parameters){
		this.interfaceName = interfaceName;
		
		this.methodName = methodName;
		
		this.parameterTypes = parameterTypes;
		
		this.parameters = parameters;
	}

	public String getInterfaceName() {
		return interfaceName;
	}

	public void setInterfaceName(String interfaceName) {
		this.interfaceName = interfaceName;
	}

	public String getMethodName() {
		return methodName;
	}

	public void setMethodName(String methodName) {
		this.methodName = methodName;
	}

	public Class[] getParameterTypes() {
		return parameterTypes;
	}

	public void setParameterTypes(Class[] parameterTypes) {
		this.parameterTypes = parameterTypes;
	}

	public Object[] getParameters() {
		return parameters;
	}

	public void setParameters(Object[] parameters) {
		this.parameters = parameters;
	}

}

④定义服务端代码,简单客户端请求

package com.bjsxt.tcp;

import java.io.IOException;
import java.io.ObjectInputStream;
import java.io.ObjectOutputStream;
import java.lang.reflect.InvocationTargetException;
import java.lang.reflect.Method;
import java.net.ServerSocket;
import java.net.Socket;
import java.util.HashMap;
import java.util.Map;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors;

public class Provider {
	
	private int threadSize = 10;
	
	private ExecutorService threadPool;
	
	private Map<String, Object> servicePool;
	
	private int port = 4321;
	
	public Provider(){
		super();
		synchronized(this){
			threadPool = Executors.newFixedThreadPool(this.threadSize);
		}
	}
	
	/**
	 * 
	 * @param threadSize
	 *        内部线程池的大小
	 * @param port
	 *        当前TCP服务的端口号
	 */
	public Provider(int threadSize, int port){
		this.threadSize = threadSize;
		
		this.port = port;
		
		synchronized(this){
			this.threadPool = Executors.newFixedThreadPool(this.threadSize);
		}
	}
	
	
	public Provider(int threadSize, int port, Map<String,Object> servicePool){
		this.threadSize = threadSize;
		
		this.port =  port;
		
		this.servicePool = servicePool;
		
		synchronized(this){
			this.threadPool = Executors.newFixedThreadPool(this.threadSize);
		}
		
	}
	
    /**
     * RPC服务端处理函数 监听指定的TPC端口,每次有请求过来的时候调用服务,放入线程池中处理	
     */
    public void service() throws IOException{
    	ServerSocket serverSocket = new ServerSocket(port);
    	
    	while(true){
    		System.out.println("Provider start....");
    		final Socket receiveSocket = serverSocket.accept();
    		System.out.println("Provider end...");
    		threadPool.execute(new Runnable() {
				
				@Override
				public void run() {
					try {
						process(receiveSocket);
					} catch (ClassNotFoundException e) {
						e.printStackTrace();
					} catch (InstantiationException e) {
						// TODO Auto-generated catch block
						e.printStackTrace();
					} catch (IllegalAccessException e) {
						// TODO Auto-generated catch block
						e.printStackTrace();
					} catch (NoSuchMethodException e) {
						// TODO Auto-generated catch block
						e.printStackTrace();
					} catch (SecurityException e) {
						// TODO Auto-generated catch block
						e.printStackTrace();
					} catch (IllegalArgumentException e) {
						// TODO Auto-generated catch block
						e.printStackTrace();
					} catch (InvocationTargetException e) {
						// TODO Auto-generated catch block
						e.printStackTrace();
					} catch (IOException e) {
						// TODO Auto-generated catch block
						e.printStackTrace();
					}
					
				}
			});
    	}
    }
    
    /*
     * 调用服务 通过TCP Socket返回结果对象
     */
    public void process(Socket receiveSocket) throws IOException, ClassNotFoundException, InstantiationException, IllegalAccessException, NoSuchMethodException, SecurityException, IllegalArgumentException, InvocationTargetException{
    	ObjectInputStream objectInputStream = new ObjectInputStream(receiveSocket.getInputStream());
    	
    	TransportMessage message = (TransportMessage)objectInputStream.readObject();
    	
    	//调用服务
    	Object result = call(message);
    	
    	//返回结果
    	ObjectOutputStream objectOutputStream = new ObjectOutputStream(receiveSocket.getOutputStream());
        
    	objectOutputStream.writeObject(result);
    	
    	objectInputStream.close();
    	objectOutputStream.close();
    }
    
    public Object call(TransportMessage message) throws ClassNotFoundException, InstantiationException, IllegalAccessException, NoSuchMethodException, SecurityException, IllegalArgumentException, InvocationTargetException{
    	
    	//根据接口的全限定名
    	String interfaceName = message.getInterfaceName();
    	
    	//从容器中获取服务对象
    	Object service = servicePool.get(interfaceName);
    	System.out.println(service);
    	
    	Class<?> serviceClass = Class.forName(interfaceName);
    	
    	Method method = serviceClass.getMethod(message.getMethodName(), message.getParameterTypes());
    	
    	Object result = method.invoke(service, message.getParameters());
    	
    	return result;
    }
	
}

 

2.编码客户端代码(客户端代码中接口定义和消息体与服务端代码一致,在此省略):

客户端代码如下:

package com.bjsxt.tcp;
import java.io.IOException;
import java.io.ObjectInputStream;
import java.io.ObjectOutputStream;
import java.lang.reflect.Method;
import java.net.Socket;
import java.net.UnknownHostException;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors;

public class Consumer {
	
	//服务端地址
	private String serverAddress;
	
	//服务端端口
	private int serverPort;
	
	//线程池大小
	private int threadPoolSize = 10;
	
	//线程池
	private ExecutorService executorService = null;
	
	public Consumer(){
		
	}
	
	public Consumer(String serverAddress, int serverPort){
		
		this.serverAddress = serverAddress;
		
		this.serverPort = serverPort;
		
		this.executorService = Executors.newFixedThreadPool(threadPoolSize);
	}
	
	/**
	 * 同步的请求和接收结果
	 */
	public Object sendAndReceive(TransportMessage transportMessage){
		
		Object result = null;
		
		Socket socket = null;
		
		try {
			socket = new Socket(serverAddress, serverPort);
			
			//反序列化 TransportMessage对象
			ObjectOutputStream objectOutputStream = new ObjectOutputStream(socket.getOutputStream());
			
			objectOutputStream.writeObject(transportMessage);
			
			ObjectInputStream objectInputStream = new ObjectInputStream(socket.getInputStream());
			
			//阻塞等待读取结果并返回序列化结果对象
			result = objectInputStream.readObject();
			
			objectOutputStream.close();
			objectInputStream.close();
			
		} catch (UnknownHostException e) {
			e.printStackTrace();
		} catch (IOException e) {
			e.printStackTrace();
		} catch (ClassNotFoundException e) {
			e.printStackTrace();
		} finally{
			
			try {
				socket.close();
			} catch (IOException e) {
				e.printStackTrace();
			}
		}
		
		return result;
	}
}

 

3.测试:

①编码服务端测试类:

package com.bjsxt.test;

import java.io.IOException;
import java.util.HashMap;
import java.util.Map;

import com.bjsxt.tcp.Provider;
import com.bjsxt.tcp.SayHelloServiceImpl;

public class ServerTest {

    
   /**
    * 启动服务端服务
    * @param args
    */
    public static void main(String[] args){

        Map<String,Object> servicePool = new  HashMap<String, Object>();
        //
        servicePool.put("com.bjsxt.tcp.SayHelloService", new SayHelloServiceImpl());

        Provider server = new Provider(4, 4321, servicePool);

        try {

            server.service();

        } catch (Exception e) {

            e.printStackTrace();

        }

         

    }
 
}

 ②编写客户端测试类:

package com.bjsxt.test;

import com.bjsxt.tcp.Consumer;
import com.bjsxt.tcp.TransportMessage;

public class ClientTest {

    public static void main(String[] args) {

        String serverAddress = "127.0.0.1";

        int serverPort = 4321;

         

        final Consumer client = new Consumer(serverAddress, serverPort);

        final TransportMessage transportMessage = buildTransportMessage();

         

        for (int i = 0; i < 3; i++) {

            final int waitTime = i * 10;

            new Thread(new Runnable() {

                public void run() {

                    Object result = client.sendAndReceive(transportMessage);

                    System.out.println(result);

                }

            }).start();

        }

    }
 

    private static TransportMessage buildTransportMessage() {
 

        String interfaceName = "com.bjsxt.tcp.SayHelloService";

        Class[] paramsTypes = {String.class };

        Object[] parameters = {"hello"};

        String methodName = "sayHello";
 

        TransportMessage transportMessage = new TransportMessage(interfaceName,
                methodName, paramsTypes, parameters);
 

        return transportMessage;

    }
 
}

 ③先执行服务端,后执行客户端,输出结果如下:

基于TCP协议实现RPC_网络_03