Better memory management for the in buffer (the same TODO for the out buffer.

This commit is contained in:
2016-07-25 07:52:27 +02:00
parent 82ee4d05bb
commit 48ff975ffd
11 changed files with 427 additions and 168 deletions

View File

@ -0,0 +1,107 @@
/*
* Copyright 2016 Christian Basler
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package ch.dissem.bitmessage.factory;
import ch.dissem.bitmessage.ports.NetworkHandler;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import java.nio.ByteBuffer;
import java.util.*;
/**
* A pool for {@link ByteBuffer}s. As they may use up a lot of memory,
* they should be reused as efficiently as possible.
*/
class BufferPool {
private static final Logger LOG = LoggerFactory.getLogger(BufferPool.class);
public static final BufferPool bufferPool = new BufferPool(256, 2048);
private final Map<Size, Integer> capacities = new EnumMap<>(Size.class);
private final Map<Size, Stack<ByteBuffer>> pools = new EnumMap<>(Size.class);
private BufferPool(int small, int medium) {
capacities.put(Size.HEADER, 24);
capacities.put(Size.SMALL, small);
capacities.put(Size.MEDIUM, medium);
capacities.put(Size.LARGE, NetworkHandler.MAX_PAYLOAD_SIZE);
pools.put(Size.HEADER, new Stack<ByteBuffer>());
pools.put(Size.SMALL, new Stack<ByteBuffer>());
pools.put(Size.MEDIUM, new Stack<ByteBuffer>());
pools.put(Size.LARGE, new Stack<ByteBuffer>());
}
public synchronized ByteBuffer allocate(int capacity) {
Size targetSize = getTargetSize(capacity);
Size s = targetSize;
do {
Stack<ByteBuffer> pool = pools.get(s);
if (!pool.isEmpty()) {
return pool.pop();
}
s = s.next();
} while (s != null);
LOG.debug("Creating new buffer of size " + targetSize);
return ByteBuffer.allocate(capacities.get(targetSize));
}
public synchronized ByteBuffer allocate() {
Stack<ByteBuffer> pool = pools.get(Size.HEADER);
if (!pool.isEmpty()) {
return pool.pop();
} else {
return ByteBuffer.allocate(capacities.get(Size.HEADER));
}
}
public synchronized void deallocate(ByteBuffer buffer) {
buffer.clear();
Size size = getTargetSize(buffer.capacity());
if (buffer.capacity() != capacities.get(size)) {
throw new IllegalArgumentException("Illegal buffer capacity " + buffer.capacity() +
" one of " + capacities.values() + " expected.");
}
pools.get(size).push(buffer);
}
private Size getTargetSize(int capacity) {
for (Size s : Size.values()) {
if (capacity <= capacities.get(s)) {
return s;
}
}
throw new IllegalArgumentException("Requested capacity too large: " +
"requested=" + capacity + "; max=" + capacities.get(Size.LARGE));
}
private enum Size {
HEADER, SMALL, MEDIUM, LARGE;
public Size next() {
switch (this) {
case SMALL:
return MEDIUM;
case MEDIUM:
return LARGE;
default:
return null;
}
}
}
}

View File

@ -21,15 +21,20 @@ import ch.dissem.bitmessage.entity.NetworkMessage;
import ch.dissem.bitmessage.exception.ApplicationException;
import ch.dissem.bitmessage.exception.NodeException;
import ch.dissem.bitmessage.utils.Decode;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import java.io.ByteArrayInputStream;
import java.io.FileWriter;
import java.io.IOException;
import java.io.UnsupportedEncodingException;
import java.nio.ByteBuffer;
import java.util.LinkedList;
import java.util.List;
import java.util.UUID;
import static ch.dissem.bitmessage.entity.NetworkMessage.MAGIC_BYTES;
import static ch.dissem.bitmessage.factory.BufferPool.bufferPool;
import static ch.dissem.bitmessage.ports.NetworkHandler.MAX_PAYLOAD_SIZE;
import static ch.dissem.bitmessage.utils.Singleton.cryptography;
@ -37,53 +42,89 @@ import static ch.dissem.bitmessage.utils.Singleton.cryptography;
* Similar to the {@link V3MessageFactory}, but used for NIO buffers which may or may not contain a whole message.
*/
public class V3MessageReader {
private static final Logger LOG = LoggerFactory.getLogger(V3MessageReader.class);
private ByteBuffer headerBuffer;
private ByteBuffer dataBuffer;
private ReaderState state = ReaderState.MAGIC;
private String command;
private int length;
private byte[] checksum;
private List<NetworkMessage> messages = new LinkedList<>();
private SizeInfo sizeInfo = new SizeInfo();
public void update(ByteBuffer buffer) {
while (buffer.hasRemaining()) {
switch (state) {
case MAGIC:
if (!findMagicBytes(buffer)) return;
state = ReaderState.HEADER;
case HEADER:
if (buffer.remaining() < 20) {
return;
}
command = getCommand(buffer);
length = (int) Decode.uint32(buffer);
if (length > MAX_PAYLOAD_SIZE) {
throw new NodeException("Payload of " + length + " bytes received, no more than " +
MAX_PAYLOAD_SIZE + " was expected.");
}
checksum = new byte[4];
buffer.get(checksum);
state = ReaderState.DATA;
case DATA:
if (buffer.remaining() < length) {
return;
}
if (!testChecksum(buffer)) {
throw new NodeException("Checksum failed for message '" + command + "'");
}
try {
MessagePayload payload = V3MessageFactory.getPayload(
command,
new ByteArrayInputStream(buffer.array(), buffer.arrayOffset() + buffer.position(), length),
length);
if (payload != null) {
messages.add(new NetworkMessage(payload));
}
} catch (IOException e) {
throw new NodeException(e.getMessage());
}
state = ReaderState.MAGIC;
public ByteBuffer getActiveBuffer() {
if (state != null && state != ReaderState.DATA) {
if (headerBuffer == null) {
headerBuffer = bufferPool.allocate();
}
}
return state == ReaderState.DATA ? dataBuffer : headerBuffer;
}
public void update() {
if (state != ReaderState.DATA) {
getActiveBuffer();
headerBuffer.flip();
}
switch (state) {
case MAGIC:
if (!findMagicBytes(headerBuffer)) {
headerBuffer.compact();
return;
}
state = ReaderState.HEADER;
case HEADER:
if (headerBuffer.remaining() < 20) {
headerBuffer.compact();
headerBuffer.limit(20);
return;
}
command = getCommand(headerBuffer);
length = (int) Decode.uint32(headerBuffer);
if (length > MAX_PAYLOAD_SIZE) {
throw new NodeException("Payload of " + length + " bytes received, no more than " +
MAX_PAYLOAD_SIZE + " was expected.");
}
sizeInfo.add(length); // FIXME: remove this once we have some values to work with
checksum = new byte[4];
headerBuffer.get(checksum);
state = ReaderState.DATA;
bufferPool.deallocate(headerBuffer);
headerBuffer = null;
dataBuffer = bufferPool.allocate(length);
dataBuffer.clear();
dataBuffer.limit(length);
case DATA:
if (dataBuffer.position() < length) {
return;
} else {
dataBuffer.flip();
}
if (!testChecksum(dataBuffer)) {
state = ReaderState.MAGIC;
throw new NodeException("Checksum failed for message '" + command + "'");
}
try {
MessagePayload payload = V3MessageFactory.getPayload(
command,
new ByteArrayInputStream(dataBuffer.array(),
dataBuffer.arrayOffset() + dataBuffer.position(), length),
length);
if (payload != null) {
messages.add(new NetworkMessage(payload));
}
} catch (IOException e) {
throw new NodeException(e.getMessage());
} finally {
state = ReaderState.MAGIC;
bufferPool.deallocate(dataBuffer);
dataBuffer = null;
dataBuffer = null;
}
}
}
public List<NetworkMessage> getMessages() {
@ -129,7 +170,7 @@ public class V3MessageReader {
private boolean testChecksum(ByteBuffer buffer) {
byte[] payloadChecksum = cryptography().sha512(buffer.array(),
buffer.arrayOffset() + buffer.position(), length);
buffer.arrayOffset() + buffer.position(), length);
for (int i = 0; i < checksum.length; i++) {
if (checksum[i] != payloadChecksum[i]) {
return false;
@ -138,5 +179,52 @@ public class V3MessageReader {
return true;
}
/**
* De-allocates all buffers. This method should be called iff the reader isn't used anymore, i.e. when its
* connection is severed.
*/
public void cleanup() {
state = null;
if (headerBuffer != null) {
bufferPool.deallocate(headerBuffer);
}
if (dataBuffer != null) {
bufferPool.deallocate(dataBuffer);
}
}
private enum ReaderState {MAGIC, HEADER, DATA}
private class SizeInfo {
private FileWriter file;
private long min = Long.MAX_VALUE;
private long avg = 0;
private long max = Long.MIN_VALUE;
private long count = 0;
private SizeInfo() {
try {
file = new FileWriter("D:/message_size_info-" + UUID.randomUUID() + ".csv");
} catch (IOException e) {
LOG.error(e.getMessage(), e);
}
}
private void add(long length) {
avg = (count * avg + length) / (count + 1);
if (length < min) {
min = length;
}
if (length > max) {
max = length;
}
count++;
LOG.info("Received message with data size " + length + "; Min: " + min + "; Max: " + max + "; Avg: " + avg);
try {
file.write(length + "\n");
} catch (IOException e) {
e.printStackTrace();
}
}
}
}

View File

@ -87,7 +87,7 @@ public class MemoryNodeRegistry implements NodeRegistry {
}
}
if (result.isEmpty()) {
if (stableNodes.isEmpty()) {
if (stableNodes.isEmpty() || stableNodes.get(stream).isEmpty()) {
loadStableNodes();
}
Set<NetworkAddress> nodes = stableNodes.get(stream);
@ -108,8 +108,8 @@ public class MemoryNodeRegistry implements NodeRegistry {
synchronized (knownNodes) {
if (!knownNodes.containsKey(node.getStream())) {
knownNodes.put(
node.getStream(),
newSetFromMap(new ConcurrentHashMap<NetworkAddress, Boolean>())
node.getStream(),
newSetFromMap(new ConcurrentHashMap<NetworkAddress, Boolean>())
);
}
}