-
Notifications
You must be signed in to change notification settings - Fork 24
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Generic Multipart reader/writer implementaion (#53)
- Loading branch information
1 parent
2c6edc6
commit b2d67ea
Showing
4 changed files
with
145 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,80 @@ | ||
package filestream | ||
|
||
import ( | ||
"errors" | ||
"fmt" | ||
ioutils "github.com/jfrog/gofrog/io" | ||
"io" | ||
"mime/multipart" | ||
"net/http" | ||
"os" | ||
) | ||
|
||
const ( | ||
contentType = "Content-Type" | ||
FileType = "file" | ||
) | ||
|
||
// The expected type of function that should be provided to the ReadFilesFromStream func, that returns the writer that should handle each file | ||
type FileHandlerFunc func(fileName string) (writer io.WriteCloser, err error) | ||
|
||
func ReadFilesFromStream(multipartReader *multipart.Reader, fileHandlerFunc FileHandlerFunc) error { | ||
for { | ||
// Read the next file streamed from client | ||
fileReader, err := multipartReader.NextPart() | ||
if err != nil { | ||
if errors.Is(err, io.EOF) { | ||
break | ||
} | ||
return fmt.Errorf("failed to read file: %w", err) | ||
} | ||
err = readFile(fileReader, fileHandlerFunc) | ||
if err != nil { | ||
return err | ||
} | ||
|
||
} | ||
return nil | ||
} | ||
|
||
func readFile(fileReader *multipart.Part, fileHandlerFunc FileHandlerFunc) (err error) { | ||
fileName := fileReader.FileName() | ||
fileWriter, err := fileHandlerFunc(fileName) | ||
if err != nil { | ||
return err | ||
} | ||
defer ioutils.Close(fileWriter, &err) | ||
if _, err = io.Copy(fileWriter, fileReader); err != nil { | ||
return fmt.Errorf("failed writing '%s' file: %w", fileName, err) | ||
} | ||
return err | ||
} | ||
|
||
func WriteFilesToStream(responseWriter http.ResponseWriter, filePaths []string) (err error) { | ||
multipartWriter := multipart.NewWriter(responseWriter) | ||
responseWriter.Header().Set(contentType, multipartWriter.FormDataContentType()) | ||
|
||
for _, filePath := range filePaths { | ||
if err = writeFile(multipartWriter, filePath); err != nil { | ||
return | ||
} | ||
} | ||
|
||
// Close finishes the multipart message and writes the trailing | ||
// boundary end line to the output. | ||
return multipartWriter.Close() | ||
} | ||
|
||
func writeFile(multipartWriter *multipart.Writer, filePath string) (err error) { | ||
fileReader, err := os.Open(filePath) | ||
if err != nil { | ||
return fmt.Errorf("failed to open file: %w", err) | ||
} | ||
defer ioutils.Close(fileReader, &err) | ||
fileWriter, err := multipartWriter.CreateFormFile(FileType, filePath) | ||
if err != nil { | ||
return fmt.Errorf("failed to CreateFormFile: %w", err) | ||
} | ||
_, err = io.Copy(fileWriter, fileReader) | ||
return err | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,58 @@ | ||
package filestream | ||
|
||
import ( | ||
"github.com/stretchr/testify/assert" | ||
"io" | ||
"mime/multipart" | ||
"net/http/httptest" | ||
"os" | ||
"path/filepath" | ||
"strings" | ||
"testing" | ||
) | ||
|
||
var targetDir string | ||
|
||
func TestWriteFilesToStreamAndReadFilesFromStream(t *testing.T) { | ||
sourceDir := t.TempDir() | ||
// Create 2 file to be transferred via our multipart stream | ||
file1 := filepath.Join(sourceDir, "test1.txt") | ||
file2 := filepath.Join(sourceDir, "test2.txt") | ||
file1Content := []byte("test content1") | ||
file2Content := []byte("test content2") | ||
assert.NoError(t, os.WriteFile(file1, file1Content, 0600)) | ||
assert.NoError(t, os.WriteFile(file2, file2Content, 0600)) | ||
|
||
// Create the multipart writer that will stream our files | ||
responseWriter := httptest.NewRecorder() | ||
assert.NoError(t, WriteFilesToStream(responseWriter, []string{file1, file2})) | ||
|
||
// Create local temp dir that will store our files | ||
targetDir = t.TempDir() | ||
|
||
// Get boundary hash from writer | ||
boundary := strings.Split(responseWriter.Header().Get(contentType), "boundary=")[1] | ||
// Create the multipart reader that will read the files from the stream | ||
multipartReader := multipart.NewReader(responseWriter.Body, boundary) | ||
assert.NoError(t, ReadFilesFromStream(multipartReader, simpleFileHandler)) | ||
|
||
// Validate file 1 transferred successfully | ||
file1 = filepath.Join(targetDir, "test1.txt") | ||
assert.FileExists(t, file1) | ||
content, err := os.ReadFile(file1) | ||
assert.NoError(t, err) | ||
assert.Equal(t, file1Content, content) | ||
assert.NoError(t, os.Remove(file1)) | ||
|
||
// Validate file 2 transferred successfully | ||
file2 = filepath.Join(targetDir, "test2.txt") | ||
assert.FileExists(t, file2) | ||
content, err = os.ReadFile(file2) | ||
assert.NoError(t, err) | ||
assert.Equal(t, file2Content, content) | ||
assert.NoError(t, os.Remove(file2)) | ||
} | ||
|
||
func simpleFileHandler(fileName string) (fileWriter io.WriteCloser, err error) { | ||
return os.Create(filepath.Join(targetDir, fileName)) | ||
} |