mirror of https://gitee.com/godoos/godoos.git
You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
288 lines
8.3 KiB
288 lines
8.3 KiB
// Copyright (c) seasonjs. All rights reserved.
|
|
// Licensed under the MIT License. See License.txt in the project root for license information.
|
|
|
|
package sd
|
|
|
|
import (
|
|
"github.com/ebitengine/purego"
|
|
"runtime"
|
|
"unsafe"
|
|
)
|
|
|
|
type LogLevel int
|
|
|
|
type RNGType int
|
|
|
|
type SampleMethod int
|
|
|
|
type Schedule int
|
|
|
|
type WType int
|
|
|
|
const (
|
|
DEBUG LogLevel = iota
|
|
INFO
|
|
WARN
|
|
ERROR
|
|
)
|
|
|
|
const (
|
|
STD_DEFAULT_RNG RNGType = iota
|
|
CUDA_RNG
|
|
)
|
|
|
|
const (
|
|
EULER_A SampleMethod = iota
|
|
EULER
|
|
HEUN
|
|
DPM2
|
|
DPMPP2S_A
|
|
DPMPP2M
|
|
DPMPP2Mv2
|
|
LCM
|
|
N_SAMPLE_METHODS
|
|
)
|
|
|
|
const (
|
|
DEFAULT Schedule = iota
|
|
DISCRETE
|
|
KARRAS
|
|
N_SCHEDULES
|
|
)
|
|
|
|
const (
|
|
F32 WType = 0
|
|
F16 = 1
|
|
Q4_0 = 2
|
|
Q4_1 = 3
|
|
Q5_0 = 6
|
|
Q5_1 = 7
|
|
Q8_0 = 8
|
|
Q8_1 = 9
|
|
Q2_K = 10
|
|
Q3_K = 11
|
|
Q4_K = 12
|
|
Q5_K = 13
|
|
Q6_K = 14
|
|
Q8_K = 15
|
|
I8 = 16
|
|
I16 = 17
|
|
I32 = 18
|
|
COUNT = 19 // don't use this when specifying a type
|
|
)
|
|
|
|
type CStableDiffusionCtx struct {
|
|
ctx uintptr
|
|
}
|
|
|
|
type CUpScalerCtx struct {
|
|
ctx uintptr
|
|
}
|
|
|
|
type CLogCallback func(level LogLevel, text string)
|
|
|
|
type CStableDiffusion interface {
|
|
NewCtx(modelPath string, vaePath string, taesdPath string, loraModelDir string, vaeDecodeOnly bool, vaeTiling bool, freeParamsImmediately bool, nThreads int, wType WType, rngType RNGType, schedule Schedule) *CStableDiffusionCtx
|
|
PredictImage(ctx *CStableDiffusionCtx, prompt string, negativePrompt string, clipSkip int, cfgScale float32, width int, height int, sampleMethod SampleMethod, sampleSteps int, seed int64, batchCount int) []Image
|
|
ImagePredictImage(ctx *CStableDiffusionCtx, img Image, prompt string, negativePrompt string, clipSkip int, cfgScale float32, width int, height int, sampleMethod SampleMethod, sampleSteps int, strength float32, seed int64, batchCount int) []Image
|
|
SetLogCallBack(cb CLogCallback)
|
|
GetSystemInfo() string
|
|
FreeCtx(ctx *CStableDiffusionCtx)
|
|
|
|
NewUpscalerCtx(esrganPath string, nThreads int, wType WType) *CUpScalerCtx
|
|
FreeUpscalerCtx(ctx *CUpScalerCtx)
|
|
UpscaleImage(ctx *CUpScalerCtx, img Image, upscaleFactor uint32) Image
|
|
|
|
Close() error
|
|
}
|
|
|
|
type cImage struct {
|
|
width uint32
|
|
height uint32
|
|
channel uint32
|
|
data uintptr
|
|
}
|
|
|
|
type cDarwinImage struct {
|
|
width uint32
|
|
height uint32
|
|
channel uint32
|
|
data *byte
|
|
}
|
|
|
|
type Image struct {
|
|
Width uint32
|
|
Height uint32
|
|
Channel uint32
|
|
Data []byte
|
|
}
|
|
|
|
type CStableDiffusionImpl struct {
|
|
libSd uintptr
|
|
|
|
sdGetSystemInfo func() string
|
|
|
|
newSdCtx func(modelPath string, vaePath string, taesdPath string, loraModelDir string, vaeDecodeOnly bool, vaeTiling bool, freeParamsImmediately bool, nThreads int, wtype int, rngType int, schedule int) uintptr
|
|
|
|
sdSetLogCallback func(callback func(level int, text uintptr, data uintptr) uintptr, data uintptr)
|
|
|
|
txt2img func(ctx uintptr, prompt string, negativePrompt string, clipSkip int, cfgScale float32, width int, height int, sampleMethod int, sampleSteps int, seed int64, batchCount int) uintptr
|
|
|
|
img2img func(ctx uintptr, img uintptr, prompt string, negativePrompt string, clipSkip int, cfgScale float32, width int, height int, sampleMethod int, sampleSteps int, strength float32, seed int64, batchCount int) uintptr
|
|
|
|
freeSdCtx func(ctx uintptr)
|
|
|
|
newUpscalerCtx func(esrganPath string, nThreads int, wtype int) uintptr
|
|
|
|
freeUpscalerCtx func(ctx uintptr)
|
|
|
|
upscale func(ctx uintptr, img uintptr, upscaleFactor uint32) uintptr
|
|
}
|
|
|
|
func NewCStableDiffusion(libraryPath string) (*CStableDiffusionImpl, error) {
|
|
libSd, err := openLibrary(libraryPath)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
impl := CStableDiffusionImpl{}
|
|
|
|
purego.RegisterLibFunc(&impl.sdSetLogCallback, libSd, "sd_get_system_info")
|
|
purego.RegisterLibFunc(&impl.newSdCtx, libSd, "new_sd_ctx")
|
|
purego.RegisterLibFunc(&impl.sdSetLogCallback, libSd, "sd_set_log_callback")
|
|
purego.RegisterLibFunc(&impl.txt2img, libSd, "txt2img")
|
|
purego.RegisterLibFunc(&impl.img2img, libSd, "img2img")
|
|
purego.RegisterLibFunc(&impl.freeSdCtx, libSd, "free_sd_ctx")
|
|
purego.RegisterLibFunc(&impl.newUpscalerCtx, libSd, "new_upscaler_ctx")
|
|
purego.RegisterLibFunc(&impl.freeUpscalerCtx, libSd, "free_upscaler_ctx")
|
|
purego.RegisterLibFunc(&impl.upscale, libSd, "upscale")
|
|
|
|
return &impl, nil
|
|
}
|
|
|
|
func (c *CStableDiffusionImpl) NewCtx(modelPath string, vaePath string, taesdPath string, loraModelDir string, vaeDecodeOnly bool, vaeTiling bool, freeParamsImmediately bool, nThreads int, wType WType, rngType RNGType, schedule Schedule) *CStableDiffusionCtx {
|
|
ctx := c.newSdCtx(modelPath, vaePath, taesdPath, loraModelDir, vaeDecodeOnly, vaeTiling, freeParamsImmediately, nThreads, int(wType), int(rngType), int(schedule))
|
|
return &CStableDiffusionCtx{
|
|
ctx: ctx,
|
|
}
|
|
}
|
|
|
|
func (c *CStableDiffusionImpl) PredictImage(ctx *CStableDiffusionCtx, prompt string, negativePrompt string, clipSkip int, cfgScale float32, width int, height int, sampleMethod SampleMethod, sampleSteps int, seed int64, batchCount int) []Image {
|
|
images := c.txt2img(ctx.ctx, prompt, negativePrompt, clipSkip, cfgScale, width, height, int(sampleMethod), sampleSteps, seed, batchCount)
|
|
return goImageSlice(images, batchCount)
|
|
}
|
|
|
|
func (c *CStableDiffusionImpl) ImagePredictImage(ctx *CStableDiffusionCtx, img Image, prompt string, negativePrompt string, clipSkip int, cfgScale float32, width int, height int, sampleMethod SampleMethod, sampleSteps int, strength float32, seed int64, batchCount int) []Image {
|
|
ci := cImage{
|
|
width: img.Width,
|
|
height: img.Height,
|
|
channel: img.Channel,
|
|
data: uintptr(unsafe.Pointer(&img.Data[0])),
|
|
}
|
|
images := c.img2img(ctx.ctx, uintptr(unsafe.Pointer(&ci)), prompt, negativePrompt, clipSkip, cfgScale, width, height, int(sampleMethod), sampleSteps, strength, seed, batchCount)
|
|
return goImageSlice(images, batchCount)
|
|
}
|
|
|
|
func (c *CStableDiffusionImpl) SetLogCallBack(cb CLogCallback) {
|
|
c.sdSetLogCallback(func(level int, text uintptr, data uintptr) uintptr {
|
|
cb(LogLevel(level), goString(text))
|
|
return 0
|
|
}, 0)
|
|
}
|
|
|
|
func (c *CStableDiffusionImpl) GetSystemInfo() string {
|
|
return c.sdGetSystemInfo()
|
|
}
|
|
|
|
func (c *CStableDiffusionImpl) FreeCtx(ctx *CStableDiffusionCtx) {
|
|
ptr := *(*unsafe.Pointer)(unsafe.Pointer(&ctx.ctx))
|
|
if ptr != nil {
|
|
c.freeSdCtx(ctx.ctx)
|
|
}
|
|
ctx = nil
|
|
runtime.GC()
|
|
}
|
|
|
|
func (c *CStableDiffusionImpl) NewUpscalerCtx(esrganPath string, nThreads int, wType WType) *CUpScalerCtx {
|
|
ctx := c.newUpscalerCtx(esrganPath, nThreads, int(wType))
|
|
|
|
return &CUpScalerCtx{ctx: ctx}
|
|
}
|
|
|
|
func (c *CStableDiffusionImpl) FreeUpscalerCtx(ctx *CUpScalerCtx) {
|
|
ptr := *(*unsafe.Pointer)(unsafe.Pointer(&ctx.ctx))
|
|
if ptr != nil {
|
|
c.freeUpscalerCtx(ctx.ctx)
|
|
}
|
|
ctx = nil
|
|
runtime.GC()
|
|
}
|
|
|
|
func (c *CStableDiffusionImpl) Close() error {
|
|
if c.libSd != 0 {
|
|
err := closeLibrary(c.libSd)
|
|
return err
|
|
}
|
|
return nil
|
|
}
|
|
|
|
func (c *CStableDiffusionImpl) UpscaleImage(ctx *CUpScalerCtx, img Image, upscaleFactor uint32) Image {
|
|
ci := cImage{
|
|
width: img.Width,
|
|
height: img.Height,
|
|
channel: img.Channel,
|
|
data: uintptr(unsafe.Pointer(&img.Data[0])),
|
|
}
|
|
uptr := c.upscale(ctx.ctx, uintptr(unsafe.Pointer(&ci)), upscaleFactor)
|
|
ptr := *(*unsafe.Pointer)(unsafe.Pointer(&uptr))
|
|
if ptr == nil {
|
|
return Image{}
|
|
}
|
|
cimg := (*cImage)(ptr)
|
|
dataPtr := *(*unsafe.Pointer)(unsafe.Pointer(&cimg.data))
|
|
return Image{
|
|
Width: cimg.width,
|
|
Height: cimg.height,
|
|
Channel: cimg.channel,
|
|
Data: unsafe.Slice((*byte)(dataPtr), cimg.channel*cimg.width*cimg.height),
|
|
}
|
|
}
|
|
|
|
func goString(c uintptr) string {
|
|
// We take the address and then dereference it to trick go vet from creating a possible misuse of unsafe.Pointer
|
|
ptr := *(*unsafe.Pointer)(unsafe.Pointer(&c))
|
|
if ptr == nil {
|
|
return ""
|
|
}
|
|
var length int
|
|
for {
|
|
if *(*byte)(unsafe.Add(ptr, uintptr(length))) == '\x00' {
|
|
break
|
|
}
|
|
length++
|
|
}
|
|
return unsafe.String((*byte)(ptr), length)
|
|
}
|
|
|
|
func goImageSlice(c uintptr, size int) []Image {
|
|
// We take the address and then dereference it to trick go vet from creating a possible misuse of unsafe.Pointer
|
|
ptr := *(*unsafe.Pointer)(unsafe.Pointer(&c))
|
|
if ptr == nil {
|
|
return nil
|
|
}
|
|
img := (*cImage)(ptr)
|
|
goImages := make([]Image, 0, size)
|
|
imgSlice := unsafe.Slice(img, size)
|
|
for _, image := range imgSlice {
|
|
var gImg Image
|
|
gImg.Channel = image.channel
|
|
gImg.Width = image.width
|
|
gImg.Height = image.height
|
|
dataPtr := *(*unsafe.Pointer)(unsafe.Pointer(&image.data))
|
|
if ptr != nil {
|
|
gImg.Data = unsafe.Slice((*byte)(dataPtr), image.channel*image.width*image.height)
|
|
}
|
|
goImages = append(goImages, gImg)
|
|
}
|
|
return goImages
|
|
}
|
|
|