You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
 
 
 
 
 
 

288 lines
8.3 KiB

// Copyright (c) seasonjs. All rights reserved.
// Licensed under the MIT License. See License.txt in the project root for license information.
package sd
import (
"github.com/ebitengine/purego"
"runtime"
"unsafe"
)
type LogLevel int
type RNGType int
type SampleMethod int
type Schedule int
type WType int
const (
DEBUG LogLevel = iota
INFO
WARN
ERROR
)
const (
STD_DEFAULT_RNG RNGType = iota
CUDA_RNG
)
const (
EULER_A SampleMethod = iota
EULER
HEUN
DPM2
DPMPP2S_A
DPMPP2M
DPMPP2Mv2
LCM
N_SAMPLE_METHODS
)
const (
DEFAULT Schedule = iota
DISCRETE
KARRAS
N_SCHEDULES
)
const (
F32 WType = 0
F16 = 1
Q4_0 = 2
Q4_1 = 3
Q5_0 = 6
Q5_1 = 7
Q8_0 = 8
Q8_1 = 9
Q2_K = 10
Q3_K = 11
Q4_K = 12
Q5_K = 13
Q6_K = 14
Q8_K = 15
I8 = 16
I16 = 17
I32 = 18
COUNT = 19 // don't use this when specifying a type
)
type CStableDiffusionCtx struct {
ctx uintptr
}
type CUpScalerCtx struct {
ctx uintptr
}
type CLogCallback func(level LogLevel, text string)
type CStableDiffusion interface {
NewCtx(modelPath string, vaePath string, taesdPath string, loraModelDir string, vaeDecodeOnly bool, vaeTiling bool, freeParamsImmediately bool, nThreads int, wType WType, rngType RNGType, schedule Schedule) *CStableDiffusionCtx
PredictImage(ctx *CStableDiffusionCtx, prompt string, negativePrompt string, clipSkip int, cfgScale float32, width int, height int, sampleMethod SampleMethod, sampleSteps int, seed int64, batchCount int) []Image
ImagePredictImage(ctx *CStableDiffusionCtx, img Image, prompt string, negativePrompt string, clipSkip int, cfgScale float32, width int, height int, sampleMethod SampleMethod, sampleSteps int, strength float32, seed int64, batchCount int) []Image
SetLogCallBack(cb CLogCallback)
GetSystemInfo() string
FreeCtx(ctx *CStableDiffusionCtx)
NewUpscalerCtx(esrganPath string, nThreads int, wType WType) *CUpScalerCtx
FreeUpscalerCtx(ctx *CUpScalerCtx)
UpscaleImage(ctx *CUpScalerCtx, img Image, upscaleFactor uint32) Image
Close() error
}
type cImage struct {
width uint32
height uint32
channel uint32
data uintptr
}
type cDarwinImage struct {
width uint32
height uint32
channel uint32
data *byte
}
type Image struct {
Width uint32
Height uint32
Channel uint32
Data []byte
}
type CStableDiffusionImpl struct {
libSd uintptr
sdGetSystemInfo func() string
newSdCtx func(modelPath string, vaePath string, taesdPath string, loraModelDir string, vaeDecodeOnly bool, vaeTiling bool, freeParamsImmediately bool, nThreads int, wtype int, rngType int, schedule int) uintptr
sdSetLogCallback func(callback func(level int, text uintptr, data uintptr) uintptr, data uintptr)
txt2img func(ctx uintptr, prompt string, negativePrompt string, clipSkip int, cfgScale float32, width int, height int, sampleMethod int, sampleSteps int, seed int64, batchCount int) uintptr
img2img func(ctx uintptr, img uintptr, prompt string, negativePrompt string, clipSkip int, cfgScale float32, width int, height int, sampleMethod int, sampleSteps int, strength float32, seed int64, batchCount int) uintptr
freeSdCtx func(ctx uintptr)
newUpscalerCtx func(esrganPath string, nThreads int, wtype int) uintptr
freeUpscalerCtx func(ctx uintptr)
upscale func(ctx uintptr, img uintptr, upscaleFactor uint32) uintptr
}
func NewCStableDiffusion(libraryPath string) (*CStableDiffusionImpl, error) {
libSd, err := openLibrary(libraryPath)
if err != nil {
return nil, err
}
impl := CStableDiffusionImpl{}
purego.RegisterLibFunc(&impl.sdSetLogCallback, libSd, "sd_get_system_info")
purego.RegisterLibFunc(&impl.newSdCtx, libSd, "new_sd_ctx")
purego.RegisterLibFunc(&impl.sdSetLogCallback, libSd, "sd_set_log_callback")
purego.RegisterLibFunc(&impl.txt2img, libSd, "txt2img")
purego.RegisterLibFunc(&impl.img2img, libSd, "img2img")
purego.RegisterLibFunc(&impl.freeSdCtx, libSd, "free_sd_ctx")
purego.RegisterLibFunc(&impl.newUpscalerCtx, libSd, "new_upscaler_ctx")
purego.RegisterLibFunc(&impl.freeUpscalerCtx, libSd, "free_upscaler_ctx")
purego.RegisterLibFunc(&impl.upscale, libSd, "upscale")
return &impl, nil
}
func (c *CStableDiffusionImpl) NewCtx(modelPath string, vaePath string, taesdPath string, loraModelDir string, vaeDecodeOnly bool, vaeTiling bool, freeParamsImmediately bool, nThreads int, wType WType, rngType RNGType, schedule Schedule) *CStableDiffusionCtx {
ctx := c.newSdCtx(modelPath, vaePath, taesdPath, loraModelDir, vaeDecodeOnly, vaeTiling, freeParamsImmediately, nThreads, int(wType), int(rngType), int(schedule))
return &CStableDiffusionCtx{
ctx: ctx,
}
}
func (c *CStableDiffusionImpl) PredictImage(ctx *CStableDiffusionCtx, prompt string, negativePrompt string, clipSkip int, cfgScale float32, width int, height int, sampleMethod SampleMethod, sampleSteps int, seed int64, batchCount int) []Image {
images := c.txt2img(ctx.ctx, prompt, negativePrompt, clipSkip, cfgScale, width, height, int(sampleMethod), sampleSteps, seed, batchCount)
return goImageSlice(images, batchCount)
}
func (c *CStableDiffusionImpl) ImagePredictImage(ctx *CStableDiffusionCtx, img Image, prompt string, negativePrompt string, clipSkip int, cfgScale float32, width int, height int, sampleMethod SampleMethod, sampleSteps int, strength float32, seed int64, batchCount int) []Image {
ci := cImage{
width: img.Width,
height: img.Height,
channel: img.Channel,
data: uintptr(unsafe.Pointer(&img.Data[0])),
}
images := c.img2img(ctx.ctx, uintptr(unsafe.Pointer(&ci)), prompt, negativePrompt, clipSkip, cfgScale, width, height, int(sampleMethod), sampleSteps, strength, seed, batchCount)
return goImageSlice(images, batchCount)
}
func (c *CStableDiffusionImpl) SetLogCallBack(cb CLogCallback) {
c.sdSetLogCallback(func(level int, text uintptr, data uintptr) uintptr {
cb(LogLevel(level), goString(text))
return 0
}, 0)
}
func (c *CStableDiffusionImpl) GetSystemInfo() string {
return c.sdGetSystemInfo()
}
func (c *CStableDiffusionImpl) FreeCtx(ctx *CStableDiffusionCtx) {
ptr := *(*unsafe.Pointer)(unsafe.Pointer(&ctx.ctx))
if ptr != nil {
c.freeSdCtx(ctx.ctx)
}
ctx = nil
runtime.GC()
}
func (c *CStableDiffusionImpl) NewUpscalerCtx(esrganPath string, nThreads int, wType WType) *CUpScalerCtx {
ctx := c.newUpscalerCtx(esrganPath, nThreads, int(wType))
return &CUpScalerCtx{ctx: ctx}
}
func (c *CStableDiffusionImpl) FreeUpscalerCtx(ctx *CUpScalerCtx) {
ptr := *(*unsafe.Pointer)(unsafe.Pointer(&ctx.ctx))
if ptr != nil {
c.freeUpscalerCtx(ctx.ctx)
}
ctx = nil
runtime.GC()
}
func (c *CStableDiffusionImpl) Close() error {
if c.libSd != 0 {
err := closeLibrary(c.libSd)
return err
}
return nil
}
func (c *CStableDiffusionImpl) UpscaleImage(ctx *CUpScalerCtx, img Image, upscaleFactor uint32) Image {
ci := cImage{
width: img.Width,
height: img.Height,
channel: img.Channel,
data: uintptr(unsafe.Pointer(&img.Data[0])),
}
uptr := c.upscale(ctx.ctx, uintptr(unsafe.Pointer(&ci)), upscaleFactor)
ptr := *(*unsafe.Pointer)(unsafe.Pointer(&uptr))
if ptr == nil {
return Image{}
}
cimg := (*cImage)(ptr)
dataPtr := *(*unsafe.Pointer)(unsafe.Pointer(&cimg.data))
return Image{
Width: cimg.width,
Height: cimg.height,
Channel: cimg.channel,
Data: unsafe.Slice((*byte)(dataPtr), cimg.channel*cimg.width*cimg.height),
}
}
func goString(c uintptr) string {
// We take the address and then dereference it to trick go vet from creating a possible misuse of unsafe.Pointer
ptr := *(*unsafe.Pointer)(unsafe.Pointer(&c))
if ptr == nil {
return ""
}
var length int
for {
if *(*byte)(unsafe.Add(ptr, uintptr(length))) == '\x00' {
break
}
length++
}
return unsafe.String((*byte)(ptr), length)
}
func goImageSlice(c uintptr, size int) []Image {
// We take the address and then dereference it to trick go vet from creating a possible misuse of unsafe.Pointer
ptr := *(*unsafe.Pointer)(unsafe.Pointer(&c))
if ptr == nil {
return nil
}
img := (*cImage)(ptr)
goImages := make([]Image, 0, size)
imgSlice := unsafe.Slice(img, size)
for _, image := range imgSlice {
var gImg Image
gImg.Channel = image.channel
gImg.Width = image.width
gImg.Height = image.height
dataPtr := *(*unsafe.Pointer)(unsafe.Pointer(&image.data))
if ptr != nil {
gImg.Data = unsafe.Slice((*byte)(dataPtr), image.channel*image.width*image.height)
}
goImages = append(goImages, gImg)
}
return goImages
}