Spring cloud gateway自定义filter以及负载均衡

自定义全局filter

package com.example.demo;

import java.nio.charset.StandardCharsets;

import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;
import org.springframework.cloud.gateway.filter.GatewayFilter;
import org.springframework.cloud.gateway.filter.GatewayFilterChain;
import org.springframework.cloud.gateway.filter.GlobalFilter;
import org.springframework.core.Ordered;
import org.springframework.core.io.buffer.DataBuffer;
import org.springframework.http.HttpStatus;
import org.springframework.http.server.reactive.ServerHttpResponse;
import org.springframework.stereotype.Component;
import org.springframework.web.server.ServerWebExchange;

import reactor.core.publisher.Flux;
import reactor.core.publisher.Mono;

/**
 * @author woniu
 * @date 2019/10/11 16:07
 */
@Component
public class CustomerTokenFilter implements GlobalFilter, Ordered{

    private static final Log log = LogFactory.getLog(GatewayFilter.class);

    private static final String REQUEST_TIME_BEGIN = "requestTimeBegin";
    @Override
    public Mono<Void> filter(ServerWebExchange exchange, GatewayFilterChain chain) {
        exchange.getAttributes().put(REQUEST_TIME_BEGIN, System.currentTimeMillis());
        log.info("contain token " + exchange.getRequest().getHeaders().containsKey("token"));
        log.info("token is " + exchange.getRequest().getHeaders().get("token"));
        if (exchange.getRequest().getPath().value().contains("success") || exchange.getRequest().getPath().value()
                .contains("rest")) {
            return chain.filter(exchange).then(
                    Mono.fromRunnable(() -> {
                        Long startTime = exchange.getAttribute(REQUEST_TIME_BEGIN);
                        if (startTime != null) {
                            log.info(exchange.getRequest().getURI().getRawPath() + ": " + (System.currentTimeMillis() - startTime) + "ms");
                        }
                    })
            );
        } else {
            byte[] bytes =
                    "{\"status\":429,\"msg\":\"Too Many Requests\",\"data\":{}}".getBytes(StandardCharsets.UTF_8);
            DataBuffer buffer = exchange.getResponse().bufferFactory().wrap(bytes);
            ServerHttpResponse serverHttpResponse = exchange.getResponse();
            serverHttpResponse.setStatusCode(HttpStatus.OK);
            return exchange.getResponse().writeWith(Flux.just(buffer));
        }

    }

    @Override
    public int getOrder() {
        return 0;
    }
}

自定义LoadBalanceRule

package com.example.demo;

import java.util.List;
import org.apache.commons.lang.math.RandomUtils;
import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;

import com.netflix.client.config.IClientConfig;
import com.netflix.loadbalancer.AbstractLoadBalancerRule;
import com.netflix.loadbalancer.Server;
/**
 * @author woniu
 * @date 2019/10/12 16:20
 */
public class CustomerLoadBalanceRule extends AbstractLoadBalancerRule {

    private static final Log log = LogFactory.getLog(CustomerLoadBalanceRule.class);

    @Override
    public Server choose(Object key) {
        log.info("key is " + key);
        List<Server> servers = this.getLoadBalancer().getReachableServers();
        log.info("servers " + servers);
        if (servers.isEmpty()) {
            return null;
        }
        if (servers.size() == 1) {
            return servers.get(0);
        }
        return randomChoose(servers);
    }
    /**
     *
     * <p>随机返回一个服务实例 </p>
     * @param servers 服务列表
     */
    private Server randomChoose(List<Server> servers) {
        int randomIndex = RandomUtils.nextInt(servers.size());
        return servers.get(randomIndex);
    }

    @Override
    public void initWithNiwsConfig(IClientConfig iClientConfig) {

    }
}

application.properties 配置

logging.level.org.springframework.cloud.gateway: TRACE
logging.level.org.springframework.http.server.reactive: TRACE
logging.level.org.springframework.web.reactive: TRACE
logging.level.org.springframework.boot.autoconfigure.web: TRACE

spring.cloud.gateway.routes[0].id: 1bcd
spring.cloud.gateway.routes[0].uri:  lb://my-load-balanced-service
spring.cloud.gateway.routes[0].predicates[0].name: Path
spring.cloud.gateway.routes[0].predicates[0].args[0]: /rest/**


my-load-balanced-service.ribbon.listOfServers: http://woniu.com
my-load-balanced-service.ribbon.NFLoadBalancerRuleClassName: com.example.demo.CustomerLoadBalanceRule

pom.xml引入的dependency

<dependency>
            <groupId>org.springframework.cloud</groupId>
            <artifactId>spring-cloud-starter-netflix-ribbon</artifactId>
            <version>2.1.2.RELEASE</version>
        </dependency>

        <dependency>
            <groupId>org.springframework.boot</groupId>
            <artifactId>spring-boot-starter</artifactId>
        </dependency>
        <dependency>
            <groupId>org.springframework.boot</groupId>
            <artifactId>spring-boot-starter-test</artifactId>
            <scope>test</scope>
        </dependency>
        <dependency>
            <groupId>org.springframework.cloud</groupId>
            <artifactId>spring-cloud-starter-gateway</artifactId>
            <version>2.1.2.RELEASE</version>
        </dependency>
        <dependency>
            <groupId>org.springframework.cloud</groupId>
            <artifactId>spring-cloud-starter-netflix-hystrix</artifactId>
            <version>2.1.2.RELEASE</version>
        </dependency>

 

githup地址:https://github.com/baishi6582/wns/tree/master/springgatewaydemo

posted @ 2019-10-12 19:59  woniu4  阅读(2421)  评论(0编辑  收藏  举报