From mboxrd@z Thu Jan 1 00:00:00 1970 Return-Path: Received: from ffbox0-bg.mplayerhq.hu (ffbox0-bg.ffmpeg.org [79.124.17.100]) by master.gitmailbox.com (Postfix) with ESMTPS id 40E774E326 for ; Mon, 10 Mar 2025 19:55:43 +0000 (UTC) Received: from [127.0.1.1] (localhost [127.0.0.1]) by ffbox0-bg.mplayerhq.hu (Postfix) with ESMTP id 6051E68E0CD; Mon, 10 Mar 2025 21:55:29 +0200 (EET) Received: from mail-wm1-f41.google.com (mail-wm1-f41.google.com [209.85.128.41]) by ffbox0-bg.mplayerhq.hu (Postfix) with ESMTPS id 2505D68C2D4 for ; Mon, 10 Mar 2025 21:55:28 +0200 (EET) Received: by mail-wm1-f41.google.com with SMTP id 5b1f17b1804b1-43cf0d787eeso16226235e9.3 for ; Mon, 10 Mar 2025 12:55:28 -0700 (PDT) DKIM-Signature: v=1; a=rsa-sha256; c=relaxed/relaxed; d=gmail.com; s=20230601; t=1741636527; x=1742241327; darn=ffmpeg.org; h=thread-index:content-language:content-transfer-encoding :mime-version:message-id:date:subject:to:from:from:to:cc:subject :date:message-id:reply-to; bh=6idr20jNtBdBh0AdNNnKt5aN84kb8j6X7Y6qNJ+iQXY=; b=PsEtD283kz/al8Rbsj4NzMBvMmgNMWyeQohSWN8QjOWl3NJsctaclR93PYBlWUDomL Eoc5+jqipOZQButr+nx1nfQ1C9oowTHzgmoEgRHomxBpfWZkaaNoEaWm/Fr2aDr6mFrc GIiKFrMsHHqA/J50Zre2wM8x5+ItmI9uw6kbC8RfJyoIzHClFTzx/bs81RgCBo7eTUDI tYY3zR50VJhtvzk9D/z6CF/QUiAIwnKMzITQNYjzaW8XOwx7mHniEufR9xauNLYJdMbO q76pvnWpNbaEYUQ+kjM4ekHziVB2nW9DWNg9LgX13+SmSMru3l0H7XIFi7Q7ICIBonxz vAhA== X-Google-DKIM-Signature: v=1; a=rsa-sha256; c=relaxed/relaxed; d=1e100.net; s=20230601; t=1741636527; x=1742241327; h=thread-index:content-language:content-transfer-encoding :mime-version:message-id:date:subject:to:from:x-gm-message-state :from:to:cc:subject:date:message-id:reply-to; bh=6idr20jNtBdBh0AdNNnKt5aN84kb8j6X7Y6qNJ+iQXY=; b=u/MhfjCG3+LTdUh5+w0xXP8n1/zbDoaoBliB2v9ywvCm1yOMIEE01LSVUfKdZObEbf 3tzD86jYO0QBTReEa0T410MUO7cLOJawPEElcbJ6dpW+HhW9WDYl8z76WZvUX1P83bLD WXR1rb2EAwjp/Y1WcdRhTKMNxrqDq72xEJPLWdjs7yrI+cPhS4Rbes/i7RmrN/2BjQJH puUNaWGa4AFA2AsHd9FttjzBIrYzOwKeT6qdItZUazugAuFdCk6BKIGuhiuXh4855azT O8wvs2tRAnNzNvhlI3RCAxWKFTJbR+E0tv3uw4zJO5IKYZebHNSMClqYaX4iZC/y8490 MuQw== X-Gm-Message-State: AOJu0YxVZkmPkRtgSHGZ/14n0zB56o829iCI2RKMYEbVw/OQl7pUUAT8 /VrLg+kepjW4E+ikvfy5lr+cdunft9slmMb+FyP0+BUWfpkhoblraQEa1g== X-Gm-Gg: ASbGnctFLNo3rwR4QBqdK73mIny6M+/N/+dZMBZU/ViN4Ae+wq4SjOlWTKlGWNLbVel pYQNnCaE2kxqO2iVNArf3ofXRsEg0h+kx5mnE2fS2O/21mgDgadN4llUntmyQdvcW70CrfJMo9A h2WBf6tJTxJeKUKWj2JIBEndDBzndlFq1OnF8aVKSRGAPMNBgrAv133mPfv66aXkXBEFguKr9Cn HruAweZn544fSnSBiYEPZYy++O0rpFmYy8yA7IOWmNIJ+C6lNoKDVDZuYTZR5H3yB8LfHEL8CNN XtdZk9/UjuJVJ2PEaPugYGNAbLGOBsEC5hx0Dzammv2PzwtvQCILlesh+iTc8rBQ/v8bT9K+md4 A6e/H916CT5HW7ir/ X-Google-Smtp-Source: AGHT+IHhNQYfr1JudA21/jElCpuv96qhyGM7t42XE0FcjfHXcwOkvu8Xiykn/A3e9nGbBa62JLTt0g== X-Received: by 2002:a05:600c:35ca:b0:43c:f81d:f with SMTP id 5b1f17b1804b1-43cf81d01fdmr43228795e9.8.1741636526760; Mon, 10 Mar 2025 12:55:26 -0700 (PDT) Received: from MK2 (80-108-16-220.cable.dynamic.surfer.at. [80.108.16.220]) by smtp.gmail.com with ESMTPSA id ffacd0b85a97d-3912bfb79fbsm16167484f8f.13.2025.03.10.12.55.26 for (version=TLS1_2 cipher=ECDHE-ECDSA-AES128-GCM-SHA256 bits=128/128); Mon, 10 Mar 2025 12:55:26 -0700 (PDT) From: To: Date: Mon, 10 Mar 2025 20:55:26 +0100 Message-ID: <004401db91f6$5e4bf460$1ae3dd20$@gmail.com> MIME-Version: 1.0 X-Mailer: Microsoft Outlook 16.0 Content-Language: en-at Thread-Index: AduR9llTEY1Mvq6iQ4KVudl6MOC1KA== Subject: [FFmpeg-devel] [PATCH v2 FFmpeg 16/20] libavfilter/dnn/dnn_backend_torch: CLIP/CLAP Inference handling and support for detection bboxes from dnn_detect filter X-BeenThere: ffmpeg-devel@ffmpeg.org X-Mailman-Version: 2.1.29 Precedence: list List-Id: FFmpeg development discussions and patches List-Unsubscribe: , List-Archive: List-Post: List-Help: List-Subscribe: , Reply-To: FFmpeg development discussions and patches Content-Type: text/plain; charset="us-ascii" Content-Transfer-Encoding: 7bit Errors-To: ffmpeg-devel-bounces@ffmpeg.org Sender: "ffmpeg-devel" Archived-At: List-Archive: List-Post: Signed-off-by: MaximilianKaindl --- libavfilter/dnn/dnn_backend_torch.cpp | 411 +++++++++++++++++++------- 1 file changed, 311 insertions(+), 100 deletions(-) diff --git a/libavfilter/dnn/dnn_backend_torch.cpp b/libavfilter/dnn/dnn_backend_torch.cpp index 1d2bfb191a..26b57f08f3 100644 --- a/libavfilter/dnn/dnn_backend_torch.cpp +++ b/libavfilter/dnn/dnn_backend_torch.cpp @@ -1,22 +1,22 @@ /* - * Copyright (c) 2024 - * - * This file is part of FFmpeg. - * - * FFmpeg is free software; you can redistribute it and/or - * modify it under the terms of the GNU Lesser General Public - * License as published by the Free Software Foundation; either - * version 2.1 of the License, or (at your option) any later version. - * - * FFmpeg is distributed in the hope that it will be useful, - * but WITHOUT ANY WARRANTY; without even the implied warranty of - * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU - * Lesser General Public License for more details. - * - * You should have received a copy of the GNU Lesser General Public - * License along with FFmpeg; if not, write to the Free Software - * Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA - */ +* Copyright (c) 2024 +* +* This file is part of FFmpeg. +* +* FFmpeg is free software; you can redistribute it and/or +* modify it under the terms of the GNU Lesser General Public +* License as published by the Free Software Foundation; either +* version 2.1 of the License, or (at your option) any later version. +* +* FFmpeg is distributed in the hope that it will be useful, +* but WITHOUT ANY WARRANTY; without even the implied warranty of +* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU +* Lesser General Public License for more details. +* +* You should have received a copy of the GNU Lesser General Public +* License along with FFmpeg; if not, write to the Free Software +* Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA +*/ /** * @file @@ -31,7 +31,9 @@ extern "C" { #include "dnn_io_proc.h" #include "dnn_backend_common.h" #include "libavutil/opt.h" +#include "libavutil/avassert.h" #include "libavutil/avstring.h" +#include "libavutil/detection_bbox.h" #include "libavutil/mem.h" #include "queue.h" #include "safe_queue.h" @@ -65,11 +67,11 @@ typedef struct THInferRequest { typedef struct THRequestItem { THInferRequest *infer_request; - LastLevelTaskItem *lltask; + LastLevelTaskItem **lltasks; + int lltask_count; DNNAsyncExecModule exec_module; } THRequestItem; - #define OFFSET(x) offsetof(THOptions, x) #define FLAGS AV_OPT_FLAG_FILTERING_PARAM static const AVOption dnn_th_options[] = { @@ -77,24 +79,95 @@ static const AVOption dnn_th_options[] = { { NULL } }; -static int extract_lltask_from_task(TaskItem *task, Queue *lltask_queue) +static int extract_lltask_from_task(DNNFunctionType func_type, TaskItem *task, Queue *lltask_queue, DNNExecBaseParams *exec_params) { THModel *th_model = (THModel *)task->model; DnnContext *ctx = th_model->ctx; - LastLevelTaskItem *lltask = (LastLevelTaskItem *)av_malloc(sizeof(*lltask)); - if (!lltask) { - av_log(ctx, AV_LOG_ERROR, "Failed to allocate memory for LastLevelTaskItem\n"); - return AVERROR(ENOMEM); - } - task->inference_todo = 1; - task->inference_done = 0; - lltask->task = task; - if (ff_queue_push_back(lltask_queue, lltask) < 0) { - av_log(ctx, AV_LOG_ERROR, "Failed to push back lltask_queue.\n"); - av_freep(&lltask); - return AVERROR(ENOMEM); + + switch (func_type) { + case DFT_PROCESS_FRAME: + case DFT_ANALYTICS_CLAP: { + LastLevelTaskItem *lltask = (LastLevelTaskItem *)av_malloc(sizeof(*lltask)); + if (!lltask) { + av_log(ctx, AV_LOG_ERROR, "Failed to allocate memory for LastLevelTaskItem\n"); + return AVERROR(ENOMEM); + } + task->inference_todo = 1; + task->inference_done = 0; + lltask->bbox_index = 0; + lltask->task = task; + if (ff_queue_push_back(lltask_queue, lltask) < 0) { + av_log(ctx, AV_LOG_ERROR, "Failed to push back lltask_queue.\n"); + av_freep(&lltask); + return AVERROR(ENOMEM); + } + return 0; + } + case DFT_ANALYTICS_CLIP: { + const AVDetectionBBoxHeader *header; + AVFrame *frame = task->in_frame; + AVFrameSideData *sd; + LastLevelTaskItem *lltask; + DNNExecZeroShotClassificationParams *params = (DNNExecZeroShotClassificationParams *)exec_params; + + if (params->target == NULL) { + LastLevelTaskItem *lltask = (LastLevelTaskItem *)av_malloc(sizeof(*lltask)); + if (!lltask) { + av_log(ctx, AV_LOG_ERROR, "Failed to allocate memory for LastLevelTaskItem\n"); + return AVERROR(ENOMEM); + } + task->inference_todo = 1; + task->inference_done = 0; + lltask->bbox_index = 0; + lltask->task = task; + if (ff_queue_push_back(lltask_queue, lltask) < 0) { + av_log(ctx, AV_LOG_ERROR, "Failed to push back lltask_queue.\n"); + av_freep(&lltask); + return AVERROR(ENOMEM); + } + return 0; + } + + task->inference_todo = 0; + task->inference_done = 0; + + if (!ff_dnn_contain_valid_detection_bbox(frame)) { + return 0; + } + + sd = av_frame_get_side_data(frame, AV_FRAME_DATA_DETECTION_BBOXES); + header = (const AVDetectionBBoxHeader *)sd->data; + + for (uint32_t i = 0; i < header->nb_bboxes; i++) { + const AVDetectionBBox *bbox = av_get_detection_bbox(header, i); + if (bbox->w * bbox->h <= 0) { + continue; + } + if (params->target) { + if (av_strncasecmp(bbox->detect_label, params->target, sizeof(bbox->detect_label)) != 0) { + continue; + } + } + + lltask = (LastLevelTaskItem *)av_malloc(sizeof(*lltask)); + if (!lltask) { + return AVERROR(ENOMEM); + } + task->inference_todo++; + lltask->task = task; + lltask->bbox_index = i; + if (ff_queue_push_back(lltask_queue, lltask) < 0) { + av_freep(&lltask); + return AVERROR(ENOMEM); + } + } + return 0; + } + default: { + av_assert0(!"should not reach here"); + return AVERROR(EINVAL); + } } - return 0; } static void th_free_request(THInferRequest *request) @@ -121,7 +194,7 @@ static inline void destroy_request_item(THRequestItem **arg) item = *arg; th_free_request(item->infer_request); av_freep(&item->infer_request); - av_freep(&item->lltask); + av_freep(&item->lltasks); ff_dnn_async_module_cleanup(&item->exec_module); av_freep(arg); } @@ -187,9 +260,9 @@ static int get_input_th(DNNModel *model, DNNData *input, const char *input_name) return 0; } -static void deleter(void *arg) -{ - av_freep(&arg); +static void deleter(void *arg) +{ + av_freep(&arg); } #if (CONFIG_LIBTOKENIZERS == 1) @@ -590,19 +663,11 @@ static int fill_model_input_th(THModel *th_model, THRequestItem *request) { LastLevelTaskItem *lltask = NULL; TaskItem *task = NULL; - THInferRequest *infer_request = NULL; + THInferRequest *infer_request = request->infer_request; DNNData input = { 0 }; DnnContext *ctx = th_model->ctx; int ret, width_idx, height_idx, channel_idx; - - lltask = (LastLevelTaskItem *)ff_queue_pop_front(th_model->lltask_queue); - if (!lltask) { - ret = AVERROR(EINVAL); - goto err; - } - request->lltask = lltask; - task = lltask->task; - infer_request = request->infer_request; + std::vector batch_tensors; ret = get_input_th(&th_model->model, &input, NULL); if ( ret != 0) { @@ -611,36 +676,92 @@ static int fill_model_input_th(THModel *th_model, THRequestItem *request) width_idx = dnn_get_width_idx_by_layout(input.layout); height_idx = dnn_get_height_idx_by_layout(input.layout); channel_idx = dnn_get_channel_idx_by_layout(input.layout); - input.dims[height_idx] = task->in_frame->height; - input.dims[width_idx] = task->in_frame->width; - input.data = av_malloc(input.dims[height_idx] * input.dims[width_idx] * - input.dims[channel_idx] * sizeof(float)); - if (!input.data) - return AVERROR(ENOMEM); infer_request->input_tensor = new torch::Tensor(); infer_request->output = new torch::Tensor(); - switch (th_model->model.func_type) { - case DFT_PROCESS_FRAME: - input.scale = 255; - if (task->do_ioproc) { - if (th_model->model.frame_pre_proc != NULL) { - th_model->model.frame_pre_proc(task->in_frame, &input, th_model->model.filter_ctx); - } else { - ff_proc_from_frame_to_dnn(task->in_frame, &input, ctx); + // Handle audio input + if (th_model->model.func_type == DFT_ANALYTICS_CLAP) { + lltask = (LastLevelTaskItem *)ff_queue_pop_front(th_model->lltask_queue); + if (!lltask) { + return -1; + } + request->lltasks[request->lltask_count++] = lltask; + task = lltask->task; + return prepare_audio_tensor(th_model, request); + } + + while (ff_queue_size(th_model->lltask_queue) != 0) { + lltask = (LastLevelTaskItem *)ff_queue_pop_front(th_model->lltask_queue); + if (!lltask) { + break; + } + request->lltasks[request->lltask_count++] = lltask; + task = lltask->task; + + input.dims[height_idx] = task->in_frame->height; + input.dims[width_idx] = task->in_frame->width; + input.data =av_malloc(input.dims[height_idx] * input.dims[width_idx] * + input.dims[channel_idx] * sizeof(float)); + if (!input.data) { + ret = AVERROR(ENOMEM); + goto err; + } + switch (th_model->model.func_type) { + case DFT_PROCESS_FRAME: + case DFT_ANALYTICS_CLIP: + input.scale = 255; + if (task->do_ioproc) { + if (th_model->model.frame_pre_proc != NULL) { + th_model->model.frame_pre_proc(task->in_frame, &input, th_model->model.filter_ctx); + } else { + ff_proc_from_frame_to_dnn(task->in_frame, &input, ctx); + } + } + break; + default: + avpriv_report_missing_feature(NULL, "model function type %d", th_model->model.func_type); + ret = AVERROR(EINVAL); + goto err; + } + + try { + auto tensor = torch::from_blob(input.data, + {1, input.dims[channel_idx], input.dims[height_idx], input.dims[width_idx]}, + deleter, torch::kFloat32).clone(); + if (th_model->model.func_type == DFT_ANALYTICS_CLIP) { + preprocess_image_tensor(th_model, &tensor, torch::kCPU); } + batch_tensors.push_back(tensor); + input.data = NULL; + } catch (const c10::Error &e) { + av_log(ctx, AV_LOG_ERROR, "Error creating tensor: %s\n", e.what()); + ret = AVERROR(EINVAL); + goto err; } - break; - default: - avpriv_report_missing_feature(NULL, "model function type %d", th_model->model.func_type); - break; + + av_freep(&input.data); + } + + // Stack tensors into batch + try { + if (!batch_tensors.empty()) { + *infer_request->input_tensor = torch::cat(batch_tensors, 0); + } else { + av_log(ctx, AV_LOG_ERROR, "No tensors to process\n"); + ret = AVERROR(EINVAL); + goto err; + } + } catch (const c10::Error &e) { + av_log(ctx, AV_LOG_ERROR, "Error creating batch tensor: %s\n", e.what()); + ret = AVERROR(EINVAL); + goto err; } - *infer_request->input_tensor = torch::from_blob(input.data, - {1, input.dims[channel_idx], input.dims[height_idx], input.dims[width_idx]}, - deleter, torch::kFloat32); return 0; err: + if (input.data) { + av_freep(&input.data); + } th_free_request(infer_request); return ret; } @@ -661,7 +782,7 @@ static int th_start_inference(void *args) return AVERROR(EINVAL); } infer_request = request->infer_request; - lltask = request->lltask; + lltask = request->lltasks[0]; task = lltask->task; th_model = (THModel *)task->model; ctx = th_model->ctx; @@ -681,16 +802,54 @@ static int th_start_inference(void *args) *infer_request->input_tensor = infer_request->input_tensor->to(device); inputs.push_back(*infer_request->input_tensor); - *infer_request->output = th_model->jit_model->forward(inputs).toTensor(); +#if (CONFIG_LIBTOKENIZERS == 1) + if (th_model->model.func_type == DFT_ANALYTICS_CLIP) { + inputs.push_back(*th_model->clxp_ctx->tokenized_text); + } else if (th_model->model.func_type == DFT_ANALYTICS_CLAP) { + inputs.push_back(*th_model->clxp_ctx->tokenized_text); + inputs.push_back(*th_model->clxp_ctx->attention_mask); + } +#endif + + auto result = th_model->jit_model->forward(inputs); + + if (th_model->model.func_type == DFT_PROCESS_FRAME) { + *infer_request->output = result.toTensor(); + } else if (th_model->model.func_type == DFT_ANALYTICS_CLIP || th_model->model.func_type == DFT_ANALYTICS_CLAP) { + if (result.isTuple()) { + auto result_tuple = result.toTuple(); + torch::Tensor media_embeddings; + torch::Tensor text_embeddings; + float logit_scale = th_model->ctx->torch_option.logit_scale; + if (th_model->ctx->torch_option.forward_order == 1) { + media_embeddings = result_tuple->elements()[1].toTensor(); + text_embeddings = result_tuple->elements()[0].toTensor(); + *infer_request->output = calculate_similarity(media_embeddings, text_embeddings, + th_model->ctx->torch_option.normalize, logit_scale, ctx); + } else { + media_embeddings = result_tuple->elements()[0].toTensor(); + text_embeddings = result_tuple->elements()[1].toTensor(); + *infer_request->output = calculate_similarity(media_embeddings, text_embeddings, + th_model->ctx->torch_option.normalize, logit_scale, ctx); + } + *infer_request->output = + apply_softmax(*infer_request->output, th_model->ctx->torch_option.temperature, + th_model->clxp_ctx->softmax_units, th_model->clxp_ctx->softmax_units_count, ctx); + } + } else { + avpriv_report_missing_feature(ctx, "model function type %d", th_model->model.func_type); + return AVERROR(EINVAL); + } return 0; } -static void infer_completion_callback(void *args) { - THRequestItem *request = (THRequestItem*)args; - LastLevelTaskItem *lltask = request->lltask; +static void infer_completion_callback(void *args) +{ + THRequestItem *request = (THRequestItem *)args; + LastLevelTaskItem *lltask = request->lltasks[0]; TaskItem *task = lltask->task; - DNNData outputs = { 0 }; + DNNData outputs = {0}; THInferRequest *infer_request = request->infer_request; THModel *th_model = (THModel *)task->model; torch::Tensor *output = infer_request->output; @@ -699,7 +858,17 @@ static void infer_completion_callback(void *args) { outputs.order = DCO_RGB; outputs.layout = DL_NCHW; outputs.dt = DNN_FLOAT; - if (sizes.size() == 4) { + if (th_model->model.func_type == DFT_ANALYTICS_CLIP || th_model->model.func_type == DFT_ANALYTICS_CLAP) { + // CLIP outputs are similarity scores [batch_size, num_labels] + if (sizes.size() != 2) { + av_log(th_model->ctx, AV_LOG_ERROR, "Invalid CLIP output dimensions\n"); + goto err; + } + outputs.dims[0] = sizes[0]; // batch_size + outputs.dims[1] = sizes[1]; // number of labels + outputs.order = th_model->model.func_type == DFT_ANALYTICS_CLIP ? DCO_RGB : DCO_NONE; + outputs.dt = DNN_FLOAT; + } else if (sizes.size() == 4 && th_model->model.func_type == DFT_PROCESS_FRAME) { // 4 dimensions: [batch_size, channel, height, width] // this format of data is normally used for video frame SR outputs.dims[0] = sizes.at(0); // N @@ -711,31 +880,62 @@ static void infer_completion_callback(void *args) { goto err; } - switch (th_model->model.func_type) { - case DFT_PROCESS_FRAME: - if (task->do_ioproc) { + // Process all Tasks + for (int i = 0; i < request->lltask_count; i++) { + LastLevelTaskItem *lltask = request->lltasks[i]; + TaskItem *task = lltask->task; + + // Extract single output + torch::Tensor single_output; + try { + single_output = output->select(0, i); + // Post process can only deal with CPU memory. - if (output->device() != torch::kCPU) - *output = output->to(torch::kCPU); - outputs.scale = 255; - outputs.data = output->data_ptr(); - if (th_model->model.frame_post_proc != NULL) { - th_model->model.frame_post_proc(task->out_frame, &outputs, th_model->model.filter_ctx); - } else { - ff_proc_from_dnn_to_frame(task->out_frame, &outputs, th_model->ctx); + if (single_output.device() != torch::kCPU) { + single_output = single_output.to(torch::kCPU); } - } else { - task->out_frame->width = outputs.dims[dnn_get_width_idx_by_layout(outputs.layout)]; - task->out_frame->height = outputs.dims[dnn_get_height_idx_by_layout(outputs.layout)]; + + outputs.data = single_output.data_ptr(); + } catch (const c10::Error &e) { + av_log(th_model->ctx, AV_LOG_ERROR, "Error processing output tensor: %s\n", e.what()); + goto err; } - break; - default: - avpriv_report_missing_feature(th_model->ctx, "model function type %d", th_model->model.func_type); - goto err; + + switch (th_model->model.func_type) { + case DFT_PROCESS_FRAME: + if (task->do_ioproc) { + outputs.scale = 255; + if (th_model->model.frame_post_proc != NULL) { + th_model->model.frame_post_proc(task->out_frame, &outputs, th_model->model.filter_ctx); + } else { + ff_proc_from_dnn_to_frame(task->out_frame, &outputs, th_model->ctx); + } + } else { + task->out_frame->width = outputs.dims[dnn_get_width_idx_by_layout(outputs.layout)]; + task->out_frame->height = outputs.dims[dnn_get_height_idx_by_layout(outputs.layout)]; + } + break; + case DFT_ANALYTICS_CLIP: + case DFT_ANALYTICS_CLAP: + if (task->do_ioproc) { + if (!th_model->model.classify_post_proc) { + av_log(th_model->ctx, AV_LOG_ERROR, "CLIP/CLAP filter needs to provide post proc\n"); + goto err; + } + th_model->model.classify_post_proc(task->in_frame, &outputs, lltask->bbox_index, + th_model->model.filter_ctx); + } + break; + default: + avpriv_report_missing_feature(th_model->ctx, "model function type %d", th_model->model.func_type); + goto err; + } + task->inference_done++; + av_freep(&request->lltasks[i]); } - task->inference_done++; - av_freep(&request->lltask); err: + av_freep(&request->lltasks); + request->lltask_count = 0; th_free_request(infer_request); if (ff_safe_queue_push_back(th_model->request_queue, request) < 0) { @@ -789,7 +989,7 @@ err: } static int get_output_th(DNNModel *model, const char *input_name, int input_width, int input_height, - const char *output_name, int *output_width, int *output_height) + const char *output_name, int *output_width, int *output_height) { int ret = 0; THModel *th_model = (THModel*) model; @@ -808,7 +1008,7 @@ static int get_output_th(DNNModel *model, const char *input_name, int input_widt goto err; } - ret = extract_lltask_from_task(&task, th_model->lltask_queue); + ret = extract_lltask_from_task(th_model->model.func_type, &task, th_model->lltask_queue, NULL); if ( ret != 0) { av_log(ctx, AV_LOG_ERROR, "unable to extract last level task from task.\n"); goto err; @@ -918,7 +1118,7 @@ static THModel *init_model_th(DnnContext *ctx, DNNFunctionType func_type, AVFilt if (!item) { goto fail; } - item->lltask = NULL; + item->lltasks = NULL; item->infer_request = th_create_inference_request(); if (!item->infer_request) { av_log(NULL, AV_LOG_ERROR, "Failed to allocate memory for Torch inference request\n"); @@ -1078,18 +1278,29 @@ static int dnn_execute_model_th(const DNNModel *model, DNNExecBaseParams *exec_p return ret; } - ret = extract_lltask_from_task(task, th_model->lltask_queue); + ret = extract_lltask_from_task(model->func_type, task, th_model->lltask_queue, exec_params); if (ret != 0) { av_log(ctx, AV_LOG_ERROR, "unable to extract last level task from task.\n"); return ret; } + if (task->inference_todo == 0) { + return 0; + } + request = (THRequestItem *)ff_safe_queue_pop_front(th_model->request_queue); if (!request) { av_log(ctx, AV_LOG_ERROR, "unable to get infer request.\n"); return AVERROR(EINVAL); } + request->lltasks = (LastLevelTaskItem **)av_malloc_array(task->inference_todo, sizeof(*request->lltasks)); + if (!request->lltasks) { + av_log(ctx, AV_LOG_ERROR, "unable to create lltasks.\n"); + return AVERROR(EINVAL); + } + request->lltask_count = 0; + return execute_model_th(request, th_model->lltask_queue); } -- 2.34.1 _______________________________________________ ffmpeg-devel mailing list ffmpeg-devel@ffmpeg.org https://ffmpeg.org/mailman/listinfo/ffmpeg-devel To unsubscribe, visit link above, or email ffmpeg-devel-request@ffmpeg.org with subject "unsubscribe".