// Author: NiuJiuRu // Email: niujiuru@qq.com package ftpclient import ( "context" "fmt" "io" "os" "path/filepath" "sync" "sync/atomic" "time" "github.com/jlaffaye/ftp" "hnyfkj.com.cn/rtu/linux/baseapp" ) const MODULE_NAME = "FtpClient" const ( defaultRtyInterval = 1 * time.Second defaultLogInterval = 2 * time.Second DefaultUploadTimeout = 5 * time.Minute DefaultDownloadTimeout = 5 * time.Minute ) var ( fileLock = struct { sync.Mutex m map[string]struct{} }{m: make(map[string]struct{})} FileUpFolder = "" // 上传文件目录 FileUploader = struct { UploadLock sync.Mutex // 上传照片任务只能串行进行时, 可以通过使用该锁来实现排队串行 }{} ) func tryLockFile(file string) (unlock func(), ok bool) { fileLock.Lock() defer fileLock.Unlock() if _, ok := fileLock.m[file]; ok { return nil, false } fileLock.m[file] = struct{}{} return func() { fileLock.Lock() delete(fileLock.m, file) fileLock.Unlock() }, true } type progressReader struct { io.Reader filename string total, transferred int64 label string // "上传"或"下载" ctx context.Context doneLogged int32 } func newProgressReader(r io.Reader, filename string, total, transferred int64, label string, ctx context.Context) *progressReader { pr := &progressReader{Reader: r, filename: filename, total: total, transferred: transferred, label: label, ctx: ctx} go pr.startProgressLogger() return pr } func (p *progressReader) Read(buf []byte) (int, error) { n, err := p.Reader.Read(buf) if n > 0 { atomic.AddInt64(&p.transferred, int64(n)) } if err == io.EOF && atomic.CompareAndSwapInt32(&p.doneLogged, 0, 1) { transferred := atomic.LoadInt64(&p.transferred) baseapp.Logger.Infof("[%s] 文件%q%s进度: 100.00%%, 剩余: %d字节, 总大小: %d字节", MODULE_NAME, p.filename, p.label, p.total-transferred, p.total) } return n, err } func (p *progressReader) startProgressLogger() { ticker := time.NewTicker(defaultLogInterval) defer ticker.Stop() for { select { case <-ticker.C: if atomic.LoadInt32(&p.doneLogged) == 1 { return } transferred := atomic.LoadInt64(&p.transferred) if transferred >= p.total { return } progress := float64(transferred) / float64(p.total) * 100 baseapp.Logger.Infof("[%s] 文件%q%s进度: %.2f%%, 剩余: %d字节, 总大小: %d字节", MODULE_NAME, p.filename, p.label, progress, p.total-transferred, p.total) case <-p.ctx.Done(): return } } } type stopError struct{ err error } func (e *stopError) Error() string { return e.err.Error() } func UploadFileToFtp(ctx context.Context, localFile, serverAddr, loginUser, loginPass string, timeout time.Duration) (string, error) { unlock, ok := tryLockFile(localFile) if !ok { return "", fmt.Errorf("文件%q正在使用中", localFile) } defer unlock() if ctx == nil { ctx = context.Background() } timeoutCtx, cancel := context.WithTimeout(ctx, timeout) defer cancel() if FileUpFolder == "" { return "", fmt.Errorf("无效的上传目录") } lf, err := os.Open(localFile) if err != nil { return "", err } defer lf.Close() info, err := lf.Stat() if err != nil { return "", err } lfs := info.Size() // 总上传的字节数 rd := FileUpFolder // 远程目录名 remoteFile := filepath.Join(rd, filepath.Base(localFile)) // 远程文件名 for { select { case <-timeoutCtx.Done(): return "", timeoutCtx.Err() default: } err := func() error { c, err := ftp.Dial(serverAddr, ftp.DialWithContext(timeoutCtx)) if err != nil { return &stopError{err} } defer c.Quit() if err := c.Login(loginUser, loginPass); err != nil { return &stopError{err} } _ = c.MakeDir(rd) // 尝试创建远程目录, 忽略已存在和其它错误 rfs, err := c.FileSize(remoteFile) // 已上传的字节数 if err != nil || rfs > lfs { rfs = 0 } if _, err := lf.Seek(rfs, io.SeekStart); err != nil { return &stopError{err} } pr := newProgressReader(lf, localFile, lfs, rfs, "上传", timeoutCtx) if err := c.StorFrom(remoteFile, pr, uint64(rfs)); err != nil { return err } return nil }() if err != nil { if lfe, ok := err.(*stopError); ok { return "", lfe.err } time.Sleep(defaultRtyInterval) continue } return remoteFile, nil } } func DownloadFileFromFtp(ctx context.Context, serverAddr, loginUser, loginPass, remoteFile string, timeout time.Duration) (string, error) { unlock, ok := tryLockFile(remoteFile) if !ok { return "", fmt.Errorf("文件%q正在使用中", remoteFile) } defer unlock() if ctx == nil { ctx = context.Background() } timeoutCtx, cancel := context.WithTimeout(ctx, timeout) defer cancel() ld := baseapp.VAR_DIR // 本地目录名 localFile := filepath.Join(ld, filepath.Base(remoteFile)) // 本地文件名 lf, err := os.OpenFile(localFile, os.O_CREATE|os.O_RDWR, 0644) if err != nil { return "", err } defer lf.Close() info, err := lf.Stat() if err != nil { return "", err } lfs := info.Size() // 已下载的字节数 for { select { case <-timeoutCtx.Done(): return "", timeoutCtx.Err() default: } err := func() error { c, err := ftp.Dial(serverAddr, ftp.DialWithContext(timeoutCtx)) if err != nil { return &stopError{err} } defer c.Quit() if err := c.Login(loginUser, loginPass); err != nil { return &stopError{err} } rfs, err := c.FileSize(remoteFile) // 总下载的字节数 if err != nil { return &stopError{err} } if lfs > rfs { lfs = 0 } if _, err := lf.Seek(lfs, io.SeekStart); err != nil { return &stopError{err} } resp, err := c.RetrFrom(remoteFile, uint64(lfs)) if err != nil { return err } defer resp.Close() pr := newProgressReader(resp, remoteFile, rfs, lfs, "下载", timeoutCtx) n, err := io.Copy(lf, pr) if err != nil { return err } lfs += n return nil }() if err != nil { if lfe, ok := err.(*stopError); ok { return "", lfe.err } time.Sleep(defaultRtyInterval) continue } return localFile, nil } }