初始化

This commit is contained in:
wangzihaogithub 2021-12-13 14:01:19 +08:00
commit 878623aafa
8 changed files with 1265 additions and 0 deletions

33
.gitignore vendored Normal file
View File

@ -0,0 +1,33 @@
HELP.md
target/
!.mvn/wrapper/maven-wrapper.jar
!**/src/main/**/target/
!**/src/test/**/target/
### STS ###
.apt_generated
.classpath
.factorypath
.project
.settings
.springBeans
.sts4-cache
### IntelliJ IDEA ###
.idea
*.iws
*.iml
*.ipr
### NetBeans ###
/nbproject/private/
/nbbuild/
/dist/
/nbdist/
/.nb-gradle/
build/
!**/src/main/**/build/
!**/src/test/**/build/
### VS Code ###
.vscode/

249
pom.xml Normal file
View File

@ -0,0 +1,249 @@
<?xml version="1.0" encoding="UTF-8"?>
<project xmlns="http://maven.apache.org/POM/4.0.0" xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance"
xsi:schemaLocation="http://maven.apache.org/POM/4.0.0 https://maven.apache.org/xsd/maven-4.0.0.xsd">
<modelVersion>4.0.0</modelVersion>
<groupId>com.github.wangzihaogithub</groupId>
<artifactId>sse-server</artifactId>
<version>0.0.1</version>
<name>sse-server</name>
<description>Sse server for Spring Boot</description>
<properties>
<argLine>-Dfile.encoding=UTF-8</argLine>
<!-- 文件拷贝时的编码 -->
<project.build.sourceEncoding>UTF-8</project.build.sourceEncoding>
<project.reporting.outputEncoding>UTF-8</project.reporting.outputEncoding>
<!-- 编译时的编码 -->
<maven.compiler.encoding>UTF-8</maven.compiler.encoding>
<java.version>1.8</java.version>
</properties>
<dependencies>
<dependency>
<groupId>org.slf4j</groupId>
<artifactId>slf4j-api</artifactId>
<version>1.7.32</version>
<scope>provided</scope>
</dependency>
<dependency>
<groupId>org.springframework</groupId>
<artifactId>spring-webmvc</artifactId>
<version>5.3.9</version>
<scope>provided</scope>
</dependency>
<dependency>
<groupId>org.springframework.boot</groupId>
<artifactId>spring-boot</artifactId>
<version>2.6.1</version>
<scope>test</scope>
</dependency>
<dependency>
<groupId>org.springframework.boot</groupId>
<artifactId>spring-boot-starter-web</artifactId>
<version>2.6.1</version>
<scope>test</scope>
</dependency>
<dependency>
<groupId>org.springframework.boot</groupId>
<artifactId>spring-boot-starter-test</artifactId>
<version>2.3.5.RELEASE</version>
<scope>test</scope>
</dependency>
</dependencies>
<build>
<plugins>
<plugin>
<groupId>org.apache.maven.plugins</groupId>
<artifactId>maven-compiler-plugin</artifactId>
<version>3.5.1</version>
<configuration>
<encoding>UTF-8</encoding>
<source>1.8</source>
<target>1.8</target>
</configuration>
<dependencies>
<dependency>
<groupId>org.codehaus.plexus</groupId>
<artifactId>plexus-compiler-javac</artifactId>
<version>2.7</version>
</dependency>
</dependencies>
</plugin>
<plugin>
<groupId>org.apache.maven.plugins</groupId>
<artifactId>maven-source-plugin</artifactId>
<version>3.0.1</version>
<executions>
<execution>
<id>attach-sources</id>
<goals>
<goal>jar</goal>
</goals>
</execution>
</executions>
<configuration>
<attach>true</attach>
</configuration>
</plugin>
<plugin>
<groupId>org.apache.maven.plugins</groupId>
<artifactId>maven-javadoc-plugin</artifactId>
<version>2.10.2</version>
<executions>
<execution>
<id>attach-javadoc</id>
<goals>
<goal>jar</goal>
</goals>
</execution>
</executions>
<configuration>
<show>public</show>
<charset>UTF-8</charset>
<encoding>UTF-8</encoding>
<docencoding>UTF-8</docencoding>
<links>
<link>http://docs.oracle.com/javase/8/docs/api</link>
</links>
</configuration>
</plugin>
<plugin>
<groupId>org.apache.maven.plugins</groupId>
<artifactId>maven-gpg-plugin</artifactId>
<version>1.6</version>
<executions>
<execution>
<id>sign-artifacts</id>
<phase>verify</phase>
<goals>
<goal>sign</goal>
</goals>
</execution>
</executions>
</plugin>
</plugins>
</build>
<distributionManagement>
<snapshotRepository>
<id>ossrh</id>
<url>https://oss.sonatype.org/content/repositories/snapshots</url>
</snapshotRepository>
<repository>
<id>ossrh</id>
<url>https://oss.sonatype.org/service/local/staging/deploy/maven2/</url>
</repository>
<!-- <repository>-->
<!-- <id>lechun-releases</id>-->
<!-- <name>releases repository</name>-->
<!-- <url>http://101.201.223.148:9099/nexus/content/repositories/releases</url>-->
<!-- </repository>-->
<!-- <snapshotRepository>-->
<!-- <id>lechun-snapshots</id>-->
<!-- <name>snapshots repository</name>-->
<!-- <url>http://101.201.223.148:9099/nexus/content/repositories/snapshots</url>-->
<!-- </snapshotRepository>-->
</distributionManagement>
<profiles>
<profile>
<id>release</id>
<build>
<pluginManagement>
<plugins>
<plugin>
<groupId>org.apache.maven.plugins</groupId>
<artifactId>maven-compiler-plugin</artifactId>
<version>3.5.1</version>
<configuration>
<encoding>UTF-8</encoding>
<source>1.8</source>
<target>1.8</target>
</configuration>
<dependencies>
<dependency>
<groupId>org.codehaus.plexus</groupId>
<artifactId>plexus-compiler-javac</artifactId>
<version>2.7</version>
</dependency>
</dependencies>
</plugin>
<plugin>
<groupId>org.apache.maven.plugins</groupId>
<artifactId>maven-source-plugin</artifactId>
<version>3.0.1</version>
<executions>
<execution>
<id>attach-sources</id>
<goals>
<goal>jar</goal>
</goals>
</execution>
</executions>
<configuration>
<attach>true</attach>
</configuration>
</plugin>
<plugin>
<groupId>org.apache.maven.plugins</groupId>
<artifactId>maven-javadoc-plugin</artifactId>
<version>2.10.2</version>
<executions>
<execution>
<id>attach-javadoc</id>
<goals>
<goal>jar</goal>
</goals>
<configuration>
<additionalparam>-Xdoclint:none</additionalparam>
</configuration>
</execution>
</executions>
<configuration>
<show>public</show>
<charset>UTF-8</charset>
<encoding>UTF-8</encoding>
<docencoding>UTF-8</docencoding>
<links>
<link>http://docs.oracle.com/javase/8/docs/api</link>
</links>
</configuration>
</plugin>
<plugin>
<groupId>org.apache.maven.plugins</groupId>
<artifactId>maven-gpg-plugin</artifactId>
<version>1.6</version>
<executions>
<execution>
<id>sign-artifacts</id>
<phase>verify</phase>
<goals>
<goal>sign</goal>
</goals>
</execution>
</executions>
</plugin>
</plugins>
</pluginManagement>
</build>
<distributionManagement>
<snapshotRepository>
<id>ossrh</id>
<url>https://oss.sonatype.org/content/repositories/snapshots</url>
</snapshotRepository>
<repository>
<id>ossrh</id>
<url>https://oss.sonatype.org/service/local/staging/deploy/maven2/</url>
</repository>
</distributionManagement>
</profile>
</profiles>
</project>

View File

@ -0,0 +1,22 @@
package com.github.sseserver;
/**
* 当前登录用户
*
* @author hao 2021年12月13日13:48:58
*/
public interface AccessUser {
/**
* 防止循环调用 NULL值穿透
*/
AccessUser NULL = () -> "";
/**
* 使用者自己业务系统的登录连接令牌
*
* @return
*/
String getAccessToken();
}

View File

@ -0,0 +1,100 @@
package com.github.sseserver;
import org.springframework.web.servlet.mvc.method.annotation.SseEmitter.SseEventBuilder;
import java.util.Collection;
import java.util.List;
import java.util.function.Consumer;
/**
* 单机长连接(非分布式)
* 1. 如果用nginx代理, 要加下面的配置
* # 长连接配置
* proxy_buffering off;
* proxy_read_timeout 7200s;
* proxy_pass http://xx.xx.xx.xx:xxx;
* proxy_http_version 1.1; #nginx默认是http1.0, 改为1.1 支持长连接, 和后端保持长连接,复用,防止出现文件句柄打开数量过多的错误
* proxy_set_header Connection ""; # 去掉Connection的close字段
*
* @author hao 2021年12月7日19:27:41
*/
public interface LocalConnectionService {
/**
* 创建用户连接并返回 SseEmitter
*
* @param accessUser 用户令牌
* @param keepaliveTime 链接最大保持时间 0表示不过期默认30秒超过时间未完成会抛出异常AsyncRequestTimeoutException
* @return SseEmitter
*/
<ACCESS_USER extends AccessUser> SseEmitter<ACCESS_USER> connect(ACCESS_USER accessUser, Long keepaliveTime);
<ACCESS_USER extends AccessUser> void addConnectListener(String accessToken, String channel, Consumer<SseEmitter<ACCESS_USER>> consumer);
<ACCESS_USER extends AccessUser> void addConnectListener(String accessToken, Consumer<SseEmitter<ACCESS_USER>> consumer);
<ACCESS_USER extends AccessUser> void addConnectListener(Consumer<SseEmitter<ACCESS_USER>> consumer);
<ACCESS_USER extends AccessUser> void addDisConnectListener(Consumer<SseEmitter<ACCESS_USER>> consumer);
<ACCESS_USER extends AccessUser> void addDisConnectListener(String accessToken, Consumer<SseEmitter<ACCESS_USER>> consumer);
<ACCESS_USER extends AccessUser> int send(SseEmitter<ACCESS_USER> sseEmitter, SseEventBuilder message);
/**
* 给指定链接发送信息
*/
int send(long connectionId, SseEventBuilder message);
/**
* 给指定管道发送信息
*/
int sendToChannel(String channel, SseEventBuilder message);
/**
* 给指定用户发送信息
*/
int send(String accessToken, SseEventBuilder message);
/**
* 群发消息
*
* @return 发送成功几个人
*/
int send(Collection<String> accessTokens, SseEventBuilder message);
/**
* 群发所有人
*/
int sendAll(SseEventBuilder message);
/**
* 移除用户连接
*
* @return 移除了几个链接
*/
<ACCESS_USER extends AccessUser> List<SseEmitter<ACCESS_USER>> disconnect(String accessToken);
/**
* 移除用户连接
*
* @return 是否成功
*/
<ACCESS_USER extends AccessUser> SseEmitter<ACCESS_USER> disconnect(String accessToken, Long connectionId);
/**
* 获取当前连接信息
*/
List<String> getAccessTokens();
/**
* 获取当前用户数量
*/
int getUserCount();
/**
* 获取当前连接数量
*/
int getConnectionCount();
}

View File

@ -0,0 +1,321 @@
package com.github.sseserver;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.springframework.web.servlet.mvc.method.annotation.SseEmitter.SseEventBuilder;
import java.io.IOException;
import java.util.*;
import java.util.concurrent.ConcurrentHashMap;
import java.util.function.Consumer;
import java.util.function.Predicate;
/**
* 单机长连接(非分布式)
* 1. 如果用nginx代理, 要加下面的配置
* # 长连接配置
* proxy_buffering off;
* proxy_read_timeout 7200s;
* proxy_pass http://xx.xx.xx.xx:xxx;
* proxy_http_version 1.1; #nginx默认是http1.0, 改为1.1 支持长连接, 和后端保持长连接,复用,防止出现文件句柄打开数量过多的错误
* proxy_set_header Connection ""; # 去掉Connection的close字段
*
* @author hao 2021年12月7日19:27:41
*/
public class LocalConnectionServiceImpl implements LocalConnectionService {
private final static Logger log = LoggerFactory.getLogger(LocalConnectionServiceImpl.class);
/**
* 使用map对象便于根据access来获取对应的SseEmitter
*/
private final Map<String, List<SseEmitter>> connectionMap = new ConcurrentHashMap<>();
private final Map<Long, SseEmitter> connectionIdMap = new ConcurrentHashMap<>();
private final List<Consumer<SseEmitter>> connectListeners = new ArrayList<>();
private final List<Consumer<SseEmitter>> disconnectListeners = new ArrayList<>();
private final Map<String, List<Predicate<SseEmitter>>> connectListenerMap = new ConcurrentHashMap<>();
private final Map<String, List<Predicate<SseEmitter>>> disconnectListenerMap = new ConcurrentHashMap<>();
/**
* 创建用户连接并返回 SseEmitter
*
* @param accessUser 用户accessToken
* @return SseEmitter
*/
@Override
public <ACCESS_USER extends AccessUser> SseEmitter<ACCESS_USER> connect(ACCESS_USER accessUser, Long keepaliveTime) {
if (keepaliveTime == null) {
keepaliveTime = 0L;
}
// 设置超时时间0表示不过期tomcat默认30秒超过时间未完成会抛出异常AsyncRequestTimeoutException
SseEmitter<ACCESS_USER> sseEmitter = new SseEmitter<>(keepaliveTime, connectionMap, accessUser);
sseEmitter.onCompletion(completionCallBack(sseEmitter));
sseEmitter.onError(errorCallBack(sseEmitter));
sseEmitter.onTimeout(timeoutCallBack(sseEmitter));
sseEmitter.addDisConnectListener(e -> {
notifyListener(e, disconnectListeners, disconnectListenerMap);
connectionIdMap.remove(e.getId());
});
sseEmitter.addConnectListener(e -> {
connectionIdMap.put(e.getId(), e);
notifyListener(e, connectListeners, connectListenerMap);
});
try {
sseEmitter.send(SseEmitter.event()
.name("connect-finish")
.data("{\"connectionId\":" + sseEmitter.getId() + "}"));
return sseEmitter;
} catch (IOException e) {
log.error("sse send {} IOException:{}", sseEmitter, e.toString(), e);
return null;
}
}
private <ACCESS_USER extends AccessUser> void notifyListener(SseEmitter<ACCESS_USER> sseEmitter,
List<Consumer<SseEmitter>> listeners,
Map<String, List<Predicate<SseEmitter>>> listenerMap) {
for (Consumer<SseEmitter> listener : listeners) {
listener.accept(sseEmitter);
}
List<Predicate<SseEmitter>> consumerList = listenerMap.get(sseEmitter.getAccessToken());
if (consumerList != null) {
for (Predicate<SseEmitter> listener : new ArrayList<>(consumerList)) {
if (listener.test(sseEmitter)) {
consumerList.remove(listener);
}
}
}
}
@Override
public <ACCESS_USER extends AccessUser> void addConnectListener(String accessToken, String channel, Consumer<SseEmitter<ACCESS_USER>> consumer) {
List<SseEmitter> sseEmitters = connectionMap.get(accessToken);
if (sseEmitters != null) {
for (SseEmitter sseEmitter : sseEmitters) {
if (Objects.equals(channel, sseEmitter.getChannel())) {
consumer.accept(sseEmitter);
return;
}
}
}
connectListenerMap.computeIfAbsent(accessToken, e -> new ArrayList<>()).add(e -> {
if (Objects.equals(channel, e.getChannel())) {
consumer.accept(e);
return true;
}
return false;
});
}
@Override
public <ACCESS_USER extends AccessUser> void addConnectListener(String accessToken, Consumer<SseEmitter<ACCESS_USER>> consumer) {
List<SseEmitter> sseEmitters = connectionMap.get(accessToken);
if (sseEmitters != null) {
for (SseEmitter sseEmitter : sseEmitters) {
consumer.accept(sseEmitter);
}
} else {
connectListenerMap.computeIfAbsent(accessToken, e -> new ArrayList<>()).add(e -> {
consumer.accept(e);
return true;
});
}
}
@Override
public <ACCESS_USER extends AccessUser> void addConnectListener(Consumer<SseEmitter<ACCESS_USER>> consumer) {
connectListeners.add((Consumer) consumer);
}
@Override
public <ACCESS_USER extends AccessUser> void addDisConnectListener(Consumer<SseEmitter<ACCESS_USER>> consumer) {
disconnectListeners.add((Consumer) consumer);
}
@Override
public <ACCESS_USER extends AccessUser> void addDisConnectListener(String accessToken, Consumer<SseEmitter<ACCESS_USER>> consumer) {
disconnectListenerMap.computeIfAbsent(accessToken, e -> new ArrayList<>()).add(e -> {
consumer.accept(e);
return true;
});
}
@Override
public <ACCESS_USER extends AccessUser> int send(SseEmitter<ACCESS_USER> sseEmitter, SseEventBuilder message) {
int count = 0;
if (sseEmitter != null) {
if (sseEmitter.isDisconnect()) {
return 0;
}
if (sseEmitter.getChannel() != null) {
return 0;
}
try {
sseEmitter.send(message);
count++;
} catch (IOException e) {
log.warn("sse send {} io exception = {}", sseEmitter, e.toString(), e);
sseEmitter.disconnect();
}
}
return count;
}
@Override
public int send(long connectionId, SseEventBuilder message) {
SseEmitter next = connectionIdMap.get(connectionId);
return send(next, message);
}
/**
* 给指定管道发送信息
*/
@Override
public int sendToChannel(String channel, SseEventBuilder message) {
Collection<SseEmitter> sseEmitters = connectionIdMap.values();
int count = 0;
for (SseEmitter next : new ArrayList<>(sseEmitters)) {
if (next.isDisconnect()) {
continue;
}
if (Objects.equals(next.getChannel(), channel)) {
try {
next.send(message);
count++;
} catch (IOException e) {
log.warn("sse send {} io exception = {}", next, e.toString(), e);
next.disconnect();
}
}
}
return count;
}
/**
* 给指定用户发送信息
*/
@Override
public int send(String accessToken, SseEventBuilder message) {
Collection<SseEmitter> sseEmitters = connectionMap.get(accessToken);
int count = 0;
if (sseEmitters != null) {
for (SseEmitter next : new ArrayList<>(sseEmitters)) {
count += send(next, message);
}
}
return count;
}
/**
* 群发消息
*/
@Override
public int send(Collection<String> accessTokens, SseEventBuilder message) {
if (accessTokens == null) {
return 0;
}
int totalSuccessCount = 0;
for (String accessToken : accessTokens) {
int sendCount = send(accessToken, message);
if (sendCount > 0) {
totalSuccessCount++;
}
}
return totalSuccessCount;
}
/**
* 群发所有人
*/
@Override
public int sendAll(SseEventBuilder message) {
return send(new ArrayList<>(connectionMap.keySet()), message);
}
/**
* 移除用户连接
*/
@Override
public List<SseEmitter> disconnect(String accessToken) {
List<SseEmitter> sseEmitters = connectionMap.remove(accessToken);
List<SseEmitter> result = new ArrayList<>();
if (sseEmitters != null) {
for (SseEmitter next : new ArrayList<>(sseEmitters)) {
if (next.disconnect()) {
result.add(next);
}
}
}
return result;
}
@Override
public SseEmitter disconnect(String accessToken, Long connectionId) {
if (connectionId == null) {
return null;
}
Collection<SseEmitter> sseEmitters = connectionMap.get(accessToken);
if (sseEmitters != null) {
for (SseEmitter next : new ArrayList<>(sseEmitters)) {
if (Objects.equals(next.getId(), connectionId)) {
if (next.disconnect()) {
return next;
} else {
return null;
}
}
}
}
return null;
}
/**
* 获取当前连接信息
*/
@Override
public List<String> getAccessTokens() {
return new ArrayList<>(connectionMap.keySet());
}
/**
* 获取当前用户数量
*/
@Override
public int getUserCount() {
return connectionMap.size();
}
/**
* 获取当前连接数量
*/
@Override
public int getConnectionCount() {
int count = 0;
for (List<SseEmitter> value : connectionMap.values()) {
if (value != null) {
count += value.size();
}
}
return count;
}
private Runnable completionCallBack(SseEmitter sseEmitter) {
return () -> {
log.debug("sse completion 结束连接:{}", sseEmitter);
sseEmitter.disconnect();
};
}
private Runnable timeoutCallBack(SseEmitter sseEmitter) {
return () -> {
log.debug("sse timeout 超过最大连接时间:{}", sseEmitter);
sseEmitter.disconnect();
};
}
private Consumer<Throwable> errorCallBack(SseEmitter sseEmitter) {
return throwable -> {
log.debug("sse error 发生错误:{}, {}", sseEmitter, throwable.toString(), throwable);
sseEmitter.disconnect();
};
}
}

View File

@ -0,0 +1,262 @@
package com.github.sseserver;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.springframework.http.MediaType;
import org.springframework.http.server.ServerHttpResponse;
import org.springframework.lang.Nullable;
import org.springframework.util.StringUtils;
import java.io.IOException;
import java.nio.charset.StandardCharsets;
import java.util.*;
import java.util.concurrent.atomic.AtomicBoolean;
import java.util.concurrent.atomic.AtomicLong;
import java.util.function.Consumer;
public class SseEmitter<ACCESS_USER extends AccessUser> extends org.springframework.web.servlet.mvc.method.annotation.SseEmitter {
private final static Logger log = LoggerFactory.getLogger(SseEmitter.class);
private static final AtomicLong ID_INCR = new AtomicLong();
private static final MediaType TEXT_PLAIN = new MediaType("text", "plain", StandardCharsets.UTF_8);
private final long id = ID_INCR.getAndIncrement();
private final String accessToken;
private final ACCESS_USER accessUser;
private final AtomicBoolean disconnect = new AtomicBoolean();
private final Map<String, List<SseEmitter>> connectionMap;
private final List<Consumer<SseEmitter<ACCESS_USER>>> connectListeners = new ArrayList<>();
private final List<Consumer<SseEmitter<ACCESS_USER>>> disconnectListeners = new ArrayList<>();
private final Map<String, Object> attributeMap = new LinkedHashMap<>();
private boolean connect = false;
private int count;
private String channel;
/**
* timeout = 0是永不过期
*/
public SseEmitter(Long timeout) {
this(timeout, new HashMap<>(), null);
}
/**
* timeout = 0是永不过期
*/
public SseEmitter(Long timeout, Map<String, List<SseEmitter>> connectionMap, ACCESS_USER accessUser) {
super(timeout);
this.connectionMap = connectionMap;
this.accessUser = accessUser;
this.accessToken = accessUser != null ? accessUser.getAccessToken() : null;
connectionMap.computeIfAbsent(accessToken, e -> Collections.synchronizedList(new ArrayList<>()))
.add(this);
log.info("sse connection create : {}", this);
}
public static SseEventBuilder event() {
return new SseEventBuilderImpl();
}
public long getId() {
return id;
}
public String getAccessToken() {
return accessToken;
}
public ACCESS_USER getAccessUser() {
return accessUser;
}
public Map<String, Object> getAttributeMap() {
return attributeMap;
}
public <T> T getAttribute(String key) {
return (T) attributeMap.get(key);
}
public <T> T setAttribute(String key, Object value) {
return (T) attributeMap.put(key, value);
}
public <T> T removeAttribute(String key) {
return (T) attributeMap.remove(key);
}
public String getChannel() {
return channel;
}
public void setChannel(String channel) {
this.channel = channel;
}
public void addConnectListener(Consumer<SseEmitter<ACCESS_USER>> consumer) {
if (connect) {
try {
consumer.accept(this);
} catch (Exception e) {
log.warn("addConnectListener connectListener error = {} {}", e.toString(), consumer, e);
}
} else {
connectListeners.add(consumer);
}
}
public void addDisConnectListener(Consumer<SseEmitter<ACCESS_USER>> consumer) {
if (isDisconnect()) {
consumer.accept(this);
} else {
disconnectListeners.add(consumer);
}
}
@Override
protected void extendResponse(ServerHttpResponse outputMessage) {
super.extendResponse(outputMessage);
connect = true;
for (Consumer<SseEmitter<ACCESS_USER>> connectListener : new ArrayList<>(connectListeners)) {
try {
connectListener.accept(this);
} catch (Exception e) {
log.warn("connectListener error = {} {}", e.toString(), connectListener, e);
}
}
connectListeners.clear();
}
@Override
public void send(SseEventBuilder builder) throws IOException {
if (builder instanceof SseEventBuilderImpl) {
String id = ((SseEventBuilderImpl) builder).id;
String name = ((SseEventBuilderImpl) builder).name;
log.info("sse connection send {} : {}, id = {}, name = {}", ++count, this, id, name);
} else {
log.info("sse connection send {} : {}", ++count, this);
}
super.send(builder);
}
public boolean isDisconnect() {
return disconnect.get();
}
public boolean disconnect() {
boolean remove = false;
if (disconnect.compareAndSet(false, true)) {
for (Consumer<SseEmitter<ACCESS_USER>> disconnectListener : new ArrayList<>(disconnectListeners)) {
try {
disconnectListener.accept(this);
} catch (Exception e) {
log.warn("disconnectListener error = {} {}", e.toString(), disconnectListener, e);
}
}
disconnectListeners.clear();
List<SseEmitter> sseEmitterList = connectionMap.get(accessToken);
if (sseEmitterList != null) {
try {
remove = sseEmitterList.remove(this);
log.info("sse connection disconnect : {}", this);
} catch (Exception e) {
remove = false;
}
if (sseEmitterList.isEmpty()) {
connectionMap.remove(accessToken);
}
}
try {
complete();
} catch (Exception e) {
log.info("sse connection disconnect exception : {}. {}", e.toString(), this);
}
}
return remove;
}
@Override
public String toString() {
if (accessUser == null) {
return id + "#";
} else {
return id + "#" + accessUser;
}
}
/**
* Default implementation of SseEventBuilder.
*/
private static class SseEventBuilderImpl implements SseEventBuilder {
private final Set<DataWithMediaType> dataToSend = new LinkedHashSet<>(4);
private String id;
private String name;
@Nullable
private StringBuilder sb;
@Override
public SseEventBuilder id(String id) {
this.id = id;
append("id:").append(id).append("\n");
return this;
}
@Override
public SseEventBuilder name(String name) {
this.name = name;
append("event:").append(name).append("\n");
return this;
}
@Override
public SseEventBuilder reconnectTime(long reconnectTimeMillis) {
append("retry:").append(String.valueOf(reconnectTimeMillis)).append("\n");
return this;
}
@Override
public SseEventBuilder comment(String comment) {
append(":").append(comment).append("\n");
return this;
}
@Override
public SseEventBuilder data(Object object) {
return data(object, null);
}
@Override
public SseEventBuilder data(Object object, @Nullable MediaType mediaType) {
append("data:");
saveAppendedText();
this.dataToSend.add(new DataWithMediaType(object, mediaType));
append("\n");
return this;
}
SseEventBuilderImpl append(String text) {
if (this.sb == null) {
this.sb = new StringBuilder();
}
this.sb.append(text);
return this;
}
@Override
public Set<DataWithMediaType> build() {
if (!StringUtils.hasLength(this.sb) && this.dataToSend.isEmpty()) {
return Collections.emptySet();
}
append("\n");
saveAppendedText();
return this.dataToSend;
}
private void saveAppendedText() {
if (this.sb != null) {
this.dataToSend.add(new DataWithMediaType(this.sb.toString(), TEXT_PLAIN));
this.sb = null;
}
}
}
}

View File

@ -0,0 +1,245 @@
package com.github.sseserver;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.http.ResponseEntity;
import org.springframework.web.bind.annotation.PathVariable;
import org.springframework.web.bind.annotation.RequestBody;
import org.springframework.web.bind.annotation.RequestMapping;
import org.springframework.web.bind.annotation.RequestParam;
import java.io.Serializable;
import java.util.*;
/**
* 消息事件推送 (非分布式)
* : !! 这里是示例代码, 根据自己项目封装的用户逻辑, 复制到自己项目里
* <p>
* 1. 如果用nginx代理, 要加下面的配置
* # 长连接配置
* proxy_buffering off;
* proxy_read_timeout 7200s;
* proxy_pass http://xx.xx.xx.xx:xxx;
* proxy_http_version 1.1; #nginx默认是http1.0, 改为1.1 支持长连接, 和后端保持长连接,复用,防止出现文件句柄打开数量过多的错误
* proxy_set_header Connection ""; # 去掉Connection的close字段
*
* @author hao 2021年12月7日19:29:51
*/
//@RestController
//@RequestMapping("/api/sse")
public class SseWebController<ACCESS_USER extends AccessUser> {
private LocalConnectionService localConnectionService;
@Autowired
public void setLocalConnectionService(LocalConnectionService localConnectionService) {
this.localConnectionService = localConnectionService;
}
/**
* 获取当前登录用户, 这里返回后, 就可以获取 {@link SseEmitter#getAccessUser()}
*
* @return 使用者自己系统的用户
*/
protected ACCESS_USER getAccessUser() {
return (ACCESS_USER) AccessUser.NULL;
}
protected Object wrapResponse(Object result) {
return new ResponseData<>(result);
}
protected void onConnect(SseEmitter<ACCESS_USER> conncet) {
}
protected void onDisconnect(List<SseEmitter<ACCESS_USER>> disconnectList, ACCESS_USER accessUser, String accessToken, Long connectionId) {
}
/**
* 创建连接
*/
@RequestMapping("/connect")
public SseEmitter connect(@RequestParam Map query, @RequestBody(required = false) Map body,
Long keepaliveTime) {
Map message = new LinkedHashMap<>(query);
if (body != null) {
message.putAll(body);
}
ACCESS_USER accessUser = getAccessUser();
SseEmitter<ACCESS_USER> emitter = localConnectionService.connect(accessUser, keepaliveTime);
emitter.getAttributeMap().putAll(message);
String channel = Objects.toString(message.get("channel"), null);
emitter.setChannel(isBlank(channel) ? null : channel);
onConnect(emitter);
return emitter;
}
/**
* 推送给所有人
*
* @return
*/
@RequestMapping("/send")
public ResponseEntity send(@RequestParam Map query, @RequestBody(required = false) Map body) {
Map message = new LinkedHashMap<>(query);
if (body != null) {
message.putAll(body);
}
int count = localConnectionService.sendAll(buildEvent(message));
return ResponseEntity.ok(wrapResponse((Collections.singletonMap("count", count))));
}
/**
* 发送给单个人
*
* @param accessToken
* @return
*/
@RequestMapping("/send/{accessToken}")
public ResponseEntity sendOne(@RequestParam Map query, @RequestBody(required = false) Map body, @PathVariable String accessToken) {
Map message = new LinkedHashMap<>(query);
if (body != null) {
message.putAll(body);
}
int count = localConnectionService.send(accessToken, buildEvent(message));
return ResponseEntity.ok(wrapResponse(Collections.singletonMap("count", count)));
}
/**
* 关闭连接
*/
@RequestMapping("/disconnect/{connectionId}")
public ResponseEntity disconnect(@PathVariable Long connectionId) {
ACCESS_USER accessUser = getAccessUser();
String accessToken = accessUser.getAccessToken();
SseEmitter<ACCESS_USER> disconnect = localConnectionService.disconnect(accessToken, connectionId);
if (disconnect != null) {
onDisconnect(Collections.singletonList(disconnect), accessUser, accessToken, connectionId);
}
return ResponseEntity.ok(wrapResponse(Collections.singletonMap("count", disconnect != null ? 1 : 0)));
}
/**
* 关闭连接
*/
@RequestMapping("/disconnect")
public ResponseEntity disconnect() {
ACCESS_USER accessUser = getAccessUser();
String accessToken = accessUser.getAccessToken();
List<SseEmitter<ACCESS_USER>> count = localConnectionService.disconnect(accessToken);
if (count.size() > 0) {
onDisconnect(count, accessUser, accessToken, null);
}
return ResponseEntity.ok(wrapResponse(Collections.singletonMap("count", count.size())));
}
private SseEmitter.SseEventBuilder buildEvent(Map rawMessage) {
Map message = new LinkedHashMap(rawMessage);
SseEmitter.SseEventBuilder event = SseEmitter.event();
Object id = message.remove("id");
if (id != null) {
event.id(id.toString());
}
Object name = message.remove("name");
if (name != null) {
event.name(name.toString());
}
Object comment = message.remove("comment");
if (comment != null) {
event.comment(comment.toString());
}
Object reconnectTime = message.remove("reconnectTime");
if (reconnectTime != null) {
event.reconnectTime(Long.parseLong(reconnectTime.toString()));
}
if (!message.isEmpty()) {
event.data(message);
}
return event;
}
public boolean isBlank(CharSequence str) {
int strLen;
if (str == null || (strLen = str.length()) == 0) {
return true;
}
for (int i = 0; i < strLen; i++) {
if ((!Character.isWhitespace(str.charAt(i)))) {
return false;
}
}
return true;
}
public class ResponseData<T> implements Serializable {
/**
* 请求是否成功
*/
private boolean success = true;
/**
* 成功或者失败的code错误码
*/
private int code = 200;
/**
* 成功时返回的数据失败时返回具体的异常信息
*/
private T data;
/**
* 请求失败返回的提示信息给前端进行页面展示的信息
*/
private String message;
/**
* 请求失败返回的提示信息排查用的信息
*/
private String errorMessage;
public ResponseData() {
}
public ResponseData(T data) {
this.data = data;
}
public boolean isSuccess() {
return success;
}
public void setSuccess(boolean success) {
this.success = success;
}
public int getCode() {
return code;
}
public void setCode(int code) {
this.code = code;
}
public T getData() {
return data;
}
public void setData(T data) {
this.data = data;
}
public String getMessage() {
return message;
}
public void setMessage(String message) {
this.message = message;
}
public String getErrorMessage() {
return errorMessage;
}
public void setErrorMessage(String errorMessage) {
this.errorMessage = errorMessage;
}
}
}

View File

@ -0,0 +1,33 @@
package com.github.sseserver;
import org.springframework.boot.SpringApplication;
import org.springframework.boot.autoconfigure.SpringBootApplication;
import org.springframework.context.annotation.Bean;
import org.springframework.web.bind.annotation.RequestMapping;
import org.springframework.web.bind.annotation.RestController;
@SpringBootApplication
public class SseServerApplicationTests {
public static void main(String[] args) {
SpringApplication.run(SseServerApplicationTests.class, args);
}
@Bean
public LocalConnectionService localConnectionService() {
return new LocalConnectionServiceImpl();
}
/**
* http://localhost:8080/api/sse/connect
*/
@RestController
@RequestMapping("/api/sse")
public static class MyController extends SseWebController {
@Override
protected AccessUser getAccessUser() {
return super.getAccessUser();
}
}
}