From mboxrd@z Thu Jan 1 00:00:00 1970 Return-Path: Received: from ffbox0-bg.ffmpeg.org (ffbox0-bg.ffmpeg.org [79.124.17.100]) by master.gitmailbox.com (Postfix) with ESMTPS id 0684A4E396 for ; Tue, 20 Jan 2026 14:10:54 +0000 (UTC) Authentication-Results: ffbox; dkim=fail (body hash mismatch (got b'eJtebbfnutV5rmD4ifoJ2LGyXiIuaxYLLebWYYsx5vU=', expected b'lvaOYe22iOqs69ZZIbcZC95Bx07tcsV4Ny+5LpwJdng=')) header.d=gmail.com header.a=rsa-sha256 DKIM-Signature: v=1; a=rsa-sha256; c=relaxed/simple; d=ffmpeg.org; i=@ffmpeg.org; q=dns/txt; s=mail; t=1768918221; h=to : date : message-id : mime-version : reply-to : subject : list-id : list-archive : list-archive : list-help : list-owner : list-post : list-subscribe : list-unsubscribe : from : cc : content-type : content-transfer-encoding : from; bh=fnHM1/ZTH9gWSY1cdTp/iDY6hs+eG6noWPaGrXj6TLE=; b=PfSnA9D4jkCHjxy8iW8EjVySd50C/4wRA4DqEf26Hfl0EanbE+zVNW5RHsUon9gm/8evX Q9V+PVYkklc/IYyuWi+SGCU2q3CxPB7mUBUw8iQhMOTtT/3SfZfgu4kNIikeGpOAENpF0ur Yb9IvyaddT2+L8A7KGcPWRjppASoxaslgZhJ9KBhtZZpAYMKQjQk7hGQjklzG7vFq5wFJE6 PGpoGOEL7fazzew0fZZvVfXxJDZ2ZgFTKqtYflpS107PMd/q6REuZlaxjU80MS72PBh4OOo r86602DpnYpihsf9VK1kqqpPrBHqYiqKzWy6RIrZU5MK+TRVo42Azw6Yzghw== Received: from [172.20.0.4] (unknown [172.20.0.4]) by ffbox0-bg.ffmpeg.org (Postfix) with ESMTP id 9B54A690EBA; Tue, 20 Jan 2026 16:10:21 +0200 (EET) ARC-Seal: i=1; cv=none; a=rsa-sha256; d=ffmpeg.org; s=arc; t=1768918204; b=ajnKAzFGzSprJOTJ5TneOC5xSjPWeaAElH8cGFrZMyRkIBZn8JGabES6ADBiB8VKRKHsI s2a0RVovOz3WLoFhtemghgGgmzbxYN90is/dgAf+yLeXbkXBtnHG44k7PcrVVSPT9q57plD f8Sy2rYl2NO+Dixvrz1pmMLKLoti99vINfLyLOU1KdfYBVr2QQLQEoWkPyTN17h6PxYFNCS CmC8Ta6HLehAFYjwUNTyyrRlw9dHM1+ItywkMJtYH4FKFvbnALMmrYWOFUI1a+I0mha1I+V H7/ESkAO2aLdo+e0mgSLxXhIXHaLBT99Y0cAOGlWBJEtMgps55DMDKP4Ft9A== ARC-Message-Signature: i=1; a=rsa-sha256; c=relaxed/relaxed; d=ffmpeg.org; s=arc; t=1768918204; h=from : sender : reply-to : subject : date : message-id : to : cc : mime-version : content-type : content-transfer-encoding : content-id : content-description : resent-date : resent-from : resent-sender : resent-to : resent-cc : resent-message-id : in-reply-to : references : list-id : list-help : list-unsubscribe : list-subscribe : list-post : list-owner : list-archive; bh=eJtebbfnutV5rmD4ifoJ2LGyXiIuaxYLLebWYYsx5vU=; b=alu2UHkMH4JF8hKcvrA5WdJOOV2WwhxPTks1ycOGd4fz3eqn9V9cF2NdFOmbHwmorbRDz +Eq0WsCA+tlE63fF1lraFLRp4FA+BdOKbnoGlIiPssyj+R+6Tu899Gcid9hJzX6mMRdwcOr hLkSww8Ql+pl6zhMVjyq4jRoVJzLpQpRQqRwG9BtUjR8Wz0BYtLaac3BIVqkATyS2D9CNtH Ci6/Oji70BoGgCmJzq3M3ZHZMPiW3U+pjEareVyQX/PgvUjNmd14D0XSe9NgkxR4/uCqLvd Wa2GPul80SwUJPptFTf/MOvUv8YZR98obbtQesr25TTixx4kShGHCXK3JktA== ARC-Authentication-Results: i=1; ffmpeg.org; dkim=pass header.d=gmail.com; arc=none; dmarc=pass header.from=gmail.com policy.dmarc=quarantine Authentication-Results: ffmpeg.org; dkim=pass header.d=gmail.com; arc=none (Message is not ARC signed); dmarc=pass (Used From Domain Record) header.from=gmail.com policy.dmarc=quarantine Received: from mail-pl1-f174.google.com (mail-pl1-f174.google.com [209.85.214.174]) by ffbox0-bg.ffmpeg.org (Postfix) with ESMTPS id 9E98468EB84 for ; Tue, 20 Jan 2026 16:09:51 +0200 (EET) Received: by mail-pl1-f174.google.com with SMTP id d9443c01a7336-2a2ea96930cso31547875ad.2 for ; Tue, 20 Jan 2026 06:09:51 -0800 (PST) DKIM-Signature: v=1; a=rsa-sha256; c=relaxed/relaxed; d=gmail.com; s=20230601; t=1768918189; x=1769522989; darn=ffmpeg.org; h=content-transfer-encoding:mime-version:message-id:date:subject:cc :to:from:from:to:cc:subject:date:message-id:reply-to; bh=lvaOYe22iOqs69ZZIbcZC95Bx07tcsV4Ny+5LpwJdng=; b=K5J9al2U5wR2gkYnT4or50pRzgPCdhtbGlwTihJmXKX8ue8L/UM57AkgwZuMxjWA6M hW+5ZfJonsvPD19D8fgcWbHI13Ht398w1gIZU5chpDNieFx93obztJKV53GWx68YxRKm 61TTxvIThddl9BtVjIbuy6tM0bkB6LcGMM25kEc7GT/B9WneLdv2u0yUZtViPx6GOUqO cyciNV2vyhdfkS6iijD0oEs3Ec7J0Uw7320GdecbHhsivEIBH+2H86owQ1EwQoD9kFmp gEJzhRI39NIYVgYZ0PSVaXQRX5DdeD7vXrmF/WQJ3bKvjP0zjbzxheqEECJIrAaG5BVK 3Ctg== X-Google-DKIM-Signature: v=1; a=rsa-sha256; c=relaxed/relaxed; d=1e100.net; s=20230601; t=1768918189; x=1769522989; h=content-transfer-encoding:mime-version:message-id:date:subject:cc :to:from:x-gm-gg:x-gm-message-state:from:to:cc:subject:date :message-id:reply-to; bh=lvaOYe22iOqs69ZZIbcZC95Bx07tcsV4Ny+5LpwJdng=; b=XkZAizGJvU2vFPcbPTlvDHOMb7KIU8oTVb9hs9Jfo2LYUc3tc+35KWfp6e2nUosI5f kMBCFc3cDqwdGS6DetKtKa83geUSrYR1WEc3l91C/iURsSpSZeueXaFCj0JwJ2FdLizN eNAJCUbR7TM0aKi4STxqJ6d4hu5F8KCBil7fUcKDKAUmEMcrojP0CLIuhQXds8RGRU5j 4uRySo+01ksi96SEWXbKw3EmSbqKLxJ8zU7ViS9tj8xO6hvaiY6HZkd3dGEjinatFLFA F5gHLP2BMzUfm6hOGGPWfJFFKhrrAhyB2g2VDKU8sLFn+6VAYlQWx07nqyqwtqD0jB1p AHww== X-Gm-Message-State: AOJu0YyPytUJoJ+sIanaMuQ3OmKf3A52WLpSFiquWfOypA6c11PGJsLE f+atODkqixvv6VzjT5L9y8WoFnCYsDhTe3uf/NyffTYPiv338w6sHgd37F7OYQ== X-Gm-Gg: AZuq6aLR7LAzgRZKBQsuv0RyvsI0yGiDxm36kIv6xNOTmH9XAQGZv8fzSH6j5Vrdg8C /vvoLql9KGZgMDDWsy+LRe6m2cKI6dXN7q7hvciyR3pk850UpvEfDdkXzqpGnATMTvDhFph31MG aVZvP833qJSksthNmOjBD18d0jE5qikmyjuIZ25L6AHWgkDym9iZGeObHB3ylySZF7quJjOut8U E02UwexIwDz0rBpqNBzIL+pSSvfjW172s7jhw3o8jv4MFQtzilCwX8HRn1a8qIvBnJBLd3NgIrL mCb/cqUT5vOMngnh5B2vlNFvIkNxUWQtLmWTcP4h567dubSIMa5mtOCJcQ5oSbyMD8wAGuAlKaf 86g+nUZTo2QLrhOb+OofaaNBARKbL5DKdMwQ2vQsENKuGqz4RN7m+2OS2UX9eMmKmIag9qtb39g GPG37oUZcUnaDoq0Hcdp89aR8FDhlChhwXicMP2aTY54DHuXDi09A2 X-Received: by 2002:a17:902:d583:b0:2a7:682b:5098 with SMTP id d9443c01a7336-2a7682b5178mr16164995ad.9.1768918189217; Tue, 20 Jan 2026 06:09:49 -0800 (PST) Received: from raja-rathour-ASUS-TUF-Gaming-A15 ([2409:40d0:12d0:e74c:a65c:c0fc:8cd6:691d]) by smtp.gmail.com with ESMTPSA id d9443c01a7336-2a7190ca058sm122729115ad.33.2026.01.20.06.09.46 (version=TLS1_3 cipher=TLS_AES_256_GCM_SHA384 bits=256/256); Tue, 20 Jan 2026 06:09:48 -0800 (PST) To: ffmpeg-devel@ffmpeg.org Date: Tue, 20 Jan 2026 19:39:36 +0530 Message-ID: <20260120140939.32403-1-imraja729@gmail.com> X-Mailer: git-send-email 2.51.0 MIME-Version: 1.0 Message-ID-Hash: C6JXN72PHQUZRD4U2DIFPIPQM4UC2OZK X-Message-ID-Hash: C6JXN72PHQUZRD4U2DIFPIPQM4UC2OZK X-MailFrom: SRS0=8m41=7Z=gmail.com=imraja729@ffmpeg.org X-Mailman-Rule-Misses: dmarc-mitigation; no-senders; approved; loop; banned-address; header-match-ffmpeg-devel.ffmpeg.org-0; header-match-ffmpeg-devel.ffmpeg.org-1; header-match-ffmpeg-devel.ffmpeg.org-2; header-match-ffmpeg-devel.ffmpeg.org-3; emergency; member-moderation; nonmember-moderation; administrivia; implicit-dest; max-recipients; max-size; news-moderation; no-subject; digests; suspicious-header X-Mailman-Version: 3.3.10 Precedence: list Reply-To: FFmpeg development discussions and patches Subject: [FFmpeg-devel] [PATCH 1/4] avfilter/dnn_backend_torch: implement common async infrastructure List-Id: FFmpeg development discussions and patches Archived-At: Archived-At: List-Archive: List-Archive: List-Help: List-Owner: List-Post: List-Subscribe: List-Unsubscribe: From: Raja Rathour via ffmpeg-devel Cc: Raja Rathour Content-Type: text/plain; charset="us-ascii" Content-Transfer-Encoding: 7bit Archived-At: List-Archive: List-Post: --- libavfilter/dnn/dnn_backend_torch.cpp | 354 ++++++++------------------ 1 file changed, 113 insertions(+), 241 deletions(-) diff --git a/libavfilter/dnn/dnn_backend_torch.cpp b/libavfilter/dnn/dnn_backend_torch.cpp index 33809bf983..4c781cc0b6 100644 --- a/libavfilter/dnn/dnn_backend_torch.cpp +++ b/libavfilter/dnn/dnn_backend_torch.cpp @@ -25,10 +25,6 @@ #include #include -#include -#include -#include -#include extern "C" { #include "dnn_io_proc.h" @@ -46,11 +42,6 @@ typedef struct THModel { SafeQueue *request_queue; Queue *task_queue; Queue *lltask_queue; - SafeQueue *pending_queue; ///< requests waiting for inference - std::thread *worker_thread; ///< background worker thread - std::mutex *mutex; ///< mutex for the condition variable - std::condition_variable *cond; ///< condition variable for worker wakeup - std::atomic worker_stop; ///< signal for thread exit } THModel; typedef struct THInferRequest { @@ -64,7 +55,6 @@ typedef struct THRequestItem { DNNAsyncExecModule exec_module; } THRequestItem; - #define OFFSET(x) offsetof(THOptions, x) #define FLAGS AV_OPT_FLAG_FILTERING_PARAM static const AVOption dnn_th_options[] = { @@ -104,15 +94,17 @@ static void th_free_request(THInferRequest *request) delete(request->input_tensor); request->input_tensor = NULL; } - return; + if (request->input_data) { + av_freep(&request->input_data); + request->input_data_size = 0; + } } static inline void destroy_request_item(THRequestItem **arg) { THRequestItem *item; - if (!arg || !*arg) { + if (!arg || !*arg) return; - } item = *arg; th_free_request(item->infer_request); av_freep(&item->infer_request); @@ -129,38 +121,6 @@ static void dnn_free_model_th(DNNModel **model) th_model = (THModel *)(*model); - /* 1. Stop and join the worker thread if it exists */ - if (th_model->worker_thread) { - { - std::lock_guard lock(*th_model->mutex); - th_model->worker_stop = true; - } - th_model->cond->notify_all(); - th_model->worker_thread->join(); - delete th_model->worker_thread; - th_model->worker_thread = NULL; - } - - /* 2. Safely delete C++ synchronization objects */ - if (th_model->mutex) { - delete th_model->mutex; - th_model->mutex = NULL; - } - if (th_model->cond) { - delete th_model->cond; - th_model->cond = NULL; - } - - /* 3. Clean up the pending queue */ - if (th_model->pending_queue) { - while (ff_safe_queue_size(th_model->pending_queue) > 0) { - THRequestItem *item = (THRequestItem *)ff_safe_queue_pop_front(th_model->pending_queue); - destroy_request_item(&item); - } - ff_safe_queue_destroy(th_model->pending_queue); - } - - /* 4. Clean up standard backend queues */ if (th_model->request_queue) { while (ff_safe_queue_size(th_model->request_queue) != 0) { THRequestItem *item = (THRequestItem *)ff_safe_queue_pop_front(th_model->request_queue); @@ -187,7 +147,6 @@ static void dnn_free_model_th(DNNModel **model) ff_queue_destroy(th_model->task_queue); } - /* 5. Final model cleanup */ if (th_model->jit_model) delete th_model->jit_model; @@ -214,37 +173,55 @@ static void deleter(void *arg) static int fill_model_input_th(THModel *th_model, THRequestItem *request) { - LastLevelTaskItem *lltask = NULL; - TaskItem *task = NULL; THInferRequest *infer_request = NULL; + TaskItem *task = NULL; + LastLevelTaskItem *lltask = NULL; DNNData input = { 0 }; DnnContext *ctx = th_model->ctx; int ret, width_idx, height_idx, channel_idx; + size_t cur_size; lltask = (LastLevelTaskItem *)ff_queue_pop_front(th_model->lltask_queue); - if (!lltask) { - ret = AVERROR(EINVAL); - goto err; - } + if (!lltask) + return AVERROR(EINVAL); + request->lltask = lltask; task = lltask->task; infer_request = request->infer_request; ret = get_input_th(&th_model->model, &input, NULL); - if ( ret != 0) { - goto err; - } + if (ret) + return ret; + width_idx = dnn_get_width_idx_by_layout(input.layout); height_idx = dnn_get_height_idx_by_layout(input.layout); channel_idx = dnn_get_channel_idx_by_layout(input.layout); input.dims[height_idx] = task->in_frame->height; input.dims[width_idx] = task->in_frame->width; - input.data = av_malloc(input.dims[height_idx] * input.dims[width_idx] * - input.dims[channel_idx] * sizeof(float)); - if (!input.data) - return AVERROR(ENOMEM); - infer_request->input_tensor = new torch::Tensor(); - infer_request->output = new torch::Tensor(); + + // Calculate required size for the current frame + cur_size = input.dims[height_idx] * input.dims[width_idx] * + input.dims[channel_idx] * sizeof(float); + + /** + * Dynamic Resizing Logic: + * Only reallocate if the existing buffer is too small or doesn't exist. + * Removed the (float *) cast to comply with FFmpeg style guidelines. + */ + if (!infer_request->input_data || infer_request->input_data_size < cur_size) { + av_freep(&infer_request->input_data); + infer_request->input_data = av_malloc(cur_size); + if (!infer_request->input_data) + return AVERROR(ENOMEM); + infer_request->input_data_size = cur_size; + } + + input.data = infer_request->input_data; + + if (!infer_request->input_tensor) + infer_request->input_tensor = new torch::Tensor(); + if (!infer_request->output) + infer_request->output = new torch::Tensor(); switch (th_model->model.func_type) { case DFT_PROCESS_FRAME: @@ -261,52 +238,30 @@ static int fill_model_input_th(THModel *th_model, THRequestItem *request) avpriv_report_missing_feature(NULL, "model function type %d", th_model->model.func_type); break; } + *infer_request->input_tensor = torch::from_blob(input.data, {1, input.dims[channel_idx], input.dims[height_idx], input.dims[width_idx]}, deleter, torch::kFloat32); - return 0; -err: - th_free_request(infer_request); - return ret; + return 0; } static int th_start_inference(void *args) { THRequestItem *request = (THRequestItem *)args; - THInferRequest *infer_request = NULL; - LastLevelTaskItem *lltask = NULL; - TaskItem *task = NULL; - THModel *th_model = NULL; - DnnContext *ctx = NULL; + THInferRequest *infer_request = request->infer_request; + LastLevelTaskItem *lltask = request->lltask; + TaskItem *task = lltask->task; + THModel *th_model = (THModel *)task->model; std::vector inputs; - torch::NoGradGuard no_grad; - - if (!request) { - av_log(NULL, AV_LOG_ERROR, "THRequestItem is NULL\n"); - return AVERROR(EINVAL); - } - infer_request = request->infer_request; - lltask = request->lltask; - task = lltask->task; - th_model = (THModel *)task->model; - ctx = th_model->ctx; - if (ctx->torch_option.optimize) - torch::jit::setGraphExecutorOptimize(true); - else - torch::jit::setGraphExecutorOptimize(false); + torch::jit::setGraphExecutorOptimize(!!th_model->ctx->torch_option.optimize); - if (!infer_request->input_tensor || !infer_request->output) { - av_log(ctx, AV_LOG_ERROR, "input or output tensor is NULL\n"); - return DNN_GENERIC_ERROR; - } - // Transfer tensor to the same device as model c10::Device device = (*th_model->jit_model->parameters().begin()).device(); if (infer_request->input_tensor->device() != device) *infer_request->input_tensor = infer_request->input_tensor->to(device); - inputs.push_back(*infer_request->input_tensor); + inputs.push_back(*infer_request->input_tensor); *infer_request->output = th_model->jit_model->forward(inputs).toTensor(); return 0; @@ -325,13 +280,12 @@ static void infer_completion_callback(void *args) { outputs.order = DCO_RGB; outputs.layout = DL_NCHW; outputs.dt = DNN_FLOAT; + if (sizes.size() == 4) { - // 4 dimensions: [batch_size, channel, height, width] - // this format of data is normally used for video frame SR - outputs.dims[0] = sizes.at(0); // N - outputs.dims[1] = sizes.at(1); // C - outputs.dims[2] = sizes.at(2); // H - outputs.dims[3] = sizes.at(3); // W + outputs.dims[0] = sizes.at(0); + outputs.dims[1] = sizes.at(1); + outputs.dims[2] = sizes.at(2); + outputs.dims[3] = sizes.at(3); } else { avpriv_report_missing_feature(th_model->ctx, "Support of this kind of model"); goto err; @@ -340,7 +294,6 @@ static void infer_completion_callback(void *args) { switch (th_model->model.func_type) { case DFT_PROCESS_FRAME: if (task->do_ioproc) { - // Post process can only deal with CPU memory. if (output->device() != torch::kCPU) *output = output->to(torch::kCPU); outputs.scale = 255; @@ -361,35 +314,11 @@ static void infer_completion_callback(void *args) { } task->inference_done++; av_freep(&request->lltask); + err: th_free_request(infer_request); - - if (ff_safe_queue_push_back(th_model->request_queue, request) < 0) { + if (ff_safe_queue_push_back(th_model->request_queue, request) < 0) destroy_request_item(&request); - av_log(th_model->ctx, AV_LOG_ERROR, "Unable to push back request_queue when failed to start inference.\n"); - } -} - -static void th_worker_thread(THModel *th_model) { - while (true) { - THRequestItem *request = NULL; - { - std::unique_lock lock(*th_model->mutex); - th_model->cond->wait(lock, [&]{ - return th_model->worker_stop || ff_safe_queue_size(th_model->pending_queue) > 0; - }); - - if (th_model->worker_stop && ff_safe_queue_size(th_model->pending_queue) == 0) - break; - - request = (THRequestItem *)ff_safe_queue_pop_front(th_model->pending_queue); - } - - if (request) { - th_start_inference(request); - infer_completion_callback(request); - } - } } static int execute_model_th(THRequestItem *request, Queue *lltask_queue) @@ -405,32 +334,27 @@ static int execute_model_th(THRequestItem *request, Queue *lltask_queue) } lltask = (LastLevelTaskItem *)ff_queue_peek_front(lltask_queue); - if (lltask == NULL) { - av_log(NULL, AV_LOG_ERROR, "Failed to get LastLevelTaskItem\n"); - ret = AVERROR(EINVAL); - goto err; + if (!lltask) { + destroy_request_item(&request); + return AVERROR(EINVAL); } + task = lltask->task; th_model = (THModel *)task->model; ret = fill_model_input_th(th_model, request); - if ( ret != 0) { - goto err; - } - if (task->async) { - std::lock_guard lock(*th_model->mutex); - if (ff_safe_queue_push_back(th_model->pending_queue, request) < 0) { - return AVERROR(ENOMEM); - } - th_model->cond->notify_one(); - return 0; + if (ret) { + th_free_request(request->infer_request); + if (ff_safe_queue_push_back(th_model->request_queue, request) < 0) + destroy_request_item(&request); + return ret; } -err: - th_free_request(request->infer_request); - if (ff_safe_queue_push_back(th_model->request_queue, request) < 0) { - destroy_request_item(&request); - } + if (task->async) + return ff_dnn_async_module_submit(&request->exec_module); + + ret = th_start_inference(request); + infer_completion_callback(request); return ret; } @@ -449,29 +373,29 @@ static int get_output_th(DNNModel *model, const char *input_name, int input_widt .in_frame = NULL, .out_frame = NULL, }; + ret = ff_dnn_fill_gettingoutput_task(&task, &exec_params, th_model, input_height, input_width, ctx); - if ( ret != 0) { - goto err; - } + if (ret) + return ret; ret = extract_lltask_from_task(&task, th_model->lltask_queue); - if ( ret != 0) { - av_log(ctx, AV_LOG_ERROR, "unable to extract last level task from task.\n"); - goto err; + if (ret) { + av_frame_free(&task.out_frame); + av_frame_free(&task.in_frame); + return ret; } request = (THRequestItem*) ff_safe_queue_pop_front(th_model->request_queue); if (!request) { - av_log(ctx, AV_LOG_ERROR, "unable to get infer request.\n"); - ret = AVERROR(EINVAL); - goto err; + av_frame_free(&task.out_frame); + av_frame_free(&task.in_frame); + return AVERROR(EINVAL); } ret = execute_model_th(request, th_model->lltask_queue); *output_width = task.out_frame->width; *output_height = task.out_frame->height; -err: av_frame_free(&task.out_frame); av_frame_free(&task.in_frame); return ret; @@ -479,105 +403,67 @@ err: static THInferRequest *th_create_inference_request(void) { - THInferRequest *request = (THInferRequest *)av_malloc(sizeof(THInferRequest)); - if (!request) { + THInferRequest *request = av_mallocz(sizeof(THInferRequest)); + if (!request) return NULL; - } - request->input_tensor = NULL; - request->output = NULL; return request; } static DNNModel *dnn_load_model_th(DnnContext *ctx, DNNFunctionType func_type, AVFilterContext *filter_ctx) { - DNNModel *model = NULL; - THModel *th_model = NULL; + THModel *th_model = av_mallocz(sizeof(THModel)); THRequestItem *item = NULL; - const char *device_name = ctx->device ? ctx->device : "cpu"; - th_model = (THModel *)av_mallocz(sizeof(THModel)); if (!th_model) return NULL; - model = &th_model->model; - th_model->ctx = ctx; - - c10::Device device = c10::Device(device_name); - if (device.is_xpu()) { - if (!at::hasXPU()) { - av_log(ctx, AV_LOG_ERROR, "No XPU device found\n"); - goto fail; - } - at::detail::getXPUHooks().initXPU(); - } else if (!device.is_cpu()) { - av_log(ctx, AV_LOG_ERROR, "Not supported device:\"%s\"\n", device_name); - goto fail; - } - try { - th_model->jit_model = new torch::jit::Module; - (*th_model->jit_model) = torch::jit::load(ctx->model_filename); - th_model->jit_model->to(device); - } catch (const c10::Error& e) { - av_log(ctx, AV_LOG_ERROR, "Failed to load torch model\n"); - goto fail; - } + th_model->ctx = ctx; + th_model->jit_model = new torch::jit::Module; + // Commit 1 uses the simplest loading logic + *th_model->jit_model = torch::jit::load(ctx->model_filename); th_model->request_queue = ff_safe_queue_create(); - if (!th_model->request_queue) { + if (!th_model->request_queue) goto fail; - } - item = (THRequestItem *)av_mallocz(sizeof(THRequestItem)); - if (!item) { + item = av_mallocz(sizeof(THRequestItem)); + if (!item) goto fail; - } - item->lltask = NULL; + item->infer_request = th_create_inference_request(); - if (!item->infer_request) { - av_log(NULL, AV_LOG_ERROR, "Failed to allocate memory for Torch inference request\n"); + if (!item->infer_request) goto fail; - } + + // Infrastructure setup for Async Module item->exec_module.start_inference = &th_start_inference; item->exec_module.callback = &infer_completion_callback; item->exec_module.args = item; - if (ff_safe_queue_push_back(th_model->request_queue, item) < 0) { + if (ff_safe_queue_push_back(th_model->request_queue, item) < 0) goto fail; - } item = NULL; th_model->task_queue = ff_queue_create(); - if (!th_model->task_queue) { + if (!th_model->task_queue) goto fail; - } th_model->lltask_queue = ff_queue_create(); - if (!th_model->lltask_queue) { - goto fail; - } - - th_model->pending_queue = ff_safe_queue_create(); - if (!th_model->pending_queue) { + if (!th_model->lltask_queue) goto fail; - } - th_model->mutex = new std::mutex(); - th_model->cond = new std::condition_variable(); - th_model->worker_stop = false; - th_model->worker_thread = new std::thread(th_worker_thread, th_model); + th_model->model.get_input = &get_input_th; + th_model->model.get_output = &get_output_th; + th_model->model.filter_ctx = filter_ctx; + th_model->model.func_type = func_type; - model->get_input = &get_input_th; - model->get_output = &get_output_th; - model->filter_ctx = filter_ctx; - model->func_type = func_type; - return model; + return &th_model->model; fail: - if (item) { + if (item) destroy_request_item(&item); - av_freep(&item); - } - dnn_free_model_th(&model); + // Passing the address of the model pointer + DNNModel *temp_model = &th_model->model; + dnn_free_model_th(&temp_model); return NULL; } @@ -590,42 +476,31 @@ static int dnn_execute_model_th(const DNNModel *model, DNNExecBaseParams *exec_p int ret = 0; ret = ff_check_exec_params(ctx, DNN_TH, model->func_type, exec_params); - if (ret != 0) { - av_log(ctx, AV_LOG_ERROR, "exec parameter checking fail.\n"); + if (ret) return ret; - } - task = (TaskItem *)av_malloc(sizeof(TaskItem)); - if (!task) { - av_log(ctx, AV_LOG_ERROR, "unable to alloc memory for task item.\n"); + task = av_mallocz(sizeof(TaskItem)); + if (!task) return AVERROR(ENOMEM); - } ret = ff_dnn_fill_task(task, exec_params, th_model, 0, 1); - if (ret != 0) { + if (ret) { av_freep(&task); - av_log(ctx, AV_LOG_ERROR, "unable to fill task.\n"); return ret; } - ret = ff_queue_push_back(th_model->task_queue, task); - if (ret < 0) { + if (ff_queue_push_back(th_model->task_queue, task) < 0) { av_freep(&task); - av_log(ctx, AV_LOG_ERROR, "unable to push back task_queue.\n"); - return ret; + return AVERROR(ENOMEM); } ret = extract_lltask_from_task(task, th_model->lltask_queue); - if (ret != 0) { - av_log(ctx, AV_LOG_ERROR, "unable to extract last level task from task.\n"); + if (ret) return ret; - } request = (THRequestItem *)ff_safe_queue_pop_front(th_model->request_queue); - if (!request) { - av_log(ctx, AV_LOG_ERROR, "unable to get infer request.\n"); + if (!request) return AVERROR(EINVAL); - } return execute_model_th(request, th_model->lltask_queue); } @@ -642,14 +517,11 @@ static int dnn_flush_th(const DNNModel *model) THRequestItem *request; if (ff_queue_size(th_model->lltask_queue) == 0) - // no pending task need to flush return 0; request = (THRequestItem *)ff_safe_queue_pop_front(th_model->request_queue); - if (!request) { - av_log(th_model->ctx, AV_LOG_ERROR, "unable to get infer request.\n"); + if (!request) return AVERROR(EINVAL); - } return execute_model_th(request, th_model->lltask_queue); } @@ -662,4 +534,4 @@ extern const DNNModule ff_dnn_backend_torch = { .get_result = dnn_get_result_th, .flush = dnn_flush_th, .free_model = dnn_free_model_th, -}; +}; \ No newline at end of file -- 2.51.0 _______________________________________________ ffmpeg-devel mailing list -- ffmpeg-devel@ffmpeg.org To unsubscribe send an email to ffmpeg-devel-leave@ffmpeg.org