From 3c38b509f184c5aea1f314df55e18b4b62ea20e2 Mon Sep 17 00:00:00 2001 From: godo Date: Thu, 9 Jan 2025 03:29:39 +0800 Subject: [PATCH] change embedings --- frontend/components.d.ts | 30 --- frontend/package-lock.json | 291 +++++++++++++++++++++ godo/ai/vector/collection.go | 409 +++++++++++++++++++++++++++++ godo/ai/vector/db.go | 412 +++++++++++++++++++++++++++++ godo/ai/vector/document.go | 198 +++----------- godo/ai/vector/files.go | 181 +++++++++++++ godo/ai/vector/openai.go | 125 +++++++++ godo/ai/vector/persistence.go | 208 +++++++++++++++ godo/ai/vector/query.go | 207 +++++++++++++++ godo/ai/vector/utils.go | 79 ++++++ godo/ai/vector/vector.go | 475 ++++++++++++++++++++++++---------- godo/cmd/main.go | 1 - godo/main.go | 4 +- godo/model/init.go | 5 +- godo/model/vec_doc.go | 14 + godo/model/vec_list.go | 65 +++++ godo/vector/go.mod | 17 +- godo/vector/go.sum | 28 +- godo/vector/init.go | 5 - godo/vector/main.go | 100 ------- godo/vector/run.go | 54 ++++ 21 files changed, 2455 insertions(+), 453 deletions(-) create mode 100644 godo/ai/vector/collection.go create mode 100644 godo/ai/vector/db.go create mode 100644 godo/ai/vector/files.go create mode 100644 godo/ai/vector/openai.go create mode 100644 godo/ai/vector/persistence.go create mode 100644 godo/ai/vector/query.go create mode 100644 godo/ai/vector/utils.go create mode 100644 godo/model/vec_doc.go create mode 100644 godo/model/vec_list.go delete mode 100644 godo/vector/init.go delete mode 100644 godo/vector/main.go create mode 100644 godo/vector/run.go diff --git a/frontend/components.d.ts b/frontend/components.d.ts index 4bbf6b6..18f8f32 100644 --- a/frontend/components.d.ts +++ b/frontend/components.d.ts @@ -64,50 +64,20 @@ declare module 'vue' { DownModelInfo: typeof import('./src/components/ai/DownModelInfo.vue')['default'] EditFileName: typeof import('./src/components/builtin/EditFileName.vue')['default'] EditType: typeof import('./src/components/builtin/EditType.vue')['default'] - ElAside: typeof import('element-plus/es')['ElAside'] ElAvatar: typeof import('element-plus/es')['ElAvatar'] - ElBadge: typeof import('element-plus/es')['ElBadge'] - ElBu: typeof import('element-plus/es')['ElBu'] ElButton: typeof import('element-plus/es')['ElButton'] ElCard: typeof import('element-plus/es')['ElCard'] ElCarousel: typeof import('element-plus/es')['ElCarousel'] ElCarouselItem: typeof import('element-plus/es')['ElCarouselItem'] ElCol: typeof import('element-plus/es')['ElCol'] - ElCollapse: typeof import('element-plus/es')['ElCollapse'] - ElCollapseItem: typeof import('element-plus/es')['ElCollapseItem'] - ElContainer: typeof import('element-plus/es')['ElContainer'] ElDialog: typeof import('element-plus/es')['ElDialog'] - ElDrawer: typeof import('element-plus/es')['ElDrawer'] - ElDropdown: typeof import('element-plus/es')['ElDropdown'] - ElDropdownItem: typeof import('element-plus/es')['ElDropdownItem'] - ElDropdownMenu: typeof import('element-plus/es')['ElDropdownMenu'] - ElEmpty: typeof import('element-plus/es')['ElEmpty'] - ElFooter: typeof import('element-plus/es')['ElFooter'] ElForm: typeof import('element-plus/es')['ElForm'] ElFormItem: typeof import('element-plus/es')['ElFormItem'] - ElFormItm: typeof import('element-plus/es')['ElFormItm'] - ElHeader: typeof import('element-plus/es')['ElHeader'] ElIcon: typeof import('element-plus/es')['ElIcon'] - ElImage: typeof import('element-plus/es')['ElImage'] ElInput: typeof import('element-plus/es')['ElInput'] - ElItem: typeof import('element-plus/es')['ElItem'] - ElMain: typeof import('element-plus/es')['ElMain'] - ElOption: typeof import('element-plus/es')['ElOption'] ElPagination: typeof import('element-plus/es')['ElPagination'] - ElPopover: typeof import('element-plus/es')['ElPopover'] ElProgress: typeof import('element-plus/es')['ElProgress'] ElRow: typeof import('element-plus/es')['ElRow'] - ElScrollbar: typeof import('element-plus/es')['ElScrollbar'] - ElSelect: typeof import('element-plus/es')['ElSelect'] - ElSlider: typeof import('element-plus/es')['ElSlider'] - ElSwitch: typeof import('element-plus/es')['ElSwitch'] - ElTabPane: typeof import('element-plus/es')['ElTabPane'] - ElTabs: typeof import('element-plus/es')['ElTabs'] - ElTag: typeof import('element-plus/es')['ElTag'] - ElText: typeof import('element-plus/es')['ElText'] - ElTooltip: typeof import('element-plus/es')['ElTooltip'] - ElTransfer: typeof import('element-plus/es')['ElTransfer'] - ElTree: typeof import('element-plus/es')['ElTree'] Error: typeof import('./src/components/taskbar/Error.vue')['default'] FileIcon: typeof import('./src/components/builtin/FileIcon.vue')['default'] FileIconImg: typeof import('./src/components/builtin/FileIconImg.vue')['default'] diff --git a/frontend/package-lock.json b/frontend/package-lock.json index aec4c59..1fd6b18 100644 --- a/frontend/package-lock.json +++ b/frontend/package-lock.json @@ -13,6 +13,7 @@ "@onlyoffice/document-editor-vue": "^1.4.0", "cherry-markdown": "^0.8.52", "dexie": "^4.0.8", + "dingtalk-jsapi": "^3.0.42", "element-plus": "^2.7.7", "file-saver": "^2.0.5", "js-md5": "^0.8.3", @@ -21,6 +22,7 @@ "moment": "^2.30.1", "pinia": "^2.1.7", "pinia-plugin-persist": "^1.0.0", + "qrcode": "^1.5.4", "swiper": "^11.1.15", "vant": "^4.9.15", "vue": "^3.4.31", @@ -1576,6 +1578,28 @@ "node": ">= 6.0.0" } }, + "node_modules/ansi-regex": { + "version": "5.0.1", + "resolved": "https://registry.npmmirror.com/ansi-regex/-/ansi-regex-5.0.1.tgz", + "integrity": "sha512-quJQXlTSUGL2LH9SUXo8VwsY4soanhgo6LNSm84E1LBcE8s3O0wpdiRzyR9z/ZZJMlMWv37qOOb9pdJlMUEKFQ==", + "engines": { + "node": ">=8" + } + }, + "node_modules/ansi-styles": { + "version": "4.3.0", + "resolved": "https://registry.npmmirror.com/ansi-styles/-/ansi-styles-4.3.0.tgz", + "integrity": "sha512-zbB9rCJAT1rbjiVDb2hqKFHNYLxgtk8NURxZ3IZwD3F6NtxbXZQCnnSi1Lkx+IDohdPlFp222wVALIheZJQSEg==", + "dependencies": { + "color-convert": "^2.0.1" + }, + "engines": { + "node": ">=8" + }, + "funding": { + "url": "https://github.com/chalk/ansi-styles?sponsor=1" + } + }, "node_modules/anymatch": { "version": "3.1.3", "resolved": "https://registry.npmmirror.com/anymatch/-/anymatch-3.1.3.tgz", @@ -1657,6 +1681,14 @@ "integrity": "sha512-E+XQCRwSbaaiChtv6k6Dwgc+bx+Bs6vuKJHHl5kox/BaKbhiXzqQOwK4cO22yElGp2OCmjwVhT3HmxgyPGnJfQ==", "dev": true }, + "node_modules/camelcase": { + "version": "5.3.1", + "resolved": "https://registry.npmmirror.com/camelcase/-/camelcase-5.3.1.tgz", + "integrity": "sha512-L28STB170nwWS63UjtlEOE3dldQApaJXZkOI1uMFfzf3rRuPegHaHesyee+YxQ+W6SvRDQV6UrdOdRiR153wJg==", + "engines": { + "node": ">=6" + } + }, "node_modules/cherry-markdown": { "version": "0.8.52", "resolved": "https://registry.npmmirror.com/cherry-markdown/-/cherry-markdown-0.8.52.tgz", @@ -1688,6 +1720,32 @@ "url": "https://paulmillr.com/funding/" } }, + "node_modules/cliui": { + "version": "6.0.0", + "resolved": "https://registry.npmmirror.com/cliui/-/cliui-6.0.0.tgz", + "integrity": "sha512-t6wbgtoCXvAzst7QgXxJYqPt0usEfbgQdftEPbLL/cvv6HPE5VgvqCuAIDR0NgU52ds6rFwqrgakNLrHEjCbrQ==", + "dependencies": { + "string-width": "^4.2.0", + "strip-ansi": "^6.0.0", + "wrap-ansi": "^6.2.0" + } + }, + "node_modules/color-convert": { + "version": "2.0.1", + "resolved": "https://registry.npmmirror.com/color-convert/-/color-convert-2.0.1.tgz", + "integrity": "sha512-RRECPsj7iu/xb5oKYcsFHSppFNnsj/52OVTRKb4zP5onXwVF3zVmmToNcOfGC+CRDpfK/U584fMg38ZHCaElKQ==", + "dependencies": { + "color-name": "~1.1.4" + }, + "engines": { + "node": ">=7.0.0" + } + }, + "node_modules/color-name": { + "version": "1.1.4", + "resolved": "https://registry.npmmirror.com/color-name/-/color-name-1.1.4.tgz", + "integrity": "sha512-dOy+3AuW3a2wNbZHIuMZpTcgjGuLU/uBL/ubcZF9OXbDo8ff4O8yVp5Bf0efS8uEoYo5q4Fx7dY9OgQGXgAsQA==" + }, "node_modules/combined-stream": { "version": "1.0.8", "resolved": "https://registry.npmmirror.com/combined-stream/-/combined-stream-1.0.8.tgz", @@ -2278,6 +2336,14 @@ } } }, + "node_modules/decamelize": { + "version": "1.2.0", + "resolved": "https://registry.npmmirror.com/decamelize/-/decamelize-1.2.0.tgz", + "integrity": "sha512-z2S+W9X73hAUUki+N+9Za2lBlun89zigOyGrsax+KUQ6wKW4ZoWpEYBkGhQjwAjjDCkWxhY0VKEhk8wzY7F5cA==", + "engines": { + "node": ">=0.10.0" + } + }, "node_modules/decimal.js": { "version": "10.4.3", "resolved": "https://registry.npmmirror.com/decimal.js/-/decimal.js-10.4.3.tgz", @@ -2318,6 +2384,19 @@ "resolved": "https://registry.npmmirror.com/dexie/-/dexie-4.0.10.tgz", "integrity": "sha512-eM2RzuR3i+M046r2Q0Optl3pS31qTWf8aFuA7H9wnsHTwl8EPvroVLwvQene/6paAs39Tbk6fWZcn2aZaHkc/w==" }, + "node_modules/dijkstrajs": { + "version": "1.0.3", + "resolved": "https://registry.npmmirror.com/dijkstrajs/-/dijkstrajs-1.0.3.tgz", + "integrity": "sha512-qiSlmBq9+BCdCA/L46dw8Uy93mloxsPSbwnm5yrKn2vMPiy8KyAskTF6zuV/j5BMsmOGZDPs7KjU+mjb670kfA==" + }, + "node_modules/dingtalk-jsapi": { + "version": "3.0.42", + "resolved": "https://registry.npmmirror.com/dingtalk-jsapi/-/dingtalk-jsapi-3.0.42.tgz", + "integrity": "sha512-cIJ+3HUnSRVAanCip5yT1rEoLPrj97BxjYKpB33sgwUDStmfPgyEzG8Hux/Sq2zYJNH6riEA9PflsDnevr1f/g==", + "dependencies": { + "promise-polyfill": "^7.1.0" + } + }, "node_modules/domexception": { "version": "4.0.0", "resolved": "https://registry.npmmirror.com/domexception/-/domexception-4.0.0.tgz", @@ -2367,6 +2446,11 @@ "integrity": "sha512-L6uRgvZTH+4OF5NE/MBbzQx/WYpru1xCBE9respNj6qznEewGUIfhzmm7horWWxbNO2M0WckQypGctR8lH79xQ==", "optional": true }, + "node_modules/emoji-regex": { + "version": "8.0.0", + "resolved": "https://registry.npmmirror.com/emoji-regex/-/emoji-regex-8.0.0.tgz", + "integrity": "sha512-MSjYzcWNOA0ewAHpz0MxpYFvwg6yjy1NG3xteoqz644VCo/RPgnr1/GGt+ic3iJTzQ8Eu3TdM14SawnVUmGE6A==" + }, "node_modules/esbuild": { "version": "0.21.5", "resolved": "https://registry.npmmirror.com/esbuild/-/esbuild-0.21.5.tgz", @@ -2522,6 +2606,18 @@ "node": ">=8" } }, + "node_modules/find-up": { + "version": "4.1.0", + "resolved": "https://registry.npmmirror.com/find-up/-/find-up-4.1.0.tgz", + "integrity": "sha512-PpOwAdQ/YlXQ2vj8a3h8IipDuYRi3wceVQQGYWxNINccq40Anw7BlsEXCMbt1Zt+OLA6Fq9suIpIWD0OsnISlw==", + "dependencies": { + "locate-path": "^5.0.0", + "path-exists": "^4.0.0" + }, + "engines": { + "node": ">=8" + } + }, "node_modules/form-data": { "version": "4.0.1", "resolved": "https://registry.npmmirror.com/form-data/-/form-data-4.0.1.tgz", @@ -2549,6 +2645,14 @@ "node": "^8.16.0 || ^10.6.0 || >=11.0.0" } }, + "node_modules/get-caller-file": { + "version": "2.0.5", + "resolved": "https://registry.npmmirror.com/get-caller-file/-/get-caller-file-2.0.5.tgz", + "integrity": "sha512-DyFP3BM/3YHTQOCUL/w0OZHR0lpKeGrxotcHWcqNEdnltqFwXVfhEBQ94eIo34AfQpo0rGki4cyIiftY06h2Fg==", + "engines": { + "node": "6.* || 8.* || >= 10.*" + } + }, "node_modules/glob-parent": { "version": "5.1.2", "resolved": "https://registry.npmmirror.com/glob-parent/-/glob-parent-5.1.2.tgz", @@ -2666,6 +2770,14 @@ "node": ">=0.10.0" } }, + "node_modules/is-fullwidth-code-point": { + "version": "3.0.0", + "resolved": "https://registry.npmmirror.com/is-fullwidth-code-point/-/is-fullwidth-code-point-3.0.0.tgz", + "integrity": "sha512-zymm5+u+sCsSWyD9qNaejV3DFvhCKclKdizYaJUuHA83RLjb7nSuGnddCHGv0hk+KY7BMAlsWeK4Ueg6EV6XQg==", + "engines": { + "node": ">=8" + } + }, "node_modules/is-glob": { "version": "4.0.3", "resolved": "https://registry.npmmirror.com/is-glob/-/is-glob-4.0.3.tgz", @@ -2808,6 +2920,17 @@ "url": "https://github.com/sponsors/antfu" } }, + "node_modules/locate-path": { + "version": "5.0.0", + "resolved": "https://registry.npmmirror.com/locate-path/-/locate-path-5.0.0.tgz", + "integrity": "sha512-t7hw9pI+WvuwNJXwk5zVHpyhIqzg2qTlklJOf0mVxGSbe3Fp2VieZcduNYjaLDoy6p9uGpQEGWG87WpMKlNq8g==", + "dependencies": { + "p-locate": "^4.1.0" + }, + "engines": { + "node": ">=8" + } + }, "node_modules/lodash": { "version": "4.17.21", "resolved": "https://registry.npmmirror.com/lodash/-/lodash-4.17.21.tgz", @@ -2988,6 +3111,39 @@ "resolved": "https://registry.npmmirror.com/nwsapi/-/nwsapi-2.2.16.tgz", "integrity": "sha512-F1I/bimDpj3ncaNDhfyMWuFqmQDBwDB0Fogc2qpL3BWvkQteFD/8BzWuIRl83rq0DXfm8SGt/HFhLXZyljTXcQ==" }, + "node_modules/p-limit": { + "version": "2.3.0", + "resolved": "https://registry.npmmirror.com/p-limit/-/p-limit-2.3.0.tgz", + "integrity": "sha512-//88mFWSJx8lxCzwdAABTJL2MyWB12+eIY7MDL2SqLmAkeKU9qxRvWuSyTjm3FUmpBEMuFfckAIqEaVGUDxb6w==", + "dependencies": { + "p-try": "^2.0.0" + }, + "engines": { + "node": ">=6" + }, + "funding": { + "url": "https://github.com/sponsors/sindresorhus" + } + }, + "node_modules/p-locate": { + "version": "4.1.0", + "resolved": "https://registry.npmmirror.com/p-locate/-/p-locate-4.1.0.tgz", + "integrity": "sha512-R79ZZ/0wAxKGu3oYMlz8jy/kbhsNrS7SKZ7PxEHBgJ5+F2mtFW2fK2cOtBh1cHYkQsbzFV7I+EoRKe6Yt0oK7A==", + "dependencies": { + "p-limit": "^2.2.0" + }, + "engines": { + "node": ">=8" + } + }, + "node_modules/p-try": { + "version": "2.2.0", + "resolved": "https://registry.npmmirror.com/p-try/-/p-try-2.2.0.tgz", + "integrity": "sha512-R4nPAVTAU0B9D35/Gk3uJf/7XYbQcyohSKdvAxIRSNghFl4e71hVoGnBNQz9cWaXxO2I10KTC+3jMdvvoKw6dQ==", + "engines": { + "node": ">=6" + } + }, "node_modules/pako": { "version": "1.0.11", "resolved": "https://registry.npmmirror.com/pako/-/pako-1.0.11.tgz", @@ -2998,6 +3154,14 @@ "resolved": "https://registry.npmmirror.com/parse5/-/parse5-6.0.1.tgz", "integrity": "sha512-Ofn/CTFzRGTTxwpNEs9PP93gXShHcTq255nzRYSKe8AkVpZY7e1fpmTfOyoIvjP5HG7Z2ZM7VS9PPhQGW2pOpw==" }, + "node_modules/path-exists": { + "version": "4.0.0", + "resolved": "https://registry.npmmirror.com/path-exists/-/path-exists-4.0.0.tgz", + "integrity": "sha512-ak9Qy5Q7jYb2Wwcey5Fpvg2KoAc/ZIhLSLOSBmRmygPsGwkVVt0fZa0qrtMz+m6tJTAHfZQ8FnmB4MG4LWy7/w==", + "engines": { + "node": ">=8" + } + }, "node_modules/pathe": { "version": "1.1.2", "resolved": "https://registry.npmmirror.com/pathe/-/pathe-1.1.2.tgz", @@ -3125,6 +3289,14 @@ "pathe": "^1.1.2" } }, + "node_modules/pngjs": { + "version": "5.0.0", + "resolved": "https://registry.npmmirror.com/pngjs/-/pngjs-5.0.0.tgz", + "integrity": "sha512-40QW5YalBNfQo5yRYmiw7Yz6TKKVr3h6970B2YE+3fQpsWcrbj1PzJgxeJ19DRQjhMbKPIuMY8rFaXc8moolVw==", + "engines": { + "node": ">=10.13.0" + } + }, "node_modules/postcss": { "version": "8.4.49", "resolved": "https://registry.npmmirror.com/postcss/-/postcss-8.4.49.tgz", @@ -3174,6 +3346,11 @@ "resolved": "https://registry.npmmirror.com/process-nextick-args/-/process-nextick-args-2.0.1.tgz", "integrity": "sha512-3ouUOpQhtgrbOa17J7+uxOTpITYWaGP7/AhoR3+A+/1e9skrzelGi/dXzEYyvbxubEF6Wn2ypscTKiKJFFn1ag==" }, + "node_modules/promise-polyfill": { + "version": "7.1.2", + "resolved": "https://registry.npmmirror.com/promise-polyfill/-/promise-polyfill-7.1.2.tgz", + "integrity": "sha512-FuEc12/eKqqoRYIGBrUptCBRhobL19PS2U31vMNTfyck1FxPyMfgsXyW4Mav85y/ZN1hop3hOwRlUDok23oYfQ==" + }, "node_modules/psl": { "version": "1.14.0", "resolved": "https://registry.npmmirror.com/psl/-/psl-1.14.0.tgz", @@ -3190,6 +3367,22 @@ "node": ">=6" } }, + "node_modules/qrcode": { + "version": "1.5.4", + "resolved": "https://registry.npmmirror.com/qrcode/-/qrcode-1.5.4.tgz", + "integrity": "sha512-1ca71Zgiu6ORjHqFBDpnSMTR2ReToX4l1Au1VFLyVeBTFavzQnv5JxMFr3ukHVKpSrSA2MCk0lNJSykjUfz7Zg==", + "dependencies": { + "dijkstrajs": "^1.0.1", + "pngjs": "^5.0.0", + "yargs": "^15.3.1" + }, + "bin": { + "qrcode": "bin/qrcode" + }, + "engines": { + "node": ">=10.13.0" + } + }, "node_modules/querystringify": { "version": "2.2.0", "resolved": "https://registry.npmmirror.com/querystringify/-/querystringify-2.2.0.tgz", @@ -3242,6 +3435,19 @@ "url": "https://paulmillr.com/funding/" } }, + "node_modules/require-directory": { + "version": "2.1.1", + "resolved": "https://registry.npmmirror.com/require-directory/-/require-directory-2.1.1.tgz", + "integrity": "sha512-fGxEI7+wsG9xrvdjsrlmL22OMTTiHRwAMroiEeMgq8gzoLC/PQr7RsRDSTLUg/bZAZtF+TVIkHc6/4RIKrui+Q==", + "engines": { + "node": ">=0.10.0" + } + }, + "node_modules/require-main-filename": { + "version": "2.0.0", + "resolved": "https://registry.npmmirror.com/require-main-filename/-/require-main-filename-2.0.0.tgz", + "integrity": "sha512-NKN5kMDylKuldxYLSUfrbo5Tuzh4hd+2E8NPPX02mZtn1VuREQToYe/ZdlJy+J3uCpfaiGF05e7B8W0iXbQHmg==" + }, "node_modules/requires-port": { "version": "1.0.0", "resolved": "https://registry.npmmirror.com/requires-port/-/requires-port-1.0.0.tgz", @@ -3393,6 +3599,11 @@ "resolved": "https://registry.npmmirror.com/sdp/-/sdp-2.12.0.tgz", "integrity": "sha512-jhXqQAQVM+8Xj5EjJGVweuEzgtGWb3tmEEpl3CLP3cStInSbVHSg0QWOGQzNq8pSID4JkpeV2mPqlMDLrm0/Vw==" }, + "node_modules/set-blocking": { + "version": "2.0.0", + "resolved": "https://registry.npmmirror.com/set-blocking/-/set-blocking-2.0.0.tgz", + "integrity": "sha512-KiKBS8AnWGEyLzofFfmvKwpdPzqiy16LvQfK3yv/fVH7Bj13/wl3JSR1J+rfgRE9q7xUJK4qvgS8raSOeLUehw==" + }, "node_modules/setimmediate": { "version": "1.0.5", "resolved": "https://registry.npmmirror.com/setimmediate/-/setimmediate-1.0.5.tgz", @@ -3433,6 +3644,30 @@ "safe-buffer": "~5.1.0" } }, + "node_modules/string-width": { + "version": "4.2.3", + "resolved": "https://registry.npmmirror.com/string-width/-/string-width-4.2.3.tgz", + "integrity": "sha512-wKyQRQpjJ0sIp62ErSZdGsjMJWsap5oRNihHhu6G7JVO/9jIB6UyevL+tXuOqrng8j/cxKTWyWUwvSTriiZz/g==", + "dependencies": { + "emoji-regex": "^8.0.0", + "is-fullwidth-code-point": "^3.0.0", + "strip-ansi": "^6.0.1" + }, + "engines": { + "node": ">=8" + } + }, + "node_modules/strip-ansi": { + "version": "6.0.1", + "resolved": "https://registry.npmmirror.com/strip-ansi/-/strip-ansi-6.0.1.tgz", + "integrity": "sha512-Y38VPSHcqkFrCpFnQ9vuSXmquuv5oXOKpGeT6aGrr3o3Gc9AlVa6JBfUSOCnbxGGZF+/0ooI7KrPuUSztUdU5A==", + "dependencies": { + "ansi-regex": "^5.0.1" + }, + "engines": { + "node": ">=8" + } + }, "node_modules/strip-literal": { "version": "2.1.1", "resolved": "https://registry.npmmirror.com/strip-literal/-/strip-literal-2.1.1.tgz", @@ -4061,6 +4296,24 @@ "node": ">=12" } }, + "node_modules/which-module": { + "version": "2.0.1", + "resolved": "https://registry.npmmirror.com/which-module/-/which-module-2.0.1.tgz", + "integrity": "sha512-iBdZ57RDvnOR9AGBhML2vFZf7h8vmBjhoaZqODJBFWHVtKkDmKuHai3cx5PgVMrX5YDNp27AofYbAwctSS+vhQ==" + }, + "node_modules/wrap-ansi": { + "version": "6.2.0", + "resolved": "https://registry.npmmirror.com/wrap-ansi/-/wrap-ansi-6.2.0.tgz", + "integrity": "sha512-r6lPcBGxZXlIcymEu7InxDMhdW0KDxpLgoFLcguasxCaJ/SOIZwINatK9KY/tf+ZrlywOKU0UDj3ATXUBfxJXA==", + "dependencies": { + "ansi-styles": "^4.0.0", + "string-width": "^4.1.0", + "strip-ansi": "^6.0.0" + }, + "engines": { + "node": ">=8" + } + }, "node_modules/ws": { "version": "8.18.0", "resolved": "https://registry.npmmirror.com/ws/-/ws-8.18.0.tgz", @@ -4093,6 +4346,44 @@ "version": "2.2.0", "resolved": "https://registry.npmmirror.com/xmlchars/-/xmlchars-2.2.0.tgz", "integrity": "sha512-JZnDKK8B0RCDw84FNdDAIpZK+JuJw+s7Lz8nksI7SIuU3UXJJslUthsi+uWBUYOwPFwW7W7PRLRfUKpxjtjFCw==" + }, + "node_modules/y18n": { + "version": "4.0.3", + "resolved": "https://registry.npmmirror.com/y18n/-/y18n-4.0.3.tgz", + "integrity": "sha512-JKhqTOwSrqNA1NY5lSztJ1GrBiUodLMmIZuLiDaMRJ+itFd+ABVE8XBjOvIWL+rSqNDC74LCSFmlb/U4UZ4hJQ==" + }, + "node_modules/yargs": { + "version": "15.4.1", + "resolved": "https://registry.npmmirror.com/yargs/-/yargs-15.4.1.tgz", + "integrity": "sha512-aePbxDmcYW++PaqBsJ+HYUFwCdv4LVvdnhBy78E57PIor8/OVvhMrADFFEDh8DHDFRv/O9i3lPhsENjO7QX0+A==", + "dependencies": { + "cliui": "^6.0.0", + "decamelize": "^1.2.0", + "find-up": "^4.1.0", + "get-caller-file": "^2.0.1", + "require-directory": "^2.1.1", + "require-main-filename": "^2.0.0", + "set-blocking": "^2.0.0", + "string-width": "^4.2.0", + "which-module": "^2.0.0", + "y18n": "^4.0.0", + "yargs-parser": "^18.1.2" + }, + "engines": { + "node": ">=8" + } + }, + "node_modules/yargs-parser": { + "version": "18.1.3", + "resolved": "https://registry.npmmirror.com/yargs-parser/-/yargs-parser-18.1.3.tgz", + "integrity": "sha512-o50j0JeToy/4K6OZcaQmW6lyXXKhq7csREXcDwk2omFPJEwUNOVtJKvmDr9EI1fAJZUyZcRF7kxGBWmRXudrCQ==", + "dependencies": { + "camelcase": "^5.0.0", + "decamelize": "^1.2.0" + }, + "engines": { + "node": ">=6" + } } } } diff --git a/godo/ai/vector/collection.go b/godo/ai/vector/collection.go new file mode 100644 index 0000000..924b9a2 --- /dev/null +++ b/godo/ai/vector/collection.go @@ -0,0 +1,409 @@ +package vector + +import ( + "context" + "errors" + "fmt" + "path/filepath" + "slices" + "sync" +) + +// Collection 表示一个文档集合。 +// 它还包含一个配置好的嵌入函数,当添加没有嵌入的文档时会使用该函数。 +type Collection struct { + Name string + + metadata map[string]string + documents map[string]*Document + documentsLock sync.RWMutex + embed EmbeddingFunc + + persistDirectory string + compress bool + + // ⚠️ 当添加字段时,请考虑在 [DB.Export] 和 [DB.Import] 的持久化结构中添加相应的字段 +} + +// 我们不导出这个函数,以保持 API 表面最小。 +// 用户通过 [Client.CreateCollection] 创建集合。 +func newCollection(name string, metadata map[string]string, embed EmbeddingFunc, dbDir string, compress bool) (*Collection, error) { + // 复制元数据以避免在创建集合后调用者修改元数据时发生数据竞争。 + m := make(map[string]string, len(metadata)) + for k, v := range metadata { + m[k] = v + } + + c := &Collection{ + Name: name, + + metadata: m, + documents: make(map[string]*Document), + embed: embed, + } + + // 持久化 + if dbDir != "" { + safeName := hash2hex(name) + c.persistDirectory = filepath.Join(dbDir, safeName) + c.compress = compress + // 持久化名称和元数据 + metadataPath := filepath.Join(c.persistDirectory, metadataFileName) + metadataPath += ".gob" + if c.compress { + metadataPath += ".gz" + } + pc := struct { + Name string + Metadata map[string]string + }{ + Name: name, + Metadata: m, + } + err := persistToFile(metadataPath, pc, compress, "") + if err != nil { + return nil, fmt.Errorf("无法持久化集合元数据: %w", err) + } + } + + return c, nil +} + +// 添加嵌入到数据存储中。 +// +// - ids: 要添加的嵌入的 ID +// - embeddings: 要添加的嵌入。如果为 nil,则基于内容使用集合的嵌入函数计算嵌入。可选。 +// - metadatas: 与嵌入关联的元数据。查询时可以过滤这些元数据。可选。 +// - contents: 与嵌入关联的内容。 +// +// 这是一个类似于 Chroma 的方法。对于更符合 Go 风格的方法,请参见 [AddDocuments]。 +func (c *Collection) Add(ctx context.Context, ids []string, embeddings [][]float32, metadatas []map[string]string, contents []string) error { + return c.AddConcurrently(ctx, ids, embeddings, metadatas, contents, 1) +} + +// AddConcurrently 类似于 Add,但并发地添加嵌入。 +// 这在没有传递任何嵌入时特别有用,因为需要创建嵌入。 +// 出现错误时,取消所有并发操作并返回错误。 +// +// 这是一个类似于 Chroma 的方法。对于更符合 Go 风格的方法,请参见 [AddDocuments]。 +func (c *Collection) AddConcurrently(ctx context.Context, ids []string, embeddings [][]float32, metadatas []map[string]string, contents []string, concurrency int) error { + if len(ids) == 0 { + return errors.New("ids 为空") + } + if len(embeddings) == 0 && len(contents) == 0 { + return errors.New("必须填写 embeddings 或 contents") + } + if len(embeddings) != 0 { + if len(embeddings) != len(ids) { + return errors.New("ids 和 embeddings 的长度必须相同") + } + } else { + // 分配空切片以便稍后通过索引访问 + embeddings = make([][]float32, len(ids)) + } + if len(metadatas) != 0 { + if len(ids) != len(metadatas) { + return errors.New("当 metadatas 不为空时,其长度必须与 ids 相同") + } + } else { + // 分配空切片以便稍后通过索引访问 + metadatas = make([]map[string]string, len(ids)) + } + if len(contents) != 0 { + if len(contents) != len(ids) { + return errors.New("ids 和 contents 的长度必须相同") + } + } else { + // 分配空切片以便稍后通过索引访问 + contents = make([]string, len(ids)) + } + if concurrency < 1 { + return errors.New("并发数必须至少为 1") + } + + // 将 Chroma 风格的参数转换为文档切片 + docs := make([]Document, 0, len(ids)) + for i, id := range ids { + docs = append(docs, Document{ + ID: id, + Metadata: metadatas[i], + Embedding: embeddings[i], + Content: contents[i], + }) + } + + return c.AddDocuments(ctx, docs, concurrency) +} + +// AddDocuments 使用指定的并发数将文档添加到集合中。 +// 如果文档没有嵌入,则使用集合的嵌入函数创建嵌入。 +// 出现错误时,取消所有并发操作并返回错误。 +func (c *Collection) AddDocuments(ctx context.Context, documents []Document, concurrency int) error { + if len(documents) == 0 { + // TODO: 这是否应为无操作(no-op)? + return errors.New("documents 切片为空") + } + if concurrency < 1 { + return errors.New("并发数必须至少为 1") + } + // 对于其他验证,我们依赖于 AddDocument。 + + var sharedErr error + sharedErrLock := sync.Mutex{} + ctx, cancel := context.WithCancelCause(ctx) + defer cancel(nil) + setSharedErr := func(err error) { + sharedErrLock.Lock() + defer sharedErrLock.Unlock() + // 另一个 goroutine 可能已经设置了错误。 + if sharedErr == nil { + sharedErr = err + // 取消所有其他 goroutine 的操作。 + cancel(sharedErr) + } + } + + var wg sync.WaitGroup + semaphore := make(chan struct{}, concurrency) + for _, doc := range documents { + wg.Add(1) + go func(doc Document) { + defer wg.Done() + + // 如果另一个 goroutine 已经失败,则不开始。 + if ctx.Err() != nil { + return + } + + // 等待直到 $concurrency 个其他 goroutine 正在创建文档。 + semaphore <- struct{}{} + defer func() { <-semaphore }() + + err := c.AddDocument(ctx, doc) + if err != nil { + setSharedErr(fmt.Errorf("无法添加文档 '%s': %w", doc.ID, err)) + return + } + }(doc) + } + + wg.Wait() + + return sharedErr +} + +// AddDocument 将文档添加到集合中。 +// 如果文档没有嵌入,则使用集合的嵌入函数创建嵌入。 +func (c *Collection) AddDocument(ctx context.Context, doc Document) error { + if doc.ID == "" { + return errors.New("文档 ID 为空") + } + if len(doc.Embedding) == 0 && doc.Content == "" { + return errors.New("必须填写文档的 embedding 或 content") + } + + // 复制元数据以避免在创建文档后调用者修改元数据时发生数据竞争。 + m := make(map[string]string, len(doc.Metadata)) + for k, v := range doc.Metadata { + m[k] = v + } + + // 如果嵌入不存在,则创建嵌入,否则如果需要则规范化 + if len(doc.Embedding) == 0 { + embedding, err := c.embed(ctx, doc.Content) + if err != nil { + return fmt.Errorf("无法创建文档的嵌入: %w", err) + } + doc.Embedding = embedding + } else { + if !isNormalized(doc.Embedding) { + doc.Embedding = normalizeVector(doc.Embedding) + } + } + + c.documentsLock.Lock() + // 我们不使用 defer 解锁,因为我们希望尽早解锁。 + c.documents[doc.ID] = &doc + c.documentsLock.Unlock() + + // 持久化文档 + if c.persistDirectory != "" { + docPath := c.getDocPath(doc.ID) + err := persistToFile(docPath, doc, c.compress, "") + if err != nil { + return fmt.Errorf("无法将文档持久化到 %q: %w", docPath, err) + } + } + + return nil +} + +// Delete 从集合中删除文档。 +// +// - where: 元数据的条件过滤。可选。 +// - whereDocument: 文档的条件过滤。可选。 +// - ids: 要删除的文档的 ID。如果为空,则删除所有文档。 +func (c *Collection) Delete(_ context.Context, where, whereDocument map[string]string, ids ...string) error { + // 必须至少有一个 where、whereDocument 或 ids + if len(where) == 0 && len(whereDocument) == 0 && len(ids) == 0 { + return fmt.Errorf("必须至少有一个 where、whereDocument 或 ids") + } + + if len(c.documents) == 0 { + return nil + } + + for k := range whereDocument { + if !slices.Contains(supportedFilters, k) { + return errors.New("不支持的 whereDocument 操作符") + } + } + + var docIDs []string + + c.documentsLock.Lock() + defer c.documentsLock.Unlock() + + if where != nil || whereDocument != nil { + // 元数据 + 内容过滤 + filteredDocs := filterDocs(c.documents, where, whereDocument) + for _, doc := range filteredDocs { + docIDs = append(docIDs, doc.ID) + } + } else { + docIDs = ids + } + + // 如果没有剩余的文档,则不执行操作 + if len(docIDs) == 0 { + return nil + } + + for _, docID := range docIDs { + delete(c.documents, docID) + + // 从磁盘删除文档 + if c.persistDirectory != "" { + docPath := c.getDocPath(docID) + err := removeFile(docPath) + if err != nil { + return fmt.Errorf("无法删除文档 %q: %w", docPath, err) + } + } + } + + return nil +} + +// Count 返回集合中的文档数量。 +func (c *Collection) Count() int { + c.documentsLock.RLock() + defer c.documentsLock.RUnlock() + return len(c.documents) +} + +// Result 表示查询结果中的单个结果。 +type Result struct { + ID string + Metadata map[string]string + Embedding []float32 + Content string + + // 查询与文档之间的余弦相似度。 + // 值越高,文档与查询越相似。 + // 值的范围是 [-1, 1]。 + Similarity float32 +} + +// 在集合上执行详尽的最近邻搜索。 +// +// - queryText: 要搜索的文本。其嵌入将使用集合的嵌入函数创建。 +// - nResults: 要返回的结果数量。必须大于 0。 +// - where: 元数据的条件过滤。可选。 +// - whereDocument: 文档的条件过滤。可选。 +func (c *Collection) Query(ctx context.Context, queryText string, nResults int, where, whereDocument map[string]string) ([]Result, error) { + if queryText == "" { + return nil, errors.New("queryText 为空") + } + + queryVectors, err := c.embed(ctx, queryText) + if err != nil { + return nil, fmt.Errorf("无法创建查询的嵌入: %w", err) + } + + return c.QueryEmbedding(ctx, queryVectors, nResults, where, whereDocument) +} + +// 在集合上执行详尽的最近邻搜索。 +// +// - queryEmbedding: 要搜索的查询的嵌入。必须使用与集合中文档嵌入相同的嵌入模型创建。 +// - nResults: 要返回的结果数量。必须大于 0。 +// - where: 元数据的条件过滤。可选。 +// - whereDocument: 文档的条件过滤。可选。 +func (c *Collection) QueryEmbedding(ctx context.Context, queryEmbedding []float32, nResults int, where, whereDocument map[string]string) ([]Result, error) { + if len(queryEmbedding) == 0 { + return nil, errors.New("queryEmbedding 为空") + } + if nResults <= 0 { + return nil, errors.New("nResults 必须大于 0") + } + c.documentsLock.RLock() + defer c.documentsLock.RUnlock() + // if nResults > len(c.documents) { + // return nil, errors.New("nResults 必须小于或等于集合中的文档数量") + // } + + if len(c.documents) == 0 { + return nil, nil + } + + // 验证 whereDocument 操作符 + for k := range whereDocument { + if !slices.Contains(supportedFilters, k) { + return nil, errors.New("不支持的操作符") + } + } + + // 根据元数据和内容过滤文档 + filteredDocs := filterDocs(c.documents, where, whereDocument) + + // 如果过滤器删除了所有文档,则不继续 + if len(filteredDocs) == 0 { + return nil, nil + } + + // 对于剩余的文档,获取最相似的文档。 + nMaxDocs, err := getMostSimilarDocs(ctx, queryEmbedding, filteredDocs, nResults) + if err != nil { + return nil, fmt.Errorf("无法获取最相似的文档: %w", err) + } + length := len(nMaxDocs) + if length > nResults { + length = nResults + } + res := make([]Result, 0, length) + for i := 0; i < length; i++ { + doc := c.documents[nMaxDocs[i].docID] + res = append(res, Result{ + ID: nMaxDocs[i].docID, + Metadata: doc.Metadata, + Embedding: doc.Embedding, + Content: doc.Content, + Similarity: nMaxDocs[i].similarity, + }) + } + + // 返回前 nResults 个结果 + return res, nil +} + +// getDocPath 生成文档文件的路径。 +func (c *Collection) getDocPath(docID string) string { + safeID := hash2hex(docID) + docPath := filepath.Join(c.persistDirectory, safeID) + docPath += ".gob" + if c.compress { + docPath += ".gz" + } + return docPath +} diff --git a/godo/ai/vector/db.go b/godo/ai/vector/db.go new file mode 100644 index 0000000..e50fa18 --- /dev/null +++ b/godo/ai/vector/db.go @@ -0,0 +1,412 @@ +package vector + +import ( + "context" + "errors" + "fmt" + "godo/libs" + "io" + "io/fs" + "os" + "path/filepath" + "strings" + "sync" +) + +// EmbeddingFunc 是一个为给定文本创建嵌入的函数。 +// 默认使用 OpenAI 的 "text-embedding-3-small" 模型。 +// 该函数必须返回一个已归一化的向量。 +type EmbeddingFunc func(ctx context.Context, text string) ([]float32, error) + +// DB 包含多个集合,每个集合包含多个文档。 +type DB struct { + collections map[string]*Collection + collectionsLock sync.RWMutex + + persistDirectory string + compress bool +} + +// NewDB 创建一个新的内存中的数据库。 +func NewDB() *DB { + return &DB{ + collections: make(map[string]*Collection), + } +} + +// NewPersistentDB 创建一个新的持久化的数据库。 +// 如果路径为空,默认为 "./godoos/data/godoDB"。 +// 如果 compress 为 true,则文件将使用 gzip 压缩。 +func NewPersistentDB(path string, compress bool) (*DB, error) { + homeDir, err := libs.GetAppDir() + if err != nil { + return nil, err + } + if path == "" { + path = filepath.Join(homeDir, "data", "godoDB") + } else { + path = filepath.Clean(path) + } + + ext := ".gob" + if compress { + ext += ".gz" + } + + db := &DB{ + collections: make(map[string]*Collection), + persistDirectory: path, + compress: compress, + } + + fi, err := os.Stat(path) + if err != nil { + if errors.Is(err, fs.ErrNotExist) { + err := os.MkdirAll(path, 0o700) + if err != nil { + return nil, fmt.Errorf("无法创建持久化目录: %w", err) + } + return db, nil + } + return nil, fmt.Errorf("无法获取持久化目录信息: %w", err) + } else if !fi.IsDir() { + return nil, fmt.Errorf("路径不是目录: %s", path) + } + + dirEntries, err := os.ReadDir(path) + if err != nil { + return nil, fmt.Errorf("无法读取持久化目录: %w", err) + } + for _, dirEntry := range dirEntries { + if !dirEntry.IsDir() { + continue + } + collectionPath := filepath.Join(path, dirEntry.Name()) + collectionDirEntries, err := os.ReadDir(collectionPath) + if err != nil { + return nil, fmt.Errorf("无法读取集合目录: %w", err) + } + c := &Collection{ + documents: make(map[string]*Document), + persistDirectory: collectionPath, + compress: compress, + } + for _, collectionDirEntry := range collectionDirEntries { + if collectionDirEntry.IsDir() { + continue + } + fPath := filepath.Join(collectionPath, collectionDirEntry.Name()) + if collectionDirEntry.Name() == metadataFileName+ext { + pc := struct { + Name string + Metadata map[string]string + }{} + err := readFromFile(fPath, &pc, "") + if err != nil { + return nil, fmt.Errorf("无法读取集合元数据: %w", err) + } + c.Name = pc.Name + c.metadata = pc.Metadata + } else if strings.HasSuffix(collectionDirEntry.Name(), ext) { + d := &Document{} + err := readFromFile(fPath, d, "") + if err != nil { + return nil, fmt.Errorf("无法读取文档: %w", err) + } + c.documents[d.ID] = d + } + } + if c.Name == "" && len(c.documents) == 0 { + continue + } + if c.Name == "" { + return nil, fmt.Errorf("未找到集合元数据文件: %s", collectionPath) + } + db.collections[c.Name] = c + } + + return db, nil +} + +// ImportFromFile 从给定路径的文件导入数据库。 +func (db *DB) ImportFromFile(filePath string, encryptionKey string) error { + if filePath == "" { + return fmt.Errorf("文件路径为空") + } + if encryptionKey != "" && len(encryptionKey) != 32 { + return errors.New("加密密钥必须为 32 字节长") + } + + fi, err := os.Stat(filePath) + if err != nil { + if errors.Is(err, fs.ErrNotExist) { + return fmt.Errorf("文件不存在: %s", filePath) + } + return fmt.Errorf("无法获取文件信息: %w", err) + } else if fi.IsDir() { + return fmt.Errorf("路径是目录: %s", filePath) + } + + type persistenceCollection struct { + Name string + Metadata map[string]string + Documents map[string]*Document + } + persistenceDB := struct { + Collections map[string]*persistenceCollection + }{ + Collections: make(map[string]*persistenceCollection, len(db.collections)), + } + + db.collectionsLock.Lock() + defer db.collectionsLock.Unlock() + + err = readFromFile(filePath, &persistenceDB, encryptionKey) + if err != nil { + return fmt.Errorf("无法读取文件: %w", err) + } + + for _, pc := range persistenceDB.Collections { + c := &Collection{ + Name: pc.Name, + metadata: pc.Metadata, + documents: pc.Documents, + } + if db.persistDirectory != "" { + c.persistDirectory = filepath.Join(db.persistDirectory, hash2hex(pc.Name)) + c.compress = db.compress + } + db.collections[c.Name] = c + } + + return nil +} + +// ImportFromReader 从 reader 导入数据库。 +func (db *DB) ImportFromReader(reader io.ReadSeeker, encryptionKey string) error { + if encryptionKey != "" && len(encryptionKey) != 32 { + return errors.New("加密密钥必须为 32 字节长") + } + + type persistenceCollection struct { + Name string + Metadata map[string]string + Documents map[string]*Document + } + persistenceDB := struct { + Collections map[string]*persistenceCollection + }{ + Collections: make(map[string]*persistenceCollection, len(db.collections)), + } + + db.collectionsLock.Lock() + defer db.collectionsLock.Unlock() + + err := readFromReader(reader, &persistenceDB, encryptionKey) + if err != nil { + return fmt.Errorf("无法读取流: %w", err) + } + + for _, pc := range persistenceDB.Collections { + c := &Collection{ + Name: pc.Name, + metadata: pc.Metadata, + documents: pc.Documents, + } + if db.persistDirectory != "" { + c.persistDirectory = filepath.Join(db.persistDirectory, hash2hex(pc.Name)) + c.compress = db.compress + } + db.collections[c.Name] = c + } + + return nil +} + +// ExportToFile 将数据库导出到给定路径的文件。 +func (db *DB) ExportToFile(filePath string, compress bool, encryptionKey string) error { + if filePath == "" { + filePath = "./gododb.gob" + if compress { + filePath += ".gz" + } + if encryptionKey != "" { + filePath += ".enc" + } + } + if encryptionKey != "" && len(encryptionKey) != 32 { + return errors.New("加密密钥必须为 32 字节长") + } + + type persistenceCollection struct { + Name string + Metadata map[string]string + Documents map[string]*Document + } + persistenceDB := struct { + Collections map[string]*persistenceCollection + }{ + Collections: make(map[string]*persistenceCollection, len(db.collections)), + } + + db.collectionsLock.RLock() + defer db.collectionsLock.RUnlock() + + for k, v := range db.collections { + persistenceDB.Collections[k] = &persistenceCollection{ + Name: v.Name, + Metadata: v.metadata, + Documents: v.documents, + } + } + + err := persistToFile(filePath, persistenceDB, compress, encryptionKey) + if err != nil { + return fmt.Errorf("无法导出数据库: %w", err) + } + + return nil +} + +// ExportToWriter 将数据库导出到 writer。 +func (db *DB) ExportToWriter(writer io.Writer, compress bool, encryptionKey string) error { + if encryptionKey != "" && len(encryptionKey) != 32 { + return errors.New("加密密钥必须为 32 字节长") + } + + type persistenceCollection struct { + Name string + Metadata map[string]string + Documents map[string]*Document + } + persistenceDB := struct { + Collections map[string]*persistenceCollection + }{ + Collections: make(map[string]*persistenceCollection, len(db.collections)), + } + + db.collectionsLock.RLock() + defer db.collectionsLock.RUnlock() + + for k, v := range db.collections { + persistenceDB.Collections[k] = &persistenceCollection{ + Name: v.Name, + Metadata: v.metadata, + Documents: v.documents, + } + } + + err := persistToWriter(writer, persistenceDB, compress, encryptionKey) + if err != nil { + return fmt.Errorf("无法导出数据库: %w", err) + } + + return nil +} + +// CreateCollection 创建具有给定名称和元数据的新集合。 +func (db *DB) CreateCollection(name string, metadata map[string]string, embeddingFunc EmbeddingFunc) (*Collection, error) { + if name == "" { + return nil, errors.New("集合名称为空") + } + if embeddingFunc == nil { + embeddingFunc = NewEmbeddingFuncDefault() + } + collection, err := newCollection(name, metadata, embeddingFunc, db.persistDirectory, db.compress) + if err != nil { + return nil, fmt.Errorf("无法创建集合: %w", err) + } + + db.collectionsLock.Lock() + defer db.collectionsLock.Unlock() + db.collections[name] = collection + return collection, nil +} + +// ListCollections 返回数据库中的所有集合。 +func (db *DB) ListCollections() map[string]*Collection { + db.collectionsLock.RLock() + defer db.collectionsLock.RUnlock() + + res := make(map[string]*Collection, len(db.collections)) + for k, v := range db.collections { + res[k] = v + } + + return res +} + +// GetCollection 返回具有给定名称的集合。 +func (db *DB) GetCollection(name string, embeddingFunc EmbeddingFunc) *Collection { + db.collectionsLock.RLock() + defer db.collectionsLock.RUnlock() + + c, ok := db.collections[name] + if !ok { + return nil + } + + if c.embed == nil { + if embeddingFunc == nil { + c.embed = NewEmbeddingFuncDefault() + } else { + c.embed = embeddingFunc + } + } + return c +} + +// GetOrCreateCollection 返回数据库中已有的集合,或创建一个新的集合。 +func (db *DB) GetOrCreateCollection(name string, metadata map[string]string, embeddingFunc EmbeddingFunc) (*Collection, error) { + collection := db.GetCollection(name, embeddingFunc) + if collection == nil { + var err error + collection, err = db.CreateCollection(name, metadata, embeddingFunc) + if err != nil { + return nil, fmt.Errorf("无法创建集合: %w", err) + } + } + return collection, nil +} + +// DeleteCollection 删除具有给定名称的集合。 +func (db *DB) DeleteCollection(name string) error { + db.collectionsLock.Lock() + defer db.collectionsLock.Unlock() + + col, ok := db.collections[name] + if !ok { + return nil + } + + if db.persistDirectory != "" { + collectionPath := col.persistDirectory + err := os.RemoveAll(collectionPath) + if err != nil { + return fmt.Errorf("无法删除集合目录: %w", err) + } + } + + delete(db.collections, name) + return nil +} + +// Reset 从数据库中移除所有集合。 +func (db *DB) Reset() error { + db.collectionsLock.Lock() + defer db.collectionsLock.Unlock() + + if db.persistDirectory != "" { + err := os.RemoveAll(db.persistDirectory) + if err != nil { + return fmt.Errorf("无法删除持久化目录: %w", err) + } + err = os.MkdirAll(db.persistDirectory, 0o700) + if err != nil { + return fmt.Errorf("无法重新创建持久化目录: %w", err) + } + } + + db.collections = make(map[string]*Collection) + return nil +} diff --git a/godo/ai/vector/document.go b/godo/ai/vector/document.go index 29e5866..0532639 100644 --- a/godo/ai/vector/document.go +++ b/godo/ai/vector/document.go @@ -1,178 +1,52 @@ package vector import ( - "encoding/json" + "context" + "errors" "fmt" - "godo/ai/server" - "godo/libs" - "godo/office" - "log" - "os" - "path/filepath" - "strings" - - "github.com/fsnotify/fsnotify" ) -var MapFilePathMonitors = map[string]uint{} - -func FolderMonitor() { - basePath, err := libs.GetOsDir() - if err != nil { - log.Printf("Error getting base path: %s", err.Error()) - return - } - watcher, err := fsnotify.NewWatcher() - if err != nil { - log.Printf("Error creating watcher: %s", err.Error()) - return - } - defer watcher.Close() - - // 递归添加所有子目录 - addRecursive(basePath, watcher) - - // Start listening for events. - go func() { - for { - select { - case event, ok := <-watcher.Events: - if !ok { - log.Println("error:", err) - return - } - //log.Println("event:", event) - filePath := event.Name - result, knowledgeId := shouldProcess(filePath) - //log.Printf("result:%d,knowledgeId:%d", result, knowledgeId) - if result > 0 { - info, err := os.Stat(filePath) - if event.Has(fsnotify.Write) || event.Has(fsnotify.Create) { - log.Println("modified file:", filePath) - if !info.IsDir() { - handleGodoosFile(filePath, knowledgeId) - } - } - if event.Has(fsnotify.Create) || event.Has(fsnotify.Rename) { - // 处理创建或重命名事件,添加新目录 - if err == nil && info.IsDir() { - addRecursive(filePath, watcher) - } - } - if event.Has(fsnotify.Remove) { - // 处理删除事件,移除目录 - if err == nil && info.IsDir() { - watcher.Remove(filePath) - } - } - } - case err, ok := <-watcher.Errors: - if !ok { - return - } - log.Println("error:", err) - } - } - }() +// Document 表示单个文档。 +type Document struct { + ID string // 文档的唯一标识符 + Metadata map[string]string // 文档的元数据 + Embedding []float32 // 文档的嵌入向量 + Content string // 文档的内容 - // Add a path. - err = watcher.Add(basePath) - if err != nil { - log.Fatal(err) - } - - // Block main goroutine forever. - <-make(chan struct{}) + // ⚠️ 当在此处添加未导出字段时,请考虑在 [DB.Export] 和 [DB.Import] 中添加一个持久化结构版本。 } -func shouldProcess(filePath string) (int, uint) { - // 规范化路径 - filePath = filepath.Clean(filePath) - - // 检查文件路径是否在 MapFilePathMonitors 中 - for path, id := range MapFilePathMonitors { - if id < 1 { - return 0, 0 - } - path = filepath.Clean(path) - if filePath == path { - return 1, id // 完全相等 - } - if strings.HasPrefix(filePath, path+string(filepath.Separator)) { - return 2, id // 包含 - } +// NewDocument 创建一个新的文档,包括其嵌入向量。 +// 元数据是可选的。 +// 如果未提供嵌入向量,则使用嵌入函数创建。 +// 如果内容为空但需要存储嵌入向量,可以仅提供嵌入向量。 +// 如果 embeddingFunc 为 nil,则使用默认的嵌入函数。 +// +// 如果你想创建没有嵌入向量的文档,例如让 [Collection.AddDocuments] 并发创建它们, +// 可以使用 `chromem.Document{...}` 而不是这个构造函数。 +func NewDocument(ctx context.Context, id string, metadata map[string]string, embedding []float32, content string, embeddingFunc EmbeddingFunc) (Document, error) { + if id == "" { + return Document{}, errors.New("ID 不能为空") + } + if len(embedding) == 0 && content == "" { + return Document{}, errors.New("嵌入向量或内容必须至少有一个非空") + } + if embeddingFunc == nil { + embeddingFunc = NewEmbeddingFuncDefault() } - return 0, 0 // 不存在 -} -func addRecursive(path string, watcher *fsnotify.Watcher) { - err := filepath.Walk(path, func(path string, info os.FileInfo, err error) error { + if len(embedding) == 0 { + var err error + embedding, err = embeddingFunc(ctx, content) if err != nil { - log.Printf("Error walking path %s: %v", path, err) - return err + return Document{}, fmt.Errorf("无法生成嵌入向量: %w", err) } - if info.IsDir() { - result, _ := shouldProcess(path) - if result > 0 { - if err := watcher.Add(path); err != nil { - log.Printf("Error adding path %s to watcher: %v", path, err) - return err - } - log.Printf("Added path %s to watcher", path) - } - - } - return nil - }) - if err != nil { - log.Printf("Error adding recursive paths: %v", err) } -} -func handleGodoosFile(filePath string, knowledgeId uint) error { - log.Printf("========Handling .godoos file: %s", filePath) - baseName := filepath.Base(filePath) - if baseName[:8] != ".godoos." { - if baseName[:1] != "." { - office.ProcessFile(filePath, knowledgeId) - } - return nil - } - var doc office.Document - content, err := os.ReadFile(filePath) - if err != nil { - return err - } - err = json.Unmarshal(content, &doc) - if err != nil { - return err - } - if len(doc.Split) == 0 { - return fmt.Errorf("invalid .godoos file: %s", filePath) - } - knowData := GetVector(knowledgeId) - resList, err := server.GetEmbeddings(knowData.Engine, knowData.EmbeddingModel, doc.Split) - if err != nil { - return err - } - if len(resList) != len(doc.Split) { - return fmt.Errorf("invalid file len: %s, expected %d embeddings, got %d", filePath, len(doc.Split), len(resList)) - } - // var vectordocs []model.Vectordoc - // for i, res := range resList { - // //log.Printf("res: %v", res) - // vectordoc := model.Vectordoc{ - // Content: doc.Split[i], - // Embed: res, - // FilePath: filePath, - // KnowledgeID: knowledgeId, - // Pos: fmt.Sprintf("%d", i), - // } - // vectordocs = append(vectordocs, vectordoc) - // } - // result := vectorListDb.Create(&vectordocs) - // if result.Error != nil { - // return result.Error - // } - return nil + return Document{ + ID: id, + Metadata: metadata, + Embedding: embedding, + Content: content, + }, nil } diff --git a/godo/ai/vector/files.go b/godo/ai/vector/files.go new file mode 100644 index 0000000..05c42c2 --- /dev/null +++ b/godo/ai/vector/files.go @@ -0,0 +1,181 @@ +package vector + +import ( + "encoding/json" + "fmt" + "godo/ai/server" + "godo/libs" + "godo/office" + "log" + "os" + "path/filepath" + "strings" + + "github.com/fsnotify/fsnotify" +) + +var MapFilePathMonitors = map[string]uint{} + +func FolderMonitor() { + basePath, err := libs.GetOsDir() + if err != nil { + log.Printf("Error getting base path: %s", err.Error()) + return + } + watcher, err := fsnotify.NewWatcher() + if err != nil { + log.Printf("Error creating watcher: %s", err.Error()) + return + } + defer watcher.Close() + + // 递归添加所有子目录 + addRecursive(basePath, watcher) + + // Start listening for events. + go func() { + for { + select { + case event, ok := <-watcher.Events: + if !ok { + log.Println("error:", err) + return + } + //log.Println("event:", event) + filePath := event.Name + result, knowledgeId := shouldProcess(filePath) + //log.Printf("result:%d,knowledgeId:%d", result, knowledgeId) + if result > 0 { + info, err := os.Stat(filePath) + if event.Has(fsnotify.Write) || event.Has(fsnotify.Create) { + log.Println("modified file:", filePath) + if !info.IsDir() { + handleGodoosFile(filePath, knowledgeId) + } + } + if event.Has(fsnotify.Create) || event.Has(fsnotify.Rename) { + // 处理创建或重命名事件,添加新目录 + if err == nil && info.IsDir() { + addRecursive(filePath, watcher) + } + } + if event.Has(fsnotify.Remove) { + // 处理删除事件,移除目录 + if err == nil && info.IsDir() { + watcher.Remove(filePath) + } + } + } + case err, ok := <-watcher.Errors: + if !ok { + return + } + log.Println("error:", err) + } + } + }() + + // Add a path. + err = watcher.Add(basePath) + if err != nil { + log.Fatal(err) + } + + // Block main goroutine forever. + <-make(chan struct{}) +} + +func shouldProcess(filePath string) (int, uint) { + // 规范化路径 + filePath = filepath.Clean(filePath) + + // 检查文件路径是否在 MapFilePathMonitors 中 + for path, id := range MapFilePathMonitors { + if id < 1 { + return 0, 0 + } + path = filepath.Clean(path) + if filePath == path { + return 1, id // 完全相等 + } + if strings.HasPrefix(filePath, path+string(filepath.Separator)) { + return 2, id // 包含 + } + } + return 0, 0 // 不存在 +} + +func addRecursive(path string, watcher *fsnotify.Watcher) { + err := filepath.Walk(path, func(path string, info os.FileInfo, err error) error { + if err != nil { + log.Printf("Error walking path %s: %v", path, err) + return err + } + if info.IsDir() { + result, _ := shouldProcess(path) + if result > 0 { + if err := watcher.Add(path); err != nil { + log.Printf("Error adding path %s to watcher: %v", path, err) + return err + } + log.Printf("Added path %s to watcher", path) + } + + } + return nil + }) + if err != nil { + log.Printf("Error adding recursive paths: %v", err) + } +} + +func handleGodoosFile(filePath string, knowledgeId uint) error { + log.Printf("========Handling .godoos file: %s", filePath) + baseName := filepath.Base(filePath) + if baseName[:8] != ".godoos." { + if baseName[:1] != "." { + office.ProcessFile(filePath, knowledgeId) + } + return nil + } + var doc office.Document + content, err := os.ReadFile(filePath) + if err != nil { + return err + } + err = json.Unmarshal(content, &doc) + if err != nil { + return err + } + if len(doc.Split) == 0 { + return fmt.Errorf("invalid .godoos file: %s", filePath) + } + knowData, err := GetVector(knowledgeId) + if err != nil { + return err + } + resList, err := server.GetEmbeddings(knowData.Engine, knowData.EmbeddingModel, doc.Split) + if err != nil { + return err + } + if len(resList) != len(doc.Split) { + return fmt.Errorf("invalid file len: %s, expected %d embeddings, got %d", filePath, len(doc.Split), len(resList)) + } + // var vectordocs []model.Vectordoc + // for i, res := range resList { + // //log.Printf("res: %v", res) + // vectordoc := model.Vectordoc{ + // Content: doc.Split[i], + // Embed: res, + // FilePath: filePath, + // KnowledgeID: knowledgeId, + // Pos: fmt.Sprintf("%d", i), + // } + // vectordocs = append(vectordocs, vectordoc) + // } + // result := vectorListDb.Create(&vectordocs) + // if result.Error != nil { + // return result.Error + // } + return nil +} diff --git a/godo/ai/vector/openai.go b/godo/ai/vector/openai.go new file mode 100644 index 0000000..2923649 --- /dev/null +++ b/godo/ai/vector/openai.go @@ -0,0 +1,125 @@ +package vector + +import ( + "bytes" + "context" + "encoding/json" + "errors" + "fmt" + "io" + "net/http" + "os" + "sync" +) + +const BaseURLOpenAI = "https://api.openai.com/v1" + +type EmbeddingModelOpenAI string + +const ( + EmbeddingModelOpenAI2Ada EmbeddingModelOpenAI = "text-embedding-ada-002" + EmbeddingModelOpenAI3Small EmbeddingModelOpenAI = "text-embedding-3-small" + EmbeddingModelOpenAI3Large EmbeddingModelOpenAI = "text-embedding-3-large" +) + +type openAIResponse struct { + Data []struct { + Embedding []float32 `json:"embedding"` + } `json:"data"` +} + +// NewEmbeddingFuncDefault 返回一个函数,使用 OpenAI 的 "text-embedding-3-small" 模型通过 API 创建文本嵌入向量。 +// 该模型支持的最大文本长度为 8191 个标记。 +// API 密钥从环境变量 "OPENAI_API_KEY" 中读取。 +func NewEmbeddingFuncDefault() EmbeddingFunc { + apiKey := os.Getenv("OPENAI_API_KEY") + return NewEmbeddingFuncOpenAI(apiKey, EmbeddingModelOpenAI3Small) +} + +// NewEmbeddingFuncOpenAI 返回一个函数,使用 OpenAI API 创建文本嵌入向量。 +func NewEmbeddingFuncOpenAI(apiKey string, model EmbeddingModelOpenAI) EmbeddingFunc { + // OpenAI 嵌入向量已归一化 + normalized := true + return NewEmbeddingFuncOpenAICompat(BaseURLOpenAI, apiKey, string(model), &normalized) +} + +// NewEmbeddingFuncOpenAICompat 返回一个函数,使用兼容 OpenAI 的 API 创建文本嵌入向量。 +// 例如: +// - Azure OpenAI: https://azure.microsoft.com/en-us/products/ai-services/openai-service +// - LitLLM: https://github.com/BerriAI/litellm +// - Ollama: https://github.com/ollama/ollama/blob/main/docs/openai.md +// +// `normalized` 参数表示嵌入模型返回的向量是否已经归一化。如果为 nil,则会在首次请求时自动检测(有小概率向量恰好长度为 1)。 +func NewEmbeddingFuncOpenAICompat(baseURL, apiKey, model string, normalized *bool) EmbeddingFunc { + client := &http.Client{} + + var checkedNormalized bool + checkNormalized := sync.Once{} + + return func(ctx context.Context, text string) ([]float32, error) { + // 准备请求体 + reqBody, err := json.Marshal(map[string]string{ + "input": text, + "model": model, + }) + if err != nil { + return nil, fmt.Errorf("无法序列化请求体: %w", err) + } + + // 创建带有上下文的请求以支持超时 + req, err := http.NewRequestWithContext(ctx, "POST", baseURL+"/embeddings", bytes.NewBuffer(reqBody)) + if err != nil { + return nil, fmt.Errorf("无法创建请求: %w", err) + } + req.Header.Set("Content-Type", "application/json") + req.Header.Set("Authorization", "Bearer "+apiKey) + + // 发送请求 + resp, err := client.Do(req) + if err != nil { + return nil, fmt.Errorf("无法发送请求: %w", err) + } + defer resp.Body.Close() + + // 检查响应状态 + if resp.StatusCode != http.StatusOK { + return nil, errors.New("嵌入 API 返回错误响应: " + resp.Status) + } + + // 读取并解码响应体 + body, err := io.ReadAll(resp.Body) + if err != nil { + return nil, fmt.Errorf("无法读取响应体: %w", err) + } + var embeddingResponse openAIResponse + err = json.Unmarshal(body, &embeddingResponse) + if err != nil { + return nil, fmt.Errorf("无法反序列化响应体: %w", err) + } + + // 检查响应中是否包含嵌入向量 + if len(embeddingResponse.Data) == 0 || len(embeddingResponse.Data[0].Embedding) == 0 { + return nil, errors.New("响应中未找到嵌入向量") + } + + v := embeddingResponse.Data[0].Embedding + if normalized != nil { + if *normalized { + return v, nil + } + return normalizeVector(v), nil + } + checkNormalized.Do(func() { + if isNormalized(v) { + checkedNormalized = true + } else { + checkedNormalized = false + } + }) + if !checkedNormalized { + v = normalizeVector(v) + } + + return v, nil + } +} diff --git a/godo/ai/vector/persistence.go b/godo/ai/vector/persistence.go new file mode 100644 index 0000000..e79d1b8 --- /dev/null +++ b/godo/ai/vector/persistence.go @@ -0,0 +1,208 @@ +package vector + +import ( + "bytes" + "compress/gzip" + "crypto/aes" + "crypto/cipher" + "crypto/rand" + "crypto/sha256" + "encoding/gob" + "encoding/hex" + "errors" + "fmt" + "io" + "io/fs" + "os" + "path/filepath" +) + +const metadataFileName = "00000000" + +// hash2hex 将字符串转换为 SHA256 哈希并返回前 8 位的十六进制表示。 +func hash2hex(name string) string { + hash := sha256.Sum256([]byte(name)) + return hex.EncodeToString(hash[:4]) +} + +// persistToFile 将对象持久化到文件。支持 Gob 序列化、Gzip 压缩和 AES-GCM 加密。 +func persistToFile(filePath string, obj any, compress bool, encryptionKey string) error { + if filePath == "" { + return fmt.Errorf("文件路径为空") + } + if encryptionKey != "" && len(encryptionKey) != 32 { + return errors.New("加密密钥必须是 32 字节长") + } + + // 确保父目录存在 + if err := os.MkdirAll(filepath.Dir(filePath), 0o700); err != nil { + return fmt.Errorf("无法创建父目录: %w", err) + } + + // 打开或创建文件 + f, err := os.Create(filePath) + if err != nil { + return fmt.Errorf("无法创建文件: %w", err) + } + defer f.Close() + + return persistToWriter(f, obj, compress, encryptionKey) +} + +// persistToWriter 将对象持久化到 io.Writer。支持 Gob 序列化、Gzip 压缩和 AES-GCM 加密。 +func persistToWriter(w io.Writer, obj any, compress bool, encryptionKey string) error { + if encryptionKey != "" && len(encryptionKey) != 32 { + return errors.New("加密密钥必须是 32 字节长") + } + + var chainedWriter io.Writer + if encryptionKey == "" { + chainedWriter = w + } else { + chainedWriter = &bytes.Buffer{} + } + + var gzw *gzip.Writer + var enc *gob.Encoder + if compress { + gzw = gzip.NewWriter(chainedWriter) + enc = gob.NewEncoder(gzw) + } else { + enc = gob.NewEncoder(chainedWriter) + } + + if err := enc.Encode(obj); err != nil { + return fmt.Errorf("无法编码或写入对象: %w", err) + } + + if compress { + if err := gzw.Close(); err != nil { + return fmt.Errorf("无法关闭 Gzip 写入器: %w", err) + } + } + + if encryptionKey == "" { + return nil + } + + block, err := aes.NewCipher([]byte(encryptionKey)) + if err != nil { + return fmt.Errorf("无法创建 AES 密码: %w", err) + } + gcm, err := cipher.NewGCM(block) + if err != nil { + return fmt.Errorf("无法创建 GCM 包装器: %w", err) + } + nonce := make([]byte, gcm.NonceSize()) + if _, err := io.ReadFull(rand.Reader, nonce); err != nil { + return fmt.Errorf("无法读取随机字节作为 nonce: %w", err) + } + + buf := chainedWriter.(*bytes.Buffer) + encrypted := gcm.Seal(nonce, nonce, buf.Bytes(), nil) + if _, err := w.Write(encrypted); err != nil { + return fmt.Errorf("无法写入加密数据: %w", err) + } + + return nil +} + +// readFromFile 从文件中读取对象。支持 Gob 反序列化、Gzip 解压和 AES-GCM 解密。 +func readFromFile(filePath string, obj any, encryptionKey string) error { + if filePath == "" { + return fmt.Errorf("文件路径为空") + } + if encryptionKey != "" && len(encryptionKey) != 32 { + return errors.New("加密密钥必须是 32 字节长") + } + + r, err := os.Open(filePath) + if err != nil { + return fmt.Errorf("无法打开文件: %w", err) + } + defer r.Close() + + return readFromReader(r, obj, encryptionKey) +} + +// readFromReader 从 io.Reader 中读取对象。支持 Gob 反序列化、Gzip 解压和 AES-GCM 解密。 +func readFromReader(r io.ReadSeeker, obj any, encryptionKey string) error { + if encryptionKey != "" && len(encryptionKey) != 32 { + return errors.New("加密密钥必须是 32 字节长") + } + + var chainedReader io.Reader + if encryptionKey != "" { + encrypted, err := io.ReadAll(r) + if err != nil { + return fmt.Errorf("无法读取数据: %w", err) + } + block, err := aes.NewCipher([]byte(encryptionKey)) + if err != nil { + return fmt.Errorf("无法创建 AES 密码: %w", err) + } + gcm, err := cipher.NewGCM(block) + if err != nil { + return fmt.Errorf("无法创建 GCM 包装器: %w", err) + } + nonceSize := gcm.NonceSize() + if len(encrypted) < nonceSize { + return fmt.Errorf("加密数据太短") + } + nonce, ciphertext := encrypted[:nonceSize], encrypted[nonceSize:] + data, err := gcm.Open(nil, nonce, ciphertext, nil) + if err != nil { + return fmt.Errorf("无法解密数据: %w", err) + } + chainedReader = bytes.NewReader(data) + } else { + chainedReader = r + } + + magicNumber := make([]byte, 2) + _, err := chainedReader.Read(magicNumber) + if err != nil { + return fmt.Errorf("无法读取魔数以确定是否压缩: %w", err) + } + compressed := magicNumber[0] == 0x1f && magicNumber[1] == 0x8b + + // 重置读取器位置 + if s, ok := chainedReader.(io.Seeker); !ok { + return fmt.Errorf("读取器不支持寻址") + } else { + _, err := s.Seek(0, 0) + if err != nil { + return fmt.Errorf("无法重置读取器: %w", err) + } + } + + if compressed { + gzr, err := gzip.NewReader(chainedReader) + if err != nil { + return fmt.Errorf("无法创建 Gzip 读取器: %w", err) + } + defer gzr.Close() + chainedReader = gzr + } + + dec := gob.NewDecoder(chainedReader) + if err := dec.Decode(obj); err != nil { + return fmt.Errorf("无法解码对象: %w", err) + } + + return nil +} + +// removeFile 删除指定路径的文件。如果文件不存在,则无操作。 +func removeFile(filePath string) error { + if filePath == "" { + return fmt.Errorf("文件路径为空") + } + + err := os.Remove(filePath) + if err != nil && !errors.Is(err, fs.ErrNotExist) { + return fmt.Errorf("无法删除文件 %q: %w", filePath, err) + } + + return nil +} diff --git a/godo/ai/vector/query.go b/godo/ai/vector/query.go new file mode 100644 index 0000000..9a4ae4f --- /dev/null +++ b/godo/ai/vector/query.go @@ -0,0 +1,207 @@ +package vector + +import ( + "cmp" + "container/heap" + "context" + "fmt" + "runtime" + "slices" + "strings" + "sync" +) + +var supportedFilters = []string{"$contains", "$not_contains"} + +type docSim struct { + docID string + similarity float32 +} + +// docMaxHeap 是基于相似度的最大堆。 +type docMaxHeap []docSim + +func (h docMaxHeap) Len() int { return len(h) } +func (h docMaxHeap) Less(i, j int) bool { return h[i].similarity < h[j].similarity } +func (h docMaxHeap) Swap(i, j int) { h[i], h[j] = h[j], h[i] } + +func (h *docMaxHeap) Push(x any) { + *h = append(*h, x.(docSim)) +} + +func (h *docMaxHeap) Pop() any { + old := *h + n := len(old) + x := old[n-1] + *h = old[0 : n-1] + return x +} + +// maxDocSims 管理一个固定大小的最大堆,保存最高的 n 个相似度。并发安全,但 values() 返回的结果不是。 +type maxDocSims struct { + h docMaxHeap + lock sync.RWMutex + size int +} + +// newMaxDocSims 创建一个新的固定大小的最大堆。 +func newMaxDocSims(size int) *maxDocSims { + return &maxDocSims{ + h: make(docMaxHeap, 0, size), + size: size, + } +} + +// add 插入一个新的 docSim 到堆中,保持最高的 n 个相似度。 +func (mds *maxDocSims) add(doc docSim) { + mds.lock.Lock() + defer mds.lock.Unlock() + if mds.h.Len() < mds.size { + heap.Push(&mds.h, doc) + } else if mds.h.Len() > 0 && mds.h[0].similarity < doc.similarity { + heap.Pop(&mds.h) + heap.Push(&mds.h, doc) + } +} + +// values 返回堆中的 docSim,按相似度降序排列。调用是并发安全的,但结果不是。 +func (d *maxDocSims) values() []docSim { + d.lock.RLock() + defer d.lock.RUnlock() + slices.SortFunc(d.h, func(i, j docSim) int { + return cmp.Compare(j.similarity, i.similarity) + }) + return d.h +} + +// filterDocs 并发过滤文档,根据元数据和内容进行筛选。 +func filterDocs(docs map[string]*Document, where, whereDocument map[string]string) []*Document { + filteredDocs := make([]*Document, 0, len(docs)) + var filteredDocsLock sync.Mutex + + numCPUs := runtime.NumCPU() + numDocs := len(docs) + concurrency := min(numCPUs, numDocs) + + docChan := make(chan *Document, concurrency*2) + + var wg sync.WaitGroup + for i := 0; i < concurrency; i++ { + wg.Add(1) + go func() { + defer wg.Done() + for doc := range docChan { + if documentMatchesFilters(doc, where, whereDocument) { + filteredDocsLock.Lock() + filteredDocs = append(filteredDocs, doc) + filteredDocsLock.Unlock() + } + } + }() + } + + for _, doc := range docs { + docChan <- doc + } + close(docChan) + + wg.Wait() + + if len(filteredDocs) == 0 { + return nil + } + return filteredDocs +} + +// documentMatchesFilters 检查文档是否匹配给定的过滤条件。 +func documentMatchesFilters(document *Document, where, whereDocument map[string]string) bool { + for k, v := range where { + if document.Metadata[k] != v { + return false + } + } + + for k, v := range whereDocument { + switch k { + case "$contains": + if !strings.Contains(document.Content, v) { + return false + } + case "$not_contains": + if strings.Contains(document.Content, v) { + return false + } + } + } + + return true +} + +// getMostSimilarDocs 获取与查询向量最相似的前 n 个文档。 +func getMostSimilarDocs(ctx context.Context, queryVectors []float32, docs []*Document, n int) ([]docSim, error) { + nMaxDocs := newMaxDocSims(n) + + numCPUs := runtime.NumCPU() + numDocs := len(docs) + concurrency := min(numCPUs, numDocs) + + var sharedErr error + var sharedErrLock sync.Mutex + ctx, cancel := context.WithCancelCause(ctx) + defer cancel(nil) + + setSharedErr := func(err error) { + sharedErrLock.Lock() + defer sharedErrLock.Unlock() + if sharedErr == nil { + sharedErr = err + cancel(sharedErr) + } + } + + var wg sync.WaitGroup + subSliceSize := len(docs) / concurrency + rem := len(docs) % concurrency + + for i := 0; i < concurrency; i++ { + start := i * subSliceSize + end := start + subSliceSize + if i == concurrency-1 { + end += rem + } + + wg.Add(1) + go func(subSlice []*Document) { + defer wg.Done() + for _, doc := range subSlice { + if ctx.Err() != nil { + return + } + + sim, err := dotProduct(queryVectors, doc.Embedding) + if err != nil { + setSharedErr(fmt.Errorf("无法计算文档 '%s' 的相似度: %w", doc.ID, err)) + return + } + + nMaxDocs.add(docSim{docID: doc.ID, similarity: sim}) + } + }(docs[start:end]) + } + + wg.Wait() + + if sharedErr != nil { + return nil, sharedErr + } + + return nMaxDocs.values(), nil +} + +// 辅助函数:返回两个数中的最小值。 +func min(a, b int) int { + if a < b { + return a + } + return b +} diff --git a/godo/ai/vector/utils.go b/godo/ai/vector/utils.go new file mode 100644 index 0000000..e928cc6 --- /dev/null +++ b/godo/ai/vector/utils.go @@ -0,0 +1,79 @@ +package vector + +import ( + "errors" + "fmt" + "math" +) + +const isNormalizedPrecisionTolerance = 1e-6 + +// cosineSimilarity 计算两个向量的余弦相似度。 +// 向量在计算前会被归一化。 +// 结果值表示相似度,值越高表示向量越相似。 +func cosineSimilarity(a, b []float32) (float32, error) { + // 向量必须具有相同的长度 + if len(a) != len(b) { + return 0, errors.New("向量必须具有相同的长度") + } + + // 归一化向量 + aNorm := normalizeVector(a) + bNorm := normalizeVector(b) + + // 计算点积 + dotProduct, err := dotProduct(aNorm, bNorm) + if err != nil { + return 0, fmt.Errorf("无法计算点积: %w", err) + } + + return dotProduct, nil +} + +// dotProduct 计算两个向量的点积。 +// 对于归一化的向量,点积等同于余弦相似度。 +// 结果值表示相似度,值越高表示向量越相似。 +func dotProduct(a, b []float32) (float32, error) { + // 向量必须具有相同的长度 + if len(a) != len(b) { + return 0, errors.New("向量必须具有相同的长度") + } + + var dotProduct float32 + for i := range a { + dotProduct += a[i] * b[i] + } + + return dotProduct, nil +} + +// normalizeVector 归一化一个浮点数向量。 +// 归一化是指将向量的每个分量除以向量的模(长度),使得归一化后的向量长度为 1。 +func normalizeVector(v []float32) []float32 { + var norm float64 + for _, val := range v { + norm += float64(val * val) + } + if norm == 0 { + return v // 避免除以零的情况 + } + norm = math.Sqrt(norm) + + res := make([]float32, len(v)) + for i, val := range v { + res[i] = float32(float64(val) / norm) + } + + return res +} + +// isNormalized 检查向量是否已经归一化。 +// 如果向量的模接近 1,则认为它是归一化的。 +func isNormalized(v []float32) bool { + var sqSum float64 + for _, val := range v { + sqSum += float64(val) * float64(val) + } + magnitude := math.Sqrt(sqSum) + return math.Abs(magnitude-1) < isNormalizedPrecisionTolerance +} diff --git a/godo/ai/vector/vector.go b/godo/ai/vector/vector.go index 6bc6aae..47b9fcd 100644 --- a/godo/ai/vector/vector.go +++ b/godo/ai/vector/vector.go @@ -2,80 +2,16 @@ package vector import ( "encoding/json" + "fmt" "godo/libs" + "godo/model" + "godo/office" "net/http" "path/filepath" - - _ "embed" - // sqlite_vec "github.com/asg017/sqlite-vec-go-bindings/ncruces" - // "github.com/ncruces/go-sqlite3" ) -//var VecDb *sqlx.DB - -type VectorList struct { - ID int `json:"id"` - FilePath string `json:"file_path"` - Engine string `json:"engine"` - EmbeddingModel string `json:"model"` -} -type VectorDoc struct { - ID int `json:"id"` - Content string `json:"content"` - FilePath string `json:"file_path"` - ListID int `json:"list_id"` -} -type VectorItem struct { - DocID int `json:"rowid"` - Embedding []float32 `json:"embedding"` -} - -func init() { - - // dbPath := libs.GetVectorDb() - // sqlite_vec.Auto() - - // db, err := sqlx.Connect("sqlite3", dbPath) - // if err != nil { - // fmt.Println("Failed to open SQLite database:", err) - // return - // } - // defer db.Close() - // VecDb = db - // dsn := "file:" + dbPath - // db, err := sqlite3.Open(dsn) - // //db, err := sqlite3.Open(":memory:") - // if err != nil { - // fmt.Println("Failed to open SQLite database:", err) - // return - // } - // stmt, _, err := db.Prepare(`SELECT vec_version()`) - // if err != nil { - // log.Fatal(err) - // } - - // stmt.Step() - // log.Printf("vec_version=%s\n", stmt.ColumnText(0)) - // stmt.Close() - // _, err = db.Exec("CREATE TABLE IF NOT EXISTS vec_list (id INTEGER PRIMARY KEY AUTOINCREMENT,file_path TEXT NOT NULL,engine TEXT NOT NULL,embedding_model TEXT NOT NULL)") - // if err != nil { - // log.Fatal(err) - // } - // _, err = db.Exec("CREATE TABLE IF NOT EXISTS vec_doc (id INTEGER PRIMARY KEY AUTOINCREMENT,list_id INTEGER DEFAULT 0,file_path TEXT,content TEXT)") - // if err != nil { - // log.Fatal(err) - // } - // _, err = db.Exec("CREATE VIRTUAL TABLE vec_items USING vec0(embedding float[768])") - // if err != nil { - // log.Fatal(err) - // } - // VecDb = db - - //InitMonitor() -} - func HandlerCreateKnowledge(w http.ResponseWriter, r *http.Request) { - var req VectorList + var req model.VecList err := json.NewDecoder(r.Body).Decode(&req) if err != nil { libs.ErrorMsg(w, "the chat request error:"+err.Error()) @@ -92,15 +28,299 @@ func HandlerCreateKnowledge(w http.ResponseWriter, r *http.Request) { } req.FilePath = filepath.Join(basePath, req.FilePath) - // id, err := CreateVector(req) - // if err != nil { - // libs.ErrorMsg(w, err.Error()) - // return - // } - // libs.SuccessMsg(w, id, "create vector success") + id, err := CreateVector(req) + if err != nil { + libs.ErrorMsg(w, err.Error()) + return + } + libs.SuccessMsg(w, id, "create vector success") +} + +// CreateVector 创建一个新的 VectorList 记录 +func CreateVector(data model.VecList) (uint, error) { + if data.FilePath == "" { + return 0, fmt.Errorf("file path is empty") + } + if data.Engine == "" { + return 0, fmt.Errorf("engine is empty") + } + + if !libs.PathExists(data.FilePath) { + return 0, fmt.Errorf("file path does not exist") + } + if data.EmbeddingModel == "" { + return 0, fmt.Errorf("embedding model is empty") + } + if data.EmbedSize == 0 { + data.EmbedSize = 768 + } + // Create the new VectorList + result := model.Db.Create(&data) + if result.Error != nil { + return 0, fmt.Errorf("failed to create vector list: %w", result.Error) + } + + // Start background tasks + go office.SetDocument(data.FilePath, uint(data.ID)) + + return uint(data.ID), nil +} + +// DeleteVector 删除指定id的 VectorList 记录 +func DeleteVector(id int) error { + + return model.Db.Delete(&model.VecList{}, id).Error +} + +// RenameVectorDb 更改指定名称的 VectorList 的数据库名称 +func RenameVectorDb(oldName string, newName string) error { + basePath, err := libs.GetOsDir() + if err != nil { + return fmt.Errorf("failed to find old vector list: %w", err) + } + + // 获取旧的 VectorList 记录 + var oldList model.VecList + oldPath := filepath.Join(basePath, oldName) + if err := model.Db.Where("file_path = ?", oldPath).First(&oldList).Error; err != nil { + return fmt.Errorf("failed to find old vector list: %w", err) + } + + // 更新 VectorList 记录中的 FilePath + newPath := filepath.Join(basePath, newName) + if err := model.Db.Model(&model.VecList{}).Where("id = ?", oldList.ID).Update("file_path", newPath).Error; err != nil { + return fmt.Errorf("failed to update vector list: %w", err) + } + + return nil } -// // CreateVector 创建一个新的 VectorList 记录 +func GetVectorList() ([]model.VecList, error) { + var vectorList []model.VecList + if err := model.Db.Find(&vectorList).Error; err != nil { + return nil, fmt.Errorf("failed to get vector list: %w", err) + } + return vectorList, nil +} + +func GetVector(id uint) (model.VecList, error) { + var vectorList model.VecList + if err := model.Db.First(&vectorList, id).Error; err != nil { + return vectorList, fmt.Errorf("failed to get vector: %w", err) + } + return vectorList, nil +} + +// func SimilaritySearch(query string, numDocuments int, collection string, where map[string]string) ([]vs.Document, error) { +// ef := v.embeddingFunc +// if embeddingFunc != nil { +// ef = embeddingFunc +// } + +// q, err := ef(ctx, query) +// if err != nil { +// return nil, fmt.Errorf("failed to compute embedding: %w", err) +// } + +// qv, err := sqlitevec.SerializeFloat32(q) +// if err != nil { +// return nil, fmt.Errorf("failed to serialize query embedding: %w", err) +// } + +// var docs []vs.Document +// err = v.db.Transaction(func(tx *gorm.DB) error { +// // Query matching document IDs and distances +// rows, err := tx.Raw(fmt.Sprintf(` +// SELECT document_id, distance +// FROM [%s_vec] +// WHERE embedding MATCH ? +// ORDER BY distance +// LIMIT ? +// `, collection), qv, numDocuments).Rows() +// if err != nil { +// return fmt.Errorf("failed to query vector table: %w", err) +// } +// defer rows.Close() + +// for rows.Next() { +// var docID string +// var distance float32 +// if err := rows.Scan(&docID, &distance); err != nil { +// return fmt.Errorf("failed to scan row: %w", err) +// } +// docs = append(docs, vs.Document{ +// ID: docID, +// SimilarityScore: 1 - distance, // Higher score means closer match +// }) +// } + +// // Fetch content and metadata for each document +// for i, doc := range docs { +// var content string +// var metadataJSON []byte +// err := tx.Raw(fmt.Sprintf(` +// SELECT content, metadata +// FROM [%s] +// WHERE id = ? +// `, v.embeddingsTableName), doc.ID).Row().Scan(&content, &metadataJSON) +// if err != nil { +// return fmt.Errorf("failed to query embeddings table for document %s: %w", doc.ID, err) +// } + +// var metadata map[string]interface{} +// if err := json.Unmarshal(metadataJSON, &metadata); err != nil { +// return fmt.Errorf("failed to parse metadata for document %s: %w", doc.ID, err) +// } + +// docs[i].Content = content +// docs[i].Metadata = metadata +// } + +// return nil +// }) + +// if err != nil { +// return nil, err +// } + +// return docs, nil +// } + +// func AddDocuments(docs []VectorDoc, collection string) ([]string, error) { +// ids := make([]string, len(docs)) + +// err := VecDb.Transaction(func(tx *gorm.DB) error { +// if len(docs) > 0 { +// valuePlaceholders := make([]string, len(docs)) +// args := make([]interface{}, 0, len(docs)*2) // 2 args per doc: document_id and embedding + +// for i, doc := range docs { +// emb, err := v.embeddingFunc(ctx, doc.Content) +// if err != nil { +// return fmt.Errorf("failed to compute embedding for document %s: %w", doc.ID, err) +// } + +// serializedEmb, err := sqlitevec.SerializeFloat32(emb) +// if err != nil { +// return fmt.Errorf("failed to serialize embedding for document %s: %w", doc.ID, err) +// } + +// valuePlaceholders[i] = "(?, ?)" +// args = append(args, doc.ID, serializedEmb) + +// ids[i] = doc.ID +// } + +// // Raw query for *_vec as gorm doesn't support virtual tables +// query := fmt.Sprintf(` +// INSERT INTO [%s_vec] (document_id, embedding) +// VALUES %s +// `, collection, strings.Join(valuePlaceholders, ", ")) + +// if err := tx.Exec(query, args...).Error; err != nil { +// return fmt.Errorf("failed to batch insert into vector table: %w", err) +// } +// } + +// embs := make([]map[string]interface{}, len(docs)) +// for i, doc := range docs { +// metadataJson, err := json.Marshal(doc.Metadata) +// if err != nil { +// return fmt.Errorf("failed to marshal metadata for document %s: %w", doc.ID, err) +// } +// embs[i] = map[string]interface{}{ +// "id": doc.ID, +// "collection_id": collection, +// "content": doc.Content, +// "metadata": metadataJson, +// } +// } + +// // Use GORM's Create for the embeddings table +// if err := tx.Table(v.embeddingsTableName).Create(embs).Error; err != nil { +// return fmt.Errorf("failed to batch insert into embeddings table: %w", err) +// } + +// return nil +// }) + +// if err != nil { +// return nil, err +// } + +// return ids, nil +// } +//func init() { + +// dbPath := libs.GetVectorDb() +// sqlite_vec.Auto() + +// db, err := sqlx.Connect("sqlite3", dbPath) +// if err != nil { +// fmt.Println("Failed to open SQLite database:", err) +// return +// } +// defer db.Close() +// VecDb = db +// dsn := "file:" + dbPath +// db, err := sqlite3.Open(dsn) +// //db, err := sqlite3.Open(":memory:") +// if err != nil { +// fmt.Println("Failed to open SQLite database:", err) +// return +// } +// stmt, _, err := db.Prepare(`SELECT vec_version()`) +// if err != nil { +// log.Fatal(err) +// } + +// stmt.Step() +// log.Printf("vec_version=%s\n", stmt.ColumnText(0)) +// stmt.Close() +// _, err = db.Exec("CREATE TABLE IF NOT EXISTS vec_list (id INTEGER PRIMARY KEY AUTOINCREMENT,file_path TEXT NOT NULL,engine TEXT NOT NULL,embedding_model TEXT NOT NULL)") +// if err != nil { +// log.Fatal(err) +// } +// _, err = db.Exec("CREATE TABLE IF NOT EXISTS vec_doc (id INTEGER PRIMARY KEY AUTOINCREMENT,list_id INTEGER DEFAULT 0,file_path TEXT,content TEXT)") +// if err != nil { +// log.Fatal(err) +// } +// _, err = db.Exec("CREATE VIRTUAL TABLE vec_items USING vec0(embedding float[768])") +// if err != nil { +// log.Fatal(err) +// } +// VecDb = db + +//InitMonitor() +//} + +// func HandlerCreateKnowledge(w http.ResponseWriter, r *http.Request) { +// var req VectorList +// err := json.NewDecoder(r.Body).Decode(&req) +// if err != nil { +// libs.ErrorMsg(w, "the chat request error:"+err.Error()) +// return +// } +// if req.FilePath == "" { +// libs.ErrorMsg(w, "file path is empty") +// return +// } +// basePath, err := libs.GetOsDir() +// if err != nil { +// libs.ErrorMsg(w, "get vector db path error:"+err.Error()) +// return +// } +// req.FilePath = filepath.Join(basePath, req.FilePath) + +// // id, err := CreateVector(req) +// // if err != nil { +// // libs.ErrorMsg(w, err.Error()) +// // return +// // } +// // libs.SuccessMsg(w, id, "create vector success") +// } + +// CreateVector 创建一个新的 VectorList 记录 // func CreateVector(data VectorList) (uint, error) { // if data.FilePath == "" { // return 0, fmt.Errorf("file path is empty") @@ -116,28 +336,9 @@ func HandlerCreateKnowledge(w http.ResponseWriter, r *http.Request) { // return 0, fmt.Errorf("embedding model is empty") // } -// // Check if a VectorList with the same path already exists - -// stmt, _, err := VecDb.Prepare(`SELECT id FROM vec_list WHERE file_path =`+ data.FilePath) -// if err != nil { -// log.Fatal(err) -// } -// defer stmt.Close() -// for stmt.Step() { -// fmt.Println(stmt.ColumnInt(0), stmt.ColumnText(1)) -// } -// if err := stmt.Err(); err != nil { -// log.Fatal(err) -// } - -// err = stmt.Close() -// if err != nil { -// log.Fatal(err) -// } // // Create the new VectorList -// err = VecDb.Exec("INSERT INTO vec_list (file_path, engine, embedding_model) VALUES (?, ?, ?)", data.FilePath, data.Engine, data.EmbeddingModel) -// if err != nil { -// return 0, err +// if err := tx.Table("vec_list").Create(data).Error; err != nil { +// return fmt.Errorf("failed to batch insert into embeddings table: %w", err) // } // // Get the ID of the newly created VectorList // vectorID, err := result.LastInsertId() @@ -258,44 +459,44 @@ func HandlerCreateKnowledge(w http.ResponseWriter, r *http.Request) { // FolderMonitor() // } -func GetVectorList() []VectorList { - var vectorList []VectorList - // stmt, _, err := VecDb.Prepare("SELECT id, file_path, engine, embedding_model FROM vec_list") - // if err != nil { - // fmt.Println("Failed to get vector list:", err) - // return vectorList - // } - // stmt.Step() - // log.Printf("vec_version=%s\n", stmt.ColumnText(0)) - // stmt.Close() - // defer rows.Close() - - // for rows.Next() { - // var v VectorList - // err := rows.Scan(&v.ID, &v.FilePath, &v.Engine, &v.EmbeddingModel) - // if err != nil { - // fmt.Println("Failed to scan vector list row:", err) - // continue - // } - // vectorList = append(vectorList, v) - // } - - return vectorList -} -func GetVector(id uint) VectorList { - var vectorList VectorList - // sql := "SELECT id, file_path, engine, embedding_model FROM vec_list WHERE id = " + fmt.Sprintf("%d", id) - // stmt, _, err := VecDb.Prepare(sql) - // if err != nil { - // fmt.Println("Failed to get vector list:", err) - // return vectorList - // } - // stmt.Step() - // log.Printf("vec_version=%s\n", stmt.ColumnText(0)) - // stmt.Close() - // err := VecDb.QueryRow("SELECT id, file_path, engine, embedding_model FROM vec_list WHERE id = ?", id).Scan(&vectorList.ID, &vectorList.FilePath, &vectorList.Engine, &vectorList.EmbeddingModel) - // if err != nil { - // fmt.Println("Failed to get vector:", err) - // } - return vectorList -} +// func GetVectorList() []VectorList { +// var vectorList []VectorList +// // stmt, _, err := VecDb.Prepare("SELECT id, file_path, engine, embedding_model FROM vec_list") +// // if err != nil { +// // fmt.Println("Failed to get vector list:", err) +// // return vectorList +// // } +// // stmt.Step() +// // log.Printf("vec_version=%s\n", stmt.ColumnText(0)) +// // stmt.Close() +// // defer rows.Close() + +// // for rows.Next() { +// // var v VectorList +// // err := rows.Scan(&v.ID, &v.FilePath, &v.Engine, &v.EmbeddingModel) +// // if err != nil { +// // fmt.Println("Failed to scan vector list row:", err) +// // continue +// // } +// // vectorList = append(vectorList, v) +// // } + +// return vectorList +// } +// func GetVector(id uint) VectorList { +// var vectorList VectorList +// // sql := "SELECT id, file_path, engine, embedding_model FROM vec_list WHERE id = " + fmt.Sprintf("%d", id) +// // stmt, _, err := VecDb.Prepare(sql) +// // if err != nil { +// // fmt.Println("Failed to get vector list:", err) +// // return vectorList +// // } +// // stmt.Step() +// // log.Printf("vec_version=%s\n", stmt.ColumnText(0)) +// // stmt.Close() +// // err := VecDb.QueryRow("SELECT id, file_path, engine, embedding_model FROM vec_list WHERE id = ?", id).Scan(&vectorList.ID, &vectorList.FilePath, &vectorList.Engine, &vectorList.EmbeddingModel) +// // if err != nil { +// // fmt.Println("Failed to get vector:", err) +// // } +// return vectorList +// } diff --git a/godo/cmd/main.go b/godo/cmd/main.go index 4122f3c..de5b368 100644 --- a/godo/cmd/main.go +++ b/godo/cmd/main.go @@ -53,7 +53,6 @@ func OsStart() { db.InitDB() proxy.InitProxyHandlers() webdav.InitWebdav() - // vector.InitMonitor() router := mux.NewRouter() router.Use(recoverMiddleware) if libs.GetIsCors() { diff --git a/godo/main.go b/godo/main.go index 1089569..1dd1423 100644 --- a/godo/main.go +++ b/godo/main.go @@ -18,7 +18,9 @@ package main -import "godo/cmd" +import ( + "godo/cmd" +) func main() { cmd.OsStart() diff --git a/godo/model/init.go b/godo/model/init.go index 75347ee..ae4706a 100644 --- a/godo/model/init.go +++ b/godo/model/init.go @@ -3,7 +3,8 @@ package model import ( "godo/libs" - _ "github.com/ncruces/go-sqlite3/embed" + _ "github.com/asg017/sqlite-vec-go-bindings/ncruces" + //_ "github.com/ncruces/go-sqlite3/embed" "github.com/ncruces/go-sqlite3/gormlite" "gorm.io/gorm" ) @@ -25,4 +26,6 @@ func InitDB() { db.AutoMigrate(&SysUser{}) db.AutoMigrate(&ClientUser{}) db.AutoMigrate(&ServerUser{}) + db.AutoMigrate(&VecList{}) + db.AutoMigrate(&VecDoc{}) } diff --git a/godo/model/vec_doc.go b/godo/model/vec_doc.go new file mode 100644 index 0000000..ad82c98 --- /dev/null +++ b/godo/model/vec_doc.go @@ -0,0 +1,14 @@ +package model + +import "gorm.io/gorm" + +type VecDoc struct { + gorm.Model + Content string `json:"content"` + FilePath string `json:"file_path" gorm:"not null"` + ListID int `json:"list_id"` +} + +func (VecDoc) TableName() string { + return "vec_doc" +} diff --git a/godo/model/vec_list.go b/godo/model/vec_list.go new file mode 100644 index 0000000..b28aba6 --- /dev/null +++ b/godo/model/vec_list.go @@ -0,0 +1,65 @@ +package model + +import ( + "fmt" + "log" + + "gorm.io/gorm" +) + +type VecList struct { + gorm.Model + FilePath string `json:"file_path" gorm:"not null"` + Engine string `json:"engine" gorm:"not null"` + EmbedSize int `json:"embed_size"` + EmbeddingModel string `json:"model" gorm:"not null"` +} + +func (*VecList) TableName() string { + return "vec_list" +} + +// BeforeCreate 在插入数据之前检查是否存在相同路径的数据 +func (v *VecList) BeforeCreate(tx *gorm.DB) error { + var count int64 + if err := tx.Model(&VecList{}).Where("file_path = ?", v.FilePath).Count(&count).Error; err != nil { + return err + } + if count > 0 { + return fmt.Errorf("file path already exists: %s", v.FilePath) + } + return nil +} + +// AfterCreate 在插入数据之后创建虚拟表 +func (v *VecList) AfterCreate(tx *gorm.DB) error { + return CreateVirtualTable(tx, v.ID, v.EmbedSize) +} + +// AfterDelete 在删除数据之后删除虚拟表 +func (v *VecList) AfterDelete(tx *gorm.DB) error { + // 删除 VecDoc 表中 ListID 对应的所有数据 + if err := tx.Where("list_id = ?", v.ID).Delete(&VecDoc{}).Error; err != nil { + return err + } + return DropVirtualTable(tx, v.ID) +} + +// CreateVirtualTable 创建虚拟表 +func CreateVirtualTable(db *gorm.DB, vectorID uint, embeddingSize int) error { + sql := fmt.Sprintf(` + CREATE VIRTUAL TABLE IF NOT EXISTS [%d_vec] USING + vec0( + document_id TEXT PRIMARY KEY, + embedding float[%d] distance_metric=cosine + ) + `, vectorID, embeddingSize) + log.Printf("sql: %s", sql) + return db.Exec(sql).Error +} + +// DropVirtualTable 删除虚拟表 +func DropVirtualTable(db *gorm.DB, vectorID uint) error { + sql := fmt.Sprintf(`DROP TABLE IF EXISTS [%d_vec]`, vectorID) + return db.Exec(sql).Error +} diff --git a/godo/vector/go.mod b/godo/vector/go.mod index 452c4f0..bd13ccf 100644 --- a/godo/vector/go.mod +++ b/godo/vector/go.mod @@ -1,14 +1,19 @@ -module vector +module godovec -go 1.22.5 +go 1.23.3 require ( - github.com/asg017/sqlite-vec-go-bindings v0.0.1-alpha.37 - github.com/ncruces/go-sqlite3 v0.17.2-0.20240711235451-21de85e849b7 + github.com/asg017/sqlite-vec-go-bindings v0.1.6 + github.com/ncruces/go-sqlite3/gormlite v0.21.0 + gorm.io/gorm v1.25.12 ) require ( + github.com/jinzhu/inflection v1.0.0 // indirect + github.com/jinzhu/now v1.1.5 // indirect + github.com/ncruces/go-sqlite3 v0.21.0 // indirect github.com/ncruces/julianday v1.0.0 // indirect - github.com/tetratelabs/wazero v1.7.3 // indirect - golang.org/x/sys v0.22.0 // indirect + github.com/tetratelabs/wazero v1.8.2 // indirect + golang.org/x/sys v0.28.0 // indirect + golang.org/x/text v0.21.0 // indirect ) diff --git a/godo/vector/go.sum b/godo/vector/go.sum index 3dcd8e9..cebbd74 100644 --- a/godo/vector/go.sum +++ b/godo/vector/go.sum @@ -1,12 +1,20 @@ -github.com/asg017/sqlite-vec-go-bindings v0.0.1-alpha.37 h1:Gz6YkDCs60k5VwbBPKDfAPPeIBcuaN3qriAozAaIIZI= -github.com/asg017/sqlite-vec-go-bindings v0.0.1-alpha.37/go.mod h1:A8+cTt/nKFsYCQF6OgzSNpKZrzNo5gQsXBTfsXHXY0Q= -github.com/ncruces/go-sqlite3 v0.17.2-0.20240711235451-21de85e849b7 h1:ssM02uUFDfz0V2TMg2du2BjbW9cpOhFJK0kpDN+X768= -github.com/ncruces/go-sqlite3 v0.17.2-0.20240711235451-21de85e849b7/go.mod h1:FnCyui8SlDoL0mQZ5dTouNo7s7jXS0kJv9lBt1GlM9w= +github.com/asg017/sqlite-vec-go-bindings v0.1.6 h1:Nx0jAzyS38XpkKznJ9xQjFXz2X9tI7KqjwVxV8RNoww= +github.com/asg017/sqlite-vec-go-bindings v0.1.6/go.mod h1:A8+cTt/nKFsYCQF6OgzSNpKZrzNo5gQsXBTfsXHXY0Q= +github.com/jinzhu/inflection v1.0.0 h1:K317FqzuhWc8YvSVlFMCCUb36O/S9MCKRDI7QkRKD/E= +github.com/jinzhu/inflection v1.0.0/go.mod h1:h+uFLlag+Qp1Va5pdKtLDYj+kHp5pxUVkryuEj+Srlc= +github.com/jinzhu/now v1.1.5 h1:/o9tlHleP7gOFmsnYNz3RGnqzefHA47wQpKrrdTIwXQ= +github.com/jinzhu/now v1.1.5/go.mod h1:d3SSVoowX0Lcu0IBviAWJpolVfI5UJVZZ7cO71lE/z8= +github.com/ncruces/go-sqlite3 v0.21.0 h1:EwKFoy1hHEopN4sFZarmi+McXdbCcbTuLixhEayXVbQ= +github.com/ncruces/go-sqlite3 v0.21.0/go.mod h1:zxMOaSG5kFYVFK4xQa0pdwIszqxqJ0W0BxBgwdrNjuA= +github.com/ncruces/go-sqlite3/gormlite v0.21.0 h1:9DsbvW9dS6uxXNFmbrNZixqAXKnIFnLM8oZmKqp8vcI= +github.com/ncruces/go-sqlite3/gormlite v0.21.0/go.mod h1:rP4JXD6jlpOSsg2Ed++kzJIAZZCIBirVYqIpwaLW88E= github.com/ncruces/julianday v1.0.0 h1:fH0OKwa7NWvniGQtxdJRxAgkBMolni2BjDHaWTxqt7M= github.com/ncruces/julianday v1.0.0/go.mod h1:Dusn2KvZrrovOMJuOt0TNXL6tB7U2E8kvza5fFc9G7g= -github.com/tetratelabs/wazero v1.7.3 h1:PBH5KVahrt3S2AHgEjKu4u+LlDbbk+nsGE3KLucy6Rw= -github.com/tetratelabs/wazero v1.7.3/go.mod h1:ytl6Zuh20R/eROuyDaGPkp82O9C/DJfXAwJfQ3X6/7Y= -golang.org/x/sys v0.22.0 h1:RI27ohtqKCnwULzJLqkv897zojh5/DwS/ENaMzUOaWI= -golang.org/x/sys v0.22.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA= -golang.org/x/text v0.16.0 h1:a94ExnEXNtEwYLGJSIUxnWoxoRz/ZcCsV63ROupILh4= -golang.org/x/text v0.16.0/go.mod h1:GhwF1Be+LQoKShO3cGOHzqOgRrGaYc9AvblQOmPVHnI= +github.com/tetratelabs/wazero v1.8.2 h1:yIgLR/b2bN31bjxwXHD8a3d+BogigR952csSDdLYEv4= +github.com/tetratelabs/wazero v1.8.2/go.mod h1:yAI0XTsMBhREkM/YDAK/zNou3GoiAce1P6+rp/wQhjs= +golang.org/x/sys v0.28.0 h1:Fksou7UEQUWlKvIdsqzJmUmCX3cZuD2+P3XyyzwMhlA= +golang.org/x/sys v0.28.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA= +golang.org/x/text v0.21.0 h1:zyQAAkrwaneQ066sspRyJaG9VNi/YJ1NfzcGB3hZ/qo= +golang.org/x/text v0.21.0/go.mod h1:4IBbMaMmOPCJ8SecivzSH54+73PCFmPWxNTLm+vZkEQ= +gorm.io/gorm v1.25.12 h1:I0u8i2hWQItBq1WfE0o2+WuL9+8L21K9e2HHSTE/0f8= +gorm.io/gorm v1.25.12/go.mod h1:xh7N7RHfYlNc5EmcI/El95gXusucDrQnHXe0+CgWcLQ= diff --git a/godo/vector/init.go b/godo/vector/init.go deleted file mode 100644 index 4b3245b..0000000 --- a/godo/vector/init.go +++ /dev/null @@ -1,5 +0,0 @@ -package main - -func Init() { - -} diff --git a/godo/vector/main.go b/godo/vector/main.go deleted file mode 100644 index ffd7078..0000000 --- a/godo/vector/main.go +++ /dev/null @@ -1,100 +0,0 @@ -package main - -import ( - _ "embed" - "log" - - sqlite_vec "github.com/asg017/sqlite-vec-go-bindings/ncruces" - "github.com/ncruces/go-sqlite3" -) - -var Db *sqlite3.Conn - -func main() { - db, err := sqlite3.Open(":memory:") - if err != nil { - log.Fatal(err) - } - - stmt, _, err := db.Prepare(`SELECT sqlite_version(), vec_version()`) - if err != nil { - log.Fatal(err) - } - - stmt.Step() - - log.Printf("sqlite_version=%s, vec_version=%s\n", stmt.ColumnText(0), stmt.ColumnText(1)) - stmt.Close() - - err = db.Exec("CREATE VIRTUAL TABLE vec_items USING vec0(embedding float[4])") - if err != nil { - log.Fatal(err) - } - items := map[int][]float32{ - 1: {0.1, 0.1, 0.1, 0.1}, - 2: {0.2, 0.2, 0.2, 0.2}, - 3: {0.3, 0.3, 0.3, 0.3}, - 4: {0.4, 0.4, 0.4, 0.4}, - 5: {0.5, 0.5, 0.5, 0.5}, - } - q := []float32{0.3, 0.3, 0.3, 0.3} - - stmt, _, err = db.Prepare("INSERT INTO vec_items(rowid, embedding) VALUES (?, ?)") - if err != nil { - log.Fatal(err) - } - - for id, values := range items { - v, err := sqlite_vec.SerializeFloat32(values) - if err != nil { - log.Fatal(err) - } - stmt.BindInt(1, id) - stmt.BindBlob(2, v) - err = stmt.Exec() - if err != nil { - log.Fatal(err) - } - stmt.Reset() - } - stmt.Close() - - stmt, _, err = db.Prepare(` - SELECT - rowid, - distance - FROM vec_items - WHERE embedding MATCH ? - ORDER BY distance - LIMIT 3 - `) - - if err != nil { - log.Fatal(err) - } - - query, err := sqlite_vec.SerializeFloat32(q) - if err != nil { - log.Fatal(err) - } - stmt.BindBlob(1, query) - - for stmt.Step() { - rowid := stmt.ColumnInt64(0) - distance := stmt.ColumnFloat(1) - log.Printf("rowid=%d, distance=%f\n", rowid, distance) - } - if err := stmt.Err(); err != nil { - log.Fatal(err) - } - - err = stmt.Close() - if err != nil { - log.Fatal(err) - } - - err = db.Close() - if err != nil { - log.Fatal(err) - } -} diff --git a/godo/vector/run.go b/godo/vector/run.go new file mode 100644 index 0000000..dab7528 --- /dev/null +++ b/godo/vector/run.go @@ -0,0 +1,54 @@ +package godovec + +import ( + "fmt" + + _ "embed" + + _ "github.com/asg017/sqlite-vec-go-bindings/ncruces" + + "github.com/ncruces/go-sqlite3/gormlite" + "gorm.io/gorm" +) + +var VecDb *gorm.DB + +type VectorList struct { + ID int `json:"id" gorm:"primaryKey"` + FilePath string `json:"file_path" gorm:"not null"` + Engine string `json:"engine" gorm:"not null"` + EmbeddingModel string `json:"model" gorm:"not null"` +} + +type VectorDoc struct { + ID int `json:"id" gorm:"primaryKey"` + Content string `json:"content"` + FilePath string `json:"file_path" gorm:"not null"` + ListID int `json:"list_id"` +} + +func main() { + InitVector() +} +func InitVector() error { + + db, err := gorm.Open(gormlite.Open("./data.db"), &gorm.Config{}) + if err != nil { + return fmt.Errorf("failed to open vector db: %w", err) + } + + // Enable PRAGMAs + // - busy_timeout (ms) to prevent db lockups as we're accessing the DB from multiple separate processes in otto8 + tx := db.Exec(` +PRAGMA busy_timeout = 5000; +`) + if tx.Error != nil { + return fmt.Errorf("failed to execute pragma busy_timeout: %w", tx.Error) + } + err = db.AutoMigrate(&VectorList{}, &VectorDoc{}) + if err != nil { + return fmt.Errorf("failed to auto migrate tables: %w", err) + } + VecDb = db + return nil +}