Refactored JdbcMessageRepository so that alternative implementations can be done easier

This commit is contained in:
Christian Basler 2016-05-20 23:58:08 +02:00
parent c3d8a07e83
commit 14849a82ea
7 changed files with 206 additions and 159 deletions

1
.gitignore vendored
View File

@ -5,6 +5,7 @@
### Gradle ### ### Gradle ###
.gradle .gradle
build/ build/
classes/
# Ignore Gradle GUI config # Ignore Gradle GUI config
gradle-app.setting gradle-app.setting

View File

@ -0,0 +1,123 @@
/*
* Copyright 2016 Christian Basler
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package ch.dissem.bitmessage.ports;
import ch.dissem.bitmessage.InternalContext;
import ch.dissem.bitmessage.entity.BitmessageAddress;
import ch.dissem.bitmessage.entity.Plaintext;
import ch.dissem.bitmessage.entity.valueobject.Label;
import ch.dissem.bitmessage.exception.ApplicationException;
import ch.dissem.bitmessage.utils.Strings;
import ch.dissem.bitmessage.utils.UnixTime;
import java.util.Collection;
import java.util.List;
import static ch.dissem.bitmessage.utils.SqlStrings.join;
public abstract class AbstractMessageRepository implements MessageRepository, InternalContext.ContextHolder {
protected InternalContext ctx;
@Override
public void setContext(InternalContext context) {
this.ctx = context;
}
protected void safeSenderIfNecessary(Plaintext message) {
if (message.getId() == null) {
BitmessageAddress savedAddress = ctx.getAddressRepository().getAddress(message.getFrom().getAddress());
if (savedAddress == null) {
ctx.getAddressRepository().save(message.getFrom());
} else if (savedAddress.getPubkey() == null && message.getFrom().getPubkey() != null) {
savedAddress.setPubkey(message.getFrom().getPubkey());
ctx.getAddressRepository().save(savedAddress);
}
}
}
@Override
public Plaintext getMessage(Object id) {
if (id instanceof Long) {
return single(find("id=" + id));
} else {
throw new IllegalArgumentException("Long expected for ID");
}
}
@Override
public Plaintext getMessage(byte[] initialHash) {
return single(find("initial_hash=X'" + Strings.hex(initialHash) + "'"));
}
@Override
public Plaintext getMessageForAck(byte[] ackData) {
return single(find("ack_data=X'" + Strings.hex(ackData) + "' AND status='" + Plaintext.Status.SENT + "'"));
}
@Override
public List<Plaintext> findMessages(Label label) {
return find("id IN (SELECT message_id FROM Message_Label WHERE label_id=" + label.getId() + ")");
}
@Override
public List<Plaintext> findMessages(Plaintext.Status status, BitmessageAddress recipient) {
return find("status='" + status.name() + "' AND recipient='" + recipient.getAddress() + "'");
}
@Override
public List<Plaintext> findMessages(Plaintext.Status status) {
return find("status='" + status.name() + "'");
}
@Override
public List<Plaintext> findMessages(BitmessageAddress sender) {
return find("sender='" + sender.getAddress() + "'");
}
@Override
public List<Plaintext> findMessagesToResend() {
return find("status='" + Plaintext.Status.SENT.name() + "'" +
" AND next_try < " + UnixTime.now());
}
@Override
public List<Label> getLabels() {
return findLabels("1=1");
}
@Override
public List<Label> getLabels(Label.Type... types) {
return findLabels("type IN (" + join(types) + ")");
}
protected abstract List<Label> findLabels(String where);
protected <T> T single(Collection<T> collection) {
switch (collection.size()) {
case 0:
return null;
case 1:
return collection.iterator().next();
default:
throw new ApplicationException("This shouldn't happen, found " + collection.size() +
" items, one or none was expected");
}
}
protected abstract List<Plaintext> find(String where);
}

View File

@ -0,0 +1,59 @@
/*
* Copyright 2016 Christian Basler
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package ch.dissem.bitmessage.utils;
import ch.dissem.bitmessage.entity.payload.ObjectType;
import static ch.dissem.bitmessage.utils.Strings.hex;
public class SqlStrings {
public static StringBuilder join(long... objects) {
StringBuilder streamList = new StringBuilder();
for (int i = 0; i < objects.length; i++) {
if (i > 0) streamList.append(", ");
streamList.append(objects[i]);
}
return streamList;
}
public static StringBuilder join(byte[]... objects) {
StringBuilder streamList = new StringBuilder();
for (int i = 0; i < objects.length; i++) {
if (i > 0) streamList.append(", ");
streamList.append(hex(objects[i]));
}
return streamList;
}
public static StringBuilder join(ObjectType... types) {
StringBuilder streamList = new StringBuilder();
for (int i = 0; i < types.length; i++) {
if (i > 0) streamList.append(", ");
streamList.append(types[i].getNumber());
}
return streamList;
}
public static StringBuilder join(Enum... types) {
StringBuilder streamList = new StringBuilder();
for (int i = 0; i < types.length; i++) {
if (i > 0) streamList.append(", ");
streamList.append('\'').append(types[i].name()).append('\'');
}
return streamList;
}
}

View File

@ -1,5 +1,5 @@
/* /*
* Copyright 2015 Christian Basler * Copyright 2016 Christian Basler
* *
* Licensed under the Apache License, Version 2.0 (the "License"); * Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License. * you may not use this file except in compliance with the License.
@ -14,16 +14,16 @@
* limitations under the License. * limitations under the License.
*/ */
package ch.dissem.bitmessage.repository; package ch.dissem.bitmessage.utils;
import org.junit.Test; import org.junit.Test;
import static org.junit.Assert.assertEquals; import static org.junit.Assert.assertEquals;
public class JdbcHelperTest { public class SqlStringsTest {
@Test @Test
public void ensureJoinWorksWithLongArray() { public void ensureJoinWorksWithLongArray() {
long[] test = {1L, 2L}; long[] test = {1L, 2L};
assertEquals("1, 2", JdbcHelper.join(test).toString()); assertEquals("1, 2", SqlStrings.join(test).toString());
} }
} }

View File

@ -42,43 +42,7 @@ public abstract class JdbcHelper {
this.config = config; this.config = config;
} }
public static StringBuilder join(long... objects) { public static void writeBlob(PreparedStatement ps, int parameterIndex, Streamable data) throws SQLException, IOException {
StringBuilder streamList = new StringBuilder();
for (int i = 0; i < objects.length; i++) {
if (i > 0) streamList.append(", ");
streamList.append(objects[i]);
}
return streamList;
}
public static StringBuilder join(byte[]... objects) {
StringBuilder streamList = new StringBuilder();
for (int i = 0; i < objects.length; i++) {
if (i > 0) streamList.append(", ");
streamList.append(hex(objects[i]));
}
return streamList;
}
public static StringBuilder join(ObjectType... types) {
StringBuilder streamList = new StringBuilder();
for (int i = 0; i < types.length; i++) {
if (i > 0) streamList.append(", ");
streamList.append(types[i].getNumber());
}
return streamList;
}
public static StringBuilder join(Enum... types) {
StringBuilder streamList = new StringBuilder();
for (int i = 0; i < types.length; i++) {
if (i > 0) streamList.append(", ");
streamList.append('\'').append(types[i].name()).append('\'');
}
return streamList;
}
protected void writeBlob(PreparedStatement ps, int parameterIndex, Streamable data) throws SQLException, IOException {
if (data == null) { if (data == null) {
ps.setBytes(parameterIndex, null); ps.setBytes(parameterIndex, null);
} else { } else {
@ -87,12 +51,4 @@ public abstract class JdbcHelper {
ps.setBytes(parameterIndex, os.toByteArray()); ps.setBytes(parameterIndex, os.toByteArray());
} }
} }
protected <T> T single(Collection<T> collection) {
if (collection.size() > 1) {
throw new ApplicationException("This shouldn't happen, found " + collection.size() +
" messages, one or none was expected");
}
return collection.stream().findAny().orElse(null);
}
} }

View File

@ -31,6 +31,7 @@ import java.util.List;
import java.util.Map; import java.util.Map;
import java.util.concurrent.ConcurrentHashMap; import java.util.concurrent.ConcurrentHashMap;
import static ch.dissem.bitmessage.utils.SqlStrings.join;
import static ch.dissem.bitmessage.utils.UnixTime.MINUTE; import static ch.dissem.bitmessage.utils.UnixTime.MINUTE;
import static ch.dissem.bitmessage.utils.UnixTime.now; import static ch.dissem.bitmessage.utils.UnixTime.now;

View File

@ -16,16 +16,12 @@
package ch.dissem.bitmessage.repository; package ch.dissem.bitmessage.repository;
import ch.dissem.bitmessage.InternalContext;
import ch.dissem.bitmessage.entity.BitmessageAddress;
import ch.dissem.bitmessage.entity.Plaintext; import ch.dissem.bitmessage.entity.Plaintext;
import ch.dissem.bitmessage.entity.valueobject.InventoryVector; import ch.dissem.bitmessage.entity.valueobject.InventoryVector;
import ch.dissem.bitmessage.entity.valueobject.Label; import ch.dissem.bitmessage.entity.valueobject.Label;
import ch.dissem.bitmessage.exception.ApplicationException; import ch.dissem.bitmessage.exception.ApplicationException;
import ch.dissem.bitmessage.ports.AbstractMessageRepository;
import ch.dissem.bitmessage.ports.MessageRepository; import ch.dissem.bitmessage.ports.MessageRepository;
import ch.dissem.bitmessage.utils.Strings;
import ch.dissem.bitmessage.utils.TTL;
import ch.dissem.bitmessage.utils.UnixTime;
import org.slf4j.Logger; import org.slf4j.Logger;
import org.slf4j.LoggerFactory; import org.slf4j.LoggerFactory;
@ -33,34 +29,30 @@ import java.io.IOException;
import java.io.InputStream; import java.io.InputStream;
import java.sql.*; import java.sql.*;
import java.util.ArrayList; import java.util.ArrayList;
import java.util.Collection;
import java.util.LinkedList; import java.util.LinkedList;
import java.util.List; import java.util.List;
public class JdbcMessageRepository extends JdbcHelper implements MessageRepository, InternalContext.ContextHolder { import static ch.dissem.bitmessage.repository.JdbcHelper.writeBlob;
public class JdbcMessageRepository extends AbstractMessageRepository implements MessageRepository {
private static final Logger LOG = LoggerFactory.getLogger(JdbcMessageRepository.class); private static final Logger LOG = LoggerFactory.getLogger(JdbcMessageRepository.class);
private InternalContext ctx; private final JdbcConfig config;
public JdbcMessageRepository(JdbcConfig config) { public JdbcMessageRepository(JdbcConfig config) {
super(config); this.config = config;
} }
@Override @Override
public List<Label> getLabels() { protected List<Label> findLabels(String where) {
List<Label> result = new LinkedList<>();
try ( try (
Connection connection = config.getConnection(); Connection connection = config.getConnection()
Statement stmt = connection.createStatement();
ResultSet rs = stmt.executeQuery("SELECT id, label, type, color FROM Label ORDER BY ord")
) { ) {
while (rs.next()) { return findLabels(connection, where);
result.add(getLabel(rs));
}
} catch (SQLException e) { } catch (SQLException e) {
throw new ApplicationException(e); LOG.error(e.getMessage(), e);
} }
return result; return new ArrayList<>();
} }
private Label getLabel(ResultSet rs) throws SQLException { private Label getLabel(ResultSet rs) throws SQLException {
@ -75,24 +67,6 @@ public class JdbcMessageRepository extends JdbcHelper implements MessageReposito
return label; return label;
} }
@Override
public List<Label> getLabels(Label.Type... types) {
List<Label> result = new LinkedList<>();
try (
Connection connection = config.getConnection();
Statement stmt = connection.createStatement();
ResultSet rs = stmt.executeQuery("SELECT id, label, type, color FROM Label WHERE type IN (" + join(types) +
") ORDER BY ord")
) {
while (rs.next()) {
result.add(getLabel(rs));
}
} catch (SQLException e) {
LOG.error(e.getMessage(), e);
}
return result;
}
@Override @Override
public int countUnread(Label label) { public int countUnread(Label label) {
String where; String where;
@ -120,60 +94,7 @@ public class JdbcMessageRepository extends JdbcHelper implements MessageReposito
} }
@Override @Override
public Plaintext getMessage(Object id) { protected List<Plaintext> find(String where) {
if (id instanceof Long) {
List<Plaintext> plaintexts = find("id=" + id);
switch (plaintexts.size()) {
case 0:
return null;
case 1:
return plaintexts.get(0);
default:
throw new ApplicationException("This shouldn't happen, found " + plaintexts.size() +
" messages, one or none was expected");
}
} else {
throw new IllegalArgumentException("Long expected for ID");
}
}
@Override
public Plaintext getMessage(byte[] initialHash) {
return single(find("initial_hash=X'" + Strings.hex(initialHash) + "'"));
}
@Override
public Plaintext getMessageForAck(byte[] ackData) {
return single(find("ack_data=X'" + Strings.hex(ackData) + "' AND status='" + Plaintext.Status.SENT + "'"));
}
@Override
public List<Plaintext> findMessages(Label label) {
return find("id IN (SELECT message_id FROM Message_Label WHERE label_id=" + label.getId() + ")");
}
@Override
public List<Plaintext> findMessages(Plaintext.Status status, BitmessageAddress recipient) {
return find("status='" + status.name() + "' AND recipient='" + recipient.getAddress() + "'");
}
@Override
public List<Plaintext> findMessages(Plaintext.Status status) {
return find("status='" + status.name() + "'");
}
@Override
public List<Plaintext> findMessages(BitmessageAddress sender) {
return find("sender='" + sender.getAddress() + "'");
}
@Override
public List<Plaintext> findMessagesToResend() {
return find("status='" + Plaintext.Status.SENT.name() + "'" +
" AND next_try < " + UnixTime.now());
}
private List<Plaintext> find(String where) {
List<Plaintext> result = new LinkedList<>(); List<Plaintext> result = new LinkedList<>();
try ( try (
Connection connection = config.getConnection(); Connection connection = config.getConnection();
@ -199,7 +120,8 @@ public class JdbcMessageRepository extends JdbcHelper implements MessageReposito
builder.ttl(rs.getLong("ttl")); builder.ttl(rs.getLong("ttl"));
builder.retries(rs.getInt("retries")); builder.retries(rs.getInt("retries"));
builder.nextTry(rs.getLong("next_try")); builder.nextTry(rs.getLong("next_try"));
builder.labels(findLabels(connection, id)); builder.labels(findLabels(connection,
"WHERE id IN (SELECT label_id FROM Message_Label WHERE message_id=" + id + ") ORDER BY ord"));
Plaintext message = builder.build(); Plaintext message = builder.build();
message.setInitialHash(rs.getBytes("initial_hash")); message.setInitialHash(rs.getBytes("initial_hash"));
result.add(message); result.add(message);
@ -210,12 +132,11 @@ public class JdbcMessageRepository extends JdbcHelper implements MessageReposito
return result; return result;
} }
private Collection<Label> findLabels(Connection connection, long messageId) { private List<Label> findLabels(Connection connection, String where) {
List<Label> result = new ArrayList<>(); List<Label> result = new ArrayList<>();
try ( try (
Statement stmt = connection.createStatement(); Statement stmt = connection.createStatement();
ResultSet rs = stmt.executeQuery("SELECT id, label, type, color FROM Label " + ResultSet rs = stmt.executeQuery("SELECT id, label, type, color FROM Label WHERE " + where)
"WHERE id IN (SELECT label_id FROM Message_Label WHERE message_id=" + messageId + ")")
) { ) {
while (rs.next()) { while (rs.next()) {
result.add(getLabel(rs)); result.add(getLabel(rs));
@ -228,16 +149,7 @@ public class JdbcMessageRepository extends JdbcHelper implements MessageReposito
@Override @Override
public void save(Plaintext message) { public void save(Plaintext message) {
// save from address if necessary safeSenderIfNecessary(message);
if (message.getId() == null) {
BitmessageAddress savedAddress = ctx.getAddressRepository().getAddress(message.getFrom().getAddress());
if (savedAddress == null) {
ctx.getAddressRepository().save(message.getFrom());
} else if (savedAddress.getPubkey() == null && message.getFrom().getPubkey() != null) {
savedAddress.setPubkey(message.getFrom().getPubkey());
ctx.getAddressRepository().save(savedAddress);
}
}
try (Connection connection = config.getConnection()) { try (Connection connection = config.getConnection()) {
try { try {
@ -350,9 +262,4 @@ public class JdbcMessageRepository extends JdbcHelper implements MessageReposito
LOG.error(e.getMessage(), e); LOG.error(e.getMessage(), e);
} }
} }
@Override
public void setContext(InternalContext context) {
this.ctx = context;
}
} }