diff --git a/include/ncrypto.h b/include/ncrypto.h index be9e0ca..f8000de 100644 --- a/include/ncrypto.h +++ b/include/ncrypto.h @@ -1,5 +1,15 @@ #pragma once +#include "root.h" + +#ifdef ASSERT_ENABLED +#define NCRYPTO_DEVELOPMENT_CHECKS 1 +#endif + +#include +#include +#include + #include #include #include @@ -61,30 +71,11 @@ namespace ncrypto { #if NCRYPTO_DEVELOPMENT_CHECKS #define NCRYPTO_STR(x) #x -#define NCRYPTO_REQUIRE(EXPR) \ - { \ - if (!(EXPR) { abort(); }) } - -#define NCRYPTO_FAIL(MESSAGE) \ - do { \ - std::cerr << "FAIL: " << (MESSAGE) << std::endl; \ - abort(); \ - } while (0); -#define NCRYPTO_ASSERT_EQUAL(LHS, RHS, MESSAGE) \ - do { \ - if (LHS != RHS) { \ - std::cerr << "Mismatch: '" << LHS << "' - '" << RHS << "'" << std::endl; \ - NCRYPTO_FAIL(MESSAGE); \ - } \ - } while (0); -#define NCRYPTO_ASSERT_TRUE(COND) \ - do { \ - if (!(COND)) { \ - std::cerr << "Assert at line " << __LINE__ << " of file " << __FILE__ \ - << std::endl; \ - NCRYPTO_FAIL(NCRYPTO_STR(COND)); \ - } \ - } while (0); +#define NCRYPTO_REQUIRE(EXPR) ASSERT_WITH_MESSAGE(EXPR, "Assertion failed") +#define NCRYPTO_FAIL(MESSAGE) ASSERT_WITH_MESSAGE(false, MESSAGE) +#define NCRYPTO_ASSERT_EQUAL(LHS, RHS, MESSAGE) \ + ASSERT_WITH_MESSAGE(LHS == RHS, MESSAGE) +#define NCRYPTO_ASSERT_TRUE(COND) ASSERT_WITH_MESSAGE(COND, NCRYPTO_STR(COND)) #else #define NCRYPTO_FAIL(MESSAGE) #define NCRYPTO_ASSERT_EQUAL(LHS, RHS, MESSAGE) @@ -131,9 +122,9 @@ class CryptoErrorList final { void capture(); // Add an error message to the end of the stack. - void add(std::string message); + void add(WTF::String message); - inline const std::string& peek_back() const { return errors_.back(); } + inline const WTF::String& peek_back() const { return errors_.back(); } inline size_t size() const { return errors_.size(); } inline bool empty() const { return errors_.empty(); } @@ -142,11 +133,11 @@ class CryptoErrorList final { inline auto rbegin() const noexcept { return errors_.rbegin(); } inline auto rend() const noexcept { return errors_.rend(); } - std::optional pop_back(); - std::optional pop_front(); + std::optional pop_back(); + std::optional pop_front(); private: - std::list errors_; + std::list errors_; }; // Forcibly clears the error stack on destruction. This stops stale errors @@ -277,12 +268,12 @@ class Cipher final { int getIvLength() const; int getKeyLength() const; int getBlockSize() const; - std::string_view getModeLabel() const; - std::string_view getName() const; + WTF::ASCIILiteral getModeLabel() const; + WTF::String getName() const; bool isSupportedAuthenticatedMode() const; - static const Cipher FromName(std::string_view name); + static const Cipher FromName(WTF::StringView name); static const Cipher FromNid(int nid); static const Cipher FromCtx(const CipherCtxPointer& ctx); @@ -336,6 +327,8 @@ class Dsa final { }; class BignumPointer final { + WTF_MAKE_TZONE_ALLOCATED(BignumPointer); + public: BignumPointer() = default; explicit BignumPointer(BIGNUM* bignum); @@ -429,8 +422,8 @@ class Rsa final { const BIGNUM* qi; }; struct PssParams { - std::string_view digest = "sha1"; - std::optional mgf1_digest = "sha1"; + WTF::StringView digest = "sha1"_s; + std::optional mgf1_digest = "sha1"_s; int64_t salt_length = 20; }; @@ -465,7 +458,7 @@ class Ec final { const EC_GROUP* getGroup() const; int getCurve() const; uint32_t getDegree() const; - std::string getCurveName() const; + WTF::String getCurveName() const; const EC_POINT* getPublicKey() const; const BIGNUM* getPrivateKey() const; @@ -535,13 +528,15 @@ class DataPointer final { }; class BIOPointer final { + WTF_MAKE_TZONE_ALLOCATED(BIOPointer); + public: static BIOPointer NewMem(); static BIOPointer NewSecMem(); static BIOPointer New(const BIO_METHOD* method); static BIOPointer New(const void* data, size_t len); static BIOPointer New(const BIGNUM* bn); - static BIOPointer NewFile(std::string_view filename, std::string_view mode); + static BIOPointer NewFile(WTF::StringView filename, WTF::StringView mode); static BIOPointer NewFp(FILE* fd, int flags); template @@ -575,7 +570,7 @@ class BIOPointer final { bool resetBio() const; - static int Write(BIOPointer* bio, std::string_view message); + static int Write(BIOPointer* bio, WTF::StringView message); template static void Printf(BIOPointer* bio, const char* format, Args... args) { @@ -588,6 +583,8 @@ class BIOPointer final { }; class CipherCtxPointer final { + WTF_MAKE_TZONE_ALLOCATED(CipherCtxPointer); + public: static CipherCtxPointer New(); @@ -630,6 +627,8 @@ class CipherCtxPointer final { }; class EVPKeyCtxPointer final { + WTF_MAKE_TZONE_ALLOCATED(EVPKeyCtxPointer); + public: EVPKeyCtxPointer(); explicit EVPKeyCtxPointer(EVP_PKEY_CTX* ctx); @@ -697,6 +696,8 @@ class EVPKeyCtxPointer final { }; class EVPKeyPointer final { + WTF_MAKE_TZONE_ALLOCATED(EVPKeyPointer); + public: static EVPKeyPointer New(); static EVPKeyPointer NewRawPublic(int id, @@ -821,6 +822,8 @@ class EVPKeyPointer final { }; class DHPointer final { + WTF_MAKE_TZONE_ALLOCATED(DHPointer); + public: enum class FindGroupOption { NONE, @@ -833,9 +836,9 @@ class DHPointer final { static BignumPointer GetStandardGenerator(); static BignumPointer FindGroup( - const std::string_view name, + const WTF::StringView name, FindGroupOption option = FindGroupOption::NONE); - static DHPointer FromGroup(const std::string_view name, + static DHPointer FromGroup(const WTF::StringView name, FindGroupOption option = FindGroupOption::NONE); static DHPointer New(BignumPointer&& p, BignumPointer&& g); @@ -910,6 +913,8 @@ struct StackOfX509Deleter { using StackOfX509 = std::unique_ptr; class SSLCtxPointer final { + WTF_MAKE_TZONE_ALLOCATED(SSLCtxPointer); + public: SSLCtxPointer() = default; explicit SSLCtxPointer(SSL_CTX* ctx); @@ -943,6 +948,8 @@ class SSLCtxPointer final { }; class SSLPointer final { + WTF_MAKE_TZONE_ALLOCATED(SSLPointer); + public: SSLPointer() = default; explicit SSLPointer(SSL* ssl); @@ -961,31 +968,33 @@ class SSLPointer final { bool setSession(const SSLSessionPointer& session); bool setSniContext(const SSLCtxPointer& ctx) const; - const std::string_view getClientHelloAlpn() const; - const std::string_view getClientHelloServerName() const; + const WTF::StringView getClientHelloAlpn() const; + const WTF::StringView getClientHelloServerName() const; - std::optional getServerName() const; + std::optional getServerName() const; X509View getCertificate() const; EVPKeyPointer getPeerTempKey() const; const SSL_CIPHER* getCipher() const; bool isServer() const; - std::optional getCipherName() const; - std::optional getCipherStandardName() const; - std::optional getCipherVersion() const; + std::optional getCipherName() const; + std::optional getCipherStandardName() const; + std::optional getCipherVersion() const; std::optional verifyPeerCertificate() const; - void getCiphers(std::function cb) const; + void getCiphers(WTF::Function&& cb) const; static SSLPointer New(const SSLCtxPointer& ctx); - static std::optional GetServerName(const SSL* ssl); + static std::optional GetServerName(const SSL* ssl); private: DeleteFnPtr ssl_; }; class X509Name final { + WTF_MAKE_TZONE_ALLOCATED(X509Name); + public: X509Name(); explicit X509Name(const X509_NAME* name); @@ -1007,7 +1016,7 @@ class X509Name final { operator bool() const; bool operator==(const Iterator& other) const; bool operator!=(const Iterator& other) const; - std::pair operator*() const; + std::pair operator*() const; private: const X509Name& name_; @@ -1062,7 +1071,7 @@ class X509View final { bool checkPrivateKey(const EVPKeyPointer& pkey) const; bool checkPublicKey(const EVPKeyPointer& pkey) const; - std::optional getFingerprint(const EVP_MD* method) const; + std::optional getFingerprint(const EVP_MD* method) const; X509Pointer clone() const; @@ -1072,16 +1081,16 @@ class X509View final { INVALID_NAME, OPERATION_FAILED, }; - CheckMatch checkHost(const std::string_view host, int flags, + CheckMatch checkHost(const std::span host, int flags, DataPointer* peerName = nullptr) const; - CheckMatch checkEmail(const std::string_view email, int flags) const; - CheckMatch checkIp(const std::string_view ip, int flags) const; + CheckMatch checkEmail(const std::span email, int flags) const; + CheckMatch checkIp(const char* ip, int flags) const; - using UsageCallback = std::function; + using UsageCallback = WTF::Function)>; bool enumUsages(UsageCallback callback) const; template - using KeyCallback = std::function; + using KeyCallback = WTF::Function; bool ifRsa(KeyCallback callback) const; bool ifEc(KeyCallback callback) const; @@ -1090,6 +1099,8 @@ class X509View final { }; class X509Pointer final { + WTF_MAKE_TZONE_ALLOCATED(X509Pointer); + public: static Result Parse(Buffer buffer); static X509Pointer IssuerFrom(const SSLPointer& ssl, const X509View& view); @@ -1114,14 +1125,16 @@ class X509Pointer final { X509View view() const; operator X509View() const { return view(); } - static std::string_view ErrorCode(int32_t err); - static std::optional ErrorReason(int32_t err); + static WTF::ASCIILiteral ErrorCode(int32_t err); + static std::optional ErrorReason(int32_t err); private: DeleteFnPtr cert_; }; class ECDSASigPointer final { + WTF_MAKE_TZONE_ALLOCATED(ECDSASigPointer); + public: explicit ECDSASigPointer(); explicit ECDSASigPointer(ECDSA_SIG* sig); @@ -1154,6 +1167,8 @@ class ECDSASigPointer final { }; class ECGroupPointer final { + WTF_MAKE_TZONE_ALLOCATED(ECGroupPointer); + public: explicit ECGroupPointer(); explicit ECGroupPointer(EC_GROUP* group); @@ -1176,6 +1191,8 @@ class ECGroupPointer final { }; class ECPointPointer final { + WTF_MAKE_TZONE_ALLOCATED(ECPointPointer); + public: ECPointPointer(); explicit ECPointPointer(EC_POINT* point); @@ -1202,6 +1219,8 @@ class ECPointPointer final { }; class ECKeyPointer final { + WTF_MAKE_TZONE_ALLOCATED(ECKeyPointer); + public: ECKeyPointer(); explicit ECKeyPointer(EC_KEY* key); @@ -1242,6 +1261,8 @@ class ECKeyPointer final { }; class EVPMDCtxPointer final { + WTF_MAKE_TZONE_ALLOCATED(EVPMDCtxPointer); + public: EVPMDCtxPointer(); explicit EVPMDCtxPointer(EVP_MD_CTX* ctx); @@ -1286,6 +1307,8 @@ class EVPMDCtxPointer final { }; class HMACCtxPointer final { + WTF_MAKE_TZONE_ALLOCATED(HMACCtxPointer); + public: HMACCtxPointer(); explicit HMACCtxPointer(HMAC_CTX* ctx); @@ -1331,7 +1354,7 @@ class EnginePointer final { bool setAsDefault(uint32_t flags, CryptoErrorList* errors = nullptr); bool init(bool finish_on_exit = false); - EVPKeyPointer loadPrivateKey(const std::string_view key_name); + EVPKeyPointer loadPrivateKey(const WTF::StringView key_name); // Release ownership of the ENGINE* pointer. ENGINE* release(); @@ -1339,7 +1362,7 @@ class EnginePointer final { // Retrieve an OpenSSL Engine instance by name. If the name does not // identify a valid named engine, the returned EnginePointer will be // empty. - static EnginePointer getEngineByName(const std::string_view name, + static EnginePointer getEngineByName(const WTF::StringView name, CryptoErrorList* errors = nullptr); // Call once when initializing OpenSSL at startup for the process. @@ -1396,8 +1419,8 @@ DataPointer ExportChallenge(const Buffer& buf); // ============================================================================ // KDF -const EVP_MD* getDigestByName(const std::string_view name); -const EVP_CIPHER* getCipherByName(const std::string_view name); +const EVP_MD* getDigestByName(const WTF::StringView name); +const EVP_CIPHER* getCipherByName(const WTF::StringView name); // Verify that the specified HKDF output length is valid for the given digest. // The maximum length for HKDF output for a given digest is 255 times the diff --git a/src/ncrypto.cpp b/src/ncrypto.cpp index 2e411ce..2315eb5 100644 --- a/src/ncrypto.cpp +++ b/src/ncrypto.cpp @@ -1,3 +1,8 @@ +#include "root.h" +#include "wtf/text/ASCIILiteral.h" +#include "wtf/text/StringImpl.h" +#include "wtf/text/WTFString.h" + #include "ncrypto.h" #include @@ -75,22 +80,22 @@ void CryptoErrorList::capture() { while (const auto err = ERR_get_error()) { char buf[256]; ERR_error_string_n(err, buf, sizeof(buf)); - errors_.emplace_front(buf); + errors_.emplace_front(WTF::String::fromUTF8(buf)); } } -void CryptoErrorList::add(std::string error) { errors_.push_back(error); } +void CryptoErrorList::add(WTF::String error) { errors_.push_back(error); } -std::optional CryptoErrorList::pop_back() { +std::optional CryptoErrorList::pop_back() { if (errors_.empty()) return std::nullopt; - std::string error = errors_.back(); + WTF::String error = errors_.back(); errors_.pop_back(); return error; } -std::optional CryptoErrorList::pop_front() { +std::optional CryptoErrorList::pop_front() { if (errors_.empty()) return std::nullopt; - std::string error = errors_.front(); + WTF::String error = errors_.front(); errors_.pop_front(); return error; } @@ -1104,7 +1109,8 @@ bool X509View::checkPublicKey(const EVPKeyPointer& pkey) const { return X509_verify(const_cast(cert_), pkey.get()) == 1; } -X509View::CheckMatch X509View::checkHost(const std::string_view host, int flags, +X509View::CheckMatch X509View::checkHost(const std::span host, + int flags, DataPointer* peerName) const { ClearErrorOnReturn clearErrorOnReturn; if (cert_ == nullptr) return CheckMatch::NO_MATCH; @@ -1127,7 +1133,7 @@ X509View::CheckMatch X509View::checkHost(const std::string_view host, int flags, } } -X509View::CheckMatch X509View::checkEmail(const std::string_view email, +X509View::CheckMatch X509View::checkEmail(const std::span email, int flags) const { ClearErrorOnReturn clearErrorOnReturn; if (cert_ == nullptr) return CheckMatch::NO_MATCH; @@ -1144,11 +1150,10 @@ X509View::CheckMatch X509View::checkEmail(const std::string_view email, } } -X509View::CheckMatch X509View::checkIp(const std::string_view ip, - int flags) const { +X509View::CheckMatch X509View::checkIp(const char* ip, int flags) const { ClearErrorOnReturn clearErrorOnReturn; if (cert_ == nullptr) return CheckMatch::NO_MATCH; - switch (X509_check_ip_asc(const_cast(cert_), ip.data(), flags)) { + switch (X509_check_ip_asc(const_cast(cert_), ip, flags)) { case 0: return CheckMatch::NO_MATCH; case 1: @@ -1172,7 +1177,7 @@ X509View X509View::From(const SSLCtxPointer& ctx) { return X509View(SSL_CTX_get0_certificate(ctx.get())); } -std::optional X509View::getFingerprint( +std::optional X509View::getFingerprint( const EVP_MD* method) const { unsigned int md_size; unsigned char md[EVP_MAX_MD_SIZE]; @@ -1180,7 +1185,9 @@ std::optional X509View::getFingerprint( if (X509_digest(get(), method, md, &md_size)) { if (md_size == 0) return std::nullopt; - std::string fingerprint((md_size * 3) - 1, 0); + std::span fingerprint; + WTF::String fingerprintStr = + WTF::String::createUninitialized((md_size * 3) - 1, fingerprint); for (unsigned int i = 0; i < md_size; i++) { auto idx = 3 * i; fingerprint[idx] = hex[(md[i] & 0xf0) >> 4]; @@ -1189,7 +1196,7 @@ std::optional X509View::getFingerprint( fingerprint[idx + 2] = ':'; } - return fingerprint; + return fingerprintStr; } return std::nullopt; @@ -1299,10 +1306,10 @@ X509Pointer X509Pointer::PeerFrom(const SSLPointer& ssl) { // When adding or removing errors below, please also update the list in the API // documentation. See the "OpenSSL Error Codes" section of doc/api/errors.md // Also *please* update the respective section in doc/api/tls.md as well -std::string_view X509Pointer::ErrorCode(int32_t err) { // NOLINT(runtime/int) +WTF::ASCIILiteral X509Pointer::ErrorCode(int32_t err) { // NOLINT(runtime/int) #define CASE(CODE) \ case X509_V_ERR_##CODE: \ - return #CODE; + return #CODE##_s; switch (err) { CASE(UNABLE_TO_GET_ISSUER_CERT) CASE(UNABLE_TO_GET_CRL) @@ -1334,12 +1341,24 @@ std::string_view X509Pointer::ErrorCode(int32_t err) { // NOLINT(runtime/int) CASE(HOSTNAME_MISMATCH) } #undef CASE - return "UNSPECIFIED"; + return "UNSPECIFIED"_s; } -std::optional X509Pointer::ErrorReason(int32_t err) { +std::optional X509Pointer::ErrorReason(int32_t err) { if (err == X509_V_OK) return std::nullopt; - return X509_verify_cert_error_string(err); + + // TODO(dylan-conway): delete this switch? + switch (err) { +#define V(name, msg) \ + case X509_V_ERR_##name: \ + return msg##_s; + V(HOSTNAME_MISMATCH, "Hostname does not match certificate") + V(EMAIL_MISMATCH, "Email address does not match certificate") + V(IP_ADDRESS_MISMATCH, "IP address does not match certificate") +#undef V + } + return WTF::ASCIILiteral::fromLiteralUnsafe( + X509_verify_cert_error_string(err)); } // ============================================================================ @@ -1385,9 +1404,10 @@ BIOPointer BIOPointer::New(const void* data, size_t len) { return BIOPointer(BIO_new_mem_buf(data, len)); } -BIOPointer BIOPointer::NewFile(std::string_view filename, - std::string_view mode) { - return BIOPointer(BIO_new_file(filename.data(), mode.data())); +BIOPointer BIOPointer::NewFile(WTF::StringView filename, WTF::StringView mode) { + auto filenameUtf8 = filename.utf8(); + auto modeUtf8 = mode.utf8(); + return BIOPointer(BIO_new_file(filenameUtf8.data(), modeUtf8.data())); } BIOPointer BIOPointer::NewFp(FILE* fd, int close_flag) { @@ -1400,20 +1420,18 @@ BIOPointer BIOPointer::New(const BIGNUM* bn) { return res; } -int BIOPointer::Write(BIOPointer* bio, std::string_view message) { - if (bio == nullptr || !*bio) return 0; - return BIO_write(bio->get(), message.data(), message.size()); +int BIOPointer::Write(BIOPointer* bio, WTF::StringView message) { + auto messageUtf8 = message.utf8(); + return Write(bio, messageUtf8.span()); } // ============================================================================ // DHPointer namespace { -bool EqualNoCase(const std::string_view a, const std::string_view b) { - if (a.size() != b.size()) return false; - return std::equal(a.begin(), a.end(), b.begin(), b.end(), [](char a, char b) { - return std::tolower(a) == std::tolower(b); - }); +bool EqualNoCase(const WTF::StringView a, const WTF::StringView b) { + if (a.length() != b.length()) return false; + return a.startsWithIgnoringASCIICase(b); } } // namespace @@ -1433,23 +1451,23 @@ void DHPointer::reset(DH* dh) { dh_.reset(dh); } DH* DHPointer::release() { return dh_.release(); } -BignumPointer DHPointer::FindGroup(const std::string_view name, +BignumPointer DHPointer::FindGroup(const WTF::StringView name, FindGroupOption option) { #define V(n, p) \ if (EqualNoCase(name, n)) return BignumPointer(p(nullptr)); if (option != FindGroupOption::NO_SMALL_PRIMES) { #ifndef OPENSSL_IS_BORINGSSL // Boringssl does not support the 768 and 1024 small primes - V("modp1", BN_get_rfc2409_prime_768); - V("modp2", BN_get_rfc2409_prime_1024); + V("modp1"_s, BN_get_rfc2409_prime_768); + V("modp2"_s, BN_get_rfc2409_prime_1024); #endif - V("modp5", BN_get_rfc3526_prime_1536); + V("modp5"_s, BN_get_rfc3526_prime_1536); } - V("modp14", BN_get_rfc3526_prime_2048); - V("modp15", BN_get_rfc3526_prime_3072); - V("modp16", BN_get_rfc3526_prime_4096); - V("modp17", BN_get_rfc3526_prime_6144); - V("modp18", BN_get_rfc3526_prime_8192); + V("modp14"_s, BN_get_rfc3526_prime_2048); + V("modp15"_s, BN_get_rfc3526_prime_3072); + V("modp16"_s, BN_get_rfc3526_prime_4096); + V("modp17"_s, BN_get_rfc3526_prime_6144); + V("modp18"_s, BN_get_rfc3526_prime_8192); #undef V return {}; } @@ -1461,7 +1479,7 @@ BignumPointer DHPointer::GetStandardGenerator() { return bn; } -DHPointer DHPointer::FromGroup(const std::string_view name, +DHPointer DHPointer::FromGroup(const WTF::StringView name, FindGroupOption option) { auto group = FindGroup(name, option); if (!group) return {}; // Unable to find the named group. @@ -1469,7 +1487,7 @@ DHPointer DHPointer::FromGroup(const std::string_view name, auto generator = GetStandardGenerator(); if (!generator) return {}; // Unable to create the generator. - return New(std::move(group), std::move(generator)); + return New(WTFMove(group), WTFMove(generator)); } DHPointer DHPointer::New(BignumPointer&& p, BignumPointer&& g) { @@ -1663,17 +1681,24 @@ DataPointer DHPointer::stateless(const EVPKeyPointer& ourKey, // ============================================================================ // KDF -const EVP_MD* getDigestByName(const std::string_view name) { +const EVP_MD* getDigestByName(const WTF::StringView name) { // Historically, "dss1" and "DSS1" were DSA aliases for SHA-1 // exposed through the public API. - if (name == "dss1" || name == "DSS1") [[unlikely]] { + if (name == "dss1"_s || name == "DSS1"_s) [[unlikely]] { return EVP_sha1(); } - return EVP_get_digestbyname(name.data()); + + // if (name == "ripemd160WithRSA"_s || name == "RSA-RIPEMD160"_s) { + // return EVP_ripemd160(); + // } + + auto nameUtf8 = name.utf8(); + return EVP_get_digestbyname(nameUtf8.data()); } -const EVP_CIPHER* getCipherByName(const std::string_view name) { - return EVP_get_cipherbyname(name.data()); +const EVP_CIPHER* getCipherByName(const WTF::StringView name) { + auto nameUtf8 = name.utf8(); + return EVP_get_cipherbyname(nameUtf8.data()); } bool checkHkdfLength(const EVP_MD* md, size_t length) { @@ -2499,7 +2524,7 @@ SSLPointer SSLPointer::New(const SSLCtxPointer& ctx) { } void SSLPointer::getCiphers( - std::function cb) const { + WTF::Function&& cb) const { if (!ssl_) return; STACK_OF(SSL_CIPHER)* ciphers = SSL_get_ciphers(get()); @@ -2507,16 +2532,16 @@ void SSLPointer::getCiphers( // document them, but since there are only 5, easier to just add them manually // and not have to explain their absence in the API docs. They are lower-cased // because the docs say they will be. - static constexpr const char* TLS13_CIPHERS[] = { - "tls_aes_256_gcm_sha384", "tls_chacha20_poly1305_sha256", - "tls_aes_128_gcm_sha256", "tls_aes_128_ccm_8_sha256", - "tls_aes_128_ccm_sha256"}; + static constexpr WTF::ASCIILiteral TLS13_CIPHERS[] = { + "tls_aes_256_gcm_sha384"_s, "tls_chacha20_poly1305_sha256"_s, + "tls_aes_128_gcm_sha256"_s, "tls_aes_128_ccm_8_sha256"_s, + "tls_aes_128_ccm_sha256"_s}; const int n = sk_SSL_CIPHER_num(ciphers); for (int i = 0; i < n; ++i) { const SSL_CIPHER* cipher = sk_SSL_CIPHER_value(ciphers, i); - cb(SSL_CIPHER_get_name(cipher)); + cb(WTF::ASCIILiteral::fromLiteralUnsafe(SSL_CIPHER_get_name(cipher))); } for (unsigned i = 0; i < 5; ++i) { @@ -2562,7 +2587,7 @@ std::optional SSLPointer::verifyPeerCertificate() const { return std::nullopt; } -const std::string_view SSLPointer::getClientHelloAlpn() const { +const WTF::StringView SSLPointer::getClientHelloAlpn() const { if (ssl_ == nullptr) return {}; #ifndef OPENSSL_IS_BORINGSSL const unsigned char* buf; @@ -2585,7 +2610,7 @@ const std::string_view SSLPointer::getClientHelloAlpn() const { #endif } -const std::string_view SSLPointer::getClientHelloServerName() const { +const WTF::StringView SSLPointer::getClientHelloServerName() const { if (ssl_ == nullptr) return {}; #ifndef OPENSSL_IS_BORINGSSL const unsigned char* buf; @@ -2613,15 +2638,14 @@ const std::string_view SSLPointer::getClientHelloServerName() const { #endif } -std::optional SSLPointer::GetServerName( - const SSL* ssl) { +std::optional SSLPointer::GetServerName(const SSL* ssl) { if (ssl == nullptr) return std::nullopt; auto res = SSL_get_servername(ssl, TLSEXT_NAMETYPE_host_name); if (res == nullptr) return std::nullopt; - return res; + return WTF::String::fromUTF8(res); } -std::optional SSLPointer::getServerName() const { +std::optional SSLPointer::getServerName() const { if (!ssl_) return std::nullopt; return GetServerName(get()); } @@ -2650,22 +2674,28 @@ EVPKeyPointer SSLPointer::getPeerTempKey() const { return EVPKeyPointer(raw_key); } -std::optional SSLPointer::getCipherName() const { +std::optional SSLPointer::getCipherName() const { auto cipher = getCipher(); if (cipher == nullptr) return std::nullopt; - return SSL_CIPHER_get_name(cipher); + const char* name = SSL_CIPHER_get_name(cipher); + if (!name) return std::nullopt; + return WTF::StringView::fromLatin1(name); } -std::optional SSLPointer::getCipherStandardName() const { +std::optional SSLPointer::getCipherStandardName() const { auto cipher = getCipher(); if (cipher == nullptr) return std::nullopt; - return SSL_CIPHER_standard_name(cipher); + const char* name = SSL_CIPHER_standard_name(cipher); + if (!name) return std::nullopt; + return WTF::StringView::fromLatin1(name); } -std::optional SSLPointer::getCipherVersion() const { +std::optional SSLPointer::getCipherVersion() const { auto cipher = getCipher(); if (cipher == nullptr) return std::nullopt; - return SSL_CIPHER_get_version(cipher); + auto version = SSL_CIPHER_get_version(cipher); + if (!version) return std::nullopt; + return WTF::StringView::fromLatin1(version); } SSLCtxPointer::SSLCtxPointer(SSL_CTX* ctx) : ctx_(ctx) {} @@ -2713,8 +2743,9 @@ bool SSLCtxPointer::setGroups(const char* groups) { // ============================================================================ -const Cipher Cipher::FromName(std::string_view name) { - return Cipher(EVP_get_cipherbyname(name.data())); +const Cipher Cipher::FromName(WTF::StringView name) { + auto nameUtf8 = name.utf8(); + return Cipher(EVP_get_cipherbyname(nameUtf8.data())); } const Cipher Cipher::FromNid(int nid) { @@ -2750,40 +2781,40 @@ int Cipher::getNid() const { return EVP_CIPHER_nid(cipher_); } -std::string_view Cipher::getModeLabel() const { +WTF::ASCIILiteral Cipher::getModeLabel() const { if (!cipher_) return {}; switch (getMode()) { case EVP_CIPH_CCM_MODE: - return "ccm"; + return "ccm"_s; case EVP_CIPH_CFB_MODE: - return "cfb"; + return "cfb"_s; case EVP_CIPH_CBC_MODE: - return "cbc"; + return "cbc"_s; case EVP_CIPH_CTR_MODE: - return "ctr"; + return "ctr"_s; case EVP_CIPH_ECB_MODE: - return "ecb"; + return "ecb"_s; case EVP_CIPH_GCM_MODE: - return "gcm"; + return "gcm"_s; case EVP_CIPH_OCB_MODE: - return "ocb"; + return "ocb"_s; case EVP_CIPH_OFB_MODE: - return "ofb"; + return "ofb"_s; case EVP_CIPH_WRAP_MODE: - return "wrap"; + return "wrap"_s; case EVP_CIPH_XTS_MODE: - return "xts"; + return "xts"_s; case EVP_CIPH_STREAM_CIPHER: - return "stream"; + return "stream"_s; } - return "{unknown}"; + return "{unknown}"_s; } -std::string_view Cipher::getName() const { +WTF::String Cipher::getName() const { if (!cipher_) return {}; // OBJ_nid2sn(EVP_CIPHER_nid(cipher)) is used here instead of // EVP_CIPHER_name(cipher) for compatibility with BoringSSL. - return OBJ_nid2sn(getNid()); + return WTF::String::fromUTF8(OBJ_nid2sn(getNid())); } bool Cipher::isSupportedAuthenticatedMode() const { @@ -3497,15 +3528,15 @@ const std::optional Rsa::getPssParams() const { const RSA_PSS_PARAMS* params = RSA_get0_pss_params(rsa_); if (params == nullptr) return std::nullopt; Rsa::PssParams ret{ - .digest = OBJ_nid2ln(NID_sha1), - .mgf1_digest = OBJ_nid2ln(NID_sha1), + .digest = WTF::StringView::fromLatin1(OBJ_nid2ln(NID_sha1)), + .mgf1_digest = WTF::StringView::fromLatin1(OBJ_nid2ln(NID_sha1)), .salt_length = 20, }; if (params->hashAlgorithm != nullptr) { const ASN1_OBJECT* hash_obj; X509_ALGOR_get0(&hash_obj, nullptr, nullptr, params->hashAlgorithm); - ret.digest = OBJ_nid2ln(OBJ_obj2nid(hash_obj)); + ret.digest = WTF::StringView::fromLatin1(OBJ_nid2ln(OBJ_obj2nid(hash_obj))); } if (params->maskGenAlgorithm != nullptr) { @@ -3515,7 +3546,8 @@ const std::optional Rsa::getPssParams() const { if (mgf_nid == NID_mgf1) { const ASN1_OBJECT* mgf1_hash_obj; X509_ALGOR_get0(&mgf1_hash_obj, nullptr, nullptr, params->maskHash); - ret.mgf1_digest = OBJ_nid2ln(OBJ_obj2nid(mgf1_hash_obj)); + ret.mgf1_digest = + WTF::StringView::fromLatin1(OBJ_nid2ln(OBJ_obj2nid(mgf1_hash_obj))); } } @@ -3627,8 +3659,8 @@ int Ec::getCurve() const { return EC_GROUP_get_curve_name(getGroup()); } uint32_t Ec::getDegree() const { return EC_GROUP_get_degree(getGroup()); } -std::string Ec::getCurveName() const { - return std::string(OBJ_nid2sn(getCurve())); +WTF::String Ec::getCurveName() const { + return WTF::String::fromUTF8(OBJ_nid2sn(getCurve())); } const EC_POINT* Ec::getPublicKey() const { return EC_KEY_get0_public_key(ec_); } @@ -3891,7 +3923,7 @@ bool X509Name::Iterator::operator!=(const Iterator& other) const { return loc_ != other.loc_; } -std::pair X509Name::Iterator::operator*() const { +std::pair X509Name::Iterator::operator*() const { if (loc_ == name_.total_) return {{}, {}}; X509_NAME_ENTRY* entry = X509_NAME_get_entry(name_, loc_); @@ -3906,21 +3938,22 @@ std::pair X509Name::Iterator::operator*() const { } int nid = OBJ_obj2nid(name); - std::string name_str; + WTF::String name_str; if (nid != NID_undef) { - name_str = std::string(OBJ_nid2sn(nid)); + name_str = WTF::String::fromUTF8(OBJ_nid2sn(nid)); } else { char buf[80]; OBJ_obj2txt(buf, sizeof(buf), name, 0); - name_str = std::string(buf); + name_str = WTF::String::fromUTF8(buf); } unsigned char* value_str; int value_str_size = ASN1_STRING_to_UTF8(&value_str, value); return { - std::move(name_str), - std::string(reinterpret_cast(value_str), value_str_size)}; + name_str, + WTF::String::fromUTF8(std::span(value_str, value_str_size)), + }; } // ============================================================================