Files
broswer-automation/app/chrome-extension/workers/similarity.worker.js
nasir@endelospay.com d97cad1736 first commit
2025-08-12 02:54:17 +05:00

385 lines
13 KiB
JavaScript
Raw Permalink Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

/* eslint-disable */
// js/similarity.worker.js
importScripts('../libs/ort.min.js'); // 调整路径以匹配您的文件结构
// 全局Worker状态
let session = null;
let modelPathInternal = null;
let ortEnvConfigured = false;
let sessionOptions = null;
let modelInputNames = null; // 存储模型的输入名称
// 复用的 TypedArray 缓冲区,减少内存分配
let reusableBuffers = {
inputIds: null,
attentionMask: null,
tokenTypeIds: null,
};
// 性能统计
let workerStats = {
totalInferences: 0,
totalInferenceTime: 0,
averageInferenceTime: 0,
memoryAllocations: 0,
};
// 配置 ONNX Runtime 环境 (仅一次)
function configureOrtEnv(numThreads = 1, executionProviders = ['wasm']) {
if (ortEnvConfigured) return;
try {
ort.env.wasm.numThreads = numThreads;
ort.env.wasm.simd = true; // 尽可能启用SIMD
ort.env.wasm.proxy = false; // 在Worker中通常不需要代理
ort.env.logLevel = 'warning'; // 'verbose', 'info', 'warning', 'error', 'fatal'
ortEnvConfigured = true;
sessionOptions = {
executionProviders: executionProviders,
graphOptimizationLevel: 'all',
enableCpuMemArena: true,
enableMemPattern: true,
// executionMode: 'sequential' // 在worker内部通常是顺序执行一个任务
};
} catch (error) {
console.error('Worker: Failed to configure ORT environment', error);
throw error; // 抛出错误,让主线程知道
}
}
async function initializeModel(modelPathOrData, numThreads, executionProviders) {
try {
configureOrtEnv(numThreads, executionProviders); // 确保环境已配置
if (!modelPathOrData) {
throw new Error('Worker: Model path or data is not provided.');
}
// Check if input is ArrayBuffer (cached model data) or string (URL path)
if (modelPathOrData instanceof ArrayBuffer) {
console.log(
`Worker: Initializing model from cached ArrayBuffer (${modelPathOrData.byteLength} bytes)`,
);
session = await ort.InferenceSession.create(modelPathOrData, sessionOptions);
modelPathInternal = '[Cached ArrayBuffer]'; // For debugging purposes
} else {
console.log(`Worker: Initializing model from URL: ${modelPathOrData}`);
modelPathInternal = modelPathOrData; // 存储模型路径以备调试或重载(如果需要)
session = await ort.InferenceSession.create(modelPathInternal, sessionOptions);
}
// 获取模型的输入名称用于判断是否需要token_type_ids
modelInputNames = session.inputNames;
console.log(`Worker: ONNX session created successfully for model: ${modelPathInternal}`);
console.log(`Worker: Model input names:`, modelInputNames);
return { status: 'success', message: 'Model initialized' };
} catch (error) {
console.error(`Worker: Model initialization failed:`, error);
session = null; // 清理session以防部分初始化
modelInputNames = null;
// 将错误信息序列化因为Error对象本身可能无法直接postMessage
throw new Error(`Worker: Model initialization failed - ${error.message}`);
}
}
// 优化的缓冲区管理函数
function getOrCreateBuffer(name, requiredLength, type = BigInt64Array) {
if (!reusableBuffers[name] || reusableBuffers[name].length < requiredLength) {
reusableBuffers[name] = new type(requiredLength);
workerStats.memoryAllocations++;
}
return reusableBuffers[name];
}
// 优化的批处理推理函数
async function runBatchInference(batchData) {
if (!session) {
throw new Error("Worker: Session not initialized. Call 'initializeModel' first.");
}
const startTime = performance.now();
try {
const feeds = {};
const batchSize = batchData.dims.input_ids[0];
const seqLength = batchData.dims.input_ids[1];
// 优化:复用缓冲区,减少内存分配
const inputIdsLength = batchData.input_ids.length;
const attentionMaskLength = batchData.attention_mask.length;
// 复用或创建 BigInt64Array 缓冲区
const inputIdsBuffer = getOrCreateBuffer('inputIds', inputIdsLength);
const attentionMaskBuffer = getOrCreateBuffer('attentionMask', attentionMaskLength);
// 批量填充数据(避免 map 操作)
for (let i = 0; i < inputIdsLength; i++) {
inputIdsBuffer[i] = BigInt(batchData.input_ids[i]);
}
for (let i = 0; i < attentionMaskLength; i++) {
attentionMaskBuffer[i] = BigInt(batchData.attention_mask[i]);
}
feeds['input_ids'] = new ort.Tensor(
'int64',
inputIdsBuffer.slice(0, inputIdsLength),
batchData.dims.input_ids,
);
feeds['attention_mask'] = new ort.Tensor(
'int64',
attentionMaskBuffer.slice(0, attentionMaskLength),
batchData.dims.attention_mask,
);
// 处理 token_type_ids - 只有当模型需要时才提供
if (modelInputNames && modelInputNames.includes('token_type_ids')) {
if (batchData.token_type_ids && batchData.dims.token_type_ids) {
const tokenTypeIdsLength = batchData.token_type_ids.length;
const tokenTypeIdsBuffer = getOrCreateBuffer('tokenTypeIds', tokenTypeIdsLength);
for (let i = 0; i < tokenTypeIdsLength; i++) {
tokenTypeIdsBuffer[i] = BigInt(batchData.token_type_ids[i]);
}
feeds['token_type_ids'] = new ort.Tensor(
'int64',
tokenTypeIdsBuffer.slice(0, tokenTypeIdsLength),
batchData.dims.token_type_ids,
);
} else {
// 创建默认的全零 token_type_ids
const tokenTypeIdsBuffer = getOrCreateBuffer('tokenTypeIds', inputIdsLength);
tokenTypeIdsBuffer.fill(0n, 0, inputIdsLength);
feeds['token_type_ids'] = new ort.Tensor(
'int64',
tokenTypeIdsBuffer.slice(0, inputIdsLength),
batchData.dims.input_ids,
);
}
} else {
console.log('Worker: Skipping token_type_ids as model does not require it');
}
// 执行批处理推理
const results = await session.run(feeds);
const outputTensor = results.last_hidden_state || results[Object.keys(results)[0]];
// 使用 Transferable Objects 优化数据传输
const outputData = new Float32Array(outputTensor.data);
// 更新统计信息
workerStats.totalInferences += batchSize; // 批处理计算多个推理
const inferenceTime = performance.now() - startTime;
workerStats.totalInferenceTime += inferenceTime;
workerStats.averageInferenceTime = workerStats.totalInferenceTime / workerStats.totalInferences;
return {
status: 'success',
output: {
data: outputData,
dims: outputTensor.dims,
batchSize: batchSize,
seqLength: seqLength,
},
transferList: [outputData.buffer],
stats: {
inferenceTime,
totalInferences: workerStats.totalInferences,
averageInferenceTime: workerStats.averageInferenceTime,
memoryAllocations: workerStats.memoryAllocations,
batchSize: batchSize,
},
};
} catch (error) {
console.error('Worker: Batch inference failed:', error);
throw new Error(`Worker: Batch inference failed - ${error.message}`);
}
}
async function runInference(inputData) {
if (!session) {
throw new Error("Worker: Session not initialized. Call 'initializeModel' first.");
}
const startTime = performance.now();
try {
const feeds = {};
// 优化:复用缓冲区,减少内存分配
const inputIdsLength = inputData.input_ids.length;
const attentionMaskLength = inputData.attention_mask.length;
// 复用或创建 BigInt64Array 缓冲区
const inputIdsBuffer = getOrCreateBuffer('inputIds', inputIdsLength);
const attentionMaskBuffer = getOrCreateBuffer('attentionMask', attentionMaskLength);
// 填充数据(避免 map 操作)
for (let i = 0; i < inputIdsLength; i++) {
inputIdsBuffer[i] = BigInt(inputData.input_ids[i]);
}
for (let i = 0; i < attentionMaskLength; i++) {
attentionMaskBuffer[i] = BigInt(inputData.attention_mask[i]);
}
feeds['input_ids'] = new ort.Tensor(
'int64',
inputIdsBuffer.slice(0, inputIdsLength),
inputData.dims.input_ids,
);
feeds['attention_mask'] = new ort.Tensor(
'int64',
attentionMaskBuffer.slice(0, attentionMaskLength),
inputData.dims.attention_mask,
);
// 处理 token_type_ids - 只有当模型需要时才提供
if (modelInputNames && modelInputNames.includes('token_type_ids')) {
if (inputData.token_type_ids && inputData.dims.token_type_ids) {
const tokenTypeIdsLength = inputData.token_type_ids.length;
const tokenTypeIdsBuffer = getOrCreateBuffer('tokenTypeIds', tokenTypeIdsLength);
for (let i = 0; i < tokenTypeIdsLength; i++) {
tokenTypeIdsBuffer[i] = BigInt(inputData.token_type_ids[i]);
}
feeds['token_type_ids'] = new ort.Tensor(
'int64',
tokenTypeIdsBuffer.slice(0, tokenTypeIdsLength),
inputData.dims.token_type_ids,
);
} else {
// 创建默认的全零 token_type_ids
const tokenTypeIdsBuffer = getOrCreateBuffer('tokenTypeIds', inputIdsLength);
tokenTypeIdsBuffer.fill(0n, 0, inputIdsLength);
feeds['token_type_ids'] = new ort.Tensor(
'int64',
tokenTypeIdsBuffer.slice(0, inputIdsLength),
inputData.dims.input_ids,
);
}
} else {
console.log('Worker: Skipping token_type_ids as model does not require it');
}
const results = await session.run(feeds);
const outputTensor = results.last_hidden_state || results[Object.keys(results)[0]];
// 使用 Transferable Objects 优化数据传输
const outputData = new Float32Array(outputTensor.data);
// 更新统计信息
workerStats.totalInferences++;
const inferenceTime = performance.now() - startTime;
workerStats.totalInferenceTime += inferenceTime;
workerStats.averageInferenceTime = workerStats.totalInferenceTime / workerStats.totalInferences;
return {
status: 'success',
output: {
data: outputData, // 直接返回 Float32Array
dims: outputTensor.dims,
},
transferList: [outputData.buffer], // 标记为可转移对象
stats: {
inferenceTime,
totalInferences: workerStats.totalInferences,
averageInferenceTime: workerStats.averageInferenceTime,
memoryAllocations: workerStats.memoryAllocations,
},
};
} catch (error) {
console.error('Worker: Inference failed:', error);
throw new Error(`Worker: Inference failed - ${error.message}`);
}
}
self.onmessage = async (event) => {
const { id, type, payload } = event.data;
try {
switch (type) {
case 'init':
// Support both modelPath (URL string) and modelData (ArrayBuffer)
const modelInput = payload.modelData || payload.modelPath;
await initializeModel(modelInput, payload.numThreads, payload.executionProviders);
self.postMessage({ id, type: 'init_complete', status: 'success' });
break;
case 'infer':
const result = await runInference(payload);
// 使用 Transferable Objects 优化数据传输
self.postMessage(
{
id,
type: 'infer_complete',
status: 'success',
payload: result.output,
stats: result.stats,
},
result.transferList || [],
);
break;
case 'batchInfer':
const batchResult = await runBatchInference(payload);
// 使用 Transferable Objects 优化数据传输
self.postMessage(
{
id,
type: 'batchInfer_complete',
status: 'success',
payload: batchResult.output,
stats: batchResult.stats,
},
batchResult.transferList || [],
);
break;
case 'getStats':
self.postMessage({
id,
type: 'stats_complete',
status: 'success',
payload: workerStats,
});
break;
case 'clearBuffers':
// 清理缓冲区,释放内存
reusableBuffers = {
inputIds: null,
attentionMask: null,
tokenTypeIds: null,
};
workerStats.memoryAllocations = 0;
self.postMessage({
id,
type: 'clear_complete',
status: 'success',
payload: { message: 'Buffers cleared' },
});
break;
default:
console.warn(`Worker: Unknown message type: ${type}`);
self.postMessage({
id,
type: 'error',
status: 'error',
payload: { message: `Unknown message type: ${type}` },
});
}
} catch (error) {
// 确保将错误作为普通对象发送因为Error对象本身可能无法正确序列化
self.postMessage({
id,
type: `${type}_error`, // 如 'init_error' 或 'infer_error'
status: 'error',
payload: {
message: error.message,
stack: error.stack, // 可选,用于调试
name: error.name,
},
});
}
};