[PATCH v2 2/5] platform/x86: wmi: Fix refcounting of WMI devices in legacy functions

From: Armin Wolf
Date: Fri Oct 20 2023 - 17:10:37 EST


Until now, legacy GUID-based functions where using find_guid() when
searching for WMI devices, which did no refcounting on the returned
WMI device. This meant that the WMI device could disappear at any
moment, potentially leading to various errors. Fix this by using
bus_find_device() which returns an actual reference to the found
WMI device.

Signed-off-by: Armin Wolf <W_Armin@xxxxxx>
---
drivers/platform/x86/wmi.c | 167 ++++++++++++++++++++++++-------------
1 file changed, 107 insertions(+), 60 deletions(-)

diff --git a/drivers/platform/x86/wmi.c b/drivers/platform/x86/wmi.c
index 1dbef16acdeb..e3984801883a 100644
--- a/drivers/platform/x86/wmi.c
+++ b/drivers/platform/x86/wmi.c
@@ -109,33 +109,13 @@ static const char * const allow_duplicates[] = {
NULL
};

+#define dev_to_wblock(__dev) container_of_const(__dev, struct wmi_block, dev.dev)
+#define dev_to_wdev(__dev) container_of_const(__dev, struct wmi_device, dev)
+
/*
* GUID parsing functions
*/

-static acpi_status find_guid(const char *guid_string, struct wmi_block **out)
-{
- guid_t guid_input;
- struct wmi_block *wblock;
-
- if (!guid_string)
- return AE_BAD_PARAMETER;
-
- if (guid_parse(guid_string, &guid_input))
- return AE_BAD_PARAMETER;
-
- list_for_each_entry(wblock, &wmi_block_list, list) {
- if (guid_equal(&wblock->gblock.guid, &guid_input)) {
- if (out)
- *out = wblock;
-
- return AE_OK;
- }
- }
-
- return AE_NOT_FOUND;
-}
-
static bool guid_parse_and_compare(const char *string, const guid_t *guid)
{
guid_t guid_input;
@@ -245,6 +225,41 @@ static acpi_status get_event_data(const struct wmi_block *wblock, struct acpi_bu
return acpi_evaluate_object(wblock->acpi_device->handle, "_WED", &input, out);
}

+static int wmidev_match_guid(struct device *dev, const void *data)
+{
+ struct wmi_block *wblock = dev_to_wblock(dev);
+ const guid_t *guid = data;
+
+ if (guid_equal(guid, &wblock->gblock.guid))
+ return 1;
+
+ return 0;
+}
+
+static struct bus_type wmi_bus_type;
+
+static struct wmi_device *wmi_find_device_by_guid(const char *guid_string)
+{
+ struct device *dev;
+ guid_t guid;
+ int ret;
+
+ ret = guid_parse(guid_string, &guid);
+ if (ret < 0)
+ return ERR_PTR(ret);
+
+ dev = bus_find_device(&wmi_bus_type, NULL, &guid, wmidev_match_guid);
+ if (!dev)
+ return ERR_PTR(-ENODEV);
+
+ return dev_to_wdev(dev);
+}
+
+static void wmi_device_put(struct wmi_device *wdev)
+{
+ put_device(&wdev->dev);
+}
+
/*
* Exported WMI functions
*/
@@ -279,18 +294,17 @@ EXPORT_SYMBOL_GPL(set_required_buffer_size);
*/
int wmi_instance_count(const char *guid_string)
{
- struct wmi_block *wblock;
- acpi_status status;
+ struct wmi_device *wdev;
+ int ret;

- status = find_guid(guid_string, &wblock);
- if (ACPI_FAILURE(status)) {
- if (status == AE_BAD_PARAMETER)
- return -EINVAL;
+ wdev = wmi_find_device_by_guid(guid_string);
+ if (IS_ERR(wdev))
+ return PTR_ERR(wdev);

- return -ENODEV;
- }
+ ret = wmidev_instance_count(wdev);
+ wmi_device_put(wdev);

- return wmidev_instance_count(&wblock->dev);
+ return ret;
}
EXPORT_SYMBOL_GPL(wmi_instance_count);

@@ -325,15 +339,18 @@ EXPORT_SYMBOL_GPL(wmidev_instance_count);
acpi_status wmi_evaluate_method(const char *guid_string, u8 instance, u32 method_id,
const struct acpi_buffer *in, struct acpi_buffer *out)
{
- struct wmi_block *wblock = NULL;
+ struct wmi_device *wdev;
acpi_status status;

- status = find_guid(guid_string, &wblock);
- if (ACPI_FAILURE(status))
- return status;
+ wdev = wmi_find_device_by_guid(guid_string);
+ if (IS_ERR(wdev))
+ return AE_ERROR;
+
+ status = wmidev_evaluate_method(wdev, instance, method_id, in, out);

- return wmidev_evaluate_method(&wblock->dev, instance, method_id,
- in, out);
+ wmi_device_put(wdev);
+
+ return status;
}
EXPORT_SYMBOL_GPL(wmi_evaluate_method);

@@ -472,13 +489,19 @@ acpi_status wmi_query_block(const char *guid_string, u8 instance,
struct acpi_buffer *out)
{
struct wmi_block *wblock;
+ struct wmi_device *wdev;
acpi_status status;

- status = find_guid(guid_string, &wblock);
- if (ACPI_FAILURE(status))
- return status;
+ wdev = wmi_find_device_by_guid(guid_string);
+ if (IS_ERR(wdev))
+ return AE_ERROR;

- return __query_block(wblock, instance, out);
+ wblock = container_of(wdev, struct wmi_block, dev);
+ status = __query_block(wblock, instance, out);
+
+ wmi_device_put(wdev);
+
+ return status;
}
EXPORT_SYMBOL_GPL(wmi_query_block);

@@ -516,8 +539,9 @@ EXPORT_SYMBOL_GPL(wmidev_block_query);
acpi_status wmi_set_block(const char *guid_string, u8 instance,
const struct acpi_buffer *in)
{
- struct wmi_block *wblock = NULL;
+ struct wmi_block *wblock;
struct guid_block *block;
+ struct wmi_device *wdev;
acpi_handle handle;
struct acpi_object_list input;
union acpi_object params[2];
@@ -527,19 +551,26 @@ acpi_status wmi_set_block(const char *guid_string, u8 instance,
if (!in)
return AE_BAD_DATA;

- status = find_guid(guid_string, &wblock);
- if (ACPI_FAILURE(status))
- return status;
+ wdev = wmi_find_device_by_guid(guid_string);
+ if (IS_ERR(wdev))
+ return AE_ERROR;

+ wblock = container_of(wdev, struct wmi_block, dev);
block = &wblock->gblock;
handle = wblock->acpi_device->handle;

- if (block->instance_count <= instance)
- return AE_BAD_PARAMETER;
+ if (block->instance_count <= instance) {
+ status = AE_BAD_PARAMETER;
+
+ goto err_wdev_put;
+ }

/* Check GUID is a data block */
- if (block->flags & (ACPI_WMI_EVENT | ACPI_WMI_METHOD))
- return AE_ERROR;
+ if (block->flags & (ACPI_WMI_EVENT | ACPI_WMI_METHOD)) {
+ status = AE_ERROR;
+
+ goto err_wdev_put;
+ }

input.count = 2;
input.pointer = params;
@@ -551,7 +582,12 @@ acpi_status wmi_set_block(const char *guid_string, u8 instance,

get_acpi_method_name(wblock, 'S', method);

- return acpi_evaluate_object(handle, method, &input, NULL);
+ status = acpi_evaluate_object(handle, method, &input, NULL);
+
+err_wdev_put:
+ wmi_device_put(wdev);
+
+ return status;
}
EXPORT_SYMBOL_GPL(wmi_set_block);

@@ -742,7 +778,15 @@ EXPORT_SYMBOL_GPL(wmi_get_event_data);
*/
bool wmi_has_guid(const char *guid_string)
{
- return ACPI_SUCCESS(find_guid(guid_string, NULL));
+ struct wmi_device *wdev;
+
+ wdev = wmi_find_device_by_guid(guid_string);
+ if (IS_ERR(wdev))
+ return false;
+
+ wmi_device_put(wdev);
+
+ return true;
}
EXPORT_SYMBOL_GPL(wmi_has_guid);

@@ -756,20 +800,23 @@ EXPORT_SYMBOL_GPL(wmi_has_guid);
*/
char *wmi_get_acpi_device_uid(const char *guid_string)
{
- struct wmi_block *wblock = NULL;
- acpi_status status;
+ struct wmi_block *wblock;
+ struct wmi_device *wdev;
+ char *uid;

- status = find_guid(guid_string, &wblock);
- if (ACPI_FAILURE(status))
+ wdev = wmi_find_device_by_guid(guid_string);
+ if (IS_ERR(wdev))
return NULL;

- return acpi_device_uid(wblock->acpi_device);
+ wblock = container_of(wdev, struct wmi_block, dev);
+ uid = acpi_device_uid(wblock->acpi_device);
+
+ wmi_device_put(wdev);
+
+ return uid;
}
EXPORT_SYMBOL_GPL(wmi_get_acpi_device_uid);

-#define dev_to_wblock(__dev) container_of_const(__dev, struct wmi_block, dev.dev)
-#define dev_to_wdev(__dev) container_of_const(__dev, struct wmi_device, dev)
-
static inline struct wmi_driver *drv_to_wdrv(struct device_driver *drv)
{
return container_of(drv, struct wmi_driver, driver);
--
2.39.2