[3/3] client: Change the way fd's are buffered and flushed

Submitted by Lloyd Pique on Sept. 8, 2018, 1:14 a.m.

Details

Message ID 20180908011405.102768-3-lpique@google.com
State New
Series "Series without cover letter"
Headers show

Commit Message

Lloyd Pique Sept. 8, 2018, 1:14 a.m.
To allow more fd's to be buffered, remove the use of the MAX_FDS_OUT
when deciding if there is room in fds_out. A full set of 1024 can now
be enqueued for output.

The consequence is that the flush code must be able to handle sending
more than MAX_FDS_OUT safely. The logic has been changed so that if
there are more than MAX_FDS_OUT of them, that maximum count is sent
along with the data for a single message, rather than the data for all
the messages, as fd's cannot be sent without any data. As the receiver
will always assume it can read at least one message from the input
buffers, one messasge is sent -- it just might have unused fd's.

To keep things sane, an explicit limit to the number of fd's per message
is introduced of three. An error will be generated if there are more.

Signed-off-by: Lloyd Pique <lpique@google.com>
---
 src/connection.c        |  93 +++++++++++++++++++-------
 src/wayland-client.c    |  12 ++--
 src/wayland-private.h   |   3 +-
 src/wayland-server.c    |   2 +-
 tests/connection-test.c | 141 +++++++++++++++++++++++++++++++++++-----
 5 files changed, 205 insertions(+), 46 deletions(-)

Patch hide | download patch | download mbox

diff --git a/src/connection.c b/src/connection.c
index c271fa0..bf77ef8 100644
--- a/src/connection.c
+++ b/src/connection.c
@@ -166,6 +166,36 @@  wl_buffer_size(struct wl_buffer *b)
 	return b->head - b->tail;
 }
 
+static void
+wl_buffer_get_one_msg_iov(struct wl_buffer *b, struct iovec *iov, int *count)
+{
+	uint32_t p[2];
+	int msg_size;
+	uint32_t head, tail;
+
+	wl_buffer_copy(b, p, sizeof p);
+	msg_size = p[1] >> 16;
+
+	head = MASK(b->head);
+	tail = MASK(b->tail);
+
+	if (tail + msg_size <= sizeof b->data) {
+		iov[0].iov_base = b->data + tail;
+		iov[0].iov_len = msg_size;
+		*count = 1;
+	} else if (head == 0) {
+		iov[0].iov_base = b->data + tail;
+		iov[0].iov_len = sizeof b->data - tail;
+		*count = 1;
+	} else {
+		iov[0].iov_base = b->data + tail;
+		iov[0].iov_len = sizeof b->data - tail;
+		iov[1].iov_base = b->data;
+		iov[1].iov_len = msg_size - (sizeof b->data - tail);
+		*count = 2;
+	}
+}
+
 struct wl_connection *
 wl_connection_create(int fd)
 {
@@ -325,7 +355,11 @@  wl_connection_flush(struct wl_connection *connection, int soft)
 
 	tail = connection->out.tail;
 	while (connection->out.head - connection->out.tail > 0) {
-		wl_buffer_get_iov(&connection->out, iov, &count);
+		if (wl_buffer_size(&connection->fds_out) > MAX_FDS_OUT * sizeof(int32_t)) {
+			wl_buffer_get_one_msg_iov(&connection->out, iov, &count);
+		} else {
+			wl_buffer_get_iov(&connection->out, iov, &count);
+		}
 
 		build_cmsg(&connection->fds_out, cmsg, &clen);
 
@@ -400,18 +434,18 @@  wl_connection_read(struct wl_connection *connection)
 	return wl_connection_pending_input(connection);
 }
 
-int
-wl_connection_write(struct wl_connection *connection,
-		    const void *data, size_t count)
+static int
+connection_buffer_write(struct wl_connection *connection,
+	                struct wl_buffer *buffer, const void *data,
+	                size_t count)
 {
-	if (connection->out.head - connection->out.tail +
-	    count > ARRAY_LENGTH(connection->out.data)) {
+	if (buffer->head - buffer->tail + count > ARRAY_LENGTH(buffer->data)) {
 		connection->want_flush = 1;
-		if (wl_connection_flush(connection) < 0)
+		if (wl_connection_flush(connection, 0) < 0)
 			return -1;
 	}
 
-	if (wl_buffer_put(&connection->out, data, count) < 0)
+	if (wl_buffer_put(buffer, data, count) < 0)
 		return -1;
 
 	connection->want_flush = 1;
@@ -419,6 +453,14 @@  wl_connection_write(struct wl_connection *connection,
 	return 0;
 }
 
+int
+wl_connection_write(struct wl_connection *connection,
+		    const void *data, size_t count)
+{
+	return connection_buffer_write(connection, &connection->out, data,
+				       count);
+}
+
 int
 wl_connection_queue(struct wl_connection *connection,
 		    const void *data, size_t count)
@@ -455,13 +497,8 @@  wl_connection_get_fd(struct wl_connection *connection)
 static int
 wl_connection_put_fd(struct wl_connection *connection, int32_t fd)
 {
-	if (wl_buffer_size(&connection->fds_out) == MAX_FDS_OUT * sizeof fd) {
-		connection->want_flush = 1;
-		if (wl_connection_flush(connection) < 0)
-			return -1;
-	}
-
-	return wl_buffer_put(&connection->fds_out, &fd, sizeof fd);
+	return connection_buffer_write(connection, &connection->fds_out, &fd,
+				       sizeof fd);
 }
 
 const char *
@@ -489,9 +526,10 @@  get_next_argument(const char *signature, struct argument_details *details)
 }
 
 int
-arg_count_for_signature(const char *signature)
+arg_count_for_signature(const char *signature, int *out_fd_count)
 {
 	int count = 0;
+	int fd_count = 0;
 	for(; *signature; ++signature) {
 		switch(*signature) {
 		case 'i':
@@ -501,10 +539,14 @@  arg_count_for_signature(const char *signature)
 		case 'o':
 		case 'n':
 		case 'a':
+			++count;
+			break;
 		case 'h':
 			++count;
+			++fd_count;
 		}
 	}
+	if (out_fd_count) *out_fd_count = fd_count;
 	return count;
 }
 
@@ -583,14 +625,19 @@  wl_closure_init(const struct wl_message *message, uint32_t size,
                 int *num_arrays, union wl_argument *args)
 {
 	struct wl_closure *closure;
-	int count;
+	int count, fd_count;
 
-	count = arg_count_for_signature(message->signature);
+	count = arg_count_for_signature(message->signature, &fd_count);
 	if (count > WL_CLOSURE_MAX_ARGS) {
 		wl_log("too many args (%d)\n", count);
 		errno = EINVAL;
 		return NULL;
 	}
+	if (fd_count > WL_CLOSURE_MAX_FD_ARGS) {
+		wl_log("too many fd args (%d)\n", fd_count);
+		errno = EINVAL;
+		return NULL;
+	}
 
 	if (size) {
 		*num_arrays = wl_message_count_arrays(message);
@@ -906,7 +953,7 @@  wl_closure_lookup_objects(struct wl_closure *closure, struct wl_map *objects)
 
 	message = closure->message;
 	signature = message->signature;
-	count = arg_count_for_signature(signature);
+	count = arg_count_for_signature(signature, NULL);
 	for (i = 0; i < count; i++) {
 		signature = get_next_argument(signature, &arg);
 		switch (arg.type) {
@@ -1011,7 +1058,7 @@  wl_closure_invoke(struct wl_closure *closure, uint32_t flags,
 	void * ffi_args[WL_CLOSURE_MAX_ARGS + 2];
 	void (* const *implementation)(void);
 
-	count = arg_count_for_signature(closure->message->signature);
+	count = arg_count_for_signature(closure->message->signature, NULL);
 
 	ffi_types[0] = &ffi_type_pointer;
 	ffi_args[0] = &data;
@@ -1054,7 +1101,7 @@  copy_fds_to_connection(struct wl_closure *closure,
 	const char *signature = message->signature;
 	int fd;
 
-	count = arg_count_for_signature(signature);
+	count = arg_count_for_signature(signature, NULL);
 	for (i = 0; i < count; i++) {
 		signature = get_next_argument(signature, &arg);
 		if (arg.type != 'h')
@@ -1083,7 +1130,7 @@  buffer_size_for_closure(struct wl_closure *closure)
 	uint32_t size, buffer_size = 0;
 
 	signature = message->signature;
-	count = arg_count_for_signature(signature);
+	count = arg_count_for_signature(signature, NULL);
 	for (i = 0; i < count; i++) {
 		signature = get_next_argument(signature, &arg);
 
@@ -1140,7 +1187,7 @@  serialize_closure(struct wl_closure *closure, uint32_t *buffer,
 	end = buffer + buffer_count;
 
 	signature = message->signature;
-	count = arg_count_for_signature(signature);
+	count = arg_count_for_signature(signature, NULL);
 	for (i = 0; i < count; i++) {
 		signature = get_next_argument(signature, &arg);
 
diff --git a/src/wayland-client.c b/src/wayland-client.c
index eea634d..c681458 100644
--- a/src/wayland-client.c
+++ b/src/wayland-client.c
@@ -242,7 +242,7 @@  validate_closure_objects(struct wl_closure *closure)
 	struct wl_proxy *proxy;
 
 	signature = closure->message->signature;
-	count = arg_count_for_signature(signature);
+	count = arg_count_for_signature(signature, NULL);
 	for (i = 0; i < count; i++) {
 		signature = get_next_argument(signature, &arg);
 		switch (arg.type) {
@@ -270,7 +270,7 @@  destroy_queued_closure(struct wl_closure *closure)
 	int i, count;
 
 	signature = closure->message->signature;
-	count = arg_count_for_signature(signature);
+	count = arg_count_for_signature(signature, NULL);
 	for (i = 0; i < count; i++) {
 		signature = get_next_argument(signature, &arg);
 		switch (arg.type) {
@@ -354,7 +354,7 @@  message_count_fds(const char *signature)
 	unsigned int count, i, fds = 0;
 	struct argument_details arg;
 
-	count = arg_count_for_signature(signature);
+	count = arg_count_for_signature(signature, NULL);
 	for (i = 0; i < count; i++) {
 		signature = get_next_argument(signature, &arg);
 		if (arg.type == 'h')
@@ -638,7 +638,7 @@  create_outgoing_proxy(struct wl_proxy *proxy, const struct wl_message *message,
 	struct wl_proxy *new_proxy = NULL;
 
 	signature = message->signature;
-	count = arg_count_for_signature(signature);
+	count = arg_count_for_signature(signature, NULL);
 	for (i = 0; i < count; i++) {
 		signature = get_next_argument(signature, &arg);
 
@@ -1285,7 +1285,7 @@  create_proxies(struct wl_proxy *sender, struct wl_closure *closure)
 	int count;
 
 	signature = closure->message->signature;
-	count = arg_count_for_signature(signature);
+	count = arg_count_for_signature(signature, NULL);
 	for (i = 0; i < count; i++) {
 		signature = get_next_argument(signature, &arg);
 		switch (arg.type) {
@@ -1318,7 +1318,7 @@  increase_closure_args_refcount(struct wl_closure *closure)
 	struct wl_proxy *proxy;
 
 	signature = closure->message->signature;
-	count = arg_count_for_signature(signature);
+	count = arg_count_for_signature(signature, NULL);
 	for (i = 0; i < count; i++) {
 		signature = get_next_argument(signature, &arg);
 		switch (arg.type) {
diff --git a/src/wayland-private.h b/src/wayland-private.h
index ba183fc..9c8f3c7 100644
--- a/src/wayland-private.h
+++ b/src/wayland-private.h
@@ -50,6 +50,7 @@ 
 #define WL_MAP_CLIENT_SIDE 1
 #define WL_SERVER_ID_START 0xff000000
 #define WL_CLOSURE_MAX_ARGS 20
+#define WL_CLOSURE_MAX_FD_ARGS 3
 
 struct wl_object {
 	const struct wl_interface *interface;
@@ -160,7 +161,7 @@  const char *
 get_next_argument(const char *signature, struct argument_details *details);
 
 int
-arg_count_for_signature(const char *signature);
+arg_count_for_signature(const char *signature, int *out_fd_count);
 
 int
 wl_message_count_arrays(const struct wl_message *message);
diff --git a/src/wayland-server.c b/src/wayland-server.c
index 43e4099..51eda78 100644
--- a/src/wayland-server.c
+++ b/src/wayland-server.c
@@ -176,7 +176,7 @@  verify_objects(struct wl_resource *resource, uint32_t opcode,
 	struct wl_resource *res;
 	int count, i;
 
-	count = arg_count_for_signature(signature);
+	count = arg_count_for_signature(signature, NULL);
 	for (i = 0; i < count; i++) {
 		signature = get_next_argument(signature, &arg);
 		switch (arg.type) {
diff --git a/tests/connection-test.c b/tests/connection-test.c
index 4248f4a..4e5fc38 100644
--- a/tests/connection-test.c
+++ b/tests/connection-test.c
@@ -43,6 +43,30 @@ 
 
 static const char message[] = "Hello, world";
 
+static int
+create_test_fd(void)
+{
+	char f[] = "/tmp/wayland-tests-XXXXXX";
+	int fd = mkstemp(f);
+	assert(fd >= 0);
+	unlink(f);
+	return fd;
+}
+
+static void
+validate_fds_same_and_close(int expected, int actual)
+{
+	struct stat expected_buf, actual_buf;
+
+	assert(actual != expected);
+	fstat(actual, &actual_buf);
+	fstat(expected, &expected_buf);
+	assert(actual_buf.st_dev == expected_buf.st_dev);
+	assert(actual_buf.st_ino == expected_buf.st_ino);
+	close(actual);
+	close(expected);
+}
+
 static struct wl_connection *
 setup(int *s)
 {
@@ -368,15 +392,7 @@  static void
 validate_demarshal_h(struct marshal_data *data,
 		     struct wl_object *object, int fd)
 {
-	struct stat buf1, buf2;
-
-	assert(fd != data->value.h);
-	fstat(fd, &buf1);
-	fstat(data->value.h, &buf2);
-	assert(buf1.st_dev == buf2.st_dev);
-	assert(buf1.st_ino == buf2.st_ino);
-	close(fd);
-	close(data->value.h);
+	validate_fds_same_and_close(data->value.h, fd);
 }
 
 static void
@@ -492,7 +508,6 @@  marshal_demarshal(struct marshal_data *data,
 TEST(connection_marshal_demarshal)
 {
 	struct marshal_data data;
-	char f[] = "/tmp/wayland-tests-XXXXXX";
 
 	setup_marshal_data(&data);
 
@@ -512,9 +527,7 @@  TEST(connection_marshal_demarshal)
 	marshal_demarshal(&data, (void *) validate_demarshal_s,
 			  28, "?s", data.value.s);
 
-	data.value.h = mkstemp(f);
-	assert(data.value.h >= 0);
-	unlink(f);
+	data.value.h = create_test_fd();
 	marshal_demarshal(&data, (void *) validate_demarshal_h,
 			  8, "h", data.value.h);
 
@@ -608,8 +621,7 @@  TEST(connection_marshal_alot)
 	 * for both regular data an fds. */
 
 	for (i = 0; i < 2000; i++) {
-		strcpy(f, "/tmp/wayland-tests-XXXXXX");
-		data.value.h = mkstemp(f);
+		data.value.h = create_test_fd();
 		assert(data.value.h >= 0);
 		unlink(f);
 		marshal_demarshal(&data, (void *) validate_demarshal_h,
@@ -637,6 +649,105 @@  TEST(connection_marshal_too_big)
 	free(big_string);
 }
 
+static void
+validate_demarshal_hhh(int* expected_fds, struct wl_object *object, int fd1,
+		       int fd2, int fd3)
+{
+	validate_fds_same_and_close(expected_fds[0], fd1);
+	validate_fds_same_and_close(expected_fds[1], fd2);
+	validate_fds_same_and_close(expected_fds[2], fd3);
+}
+
+
+
+TEST(connection_marshal_alot_of_fds)
+{
+	struct marshal_data data;
+	struct wl_closure *closure;
+	static const int opcode = 4444;
+	static struct wl_object sender = { NULL, NULL, 1234 };
+	struct wl_message message = { "test", "hhh", NULL };
+	struct wl_map objects;
+	void (*func)(int *, struct wl_object *, int, int, int) = validate_demarshal_hhh;
+	struct wl_object object = { NULL, &func, 0 };
+	uint32_t msg[1] = { 1234 };
+	static const int one_message_bytes = 8;
+	static const int total_bytes = one_message_bytes * 100;
+	int fds[300];
+	int *next_fd;
+	union wl_argument args[WL_CLOSURE_MAX_ARGS];
+	int size;
+
+	setup_marshal_data(&data);
+
+        /* This test enqueues 100 messages that each use 3 fd's, and exercises
+         * the ability of the code to deal with more than MAX_FDS_OUT (the
+         * sendmsg() limit)
+         *
+         * It verifies that all 300 fd's
+	 */
+
+	wl_map_init(&objects, WL_MAP_SERVER_SIDE);
+	object.id = msg[0];
+
+	for (int i = 0; i < 300; i++) {
+		fds[i] = create_test_fd();
+	}
+
+	next_fd = fds;
+	for (int i = 0; i < 100; i++) {
+		args[0].h = *next_fd++;
+		args[1].h = *next_fd++;
+		args[2].h = *next_fd++;
+		closure  = wl_closure_marshal(&sender, opcode, args, &message);
+		assert(closure);
+		assert(wl_closure_send(closure, data.write_connection) == 0);
+		wl_closure_destroy(closure);
+	}
+
+	assert(next_fd == &fds[300]);
+
+	assert(wl_connection_flush(data.write_connection, 0) == total_bytes);
+
+	next_fd = fds;
+	size = 0;
+	for (int i = 0; i < 100; i++) {
+		if (size == 0)
+			size = wl_connection_read(data.read_connection);
+		closure = wl_connection_demarshal(data.read_connection,
+						  one_message_bytes, &objects,
+						  &message);
+		assert(closure);
+		wl_closure_invoke(closure, WL_CLOSURE_INVOKE_SERVER, &object,
+				  0, next_fd);
+		wl_closure_destroy(closure);
+
+		size -= one_message_bytes;
+		assert(size >= 0);
+
+		next_fd += 3;
+	}
+
+	assert(next_fd == &fds[300]);
+
+	assert(wl_connection_read(data.read_connection) == -1);
+
+	release_marshal_data(&data);
+}
+
+TEST(connection_marshal_too_many_fds_in_one_message)
+{
+	struct marshal_data data;
+	int fd = create_test_fd();
+
+	setup_marshal_data(&data);
+
+	expected_fail_marshal(EINVAL, "hhhh", fd, fd, fd, fd);
+
+	release_marshal_data(&data);
+	close(fd);
+}
+
 static void
 marshal_helper(const char *format, void *handler, ...)
 {