/* eslint-disable */ // js/similarity.worker.js importScripts('../libs/ort.min.js'); // 调整路径以匹配您的文件结构 // 全局Worker状态 let session = null; let modelPathInternal = null; let ortEnvConfigured = false; let sessionOptions = null; let modelInputNames = null; // 存储模型的输入名称 // 复用的 TypedArray 缓冲区,减少内存分配 let reusableBuffers = { inputIds: null, attentionMask: null, tokenTypeIds: null, }; // 性能统计 let workerStats = { totalInferences: 0, totalInferenceTime: 0, averageInferenceTime: 0, memoryAllocations: 0, }; // 配置 ONNX Runtime 环境 (仅一次) function configureOrtEnv(numThreads = 1, executionProviders = ['wasm']) { if (ortEnvConfigured) return; try { ort.env.wasm.numThreads = numThreads; ort.env.wasm.simd = true; // 尽可能启用SIMD ort.env.wasm.proxy = false; // 在Worker中,通常不需要代理 ort.env.logLevel = 'warning'; // 'verbose', 'info', 'warning', 'error', 'fatal' ortEnvConfigured = true; sessionOptions = { executionProviders: executionProviders, graphOptimizationLevel: 'all', enableCpuMemArena: true, enableMemPattern: true, // executionMode: 'sequential' // 在worker内部通常是顺序执行一个任务 }; } catch (error) { console.error('Worker: Failed to configure ORT environment', error); throw error; // 抛出错误,让主线程知道 } } async function initializeModel(modelPathOrData, numThreads, executionProviders) { try { configureOrtEnv(numThreads, executionProviders); // 确保环境已配置 if (!modelPathOrData) { throw new Error('Worker: Model path or data is not provided.'); } // Check if input is ArrayBuffer (cached model data) or string (URL path) if (modelPathOrData instanceof ArrayBuffer) { console.log( `Worker: Initializing model from cached ArrayBuffer (${modelPathOrData.byteLength} bytes)`, ); session = await ort.InferenceSession.create(modelPathOrData, sessionOptions); modelPathInternal = '[Cached ArrayBuffer]'; // For debugging purposes } else { console.log(`Worker: Initializing model from URL: ${modelPathOrData}`); modelPathInternal = modelPathOrData; // 存储模型路径以备调试或重载(如果需要) session = await ort.InferenceSession.create(modelPathInternal, sessionOptions); } // 获取模型的输入名称,用于判断是否需要token_type_ids modelInputNames = session.inputNames; console.log(`Worker: ONNX session created successfully for model: ${modelPathInternal}`); console.log(`Worker: Model input names:`, modelInputNames); return { status: 'success', message: 'Model initialized' }; } catch (error) { console.error(`Worker: Model initialization failed:`, error); session = null; // 清理session以防部分初始化 modelInputNames = null; // 将错误信息序列化,因为Error对象本身可能无法直接postMessage throw new Error(`Worker: Model initialization failed - ${error.message}`); } } // 优化的缓冲区管理函数 function getOrCreateBuffer(name, requiredLength, type = BigInt64Array) { if (!reusableBuffers[name] || reusableBuffers[name].length < requiredLength) { reusableBuffers[name] = new type(requiredLength); workerStats.memoryAllocations++; } return reusableBuffers[name]; } // 优化的批处理推理函数 async function runBatchInference(batchData) { if (!session) { throw new Error("Worker: Session not initialized. Call 'initializeModel' first."); } const startTime = performance.now(); try { const feeds = {}; const batchSize = batchData.dims.input_ids[0]; const seqLength = batchData.dims.input_ids[1]; // 优化:复用缓冲区,减少内存分配 const inputIdsLength = batchData.input_ids.length; const attentionMaskLength = batchData.attention_mask.length; // 复用或创建 BigInt64Array 缓冲区 const inputIdsBuffer = getOrCreateBuffer('inputIds', inputIdsLength); const attentionMaskBuffer = getOrCreateBuffer('attentionMask', attentionMaskLength); // 批量填充数据(避免 map 操作) for (let i = 0; i < inputIdsLength; i++) { inputIdsBuffer[i] = BigInt(batchData.input_ids[i]); } for (let i = 0; i < attentionMaskLength; i++) { attentionMaskBuffer[i] = BigInt(batchData.attention_mask[i]); } feeds['input_ids'] = new ort.Tensor( 'int64', inputIdsBuffer.slice(0, inputIdsLength), batchData.dims.input_ids, ); feeds['attention_mask'] = new ort.Tensor( 'int64', attentionMaskBuffer.slice(0, attentionMaskLength), batchData.dims.attention_mask, ); // 处理 token_type_ids - 只有当模型需要时才提供 if (modelInputNames && modelInputNames.includes('token_type_ids')) { if (batchData.token_type_ids && batchData.dims.token_type_ids) { const tokenTypeIdsLength = batchData.token_type_ids.length; const tokenTypeIdsBuffer = getOrCreateBuffer('tokenTypeIds', tokenTypeIdsLength); for (let i = 0; i < tokenTypeIdsLength; i++) { tokenTypeIdsBuffer[i] = BigInt(batchData.token_type_ids[i]); } feeds['token_type_ids'] = new ort.Tensor( 'int64', tokenTypeIdsBuffer.slice(0, tokenTypeIdsLength), batchData.dims.token_type_ids, ); } else { // 创建默认的全零 token_type_ids const tokenTypeIdsBuffer = getOrCreateBuffer('tokenTypeIds', inputIdsLength); tokenTypeIdsBuffer.fill(0n, 0, inputIdsLength); feeds['token_type_ids'] = new ort.Tensor( 'int64', tokenTypeIdsBuffer.slice(0, inputIdsLength), batchData.dims.input_ids, ); } } else { console.log('Worker: Skipping token_type_ids as model does not require it'); } // 执行批处理推理 const results = await session.run(feeds); const outputTensor = results.last_hidden_state || results[Object.keys(results)[0]]; // 使用 Transferable Objects 优化数据传输 const outputData = new Float32Array(outputTensor.data); // 更新统计信息 workerStats.totalInferences += batchSize; // 批处理计算多个推理 const inferenceTime = performance.now() - startTime; workerStats.totalInferenceTime += inferenceTime; workerStats.averageInferenceTime = workerStats.totalInferenceTime / workerStats.totalInferences; return { status: 'success', output: { data: outputData, dims: outputTensor.dims, batchSize: batchSize, seqLength: seqLength, }, transferList: [outputData.buffer], stats: { inferenceTime, totalInferences: workerStats.totalInferences, averageInferenceTime: workerStats.averageInferenceTime, memoryAllocations: workerStats.memoryAllocations, batchSize: batchSize, }, }; } catch (error) { console.error('Worker: Batch inference failed:', error); throw new Error(`Worker: Batch inference failed - ${error.message}`); } } async function runInference(inputData) { if (!session) { throw new Error("Worker: Session not initialized. Call 'initializeModel' first."); } const startTime = performance.now(); try { const feeds = {}; // 优化:复用缓冲区,减少内存分配 const inputIdsLength = inputData.input_ids.length; const attentionMaskLength = inputData.attention_mask.length; // 复用或创建 BigInt64Array 缓冲区 const inputIdsBuffer = getOrCreateBuffer('inputIds', inputIdsLength); const attentionMaskBuffer = getOrCreateBuffer('attentionMask', attentionMaskLength); // 填充数据(避免 map 操作) for (let i = 0; i < inputIdsLength; i++) { inputIdsBuffer[i] = BigInt(inputData.input_ids[i]); } for (let i = 0; i < attentionMaskLength; i++) { attentionMaskBuffer[i] = BigInt(inputData.attention_mask[i]); } feeds['input_ids'] = new ort.Tensor( 'int64', inputIdsBuffer.slice(0, inputIdsLength), inputData.dims.input_ids, ); feeds['attention_mask'] = new ort.Tensor( 'int64', attentionMaskBuffer.slice(0, attentionMaskLength), inputData.dims.attention_mask, ); // 处理 token_type_ids - 只有当模型需要时才提供 if (modelInputNames && modelInputNames.includes('token_type_ids')) { if (inputData.token_type_ids && inputData.dims.token_type_ids) { const tokenTypeIdsLength = inputData.token_type_ids.length; const tokenTypeIdsBuffer = getOrCreateBuffer('tokenTypeIds', tokenTypeIdsLength); for (let i = 0; i < tokenTypeIdsLength; i++) { tokenTypeIdsBuffer[i] = BigInt(inputData.token_type_ids[i]); } feeds['token_type_ids'] = new ort.Tensor( 'int64', tokenTypeIdsBuffer.slice(0, tokenTypeIdsLength), inputData.dims.token_type_ids, ); } else { // 创建默认的全零 token_type_ids const tokenTypeIdsBuffer = getOrCreateBuffer('tokenTypeIds', inputIdsLength); tokenTypeIdsBuffer.fill(0n, 0, inputIdsLength); feeds['token_type_ids'] = new ort.Tensor( 'int64', tokenTypeIdsBuffer.slice(0, inputIdsLength), inputData.dims.input_ids, ); } } else { console.log('Worker: Skipping token_type_ids as model does not require it'); } const results = await session.run(feeds); const outputTensor = results.last_hidden_state || results[Object.keys(results)[0]]; // 使用 Transferable Objects 优化数据传输 const outputData = new Float32Array(outputTensor.data); // 更新统计信息 workerStats.totalInferences++; const inferenceTime = performance.now() - startTime; workerStats.totalInferenceTime += inferenceTime; workerStats.averageInferenceTime = workerStats.totalInferenceTime / workerStats.totalInferences; return { status: 'success', output: { data: outputData, // 直接返回 Float32Array dims: outputTensor.dims, }, transferList: [outputData.buffer], // 标记为可转移对象 stats: { inferenceTime, totalInferences: workerStats.totalInferences, averageInferenceTime: workerStats.averageInferenceTime, memoryAllocations: workerStats.memoryAllocations, }, }; } catch (error) { console.error('Worker: Inference failed:', error); throw new Error(`Worker: Inference failed - ${error.message}`); } } self.onmessage = async (event) => { const { id, type, payload } = event.data; try { switch (type) { case 'init': // Support both modelPath (URL string) and modelData (ArrayBuffer) const modelInput = payload.modelData || payload.modelPath; await initializeModel(modelInput, payload.numThreads, payload.executionProviders); self.postMessage({ id, type: 'init_complete', status: 'success' }); break; case 'infer': const result = await runInference(payload); // 使用 Transferable Objects 优化数据传输 self.postMessage( { id, type: 'infer_complete', status: 'success', payload: result.output, stats: result.stats, }, result.transferList || [], ); break; case 'batchInfer': const batchResult = await runBatchInference(payload); // 使用 Transferable Objects 优化数据传输 self.postMessage( { id, type: 'batchInfer_complete', status: 'success', payload: batchResult.output, stats: batchResult.stats, }, batchResult.transferList || [], ); break; case 'getStats': self.postMessage({ id, type: 'stats_complete', status: 'success', payload: workerStats, }); break; case 'clearBuffers': // 清理缓冲区,释放内存 reusableBuffers = { inputIds: null, attentionMask: null, tokenTypeIds: null, }; workerStats.memoryAllocations = 0; self.postMessage({ id, type: 'clear_complete', status: 'success', payload: { message: 'Buffers cleared' }, }); break; default: console.warn(`Worker: Unknown message type: ${type}`); self.postMessage({ id, type: 'error', status: 'error', payload: { message: `Unknown message type: ${type}` }, }); } } catch (error) { // 确保将错误作为普通对象发送,因为Error对象本身可能无法正确序列化 self.postMessage({ id, type: `${type}_error`, // 如 'init_error' 或 'infer_error' status: 'error', payload: { message: error.message, stack: error.stack, // 可选,用于调试 name: error.name, }, }); } };