Skip to content

Commit

Permalink
Browse files Browse the repository at this point in the history
  • Loading branch information
datumbox committed Jan 27, 2021
1 parent 60b2894 commit ed56815
Show file tree
Hide file tree
Showing 23 changed files with 341 additions and 158 deletions.
8 changes: 8 additions & 0 deletions torchvision/csrc/io/image/cpu/jpegcommon.cpp
Original file line number Diff line number Diff line change
@@ -1,5 +1,9 @@
#include "jpegcommon.h"

namespace vision {
namespace image {
namespace detail {

#if JPEG_FOUND
void torch_jpeg_error_exit(j_common_ptr cinfo) {
/* cinfo->err really points to a torch_jpeg_error_mgr struct, so coerce
Expand All @@ -16,3 +20,7 @@ void torch_jpeg_error_exit(j_common_ptr cinfo) {
longjmp(myerr->setjmp_buffer, 1);
}
#endif

} // namespace detail
} // namespace image
} // namespace vision
13 changes: 13 additions & 0 deletions torchvision/csrc/io/image/cpu/jpegcommon.h
Original file line number Diff line number Diff line change
@@ -1,8 +1,17 @@
#pragma once

#if JPEG_FOUND
#include <stdio.h>

#include <jpeglib.h>
#include <setjmp.h>
#endif

namespace vision {
namespace image {
namespace detail {

#if JPEG_FOUND

static const JOCTET EOI_BUFFER[1] = {JPEG_EOI};
struct torch_jpeg_error_mgr {
Expand All @@ -15,3 +24,7 @@ using torch_jpeg_error_ptr = struct torch_jpeg_error_mgr*;
void torch_jpeg_error_exit(j_common_ptr cinfo);

#endif

} // namespace detail
} // namespace image
} // namespace vision
6 changes: 6 additions & 0 deletions torchvision/csrc/io/image/cpu/pngcommon.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
#pragma once

#if PNG_FOUND
#include <png.h>
#include <setjmp.h>
#endif
6 changes: 6 additions & 0 deletions torchvision/csrc/io/image/cpu/read_image_impl.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,9 @@
#include "readjpeg_impl.h"
#include "readpng_impl.h"

namespace vision {
namespace image {

torch::Tensor decode_image(const torch::Tensor& data, ImageReadMode mode) {
// Check that the input tensor dtype is uint8
TORCH_CHECK(data.dtype() == torch::kU8, "Expected a torch.uint8 tensor");
Expand All @@ -27,3 +30,6 @@ torch::Tensor decode_image(const torch::Tensor& data, ImageReadMode mode) {
"are currently supported.");
}
}

} // namespace image
} // namespace vision
6 changes: 6 additions & 0 deletions torchvision/csrc/io/image/cpu/read_image_impl.h
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,12 @@
#include <torch/types.h>
#include "../image_read_mode.h"

namespace vision {
namespace image {

C10_EXPORT torch::Tensor decode_image(
const torch::Tensor& data,
ImageReadMode mode = IMAGE_READ_MODE_UNCHANGED);

} // namespace image
} // namespace vision
12 changes: 12 additions & 0 deletions torchvision/csrc/io/image/cpu/read_write_file_impl.cpp
Original file line number Diff line number Diff line change
@@ -1,9 +1,17 @@
#include "read_write_file_impl.h"

#include <sys/stat.h>

#ifdef _WIN32
#define WIN32_LEAN_AND_MEAN
#include <Windows.h>
#endif

namespace vision {
namespace image {

#ifdef _WIN32
namespace {
std::wstring utf8_decode(const std::string& str) {
if (str.empty()) {
return std::wstring();
Expand All @@ -21,6 +29,7 @@ std::wstring utf8_decode(const std::string& str) {
size_needed);
return wstrTo;
}
} // namespace
#endif

torch::Tensor read_file(const std::string& filename) {
Expand Down Expand Up @@ -90,3 +99,6 @@ void write_file(const std::string& filename, torch::Tensor& data) {
fwrite(fileBytes, sizeof(uint8_t), data.numel(), outfile);
fclose(outfile);
}

} // namespace image
} // namespace vision
7 changes: 6 additions & 1 deletion torchvision/csrc/io/image/cpu/read_write_file_impl.h
Original file line number Diff line number Diff line change
@@ -1,8 +1,13 @@
#pragma once

#include <sys/stat.h>
#include <torch/types.h>

namespace vision {
namespace image {

C10_EXPORT torch::Tensor read_file(const std::string& filename);

C10_EXPORT void write_file(const std::string& filename, torch::Tensor& data);

} // namespace image
} // namespace vision
16 changes: 14 additions & 2 deletions torchvision/csrc/io/image/cpu/readjpeg_impl.cpp
Original file line number Diff line number Diff line change
@@ -1,12 +1,19 @@
#include "readjpeg_impl.h"
#include "jpegcommon.h"

namespace vision {
namespace image {

using namespace detail;

#if !JPEG_FOUND
torch::Tensor decodeJPEG(const torch::Tensor& data, ImageReadMode mode) {
TORCH_CHECK(
false, "decodeJPEG: torchvision not compiled with libjpeg support");
}
#else
#include "jpegcommon.h"

namespace {

struct torch_jpeg_mgr {
struct jpeg_source_mgr pub;
Expand Down Expand Up @@ -64,6 +71,8 @@ static void torch_jpeg_set_source_mgr(
src->pub.next_input_byte = src->data;
}

} // namespace

torch::Tensor decodeJPEG(const torch::Tensor& data, ImageReadMode mode) {
// Check that the input tensor dtype is uint8
TORCH_CHECK(data.dtype() == torch::kU8, "Expected a torch.uint8 tensor");
Expand Down Expand Up @@ -146,4 +155,7 @@ torch::Tensor decodeJPEG(const torch::Tensor& data, ImageReadMode mode) {
return tensor.permute({2, 0, 1});
}

#endif // JPEG_FOUND
#endif

} // namespace image
} // namespace vision
6 changes: 6 additions & 0 deletions torchvision/csrc/io/image/cpu/readjpeg_impl.h
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,12 @@
#include <torch/types.h>
#include "../image_read_mode.h"

namespace vision {
namespace image {

C10_EXPORT torch::Tensor decodeJPEG(
const torch::Tensor& data,
ImageReadMode mode = IMAGE_READ_MODE_UNCHANGED);

} // namespace image
} // namespace vision
11 changes: 8 additions & 3 deletions torchvision/csrc/io/image/cpu/readpng_impl.cpp
Original file line number Diff line number Diff line change
@@ -1,12 +1,14 @@
#include "readpng_impl.h"
#include "pngcommon.h"

namespace vision {
namespace image {

#if !PNG_FOUND
torch::Tensor decodePNG(const torch::Tensor& data, ImageReadMode mode) {
TORCH_CHECK(false, "decodePNG: torchvision not compiled with libPNG support");
}
#else
#include <png.h>
#include <setjmp.h>

torch::Tensor decodePNG(const torch::Tensor& data, ImageReadMode mode) {
// Check that the input tensor dtype is uint8
Expand Down Expand Up @@ -160,4 +162,7 @@ torch::Tensor decodePNG(const torch::Tensor& data, ImageReadMode mode) {
png_destroy_read_struct(&png_ptr, &info_ptr, nullptr);
return tensor.permute({2, 0, 1});
}
#endif // PNG_FOUND
#endif

} // namespace image
} // namespace vision
6 changes: 6 additions & 0 deletions torchvision/csrc/io/image/cpu/readpng_impl.h
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,12 @@
#include <torch/types.h>
#include "../image_read_mode.h"

namespace vision {
namespace image {

C10_EXPORT torch::Tensor decodePNG(
const torch::Tensor& data,
ImageReadMode mode = IMAGE_READ_MODE_UNCHANGED);

} // namespace image
} // namespace vision
11 changes: 10 additions & 1 deletion torchvision/csrc/io/image/cpu/writejpeg_impl.cpp
Original file line number Diff line number Diff line change
@@ -1,5 +1,12 @@
#include "writejpeg_impl.h"

#include "jpegcommon.h"

namespace vision {
namespace image {

using namespace detail;

#if !JPEG_FOUND

torch::Tensor encodeJPEG(const torch::Tensor& data, int64_t quality) {
Expand All @@ -8,7 +15,6 @@ torch::Tensor encodeJPEG(const torch::Tensor& data, int64_t quality) {
}

#else
#include "jpegcommon.h"

torch::Tensor encodeJPEG(const torch::Tensor& data, int64_t quality) {
// Define compression structures and error handling
Expand Down Expand Up @@ -98,3 +104,6 @@ torch::Tensor encodeJPEG(const torch::Tensor& data, int64_t quality) {
return outTensor;
}
#endif

} // namespace image
} // namespace vision
6 changes: 6 additions & 0 deletions torchvision/csrc/io/image/cpu/writejpeg_impl.h
Original file line number Diff line number Diff line change
Expand Up @@ -2,4 +2,10 @@

#include <torch/types.h>

namespace vision {
namespace image {

C10_EXPORT torch::Tensor encodeJPEG(const torch::Tensor& data, int64_t quality);

} // namespace image
} // namespace vision
14 changes: 12 additions & 2 deletions torchvision/csrc/io/image/cpu/writepng_impl.cpp
Original file line number Diff line number Diff line change
@@ -1,14 +1,19 @@
#include "writejpeg_impl.h"

#include "pngcommon.h"

namespace vision {
namespace image {

#if !PNG_FOUND

torch::Tensor encodePNG(const torch::Tensor& data, int64_t compression_level) {
TORCH_CHECK(false, "encodePNG: torchvision not compiled with libpng support");
}

#else
#include <png.h>
#include <setjmp.h>

namespace {

struct torch_mem_encode {
char* buffer;
Expand Down Expand Up @@ -59,6 +64,8 @@ void torch_png_write_data(
p->size += length;
}

} // namespace

torch::Tensor encodePNG(const torch::Tensor& data, int64_t compression_level) {
// Define compression structures and error handling
png_structp png_write;
Expand Down Expand Up @@ -171,3 +178,6 @@ torch::Tensor encodePNG(const torch::Tensor& data, int64_t compression_level) {
}

#endif

} // namespace image
} // namespace vision
6 changes: 6 additions & 0 deletions torchvision/csrc/io/image/cpu/writepng_impl.h
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,12 @@

#include <torch/types.h>

namespace vision {
namespace image {

C10_EXPORT torch::Tensor encodePNG(
const torch::Tensor& data,
int64_t compression_level);

} // namespace image
} // namespace vision
6 changes: 6 additions & 0 deletions torchvision/csrc/io/image/image.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,9 @@ PyMODINIT_FUNC PyInit_image(void) {
}
#endif

namespace vision {
namespace image {

static auto registry = torch::RegisterOperators()
.op("image::decode_png", &decodePNG)
.op("image::encode_png", &encodePNG)
Expand All @@ -19,3 +22,6 @@ static auto registry = torch::RegisterOperators()
.op("image::read_file", &read_file)
.op("image::write_file", &write_file)
.op("image::decode_image", &decode_image);

} // namespace image
} // namespace vision
8 changes: 8 additions & 0 deletions torchvision/csrc/io/image/image_read_mode.h
Original file line number Diff line number Diff line change
@@ -1,9 +1,17 @@
#pragma once

#include <stdint.h>

namespace vision {
namespace image {

/* Should be kept in-sync with Python ImageReadMode enum */
using ImageReadMode = int64_t;
const ImageReadMode IMAGE_READ_MODE_UNCHANGED = 0;
const ImageReadMode IMAGE_READ_MODE_GRAY = 1;
const ImageReadMode IMAGE_READ_MODE_GRAY_ALPHA = 2;
const ImageReadMode IMAGE_READ_MODE_RGB = 3;
const ImageReadMode IMAGE_READ_MODE_RGB_ALPHA = 4;

} // namespace image
} // namespace vision
14 changes: 0 additions & 14 deletions torchvision/csrc/io/video/register.cpp

This file was deleted.

Loading

0 comments on commit ed56815

Please sign in to comment.