Split LabelRepository off the MessageRepository

This commit is contained in:
Christian Basler 2017-11-26 20:30:05 +01:00
parent ddb2073c2f
commit 278d5b05e6
20 changed files with 407 additions and 236 deletions

View File

@ -73,6 +73,9 @@ class BitmessageContext private constructor(builder: BitmessageContext.Builder)
val addresses: AddressRepository val addresses: AddressRepository
@JvmName("addresses") get @JvmName("addresses") get
val labels: LabelRepository
@JvmName("labels") get
val messages: MessageRepository val messages: MessageRepository
@JvmName("messages") get @JvmName("messages") get
@ -301,6 +304,7 @@ class BitmessageContext private constructor(builder: BitmessageContext.Builder)
var nodeRegistry by Delegates.notNull<NodeRegistry>() var nodeRegistry by Delegates.notNull<NodeRegistry>()
var networkHandler by Delegates.notNull<NetworkHandler>() var networkHandler by Delegates.notNull<NetworkHandler>()
var addressRepo by Delegates.notNull<AddressRepository>() var addressRepo by Delegates.notNull<AddressRepository>()
var labelRepo by Delegates.notNull<LabelRepository>()
var messageRepo by Delegates.notNull<MessageRepository>() var messageRepo by Delegates.notNull<MessageRepository>()
var proofOfWorkRepo by Delegates.notNull<ProofOfWorkRepository>() var proofOfWorkRepo by Delegates.notNull<ProofOfWorkRepository>()
var proofOfWorkEngine: ProofOfWorkEngine? = null var proofOfWorkEngine: ProofOfWorkEngine? = null
@ -330,6 +334,11 @@ class BitmessageContext private constructor(builder: BitmessageContext.Builder)
return this return this
} }
fun labelRepo(labelRepo: LabelRepository): Builder {
this.labelRepo = labelRepo
return this
}
fun messageRepo(messageRepo: MessageRepository): Builder { fun messageRepo(messageRepo: MessageRepository): Builder {
this.messageRepo = messageRepo this.messageRepo = messageRepo
return this return this
@ -386,6 +395,7 @@ class BitmessageContext private constructor(builder: BitmessageContext.Builder)
builder.nodeRegistry, builder.nodeRegistry,
builder.networkHandler, builder.networkHandler,
builder.addressRepo, builder.addressRepo,
builder.labelRepo,
builder.messageRepo, builder.messageRepo,
builder.proofOfWorkRepo, builder.proofOfWorkRepo,
builder.proofOfWorkEngine ?: MultiThreadedPOWEngine(), builder.proofOfWorkEngine ?: MultiThreadedPOWEngine(),
@ -400,6 +410,7 @@ class BitmessageContext private constructor(builder: BitmessageContext.Builder)
builder.preferences builder.preferences
) )
this.addresses = builder.addressRepo this.addresses = builder.addressRepo
this.labels = builder.labelRepo
this.messages = builder.messageRepo this.messages = builder.messageRepo
(builder.listener as? Listener.WithContext)?.setContext(this) (builder.listener as? Listener.WithContext)?.setContext(this)
internals.proofOfWorkService.doMissingProofOfWork(builder.preferences.doMissingProofOfWorkDelayInSeconds * 1000L) internals.proofOfWorkService.doMissingProofOfWork(builder.preferences.doMissingProofOfWorkDelayInSeconds * 1000L)

View File

@ -40,11 +40,12 @@ import java.util.concurrent.Executors
*/ */
class InternalContext( class InternalContext(
val cryptography: Cryptography, val cryptography: Cryptography,
val inventory: ch.dissem.bitmessage.ports.Inventory, val inventory: Inventory,
val nodeRegistry: NodeRegistry, val nodeRegistry: NodeRegistry,
val networkHandler: NetworkHandler, val networkHandler: NetworkHandler,
val addressRepository: AddressRepository, val addressRepository: AddressRepository,
val messageRepository: ch.dissem.bitmessage.ports.MessageRepository, val labelRepository: LabelRepository,
val messageRepository: MessageRepository,
val proofOfWorkRepository: ProofOfWorkRepository, val proofOfWorkRepository: ProofOfWorkRepository,
val proofOfWorkEngine: ProofOfWorkEngine, val proofOfWorkEngine: ProofOfWorkEngine,
val customCommandHandler: CustomCommandHandler, val customCommandHandler: CustomCommandHandler,
@ -216,7 +217,9 @@ class InternalContext(
companion object { companion object {
private val LOG = LoggerFactory.getLogger(InternalContext::class.java) private val LOG = LoggerFactory.getLogger(InternalContext::class.java)
@JvmField val NETWORK_NONCE_TRIALS_PER_BYTE: Long = 1000 @JvmField
@JvmField val NETWORK_EXTRA_BYTES: Long = 1000 val NETWORK_NONCE_TRIALS_PER_BYTE: Long = 1000
@JvmField
val NETWORK_EXTRA_BYTES: Long = 1000
} }
} }

View File

@ -35,6 +35,7 @@ import java.nio.ByteBuffer
import java.util.* import java.util.*
import java.util.Collections import java.util.Collections
import kotlin.collections.HashSet import kotlin.collections.HashSet
import kotlin.collections.LinkedHashSet
private fun message(encoding: Plaintext.Encoding, subject: String, body: String): ByteArray = when (encoding) { private fun message(encoding: Plaintext.Encoding, subject: String, body: String): ByteArray = when (encoding) {
SIMPLE -> "Subject:$subject\nBody:$body".toByteArray() SIMPLE -> "Subject:$subject\nBody:$body".toByteArray()
@ -254,7 +255,7 @@ class Plaintext private constructor(
received = builder.received, received = builder.received,
initialHash = null, initialHash = null,
ttl = builder.ttl, ttl = builder.ttl,
labels = builder.labels, labels = LinkedHashSet(builder.labels),
status = builder.status ?: Status.RECEIVED status = builder.status ?: Status.RECEIVED
) { ) {
id = builder.id id = builder.id
@ -278,28 +279,30 @@ class Plaintext private constructor(
get() { get() {
val s = Scanner(ByteArrayInputStream(message), "UTF-8") val s = Scanner(ByteArrayInputStream(message), "UTF-8")
val firstLine = s.nextLine() val firstLine = s.nextLine()
if (encodingCode == EXTENDED.code) { return when (encodingCode) {
if (Message.TYPE == extendedData?.type) { EXTENDED.code -> if (Message.TYPE == extendedData?.type) {
return (extendedData!!.content as? Message)?.subject (extendedData!!.content as? Message)?.subject
} else { } else {
return null null
} }
} else if (encodingCode == SIMPLE.code) { SIMPLE.code -> firstLine.substring("Subject:".length).trim { it <= ' ' }
return firstLine.substring("Subject:".length).trim { it <= ' ' } else -> {
} else if (firstLine.length > 50) { if (firstLine.length > 50) {
return firstLine.substring(0, 50).trim { it <= ' ' } + "..." firstLine.substring(0, 50).trim { it <= ' ' } + "..."
} else { } else {
return firstLine firstLine
}
}
} }
} }
val text: String? val text: String?
get() { get() {
if (encodingCode == EXTENDED.code) { if (encodingCode == EXTENDED.code) {
if (Message.TYPE == extendedData?.type) { return if (Message.TYPE == extendedData?.type) {
return (extendedData?.content as Message?)?.body (extendedData?.content as Message?)?.body
} else { } else {
return null null
} }
} else { } else {
val text = String(message) val text = String(message)
@ -322,20 +325,20 @@ class Plaintext private constructor(
val parents: List<InventoryVector> val parents: List<InventoryVector>
get() { get() {
val extendedData = extendedData ?: return emptyList() val extendedData = extendedData ?: return emptyList()
if (Message.TYPE == extendedData.type) { return if (Message.TYPE == extendedData.type) {
return (extendedData.content as Message).parents (extendedData.content as Message).parents
} else { } else {
return emptyList() emptyList()
} }
} }
val files: List<Attachment> val files: List<Attachment>
get() { get() {
val extendedData = extendedData ?: return emptyList() val extendedData = extendedData ?: return emptyList()
if (Message.TYPE == extendedData.type) { return if (Message.TYPE == extendedData.type) {
return (extendedData.content as Message).files (extendedData.content as Message).files
} else { } else {
return emptyList() emptyList()
} }
} }
@ -378,8 +381,8 @@ class Plaintext private constructor(
override fun toString(): String { override fun toString(): String {
val subject = subject val subject = subject
if (subject?.isNotEmpty() ?: false) { if (subject?.isNotEmpty() == true) {
return subject!! return subject
} else { } else {
return Strings.hex( return Strings.hex(
initialHash ?: return super.toString() initialHash ?: return super.toString()
@ -529,32 +532,43 @@ class Plaintext private constructor(
} }
class Builder(internal val type: Type) { class Builder(internal val type: Type) {
internal var id: Any? = null var id: Any? = null
internal var inventoryVector: InventoryVector? = null var inventoryVector: InventoryVector? = null
internal var from: BitmessageAddress? = null var from: BitmessageAddress? = null
internal var to: BitmessageAddress? = null var to: BitmessageAddress? = null
private var addressVersion: Long = 0 set(value) {
private var stream: Long = 0 if (value != null) {
private var behaviorBitfield: Int = 0 if (type != MSG && to != null)
private var publicSigningKey: ByteArray? = null throw IllegalArgumentException("recipient address only allowed for msg")
private var publicEncryptionKey: ByteArray? = null field = value
private var nonceTrialsPerByte: Long = 0 }
private var extraBytes: Long = 0 }
private var destinationRipe: ByteArray? = null var addressVersion: Long = 0
private var preventAck: Boolean = false var stream: Long = 0
internal var encoding: Long = 0 var behaviorBitfield: Int = 0
internal var message = ByteArray(0) var publicSigningKey: ByteArray? = null
internal var ackData: ByteArray? = null var publicEncryptionKey: ByteArray? = null
internal var ackMessage: ByteArray? = null var nonceTrialsPerByte: Long = 0
internal var signature: ByteArray? = null var extraBytes: Long = 0
internal var sent: Long? = null var destinationRipe: ByteArray? = null
internal var received: Long? = null set(value) {
internal var status: Status? = null if (type != MSG && value != null) throw IllegalArgumentException("ripe only allowed for msg")
internal val labels = LinkedHashSet<Label>() field = value
internal var ttl: Long = 0 }
internal var retries: Int = 0 var preventAck: Boolean = false
internal var nextTry: Long? = null var encoding: Long = 0
internal var conversation: UUID? = null var message = ByteArray(0)
var ackData: ByteArray? = null
var ackMessage: ByteArray? = null
var signature: ByteArray? = null
var sent: Long? = null
var received: Long? = null
var status: Status? = null
var labels: Collection<Label> = emptySet()
var ttl: Long = 0
var retries: Int = 0
var nextTry: Long? = null
var conversation: UUID? = null
fun id(id: Any): Builder { fun id(id: Any): Builder {
this.id = id this.id = id
@ -572,11 +586,7 @@ class Plaintext private constructor(
} }
fun to(address: BitmessageAddress?): Builder { fun to(address: BitmessageAddress?): Builder {
if (address != null) {
if (type != MSG && to != null)
throw IllegalArgumentException("recipient address only allowed for msg")
to = address to = address
}
return this return this
} }
@ -616,7 +626,6 @@ class Plaintext private constructor(
} }
fun destinationRipe(ripe: ByteArray?): Builder { fun destinationRipe(ripe: ByteArray?): Builder {
if (type != MSG && ripe != null) throw IllegalArgumentException("ripe only allowed for msg")
this.destinationRipe = ripe this.destinationRipe = ripe
return this return this
} }
@ -692,7 +701,7 @@ class Plaintext private constructor(
} }
fun labels(labels: Collection<Label>): Builder { fun labels(labels: Collection<Label>): Builder {
this.labels.addAll(labels) this.labels = labels
return this return this
} }
@ -743,6 +752,12 @@ class Plaintext private constructor(
return this return this
} }
@JvmSynthetic
inline fun build(block: Builder.() -> Unit): Plaintext {
block(this)
return build()
}
fun build(): Plaintext { fun build(): Plaintext {
return Plaintext(this) return Plaintext(this)
} }
@ -774,5 +789,13 @@ class Plaintext private constructor(
.message(Decode.varBytes(input)) .message(Decode.varBytes(input))
.ackMessage(if (type == MSG) Decode.varBytes(input) else null) .ackMessage(if (type == MSG) Decode.varBytes(input) else null)
} }
@JvmSynthetic
inline fun build(type: Type, block: Builder.() -> Unit): Plaintext {
val builder = Builder(type)
block(builder)
return builder.build()
}
} }
} }

View File

@ -0,0 +1,33 @@
/*
* Copyright 2017 Christian Basler
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package ch.dissem.bitmessage.ports
import ch.dissem.bitmessage.entity.valueobject.Label
import ch.dissem.bitmessage.utils.SqlStrings.join
abstract class AbstractLabelRepository : LabelRepository {
override fun getLabels(): List<Label> {
return find("1=1")
}
override fun getLabels(vararg types: Label.Type): List<Label> {
return find("type IN (${join(*types)})")
}
protected abstract fun find(where: String): List<Label>
}

View File

@ -21,8 +21,7 @@ import ch.dissem.bitmessage.entity.BitmessageAddress
import ch.dissem.bitmessage.entity.Plaintext import ch.dissem.bitmessage.entity.Plaintext
import ch.dissem.bitmessage.entity.valueobject.InventoryVector import ch.dissem.bitmessage.entity.valueobject.InventoryVector
import ch.dissem.bitmessage.entity.valueobject.Label import ch.dissem.bitmessage.entity.valueobject.Label
import ch.dissem.bitmessage.exception.ApplicationException import ch.dissem.bitmessage.utils.Collections.single
import ch.dissem.bitmessage.utils.SqlStrings.join
import ch.dissem.bitmessage.utils.Strings import ch.dissem.bitmessage.utils.Strings
import ch.dissem.bitmessage.utils.UnixTime import ch.dissem.bitmessage.utils.UnixTime
import java.util.* import java.util.*
@ -114,25 +113,6 @@ abstract class AbstractMessageRepository : MessageRepository, InternalContext.Co
return find("conversation=X'${conversationId.toString().replace("-", "")}'") return find("conversation=X'${conversationId.toString().replace("-", "")}'")
} }
override fun getLabels(): List<Label> {
return findLabels("1=1")
}
override fun getLabels(vararg types: Label.Type): List<Label> {
return findLabels("type IN (${join(*types)})")
}
protected abstract fun findLabels(where: String): List<Label>
protected fun <T> single(collection: Collection<T>): T? {
return when (collection.size) {
0 -> null
1 -> collection.iterator().next()
else -> throw ApplicationException("This shouldn't happen, found ${collection.size} items, one or none was expected")
}
}
/** /**
* Finds messages that mach the given where statement, with optional offset and limit. If the limit is set to 0, * Finds messages that mach the given where statement, with optional offset and limit. If the limit is set to 0,
* offset and limit are ignored. * offset and limit are ignored.

View File

@ -35,9 +35,9 @@ open class DefaultLabeler : Labeler, InternalContext.ContextHolder {
msg.status = RECEIVED msg.status = RECEIVED
val labelsToAdd = val labelsToAdd =
if (msg.type == BROADCAST) { if (msg.type == BROADCAST) {
ctx.messageRepository.getLabels(Label.Type.BROADCAST, Label.Type.UNREAD) ctx.labelRepository.getLabels(Label.Type.BROADCAST, Label.Type.UNREAD)
} else { } else {
ctx.messageRepository.getLabels(Label.Type.INBOX, Label.Type.UNREAD) ctx.labelRepository.getLabels(Label.Type.INBOX, Label.Type.UNREAD)
} }
msg.addLabels(labelsToAdd) msg.addLabels(labelsToAdd)
listener?.invoke(msg, labelsToAdd, emptyList()) listener?.invoke(msg, labelsToAdd, emptyList())
@ -45,7 +45,7 @@ open class DefaultLabeler : Labeler, InternalContext.ContextHolder {
override fun markAsDraft(msg: Plaintext) { override fun markAsDraft(msg: Plaintext) {
msg.status = DRAFT msg.status = DRAFT
val labelsToAdd = ctx.messageRepository.getLabels(Label.Type.DRAFT) val labelsToAdd = ctx.labelRepository.getLabels(Label.Type.DRAFT)
msg.addLabels(labelsToAdd) msg.addLabels(labelsToAdd)
listener?.invoke(msg, labelsToAdd, emptyList()) listener?.invoke(msg, labelsToAdd, emptyList())
} }
@ -58,7 +58,7 @@ open class DefaultLabeler : Labeler, InternalContext.ContextHolder {
} }
val labelsToRemove = msg.labels.filter { it.type == Label.Type.DRAFT } val labelsToRemove = msg.labels.filter { it.type == Label.Type.DRAFT }
msg.removeLabel(Label.Type.DRAFT) msg.removeLabel(Label.Type.DRAFT)
val labelsToAdd = ctx.messageRepository.getLabels(Label.Type.OUTBOX) val labelsToAdd = ctx.labelRepository.getLabels(Label.Type.OUTBOX)
msg.addLabels(labelsToAdd) msg.addLabels(labelsToAdd)
listener?.invoke(msg, labelsToAdd, labelsToRemove) listener?.invoke(msg, labelsToAdd, labelsToRemove)
} }
@ -67,7 +67,7 @@ open class DefaultLabeler : Labeler, InternalContext.ContextHolder {
msg.status = SENT msg.status = SENT
val labelsToRemove = msg.labels.filter { it.type == Label.Type.OUTBOX } val labelsToRemove = msg.labels.filter { it.type == Label.Type.OUTBOX }
msg.removeLabel(Label.Type.OUTBOX) msg.removeLabel(Label.Type.OUTBOX)
val labelsToAdd = ctx.messageRepository.getLabels(Label.Type.SENT) val labelsToAdd = ctx.labelRepository.getLabels(Label.Type.SENT)
msg.addLabels(labelsToAdd) msg.addLabels(labelsToAdd)
listener?.invoke(msg, labelsToAdd, labelsToRemove) listener?.invoke(msg, labelsToAdd, labelsToRemove)
} }
@ -83,7 +83,7 @@ open class DefaultLabeler : Labeler, InternalContext.ContextHolder {
} }
override fun markAsUnread(msg: Plaintext) { override fun markAsUnread(msg: Plaintext) {
val labelsToAdd = ctx.messageRepository.getLabels(Label.Type.UNREAD) val labelsToAdd = ctx.labelRepository.getLabels(Label.Type.UNREAD)
msg.addLabels(labelsToAdd) msg.addLabels(labelsToAdd)
listener?.invoke(msg, labelsToAdd, emptyList()) listener?.invoke(msg, labelsToAdd, emptyList())
} }
@ -91,7 +91,7 @@ open class DefaultLabeler : Labeler, InternalContext.ContextHolder {
override fun delete(msg: Plaintext) { override fun delete(msg: Plaintext) {
val labelsToRemove = msg.labels.toSet() val labelsToRemove = msg.labels.toSet()
msg.labels.clear() msg.labels.clear()
val labelsToAdd = ctx.messageRepository.getLabels(Label.Type.TRASH) val labelsToAdd = ctx.labelRepository.getLabels(Label.Type.TRASH)
msg.addLabels(labelsToAdd) msg.addLabels(labelsToAdd)
listener?.invoke(msg, labelsToAdd, labelsToRemove) listener?.invoke(msg, labelsToAdd, labelsToRemove)
} }

View File

@ -0,0 +1,27 @@
/*
* Copyright 2017 Christian Basler
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package ch.dissem.bitmessage.ports
import ch.dissem.bitmessage.entity.valueobject.Label
interface LabelRepository {
fun getLabels(): List<Label>
fun getLabels(vararg types: Label.Type): List<Label>
fun save(label: Label)
}

View File

@ -24,12 +24,6 @@ import ch.dissem.bitmessage.entity.valueobject.Label
import java.util.* import java.util.*
interface MessageRepository { interface MessageRepository {
fun getLabels(): List<Label>
fun getLabels(vararg types: Label.Type): List<Label>
fun save(label: Label)
fun countUnread(label: Label?): Int fun countUnread(label: Label?): Int
fun getAllMessages(): List<Plaintext> fun getAllMessages(): List<Plaintext>

View File

@ -16,6 +16,7 @@
package ch.dissem.bitmessage.utils package ch.dissem.bitmessage.utils
import ch.dissem.bitmessage.exception.ApplicationException
import java.util.* import java.util.*
object Collections { object Collections {
@ -67,4 +68,12 @@ object Collections {
} }
throw IllegalArgumentException("Empty collection? Size: " + collection.size) throw IllegalArgumentException("Empty collection? Size: " + collection.size)
} }
@JvmStatic fun <T> single(collection: Collection<T>): T? {
return when (collection.size) {
0 -> null
1 -> collection.iterator().next()
else -> throw ApplicationException("This shouldn't happen, found ${collection.size} items, one or none was expected")
}
}
} }

View File

@ -98,6 +98,7 @@ class BitmessageContextTest {
cryptography = BouncyCryptography() cryptography = BouncyCryptography()
inventory = testInventory inventory = testInventory
listener = testListener listener = testListener
labelRepo = mock()
messageRepo = mock() messageRepo = mock()
networkHandler = mock { networkHandler = mock {
on { getNetworkStatus() } doReturn Property("test", "mocked") on { getNetworkStatus() } doReturn Property("test", "mocked")

View File

@ -113,6 +113,7 @@ object TestUtils {
nodeRegistry: NodeRegistry = mock {}, nodeRegistry: NodeRegistry = mock {},
networkHandler: NetworkHandler = mock {}, networkHandler: NetworkHandler = mock {},
addressRepository: AddressRepository = mock {}, addressRepository: AddressRepository = mock {},
labelRepository: LabelRepository = mock {},
messageRepository: MessageRepository = mock {}, messageRepository: MessageRepository = mock {},
proofOfWorkRepository: ProofOfWorkRepository = mock {}, proofOfWorkRepository: ProofOfWorkRepository = mock {},
proofOfWorkEngine: ProofOfWorkEngine = mock {}, proofOfWorkEngine: ProofOfWorkEngine = mock {},
@ -129,6 +130,7 @@ object TestUtils {
nodeRegistry, nodeRegistry,
networkHandler, networkHandler,
addressRepository, addressRepository,
labelRepository,
messageRepository, messageRepository,
proofOfWorkRepository, proofOfWorkRepository,
proofOfWorkEngine, proofOfWorkEngine,

View File

@ -285,7 +285,7 @@ public class Application {
} }
private void labels() { private void labels() {
List<Label> labels = ctx.messages().getLabels(); List<Label> labels = ctx.labels().getLabels();
String command; String command;
do { do {
System.out.println(); System.out.println();

View File

@ -35,9 +35,7 @@ import java.util.concurrent.TimeUnit;
import static ch.dissem.bitmessage.entity.payload.Pubkey.Feature.DOES_ACK; import static ch.dissem.bitmessage.entity.payload.Pubkey.Feature.DOES_ACK;
import static ch.dissem.bitmessage.utils.UnixTime.MINUTE; import static ch.dissem.bitmessage.utils.UnixTime.MINUTE;
import static com.nhaarman.mockito_kotlin.MockitoKt.spy; import static com.nhaarman.mockito_kotlin.MockitoKt.*;
import static com.nhaarman.mockito_kotlin.MockitoKt.timeout;
import static com.nhaarman.mockito_kotlin.MockitoKt.verify;
import static org.hamcrest.CoreMatchers.equalTo; import static org.hamcrest.CoreMatchers.equalTo;
import static org.junit.Assert.assertThat; import static org.junit.Assert.assertThat;
import static org.mockito.ArgumentMatchers.any; import static org.mockito.ArgumentMatchers.any;
@ -71,6 +69,7 @@ public class SystemTest {
alice = new BitmessageContext.Builder() alice = new BitmessageContext.Builder()
.addressRepo(new JdbcAddressRepository(aliceDB)) .addressRepo(new JdbcAddressRepository(aliceDB))
.inventory(new JdbcInventory(aliceDB)) .inventory(new JdbcInventory(aliceDB))
.labelRepo(new JdbcLabelRepository(aliceDB))
.messageRepo(new JdbcMessageRepository(aliceDB)) .messageRepo(new JdbcMessageRepository(aliceDB))
.powRepo(new JdbcProofOfWorkRepository(aliceDB)) .powRepo(new JdbcProofOfWorkRepository(aliceDB))
.nodeRegistry(new TestNodeRegistry(bobPort)) .nodeRegistry(new TestNodeRegistry(bobPort))
@ -89,6 +88,7 @@ public class SystemTest {
bob = new BitmessageContext.Builder() bob = new BitmessageContext.Builder()
.addressRepo(new JdbcAddressRepository(bobDB)) .addressRepo(new JdbcAddressRepository(bobDB))
.inventory(new JdbcInventory(bobDB)) .inventory(new JdbcInventory(bobDB))
.labelRepo(new JdbcLabelRepository(bobDB))
.messageRepo(new JdbcMessageRepository(bobDB)) .messageRepo(new JdbcMessageRepository(bobDB))
.powRepo(new JdbcProofOfWorkRepository(bobDB)) .powRepo(new JdbcProofOfWorkRepository(bobDB))
.nodeRegistry(new TestNodeRegistry(alicePort)) .nodeRegistry(new TestNodeRegistry(alicePort))

View File

@ -72,6 +72,7 @@ class NetworkHandlerTest {
nodeRegistry = TestNodeRegistry() nodeRegistry = TestNodeRegistry()
networkHandler = peerNetworkHandler networkHandler = peerNetworkHandler
addressRepo = mock() addressRepo = mock()
labelRepo = mock()
messageRepo = mock() messageRepo = mock()
proofOfWorkRepo = mock() proofOfWorkRepo = mock()
customCommandHandler = object : CustomCommandHandler { customCommandHandler = object : CustomCommandHandler {
@ -102,6 +103,7 @@ class NetworkHandlerTest {
nodeRegistry = TestNodeRegistry(peerAddress) nodeRegistry = TestNodeRegistry(peerAddress)
networkHandler = nodeNetworkHandler networkHandler = nodeNetworkHandler
addressRepo = mock() addressRepo = mock()
labelRepo = mock()
messageRepo = mock() messageRepo = mock()
proofOfWorkRepo = mock() proofOfWorkRepo = mock()
customCommandHandler = object : CustomCommandHandler { customCommandHandler = object : CustomCommandHandler {

View File

@ -0,0 +1,123 @@
/*
* Copyright 2017 Christian Basler
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package ch.dissem.bitmessage.repository
import ch.dissem.bitmessage.entity.valueobject.Label
import ch.dissem.bitmessage.ports.AbstractLabelRepository
import ch.dissem.bitmessage.ports.LabelRepository
import org.slf4j.LoggerFactory
import java.sql.Connection
import java.sql.ResultSet
import java.sql.SQLException
import java.util.*
class JdbcLabelRepository(private val config: JdbcConfig) : AbstractLabelRepository(), LabelRepository {
override fun find(where: String): List<Label> {
try {
config.getConnection().use { connection ->
return findLabels(connection, where)
}
} catch (e: SQLException) {
LOG.error(e.message, e)
return ArrayList()
}
}
override fun save(label: Label) {
config.getConnection().use { connection ->
if (label.id != null) {
connection.prepareStatement("UPDATE Label SET label=?, type=?, color=?, ord=? WHERE id=?").use { ps ->
ps.setString(1, label.toString())
ps.setString(2, label.type?.name)
ps.setInt(3, label.color)
ps.setInt(4, label.ord)
ps.setInt(5, label.id as Int)
ps.executeUpdate()
}
} else {
try {
connection.autoCommit = false
var exists = false
connection.prepareStatement("SELECT COUNT(1) FROM Label WHERE label=?").use { ps ->
ps.setString(1, label.toString())
val rs = ps.executeQuery()
if (rs.next()) {
exists = rs.getInt(1) > 0
}
}
if (exists) {
connection.prepareStatement("UPDATE Label SET type=?, color=?, ord=? WHERE label=?").use { ps ->
ps.setString(1, label.type?.name)
ps.setInt(2, label.color)
ps.setInt(3, label.ord)
ps.setString(4, label.toString())
ps.executeUpdate()
}
} else {
connection.prepareStatement("INSERT INTO Label (label, type, color, ord) VALUES (?, ?, ?, ?)").use { ps ->
ps.setString(1, label.toString())
ps.setString(2, label.type?.name)
ps.setInt(3, label.color)
ps.setInt(4, label.ord)
ps.executeUpdate()
}
}
connection.commit()
} catch (e: Exception) {
connection.rollback()
throw e
}
}
}
}
private fun findLabels(connection: Connection, where: String): List<Label> {
val result = ArrayList<Label>()
try {
connection.createStatement().use { stmt ->
stmt.executeQuery("SELECT id, label, type, color, ord FROM Label WHERE $where").use { rs ->
while (rs.next()) {
result.add(getLabel(rs))
}
}
}
} catch (e: SQLException) {
LOG.error(e.message, e)
}
return result
}
companion object {
private val LOG = LoggerFactory.getLogger(JdbcLabelRepository::class.java)
internal fun getLabel(rs: ResultSet): Label {
val typeName = rs.getString("type")
val type = if (typeName == null) {
null
} else {
Label.Type.valueOf(typeName)
}
val label = Label(rs.getString("label"), type, rs.getInt("color"), rs.getInt("ord"))
label.id = rs.getLong("id")
return label
}
}
}

View File

@ -16,6 +16,7 @@
package ch.dissem.bitmessage.repository package ch.dissem.bitmessage.repository
import ch.dissem.bitmessage.entity.BitmessageAddress
import ch.dissem.bitmessage.entity.Plaintext import ch.dissem.bitmessage.entity.Plaintext
import ch.dissem.bitmessage.entity.valueobject.InventoryVector import ch.dissem.bitmessage.entity.valueobject.InventoryVector
import ch.dissem.bitmessage.entity.valueobject.Label import ch.dissem.bitmessage.entity.valueobject.Label
@ -29,79 +30,6 @@ import java.util.*
class JdbcMessageRepository(private val config: JdbcConfig) : AbstractMessageRepository(), MessageRepository { class JdbcMessageRepository(private val config: JdbcConfig) : AbstractMessageRepository(), MessageRepository {
override fun findLabels(where: String): List<Label> {
try {
config.getConnection().use { connection ->
return findLabels(connection, where)
}
} catch (e: SQLException) {
LOG.error(e.message, e)
return ArrayList()
}
}
private fun getLabel(rs: ResultSet): Label {
val typeName = rs.getString("type")
val type = if (typeName == null) {
null
} else {
Label.Type.valueOf(typeName)
}
val label = Label(rs.getString("label"), type, rs.getInt("color"), rs.getInt("ord"))
label.id = rs.getLong("id")
return label
}
override fun save(label: Label) {
config.getConnection().use { connection ->
if (label.id != null) {
connection.prepareStatement("UPDATE Label SET label=?, type=?, color=?, ord=? WHERE id=?").use { ps ->
ps.setString(1, label.toString())
ps.setString(2, label.type?.name)
ps.setInt(3, label.color)
ps.setInt(4, label.ord)
ps.setInt(5, label.id as Int)
ps.executeUpdate()
}
} else {
try {
connection.autoCommit = false
var exists = false
connection.prepareStatement("SELECT COUNT(1) FROM Label WHERE label=?").use { ps ->
ps.setString(1, label.toString())
val rs = ps.executeQuery()
if (rs.next()) {
exists = rs.getInt(1) > 0
}
}
if (exists) {
connection.prepareStatement("UPDATE Label SET type=?, color=?, ord=? WHERE label=?").use { ps ->
ps.setString(1, label.type?.name)
ps.setInt(2, label.color)
ps.setInt(3, label.ord)
ps.setString(4, label.toString())
ps.executeUpdate()
}
} else {
connection.prepareStatement("INSERT INTO Label (label, type, color, ord) VALUES (?, ?, ?, ?)").use { ps ->
ps.setString(1, label.toString())
ps.setString(2, label.type?.name)
ps.setInt(3, label.color)
ps.setInt(4, label.ord)
ps.executeUpdate()
}
}
connection.commit()
} catch (e: Exception) {
connection.rollback()
throw e
}
}
}
}
override fun countUnread(label: Label?): Int { override fun countUnread(label: Label?): Int {
val where = if (label == null) { val where = if (label == null) {
"" ""
@ -136,26 +64,7 @@ class JdbcMessageRepository(private val config: JdbcConfig) : AbstractMessageRep
"""SELECT id, iv, type, sender, recipient, data, ack_data, sent, received, initial_hash, status, ttl, retries, next_try, conversation """SELECT id, iv, type, sender, recipient, data, ack_data, sent, received, initial_hash, status, ttl, retries, next_try, conversation
FROM Message WHERE $where $limit""").use { rs -> FROM Message WHERE $where $limit""").use { rs ->
while (rs.next()) { while (rs.next()) {
val iv = rs.getBytes("iv") val message = getMessage(connection, rs)
val data = rs.getBinaryStream("data")
val type = Plaintext.Type.valueOf(rs.getString("type"))
val builder = Plaintext.readWithoutSignature(type, data)
val id = rs.getLong("id")
builder.id(id)
builder.IV(InventoryVector.fromHash(iv))
builder.from(ctx.addressRepository.getAddress(rs.getString("sender"))!!)
rs.getString("recipient")?.let { builder.to(ctx.addressRepository.getAddress(it)) }
builder.ackData(rs.getBytes("ack_data"))
builder.sent(rs.getObject("sent") as Long?)
builder.received(rs.getObject("received") as Long?)
builder.status(Plaintext.Status.valueOf(rs.getString("status")))
builder.ttl(rs.getLong("ttl"))
builder.retries(rs.getInt("retries"))
builder.nextTry(rs.getObject("next_try") as Long?)
builder.conversation(rs.getObject("conversation") as UUID? ?: UUID.randomUUID())
builder.labels(findLabels(connection,
"id IN (SELECT label_id FROM Message_Label WHERE message_id=$id) ORDER BY ord"))
val message = builder.build()
message.initialHash = rs.getBytes("initial_hash") message.initialHash = rs.getBytes("initial_hash")
result.add(message) result.add(message)
} }
@ -165,17 +74,38 @@ class JdbcMessageRepository(private val config: JdbcConfig) : AbstractMessageRep
} catch (e: SQLException) { } catch (e: SQLException) {
LOG.error(e.message, e) LOG.error(e.message, e)
} }
return result return result
} }
private fun getMessage(connection: Connection, rs: ResultSet): Plaintext {
return Plaintext.readWithoutSignature(
Plaintext.Type.valueOf(rs.getString("type")),
rs.getBinaryStream("data")
).build {
id = rs.getLong("id")
inventoryVector = InventoryVector.fromHash(rs.getBytes("iv"))
from = rs.getString("sender")?.let { ctx.addressRepository.getAddress(it) ?: BitmessageAddress(it) }
to = rs.getString("recipient")?.let { ctx.addressRepository.getAddress(it) ?: BitmessageAddress(it) }
ackData = rs.getBytes("ack_data")
sent = rs.getObject("sent") as Long?
received = rs.getObject("received") as Long?
status = Plaintext.Status.valueOf(rs.getString("status"))
ttl = rs.getLong("ttl")
retries = rs.getInt("retries")
nextTry = rs.getObject("next_try") as Long?
conversation = rs.getObject("conversation") as UUID? ?: UUID.randomUUID()
labels = findLabels(connection,
"id IN (SELECT label_id FROM Message_Label WHERE message_id=$id) ORDER BY ord")
}
}
private fun findLabels(connection: Connection, where: String): List<Label> { private fun findLabels(connection: Connection, where: String): List<Label> {
val result = ArrayList<Label>() val result = ArrayList<Label>()
try { try {
connection.createStatement().use { stmt -> connection.createStatement().use { stmt ->
stmt.executeQuery("SELECT id, label, type, color, ord FROM Label WHERE $where").use { rs -> stmt.executeQuery("SELECT id, label, type, color, ord FROM Label WHERE $where").use { rs ->
while (rs.next()) { while (rs.next()) {
result.add(getLabel(rs)) result.add(JdbcLabelRepository.getLabel(rs))
} }
} }
} }
@ -258,19 +188,7 @@ class JdbcMessageRepository(private val config: JdbcConfig) : AbstractMessageRep
"status, initial_hash, ttl, retries, next_try, conversation) " + "status, initial_hash, ttl, retries, next_try, conversation) " +
"VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)", "VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)",
Statement.RETURN_GENERATED_KEYS).use { ps -> Statement.RETURN_GENERATED_KEYS).use { ps ->
ps.setBytes(1, if (message.inventoryVector == null) null else message.inventoryVector!!.hash) prepare(ps, message)
ps.setString(2, message.type.name)
ps.setString(3, message.from.address)
ps.setString(4, if (message.to == null) null else message.to!!.address)
writeBlob(ps, 5, message)
ps.setBytes(6, message.ackData)
ps.setObject(7, message.sent)
ps.setObject(8, message.received)
ps.setString(9, message.status.name)
ps.setBytes(10, message.initialHash)
ps.setLong(11, message.ttl)
ps.setInt(12, message.retries)
ps.setObject(13, message.nextTry)
ps.setObject(14, message.conversationId) ps.setObject(14, message.conversationId)
try { try {
@ -291,6 +209,13 @@ class JdbcMessageRepository(private val config: JdbcConfig) : AbstractMessageRep
"UPDATE Message SET iv=?, type=?, sender=?, recipient=?, data=?, ack_data=?, sent=?, received=?, " + "UPDATE Message SET iv=?, type=?, sender=?, recipient=?, data=?, ack_data=?, sent=?, received=?, " +
"status=?, initial_hash=?, ttl=?, retries=?, next_try=? " + "status=?, initial_hash=?, ttl=?, retries=?, next_try=? " +
"WHERE id=?").use { ps -> "WHERE id=?").use { ps ->
prepare(ps, message)
ps.setLong(14, (message.id as Long?)!!)
ps.executeUpdate()
}
}
private fun prepare(ps: PreparedStatement, message: Plaintext): Int{
ps.setBytes(1, if (message.inventoryVector == null) null else message.inventoryVector!!.hash) ps.setBytes(1, if (message.inventoryVector == null) null else message.inventoryVector!!.hash)
ps.setString(2, message.type.name) ps.setString(2, message.type.name)
ps.setString(3, message.from.address) ps.setString(3, message.from.address)
@ -304,9 +229,7 @@ class JdbcMessageRepository(private val config: JdbcConfig) : AbstractMessageRep
ps.setLong(11, message.ttl) ps.setLong(11, message.ttl)
ps.setInt(12, message.retries) ps.setInt(12, message.retries)
ps.setObject(13, message.nextTry) ps.setObject(13, message.nextTry)
ps.setLong(14, (message.id as Long?)!!) return 14
ps.executeUpdate()
}
} }
override fun remove(message: Plaintext) { override fun remove(message: Plaintext) {
@ -332,7 +255,6 @@ class JdbcMessageRepository(private val config: JdbcConfig) : AbstractMessageRep
} catch (e: SQLException) { } catch (e: SQLException) {
LOG.error(e.message, e) LOG.error(e.message, e)
} }
} }
override fun findConversations(label: Label?): List<UUID> { override fun findConversations(label: Label?): List<UUID> {

View File

@ -0,0 +1,49 @@
/*
* Copyright 2017 Christian Basler
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package ch.dissem.bitmessage.repository
import ch.dissem.bitmessage.entity.valueobject.Label
import ch.dissem.bitmessage.ports.LabelRepository
import org.junit.Assert.assertEquals
import org.junit.Before
import org.junit.Test
class JdbcLabelRepositoryTest : TestBase() {
private lateinit var repo: LabelRepository
@Before
fun setUp() {
val config = TestJdbcConfig()
config.reset()
repo = JdbcLabelRepository(config)
}
@Test
fun `ensure labels are retrieved`() {
val labels = repo.getLabels()
assertEquals(5, labels.size.toLong())
}
@Test
fun `ensure labels can be retrieved by type`() {
val labels = repo.getLabels(Label.Type.INBOX)
assertEquals(1, labels.size.toLong())
assertEquals("Inbox", labels[0].toString())
}
}

View File

@ -26,6 +26,7 @@ import ch.dissem.bitmessage.entity.valueobject.ExtendedEncoding
import ch.dissem.bitmessage.entity.valueobject.Label import ch.dissem.bitmessage.entity.valueobject.Label
import ch.dissem.bitmessage.entity.valueobject.PrivateKey import ch.dissem.bitmessage.entity.valueobject.PrivateKey
import ch.dissem.bitmessage.entity.valueobject.extended.Message import ch.dissem.bitmessage.entity.valueobject.extended.Message
import ch.dissem.bitmessage.ports.LabelRepository
import ch.dissem.bitmessage.ports.MessageRepository import ch.dissem.bitmessage.ports.MessageRepository
import ch.dissem.bitmessage.utils.TestUtils import ch.dissem.bitmessage.utils.TestUtils
import ch.dissem.bitmessage.utils.TestUtils.mockedInternalContext import ch.dissem.bitmessage.utils.TestUtils.mockedInternalContext
@ -46,6 +47,7 @@ class JdbcMessageRepositoryTest : TestBase() {
private lateinit var identity: BitmessageAddress private lateinit var identity: BitmessageAddress
private lateinit var repo: MessageRepository private lateinit var repo: MessageRepository
private lateinit var labelRepo: LabelRepository
private lateinit var inbox: Label private lateinit var inbox: Label
private lateinit var sent: Label private lateinit var sent: Label
@ -58,6 +60,7 @@ class JdbcMessageRepositoryTest : TestBase() {
config.reset() config.reset()
val addressRepo = JdbcAddressRepository(config) val addressRepo = JdbcAddressRepository(config)
repo = JdbcMessageRepository(config) repo = JdbcMessageRepository(config)
labelRepo = JdbcLabelRepository(config)
mockedInternalContext( mockedInternalContext(
cryptography = BouncyCryptography(), cryptography = BouncyCryptography(),
addressRepository = addressRepo, addressRepository = addressRepo,
@ -76,29 +79,16 @@ class JdbcMessageRepositoryTest : TestBase() {
identity = BitmessageAddress(PrivateKey(false, 1, 1000, 1000, DOES_ACK)) identity = BitmessageAddress(PrivateKey(false, 1, 1000, 1000, DOES_ACK))
addressRepo.save(identity) addressRepo.save(identity)
inbox = repo.getLabels(Label.Type.INBOX)[0] inbox = labelRepo.getLabels(Label.Type.INBOX)[0]
sent = repo.getLabels(Label.Type.SENT)[0] sent = labelRepo.getLabels(Label.Type.SENT)[0]
drafts = repo.getLabels(Label.Type.DRAFT)[0] drafts = labelRepo.getLabels(Label.Type.DRAFT)[0]
unread = repo.getLabels(Label.Type.UNREAD)[0] unread = labelRepo.getLabels(Label.Type.UNREAD)[0]
addMessage(contactA, identity, Plaintext.Status.RECEIVED, inbox, unread) addMessage(contactA, identity, Plaintext.Status.RECEIVED, inbox, unread)
addMessage(identity, contactA, Plaintext.Status.DRAFT, drafts) addMessage(identity, contactA, Plaintext.Status.DRAFT, drafts)
addMessage(identity, contactB, Plaintext.Status.DRAFT, unread) addMessage(identity, contactB, Plaintext.Status.DRAFT, unread)
} }
@Test
fun `ensure labels are retrieved`() {
val labels = repo.getLabels()
assertEquals(5, labels.size.toLong())
}
@Test
fun `ensure labels can be retrieved by type`() {
val labels = repo.getLabels(Label.Type.INBOX)
assertEquals(1, labels.size.toLong())
assertEquals("Inbox", labels[0].toString())
}
@Test @Test
fun `ensure messages can be found by label`() { fun `ensure messages can be found by label`() {
val messages = repo.findMessages(inbox) val messages = repo.findMessages(inbox)

View File

@ -53,6 +53,7 @@ class WifExporterTest {
cryptography = BouncyCryptography() cryptography = BouncyCryptography()
networkHandler = mock() networkHandler = mock()
inventory = mock() inventory = mock()
labelRepo = mock()
messageRepo = mock() messageRepo = mock()
proofOfWorkRepo = mock() proofOfWorkRepo = mock()
nodeRegistry = mock() nodeRegistry = mock()

View File

@ -54,6 +54,7 @@ class WifImporterTest {
cryptography = BouncyCryptography() cryptography = BouncyCryptography()
networkHandler = mock() networkHandler = mock()
inventory = mock() inventory = mock()
labelRepo = mock()
messageRepo = mock() messageRepo = mock()
proofOfWorkRepo = mock() proofOfWorkRepo = mock()
nodeRegistry = mock() nodeRegistry = mock()