徒手撸了一个RPC框架,理解更透彻了,代码已上传github,自取~

JAVA公众号

共 37623字,需浏览 76分钟

 ·

2021-03-05 12:05

由于公众号文章推送规则改变,所以为了大家能够准时收到我们的文章推送,请记得将公众号: JAVA 设为星标~这样就不会错过每一篇精彩的推送啦~

 来源 |https://www.cnblogs.com/2YSP/p/13545217.html


一、前言

前段时间看到一篇不错的文章《看了这篇你就会手写RPC框架了》,于是便来了兴趣对着实现了一遍,后面觉得还有很多优化的地方便对其进行了改进。

主要改动点如下:

  1. 除了Java序列化协议,增加了protobuf和kryo序列化协议,配置即用。
  2. 增加多种负载均衡算法(随机、轮询、加权轮询、平滑加权轮询),配置即用。
  3. 客户端增加本地服务列表缓存,提高性能。
  4. 修复高并发情况下,netty导致的内存泄漏问题
  5. 由原来的每个请求建立一次连接,改为建立TCP长连接,并多次复用。
  6. 服务端增加线程池提高消息处理能力

二、介绍

RPC,即 Remote Procedure Call(远程过程调用),调用远程计算机上的服务,就像调用本地服务一样。RPC可以很好的解耦系统,如WebService就是一种基于Http协议的RPC。

调用示意图

调用示意图

总的来说,就如下几个步骤:

  1. 客户端(ServerA)执行远程方法时就调用client stub传递类名、方法名和参数等信息。
  2. client stub会将参数等信息序列化为二进制流的形式,然后通过Sockect发送给服务端(ServerB)
  3. 服务端收到数据包后,server stub 需要进行解析反序列化为类名、方法名和参数等信息。
  4. server stub调用对应的本地方法,并把执行结果返回给客户端

所以一个RPC框架有如下角色:

服务消费者

远程方法的调用方,即客户端。一个服务既可以是消费者也可以是提供者。

服务提供者

远程服务的提供方,即服务端。一个服务既可以是消费者也可以是提供者。

注册中心

保存服务提供者的服务地址等信息,一般由zookeeper、redis等实现。

监控运维(可选)

监控接口的响应时间、统计请求数量等,及时发现系统问题并发出告警通知。

三、实现

本RPC框架rpc-spring-boot-starter涉及技术栈如下:

  • 使用zookeeper作为注册中心
  • 使用netty作为通信框架
  • 消息编解码:protostuff、kryo、java
  • spring
  • 使用SPI来根据配置动态选择负载均衡算法等

由于代码过多,这里只讲几处改动点。

3.1动态负载均衡算法

1.编写LoadBalance的实现类

负载均衡算法实现类

2.自定义注解 @LoadBalanceAno

/**
 * 负载均衡注解
 */

@Target(ElementType.TYPE)
@Retention(RetentionPolicy.RUNTIME)
@Documented
public @interface LoadBalanceAno {

    String value() default "";
}

/**
 * 轮询算法
 */

@LoadBalanceAno(RpcConstant.BALANCE_ROUND)
public class FullRoundBalance implements LoadBalance {

    private static Logger logger = LoggerFactory.getLogger(FullRoundBalance.class);

    private volatile int index;

    @Override
    public synchronized Service chooseOne(List<Service> services) {
        // 加锁防止多线程情况下,index超出services.size()
        if (index == services.size()) {
            index = 0;
        }
        return services.get(index++);
    }
}

3.新建在resource目录下META-INF/servers文件夹并创建文件

enter description here

4.RpcConfig增加配置项loadBalance

/**
 * @author 2YSP
 * @date 2020/7/26 15:13
 */

@ConfigurationProperties(prefix = "sp.rpc")
public class RpcConfig {

    /**
     * 服务注册中心地址
     */

    private String registerAddress = "127.0.0.1:2181";

    /**
     * 服务暴露端口
     */

    private Integer serverPort = 9999;
    /**
     * 服务协议
     */

    private String protocol = "java";
    /**
     * 负载均衡算法
     */

    private String loadBalance = "random";
    /**
     * 权重,默认为1
     */

    private Integer weight = 1;

   // 省略getter setter
}

5.在自动配置类RpcAutoConfiguration根据配置选择对应的算法实现类

/**
     * 使用spi匹配符合配置的负载均衡算法
     *
     * @param name
     * @return
     */

    private LoadBalance getLoadBalance(String name) {
        ServiceLoader<LoadBalance> loader = ServiceLoader.load(LoadBalance.class);
        Iterator<LoadBalance> iterator = loader.iterator();
        while (iterator.hasNext()) {
            LoadBalance loadBalance = iterator.next();
            LoadBalanceAno ano = loadBalance.getClass().getAnnotation(LoadBalanceAno.class);
            Assert.notNull(ano, "load balance name can not be empty!");
            if (name.equals(ano.value())) {
                return loadBalance;
            }
        }
        throw new RpcException("invalid load balance config");
    }

 @Bean
    public ClientProxyFactory proxyFactory(@Autowired RpcConfig rpcConfig) {
        ClientProxyFactory clientProxyFactory = new ClientProxyFactory();
        // 设置服务发现着
        clientProxyFactory.setServerDiscovery(new           ZookeeperServerDiscovery(rpcConfig.getRegisterAddress()));

        // 设置支持的协议
        Map<String, MessageProtocol> supportMessageProtocols = buildSupportMessageProtocols();
        clientProxyFactory.setSupportMessageProtocols(supportMessageProtocols);
        // 设置负载均衡算法
        LoadBalance loadBalance = getLoadBalance(rpcConfig.getLoadBalance());
        clientProxyFactory.setLoadBalance(loadBalance);
        // 设置网络层实现
        clientProxyFactory.setNetClient(new NettyNetClient());

        return clientProxyFactory;
    }

3.2本地服务列表缓存

使用Map来缓存数据

/**
 * 服务发现本地缓存
 */

public class ServerDiscoveryCache {
    /**
     * key: serviceName
     */

    private static final Map<String, List<Service>> SERVER_MAP = new ConcurrentHashMap<>();
    /**
     * 客户端注入的远程服务service class
     */

    public static final List<String> SERVICE_CLASS_NAMES = new ArrayList<>();

    public static void put(String serviceName, List<Service> serviceList) {
        SERVER_MAP.put(serviceName, serviceList);
    }

    /**
     * 去除指定的值
     * @param serviceName
     * @param service
     */

    public static void remove(String serviceName, Service service) {
        SERVER_MAP.computeIfPresent(serviceName, (key, value) ->
                value.stream().filter(o -> !o.toString().equals(service.toString())).collect(Collectors.toList())
        );
    }

    public static void removeAll(String serviceName) {
        SERVER_MAP.remove(serviceName);
    }


    public static boolean isEmpty(String serviceName) {
        return SERVER_MAP.get(serviceName) == null || SERVER_MAP.get(serviceName).size() == 0;
    }

    public static List<Service> get(String serviceName) {
        return SERVER_MAP.get(serviceName);
    }
}

ClientProxyFactory,先查本地缓存,缓存没有再查询zookeeper。

/**
     * 根据服务名获取可用的服务地址列表
     * @param serviceName
     * @return
     */

    private List<Service> getServiceList(String serviceName) {
        List<Service> services;
        synchronized (serviceName){
            if (ServerDiscoveryCache.isEmpty(serviceName)) {
                services = serverDiscovery.findServiceList(serviceName);
                if (services == null || services.size() == 0) {
                    throw new RpcException("No provider available!");
                }
                ServerDiscoveryCache.put(serviceName, services);
            } else {
                services = ServerDiscoveryCache.get(serviceName);
            }
        }
        return services;
    }

问题: 如果服务端因为宕机或网络问题下线了,缓存却还在就会导致客户端请求已经不可用的服务端,增加请求失败率。**解决方案:**由于服务端注册的是临时节点,所以如果服务端下线节点会被移除。只要监听zookeeper的子节点,如果新增或删除子节点就直接清空本地缓存即可。

DefaultRpcProcessor

/**
 * Rpc处理者,支持服务启动暴露,自动注入Service
 * @author 2YSP
 * @date 2020/7/26 14:46
 */

public class DefaultRpcProcessor implements ApplicationListener<ContextRefreshedEvent{

   

    @Override
    public void onApplicationEvent(ContextRefreshedEvent event) {
        // Spring启动完毕过后会收到一个事件通知
        if (Objects.isNull(event.getApplicationContext().getParent())){
            ApplicationContext context = event.getApplicationContext();
            // 开启服务
            startServer(context);
            // 注入Service
            injectService(context);
        }
    }

    private void injectService(ApplicationContext context) {
        String[] names = context.getBeanDefinitionNames();
        for(String name : names){
            Class<?> clazz = context.getType(name);
            if (Objects.isNull(clazz)){
                continue;
            }

            Field[] declaredFields = clazz.getDeclaredFields();
            for(Field field : declaredFields){
                // 找出标记了InjectService注解的属性
                InjectService injectService = field.getAnnotation(InjectService.class);
                if (injectService == null){
                    continue;
                }

                Class<?> fieldClass = field.getType();
                Object object = context.getBean(name);
                field.setAccessible(true);
                try {
                    field.set(object,clientProxyFactory.getProxy(fieldClass));
                } catch (IllegalAccessException e) {
                    e.printStackTrace();
                }
    // 添加本地服务缓存
                ServerDiscoveryCache.SERVICE_CLASS_NAMES.add(fieldClass.getName());
            }
        }
        // 注册子节点监听
        if (clientProxyFactory.getServerDiscovery() instanceof ZookeeperServerDiscovery){
            ZookeeperServerDiscovery serverDiscovery = (ZookeeperServerDiscovery) clientProxyFactory.getServerDiscovery();
            ZkClient zkClient = serverDiscovery.getZkClient();
            ServerDiscoveryCache.SERVICE_CLASS_NAMES.forEach(name ->{
                String servicePath = RpcConstant.ZK_SERVICE_PATH + RpcConstant.PATH_DELIMITER + name + "/service";
                zkClient.subscribeChildChanges(servicePath, new ZkChildListenerImpl());
            });
            logger.info("subscribe service zk node successfully");
        }

    }

    private void startServer(ApplicationContext context) {
        ...

    }
}

ZkChildListenerImpl

/**
 * 子节点事件监听处理类
 */

public class ZkChildListenerImpl implements IZkChildListener {

    private static Logger logger = LoggerFactory.getLogger(ZkChildListenerImpl.class);

    /**
     * 监听子节点的删除和新增事件
     * @param parentPath /rpc/serviceName/service
     * @param childList
     * @throws Exception
     */

    @Override
    public void handleChildChange(String parentPath, List<String> childList) throws Exception {
        logger.debug("Child change parentPath:[{}] -- childList:[{}]", parentPath, childList);
        // 只要子节点有改动就清空缓存
        String[] arr = parentPath.split("/");
        ServerDiscoveryCache.removeAll(arr[2]);
    }
}

3.3nettyClient支持TCP长连接

这部分的改动最多,先增加新的sendRequest接口。

添加接口

实现类NettyNetClient

/**
 * @author 2YSP
 * @date 2020/7/25 20:12
 */

public class NettyNetClient implements NetClient {

    private static Logger logger = LoggerFactory.getLogger(NettyNetClient.class);

    private static ExecutorService threadPool = new ThreadPoolExecutor(410200,
            TimeUnit.SECONDS, new LinkedBlockingQueue<>(1000), new ThreadFactoryBuilder()
            .setNameFormat("rpcClient-%d")
            .build());

    private EventLoopGroup loopGroup = new NioEventLoopGroup(4);

    /**
     * 已连接的服务缓存
     * key: 服务地址,格式:ip:port
     */

    public static Map<String, SendHandlerV2> connectedServerNodes = new ConcurrentHashMap<>();

    @Override
    public byte[] sendRequest(byte[] data, Service service) throws InterruptedException {
  ....
        return respData;
    }

    @Override
    public RpcResponse sendRequest(RpcRequest rpcRequest, Service service, MessageProtocol messageProtocol) {

        String address = service.getAddress();
        synchronized (address) {
            if (connectedServerNodes.containsKey(address)) {
                SendHandlerV2 handler = connectedServerNodes.get(address);
                logger.info("使用现有的连接");
                return handler.sendRequest(rpcRequest);
            }

            String[] addrInfo = address.split(":");
            final String serverAddress = addrInfo[0];
            final String serverPort = addrInfo[1];
            final SendHandlerV2 handler = new SendHandlerV2(messageProtocol, address);
            threadPool.submit(() -> {
                        // 配置客户端
                        Bootstrap b = new Bootstrap();
                        b.group(loopGroup).channel(NioSocketChannel.class)
                                .option(ChannelOption.TCP_NODELAYtrue)
                                .handler(new ChannelInitializer<SocketChannel>() 
{
                                    @Override
                                    protected void initChannel(SocketChannel socketChannel) throws Exception {
                                        ChannelPipeline pipeline = socketChannel.pipeline();
                                        pipeline
                                                .addLast(handler);
                                    }
                                });
                        // 启用客户端连接
                        ChannelFuture channelFuture = b.connect(serverAddress, Integer.parseInt(serverPort));
                        channelFuture.addListener(new ChannelFutureListener() {
                            @Override
                            public void operationComplete(ChannelFuture channelFuture) throws Exception {
                                connectedServerNodes.put(address, handler);
                            }
                        });
                    }
            );
            logger.info("使用新的连接。。。");
            return handler.sendRequest(rpcRequest);
        }
    }
}

每次请求都会调用sendRequest()方法,用线程池异步和服务端创建TCP长连接,连接成功后将SendHandlerV2缓存到ConcurrentHashMap中方便复用,后续请求的请求地址(ip+port)如果在connectedServerNodes中存在则使用connectedServerNodes中的handler处理不再重新建立连接。

SendHandlerV2

/**
 * @author 2YSP
 * @date 2020/8/19 20:06
 */

public class SendHandlerV2 extends ChannelInboundHandlerAdapter {

    private static Logger logger = LoggerFactory.getLogger(SendHandlerV2.class);

    /**
     * 等待通道建立最大时间
     */

    static final int CHANNEL_WAIT_TIME = 4;
    /**
     * 等待响应最大时间
     */

    static final int RESPONSE_WAIT_TIME = 8;

    private volatile Channel channel;

    private String remoteAddress;

    private static Map<String, RpcFuture<RpcResponse>> requestMap = new ConcurrentHashMap<>();

    private MessageProtocol messageProtocol;

    private CountDownLatch latch = new CountDownLatch(1);

    public SendHandlerV2(MessageProtocol messageProtocol,String remoteAddress) {
        this.messageProtocol = messageProtocol;
        this.remoteAddress = remoteAddress;
    }

    @Override
    public void channelRegistered(ChannelHandlerContext ctx) throws Exception {
        this.channel = ctx.channel();
        latch.countDown();
    }

    @Override
    public void channelActive(ChannelHandlerContext ctx) throws Exception {
        logger.debug("Connect to server successfully:{}", ctx);
    }

    @Override
    public void channelRead(ChannelHandlerContext ctx, Object msg) throws Exception {
        logger.debug("Client reads message:{}", msg);
        ByteBuf byteBuf = (ByteBuf) msg;
        byte[] resp = new byte[byteBuf.readableBytes()];
        byteBuf.readBytes(resp);
        // 手动回收
        ReferenceCountUtil.release(byteBuf);
        RpcResponse response = messageProtocol.unmarshallingResponse(resp);
        RpcFuture<RpcResponse> future = requestMap.get(response.getRequestId());
        future.setResponse(response);
    }

    @Override
    public void exceptionCaught(ChannelHandlerContext ctx, Throwable cause) throws Exception {
        cause.printStackTrace();
        logger.error("Exception occurred:{}", cause.getMessage());
        ctx.close();
    }

    @Override
    public void channelReadComplete(ChannelHandlerContext ctx) throws Exception {
        ctx.flush();
    }

    @Override
    public void channelInactive(ChannelHandlerContext ctx) throws Exception {
        super.channelInactive(ctx);
        logger.error("channel inactive with remoteAddress:[{}]",remoteAddress);
        NettyNetClient.connectedServerNodes.remove(remoteAddress);

    }

    @Override
    public void userEventTriggered(ChannelHandlerContext ctx, Object evt) throws Exception {
        super.userEventTriggered(ctx, evt);
    }

    public RpcResponse sendRequest(RpcRequest request) {
        RpcResponse response;
        RpcFuture<RpcResponse> future = new RpcFuture<>();
        requestMap.put(request.getRequestId(), future);
        try {
            byte[] data = messageProtocol.marshallingRequest(request);
            ByteBuf reqBuf = Unpooled.buffer(data.length);
            reqBuf.writeBytes(data);
            if (latch.await(CHANNEL_WAIT_TIME,TimeUnit.SECONDS)){
                channel.writeAndFlush(reqBuf);
                // 等待响应
                response = future.get(RESPONSE_WAIT_TIME, TimeUnit.SECONDS);
            }else {
                throw new RpcException("establish channel time out");
            }
        } catch (Exception e) {
            throw new RpcException(e.getMessage());
        } finally {
            requestMap.remove(request.getRequestId());
        }
        return response;
    }
}

RpcFuture

package cn.sp.rpc.client.net;

import java.util.concurrent.*;

/**
 * @author 2YSP
 * @date 2020/8/19 22:31
 */

public class RpcFuture<Timplements Future<T{

    private T response;
    /**
     * 因为请求和响应是一一对应的,所以这里是1
     */

    private CountDownLatch countDownLatch = new CountDownLatch(1);
    /**
     * Future的请求时间,用于计算Future是否超时
     */

    private long beginTime = System.currentTimeMillis();

    @Override
    public boolean cancel(boolean mayInterruptIfRunning) {
        return false;
    }

    @Override
    public boolean isCancelled() {
        return false;
    }

    @Override
    public boolean isDone() {
        if (response != null) {
            return true;
        }
        return false;
    }

    /**
     * 获取响应,直到有结果才返回
     * @return
     * @throws InterruptedException
     * @throws ExecutionException
     */

    @Override
    public T get() throws InterruptedException, ExecutionException {
        countDownLatch.await();
        return response;
    }

    @Override
    public T get(long timeout, TimeUnit unit) throws InterruptedException, ExecutionException, TimeoutException {
        if (countDownLatch.await(timeout,unit)){
            return response;
        }
        return null;
    }

    public void setResponse(T response) {
        this.response = response;
        countDownLatch.countDown();
    }

    public long getBeginTime() {
        return beginTime;
    }
}

此处逻辑,第一次执行 SendHandlerV2#sendRequest() 时channel需要等待通道建立好之后才能发送请求,所以用CountDownLatch来控制,等待通道建立。自定义Future+requestMap缓存来实现netty的请求和阻塞等待响应,RpcRequest对象在创建时会生成一个请求的唯一标识requestId,发送请求前先将RpcFuture缓存到requestMap中,key为requestId,读取到服务端的响应信息后(channelRead方法),将响应结果放入对应的RpcFuture中。SendHandlerV2#channelInactive() 方法中,如果连接的服务端异常断开连接了,则及时清理缓存中对应的serverNode。

四、压力测试

测试环境:

  • (英特尔)Intel(R) Core(TM) i5-6300HQ CPU @ 2.30GHz 4核
  • windows10家庭版(64位)
  • 16G内存

1.本地启动zookeeper 2.本地启动一个消费者,两个服务端,轮询算法 3.使用ab进行压力测试,4个线程发送10000个请求

ab -c 4 -n 10000 http://localhost:8080/test/user?id=1

测试结果

测试结果

从图片可以看出,10000个请求只用了11s,比之前的130+秒耗时减少了10倍以上。

代码地址:
https://github.com/2YSP/rpc-spring-boot-starter

https://github.com/2YSP/rpc-example

【END】
读 
1. 带工作流的springboot后台管理项目,一个企业级快速开发解决方案
2. 面试官问:为什么SpringBoot的 jar 可以直接运行?
3. 推荐一套超高颜值的 Spring Boot 快速开发框架【文末送书】
4. 防止删库跑路?市值缩水近 24 亿元!就靠堡垒机?这货这么吊?

5.  2020年度开发者工具Top 100名单!


浏览 23
点赞
评论
收藏
分享

手机扫一扫分享

分享
举报
评论
图片
表情
推荐
点赞
评论
收藏
分享

手机扫一扫分享

分享
举报