@Slf4j
public class MultiThreadServer {
// 获取 cpu核心数
private static final int cpu = Runtime.getRuntime().availableProcessors();
public static void main(String[] args) throws IOException {
Thread.currentThread().setName("boss");
Selector boss = Selector.open();
ServerSocketChannel ssc = ServerSocketChannel.open();
ssc.bind(new InetSocketAddress(8080));
ssc.configureBlocking(false);
SelectionKey bossKey = ssc.register(boss, SelectionKey.OP_ACCEPT, null);
// 1.创建 cpu核心数目的 worker
Worker[] workers = new Worker[cpu];
for (int i = 0; i < workers.length; i ++) {
workers[i] = new Worker("$worker" + i);
}
// Round Robin
AtomicInteger cnt = new AtomicInteger();
while (true) {
boss.select();
Iterator<SelectionKey> iter = boss.selectedKeys().iterator();
while (iter.hasNext()) {
SelectionKey key = iter.next();
iter.remove();
if (key.isAcceptable()) {
SocketChannel sc = ssc.accept();
sc.configureBlocking(false);
log.debug("connected...{}", sc.getRemoteAddress());
// 2.关联
log.debug("before register...{}", sc.getRemoteAddress());
// round robin
workers[cnt.getAndIncrement() % workers.length].register(sc, SelectionKey.OP_READ, null);
log.debug("after register...{}", sc.getRemoteAddress());
}
}
}
}
// Worker职责:监测读写事件
// 每个 worker负责去监听一部分的 channel
static class Worker implements Runnable{
private Thread thread;
private Selector selector;
private String name;
// 使用任务队列,完成线程间的通信
private ConcurrentLinkedQueue<Runnable> queue = new ConcurrentLinkedQueue<>();
// 这里来完成全部的初始化操作
public Worker(String name) throws IOException {
this.name = name;
thread = new Thread(this, name);
selector = Selector.open();
}
// worker开始工作
public void register(SocketChannel sc, int opRead, Object att) throws IOException {
if (!thread.isAlive()) {
thread.start();
}
// 添加的任务并没有立刻被执行
queue.add(() -> {
try {
// 注册读事件, selector监听需要
sc.register(selector, opRead, att); // boss
} catch (ClosedChannelException e) {
e.printStackTrace();
}
});
// 这里需要 wakeup selector
// wakeup这个方法有点特殊 - 就算是 wakeup方法先执行, selector.select后执行时, 也不会阻塞c
selector.wakeup();
}
@Override
public void run() {
while (true) {
try {
// selector阻塞时是无法往其中去注册 channel的!
selector.select();
// 取任务, 并且去执行任务 -- 队列为空时, poll 不会报异常,会返回 null
Runnable task = queue.poll();
if (task != null) {
task.run(); // work-0
}
Iterator<SelectionKey> iter = selector.selectedKeys().iterator();
while (iter.hasNext()) {
SelectionKey key = iter.next();
iter.remove();
if (key.isReadable()) {
SocketChannel channel = (SocketChannel) key.channel();
ByteBuffer bf = ByteBuffer.allocate(32);
log.debug("read...{}", channel.getRemoteAddress());
channel.read(bf);
bf.flip();
debugAll(bf);
}
}
} catch (IOException e) {
e.printStackTrace();
}
}
}
}
}
public class Client {
public static void main(String[] args) throws IOException {
// 1.创建客户端对象
SocketChannel client = SocketChannel.open();
// 2.建立与服务端之间的连接
// 指定主机名以及端口号
client.connect(new InetSocketAddress("localhost", 8080));
client.write(StandardCharsets.UTF_8.encode("12345678912345678920\n"));
System.out.println("waiting...");
}
}