在 AWS Inferentia 2 上使用 Stable Diffusion
AWS Inferentia2 实例专为深度学习(DL)推理而构建。它们在 Amazon EC2 中以最低的成本为生成式人工智能(AI)模型(包括大型语言模型(LLM)和视觉转换器)提供高性能计算。您可以使用 Inf2 实例来运行推理应用程序,以实现文本摘要、代码生成、视频和图像生成、语音识别、个性化、欺诈检测等等。
启动实例
启动 Inf2 实例需要选择专门的系统 AMI,在 AMI 市场搜索 Neuron,选择 AMI: Deep Learning AMI Neuron PyTorch 1.13 (Ubuntu 20.04)
- 截止 2023-08-31
由于需要转化模型,官方的 sample 实例建议机型为 inf2.8xlarge
。
进入系统后,查看 /home/ubuntu/README
文件,首先启动 PyTorch 的 Inf2 环境:
source /opt/aws_neuron_venv_pytorch/bin/activate
Stable diffusion 1.5
模型转换
使用如下的代码将模型转化为 Inf2 的支持的模型。
此代码来自于官方示例。
import os
os.environ["NEURON_FUSE_SOFTMAX"] = "1"
import torch
import torch.nn as nn
import torch_neuronx
import numpy as np
from matplotlib import pyplot as plt
from matplotlib import image as mpimg
import time
import copy
from IPython.display import clear_output
from diffusers import StableDiffusionPipeline
from diffusers.models.unet_2d_condition import UNet2DConditionOutput
from diffusers.models.cross_attention import CrossAttention
clear_output(wait=False)
###### 定义参数 #########
DTYPE = torch.float32
COMPILER_WORKDIR_ROOT = '/home/ubuntu/models/deli-inf2'
model_id = "XpucT/Deliberate"
#########################
######## 类和方法的定义 #################
def get_attention_scores(self, query, key, attn_mask):
dtype = query.dtype
if self.upcast_attention:
query = query.float()
key = key.float()
if(query.size() == key.size()):
attention_scores = cust_badbmm(
key,
query.transpose(-1, -2),
self.scale
)
if self.upcast_softmax:
attention_scores = attention_scores.float()
attention_probs = torch.nn.functional.softmax(attention_scores, dim=1).permute(0,2,1)
attention_probs = attention_probs.to(dtype)
else:
attention_scores = cust_badbmm(
query,
key.transpose(-1, -2),
self.scale
)
if self.upcast_softmax:
attention_scores = attention_scores.float()
attention_probs = torch.nn.functional.softmax(attention_scores, dim=-1)
attention_probs = attention_probs.to(dtype)
return attention_probs
def cust_badbmm(a, b, scale):
bmm = torch.bmm(a, b)
scaled = bmm * scale
return scaled
class UNetWrap(nn.Module):
def __init__(self, unet):
super().__init__()
self.unet = unet
def forward(self, sample, timestep, encoder_hidden_states, cross_attention_kwargs=None):
out_tuple = self.unet(sample, timestep, encoder_hidden_states, return_dict=False)
return out_tuple
class NeuronUNet(nn.Module):
def __init__(self, unetwrap):
super().__init__()
self.unetwrap = unetwrap
self.config = unetwrap.unet.config
self.in_channels = unetwrap.unet.in_channels
self.device = unetwrap.unet.device
def forward(self, sample, timestep, encoder_hidden_states, cross_attention_kwargs=None):
sample = self.unetwrap(sample, timestep.float().expand((sample.shape[0],)), encoder_hidden_states)[0]
return UNet2DConditionOutput(sample=sample)
class NeuronTextEncoder(nn.Module):
def __init__(self, text_encoder):
super().__init__()
self.neuron_text_encoder = text_encoder
self.config = text_encoder.config
self.dtype = torch.float32
self.device = text_encoder.device
def forward(self, emb, attention_mask = None):
return [self.neuron_text_encoder(emb)['last_hidden_state']]
class NeuronSafetyModelWrap(nn.Module):
def __init__(self, safety_model):
super().__init__()
self.safety_model = safety_model
def forward(self, clip_inputs):
return list(self.safety_model(clip_inputs).values())
# --- Compile CLIP text encoder and save ---
print("Load model....")
# Only keep the model being compiled in RAM to minimze memory pressure
pipe = StableDiffusionPipeline.from_pretrained(model_id, torch_dtype=DTYPE)
print("Compile CLIP text encoder and save")
text_encoder = copy.deepcopy(pipe.text_encoder)
# Apply the wrapper to deal with custom return type
text_encoder = NeuronTextEncoder(text_encoder)
# Compile text encoder
# This is used for indexing a lookup table in torch.nn.Embedding,
# so using random numbers may give errors (out of range).
emb = torch.tensor([[49406, 18376, 525, 7496, 49407, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0]])
with torch.no_grad():
text_encoder_neuron = torch_neuronx.trace(
text_encoder.neuron_text_encoder,
emb,
compiler_workdir=os.path.join(COMPILER_WORKDIR_ROOT, 'text_encoder'),
compiler_args=["--enable-fast-loading-neuron-binaries"]
)
# Save the compiled text encoder
text_encoder_filename = os.path.join(COMPILER_WORKDIR_ROOT, 'text_encoder/model.pt')
torch_neuronx.async_load(text_encoder_neuron)
torch.jit.save(text_encoder_neuron, text_encoder_filename)
# delete unused objects
del text_encoder
del text_encoder_neuron
del emb
# --- Compile VAE decoder and save ---
print("Compile VAE decoder and save")
# Only keep the model being compiled in RAM to minimze memory pressure
decoder = copy.deepcopy(pipe.vae.decoder)
# # Compile vae decoder
decoder_in = torch.randn([1, 4, 64, 64])
with torch.no_grad():
decoder_neuron = torch_neuronx.trace(
decoder,
decoder_in,
compiler_workdir=os.path.join(COMPILER_WORKDIR_ROOT, 'vae_decoder'),
compiler_args=["--enable-fast-loading-neuron-binaries"]
)
# Save the compiled vae decoder #######################
decoder_filename = os.path.join(COMPILER_WORKDIR_ROOT, 'vae_decoder/model.pt')
torch_neuronx.async_load(decoder_neuron)
torch.jit.save(decoder_neuron, decoder_filename)
# delete unused objects
del decoder
del decoder_in
del decoder_neuron
# --- Compile UNet and save ---
print("Compile UNet and save")
# Replace original cross-attention module with custom cross-attention module for better performance
CrossAttention.get_attention_scores = get_attention_scores
# Apply double wrapper to deal with custom return type
pipe.unet = NeuronUNet(UNetWrap(pipe.unet))
# Only keep the model being compiled in RAM to minimze memory pressure
unet = copy.deepcopy(pipe.unet.unetwrap)
# Compile unet - FP32
sample_1b = torch.randn([1, 4, 64, 64])
timestep_1b = torch.tensor(999).float().expand((1,))
encoder_hidden_states_1b = torch.randn([1, 77, 768])
example_inputs = sample_1b, timestep_1b, encoder_hidden_states_1b
with torch.no_grad():
unet_neuron = torch_neuronx.trace(
unet,
example_inputs,
compiler_workdir=os.path.join(COMPILER_WORKDIR_ROOT, 'unet'),
compiler_args=["--model-type=unet-inference", "--enable-fast-loading-neuron-binaries"]
)
# save compiled unet
unet_filename = os.path.join(COMPILER_WORKDIR_ROOT, 'unet/model.pt')
torch_neuronx.async_load(unet_neuron)
torch_neuronx.lazy_load(unet_neuron)
torch.jit.save(unet_neuron, unet_filename)
# delete unused objects
del unet
del unet_neuron
del sample_1b
del timestep_1b
del encoder_hidden_states_1b
# --- Compile VAE post_quant_conv and save ---
print("Compile VAE post_quant_conv and save")
# Only keep the model being compiled in RAM to minimze memory pressure
post_quant_conv = copy.deepcopy(pipe.vae.post_quant_conv)
# # # Compile vae post_quant_conv
post_quant_conv_in = torch.randn([1, 4, 64, 64])
with torch.no_grad():
post_quant_conv_neuron = torch_neuronx.trace(
post_quant_conv,
post_quant_conv_in,
compiler_workdir=os.path.join(COMPILER_WORKDIR_ROOT, 'vae_post_quant_conv'),
compiler_args=["--enable-fast-loading-neuron-binaries"]
)
# # Save the compiled vae post_quant_conv
post_quant_conv_filename = os.path.join(COMPILER_WORKDIR_ROOT, 'vae_post_quant_conv/model.pt')
torch_neuronx.async_load(post_quant_conv_neuron)
torch.jit.save(post_quant_conv_neuron, post_quant_conv_filename)
# delete unused objects
del post_quant_conv
# # --- Compile safety checker and save ---
print("Compile safety checker and save")
# Only keep the model being compiled in RAM to minimze memory pressure
safety_model = copy.deepcopy(pipe.safety_checker.vision_model)
clip_input = torch.randn([1, 3, 224, 224])
with torch.no_grad():
safety_model_neuron = torch_neuronx.trace(
safety_model,
clip_input,
compiler_workdir=os.path.join(COMPILER_WORKDIR_ROOT, 'safety_model_neuron'),
compiler_args=["--enable-fast-loading-neuron-binaries"]
)
# # Save the compiled vae post_quant_conv
safety_model_neuron_filename = os.path.join(COMPILER_WORKDIR_ROOT, 'safety_model_neuron/model.pt')
torch_neuronx.async_load(safety_model_neuron)
torch.jit.save(safety_model_neuron, safety_model_neuron_filename)
# delete unused objects
del safety_model_neuron
del pipe
- 需要 diffusers 的版本为 0.14.0
- 直接使用了 hf 上的 diffusers 模型:
XpucT/Deliberate
,使用最新的 diffusers 代码自己转换的 C 站模型会推理失败。具体原因不详。 - 上述代码转换后的模型只支持 512 * 512 的图片,这个代码是相关的shape:
torch.randn([1, 4, 64, 64])
保存文件: python trans_1.5.py 之后,直接运行 python trans_1.5.py
。
此过程将会生产 neuron 的模型,如下:
drwxrwxr-x 3 ubuntu ubuntu 4096 Aug 30 02:23 safety_model_neuron
drwxrwxr-x 3 ubuntu ubuntu 4096 Aug 30 02:14 text_encoder
drwxrwxr-x 3 ubuntu ubuntu 4096 Aug 30 02:22 unet
drwxrwxr-x 3 ubuntu ubuntu 4096 Aug 30 02:18 vae_decoder
drwxrwxr-x 3 ubuntu ubuntu 4096 Aug 30 02:22 vae_post_quant_conv
推理
直接上代码:
# 类库,类和方法的定义,参数定义和上面一致,略。。。
print("加载原始模型")
pipe = StableDiffusionPipeline.from_pretrained(
model_id,
torch_dtype=torch.float32,
safety_checker=None,
)
pipe.scheduler = UniPCMultistepScheduler.from_config(pipe.scheduler.config)
text_encoder_filename = os.path.join(COMPILER_WORKDIR_ROOT, "text_encoder/model.pt")
unet_filename = os.path.join(COMPILER_WORKDIR_ROOT, "unet/model.pt")
decoder_filename = os.path.join(COMPILER_WORKDIR_ROOT, "vae_decoder/model.pt")
post_quant_conv_filename = os.path.join(
COMPILER_WORKDIR_ROOT, "vae_post_quant_conv/model.pt"
)
safety_model_neuron_filename = os.path.join(
COMPILER_WORKDIR_ROOT, "safety_model_neuron/model.pt"
)
# Load the compiled UNet onto two neuron cores.
print("加载 Neuro 转换模型")
print("加载 unet")
pipe.unet = NeuronUNet(UNetWrap(pipe.unet))
device_ids = [0, 1]
pipe.unet.unetwrap = torch_neuronx.DataParallel(
torch.jit.load(unet_filename), device_ids, set_dynamic_batching=False
)
print("加载 text_encoder")
pipe.text_encoder = NeuronTextEncoder(pipe.text_encoder)
pipe.text_encoder.neuron_text_encoder = torch.jit.load(text_encoder_filename)
print("加载 vae decoder")
pipe.vae.decoder = torch.jit.load(decoder_filename)
print("加载 vae post_quant_conv")
pipe.vae.post_quant_conv = torch.jit.load(post_quant_conv_filename)
print("加载 safety_checker, skipping...")
# pipe.safety_checker.vision_model = NeuronSafetyModelWrap(torch.jit.load(safety_model_neuron_filename))
# Run pipeline
prompt = [
"A 24 y.o pretty girl, masterpiece, 8k, highres",
"a photo of an astronaut riding a horse on mars",
"sonic on the moon",
"elvis playing guitar while eating a hotdog",
"saved by the bell",
"engineers eating lunch at the opera",
"panda eating bamboo on a plane",
"A digital illustration of a steampunk flying machine in the sky with cogs and mechanisms, 4k, detailed, trending in artstation, fantasy vivid colors",
"kids playing soccer at the FIFA World Cup",
]
total_time = 0
i = 0
for x in prompt:
start_time = time.time()
image = pipe(
x, negative_prompt="nsfw, bad finger, lowres", num_inference_steps=20
).images[0]
total_time = total_time + (time.time() - start_time)
image.save(f"outputs/{i}.webp", "WEBP")
i = i + 1
print("Average time: ", np.round((total_time / len(prompt)), 2), "seconds")
推理顺利完成,速度很快。
SDXL
8月28日刚刚更新的新的 neuron 2.13 版本可以支持 SDXL 了。当前 AMI 未放出,所以需要先手工安装依赖包。
具体的依赖包仓库在:https://pip.repos.neuron.amazonaws.com/
比如安装最新版本的 neuronx-distributed 命令:
pip install neuronx-distributed transformers-neuronx -i https://pip.repos.neuron.amazonaws.com/
模型转换
废话不多说,上代码:
# filename: trans_xl.py
import os
import numpy as np
import torch
import torch.nn as nn
import torch_neuronx
from diffusers import DiffusionPipeline, DPMSolverMultistepScheduler
from diffusers.models.unet_2d_condition import UNet2DConditionOutput
from diffusers.models.attention_processor import Attention
from matplotlib import pyplot as plt
from matplotlib import image as mpimg
import time
import copy
from IPython.display import clear_output
clear_output(wait=False)
COMPILER_WORKDIR_ROOT = "/home/ubuntu/models/sdxl"
model_id = "stabilityai/stable-diffusion-xl-base-1.0"
def get_attention_scores_neuron(self, query, key, attn_mask):
if query.size() == key.size():
attention_scores = custom_badbmm(key, query.transpose(-1, -2), self.scale)
attention_probs = attention_scores.softmax(dim=1).permute(0, 2, 1)
else:
attention_scores = custom_badbmm(query, key.transpose(-1, -2), self.scale)
attention_probs = attention_scores.softmax(dim=-1)
return attention_probs
def custom_badbmm(a, b, scale):
bmm = torch.bmm(a, b)
scaled = bmm * scale
return scaled
class UNetWrap(nn.Module):
def __init__(self, unet):
super().__init__()
self.unet = unet
def forward(
self, sample, timestep, encoder_hidden_states, text_embeds=None, time_ids=None
):
out_tuple = self.unet(
sample,
timestep,
encoder_hidden_states,
added_cond_kwargs={"text_embeds": text_embeds, "time_ids": time_ids},
return_dict=False,
)
return out_tuple
class NeuronUNet(nn.Module):
def __init__(self, unetwrap):
super().__init__()
self.unetwrap = unetwrap
self.config = unetwrap.unet.config
self.in_channels = unetwrap.unet.in_channels
self.add_embedding = unetwrap.unet.add_embedding
self.device = unetwrap.unet.device
def forward(
self,
sample,
timestep,
encoder_hidden_states,
added_cond_kwargs=None,
return_dict=False,
cross_attention_kwargs=None,
):
sample = self.unetwrap(
sample,
timestep.float().expand((sample.shape[0],)),
encoder_hidden_states,
added_cond_kwargs["text_embeds"],
added_cond_kwargs["time_ids"],
)[0]
return UNet2DConditionOutput(sample=sample)
# --- Compile VAE decoder and save ---
# Only keep the model being compiled in RAM to minimze memory pressure
print("Load model....")
pipe = DiffusionPipeline.from_pretrained(model_id, torch_dtype=torch.float32)
print("Compile vae decoder....")
decoder = copy.deepcopy(pipe.vae.decoder)
# # Compile vae decoder
decoder_in = torch.randn([1, 4, 128, 128])
decoder_neuron = torch_neuronx.trace(
decoder,
decoder_in,
compiler_workdir=os.path.join(COMPILER_WORKDIR_ROOT, "vae_decoder"),
)
# Save the compiled vae decoder
decoder_filename = os.path.join(COMPILER_WORKDIR_ROOT, "vae_decoder/model.pt")
torch.jit.save(decoder_neuron, decoder_filename)
# delete unused objects
del decoder
# --- Compile UNet and save ---
print("Compile UNet")
# pipe = DiffusionPipeline.from_pretrained(model_id, torch_dtype=torch.float32)
# Replace original cross-attention module with custom cross-attention module for better performance
Attention.get_attention_scores = get_attention_scores_neuron
# Apply double wrapper to deal with custom return type
pipe.unet = NeuronUNet(UNetWrap(pipe.unet))
# Only keep the model being compiled in RAM to minimze memory pressure
unet = copy.deepcopy(pipe.unet.unetwrap)
# del pipe
# Compile unet - FP32
sample_1b = torch.randn([1, 4, 128, 128])
timestep_1b = torch.tensor(999).float().expand((1,))
encoder_hidden_states_1b = torch.randn([1, 77, 2048])
added_cond_kwargs_1b = {
"text_embeds": torch.randn([1, 1280]),
"time_ids": torch.randn([1, 6]),
}
example_inputs = (
sample_1b,
timestep_1b,
encoder_hidden_states_1b,
added_cond_kwargs_1b["text_embeds"],
added_cond_kwargs_1b["time_ids"],
)
unet_neuron = torch_neuronx.trace(
unet,
example_inputs,
compiler_workdir=os.path.join(COMPILER_WORKDIR_ROOT, "unet"),
compiler_args=["--model-type=unet-inference"],
)
# save compiled unet
unet_filename = os.path.join(COMPILER_WORKDIR_ROOT, "unet/model.pt")
torch.jit.save(unet_neuron, unet_filename)
# delete unused objects
del unet
print("Compile VAE post_quant_conv ")
# --- Compile VAE post_quant_conv and save ---
# Only keep the model being compiled in RAM to minimze memory pressure
# pipe = DiffusionPipeline.from_pretrained(model_id, torch_dtype=torch.float32)
post_quant_conv = copy.deepcopy(pipe.vae.post_quant_conv)
# del pipe
# Compile vae post_quant_conv
post_quant_conv_in = torch.randn([1, 4, 128, 128])
post_quant_conv_neuron = torch_neuronx.trace(
post_quant_conv,
post_quant_conv_in,
compiler_workdir=os.path.join(COMPILER_WORKDIR_ROOT, "vae_post_quant_conv"),
)
# Save the compiled vae post_quant_conv
post_quant_conv_filename = os.path.join(
COMPILER_WORKDIR_ROOT, "vae_post_quant_conv/model.pt"
)
torch.jit.save(post_quant_conv_neuron, post_quant_conv_filename)
# delete unused objects
del post_quant_conv
del pipe
但由于 SDXL 的模型要大很多,转换速度比较慢。
推理代码
# 头部定义掠过
decoder_filename = os.path.join(COMPILER_WORKDIR_ROOT, "vae_decoder/model.pt")
unet_filename = os.path.join(COMPILER_WORKDIR_ROOT, "unet/model.pt")
post_quant_conv_filename = os.path.join(
COMPILER_WORKDIR_ROOT, "vae_post_quant_conv/model.pt"
)
pipe = DiffusionPipeline.from_pretrained(model_id, torch_dtype=torch.float32)
pipe.scheduler = DPMSolverMultistepScheduler.from_config(pipe.scheduler.config)
# Load the compiled UNet onto two neuron cores.
pipe.unet = NeuronUNet(UNetWrap(pipe.unet))
device_ids = [0, 1]
pipe.unet.unetwrap = torch_neuronx.DataParallel(
torch.jit.load(unet_filename), device_ids, set_dynamic_batching=False
)
# Load other compiled models onto a single neuron core.
pipe.vae.decoder = torch.jit.load(decoder_filename)
pipe.vae.post_quant_conv = torch.jit.load(post_quant_conv_filename)
# Run pipeline
prompt = [
"A 24 y.o pretty girl, masterpiece, 8k, highres",
"a photo of an astronaut riding a horse on mars",
]
total_time = 0
i = 0
for x in prompt:
start_time = time.time()
pmt = f"""anime artwork {x} . anime style, key visual, vibrant, studio anime, highly detailed"""
npmt = """photo, deformed, black and white, realism, disfigured, low contrast"""
image = pipe(pmt, negative_prompt=npmt, num_inference_steps=20).images[0]
total_time = total_time + (time.time() - start_time)
image.save(f"outputs/xl-{i}.webp", "WEBP")
i = i + 1
print("Average time: ", np.round((total_time / len(prompt)), 2), "seconds")
过程和 1.5 差不多。
但模型加载的时间较长。
总结
- 模型需要转换,基本都能能成功完成转换,
- SD 1.5 的模型能成功完成推理的不多。
- SDXL 模型本身有较高质量,但加载时间较长。
- 当前按照官方示例,不支持动图片态尺寸,模型转化的时候已经固定好图片尺寸了。
- 推理速度很快:512 尺寸的 SD1.5 可达到 19it/s 左右,1024 的 SDXL 是 3.5it/s。
参考:
https://github.com/aws-neuron/aws-neuron-samples
https://awsdocs-neuron.readthedocs-hosted.com/en/latest/release-notes/index.html