t-sne_0423.py 6.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167
  1. from openTSNE import TSNE
  2. import numpy as np
  3. import matplotlib
  4. import matplotlib.pyplot as plt
  5. import random
  6. def visualize(
  7. x,
  8. y,
  9. ax=None,
  10. title=None,
  11. draw_legend=True,
  12. draw_centers=False,
  13. draw_cluster_labels=False,
  14. colors=None,
  15. legend_kwargs=None,
  16. label_order=None,
  17. **kwargs
  18. ):
  19. if ax is None:
  20. _, ax = matplotlib.pyplot.subplots(figsize=(10, 8))
  21. if title is not None:
  22. ax.set_title(title)
  23. plot_params = {"alpha": kwargs.get("alpha", 0.6), "s": kwargs.get("s", 1)}
  24. # Create main plot
  25. if label_order is not None:
  26. assert all(np.isin(np.unique(y), label_order))
  27. classes = [l for l in label_order if l in np.unique(y)]
  28. else:
  29. classes = np.unique(y)
  30. if colors is None:
  31. default_colors = matplotlib.rcParams["axes.prop_cycle"]
  32. colors = {k: v["color"] for k, v in zip(classes, default_colors())}
  33. point_colors = list(map(colors.get, y))
  34. ax.scatter(x[:, 0], x[:, 1], c=point_colors, rasterized=True, **plot_params)
  35. # Plot mediods
  36. if draw_centers:
  37. centers = []
  38. for yi in classes:
  39. mask = yi == y
  40. centers.append(np.median(x[mask, :2], axis=0))
  41. centers = np.array(centers)
  42. center_colors = list(map(colors.get, classes))
  43. ax.scatter(
  44. centers[:, 0], centers[:, 1], c=center_colors, s=48, alpha=1, edgecolor="k"
  45. )
  46. # Draw mediod labels
  47. if draw_cluster_labels:
  48. for idx, label in enumerate(classes):
  49. ax.text(
  50. centers[idx, 0],
  51. centers[idx, 1] + 2.2,
  52. label,
  53. fontsize=kwargs.get("fontsize", 6),
  54. horizontalalignment="center",
  55. )
  56. # Hide ticks and axis
  57. ax.set_xticks([]), ax.set_yticks([]), ax.axis("off")
  58. if draw_legend:
  59. legend_handles = [
  60. matplotlib.lines.Line2D(
  61. [],
  62. [],
  63. marker="s",
  64. color="w",
  65. markerfacecolor=colors[yi],
  66. ms=10,
  67. alpha=1,
  68. linewidth=0,
  69. label=yi,
  70. markeredgecolor="k",
  71. )
  72. for yi in classes
  73. ]
  74. legend_kwargs_ = dict(loc="best", bbox_to_anchor=(0.05, 0.5), frameon=False, )
  75. if legend_kwargs is not None:
  76. legend_kwargs_.update(legend_kwargs)
  77. ax.legend(handles=legend_handles, **legend_kwargs_)
  78. tsne = TSNE(
  79. perplexity=30,
  80. metric="euclidean",
  81. n_jobs=8,
  82. random_state=42,
  83. verbose=True,
  84. )
  85. # idexp_lm3d_pred_lrs3 = np.load("autio2motion_dream_it_possible.npy")
  86. # idx = np.random.choice(np.arange(len(idexp_lm3d_pred_lrs3)), 10000)
  87. # idexp_lm3d_pred_lrs3 = idexp_lm3d_pred_lrs3[idx]
  88. person_ds = np.load("data/binary/videos/May/trainval_dataset.npy", allow_pickle=True).tolist()
  89. person_idexp_mean = person_ds['idexp_lm3d_mean'].reshape([1,204])
  90. person_idexp_std = person_ds['idexp_lm3d_std'].reshape([1,204])
  91. person_idexp_lm3d_train = np.stack([s['idexp_lm3d_normalized'].reshape([204,]) for s in person_ds['train_samples']])
  92. person_idexp_lm3d_val = np.stack([s['idexp_lm3d_normalized'].reshape([204,]) for s in person_ds['val_samples']])
  93. person_idexp_lm3d_train = person_idexp_lm3d_train * person_idexp_std + person_idexp_mean
  94. person_idexp_lm3d_val = person_idexp_lm3d_val * person_idexp_std + person_idexp_mean
  95. # lrs3_stats = np.load('/home/yezhenhui/datasets/binary/lrs3_0702/stats.npy',allow_pickle=True).tolist()
  96. # lrs3_idexp_mean = lrs3_stats['idexp_lm3d_mean'].reshape([1,204])
  97. # lrs3_idexp_std = lrs3_stats['idexp_lm3d_std'].reshape([1,204])
  98. # person_idexp_lm3d_train = (person_idexp_lm3d_train - lrs3_idexp_mean) / lrs3_idexp_std
  99. # person_idexp_lm3d_val = (person_idexp_lm3d_val - lrs3_idexp_mean) / lrs3_idexp_std
  100. # idexp_lm3d_pred_lrs3 = idexp_lm3d_pred_lrs3 * lrs3_idexp_std + lrs3_idexp_mean
  101. idexp_lm3d_pred_vae = np.load("autio2motion_dream_it_possible.npy").reshape([-1,204])[:1000]
  102. idexp_lm3d_pred_postnet = np.load("postnet_dream_it_possible.npy").reshape([-1,204])[:1000]
  103. idexp_lm3d_pred_lle = np.load("lle_dream_it_possible.npy").reshape([-1,204])[:1000]
  104. # idexp_lm3d_pred_postnet = idexp_lm3d_pred_postnet * lrs3_idexp_std + lrs3_idexp_mean
  105. idexp_lm3d_all = np.concatenate([person_idexp_lm3d_train,idexp_lm3d_pred_vae, idexp_lm3d_pred_postnet,idexp_lm3d_pred_lle])
  106. idexp_lm3d_all_emb = tsne.fit(idexp_lm3d_all) # array(float64) [B,50]==>[B, 2]
  107. # z_p_emb = tsne.fit(z_p) # array(float64) [B,50]==>[B, 2]
  108. # y1 = ["pred_lrs3" for _ in range(len(idexp_lm3d_pred_lrs3))]
  109. y2 = ["person_train" for _ in range(len(person_idexp_lm3d_train))]
  110. y3 = ["vae" for _ in range(len(idexp_lm3d_pred_vae))]
  111. y4 = ["postnet" for _ in range(len(idexp_lm3d_pred_postnet))]
  112. y5 = ["lle" for _ in range(len(idexp_lm3d_pred_lle))]
  113. visualize(idexp_lm3d_all_emb, y2+y3+y4+y5)
  114. plt.savefig("0.png")
  115. idexp_lm3d_pred_vae = np.load("autio2motion_dream_it_possible.npy").reshape([-1,204])[1000:2000]
  116. idexp_lm3d_pred_postnet = np.load("postnet_dream_it_possible.npy").reshape([-1,204])[1000:2000]
  117. idexp_lm3d_pred_lle = np.load("lle_dream_it_possible.npy").reshape([-1,204])[1000:2000]
  118. # idexp_lm3d_pred_postnet = idexp_lm3d_pred_postnet * lrs3_idexp_std + lrs3_idexp_mean
  119. idexp_lm3d_all = np.concatenate([person_idexp_lm3d_train,idexp_lm3d_pred_vae, idexp_lm3d_pred_postnet,idexp_lm3d_pred_lle])
  120. idexp_lm3d_all_emb = tsne.fit(idexp_lm3d_all) # array(float64) [B,50]==>[B, 2]
  121. # z_p_emb = tsne.fit(z_p) # array(float64) [B,50]==>[B, 2]
  122. # y1 = ["pred_lrs3" for _ in range(len(idexp_lm3d_pred_lrs3))]
  123. y2 = ["person_train" for _ in range(len(person_idexp_lm3d_train))]
  124. y3 = ["vae" for _ in range(len(idexp_lm3d_pred_vae))]
  125. y4 = ["postnet" for _ in range(len(idexp_lm3d_pred_postnet))]
  126. y5 = ["lle" for _ in range(len(idexp_lm3d_pred_lle))]
  127. visualize(idexp_lm3d_all_emb, y2+y3+y4+y5)
  128. plt.savefig("1.png")
  129. idexp_lm3d_pred_vae = np.load("autio2motion_dream_it_possible.npy").reshape([-1,204])[2000:2500]
  130. idexp_lm3d_pred_postnet = np.load("postnet_dream_it_possible.npy").reshape([-1,204])[2000:2500]
  131. idexp_lm3d_pred_lle = np.load("lle_dream_it_possible.npy").reshape([-1,204])[2000:2500]
  132. # idexp_lm3d_pred_postnet = idexp_lm3d_pred_postnet * lrs3_idexp_std + lrs3_idexp_mean
  133. idexp_lm3d_all = np.concatenate([person_idexp_lm3d_train,idexp_lm3d_pred_vae, idexp_lm3d_pred_postnet,idexp_lm3d_pred_lle])
  134. idexp_lm3d_all_emb = tsne.fit(idexp_lm3d_all) # array(float64) [B,50]==>[B, 2]
  135. # z_p_emb = tsne.fit(z_p) # array(float64) [B,50]==>[B, 2]
  136. # y1 = ["pred_lrs3" for _ in range(len(idexp_lm3d_pred_lrs3))]
  137. y2 = ["person_train" for _ in range(len(person_idexp_lm3d_train))]
  138. y3 = ["vae" for _ in range(len(idexp_lm3d_pred_vae))]
  139. y4 = ["postnet" for _ in range(len(idexp_lm3d_pred_postnet))]
  140. y5 = ["lle" for _ in range(len(idexp_lm3d_pred_lle))]
  141. visualize(idexp_lm3d_all_emb, y2+y3+y4+y5)
  142. plt.savefig("2.png")