sd_hack.py 3.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110
  1. import torch
  2. import einops
  3. from transformers import logging
  4. from . import ldm
  5. from .ldm.modules.attention import default
  6. def disable_verbosity():
  7. logging.set_verbosity_error()
  8. return
  9. def enable_sliced_attention():
  10. ldm.modules.attention.CrossAttention.forward = _hacked_sliced_attentin_forward
  11. print('Enabled sliced_attention.')
  12. return
  13. def hack_everything(clip_skip=0):
  14. disable_verbosity()
  15. ldm.modules.encoders.modules.FrozenCLIPEmbedder.forward = _hacked_clip_forward
  16. ldm.modules.encoders.modules.FrozenCLIPEmbedder.clip_skip = clip_skip
  17. return
  18. # Written by Lvmin
  19. def _hacked_clip_forward(self, text):
  20. PAD = self.tokenizer.pad_token_id
  21. EOS = self.tokenizer.eos_token_id
  22. BOS = self.tokenizer.bos_token_id
  23. text = [t.replace('_', ' ') for t in text]
  24. def tokenize(t):
  25. return self.tokenizer(t, truncation=False, add_special_tokens=False)["input_ids"]
  26. def transformer_encode(t):
  27. if self.clip_skip > 1:
  28. rt = self.transformer(input_ids=t, output_hidden_states=True)
  29. return self.transformer.text_model.final_layer_norm(rt.hidden_states[-self.clip_skip])
  30. else:
  31. return self.transformer(input_ids=t, output_hidden_states=False).last_hidden_state
  32. def split(x):
  33. return x[75 * 0: 75 * 1], x[75 * 1: 75 * 2], x[75 * 2: 75 * 3]
  34. def pad(x, p, i):
  35. return x[:i] if len(x) >= i else x + [p] * (i - len(x))
  36. raw_tokens_list = tokenize(text)
  37. tokens_list = []
  38. for raw_tokens in raw_tokens_list:
  39. raw_tokens_123 = split(raw_tokens)
  40. raw_tokens_123 = [[BOS] + raw_tokens_i + [EOS] for raw_tokens_i in raw_tokens_123]
  41. raw_tokens_123 = [pad(raw_tokens_i, PAD, 77) for raw_tokens_i in raw_tokens_123]
  42. tokens_list.append(raw_tokens_123)
  43. tokens_list = torch.IntTensor(tokens_list).to(self.device)
  44. feed = einops.rearrange(tokens_list, 'b f i -> (b f) i')
  45. y = transformer_encode(feed)
  46. z = einops.rearrange(y, '(b f) i c -> b (f i) c', f=3)
  47. return z
  48. # Stolen from https://github.com/basujindal/stable-diffusion/blob/main/optimizedSD/splitAttention.py
  49. def _hacked_sliced_attentin_forward(self, x, context=None, mask=None):
  50. h = self.heads
  51. q = self.to_q(x)
  52. context = default(context, x)
  53. k = self.to_k(context)
  54. v = self.to_v(context)
  55. del context, x
  56. q, k, v = map(lambda t: einops.rearrange(t, 'b n (h d) -> (b h) n d', h=h), (q, k, v))
  57. limit = k.shape[0]
  58. att_step = 1
  59. q_chunks = list(torch.tensor_split(q, limit // att_step, dim=0))
  60. k_chunks = list(torch.tensor_split(k, limit // att_step, dim=0))
  61. v_chunks = list(torch.tensor_split(v, limit // att_step, dim=0))
  62. q_chunks.reverse()
  63. k_chunks.reverse()
  64. v_chunks.reverse()
  65. sim = torch.zeros(q.shape[0], q.shape[1], v.shape[2], device=q.device)
  66. del k, q, v
  67. for i in range(0, limit, att_step):
  68. q_buffer = q_chunks.pop()
  69. k_buffer = k_chunks.pop()
  70. v_buffer = v_chunks.pop()
  71. sim_buffer = torch.einsum('b i d, b j d -> b i j', q_buffer, k_buffer) * self.scale
  72. del k_buffer, q_buffer
  73. # attention, what we cannot get enough of, by chunks
  74. sim_buffer = sim_buffer.softmax(dim=-1)
  75. sim_buffer = torch.einsum('b i j, b j d -> b i d', sim_buffer, v_buffer)
  76. del v_buffer
  77. sim[i:i + att_step, :, :] = sim_buffer
  78. del sim_buffer
  79. sim = einops.rearrange(sim, '(b h) n d -> b n (h d)', h=h)
  80. return self.to_out(sim)