[RFC PATCH v2 16/29] ntsync: Introduce alertable waits.

From: Elizabeth Figura
Date: Tue Jan 30 2024 - 21:21:30 EST


NT waits can optionally be made "alertable". This is a special channel for
thread wakeup that is mildly similar to SIGIO. A thread has an internal single
bit of "alerted" state, and if a thread is made alerted while an alertable wait,
the wait will return a special value, consume the "alerted" state, and will not
consume any of its objects.

Alerts are implemented using events; the user-space NT emulator is expected to
create an internal ntsync event for each thread and pass that event to wait
functions.

Signed-off-by: Elizabeth Figura <zfigura@xxxxxxxxxxxxxxx>
---
drivers/misc/ntsync.c | 68 ++++++++++++++++++++++++++++++++-----
include/uapi/linux/ntsync.h | 2 +-
2 files changed, 60 insertions(+), 10 deletions(-)

diff --git a/drivers/misc/ntsync.c b/drivers/misc/ntsync.c
index 5439c1c9e90f..1e619e1ce6a6 100644
--- a/drivers/misc/ntsync.c
+++ b/drivers/misc/ntsync.c
@@ -784,22 +784,29 @@ static int setup_wait(struct ntsync_device *dev,
const struct ntsync_wait_args *args, bool all,
struct ntsync_q **ret_q)
{
+ int fds[NTSYNC_MAX_WAIT_COUNT + 1];
const __u32 count = args->count;
- int fds[NTSYNC_MAX_WAIT_COUNT];
struct ntsync_q *q;
+ __u32 total_count;
__u32 i, j;

- if (!args->owner || args->pad)
+ if (!args->owner)
return -EINVAL;

if (args->count > NTSYNC_MAX_WAIT_COUNT)
return -EINVAL;

+ total_count = count;
+ if (args->alert)
+ total_count++;
+
if (copy_from_user(fds, u64_to_user_ptr(args->objs),
array_size(count, sizeof(*fds))))
return -EFAULT;
+ if (args->alert)
+ fds[count] = args->alert;

- q = kmalloc(struct_size(q, entries, count), GFP_KERNEL);
+ q = kmalloc(struct_size(q, entries, total_count), GFP_KERNEL);
if (!q)
return -ENOMEM;
q->task = current;
@@ -809,7 +816,7 @@ static int setup_wait(struct ntsync_device *dev,
q->ownerdead = false;
q->count = count;

- for (i = 0; i < count; i++) {
+ for (i = 0; i < total_count; i++) {
struct ntsync_q_entry *entry = &q->entries[i];
struct ntsync_obj *obj = get_obj(dev, fds[i]);

@@ -860,9 +867,9 @@ static int ntsync_wait_any(struct ntsync_device *dev, void __user *argp)
{
struct ntsync_wait_args args;
struct ntsync_q *q;
+ __u32 i, total_count;
ktime_t timeout;
int signaled;
- __u32 i;
int ret;

if (copy_from_user(&args, argp, sizeof(args)))
@@ -872,9 +879,13 @@ static int ntsync_wait_any(struct ntsync_device *dev, void __user *argp)
if (ret < 0)
return ret;

+ total_count = args.count;
+ if (args.alert)
+ total_count++;
+
/* queue ourselves */

- for (i = 0; i < args.count; i++) {
+ for (i = 0; i < total_count; i++) {
struct ntsync_q_entry *entry = &q->entries[i];
struct ntsync_obj *obj = entry->obj;

@@ -883,9 +894,15 @@ static int ntsync_wait_any(struct ntsync_device *dev, void __user *argp)
spin_unlock(&obj->lock);
}

- /* check if we are already signaled */
+ /*
+ * Check if we are already signaled.
+ *
+ * Note that the API requires that normal objects are checked before
+ * the alert event. Hence we queue the alert event last, and check
+ * objects in order.
+ */

- for (i = 0; i < args.count; i++) {
+ for (i = 0; i < total_count; i++) {
struct ntsync_obj *obj = q->entries[i].obj;

if (atomic_read(&q->signaled) != -1)
@@ -903,7 +920,7 @@ static int ntsync_wait_any(struct ntsync_device *dev, void __user *argp)

/* and finally, unqueue */

- for (i = 0; i < args.count; i++) {
+ for (i = 0; i < total_count; i++) {
struct ntsync_q_entry *entry = &q->entries[i];
struct ntsync_obj *obj = entry->obj;

@@ -964,6 +981,14 @@ static int ntsync_wait_all(struct ntsync_device *dev, void __user *argp)
*/
list_add_tail(&entry->node, &obj->all_waiters);
}
+ if (args.alert) {
+ struct ntsync_q_entry *entry = &q->entries[args.count];
+ struct ntsync_obj *obj = entry->obj;
+
+ spin_lock_nest_lock(&obj->lock, &dev->wait_all_lock);
+ list_add_tail(&entry->node, &obj->any_waiters);
+ spin_unlock(&obj->lock);
+ }

/* check if we are already signaled */

@@ -971,6 +996,21 @@ static int ntsync_wait_all(struct ntsync_device *dev, void __user *argp)

spin_unlock(&dev->wait_all_lock);

+ /*
+ * Check if the alert event is signaled, making sure to do so only
+ * after checking if the other objects are signaled.
+ */
+
+ if (args.alert) {
+ struct ntsync_obj *obj = q->entries[args.count].obj;
+
+ if (atomic_read(&q->signaled) == -1) {
+ spin_lock(&obj->lock);
+ try_wake_any_obj(obj);
+ spin_unlock(&obj->lock);
+ }
+ }
+
/* sleep */

timeout = ns_to_ktime(args.timeout);
@@ -994,6 +1034,16 @@ static int ntsync_wait_all(struct ntsync_device *dev, void __user *argp)

put_obj(obj);
}
+ if (args.alert) {
+ struct ntsync_q_entry *entry = &q->entries[args.count];
+ struct ntsync_obj *obj = entry->obj;
+
+ spin_lock_nest_lock(&obj->lock, &dev->wait_all_lock);
+ list_del(&entry->node);
+ spin_unlock(&obj->lock);
+
+ put_obj(obj);
+ }

spin_unlock(&dev->wait_all_lock);

diff --git a/include/uapi/linux/ntsync.h b/include/uapi/linux/ntsync.h
index 582d33b0dcac..7c91af7011e4 100644
--- a/include/uapi/linux/ntsync.h
+++ b/include/uapi/linux/ntsync.h
@@ -34,7 +34,7 @@ struct ntsync_wait_args {
__u32 count;
__u32 owner;
__u32 index;
- __u32 pad;
+ __u32 alert;
};

#define NTSYNC_MAX_WAIT_COUNT 64
--
2.43.0