//
// VMime library (http://www.vmime.org)
// Copyright (C) 2002 Vincent Richard <vincent@vmime.org>
//
// This program is free software; you can redistribute it and/or
// modify it under the terms of the GNU General Public License as
// published by the Free Software Foundation; either version 3 of
// the License, or (at your option) any later version.
//
// This program is distributed in the hope that it will be useful,
// but WITHOUT ANY WARRANTY; without even the implied warranty of
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
// General Public License for more details.
//
// You should have received a copy of the GNU General Public License along
// with this program; if not, write to the Free Software Foundation, Inc.,
// 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA.
//
// Linking this library statically or dynamically with other modules is making
// a combined work based on this library. Thus, the terms and conditions of
// the GNU General Public License cover the whole combination.
//
#include "vmime/config.hpp"
#if VMIME_HAVE_MESSAGING_FEATURES && VMIME_HAVE_TLS_SUPPORT && VMIME_TLS_SUPPORT_LIB_IS_GNUTLS
#include <gnutls/gnutls.h>
#include <gnutls/x509.h>
#include <errno.h>
#include "vmime/net/tls/gnutls/TLSSocket_GnuTLS.hpp"
#include "vmime/net/tls/gnutls/TLSSession_GnuTLS.hpp"
#include "vmime/platform.hpp"
#include "vmime/security/cert/X509Certificate.hpp"
#include "vmime/utility/stringUtils.hpp"
#include <cstring>
namespace vmime {
namespace net {
namespace tls {
// static
shared_ptr <TLSSocket> TLSSocket::wrap(
const shared_ptr <TLSSession>& session,
const shared_ptr <socket>& sok
)
{
return make_shared <TLSSocket_GnuTLS>(dynamicCast <TLSSession_GnuTLS>(session), sok);
}
TLSSocket_GnuTLS::TLSSocket_GnuTLS(
const shared_ptr <TLSSession_GnuTLS>& session,
const shared_ptr <socket>& sok
)
: m_session(session),
m_wrapped(sok),
m_connected(false),
m_ex(NULL),
m_status(0),
m_errno(0) {
gnutls_transport_set_ptr(*m_session->m_gnutlsSession, this);
gnutls_transport_set_push_function(*m_session->m_gnutlsSession, gnutlsPushFunc);
gnutls_transport_set_pull_function(*m_session->m_gnutlsSession, gnutlsPullFunc);
gnutls_transport_set_errno_function(*m_session->m_gnutlsSession, gnutlsErrnoFunc);
}
TLSSocket_GnuTLS::~TLSSocket_GnuTLS() {
resetException();
try {
disconnect();
} catch (...) {
// Don't throw exception in destructor
}
}
void TLSSocket_GnuTLS::connect(const string& address, const port_t port) {
try {
m_wrapped->connect(address, port);
handshake();
} catch (...) {
disconnect();
throw;
}
}
void TLSSocket_GnuTLS::disconnect() {
if (m_connected) {
gnutls_bye(*m_session->m_gnutlsSession, GNUTLS_SHUT_RDWR);
m_wrapped->disconnect();
m_connected = false;
}
}
bool TLSSocket_GnuTLS::isConnected() const {
return m_wrapped->isConnected() && m_connected;
}
size_t TLSSocket_GnuTLS::getBlockSize() const {
return 16384; // 16 KB
}
const string TLSSocket_GnuTLS::getPeerName() const {
return m_wrapped->getPeerName();
}
const string TLSSocket_GnuTLS::getPeerAddress() const {
return m_wrapped->getPeerAddress();
}
shared_ptr <timeoutHandler> TLSSocket_GnuTLS::getTimeoutHandler() {
return m_wrapped->getTimeoutHandler();
}
void TLSSocket_GnuTLS::setTracer(const shared_ptr <net::tracer>& tracer) {
m_wrapped->setTracer(tracer);
}
shared_ptr <net::tracer> TLSSocket_GnuTLS::getTracer() {
return m_wrapped->getTracer();
}
bool TLSSocket_GnuTLS::waitForRead(const int msecs) {
return m_wrapped->waitForRead(msecs);
}
bool TLSSocket_GnuTLS::waitForWrite(const int msecs) {
return m_wrapped->waitForWrite(msecs);
}
void TLSSocket_GnuTLS::receive(string& buffer) {
const size_t size = receiveRaw(m_buffer, sizeof(m_buffer));
buffer = utility::stringUtils::makeStringFromBytes(m_buffer, size);
}
void TLSSocket_GnuTLS::send(const string& buffer) {
sendRaw(reinterpret_cast <const byte_t*>(buffer.data()), buffer.length());
}
void TLSSocket_GnuTLS::send(const char* str) {
sendRaw(reinterpret_cast <const byte_t*>(str), ::strlen(str));
}
size_t TLSSocket_GnuTLS::receiveRaw(byte_t* buffer, const size_t count) {
m_status &= ~(STATUS_WANT_WRITE | STATUS_WANT_READ);
resetException();
const ssize_t ret = gnutls_record_recv(
*m_session->m_gnutlsSession,
buffer, static_cast <size_t>(count)
);
throwException();
if (ret < 0) {
if (ret == GNUTLS_E_AGAIN) {
if (gnutls_record_get_direction(*m_session->m_gnutlsSession) == 0) {
m_status |= STATUS_WANT_READ;
} else {
m_status |= STATUS_WANT_WRITE;
}
return 0;
}
TLSSession_GnuTLS::throwTLSException("gnutls_record_recv", static_cast <int>(ret));
}
return static_cast <size_t>(ret);
}
void TLSSocket_GnuTLS::sendRaw(const byte_t* buffer, const size_t count) {
m_status &= ~(STATUS_WANT_WRITE | STATUS_WANT_READ);
for (size_t size = count ; size > 0 ; ) {
resetException();
ssize_t ret = gnutls_record_send(
*m_session->m_gnutlsSession,
buffer, static_cast <size_t>(size)
);
throwException();
if (ret < 0) {
if (ret == GNUTLS_E_AGAIN) {
if (gnutls_record_get_direction(*m_session->m_gnutlsSession) == 0) {
m_wrapped->waitForRead();
} else {
m_wrapped->waitForWrite();
}
continue;
}
TLSSession_GnuTLS::throwTLSException("gnutls_record_send", static_cast <int>(ret));
} else {
buffer += ret;
size -= ret;
}
}
}
size_t TLSSocket_GnuTLS::sendRawNonBlocking(const byte_t* buffer, const size_t count) {
m_status &= ~(STATUS_WANT_WRITE | STATUS_WANT_READ);
resetException();
ssize_t ret = gnutls_record_send(
*m_session->m_gnutlsSession,
buffer, static_cast <size_t>(count)
);
throwException();
if (ret < 0) {
if (ret == GNUTLS_E_AGAIN) {
if (gnutls_record_get_direction(*m_session->m_gnutlsSession) == 0) {
m_status |= STATUS_WANT_READ;
} else {
m_status |= STATUS_WANT_WRITE;
}
return 0;
}
TLSSession_GnuTLS::throwTLSException("gnutls_record_send", static_cast <int>(ret));
}
return static_cast <size_t>(ret);
}
unsigned int TLSSocket_GnuTLS::getStatus() const {
return m_status | m_wrapped->getStatus();
}
void TLSSocket_GnuTLS::handshake() {
shared_ptr <timeoutHandler> toHandler = m_wrapped->getTimeoutHandler();
if (toHandler) {
toHandler->resetTimeOut();
}
if (getTracer()) {
getTracer()->traceSend("Beginning SSL/TLS handshake");
}
// Start handshaking process
try {
string peerName = getPeerName();
gnutls_server_name_set(*m_session->m_gnutlsSession, GNUTLS_NAME_DNS, peerName.c_str(), peerName.size());
while (true) {
resetException();
const int ret = gnutls_handshake(*m_session->m_gnutlsSession);
throwException();
if (ret < 0) {
if (ret == GNUTLS_E_AGAIN) {
if (gnutls_record_get_direction(*m_session->m_gnutlsSession) == 0) {
m_wrapped->waitForRead();
} else {
m_wrapped->waitForWrite();
}
} else if (ret == GNUTLS_E_INTERRUPTED) {
// Non-fatal error
} else {
TLSSession_GnuTLS::throwTLSException("gnutls_handshake", ret);
}
} else {
// Successful handshake
break;
}
}
} catch (...) {
throw;
}
// Verify server's certificate(s)
shared_ptr <security::cert::certificateChain> certs = getPeerCertificates();
if (certs == NULL) {
throw exceptions::tls_exception("No peer certificate.");
}
m_session->getCertificateVerifier()->verify(certs, getPeerName());
m_connected = true;
}
int TLSSocket_GnuTLS::gnutlsErrnoFunc(gnutls_transport_ptr_t trspt) {
TLSSocket_GnuTLS* sok = reinterpret_cast <TLSSocket_GnuTLS*>(trspt);
return sok->m_errno;
}
ssize_t TLSSocket_GnuTLS::gnutlsPushFunc(
gnutls_transport_ptr_t trspt,
const void* data,
size_t len
) {
TLSSocket_GnuTLS* sok = reinterpret_cast <TLSSocket_GnuTLS*>(trspt);
try {
const ssize_t ret = static_cast <ssize_t>(
sok->m_wrapped->sendRawNonBlocking(reinterpret_cast <const byte_t*>(data), len)
);
if (ret == 0) {
gnutls_transport_set_errno(*sok->m_session->m_gnutlsSession, EAGAIN);
sok->m_errno = EAGAIN;
return -1;
}
return ret;
} catch (exception& e) {
// Workaround for non-portable behaviour when throwing C++ exceptions
// from C functions (GNU TLS)
sok->m_ex = e.clone();
return -1;
}
}
ssize_t TLSSocket_GnuTLS::gnutlsPullFunc(
gnutls_transport_ptr_t trspt,
void* data,
size_t len
) {
TLSSocket_GnuTLS* sok = reinterpret_cast <TLSSocket_GnuTLS*>(trspt);
try {
const ssize_t n = static_cast <ssize_t>(
sok->m_wrapped->receiveRaw(reinterpret_cast <byte_t*>(data), len)
);
if (n == 0) {
gnutls_transport_set_errno(*sok->m_session->m_gnutlsSession, EAGAIN);
sok->m_errno = EAGAIN;
return -1;
}
return n;
} catch (exception& e) {
// Workaround for non-portable behaviour when throwing C++ exceptions
// from C functions (GNU TLS)
sok->m_ex = e.clone();
return -1;
}
}
shared_ptr <security::cert::certificateChain> TLSSocket_GnuTLS::getPeerCertificates() {
if (getTracer()) {
getTracer()->traceSend("Getting peer certificates");
}
unsigned int certCount = 0;
const gnutls_datum_t* rawData = gnutls_certificate_get_peers(
*m_session->m_gnutlsSession, &certCount
);
if (rawData == NULL) {
return null;
}
// Try X.509
gnutls_x509_crt_t* x509Certs = new gnutls_x509_crt_t[certCount];
for (unsigned int i = 0; i < certCount; ++i) {
gnutls_x509_crt_init(x509Certs + i);
int res = gnutls_x509_crt_import(x509Certs[i], rawData + i, GNUTLS_X509_FMT_DER);
if (res < 0) {
for (unsigned int j = 0 ; j <= i ; ++j) {
gnutls_x509_crt_deinit(x509Certs[j]);
}
// XXX more fine-grained error reporting?
delete [] x509Certs;
return null;
}
}
std::vector <shared_ptr <security::cert::certificate> > certs;
bool error = false;
for (unsigned int i = 0 ; i < certCount ; ++i) {
size_t dataSize = 0;
gnutls_x509_crt_export(x509Certs[i], GNUTLS_X509_FMT_DER, NULL, &dataSize);
std::vector <byte_t> data(dataSize);
gnutls_x509_crt_export(x509Certs[i], GNUTLS_X509_FMT_DER, &data[0], &dataSize);
shared_ptr <security::cert::X509Certificate> cert =
security::cert::X509Certificate::import(&data[0], dataSize);
if (cert != NULL) {
certs.push_back(cert);
} else {
error = true;
}
gnutls_x509_crt_deinit(x509Certs[i]);
}
delete [] x509Certs;
if (error) {
return null;
}
return make_shared <security::cert::certificateChain>(certs);
}
// Following is a workaround for C++ exceptions to pass correctly between
// C and C++ calls.
//
// gnutls_record_recv() calls TLSSocket::gnutlsPullFunc, and exceptions
// thrown by the socket can not be caught.
void TLSSocket_GnuTLS::throwException() {
if (m_ex) {
throw *m_ex;
}
}
void TLSSocket_GnuTLS::resetException() {
if (m_ex) {
delete m_ex;
m_ex = NULL;
}
}
} // tls
} // net
} // vmime
#endif // VMIME_HAVE_MESSAGING_FEATURES && VMIME_HAVE_TLS_SUPPORT && VMIME_TLS_SUPPORT_LIB_IS_GNUTLS