built_in.go 5.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221
  1. // Copyright 2020 gorse Project Authors
  2. //
  3. // Licensed under the Apache License, Version 2.0 (the "License");
  4. // you may not use this file except in compliance with the License.
  5. // You may obtain a copy of the License at
  6. //
  7. // http://www.apache.org/licenses/LICENSE-2.0
  8. //
  9. // Unless required by applicable law or agreed to in writing, software
  10. // distributed under the License is distributed on an "AS IS" BASIS,
  11. // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. // See the License for the specific language governing permissions and
  13. // limitations under the License.
  14. package model
  15. import (
  16. "archive/zip"
  17. "fmt"
  18. "github.com/zhenghaoz/gorse/base/log"
  19. "io"
  20. "net/http"
  21. "os"
  22. "os/user"
  23. "path/filepath"
  24. "strings"
  25. "go.uber.org/zap"
  26. )
  27. type DatasetFormat int
  28. const (
  29. FormatNCF DatasetFormat = iota
  30. FormatLibFM
  31. )
  32. // Built-in Data set
  33. type _BuiltInDataSet struct {
  34. downloadURL string
  35. trainFile string
  36. testFile string
  37. format DatasetFormat
  38. }
  39. var builtInDataSets = map[string]_BuiltInDataSet{
  40. "pinterest-20": {
  41. downloadURL: "https://cdn.gorse.io/datasets/pinterest-20.zip",
  42. trainFile: "pinterest-20/train.txt",
  43. testFile: "pinterest-20/test.txt",
  44. format: FormatNCF,
  45. },
  46. "ml-100k": {
  47. downloadURL: "https://cdn.gorse.io/datasets/ml-100k.zip",
  48. trainFile: "ml-100k/train.txt",
  49. testFile: "ml-100k/test.txt",
  50. format: FormatNCF,
  51. },
  52. "ml-1m": {
  53. downloadURL: "https://cdn.gorse.io/datasets/ml-1m.zip",
  54. trainFile: "ml-1m/train.txt",
  55. testFile: "ml-1m/test.txt",
  56. format: FormatNCF,
  57. },
  58. "ml-tag": {
  59. downloadURL: "https://cdn.gorse.io/datasets/ml-tag.zip",
  60. trainFile: "ml-tag/train.libfm",
  61. testFile: "ml-tag/test.libfm",
  62. format: FormatLibFM,
  63. },
  64. "frappe": {
  65. downloadURL: "https://cdn.gorse.io/datasets/frappe.zip",
  66. trainFile: "frappe/train.libfm",
  67. testFile: "frappe/test.libfm",
  68. format: FormatLibFM,
  69. },
  70. "criteo": {
  71. downloadURL: "https://cdn.gorse.io/datasets/criteo.zip",
  72. trainFile: "criteo/train.libfm",
  73. testFile: "criteo/test.libfm",
  74. format: FormatLibFM,
  75. },
  76. }
  77. // The Data directories
  78. var (
  79. GorseDir string
  80. DataSetDir string
  81. TempDir string
  82. )
  83. func init() {
  84. usr, err := user.Current()
  85. if err != nil {
  86. log.Logger().Fatal("failed to get user directory", zap.Error(err))
  87. }
  88. GorseDir = usr.HomeDir + "/.gorse"
  89. DataSetDir = GorseDir + "/dataset"
  90. TempDir = GorseDir + "/temp"
  91. // create all folders
  92. if err = os.MkdirAll(DataSetDir, os.ModePerm); err != nil {
  93. log.Logger().Fatal("failed to create directory", zap.Error(err), zap.String("path", DataSetDir))
  94. }
  95. if err = os.MkdirAll(TempDir, os.ModePerm); err != nil {
  96. log.Logger().Fatal("failed to create directory", zap.Error(err), zap.String("path", TempDir))
  97. }
  98. }
  99. func LocateBuiltInDataset(name string, format DatasetFormat) (string, string, error) {
  100. // Extract Data set information
  101. dataSet, exist := builtInDataSets[name]
  102. if !exist {
  103. return "", "", fmt.Errorf("no such dataset %v", name)
  104. }
  105. if dataSet.format != format {
  106. return "", "", fmt.Errorf("format not matchs %v != %v", format, dataSet.format)
  107. }
  108. // Download if not exists
  109. trainFilePah := filepath.Join(DataSetDir, dataSet.trainFile)
  110. testFilePath := filepath.Join(DataSetDir, dataSet.testFile)
  111. if _, err := os.Stat(trainFilePah); os.IsNotExist(err) {
  112. zipFileName, _ := downloadFromUrl(dataSet.downloadURL, TempDir)
  113. if _, err := unzip(zipFileName, DataSetDir); err != nil {
  114. return "", "", err
  115. }
  116. }
  117. return trainFilePah, testFilePath, nil
  118. }
  119. // downloadFromUrl downloads file from URL.
  120. func downloadFromUrl(src, dst string) (string, error) {
  121. log.Logger().Info("Download dataset", zap.String("source", src))
  122. // Extract file name
  123. tokens := strings.Split(src, "/")
  124. fileName := filepath.Join(dst, tokens[len(tokens)-1])
  125. // Create file
  126. if err := os.MkdirAll(filepath.Dir(fileName), os.ModePerm); err != nil {
  127. return fileName, err
  128. }
  129. output, err := os.Create(fileName)
  130. if err != nil {
  131. log.Logger().Error("failed to create file", zap.Error(err), zap.String("filename", fileName))
  132. return fileName, err
  133. }
  134. defer output.Close()
  135. // Download file
  136. response, err := http.Get(src)
  137. if err != nil {
  138. log.Logger().Error("failed to download", zap.Error(err), zap.String("source", src))
  139. return fileName, err
  140. }
  141. defer response.Body.Close()
  142. // Save file
  143. _, err = io.Copy(output, response.Body)
  144. if err != nil {
  145. log.Logger().Error("failed to download", zap.Error(err), zap.String("source", src))
  146. return fileName, err
  147. }
  148. return fileName, nil
  149. }
  150. // unzip zip file.
  151. func unzip(src, dst string) ([]string, error) {
  152. var fileNames []string
  153. // Open zip file
  154. r, err := zip.OpenReader(src)
  155. if err != nil {
  156. return fileNames, err
  157. }
  158. defer r.Close()
  159. // Extract files
  160. for _, f := range r.File {
  161. // Open file
  162. rc, err := f.Open()
  163. if err != nil {
  164. return fileNames, err
  165. }
  166. // Store filename/path for returning and using later on
  167. filePath := filepath.Join(dst, f.Name)
  168. // Check for ZipSlip. More Info: http://bit.ly/2MsjAWE
  169. if !strings.HasPrefix(filePath, filepath.Clean(dst)+string(os.PathSeparator)) {
  170. return fileNames, fmt.Errorf("%s: illegal file path", filePath)
  171. }
  172. // Add filename
  173. fileNames = append(fileNames, filePath)
  174. if f.FileInfo().IsDir() {
  175. // Create folder
  176. if err = os.MkdirAll(filePath, os.ModePerm); err != nil {
  177. return fileNames, err
  178. }
  179. } else {
  180. // Create all folders
  181. if err = os.MkdirAll(filepath.Dir(filePath), os.ModePerm); err != nil {
  182. return fileNames, err
  183. }
  184. // Create file
  185. outFile, err := os.OpenFile(filePath, os.O_WRONLY|os.O_CREATE|os.O_TRUNC, f.Mode())
  186. if err != nil {
  187. return fileNames, err
  188. }
  189. // Save file
  190. _, err = io.Copy(outFile, rc)
  191. if err != nil {
  192. return nil, err
  193. }
  194. // Close the file without defer to close before next iteration of loop
  195. err = outFile.Close()
  196. if err != nil {
  197. return nil, err
  198. }
  199. }
  200. // Close file
  201. err = rc.Close()
  202. if err != nil {
  203. return nil, err
  204. }
  205. }
  206. return fileNames, nil
  207. }