strongswan/src/libcharon/plugins/vici/vici_socket.c

724 lines
14 KiB
C

/*
* Copyright (C) 2014 Martin Willi
* Copyright (C) 2014 revosec AG
*
* This program is free software; you can redistribute it and/or modify it
* under the terms of the GNU General Public License as published by the
* Free Software Foundation; either version 2 of the License, or (at your
* option) any later version. See <http://www.fsf.org/copyleft/gpl.txt>.
*
* This program is distributed in the hope that it will be useful, but
* WITHOUT ANY WARRANTY; without even the implied warranty of MERCHANTABILITY
* or FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License
* for more details.
*/
#include "vici_socket.h"
#include <threading/mutex.h>
#include <threading/condvar.h>
#include <threading/thread.h>
#include <collections/array.h>
#include <collections/linked_list.h>
#include <processing/jobs/callback_job.h>
#include <errno.h>
#include <string.h>
typedef struct private_vici_socket_t private_vici_socket_t;
/**
* Private members of vici_socket_t
*/
struct private_vici_socket_t {
/**
* public functions
*/
vici_socket_t public;
/**
* Inbound message callback
*/
vici_inbound_cb_t inbound;
/**
* Client connect callback
*/
vici_connect_cb_t connect;
/**
* Client disconnect callback
*/
vici_disconnect_cb_t disconnect;
/**
* Next client connection identifier
*/
u_int nextid;
/**
* User data for callbacks
*/
void *user;
/**
* Service accepting vici connections
*/
stream_service_t *service;
/**
* Client connections, as entry_t
*/
linked_list_t *connections;
/**
* mutex for client connections
*/
mutex_t *mutex;
};
/**
* Data to securely reference an entry
*/
typedef struct {
/* reference to socket instance */
private_vici_socket_t *this;
/** connection identifier of entry */
u_int id;
} entry_selector_t;
/**
* Partially processed message
*/
typedef struct {
/** bytes of length header sent/received */
u_char hdrlen;
/** bytes of length header */
char hdr[sizeof(uint32_t)];
/** send/receive buffer on heap */
chunk_t buf;
/** bytes sent/received in buffer */
uint32_t done;
} msg_buf_t;
/**
* Client connection entry
*/
typedef struct {
/** reference to socket */
private_vici_socket_t *this;
/** associated stream */
stream_t *stream;
/** queued messages to send, as msg_buf_t pointers */
array_t *out;
/** input message buffer */
msg_buf_t in;
/** queued input messages to process, as chunk_t */
array_t *queue;
/** do we have job processing input queue? */
bool has_processor;
/** is this client disconnecting */
bool disconnecting;
/** client connection identifier */
u_int id;
/** any users reading over this connection? */
int readers;
/** any users writing over this connection? */
int writers;
/** condvar to wait for usage */
condvar_t *cond;
} entry_t;
/**
* Destroy an connection entry
*/
CALLBACK(destroy_entry, void,
entry_t *entry)
{
msg_buf_t *out;
chunk_t chunk;
entry->stream->destroy(entry->stream);
entry->this->disconnect(entry->this->user, entry->id);
entry->cond->destroy(entry->cond);
while (array_remove(entry->out, ARRAY_TAIL, &out))
{
chunk_clear(&out->buf);
free(out);
}
array_destroy(entry->out);
while (array_remove(entry->queue, ARRAY_TAIL, &chunk))
{
chunk_clear(&chunk);
}
array_destroy(entry->queue);
chunk_clear(&entry->in.buf);
free(entry);
}
/**
* Find entry by stream (if given) or id, claim use
*/
static entry_t* find_entry(private_vici_socket_t *this, stream_t *stream,
u_int id, bool reader, bool writer)
{
enumerator_t *enumerator;
entry_t *entry, *found = NULL;
bool candidate = TRUE;
this->mutex->lock(this->mutex);
while (candidate && !found)
{
candidate = FALSE;
enumerator = this->connections->create_enumerator(this->connections);
while (enumerator->enumerate(enumerator, &entry))
{
if (stream)
{
if (entry->stream != stream)
{
continue;
}
}
else
{
if (entry->id != id)
{
continue;
}
}
if (entry->disconnecting)
{
entry->cond->signal(entry->cond);
continue;
}
candidate = TRUE;
if ((reader && entry->readers) ||
(writer && entry->writers))
{
entry->cond->wait(entry->cond, this->mutex);
break;
}
if (reader)
{
entry->readers++;
}
if (writer)
{
entry->writers++;
}
found = entry;
break;
}
enumerator->destroy(enumerator);
}
this->mutex->unlock(this->mutex);
return found;
}
/**
* Remove entry by id, claim use
*/
static entry_t* remove_entry(private_vici_socket_t *this, u_int id)
{
enumerator_t *enumerator;
entry_t *entry, *found = NULL;
bool candidate = TRUE;
this->mutex->lock(this->mutex);
while (candidate && !found)
{
candidate = FALSE;
enumerator = this->connections->create_enumerator(this->connections);
while (enumerator->enumerate(enumerator, &entry))
{
if (entry->id == id)
{
candidate = TRUE;
if (entry->readers || entry->writers)
{
entry->cond->wait(entry->cond, this->mutex);
break;
}
this->connections->remove_at(this->connections, enumerator);
entry->cond->broadcast(entry->cond);
found = entry;
break;
}
}
enumerator->destroy(enumerator);
}
this->mutex->unlock(this->mutex);
return found;
}
/**
* Release a claimed entry
*/
static void put_entry(private_vici_socket_t *this, entry_t *entry,
bool reader, bool writer)
{
this->mutex->lock(this->mutex);
if (reader)
{
entry->readers--;
}
if (writer)
{
entry->writers--;
}
entry->cond->signal(entry->cond);
this->mutex->unlock(this->mutex);
}
/**
* Asynchronous callback to disconnect client
*/
CALLBACK(disconnect_async, job_requeue_t,
entry_selector_t *sel)
{
entry_t *entry;
entry = remove_entry(sel->this, sel->id);
if (entry)
{
destroy_entry(entry);
}
return JOB_REQUEUE_NONE;
}
/**
* Disconnect a connected client
*/
static void disconnect(private_vici_socket_t *this, u_int id)
{
entry_selector_t *sel;
INIT(sel,
.this = this,
.id = id,
);
lib->processor->queue_job(lib->processor,
(job_t*)callback_job_create(disconnect_async, sel, free, NULL));
}
/**
* Write queued output data
*/
static bool do_write(private_vici_socket_t *this, entry_t *entry,
stream_t *stream, char *errmsg, size_t errlen, bool block)
{
msg_buf_t *out;
ssize_t len;
while (array_get(entry->out, ARRAY_HEAD, &out))
{
/* write header */
while (out->hdrlen < sizeof(out->hdr))
{
len = stream->write(stream, out->hdr + out->hdrlen,
sizeof(out->hdr) - out->hdrlen, block);
if (len == 0)
{
return FALSE;
}
if (len < 0)
{
if (errno == EWOULDBLOCK)
{
return TRUE;
}
snprintf(errmsg, errlen, "vici header write error: %s",
strerror(errno));
return FALSE;
}
out->hdrlen += len;
}
/* write buffer buffer */
while (out->buf.len > out->done)
{
len = stream->write(stream, out->buf.ptr + out->done,
out->buf.len - out->done, block);
if (len == 0)
{
snprintf(errmsg, errlen, "premature vici disconnect");
return FALSE;
}
if (len < 0)
{
if (errno == EWOULDBLOCK)
{
return TRUE;
}
snprintf(errmsg, errlen, "vici write error: %s", strerror(errno));
return FALSE;
}
out->done += len;
}
if (array_remove(entry->out, ARRAY_HEAD, &out))
{
chunk_clear(&out->buf);
free(out);
}
}
return TRUE;
}
/**
* Send pending messages
*/
CALLBACK(on_write, bool,
private_vici_socket_t *this, stream_t *stream)
{
char errmsg[256] = "";
entry_t *entry;
bool ret = FALSE;
entry = find_entry(this, stream, 0, FALSE, TRUE);
if (entry)
{
ret = do_write(this, entry, stream, errmsg, sizeof(errmsg), FALSE);
if (ret)
{
/* unregister if we have no more messages to send */
ret = array_count(entry->out) != 0;
}
else
{
entry->disconnecting = TRUE;
disconnect(entry->this, entry->id);
}
put_entry(this, entry, FALSE, TRUE);
if (!ret && errmsg[0])
{
DBG1(DBG_CFG, errmsg);
}
}
return ret;
}
/**
* Read in available header with data, non-blocking accumulating to buffer
*/
static bool do_read(private_vici_socket_t *this, entry_t *entry,
stream_t *stream, char *errmsg, size_t errlen)
{
uint32_t msglen;
ssize_t len;
/* assemble the length header first */
while (entry->in.hdrlen < sizeof(entry->in.hdr))
{
len = stream->read(stream, entry->in.hdr + entry->in.hdrlen,
sizeof(entry->in.hdr) - entry->in.hdrlen, FALSE);
if (len == 0)
{
return FALSE;
}
if (len < 0)
{
if (errno == EWOULDBLOCK)
{
return TRUE;
}
snprintf(errmsg, errlen, "vici header read error: %s",
strerror(errno));
return FALSE;
}
entry->in.hdrlen += len;
if (entry->in.hdrlen == sizeof(entry->in.hdr))
{
msglen = untoh32(entry->in.hdr);
if (msglen > VICI_MESSAGE_SIZE_MAX)
{
snprintf(errmsg, errlen, "vici message length %u exceeds %u "
"bytes limit, ignored", msglen, VICI_MESSAGE_SIZE_MAX);
return FALSE;
}
/* header complete, continue with data */
entry->in.buf = chunk_alloc(msglen);
}
}
/* assemble buffer */
while (entry->in.buf.len > entry->in.done)
{
len = stream->read(stream, entry->in.buf.ptr + entry->in.done,
entry->in.buf.len - entry->in.done, FALSE);
if (len == 0)
{
snprintf(errmsg, errlen, "premature vici disconnect");
return FALSE;
}
if (len < 0)
{
if (errno == EWOULDBLOCK)
{
return TRUE;
}
snprintf(errmsg, errlen, "vici read error: %s", strerror(errno));
return FALSE;
}
entry->in.done += len;
}
return TRUE;
}
/**
* Callback processing incoming requests in strict order
*/
CALLBACK(process_queue, job_requeue_t,
entry_selector_t *sel)
{
entry_t *entry;
chunk_t chunk;
bool found;
u_int id;
while (TRUE)
{
entry = find_entry(sel->this, NULL, sel->id, TRUE, FALSE);
if (!entry)
{
break;
}
found = array_remove(entry->queue, ARRAY_HEAD, &chunk);
if (!found)
{
entry->has_processor = FALSE;
}
id = entry->id;
put_entry(sel->this, entry, TRUE, FALSE);
if (!found)
{
break;
}
thread_cleanup_push(free, chunk.ptr);
sel->this->inbound(sel->this->user, id, chunk);
thread_cleanup_pop(TRUE);
}
return JOB_REQUEUE_NONE;
}
/**
* Process incoming messages
*/
CALLBACK(on_read, bool,
private_vici_socket_t *this, stream_t *stream)
{
char errmsg[256] = "";
entry_selector_t *sel;
entry_t *entry;
bool ret = FALSE;
entry = find_entry(this, stream, 0, TRUE, FALSE);
if (entry)
{
ret = do_read(this, entry, stream, errmsg, sizeof(errmsg));
if (!ret)
{
entry->disconnecting = TRUE;
disconnect(this, entry->id);
}
else if (entry->in.hdrlen == sizeof(entry->in.hdr) &&
entry->in.buf.len == entry->in.done)
{
array_insert(entry->queue, ARRAY_TAIL, &entry->in.buf);
entry->in.buf = chunk_empty;
entry->in.hdrlen = entry->in.done = 0;
if (!entry->has_processor)
{
INIT(sel,
.this = this,
.id = entry->id,
);
lib->processor->queue_job(lib->processor,
(job_t*)callback_job_create(process_queue,
sel, free, NULL));
entry->has_processor = TRUE;
}
}
put_entry(this, entry, TRUE, FALSE);
if (!ret && errmsg[0])
{
DBG1(DBG_CFG, errmsg);
}
}
return ret;
}
/**
* Process connection request
*/
CALLBACK(on_accept, bool,
private_vici_socket_t *this, stream_t *stream)
{
entry_t *entry;
u_int id;
id = ref_get(&this->nextid);
INIT(entry,
.this = this,
.stream = stream,
.id = id,
.out = array_create(0, 0),
.queue = array_create(sizeof(chunk_t), 0),
.cond = condvar_create(CONDVAR_TYPE_DEFAULT),
.readers = 1,
);
this->mutex->lock(this->mutex);
this->connections->insert_last(this->connections, entry);
this->mutex->unlock(this->mutex);
stream->on_read(stream, on_read, this);
put_entry(this, entry, TRUE, FALSE);
this->connect(this->user, id);
return TRUE;
}
/**
* Async callback to enable writer
*/
CALLBACK(enable_writer, job_requeue_t,
entry_selector_t *sel)
{
entry_t *entry;
entry = find_entry(sel->this, NULL, sel->id, FALSE, TRUE);
if (entry)
{
entry->stream->on_write(entry->stream, on_write, sel->this);
put_entry(sel->this, entry, FALSE, TRUE);
}
return JOB_REQUEUE_NONE;
}
METHOD(vici_socket_t, send_, void,
private_vici_socket_t *this, u_int id, chunk_t msg)
{
if (msg.len <= VICI_MESSAGE_SIZE_MAX)
{
entry_selector_t *sel;
msg_buf_t *out;
entry_t *entry;
entry = find_entry(this, NULL, id, FALSE, TRUE);
if (entry)
{
INIT(out,
.buf = msg,
);
htoun32(out->hdr, msg.len);
array_insert(entry->out, ARRAY_TAIL, out);
if (array_count(entry->out) == 1)
{ /* asynchronously re-enable on_write callback when we get data */
INIT(sel,
.this = this,
.id = entry->id,
);
lib->processor->queue_job(lib->processor,
(job_t*)callback_job_create(enable_writer,
sel, free, NULL));
}
put_entry(this, entry, FALSE, TRUE);
}
else
{
DBG1(DBG_CFG, "vici connection %u unknown", id);
chunk_clear(&msg);
}
}
else
{
DBG1(DBG_CFG, "vici message size %zu exceeds maximum size of %u, "
"discarded", msg.len, VICI_MESSAGE_SIZE_MAX);
chunk_clear(&msg);
}
}
CALLBACK(flush_messages, void,
entry_t *entry, va_list args)
{
private_vici_socket_t *this;
char errmsg[256] = "";
bool ret;
VA_ARGS_VGET(args, this);
/* no need for any locking as no other threads are running, the connections
* all get disconnected afterwards, so error handling is simple too */
ret = do_write(this, entry, entry->stream, errmsg, sizeof(errmsg), TRUE);
if (!ret && errmsg[0])
{
DBG1(DBG_CFG, errmsg);
}
}
METHOD(vici_socket_t, destroy, void,
private_vici_socket_t *this)
{
DESTROY_IF(this->service);
this->connections->invoke_function(this->connections, flush_messages, this);
this->connections->destroy_function(this->connections, destroy_entry);
this->mutex->destroy(this->mutex);
free(this);
}
/*
* see header file
*/
vici_socket_t *vici_socket_create(char *uri, vici_inbound_cb_t inbound,
vici_connect_cb_t connect,
vici_disconnect_cb_t disconnect, void *user)
{
private_vici_socket_t *this;
INIT(this,
.public = {
.send = _send_,
.destroy = _destroy,
},
.mutex = mutex_create(MUTEX_TYPE_DEFAULT),
.connections = linked_list_create(),
.inbound = inbound,
.connect = connect,
.disconnect = disconnect,
.user = user,
);
this->service = lib->streams->create_service(lib->streams, uri, 3);
if (!this->service)
{
DBG1(DBG_CFG, "creating vici socket failed");
destroy(this);
return NULL;
}
this->service->on_accept(this->service, on_accept, this,
JOB_PRIO_CRITICAL, 0);
return &this->public;
}