first commit

This commit is contained in:
nasir@endelospay.com
2025-08-12 02:54:17 +05:00
commit d97cad1736
225 changed files with 137626 additions and 0 deletions

View File

@@ -0,0 +1,39 @@
{
"name": "chrome-mcp-shared",
"version": "1.0.1",
"author": "hangye",
"main": "dist/index.js",
"module": "./dist/index.mjs",
"types": "dist/index.d.ts",
"exports": {
".": {
"import": {
"types": "./dist/index.d.ts",
"default": "./dist/index.mjs"
},
"require": {
"types": "./dist/index.d.ts",
"default": "./dist/index.js"
}
}
},
"scripts": {
"build": "tsup src/index.ts --format cjs,esm --dts --clean",
"dev": "tsup src/index.ts --format cjs,esm --dts --watch",
"lint": "npx eslint 'src/**/*.{js,ts}'",
"lint:fix": "npx eslint 'src/**/*.{js,ts}' --fix",
"format": "prettier --write 'src/**/*.{js,ts,json}'"
},
"files": [
"dist"
],
"devDependencies": {
"@types/node": "^18.0.0",
"@typescript-eslint/parser": "^8.32.0",
"tsup": "^8.4.0"
},
"dependencies": {
"@modelcontextprotocol/sdk": "^1.11.0",
"zod": "^3.24.4"
}
}

View File

@@ -0,0 +1,2 @@
export const DEFAULT_SERVER_PORT = 56889;
export const HOST_NAME = 'com.chrome_mcp.native_host';

View File

@@ -0,0 +1,3 @@
export * from './constants';
export * from './types';
export * from './tools';

View File

@@ -0,0 +1,537 @@
import { type Tool } from '@modelcontextprotocol/sdk/types.js';
export const TOOL_NAMES = {
BROWSER: {
GET_WINDOWS_AND_TABS: 'get_windows_and_tabs',
SEARCH_TABS_CONTENT: 'search_tabs_content',
NAVIGATE: 'chrome_navigate',
SCREENSHOT: 'chrome_screenshot',
CLOSE_TABS: 'chrome_close_tabs',
GO_BACK_OR_FORWARD: 'chrome_go_back_or_forward',
WEB_FETCHER: 'chrome_get_web_content',
CLICK: 'chrome_click_element',
FILL: 'chrome_fill_or_select',
GET_INTERACTIVE_ELEMENTS: 'chrome_get_interactive_elements',
NETWORK_CAPTURE_START: 'chrome_network_capture_start',
NETWORK_CAPTURE_STOP: 'chrome_network_capture_stop',
NETWORK_REQUEST: 'chrome_network_request',
NETWORK_DEBUGGER_START: 'chrome_network_debugger_start',
NETWORK_DEBUGGER_STOP: 'chrome_network_debugger_stop',
KEYBOARD: 'chrome_keyboard',
HISTORY: 'chrome_history',
BOOKMARK_SEARCH: 'chrome_bookmark_search',
BOOKMARK_ADD: 'chrome_bookmark_add',
BOOKMARK_DELETE: 'chrome_bookmark_delete',
INJECT_SCRIPT: 'chrome_inject_script',
SEND_COMMAND_TO_INJECT_SCRIPT: 'chrome_send_command_to_inject_script',
CONSOLE: 'chrome_console',
},
};
export const TOOL_SCHEMAS: Tool[] = [
{
name: TOOL_NAMES.BROWSER.GET_WINDOWS_AND_TABS,
description: 'Get all currently open browser windows and tabs',
inputSchema: {
type: 'object',
properties: {},
required: [],
},
},
{
name: TOOL_NAMES.BROWSER.NAVIGATE,
description: 'Navigate to a URL or refresh the current tab',
inputSchema: {
type: 'object',
properties: {
url: { type: 'string', description: 'URL to navigate to the website specified' },
newWindow: {
type: 'boolean',
description: 'Create a new window to navigate to the URL or not. Defaults to false',
},
width: { type: 'number', description: 'Viewport width in pixels (default: 1280)' },
height: { type: 'number', description: 'Viewport height in pixels (default: 720)' },
refresh: {
type: 'boolean',
description:
'Refresh the current active tab instead of navigating to a URL. When true, the url parameter is ignored. Defaults to false',
},
},
required: [],
},
},
{
name: TOOL_NAMES.BROWSER.SCREENSHOT,
description:
'Take a screenshot of the current page or a specific element(if you want to see the page, recommend to use chrome_get_web_content first)',
inputSchema: {
type: 'object',
properties: {
name: { type: 'string', description: 'Name for the screenshot, if saving as PNG' },
selector: { type: 'string', description: 'CSS selector for element to screenshot' },
width: { type: 'number', description: 'Width in pixels (default: 800)' },
height: { type: 'number', description: 'Height in pixels (default: 600)' },
storeBase64: {
type: 'boolean',
description:
'return screenshot in base64 format (default: false) if you want to see the page, recommend set this to be true',
},
fullPage: {
type: 'boolean',
description: 'Store screenshot of the entire page (default: true)',
},
savePng: {
type: 'boolean',
description:
'Save screenshot as PNG file (default: true)if you want to see the page, recommend set this to be false, and set storeBase64 to be true',
},
},
required: [],
},
},
{
name: TOOL_NAMES.BROWSER.CLOSE_TABS,
description: 'Close one or more browser tabs',
inputSchema: {
type: 'object',
properties: {
tabIds: {
type: 'array',
items: { type: 'number' },
description: 'Array of tab IDs to close. If not provided, will close the active tab.',
},
url: {
type: 'string',
description: 'Close tabs matching this URL. Can be used instead of tabIds.',
},
},
required: [],
},
},
{
name: TOOL_NAMES.BROWSER.GO_BACK_OR_FORWARD,
description: 'Navigate back or forward in browser history',
inputSchema: {
type: 'object',
properties: {
isForward: {
type: 'boolean',
description: 'Go forward in history if true, go back if false (default: false)',
},
},
required: [],
},
},
{
name: TOOL_NAMES.BROWSER.WEB_FETCHER,
description: 'Fetch content from a web page',
inputSchema: {
type: 'object',
properties: {
url: {
type: 'string',
description: 'URL to fetch content from. If not provided, uses the current active tab',
},
htmlContent: {
type: 'boolean',
description:
'Get the visible HTML content of the page. If true, textContent will be ignored (default: false)',
},
textContent: {
type: 'boolean',
description:
'Get the visible text content of the page with metadata. Ignored if htmlContent is true (default: true)',
},
selector: {
type: 'string',
description:
'CSS selector to get content from a specific element. If provided, only content from this element will be returned',
},
},
required: [],
},
},
{
name: TOOL_NAMES.BROWSER.CLICK,
description: 'Click on an element in the current page or at specific coordinates',
inputSchema: {
type: 'object',
properties: {
selector: {
type: 'string',
description:
'CSS selector for the element to click. Either selector or coordinates must be provided. if coordinates are not provided, the selector must be provided.',
},
coordinates: {
type: 'object',
description:
'Coordinates to click at (relative to viewport). If provided, takes precedence over selector.',
properties: {
x: {
type: 'number',
description: 'X coordinate relative to the viewport',
},
y: {
type: 'number',
description: 'Y coordinate relative to the viewport',
},
},
required: ['x', 'y'],
},
waitForNavigation: {
type: 'boolean',
description: 'Wait for page navigation to complete after click (default: false)',
},
timeout: {
type: 'number',
description:
'Timeout in milliseconds for waiting for the element or navigation (default: 5000)',
},
},
required: [],
},
},
{
name: TOOL_NAMES.BROWSER.FILL,
description: 'Fill a form element or select an option with the specified value',
inputSchema: {
type: 'object',
properties: {
selector: {
type: 'string',
description: 'CSS selector for the input element to fill or select',
},
value: {
type: 'string',
description: 'Value to fill or select into the element',
},
},
required: ['selector', 'value'],
},
},
{
name: TOOL_NAMES.BROWSER.GET_INTERACTIVE_ELEMENTS,
description: 'Get interactive elements from the current page',
inputSchema: {
type: 'object',
properties: {
textQuery: {
type: 'string',
description: 'Text to search for within interactive elements (fuzzy search)',
},
selector: {
type: 'string',
description:
'CSS selector to filter interactive elements. Takes precedence over textQuery if both are provided.',
},
includeCoordinates: {
type: 'boolean',
description: 'Include element coordinates in the response (default: true)',
},
},
required: [],
},
},
{
name: TOOL_NAMES.BROWSER.NETWORK_REQUEST,
description: 'Send a network request from the browser with cookies and other browser context',
inputSchema: {
type: 'object',
properties: {
url: {
type: 'string',
description: 'URL to send the request to',
},
method: {
type: 'string',
description: 'HTTP method to use (default: GET)',
},
headers: {
type: 'object',
description: 'Headers to include in the request',
},
body: {
type: 'string',
description: 'Body of the request (for POST, PUT, etc.)',
},
timeout: {
type: 'number',
description: 'Timeout in milliseconds (default: 30000)',
},
},
required: ['url'],
},
},
{
name: TOOL_NAMES.BROWSER.NETWORK_DEBUGGER_START,
description:
'Start capturing network requests from a web page using Chrome Debugger APIwith responseBody',
inputSchema: {
type: 'object',
properties: {
url: {
type: 'string',
description:
'URL to capture network requests from. If not provided, uses the current active tab',
},
},
required: [],
},
},
{
name: TOOL_NAMES.BROWSER.NETWORK_DEBUGGER_STOP,
description:
'Stop capturing network requests using Chrome Debugger API and return the captured data',
inputSchema: {
type: 'object',
properties: {},
required: [],
},
},
{
name: TOOL_NAMES.BROWSER.NETWORK_CAPTURE_START,
description:
'Start capturing network requests from a web page using Chrome webRequest API(without responseBody)',
inputSchema: {
type: 'object',
properties: {
url: {
type: 'string',
description:
'URL to capture network requests from. If not provided, uses the current active tab',
},
},
required: [],
},
},
{
name: TOOL_NAMES.BROWSER.NETWORK_CAPTURE_STOP,
description:
'Stop capturing network requests using webRequest API and return the captured data',
inputSchema: {
type: 'object',
properties: {},
required: [],
},
},
{
name: TOOL_NAMES.BROWSER.KEYBOARD,
description: 'Simulate keyboard events in the browser',
inputSchema: {
type: 'object',
properties: {
keys: {
type: 'string',
description: 'Keys to simulate (e.g., "Enter", "Ctrl+C", "A,B,C" for sequence)',
},
selector: {
type: 'string',
description:
'CSS selector for the element to send keyboard events to (optional, defaults to active element)',
},
delay: {
type: 'number',
description: 'Delay between key sequences in milliseconds (optional, default: 0)',
},
},
required: ['keys'],
},
},
{
name: TOOL_NAMES.BROWSER.HISTORY,
description: 'Retrieve and search browsing history from Chrome',
inputSchema: {
type: 'object',
properties: {
text: {
type: 'string',
description:
'Text to search for in history URLs and titles. Leave empty to retrieve all history entries within the time range.',
},
startTime: {
type: 'string',
description:
'Start time as a date string. Supports ISO format (e.g., "2023-10-01", "2023-10-01T14:30:00"), relative times (e.g., "1 day ago", "2 weeks ago", "3 months ago", "1 year ago"), and special keywords ("now", "today", "yesterday"). Default: 24 hours ago',
},
endTime: {
type: 'string',
description:
'End time as a date string. Supports ISO format (e.g., "2023-10-31", "2023-10-31T14:30:00"), relative times (e.g., "1 day ago", "2 weeks ago", "3 months ago", "1 year ago"), and special keywords ("now", "today", "yesterday"). Default: current time',
},
maxResults: {
type: 'number',
description:
'Maximum number of history entries to return. Use this to limit results for performance or to focus on the most relevant entries. (default: 100)',
},
excludeCurrentTabs: {
type: 'boolean',
description:
"When set to true, filters out URLs that are currently open in any browser tab. Useful for finding pages you've visited but don't have open anymore. (default: false)",
},
},
required: [],
},
},
{
name: TOOL_NAMES.BROWSER.BOOKMARK_SEARCH,
description: 'Search Chrome bookmarks by title and URL',
inputSchema: {
type: 'object',
properties: {
query: {
type: 'string',
description:
'Search query to match against bookmark titles and URLs. Leave empty to retrieve all bookmarks.',
},
maxResults: {
type: 'number',
description: 'Maximum number of bookmarks to return (default: 50)',
},
folderPath: {
type: 'string',
description:
'Optional folder path or ID to limit search to a specific bookmark folder. Can be a path string (e.g., "Work/Projects") or a folder ID.',
},
},
required: [],
},
},
{
name: TOOL_NAMES.BROWSER.BOOKMARK_ADD,
description: 'Add a new bookmark to Chrome',
inputSchema: {
type: 'object',
properties: {
url: {
type: 'string',
description: 'URL to bookmark. If not provided, uses the current active tab URL.',
},
title: {
type: 'string',
description: 'Title for the bookmark. If not provided, uses the page title from the URL.',
},
parentId: {
type: 'string',
description:
'Parent folder path or ID to add the bookmark to. Can be a path string (e.g., "Work/Projects") or a folder ID. If not provided, adds to the "Bookmarks Bar" folder.',
},
createFolder: {
type: 'boolean',
description: 'Whether to create the parent folder if it does not exist (default: false)',
},
},
required: [],
},
},
{
name: TOOL_NAMES.BROWSER.BOOKMARK_DELETE,
description: 'Delete a bookmark from Chrome',
inputSchema: {
type: 'object',
properties: {
bookmarkId: {
type: 'string',
description: 'ID of the bookmark to delete. Either bookmarkId or url must be provided.',
},
url: {
type: 'string',
description: 'URL of the bookmark to delete. Used if bookmarkId is not provided.',
},
title: {
type: 'string',
description: 'Title of the bookmark to help with matching when deleting by URL.',
},
},
required: [],
},
},
{
name: TOOL_NAMES.BROWSER.SEARCH_TABS_CONTENT,
description:
'search for related content from the currently open tab and return the corresponding web pages.',
inputSchema: {
type: 'object',
properties: {
query: {
type: 'string',
description: 'the query to search for related content.',
},
},
required: ['query'],
},
},
{
name: TOOL_NAMES.BROWSER.INJECT_SCRIPT,
description:
'inject the user-specified content script into the webpage. By default, inject into the currently active tab',
inputSchema: {
type: 'object',
properties: {
url: {
type: 'string',
description:
'If a URL is specified, inject the script into the webpage corresponding to the URL.',
},
type: {
type: 'string',
description:
'the javaScript world for a script to execute within. must be ISOLATED or MAIN',
},
jsScript: {
type: 'string',
description: 'the content script to inject',
},
},
required: ['type', 'jsScript'],
},
},
{
name: TOOL_NAMES.BROWSER.SEND_COMMAND_TO_INJECT_SCRIPT,
description:
'if the script injected using chrome_inject_script listens for user-defined events, this tool can be used to trigger those events',
inputSchema: {
type: 'object',
properties: {
tabId: {
type: 'number',
description:
'the tab where you previously injected the script(if not provided, use the currently active tab)',
},
eventName: {
type: 'string',
description: 'the eventName your injected content script listen for',
},
payload: {
type: 'string',
description: 'the payload passed to event, must be a json string',
},
},
required: ['eventName'],
},
},
{
name: TOOL_NAMES.BROWSER.CONSOLE,
description:
'Capture and retrieve all console output from the current active browser tab/page. This captures console messages that existed before the tool was called.',
inputSchema: {
type: 'object',
properties: {
url: {
type: 'string',
description:
'URL to navigate to and capture console from. If not provided, uses the current active tab',
},
includeExceptions: {
type: 'boolean',
description: 'Include uncaught exceptions in the output (default: true)',
},
maxMessages: {
type: 'number',
description: 'Maximum number of console messages to capture (default: 100)',
},
},
required: [],
},
},
];

View File

@@ -0,0 +1,27 @@
export enum NativeMessageType {
START = 'start',
STARTED = 'started',
STOP = 'stop',
STOPPED = 'stopped',
PING = 'ping',
PONG = 'pong',
ERROR = 'error',
PROCESS_DATA = 'process_data',
PROCESS_DATA_RESPONSE = 'process_data_response',
CALL_TOOL = 'call_tool',
CALL_TOOL_RESPONSE = 'call_tool_response',
// Additional message types used in Chrome extension
SERVER_STARTED = 'server_started',
SERVER_STOPPED = 'server_stopped',
ERROR_FROM_NATIVE_HOST = 'error_from_native_host',
CONNECT_NATIVE = 'connectNative',
PING_NATIVE = 'ping_native',
DISCONNECT_NATIVE = 'disconnect_native',
}
export interface NativeMessage<P = any, E = any> {
type?: NativeMessageType;
responseToRequestId?: string;
payload?: P;
error?: E;
}

View File

@@ -0,0 +1,14 @@
{
"compilerOptions": {
"target": "ES2020",
"module": "NodeNext",
"moduleResolution": "NodeNext",
"esModuleInterop": true,
"declaration": true,
"outDir": "./dist",
"strict": true,
"skipLibCheck": true
},
"include": ["src/**/*"],
"exclude": ["node_modules", "dist"]
}

22
packages/wasm-simd/.gitignore vendored Normal file
View File

@@ -0,0 +1,22 @@
# WASM build outputs
/pkg/
/target/
# Rust
Cargo.lock
# Node.js
node_modules/
npm-debug.log*
yarn-debug.log*
yarn-error.log*
# IDE
.vscode/
.idea/
*.swp
*.swo
# OS
.DS_Store
Thumbs.db

View File

@@ -0,0 +1,70 @@
# WASM SIMD 构建指南
## 🚀 快速构建
### 前置要求
```bash
# 安装 Rust
curl --proto '=https' --tlsv1.2 -sSf https://sh.rustup.rs | sh
# 安装 wasm-pack
curl https://rustwasm.github.io/wasm-pack/installer/init.sh -sSf | sh
```
### 构建选项
1. **从项目根目录构建**(推荐):
```bash
# 构建 WASM 并自动复制到 Chrome 扩展
npm run build:wasm
```
2. **只构建 WASM 包**
```bash
# 从 packages/wasm-simd 目录
npm run build
# 或者从任何地方使用 pnpm filter
pnpm --filter @chrome-mcp/wasm-simd build
```
3. **开发模式构建**
```bash
npm run build:dev # 未优化版本,构建更快
```
### 构建产物
构建完成后,在 `pkg/` 目录下会生成:
- `simd_math.js` - JavaScript 绑定
- `simd_math_bg.wasm` - WebAssembly 二进制文件
- `simd_math.d.ts` - TypeScript 类型定义
- `package.json` - NPM 包信息
### 集成到 Chrome 扩展
WASM 文件会自动复制到 `app/chrome-extension/workers/` 目录Chrome 扩展可以直接使用:
```typescript
// 在 Chrome 扩展中使用
const wasmUrl = chrome.runtime.getURL('workers/simd_math.js');
const wasmModule = await import(wasmUrl);
```
## 🔧 开发工作流
1. 修改 `src/lib.rs` 中的 Rust 代码
2. 运行 `npm run build` 重新构建
3. Chrome 扩展会自动使用新的 WASM 文件
## 📊 性能测试
```bash
# 在 Chrome 扩展中运行基准测试
import { runSIMDBenchmark } from './utils/simd-benchmark';
await runSIMDBenchmark();
```

View File

@@ -0,0 +1,24 @@
[package]
name = "simd-math"
version = "0.1.0"
edition = "2021"
[lib]
crate-type = ["cdylib"]
[dependencies]
wasm-bindgen = "0.2"
wide = "0.7"
console_error_panic_hook = "0.1"
[dependencies.web-sys]
version = "0.3"
features = [
"console",
]
[profile.release]
opt-level = 3
lto = true
codegen-units = 1
panic = "abort"

View File

@@ -0,0 +1,61 @@
# @chrome-mcp/wasm-simd
SIMD-optimized WebAssembly math functions for high-performance vector operations.
## Features
- 🚀 **SIMD Acceleration**: Uses WebAssembly SIMD instructions for 4-8x performance boost
- 🧮 **Vector Operations**: Optimized cosine similarity, batch processing, and matrix operations
- 🔧 **Memory Efficient**: Smart memory pooling and aligned buffer management
- 🌐 **Browser Compatible**: Works in all modern browsers with WebAssembly SIMD support
## Performance
| Operation | JavaScript | SIMD WASM | Speedup |
| ------------------------------ | ---------- | --------- | ------- |
| Cosine Similarity (768d) | 100ms | 18ms | 5.6x |
| Batch Similarity (100x768d) | 850ms | 95ms | 8.9x |
| Similarity Matrix (50x50x384d) | 2.1s | 180ms | 11.7x |
## Usage
```rust
// The Rust implementation provides SIMD-optimized functions
use wasm_bindgen::prelude::*;
#[wasm_bindgen]
pub struct SIMDMath;
#[wasm_bindgen]
impl SIMDMath {
#[wasm_bindgen(constructor)]
pub fn new() -> SIMDMath { SIMDMath }
#[wasm_bindgen]
pub fn cosine_similarity(&self, vec_a: &[f32], vec_b: &[f32]) -> f32 {
// SIMD-optimized implementation
}
}
```
## Building
```bash
# Install dependencies
cargo install wasm-pack
# Build for release
npm run build
# Build for development
npm run build:dev
```
## Browser Support
- Chrome 91+
- Firefox 89+
- Safari 16.4+
- Edge 91+
Older browsers automatically fall back to JavaScript implementations.

View File

@@ -0,0 +1,34 @@
{
"name": "@chrome-mcp/wasm-simd",
"version": "0.1.0",
"description": "SIMD-optimized WebAssembly math functions for Chrome MCP",
"main": "pkg/simd_math.js",
"types": "pkg/simd_math.d.ts",
"files": [
"pkg/"
],
"scripts": {
"build": "wasm-pack build --target web --out-dir pkg --release",
"build:dev": "wasm-pack build --target web --out-dir pkg --dev",
"clean": "rimraf pkg/",
"test": "wasm-pack test --headless --firefox"
},
"keywords": [
"wasm",
"simd",
"webassembly",
"math",
"cosine-similarity",
"vector-operations"
],
"author": "hangye",
"license": "MIT",
"devDependencies": {
"rimraf": "^5.0.0"
},
"repository": {
"type": "git",
"url": "git+https://github.com/your-repo/chrome-mcp-server.git",
"directory": "packages/wasm-simd"
}
}

View File

@@ -0,0 +1,245 @@
use wasm_bindgen::prelude::*;
use wide::f32x4;
// 设置 panic hook 以便在浏览器中调试
#[wasm_bindgen(start)]
pub fn main() {
console_error_panic_hook::set_once();
}
#[wasm_bindgen]
pub struct SIMDMath;
#[wasm_bindgen]
impl SIMDMath {
#[wasm_bindgen(constructor)]
pub fn new() -> SIMDMath {
SIMDMath
}
// 辅助函数:仅计算点积 (SIMD)
#[inline]
fn dot_product_simd_only(&self, vec_a: &[f32], vec_b: &[f32]) -> f32 {
let len = vec_a.len();
let simd_lanes = 4;
let simd_len = len - (len % simd_lanes);
let mut dot_sum_simd = f32x4::ZERO;
for i in (0..simd_len).step_by(simd_lanes) {
// 使用 try_from 和 new 方法,这是 wide 库的正确 API
let a_array: [f32; 4] = vec_a[i..i + simd_lanes].try_into().unwrap();
let b_array: [f32; 4] = vec_b[i..i + simd_lanes].try_into().unwrap();
let a_chunk = f32x4::new(a_array);
let b_chunk = f32x4::new(b_array);
dot_sum_simd = a_chunk.mul_add(b_chunk, dot_sum_simd);
}
let mut dot_product = dot_sum_simd.reduce_add();
for i in simd_len..len {
dot_product += vec_a[i] * vec_b[i];
}
dot_product
}
#[wasm_bindgen]
pub fn cosine_similarity(&self, vec_a: &[f32], vec_b: &[f32]) -> f32 {
if vec_a.len() != vec_b.len() || vec_a.is_empty() {
return 0.0;
}
let len = vec_a.len();
let simd_lanes = 4;
let simd_len = len - (len % simd_lanes);
let mut dot_sum_simd = f32x4::ZERO;
let mut norm_a_sum_simd = f32x4::ZERO;
let mut norm_b_sum_simd = f32x4::ZERO;
// SIMD 处理
for i in (0..simd_len).step_by(simd_lanes) {
let a_array: [f32; 4] = vec_a[i..i + simd_lanes].try_into().unwrap();
let b_array: [f32; 4] = vec_b[i..i + simd_lanes].try_into().unwrap();
let a_chunk = f32x4::new(a_array);
let b_chunk = f32x4::new(b_array);
// 使用 Fused Multiply-Add (FMA)
dot_sum_simd = a_chunk.mul_add(b_chunk, dot_sum_simd);
norm_a_sum_simd = a_chunk.mul_add(a_chunk, norm_a_sum_simd);
norm_b_sum_simd = b_chunk.mul_add(b_chunk, norm_b_sum_simd);
}
// 水平求和
let mut dot_product = dot_sum_simd.reduce_add();
let mut norm_a_sq = norm_a_sum_simd.reduce_add();
let mut norm_b_sq = norm_b_sum_simd.reduce_add();
// 处理剩余元素
for i in simd_len..len {
dot_product += vec_a[i] * vec_b[i];
norm_a_sq += vec_a[i] * vec_a[i];
norm_b_sq += vec_b[i] * vec_b[i];
}
// 优化的数值稳定性处理
let norm_a = norm_a_sq.sqrt();
let norm_b = norm_b_sq.sqrt();
if norm_a == 0.0 || norm_b == 0.0 {
return 0.0;
}
let magnitude = norm_a * norm_b;
// 限制结果在 [-1.0, 1.0] 范围内,处理浮点精度误差
(dot_product / magnitude).max(-1.0).min(1.0)
}
#[wasm_bindgen]
pub fn batch_similarity(&self, vectors: &[f32], query: &[f32], vector_dim: usize) -> Vec<f32> {
if vector_dim == 0 { return Vec::new(); }
if vectors.len() % vector_dim != 0 { return Vec::new(); }
if query.len() != vector_dim { return Vec::new(); }
let num_vectors = vectors.len() / vector_dim;
let mut results = Vec::with_capacity(num_vectors);
// 预计算查询向量的范数
let query_norm_sq = self.compute_norm_squared_simd(query);
if query_norm_sq == 0.0 {
return vec![0.0; num_vectors];
}
let query_norm = query_norm_sq.sqrt();
for i in 0..num_vectors {
let start = i * vector_dim;
let vector_slice = &vectors[start..start + vector_dim];
// dot_product_and_norm_simd 计算 vector_slice (vec_a) 的范数
let (dot_product, vector_norm_sq) = self.dot_product_and_norm_simd(vector_slice, query);
if vector_norm_sq == 0.0 {
results.push(0.0);
} else {
let vector_norm = vector_norm_sq.sqrt();
let similarity = dot_product / (vector_norm * query_norm);
results.push(similarity.max(-1.0).min(1.0));
}
}
results
}
// 辅助函数SIMD 计算范数平方
#[inline]
fn compute_norm_squared_simd(&self, vec: &[f32]) -> f32 {
let len = vec.len();
let simd_lanes = 4;
let simd_len = len - (len % simd_lanes);
let mut norm_sum_simd = f32x4::ZERO;
for i in (0..simd_len).step_by(simd_lanes) {
let array: [f32; 4] = vec[i..i + simd_lanes].try_into().unwrap();
let chunk = f32x4::new(array);
norm_sum_simd = chunk.mul_add(chunk, norm_sum_simd);
}
let mut norm_sq = norm_sum_simd.reduce_add();
for i in simd_len..len {
norm_sq += vec[i] * vec[i];
}
norm_sq
}
// 辅助函数同时计算点积和vec_a的范数平方
#[inline]
fn dot_product_and_norm_simd(&self, vec_a: &[f32], vec_b: &[f32]) -> (f32, f32) {
let len = vec_a.len(); // 假设 vec_a.len() == vec_b.len()
let simd_lanes = 4;
let simd_len = len - (len % simd_lanes);
let mut dot_sum_simd = f32x4::ZERO;
let mut norm_a_sum_simd = f32x4::ZERO;
for i in (0..simd_len).step_by(simd_lanes) {
let a_array: [f32; 4] = vec_a[i..i + simd_lanes].try_into().unwrap();
let b_array: [f32; 4] = vec_b[i..i + simd_lanes].try_into().unwrap();
let a_chunk = f32x4::new(a_array);
let b_chunk = f32x4::new(b_array);
dot_sum_simd = a_chunk.mul_add(b_chunk, dot_sum_simd);
norm_a_sum_simd = a_chunk.mul_add(a_chunk, norm_a_sum_simd);
}
let mut dot_product = dot_sum_simd.reduce_add();
let mut norm_a_sq = norm_a_sum_simd.reduce_add();
for i in simd_len..len {
dot_product += vec_a[i] * vec_b[i];
norm_a_sq += vec_a[i] * vec_a[i];
}
(dot_product, norm_a_sq)
}
// 批量矩阵相似度计算 - 优化版
#[wasm_bindgen]
pub fn similarity_matrix(&self, vectors_a: &[f32], vectors_b: &[f32], vector_dim: usize) -> Vec<f32> {
if vector_dim == 0 || vectors_a.len() % vector_dim != 0 || vectors_b.len() % vector_dim != 0 {
return Vec::new();
}
let num_a = vectors_a.len() / vector_dim;
let num_b = vectors_b.len() / vector_dim;
let mut results = Vec::with_capacity(num_a * num_b);
// 1. 预计算 vectors_a 的范数
let norms_a: Vec<f32> = (0..num_a)
.map(|i| {
let start = i * vector_dim;
let vec_a_slice = &vectors_a[start..start + vector_dim];
self.compute_norm_squared_simd(vec_a_slice).sqrt()
})
.collect();
// 2. 预计算 vectors_b 的范数
let norms_b: Vec<f32> = (0..num_b)
.map(|j| {
let start = j * vector_dim;
let vec_b_slice = &vectors_b[start..start + vector_dim];
self.compute_norm_squared_simd(vec_b_slice).sqrt()
})
.collect();
for i in 0..num_a {
let start_a = i * vector_dim;
let vec_a = &vectors_a[start_a..start_a + vector_dim];
let norm_a = norms_a[i];
if norm_a == 0.0 {
// 如果 norm_a 为 0所有相似度都为 0
for _ in 0..num_b {
results.push(0.0);
}
continue;
}
for j in 0..num_b {
let start_b = j * vector_dim;
let vec_b = &vectors_b[start_b..start_b + vector_dim];
let norm_b = norms_b[j];
if norm_b == 0.0 {
results.push(0.0);
continue;
}
// 使用专用的点积函数
let dot_product = self.dot_product_simd_only(vec_a, vec_b);
let magnitude = norm_a * norm_b;
// magnitude 不应该为零,因为已经检查了 norm_a/norm_b
let similarity = (dot_product / magnitude).max(-1.0).min(1.0);
results.push(similarity);
}
}
results
}
}