diff --git a/core/src/main/kotlin/ch/dissem/bitmessage/factory/BufferPool.kt b/core/src/main/kotlin/ch/dissem/bitmessage/factory/BufferPool.kt index 8b1d031..548ec2e 100644 --- a/core/src/main/kotlin/ch/dissem/bitmessage/factory/BufferPool.kt +++ b/core/src/main/kotlin/ch/dissem/bitmessage/factory/BufferPool.kt @@ -18,9 +18,11 @@ package ch.dissem.bitmessage.factory import ch.dissem.bitmessage.constants.Network.HEADER_SIZE import ch.dissem.bitmessage.constants.Network.MAX_PAYLOAD_SIZE +import ch.dissem.bitmessage.exception.NodeException import org.slf4j.LoggerFactory import java.nio.ByteBuffer import java.util.* +import kotlin.math.max /** * A pool for [ByteBuffer]s. As they may use up a lot of memory, @@ -29,22 +31,44 @@ import java.util.* object BufferPool { private val LOG = LoggerFactory.getLogger(BufferPool::class.java) + private var limit: Int? = null + private var strictLimit = false + + /** + * Sets a limit to how many buffers the pool handles. If strict is set to true, it will not issue any + * buffers once the limit is reached and will throw a NodeException instead. Otherwise, it will simply + * ignore returned buffers once the limit is reached (and therefore garbage collected) + */ + fun setLimit(limit: Int, strict: Boolean = false) { + this.limit = limit + this.strictLimit = strict + pools.values.forEach { it.limit = limit } + pools[HEADER_SIZE]!!.limit = 2 * limit + pools[MAX_PAYLOAD_SIZE]!!.limit = max(limit / 2, 1) + } + private val pools = mapOf( - HEADER_SIZE to Stack(), - 54 to Stack(), - 1000 to Stack(), - 60000 to Stack(), - MAX_PAYLOAD_SIZE to Stack() + HEADER_SIZE to Pool(), + 54 to Pool(), + 1000 to Pool(), + 60000 to Pool(), + MAX_PAYLOAD_SIZE to Pool() ) - @Synchronized fun allocate(capacity: Int): ByteBuffer { + @Synchronized + fun allocate(capacity: Int): ByteBuffer { val targetSize = getTargetSize(capacity) val pool = pools[targetSize] ?: throw IllegalStateException("No pool for size $targetSize available") - if (pool.isEmpty()) { - LOG.trace("Creating new buffer of size $targetSize") - return ByteBuffer.allocate(targetSize) + + return if (pool.isEmpty) { + if (pool.hasCapacity || !strictLimit) { + LOG.trace("Creating new buffer of size $targetSize") + ByteBuffer.allocate(targetSize) + } else { + throw NodeException("pool limit for capacity $capacity is reached") + } } else { - return pool.pop() + pool.pop() } } @@ -53,18 +77,26 @@ object BufferPool { * @return a buffer of size 24 */ - @Synchronized fun allocateHeaderBuffer(): ByteBuffer { - val pool = pools[HEADER_SIZE] - if (pool == null || pool.isEmpty()) { - return ByteBuffer.allocate(HEADER_SIZE) + @Synchronized + fun allocateHeaderBuffer(): ByteBuffer { + val pool = pools[HEADER_SIZE] ?: throw IllegalStateException("No pool for header available") + return if (pool.isEmpty) { + if (pool.hasCapacity || !strictLimit) { + LOG.trace("Creating new buffer of header") + ByteBuffer.allocate(HEADER_SIZE) + } else { + throw NodeException("pool limit for header buffer is reached") + } } else { - return pool.pop() + pool.pop() } } - @Synchronized fun deallocate(buffer: ByteBuffer) { + @Synchronized + fun deallocate(buffer: ByteBuffer) { buffer.clear() - val pool = pools[buffer.capacity()] ?: throw IllegalArgumentException("Illegal buffer capacity ${buffer.capacity()} one of ${pools.keys} expected.") + val pool = pools[buffer.capacity()] + ?: throw IllegalArgumentException("Illegal buffer capacity ${buffer.capacity()} one of ${pools.keys} expected.") pool.push(buffer) } @@ -74,4 +106,40 @@ object BufferPool { } throw IllegalArgumentException("Requested capacity too large: requested=$capacity; max=$MAX_PAYLOAD_SIZE") } + + /** + * There is a race condition where the limit could be ignored for an allocation, but I think the consequences + * are benign. + */ + class Pool { + private val stack = Stack() + private var capacity = 0 + internal var limit: Int? = null + set(value) { + capacity = value ?: 0 + field = value + } + + val isEmpty + get() = stack.isEmpty() + + val hasCapacity + @Synchronized + get() = limit == null || capacity > 0 + + @Synchronized + fun pop(): ByteBuffer { + capacity-- + return stack.pop() + } + + @Synchronized + fun push(buffer: ByteBuffer) { + if (hasCapacity) { + stack.push(buffer) + } + // else, let it be collected by the garbage collector + capacity++ + } + } } diff --git a/core/src/main/kotlin/ch/dissem/bitmessage/factory/V3MessageReader.kt b/core/src/main/kotlin/ch/dissem/bitmessage/factory/V3MessageReader.kt index 5b81621..ce152bf 100644 --- a/core/src/main/kotlin/ch/dissem/bitmessage/factory/V3MessageReader.kt +++ b/core/src/main/kotlin/ch/dissem/bitmessage/factory/V3MessageReader.kt @@ -90,11 +90,11 @@ class V3MessageReader { state = ReaderState.DATA this.headerBuffer = null BufferPool.deallocate(headerBuffer) - val dataBuffer = BufferPool.allocate(length) - this.dataBuffer = dataBuffer - dataBuffer.clear() - dataBuffer.limit(length) - data(dataBuffer) + this.dataBuffer = BufferPool.allocate(length).apply { + clear() + limit(length) + data(this) + } } private fun data(dataBuffer: ByteBuffer) { diff --git a/networking/src/main/kotlin/ch/dissem/bitmessage/networking/nio/Connection.kt b/networking/src/main/kotlin/ch/dissem/bitmessage/networking/nio/Connection.kt index 22293c9..57040db 100644 --- a/networking/src/main/kotlin/ch/dissem/bitmessage/networking/nio/Connection.kt +++ b/networking/src/main/kotlin/ch/dissem/bitmessage/networking/nio/Connection.kt @@ -58,7 +58,7 @@ class Connection( private var lastObjectTime: Long = 0 lateinit var streams: LongArray - protected set + private set @Volatile var state = State.CONNECTING private set diff --git a/networking/src/main/kotlin/ch/dissem/bitmessage/networking/nio/ConnectionIO.kt b/networking/src/main/kotlin/ch/dissem/bitmessage/networking/nio/ConnectionIO.kt index a0789f0..63ca774 100644 --- a/networking/src/main/kotlin/ch/dissem/bitmessage/networking/nio/ConnectionIO.kt +++ b/networking/src/main/kotlin/ch/dissem/bitmessage/networking/nio/ConnectionIO.kt @@ -16,6 +16,7 @@ package ch.dissem.bitmessage.networking.nio +import ch.dissem.bitmessage.constants.Network.HEADER_SIZE import ch.dissem.bitmessage.entity.GetData import ch.dissem.bitmessage.entity.MessagePayload import ch.dissem.bitmessage.entity.NetworkMessage @@ -39,7 +40,7 @@ class ConnectionIO( private val getState: () -> Connection.State, private val handleMessage: (MessagePayload) -> Unit ) { - private val headerOut: ByteBuffer = ByteBuffer.allocate(24) + private val headerOut: ByteBuffer = ByteBuffer.allocate(HEADER_SIZE) private var payloadOut: ByteBuffer? = null private var reader: V3MessageReader? = V3MessageReader() internal val sendingQueue: Deque = ConcurrentLinkedDeque() diff --git a/networking/src/main/kotlin/ch/dissem/bitmessage/networking/nio/NioNetworkHandler.kt b/networking/src/main/kotlin/ch/dissem/bitmessage/networking/nio/NioNetworkHandler.kt index 01ef19f..b483d9b 100644 --- a/networking/src/main/kotlin/ch/dissem/bitmessage/networking/nio/NioNetworkHandler.kt +++ b/networking/src/main/kotlin/ch/dissem/bitmessage/networking/nio/NioNetworkHandler.kt @@ -17,7 +17,6 @@ package ch.dissem.bitmessage.networking.nio import ch.dissem.bitmessage.InternalContext -import ch.dissem.bitmessage.constants.Network.HEADER_SIZE import ch.dissem.bitmessage.constants.Network.NETWORK_MAGIC_NUMBER import ch.dissem.bitmessage.entity.CustomMessage import ch.dissem.bitmessage.entity.GetData @@ -25,6 +24,7 @@ import ch.dissem.bitmessage.entity.NetworkMessage import ch.dissem.bitmessage.entity.valueobject.InventoryVector import ch.dissem.bitmessage.entity.valueobject.NetworkAddress import ch.dissem.bitmessage.exception.NodeException +import ch.dissem.bitmessage.factory.BufferPool import ch.dissem.bitmessage.factory.V3MessageReader import ch.dissem.bitmessage.networking.nio.Connection.Mode.* import ch.dissem.bitmessage.ports.NetworkHandler @@ -48,7 +48,8 @@ import java.util.concurrent.* /** * Network handler using java.nio, resulting in less threads. */ -class NioNetworkHandler : NetworkHandler, InternalContext.ContextHolder { +class NioNetworkHandler(private val magicNetworkNumber: Int = NETWORK_MAGIC_NUMBER) : NetworkHandler, + InternalContext.ContextHolder { private val threadPool = Executors.newCachedThreadPool( pool("network") @@ -93,12 +94,13 @@ class NioNetworkHandler : NetworkHandler, InternalContext.ContextHolder { override fun send(server: InetAddress, port: Int, request: CustomMessage): CustomMessage { SocketChannel.open(InetSocketAddress(server, port)).use { channel -> channel.configureBlocking(true) - val headerBuffer = ByteBuffer.allocate(HEADER_SIZE) + val headerBuffer = BufferPool.allocateHeaderBuffer() val payloadBuffer = NetworkMessage(request).writer().writeHeaderAndGetPayloadBuffer(headerBuffer) headerBuffer.flip() while (headerBuffer.hasRemaining()) { channel.write(headerBuffer) } + BufferPool.deallocate(headerBuffer) while (payloadBuffer.hasRemaining()) { channel.write(payloadBuffer) } @@ -108,12 +110,14 @@ class NioNetworkHandler : NetworkHandler, InternalContext.ContextHolder { if (channel.read(reader.getActiveBuffer()) > 0) { reader.update() } else { + reader.cleanup() throw NodeException("No response from node $server") } } val networkMessage: NetworkMessage? if (reader.getMessages().isEmpty()) { - throw NodeException("No response from node " + server) + reader.cleanup() + throw NodeException("No response from node $server") } else { networkMessage = reader.getMessages().first() } @@ -121,13 +125,14 @@ class NioNetworkHandler : NetworkHandler, InternalContext.ContextHolder { if (networkMessage.payload is CustomMessage) { return networkMessage.payload as CustomMessage } else { + reader.cleanup() throw NodeException("Unexpected response from node $server: ${networkMessage.payload.javaClass}") } } } override fun start() { - if (selector?.isOpen ?: false) { + if (selector?.isOpen == true) { throw IllegalStateException("Network already running - you need to stop first.") } val selector = Selector.open() @@ -137,7 +142,7 @@ class NioNetworkHandler : NetworkHandler, InternalContext.ContextHolder { starter = thread("connection manager") { while (selector.isOpen) { - var missing = NETWORK_MAGIC_NUMBER + var missing = magicNetworkNumber for ((connection, _) in connections) { if (connection.state == Connection.State.ACTIVE) { missing-- @@ -229,10 +234,8 @@ class NioNetworkHandler : NetworkHandler, InternalContext.ContextHolder { ), requestedObjects, 0 ) - connections.put( - connection, + connections[connection] = accepted.register(selector, OP_READ or OP_WRITE, connection) - ) } catch (e: AsynchronousCloseException) { LOG.trace(e.message) } catch (e: IOException) { @@ -260,13 +263,13 @@ class NioNetworkHandler : NetworkHandler, InternalContext.ContextHolder { if (key.isReadable) { read(channel, connection.io) } - if (connection.state == Connection.State.DISCONNECTED) { - key.interestOps(0) - channel.close() - } else if (connection.io.isWritePending) { - key.interestOps(OP_READ or OP_WRITE) - } else { - key.interestOps(OP_READ) + when { + connection.state == Connection.State.DISCONNECTED -> { + key.interestOps(0) + channel.close() + } + connection.io.isWritePending -> key.interestOps(OP_READ or OP_WRITE) + else -> key.interestOps(OP_READ) } } catch (e: CancelledKeyException) { connection.disconnect() @@ -361,7 +364,7 @@ class NioNetworkHandler : NetworkHandler, InternalContext.ContextHolder { override fun offer(iv: InventoryVector) { val targetConnections = connections.keys.filter { it.state == Connection.State.ACTIVE && !it.knowsOf(iv) } - selectRandom(NETWORK_MAGIC_NUMBER, targetConnections).forEach { it.offer(iv) } + selectRandom(magicNetworkNumber, targetConnections).forEach { it.offer(iv) } } override fun request(inventoryVectors: MutableCollection) {