first commit

This commit is contained in:
nasir@endelospay.com
2025-08-12 02:54:17 +05:00
commit d97cad1736
225 changed files with 137626 additions and 0 deletions

View File

@@ -0,0 +1,384 @@
/* eslint-disable */
// js/similarity.worker.js
importScripts('../libs/ort.min.js'); // 调整路径以匹配您的文件结构
// 全局Worker状态
let session = null;
let modelPathInternal = null;
let ortEnvConfigured = false;
let sessionOptions = null;
let modelInputNames = null; // 存储模型的输入名称
// 复用的 TypedArray 缓冲区,减少内存分配
let reusableBuffers = {
inputIds: null,
attentionMask: null,
tokenTypeIds: null,
};
// 性能统计
let workerStats = {
totalInferences: 0,
totalInferenceTime: 0,
averageInferenceTime: 0,
memoryAllocations: 0,
};
// 配置 ONNX Runtime 环境 (仅一次)
function configureOrtEnv(numThreads = 1, executionProviders = ['wasm']) {
if (ortEnvConfigured) return;
try {
ort.env.wasm.numThreads = numThreads;
ort.env.wasm.simd = true; // 尽可能启用SIMD
ort.env.wasm.proxy = false; // 在Worker中通常不需要代理
ort.env.logLevel = 'warning'; // 'verbose', 'info', 'warning', 'error', 'fatal'
ortEnvConfigured = true;
sessionOptions = {
executionProviders: executionProviders,
graphOptimizationLevel: 'all',
enableCpuMemArena: true,
enableMemPattern: true,
// executionMode: 'sequential' // 在worker内部通常是顺序执行一个任务
};
} catch (error) {
console.error('Worker: Failed to configure ORT environment', error);
throw error; // 抛出错误,让主线程知道
}
}
async function initializeModel(modelPathOrData, numThreads, executionProviders) {
try {
configureOrtEnv(numThreads, executionProviders); // 确保环境已配置
if (!modelPathOrData) {
throw new Error('Worker: Model path or data is not provided.');
}
// Check if input is ArrayBuffer (cached model data) or string (URL path)
if (modelPathOrData instanceof ArrayBuffer) {
console.log(
`Worker: Initializing model from cached ArrayBuffer (${modelPathOrData.byteLength} bytes)`,
);
session = await ort.InferenceSession.create(modelPathOrData, sessionOptions);
modelPathInternal = '[Cached ArrayBuffer]'; // For debugging purposes
} else {
console.log(`Worker: Initializing model from URL: ${modelPathOrData}`);
modelPathInternal = modelPathOrData; // 存储模型路径以备调试或重载(如果需要)
session = await ort.InferenceSession.create(modelPathInternal, sessionOptions);
}
// 获取模型的输入名称用于判断是否需要token_type_ids
modelInputNames = session.inputNames;
console.log(`Worker: ONNX session created successfully for model: ${modelPathInternal}`);
console.log(`Worker: Model input names:`, modelInputNames);
return { status: 'success', message: 'Model initialized' };
} catch (error) {
console.error(`Worker: Model initialization failed:`, error);
session = null; // 清理session以防部分初始化
modelInputNames = null;
// 将错误信息序列化因为Error对象本身可能无法直接postMessage
throw new Error(`Worker: Model initialization failed - ${error.message}`);
}
}
// 优化的缓冲区管理函数
function getOrCreateBuffer(name, requiredLength, type = BigInt64Array) {
if (!reusableBuffers[name] || reusableBuffers[name].length < requiredLength) {
reusableBuffers[name] = new type(requiredLength);
workerStats.memoryAllocations++;
}
return reusableBuffers[name];
}
// 优化的批处理推理函数
async function runBatchInference(batchData) {
if (!session) {
throw new Error("Worker: Session not initialized. Call 'initializeModel' first.");
}
const startTime = performance.now();
try {
const feeds = {};
const batchSize = batchData.dims.input_ids[0];
const seqLength = batchData.dims.input_ids[1];
// 优化:复用缓冲区,减少内存分配
const inputIdsLength = batchData.input_ids.length;
const attentionMaskLength = batchData.attention_mask.length;
// 复用或创建 BigInt64Array 缓冲区
const inputIdsBuffer = getOrCreateBuffer('inputIds', inputIdsLength);
const attentionMaskBuffer = getOrCreateBuffer('attentionMask', attentionMaskLength);
// 批量填充数据(避免 map 操作)
for (let i = 0; i < inputIdsLength; i++) {
inputIdsBuffer[i] = BigInt(batchData.input_ids[i]);
}
for (let i = 0; i < attentionMaskLength; i++) {
attentionMaskBuffer[i] = BigInt(batchData.attention_mask[i]);
}
feeds['input_ids'] = new ort.Tensor(
'int64',
inputIdsBuffer.slice(0, inputIdsLength),
batchData.dims.input_ids,
);
feeds['attention_mask'] = new ort.Tensor(
'int64',
attentionMaskBuffer.slice(0, attentionMaskLength),
batchData.dims.attention_mask,
);
// 处理 token_type_ids - 只有当模型需要时才提供
if (modelInputNames && modelInputNames.includes('token_type_ids')) {
if (batchData.token_type_ids && batchData.dims.token_type_ids) {
const tokenTypeIdsLength = batchData.token_type_ids.length;
const tokenTypeIdsBuffer = getOrCreateBuffer('tokenTypeIds', tokenTypeIdsLength);
for (let i = 0; i < tokenTypeIdsLength; i++) {
tokenTypeIdsBuffer[i] = BigInt(batchData.token_type_ids[i]);
}
feeds['token_type_ids'] = new ort.Tensor(
'int64',
tokenTypeIdsBuffer.slice(0, tokenTypeIdsLength),
batchData.dims.token_type_ids,
);
} else {
// 创建默认的全零 token_type_ids
const tokenTypeIdsBuffer = getOrCreateBuffer('tokenTypeIds', inputIdsLength);
tokenTypeIdsBuffer.fill(0n, 0, inputIdsLength);
feeds['token_type_ids'] = new ort.Tensor(
'int64',
tokenTypeIdsBuffer.slice(0, inputIdsLength),
batchData.dims.input_ids,
);
}
} else {
console.log('Worker: Skipping token_type_ids as model does not require it');
}
// 执行批处理推理
const results = await session.run(feeds);
const outputTensor = results.last_hidden_state || results[Object.keys(results)[0]];
// 使用 Transferable Objects 优化数据传输
const outputData = new Float32Array(outputTensor.data);
// 更新统计信息
workerStats.totalInferences += batchSize; // 批处理计算多个推理
const inferenceTime = performance.now() - startTime;
workerStats.totalInferenceTime += inferenceTime;
workerStats.averageInferenceTime = workerStats.totalInferenceTime / workerStats.totalInferences;
return {
status: 'success',
output: {
data: outputData,
dims: outputTensor.dims,
batchSize: batchSize,
seqLength: seqLength,
},
transferList: [outputData.buffer],
stats: {
inferenceTime,
totalInferences: workerStats.totalInferences,
averageInferenceTime: workerStats.averageInferenceTime,
memoryAllocations: workerStats.memoryAllocations,
batchSize: batchSize,
},
};
} catch (error) {
console.error('Worker: Batch inference failed:', error);
throw new Error(`Worker: Batch inference failed - ${error.message}`);
}
}
async function runInference(inputData) {
if (!session) {
throw new Error("Worker: Session not initialized. Call 'initializeModel' first.");
}
const startTime = performance.now();
try {
const feeds = {};
// 优化:复用缓冲区,减少内存分配
const inputIdsLength = inputData.input_ids.length;
const attentionMaskLength = inputData.attention_mask.length;
// 复用或创建 BigInt64Array 缓冲区
const inputIdsBuffer = getOrCreateBuffer('inputIds', inputIdsLength);
const attentionMaskBuffer = getOrCreateBuffer('attentionMask', attentionMaskLength);
// 填充数据(避免 map 操作)
for (let i = 0; i < inputIdsLength; i++) {
inputIdsBuffer[i] = BigInt(inputData.input_ids[i]);
}
for (let i = 0; i < attentionMaskLength; i++) {
attentionMaskBuffer[i] = BigInt(inputData.attention_mask[i]);
}
feeds['input_ids'] = new ort.Tensor(
'int64',
inputIdsBuffer.slice(0, inputIdsLength),
inputData.dims.input_ids,
);
feeds['attention_mask'] = new ort.Tensor(
'int64',
attentionMaskBuffer.slice(0, attentionMaskLength),
inputData.dims.attention_mask,
);
// 处理 token_type_ids - 只有当模型需要时才提供
if (modelInputNames && modelInputNames.includes('token_type_ids')) {
if (inputData.token_type_ids && inputData.dims.token_type_ids) {
const tokenTypeIdsLength = inputData.token_type_ids.length;
const tokenTypeIdsBuffer = getOrCreateBuffer('tokenTypeIds', tokenTypeIdsLength);
for (let i = 0; i < tokenTypeIdsLength; i++) {
tokenTypeIdsBuffer[i] = BigInt(inputData.token_type_ids[i]);
}
feeds['token_type_ids'] = new ort.Tensor(
'int64',
tokenTypeIdsBuffer.slice(0, tokenTypeIdsLength),
inputData.dims.token_type_ids,
);
} else {
// 创建默认的全零 token_type_ids
const tokenTypeIdsBuffer = getOrCreateBuffer('tokenTypeIds', inputIdsLength);
tokenTypeIdsBuffer.fill(0n, 0, inputIdsLength);
feeds['token_type_ids'] = new ort.Tensor(
'int64',
tokenTypeIdsBuffer.slice(0, inputIdsLength),
inputData.dims.input_ids,
);
}
} else {
console.log('Worker: Skipping token_type_ids as model does not require it');
}
const results = await session.run(feeds);
const outputTensor = results.last_hidden_state || results[Object.keys(results)[0]];
// 使用 Transferable Objects 优化数据传输
const outputData = new Float32Array(outputTensor.data);
// 更新统计信息
workerStats.totalInferences++;
const inferenceTime = performance.now() - startTime;
workerStats.totalInferenceTime += inferenceTime;
workerStats.averageInferenceTime = workerStats.totalInferenceTime / workerStats.totalInferences;
return {
status: 'success',
output: {
data: outputData, // 直接返回 Float32Array
dims: outputTensor.dims,
},
transferList: [outputData.buffer], // 标记为可转移对象
stats: {
inferenceTime,
totalInferences: workerStats.totalInferences,
averageInferenceTime: workerStats.averageInferenceTime,
memoryAllocations: workerStats.memoryAllocations,
},
};
} catch (error) {
console.error('Worker: Inference failed:', error);
throw new Error(`Worker: Inference failed - ${error.message}`);
}
}
self.onmessage = async (event) => {
const { id, type, payload } = event.data;
try {
switch (type) {
case 'init':
// Support both modelPath (URL string) and modelData (ArrayBuffer)
const modelInput = payload.modelData || payload.modelPath;
await initializeModel(modelInput, payload.numThreads, payload.executionProviders);
self.postMessage({ id, type: 'init_complete', status: 'success' });
break;
case 'infer':
const result = await runInference(payload);
// 使用 Transferable Objects 优化数据传输
self.postMessage(
{
id,
type: 'infer_complete',
status: 'success',
payload: result.output,
stats: result.stats,
},
result.transferList || [],
);
break;
case 'batchInfer':
const batchResult = await runBatchInference(payload);
// 使用 Transferable Objects 优化数据传输
self.postMessage(
{
id,
type: 'batchInfer_complete',
status: 'success',
payload: batchResult.output,
stats: batchResult.stats,
},
batchResult.transferList || [],
);
break;
case 'getStats':
self.postMessage({
id,
type: 'stats_complete',
status: 'success',
payload: workerStats,
});
break;
case 'clearBuffers':
// 清理缓冲区,释放内存
reusableBuffers = {
inputIds: null,
attentionMask: null,
tokenTypeIds: null,
};
workerStats.memoryAllocations = 0;
self.postMessage({
id,
type: 'clear_complete',
status: 'success',
payload: { message: 'Buffers cleared' },
});
break;
default:
console.warn(`Worker: Unknown message type: ${type}`);
self.postMessage({
id,
type: 'error',
status: 'error',
payload: { message: `Unknown message type: ${type}` },
});
}
} catch (error) {
// 确保将错误作为普通对象发送因为Error对象本身可能无法正确序列化
self.postMessage({
id,
type: `${type}_error`, // 如 'init_error' 或 'infer_error'
status: 'error',
payload: {
message: error.message,
stack: error.stack, // 可选,用于调试
name: error.name,
},
});
}
};