view src/main/kotlin/name/blackcap/passman/Database.kt @ 18:8f3ddebb4295

Was using wrong db object to decrypt, fixed.
author David Barts <n5jrn@me.com>
date Tue, 04 Apr 2023 20:38:52 -0700
parents 7a74ae668665
children 7d80cbcb67bb
line wrap: on
line source

package name.blackcap.passman

import java.nio.file.Files
import java.nio.file.Path
import java.security.GeneralSecurityException
import java.security.SecureRandom
import java.sql.*

class Database private constructor(val connection: Connection, val encryption: Encryption){

    companion object {
        private const val PLAINTEXT = "This is a test."
        private const val SALT_LENGTH = 16
        private const val DEFAULT_PROMPT = "Decryption key: "

        fun open(passwordPrompt: String = DEFAULT_PROMPT, fileName: String = DB_FILE,
                 create: Boolean = true): Database {
            val exists = Files.exists(Path.of(fileName))
            if (!exists) {
                if (create) {
                    error("initializing database ${see(fileName)}")
                } else {
                    die("${see(fileName)} not found")
                }
            }
            val masterPassword = getPassword(passwordPrompt, !exists)
            val conn = DriverManager.getConnection("jdbc:sqlite:$fileName")
            val enc = if (exists) { reuse(conn, masterPassword) } else { init(conn, masterPassword) }
            val ret = Database(conn, enc)
            verifyPassword(ret)
            return ret
        }

        private fun reuse(connection: Connection, masterPassword: CharArray): Encryption {
            try {
                connection.prepareStatement("select value from blobs where name = ?").use {
                    it.setString(1, "salt")
                    val result = it.executeQuery()
                    if (!result.next()) {
                        die("corrupt database, missing salt parameter")
                    }
                    val salt = result.getBytes(1)
                    return Encryption(masterPassword, salt)
                }
            } catch (e: SQLException) {
                e.printStackTrace()
                die("unable to reopen database")
                throw RuntimeException("this will never happen")
            }
        }

        private fun init(connection: Connection, masterPassword: CharArray): Encryption {
            try {
                connection.createStatement().use { stmt ->
                    stmt.executeUpdate("create table integers ( name string not null, value integer )")
                    stmt.executeUpdate("create table reals ( name string not null, value integer )")
                    stmt.executeUpdate("create table strings ( name string not null, value real )")
                    stmt.executeUpdate("create table blobs ( name string not null, value blob )")
                    stmt.executeUpdate(
                        "create table passwords (" +
                                "id integer not null primary key, " +
                                "name blob not null, " +
                                "username blob not null, " +
                                "password blob not null, " +
                                "notes blob, " +
                                "created integer, " +
                                "modified integer, " +
                                "accessed integer )"
                    )
                }
                val salt = ByteArray(SALT_LENGTH).also { SecureRandom().nextBytes(it) }
                val encryption = Encryption(masterPassword, salt)
                connection.prepareStatement("insert into blobs (name, value) values (?, ?)").use {
                    it.setString(1, "salt")
                    it.setBytes(2, salt)
                    it.execute()
                }
                connection.prepareStatement("insert into blobs (name, value) values (?, ?)").use { stmt ->
                    stmt.setString(1, "test")
                    stmt.setEncryptedString(2, PLAINTEXT, encryption)
                    stmt.execute()
                }
                return encryption
            } catch (e: SQLException) {
                e.printStackTrace()
                die("unable to initialize database")
                throw RuntimeException("this will never happen")
            }
        }

        private fun verifyPassword(database: Database) {
            try {
                database.connection.prepareStatement("select value from blobs where name = ?").use { stmt ->
                    stmt.setString(1, "test")
                    val result = stmt.executeQuery()
                    if (!result.next()) {
                        die("corrupt database, missing test parameter")
                    }
                    val readFromDb = result.getDecryptedString(1, database.encryption)
                    if (readFromDb != PLAINTEXT) {
                        /* might also get thrown by getDecryptedString if bad */
                        throw GeneralSecurityException("bad key!")
                    }
                }
            } catch (e: SQLException) {
                e.printStackTrace()
                die("unable to verify decryption key")
            } catch (e: GeneralSecurityException) {
                die("invalid decryption key")
            }
        }
    }

    fun makeKey(name: String): Long = Hashing.hash(encryption.encryptFromString0(name.lowercase()))
}

fun ResultSet.getDecryptedString(columnIndex: Int, encryption: Encryption): String? {
    return encryption.decryptToString(getBytes(columnIndex) ?: return null)
}

fun ResultSet.getDecrypted(columnIndex: Int, encryption: Encryption): CharArray? {
    return encryption.decrypt(getBytes(columnIndex) ?: return null)
}

fun PreparedStatement.setEncryptedString(columnIndex: Int, value: String?, encryption: Encryption) =
    if (value == null) {
        setNull(columnIndex, Types.BLOB)
    } else {
        setBytes(columnIndex, encryption.encryptFromString(value))
    }

fun PreparedStatement.setEncrypted(columnIndex: Int, value: CharArray?, encryption: Encryption) =
    if (value == null) {
        setNull(columnIndex, Types.BLOB)
    } else {
        setBytes(columnIndex, encryption.encrypt(value))
    }

fun PreparedStatement.setBytesOrNull(columnIndex: Int, value: ByteArray?) =
    if (value == null) {
        setNull(columnIndex, Types.BLOB)
    } else {
        setBytes(columnIndex, value)
}

fun PreparedStatement.setLongOrNull(columnIndex: Int, value: Long?) =
    if (value == null) {
        setNull(columnIndex, Types.INTEGER)
    } else {
        setLong(columnIndex, value)
    }

fun PreparedStatement.setDateOrNull(parameterIndex: Int, value: Long?) {
    if (value == null || value == 0L) {
        setNull(parameterIndex, Types.INTEGER)
    } else {
        setLong(parameterIndex, value)
    }
}