attention.py 12 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341
  1. from inspect import isfunction
  2. import math
  3. import torch
  4. import torch.nn.functional as F
  5. from torch import nn, einsum
  6. from einops import rearrange, repeat
  7. from typing import Optional, Any
  8. from .diffusionmodules.util import checkpoint
  9. try:
  10. import xformers
  11. import xformers.ops
  12. XFORMERS_IS_AVAILABLE = False
  13. except ImportError:
  14. XFORMERS_IS_AVAILABLE = False
  15. # CrossAttn precision handling
  16. import os
  17. _ATTN_PRECISION = os.environ.get("ATTN_PRECISION", "fp32")
  18. def exists(val):
  19. return val is not None
  20. def uniq(arr):
  21. return{el: True for el in arr}.keys()
  22. def default(val, d):
  23. if exists(val):
  24. return val
  25. return d() if isfunction(d) else d
  26. def max_neg_value(t):
  27. return -torch.finfo(t.dtype).max
  28. def init_(tensor):
  29. dim = tensor.shape[-1]
  30. std = 1 / math.sqrt(dim)
  31. tensor.uniform_(-std, std)
  32. return tensor
  33. # feedforward
  34. class GEGLU(nn.Module):
  35. def __init__(self, dim_in, dim_out):
  36. super().__init__()
  37. self.proj = nn.Linear(dim_in, dim_out * 2)
  38. def forward(self, x):
  39. x, gate = self.proj(x).chunk(2, dim=-1)
  40. return x * F.gelu(gate)
  41. class FeedForward(nn.Module):
  42. def __init__(self, dim, dim_out=None, mult=4, glu=False, dropout=0.):
  43. super().__init__()
  44. inner_dim = int(dim * mult)
  45. dim_out = default(dim_out, dim)
  46. project_in = nn.Sequential(
  47. nn.Linear(dim, inner_dim),
  48. nn.GELU()
  49. ) if not glu else GEGLU(dim, inner_dim)
  50. self.net = nn.Sequential(
  51. project_in,
  52. nn.Dropout(dropout),
  53. nn.Linear(inner_dim, dim_out)
  54. )
  55. def forward(self, x):
  56. return self.net(x)
  57. def zero_module(module):
  58. """
  59. Zero out the parameters of a module and return it.
  60. """
  61. for p in module.parameters():
  62. p.detach().zero_()
  63. return module
  64. def Normalize(in_channels):
  65. return torch.nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True)
  66. class SpatialSelfAttention(nn.Module):
  67. def __init__(self, in_channels):
  68. super().__init__()
  69. self.in_channels = in_channels
  70. self.norm = Normalize(in_channels)
  71. self.q = torch.nn.Conv2d(in_channels,
  72. in_channels,
  73. kernel_size=1,
  74. stride=1,
  75. padding=0)
  76. self.k = torch.nn.Conv2d(in_channels,
  77. in_channels,
  78. kernel_size=1,
  79. stride=1,
  80. padding=0)
  81. self.v = torch.nn.Conv2d(in_channels,
  82. in_channels,
  83. kernel_size=1,
  84. stride=1,
  85. padding=0)
  86. self.proj_out = torch.nn.Conv2d(in_channels,
  87. in_channels,
  88. kernel_size=1,
  89. stride=1,
  90. padding=0)
  91. def forward(self, x):
  92. h_ = x
  93. h_ = self.norm(h_)
  94. q = self.q(h_)
  95. k = self.k(h_)
  96. v = self.v(h_)
  97. # compute attention
  98. b,c,h,w = q.shape
  99. q = rearrange(q, 'b c h w -> b (h w) c')
  100. k = rearrange(k, 'b c h w -> b c (h w)')
  101. w_ = torch.einsum('bij,bjk->bik', q, k)
  102. w_ = w_ * (int(c)**(-0.5))
  103. w_ = torch.nn.functional.softmax(w_, dim=2)
  104. # attend to values
  105. v = rearrange(v, 'b c h w -> b c (h w)')
  106. w_ = rearrange(w_, 'b i j -> b j i')
  107. h_ = torch.einsum('bij,bjk->bik', v, w_)
  108. h_ = rearrange(h_, 'b c (h w) -> b c h w', h=h)
  109. h_ = self.proj_out(h_)
  110. return x+h_
  111. class CrossAttention(nn.Module):
  112. def __init__(self, query_dim, context_dim=None, heads=8, dim_head=64, dropout=0.):
  113. super().__init__()
  114. inner_dim = dim_head * heads
  115. context_dim = default(context_dim, query_dim)
  116. self.scale = dim_head ** -0.5
  117. self.heads = heads
  118. self.to_q = nn.Linear(query_dim, inner_dim, bias=False)
  119. self.to_k = nn.Linear(context_dim, inner_dim, bias=False)
  120. self.to_v = nn.Linear(context_dim, inner_dim, bias=False)
  121. self.to_out = nn.Sequential(
  122. nn.Linear(inner_dim, query_dim),
  123. nn.Dropout(dropout)
  124. )
  125. def forward(self, x, context=None, mask=None):
  126. h = self.heads
  127. q = self.to_q(x)
  128. context = default(context, x)
  129. k = self.to_k(context)
  130. v = self.to_v(context)
  131. q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h=h), (q, k, v))
  132. # force cast to fp32 to avoid overflowing
  133. if _ATTN_PRECISION =="fp32":
  134. with torch.autocast(enabled=False, device_type = 'cuda'):
  135. q, k = q.float(), k.float()
  136. sim = einsum('b i d, b j d -> b i j', q, k) * self.scale
  137. else:
  138. sim = einsum('b i d, b j d -> b i j', q, k) * self.scale
  139. del q, k
  140. if exists(mask):
  141. mask = rearrange(mask, 'b ... -> b (...)')
  142. max_neg_value = -torch.finfo(sim.dtype).max
  143. mask = repeat(mask, 'b j -> (b h) () j', h=h)
  144. sim.masked_fill_(~mask, max_neg_value)
  145. # attention, what we cannot get enough of
  146. sim = sim.softmax(dim=-1)
  147. out = einsum('b i j, b j d -> b i d', sim, v)
  148. out = rearrange(out, '(b h) n d -> b n (h d)', h=h)
  149. return self.to_out(out)
  150. class MemoryEfficientCrossAttention(nn.Module):
  151. # https://github.com/MatthieuTPHR/diffusers/blob/d80b531ff8060ec1ea982b65a1b8df70f73aa67c/src/diffusers/models/attention.py#L223
  152. def __init__(self, query_dim, context_dim=None, heads=8, dim_head=64, dropout=0.0):
  153. super().__init__()
  154. print(f"Setting up {self.__class__.__name__}. Query dim is {query_dim}, context_dim is {context_dim} and using "
  155. f"{heads} heads.")
  156. inner_dim = dim_head * heads
  157. context_dim = default(context_dim, query_dim)
  158. self.heads = heads
  159. self.dim_head = dim_head
  160. self.to_q = nn.Linear(query_dim, inner_dim, bias=False)
  161. self.to_k = nn.Linear(context_dim, inner_dim, bias=False)
  162. self.to_v = nn.Linear(context_dim, inner_dim, bias=False)
  163. self.to_out = nn.Sequential(nn.Linear(inner_dim, query_dim), nn.Dropout(dropout))
  164. self.attention_op: Optional[Any] = None
  165. def forward(self, x, context=None, mask=None):
  166. q = self.to_q(x)
  167. context = default(context, x)
  168. k = self.to_k(context)
  169. v = self.to_v(context)
  170. b, _, _ = q.shape
  171. q, k, v = map(
  172. lambda t: t.unsqueeze(3)
  173. .reshape(b, t.shape[1], self.heads, self.dim_head)
  174. .permute(0, 2, 1, 3)
  175. .reshape(b * self.heads, t.shape[1], self.dim_head)
  176. .contiguous(),
  177. (q, k, v),
  178. )
  179. # actually compute the attention, what we cannot get enough of
  180. out = xformers.ops.memory_efficient_attention(q, k, v, attn_bias=None, op=self.attention_op)
  181. if exists(mask):
  182. raise NotImplementedError
  183. out = (
  184. out.unsqueeze(0)
  185. .reshape(b, self.heads, out.shape[1], self.dim_head)
  186. .permute(0, 2, 1, 3)
  187. .reshape(b, out.shape[1], self.heads * self.dim_head)
  188. )
  189. return self.to_out(out)
  190. class BasicTransformerBlock(nn.Module):
  191. ATTENTION_MODES = {
  192. "softmax": CrossAttention, # vanilla attention
  193. "softmax-xformers": MemoryEfficientCrossAttention
  194. }
  195. def __init__(self, dim, n_heads, d_head, dropout=0., context_dim=None, gated_ff=True, checkpoint=True,
  196. disable_self_attn=False):
  197. super().__init__()
  198. attn_mode = "softmax-xformers" if XFORMERS_IS_AVAILABLE else "softmax"
  199. assert attn_mode in self.ATTENTION_MODES
  200. attn_cls = self.ATTENTION_MODES[attn_mode]
  201. self.disable_self_attn = disable_self_attn
  202. self.attn1 = attn_cls(query_dim=dim, heads=n_heads, dim_head=d_head, dropout=dropout,
  203. context_dim=context_dim if self.disable_self_attn else None) # is a self-attention if not self.disable_self_attn
  204. self.ff = FeedForward(dim, dropout=dropout, glu=gated_ff)
  205. self.attn2 = attn_cls(query_dim=dim, context_dim=context_dim,
  206. heads=n_heads, dim_head=d_head, dropout=dropout) # is self-attn if context is none
  207. self.norm1 = nn.LayerNorm(dim)
  208. self.norm2 = nn.LayerNorm(dim)
  209. self.norm3 = nn.LayerNorm(dim)
  210. self.checkpoint = checkpoint
  211. def forward(self, x, context=None):
  212. return checkpoint(self._forward, (x, context), self.parameters(), self.checkpoint)
  213. def _forward(self, x, context=None):
  214. x = self.attn1(self.norm1(x), context=context if self.disable_self_attn else None) + x
  215. x = self.attn2(self.norm2(x), context=context) + x
  216. x = self.ff(self.norm3(x)) + x
  217. return x
  218. class SpatialTransformer(nn.Module):
  219. """
  220. Transformer block for image-like data.
  221. First, project the input (aka embedding)
  222. and reshape to b, t, d.
  223. Then apply standard transformer action.
  224. Finally, reshape to image
  225. NEW: use_linear for more efficiency instead of the 1x1 convs
  226. """
  227. def __init__(self, in_channels, n_heads, d_head,
  228. depth=1, dropout=0., context_dim=None,
  229. disable_self_attn=False, use_linear=False,
  230. use_checkpoint=True):
  231. super().__init__()
  232. if exists(context_dim) and not isinstance(context_dim, list):
  233. context_dim = [context_dim]
  234. self.in_channels = in_channels
  235. inner_dim = n_heads * d_head
  236. self.norm = Normalize(in_channels)
  237. if not use_linear:
  238. self.proj_in = nn.Conv2d(in_channels,
  239. inner_dim,
  240. kernel_size=1,
  241. stride=1,
  242. padding=0)
  243. else:
  244. self.proj_in = nn.Linear(in_channels, inner_dim)
  245. self.transformer_blocks = nn.ModuleList(
  246. [BasicTransformerBlock(inner_dim, n_heads, d_head, dropout=dropout, context_dim=context_dim[d],
  247. disable_self_attn=disable_self_attn, checkpoint=use_checkpoint)
  248. for d in range(depth)]
  249. )
  250. if not use_linear:
  251. self.proj_out = zero_module(nn.Conv2d(inner_dim,
  252. in_channels,
  253. kernel_size=1,
  254. stride=1,
  255. padding=0))
  256. else:
  257. self.proj_out = zero_module(nn.Linear(in_channels, inner_dim))
  258. self.use_linear = use_linear
  259. def forward(self, x, context=None):
  260. # note: if no context is given, cross-attention defaults to self-attention
  261. if not isinstance(context, list):
  262. context = [context]
  263. b, c, h, w = x.shape
  264. x_in = x
  265. x = self.norm(x)
  266. if not self.use_linear:
  267. x = self.proj_in(x)
  268. x = rearrange(x, 'b c h w -> b (h w) c').contiguous()
  269. if self.use_linear:
  270. x = self.proj_in(x)
  271. for i, block in enumerate(self.transformer_blocks):
  272. x = block(x, context=context[i])
  273. if self.use_linear:
  274. x = self.proj_out(x)
  275. x = rearrange(x, 'b (h w) c -> b c h w', h=h, w=w).contiguous()
  276. if not self.use_linear:
  277. x = self.proj_out(x)
  278. return x + x_in