Skip to content

Commit 142a00e

Browse files
committed
add lesson6
1 parent aaecd1c commit 142a00e

File tree

15 files changed

+26447
-0
lines changed

15 files changed

+26447
-0
lines changed

lesson6-Segmentation/ResNet.cpp

Lines changed: 240 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,240 @@
1+
#include "ResNet.h"
2+
3+
BlockImpl::BlockImpl(int64_t inplanes, int64_t planes, int64_t stride_,
4+
torch::nn::Sequential downsample_, int groups, int base_width, bool _is_basic)
5+
{
6+
downsample = downsample_;
7+
stride = stride_;
8+
int width = int(planes * (base_width / 64.)) * groups;
9+
10+
conv1 = torch::nn::Conv2d(conv_options(inplanes, width, 3, stride_, 1, groups, false));
11+
bn1 = torch::nn::BatchNorm2d(torch::nn::BatchNorm2dOptions(width));
12+
conv2 = torch::nn::Conv2d(conv_options(width, width, 3, 1, 1, groups, false));
13+
bn2 = torch::nn::BatchNorm2d(torch::nn::BatchNorm2dOptions(width));
14+
is_basic = _is_basic;
15+
if (!is_basic) {
16+
conv1 = torch::nn::Conv2d(conv_options(inplanes, width, 1, 1, 0, 1, false));
17+
conv2 = torch::nn::Conv2d(conv_options(width, width, 3, stride_, 1, groups, false));
18+
conv3 = torch::nn::Conv2d(conv_options(width, planes * 4, 1, 1, 0, 1, false));
19+
bn3 = torch::nn::BatchNorm2d(torch::nn::BatchNorm2dOptions(planes * 4));
20+
}
21+
22+
register_module("conv1", conv1);
23+
register_module("bn1", bn1);
24+
register_module("conv2", conv2);
25+
register_module("bn2", bn2);
26+
if (!is_basic) {
27+
register_module("conv3", conv3);
28+
register_module("bn3", bn3);
29+
}
30+
31+
if (!downsample->is_empty()) {
32+
register_module("downsample", downsample);
33+
}
34+
}
35+
36+
torch::Tensor BlockImpl::forward(torch::Tensor x) {
37+
torch::Tensor residual = x.clone();
38+
39+
x = conv1->forward(x);
40+
x = bn1->forward(x);
41+
x = torch::relu(x);
42+
43+
x = conv2->forward(x);
44+
x = bn2->forward(x);
45+
46+
if (!is_basic) {
47+
x = torch::relu(x);
48+
x = conv3->forward(x);
49+
x = bn3->forward(x);
50+
}
51+
52+
if (!downsample->is_empty()) {
53+
residual = downsample->forward(residual);
54+
}
55+
56+
x += residual;
57+
x = torch::relu(x);
58+
59+
return x;
60+
}
61+
62+
ResNetImpl::ResNetImpl(std::vector<int> layers, int num_classes, std::string model_type, int _groups, int _width_per_group)
63+
{
64+
if (model_type != "resnet18" && model_type != "resnet34")
65+
{
66+
expansion = 4;
67+
is_basic = false;
68+
}
69+
groups = _groups;
70+
base_width = _width_per_group;
71+
conv1 = torch::nn::Conv2d(conv_options(3, 64, 7, 2, 3, 1, false));
72+
bn1 = torch::nn::BatchNorm2d(torch::nn::BatchNorm2dOptions(64));
73+
layer1 = torch::nn::Sequential(_make_layer(64, layers[0]));
74+
layer2 = torch::nn::Sequential(_make_layer(128, layers[1], 2));
75+
layer3 = torch::nn::Sequential(_make_layer(256, layers[2], 2));
76+
layer4 = torch::nn::Sequential(_make_layer(512, layers[3], 2));
77+
78+
fc = torch::nn::Linear(512 * expansion, num_classes);
79+
register_module("conv1", conv1);
80+
register_module("bn1", bn1);
81+
register_module("layer1", layer1);
82+
register_module("layer2", layer2);
83+
register_module("layer3", layer3);
84+
register_module("layer4", layer4);
85+
register_module("fc", fc);
86+
}
87+
88+
89+
torch::Tensor ResNetImpl::forward(torch::Tensor x) {
90+
x = conv1->forward(x);
91+
x = bn1->forward(x);
92+
x = torch::relu(x);
93+
x = torch::max_pool2d(x, 3, 2, 1);
94+
95+
x = layer1->forward(x);
96+
x = layer2->forward(x);
97+
x = layer3->forward(x);
98+
x = layer4->forward(x);
99+
100+
x = torch::avg_pool2d(x, 7, 1);
101+
x = x.view({ x.sizes()[0], -1 });
102+
x = fc->forward(x);
103+
104+
return torch::log_softmax(x, 1);
105+
}
106+
107+
std::vector<torch::Tensor> ResNetImpl::features(torch::Tensor x){
108+
std::vector<torch::Tensor> features;
109+
features.push_back(x);
110+
x = conv1->forward(x);
111+
x = bn1->forward(x);
112+
x = torch::relu(x);
113+
features.push_back(x);
114+
x = torch::max_pool2d(x, 3, 2, 1);
115+
116+
x = layer1->forward(x);
117+
features.push_back(x);
118+
x = layer2->forward(x);
119+
features.push_back(x);
120+
x = layer3->forward(x);
121+
features.push_back(x);
122+
x = layer4->forward(x);
123+
features.push_back(x);
124+
125+
return features;
126+
}
127+
128+
torch::nn::Sequential ResNetImpl::_make_layer(int64_t planes, int64_t blocks, int64_t stride) {
129+
130+
torch::nn::Sequential downsample;
131+
if (stride != 1 || inplanes != planes * expansion) {
132+
downsample = torch::nn::Sequential(
133+
torch::nn::Conv2d(conv_options(inplanes, planes * expansion, 1, stride, 0, 1, false)),
134+
torch::nn::BatchNorm2d(planes * expansion)
135+
);
136+
}
137+
torch::nn::Sequential layers;
138+
layers->push_back(Block(inplanes, planes, stride, downsample, groups, base_width, is_basic));
139+
inplanes = planes * expansion;
140+
for (int64_t i = 1; i < blocks; i++) {
141+
layers->push_back(Block(inplanes, planes, 1, torch::nn::Sequential(), groups, base_width,is_basic));
142+
}
143+
144+
return layers;
145+
}
146+
147+
void ResNetImpl::make_dilated(std::vector<int> stage_list, std::vector<int> dilation_list) {
148+
if (stage_list.size() != dilation_list.size()) {
149+
std::cout << "make sure stage list len equal to dilation list len";
150+
return;
151+
}
152+
std::map<int, torch::nn::Sequential> stage_dict = {};
153+
stage_dict.insert(std::pair<int, torch::nn::Sequential>(5, this->layer4));
154+
stage_dict.insert(std::pair<int, torch::nn::Sequential>(4, this->layer3));
155+
stage_dict.insert(std::pair<int, torch::nn::Sequential>(3, this->layer2));
156+
stage_dict.insert(std::pair<int, torch::nn::Sequential>(2, this->layer1));
157+
for (int i = 0; i < stage_list.size(); i++) {
158+
int dilation_rate = dilation_list[i];
159+
for (auto m : stage_dict[stage_list[i]]->modules()) {
160+
if (m->name() == "torch::nn::Conv2dImpl") {
161+
m->as<torch::nn::Conv2d>()->options.stride(1);
162+
m->as<torch::nn::Conv2d>()->options.dilation(dilation_rate);
163+
int kernel_size = m->as<torch::nn::Conv2d>()->options.kernel_size()->at(0);
164+
m->as<torch::nn::Conv2d>()->options.padding((kernel_size / 2) * dilation_rate);
165+
}
166+
}
167+
}
168+
return;
169+
}
170+
171+
ResNet resnet18(int64_t num_classes) {
172+
std::vector<int> layers = { 2, 2, 2, 2 };
173+
ResNet model(layers, num_classes, "resnet18");
174+
return model;
175+
}
176+
177+
ResNet resnet34(int64_t num_classes) {
178+
std::vector<int> layers = { 3, 4, 6, 3 };
179+
ResNet model(layers, num_classes, "resnet34");
180+
return model;
181+
}
182+
183+
ResNet resnet50(int64_t num_classes) {
184+
std::vector<int> layers = { 3, 4, 6, 3 };
185+
ResNet model(layers, num_classes, "resnet50");
186+
return model;
187+
}
188+
189+
ResNet resnet101(int64_t num_classes) {
190+
std::vector<int> layers = { 3, 4, 23, 3 };
191+
ResNet model(layers, num_classes, "resnet101");
192+
return model;
193+
}
194+
195+
ResNet pretrained_resnet(int64_t num_classes, std::string model_name, std::string weight_path){
196+
std::map<std::string, std::vector<int>> name2layers = getParams();
197+
int groups = 1;
198+
int width_per_group = 64;
199+
if (model_name == "resnext50_32x4d") {
200+
groups = 32; width_per_group = 4;
201+
}
202+
if (model_name == "resnext101_32x8d") {
203+
groups = 32; width_per_group = 8;
204+
}
205+
ResNet net_pretrained = ResNet(name2layers[model_name],1000,model_name,groups,width_per_group);
206+
torch::load(net_pretrained, weight_path);
207+
if(num_classes == 1000) return net_pretrained;
208+
ResNet module = ResNet(name2layers[model_name],num_classes,model_name);
209+
210+
torch::OrderedDict<std::string, at::Tensor> pretrained_dict = net_pretrained->named_parameters();
211+
torch::OrderedDict<std::string, at::Tensor> model_dict = module->named_parameters();
212+
213+
for (auto n = pretrained_dict.begin(); n != pretrained_dict.end(); n++)
214+
{
215+
if (strstr((*n).key().data(), "fc.")) {
216+
continue;
217+
}
218+
model_dict[(*n).key()] = (*n).value();
219+
}
220+
221+
torch::autograd::GradMode::set_enabled(false); // make parameters copying possible
222+
auto new_params = model_dict; // implement this
223+
auto params = module->named_parameters(true /*recurse*/);
224+
auto buffers = module->named_buffers(true /*recurse*/);
225+
for (auto& val : new_params) {
226+
auto name = val.key();
227+
auto* t = params.find(name);
228+
if (t != nullptr) {
229+
t->copy_(val.value());
230+
}
231+
else {
232+
t = buffers.find(name);
233+
if (t != nullptr) {
234+
t->copy_(val.value());
235+
}
236+
}
237+
}
238+
torch::autograd::GradMode::set_enabled(true);
239+
return module;
240+
}

lesson6-Segmentation/ResNet.h

Lines changed: 65 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,65 @@
1+
#ifndef RESNET_H
2+
#define RESNET_H
3+
#include"util.h"
4+
5+
class BlockImpl : public torch::nn::Module {
6+
public:
7+
BlockImpl(int64_t inplanes, int64_t planes, int64_t stride_ = 1,
8+
torch::nn::Sequential downsample_ = nullptr, int groups = 1, int base_width = 64, bool is_basic = true);
9+
torch::Tensor forward(torch::Tensor x);
10+
torch::nn::Sequential downsample{ nullptr };
11+
private:
12+
bool is_basic = true;
13+
int64_t stride = 1;
14+
torch::nn::Conv2d conv1{ nullptr };
15+
torch::nn::BatchNorm2d bn1{ nullptr };
16+
torch::nn::Conv2d conv2{ nullptr };
17+
torch::nn::BatchNorm2d bn2{ nullptr };
18+
torch::nn::Conv2d conv3{ nullptr };
19+
torch::nn::BatchNorm2d bn3{ nullptr };
20+
21+
};
22+
TORCH_MODULE(Block);
23+
24+
25+
class ResNetImpl : public torch::nn::Module {
26+
public:
27+
ResNetImpl(std::vector<int> layers, int num_classes = 1000, std::string model_type = "resnet18",
28+
int groups = 1, int width_per_group = 64);
29+
torch::Tensor forward(torch::Tensor x);
30+
std::vector<torch::Tensor> features(torch::Tensor x);
31+
torch::nn::Sequential _make_layer(int64_t planes, int64_t blocks, int64_t stride = 1);
32+
void make_dilated(std::vector<int> stage_list, std::vector<int> dilation_list);
33+
private:
34+
int expansion = 1; bool is_basic = true;
35+
int64_t inplanes = 64; int groups = 1; int base_width = 64;
36+
torch::nn::Conv2d conv1{ nullptr };
37+
torch::nn::BatchNorm2d bn1{ nullptr };
38+
torch::nn::Sequential layer1{ nullptr };
39+
torch::nn::Sequential layer2{ nullptr };
40+
torch::nn::Sequential layer3{ nullptr };
41+
torch::nn::Sequential layer4{ nullptr };
42+
torch::nn::Linear fc{nullptr};
43+
};
44+
TORCH_MODULE(ResNet);
45+
46+
inline std::map<std::string, std::vector<int>> getParams(){
47+
std::map<std::string, std::vector<int>> name2layers = {};
48+
name2layers.insert(std::pair<std::string, std::vector<int>>("resnet18",{2, 2, 2, 2}));
49+
name2layers.insert(std::pair<std::string, std::vector<int>>("resnet34",{3, 4, 6, 3}));
50+
name2layers.insert(std::pair<std::string, std::vector<int>>("resnet50",{3, 4, 6, 3}));
51+
name2layers.insert(std::pair<std::string, std::vector<int>>("resnet101",{3, 4, 23, 3}));
52+
name2layers.insert(std::pair<std::string, std::vector<int>>("resnet152", { 3, 8, 36, 3 }));
53+
name2layers.insert(std::pair<std::string, std::vector<int>>("resnext50_32x4d", { 3, 4, 6, 3 }));
54+
name2layers.insert(std::pair<std::string, std::vector<int>>("resnext101_32x8d", { 3, 4, 23, 3 }));
55+
56+
return name2layers;
57+
}
58+
59+
ResNet resnet18(int64_t num_classes);
60+
ResNet resnet34(int64_t num_classes);
61+
ResNet resnet50(int64_t num_classes);
62+
ResNet resnet101(int64_t num_classes);
63+
64+
ResNet pretrained_resnet(int64_t num_classes, std::string model_name, std::string weight_path);
65+
#endif // RESNET_H

0 commit comments

Comments
 (0)
pFad - Phonifier reborn

Pfad - The Proxy pFad of © 2024 Garber Painting. All rights reserved.

Note: This service is not intended for secure transactions such as banking, social media, email, or purchasing. Use at your own risk. We assume no liability whatsoever for broken pages.


Alternative Proxies:

Alternative Proxy

pFad Proxy

pFad v3 Proxy

pFad v4 Proxy