view src/main/kotlin/name/blackcap/passman/Database.kt @ 22:07406c4af4a5

More interactive mode stuff.
author David Barts <n5jrn@me.com>
date Tue, 02 Jul 2024 17:34:52 -0700
parents ea65ab890f66
children
line wrap: on
line source

package name.blackcap.passman

import java.nio.file.Files
import java.nio.file.Path
import java.security.GeneralSecurityException
import java.security.SecureRandom
import java.sql.*

class Database private constructor(val connection: Connection, val encryption: Encryption){

    companion object {
        private const val PLAINTEXT = "This is a test."
        private const val SALT_LENGTH = 16
        private const val DEFAULT_PROMPT = "Decryption key: "
        lateinit var default: Database

        fun open(passwordPrompt: String = DEFAULT_PROMPT, fileName: String = DB_FILE,
                 create: Boolean = true): Database {
            val exists = Files.exists(Path.of(fileName))
            if (!exists) {
                if (create) {
                    error("initializing database ${see(fileName)}")
                } else {
                    throw DatabaseException("${see(fileName)} not found")
                }
            }
            val masterPassword = try {
                getPassword(passwordPrompt, !exists)
            } catch (e: ConsoleException) {
                throw DatabaseException(e.message, cause = e)
            }
            val conn = DriverManager.getConnection("jdbc:sqlite:$fileName")
            val enc = if (exists) { reuse(conn, masterPassword) } else { init(conn, masterPassword) }
            val ret = Database(conn, enc)
            verifyPassword(ret)
            return ret
        }

        private fun reuse(connection: Connection, masterPassword: CharArray): Encryption {
            try {
                connection.prepareStatement("select value from blobs where name = ?").use {
                    it.setString(1, "salt")
                    val result = it.executeQuery()
                    if (!result.next()) {
                        throw DatabaseException("corrupt database, missing salt parameter")
                    }
                    val salt = result.getBytes(1)
                    return Encryption(masterPassword, salt)
                }
            } catch (e: SQLException) {
                e.printStackTrace()
                throw DatabaseException("unable to reopen database", e)
            }
        }

        private fun init(connection: Connection, masterPassword: CharArray): Encryption {
            try {
                connection.createStatement().use { stmt ->
                    stmt.executeUpdate("create table integers ( name text not null, value integer )")
                    stmt.executeUpdate("create table reals ( name text not null, value real )")
                    stmt.executeUpdate("create table strings ( name text not null, value text )")
                    stmt.executeUpdate("create table blobs ( name text not null, value blob )")
                    stmt.executeUpdate(
                        "create table passwords (" +
                                "id integer not null primary key, " +
                                "name blob not null, " +
                                "username blob not null, " +
                                "password blob not null, " +
                                "notes blob, " +
                                "created integer, " +
                                "modified integer, " +
                                "accessed integer )"
                    )
                }
                val salt = ByteArray(SALT_LENGTH).also { SecureRandom().nextBytes(it) }
                val encryption = Encryption(masterPassword, salt)
                connection.prepareStatement("insert into blobs (name, value) values (?, ?)").use {
                    it.setString(1, "salt")
                    it.setBytes(2, salt)
                    it.execute()
                }
                connection.prepareStatement("insert into blobs (name, value) values (?, ?)").use { stmt ->
                    stmt.setString(1, "test")
                    stmt.setEncryptedString(2, PLAINTEXT, encryption)
                    stmt.execute()
                }
                return encryption
            } catch (e: SQLException) {
                e.printStackTrace()
                throw DatabaseException("unable to initialize database", e)
            }
        }

        private fun verifyPassword(database: Database) {
            try {
                database.connection.prepareStatement("select value from blobs where name = ?").use { stmt ->
                    stmt.setString(1, "test")
                    val result = stmt.executeQuery()
                    if (!result.next()) {
                        throw DatabaseException("corrupt database, missing test parameter")
                    }
                    val readFromDb = result.getDecryptedString(1, database.encryption)
                    if (readFromDb != PLAINTEXT) {
                        /* might also get thrown by getDecryptedString if bad */
                        throw GeneralSecurityException("bad key!")
                    }
                }
            } catch (e: SQLException) {
                e.printStackTrace()
                throw DatabaseException("unable to verify decryption key", e)
            } catch (e: GeneralSecurityException) {
                throw DatabaseException("invalid decryption key", e)
            }
        }
    }

    fun makeKey(name: String): Long = Hashing.hash(encryption.encryptFromString0(name.lowercase()))
}

class DatabaseException(message: String, cause: Throwable? = null) : MessagedException(message, cause)

fun ResultSet.getDecryptedString(columnIndex: Int, encryption: Encryption): String? {
    return encryption.decryptToString(getBytes(columnIndex) ?: return null)
}

fun ResultSet.getDecrypted(columnIndex: Int, encryption: Encryption): CharArray? {
    return encryption.decrypt(getBytes(columnIndex) ?: return null)
}

fun PreparedStatement.setEncryptedString(columnIndex: Int, value: String?, encryption: Encryption) =
    if (value == null) {
        setNull(columnIndex, Types.BLOB)
    } else {
        setBytes(columnIndex, encryption.encryptFromString(value))
    }

fun PreparedStatement.setEncrypted(columnIndex: Int, value: CharArray?, encryption: Encryption) =
    if (value == null) {
        setNull(columnIndex, Types.BLOB)
    } else {
        setBytes(columnIndex, encryption.encrypt(value))
    }

fun PreparedStatement.setBytesOrNull(columnIndex: Int, value: ByteArray?) =
    if (value == null) {
        setNull(columnIndex, Types.BLOB)
    } else {
        setBytes(columnIndex, value)
}

fun PreparedStatement.setLongOrNull(columnIndex: Int, value: Long?) =
    if (value == null) {
        setNull(columnIndex, Types.INTEGER)
    } else {
        setLong(columnIndex, value)
    }

fun PreparedStatement.setDateOrNull(parameterIndex: Int, value: Long?) {
    if (value == null || value == 0L) {
        setNull(parameterIndex, Types.INTEGER)
    } else {
        setLong(parameterIndex, value)
    }
}